main function
void
main()
Implementation
void main() {
// 1. Setup Network
final model = MLP(2, [4, 1]);
final double learningRate = 0.5; // High LR for XOR
// 2. Data (XOR)
final xData = Tensor.fromList([4, 2], [0,0, 0,1, 1,0, 1,1]);
final target = Tensor.fromList([4, 1], [0, 1, 1, 0]);
print('Starting GPU MLP Training...');
for (int epoch = 0; epoch <= 1000; epoch++) {
// This list tracks every tensor created in this iteration
List<Tensor> tracker = [];
// Forward
final pred = model.forward(xData, tracker);
// Loss = (pred - target)^2
final diff = pred - target;
final loss = diff.pow(2.0);
tracker.addAll([diff, loss]);
// Backward
model.zeroGrad();
loss.backward();
// Update (On GPU)
model.step(learningRate);
if (epoch % 100 == 0) {
print("Epoch $epoch, Loss: ${loss.data[0].toStringAsFixed(6)}");
}
// --- CRITICAL: GPU MEMORY CLEANUP ---
// Dispose all intermediate tensors created this epoch
for (var t in tracker) {
t.dispose();
}
}
// Final Test
List<Tensor> dummy = [];
final finalPred = model.forward(xData, dummy);
print("\nFinal Results:");
print(finalPred.data);
// Final cleanup
finalPred.dispose();
xData.dispose();
target.dispose();
}