main function
void
main()
Implementation
void main() {
// 1. Setup Network
final model = MLP(768, [20, 10]);
final double learningRate = 0.5; // High LR for XOR
// final optimizer = Adam(model.parameters(), lr: learningRate);
// 2. Data (XOR)
final xData = Tensor.random([1, 768]);
final target = [0, 0, 0, 0, 1, 0, 0, 0, 0, 0];
print('Starting GPU MLP Training...');
for (int epoch = 0; epoch <= 2000; epoch++) {
// This list tracks every tensor created in this iteration
List<Tensor> tracker = [];
// Backward
model.zeroGrad();
// Forward
final logits = model.forward(xData, tracker);
// Loss = (pred - target)^2
// final diff = pred - target;
final loss = logits.crossEntropy(target);
tracker.addAll([loss]);
loss.backward();
// Update (On GPU)
model.step(learningRate);
// optimizer.step();
if (epoch % 100 == 0) {
print("Epoch $epoch, Loss: ${loss.data[0].toStringAsFixed(6)}");
}
// --- CRITICAL: GPU MEMORY CLEANUP ---
// Dispose all intermediate tensors created this epoch
for (var t in tracker) {
t.dispose();
}
}
// Final Test
List<Tensor> dummy = [];
final finalPred = model.forward(xData, dummy);
print("\nFinal Results:");
print(finalPred.printMatrix());
// Final cleanup
finalPred.dispose();
xData.dispose();
// target.dispose();
}