trainModel function
void
trainModel(
- MultiLayerPerceptron model,
- List<
Float32List> images, - List<
int> labels, { - int epochs = 5,
- int batchSize = 32,
Implementation
void trainModel(
MultiLayerPerceptron model, List<Float32List> images, List<int> labels,
{int epochs = 5, int batchSize = 32}) {
for (int epoch = 0; epoch < epochs; epoch++) {
final losses = <Value>[];
// Reset gradients
// print("zeroing gradients");
model.zeroGrad();
int samplesLength = 0;
for (int i = 0; i < images.length; i++) {
final rand = Random().nextInt(images.length - 1);
final input = images[rand];
final target = List<double>.filled(10, 0.0);
target[labels[rand]] = 1.0; // One-hot encoding
// Compute loss for all samples
// print("Predicting");
final yPred = model.forward(ValueVector.fromFloat32List(input));
final yTrue = ValueVector.fromDoubleList(target);
final diff = yPred - yTrue;
final squared = diff.squared();
final sampleLoss = squared.mean();
// print("Sample loss: $sampleLoss");
losses.add(sampleLoss);
if (losses.length > batchSize) break;
samplesLength++;
}
final totalLoss = losses.reduce((a, b) => a + b);
final avgLoss = totalLoss * (1.0 / samplesLength);
// print("Average loss: $avgLoss");
// print("Performing back propagation");
avgLoss.backward();
// Gradient descent
// print("Weightd update");
model.updateWeights();
print('Epoch $epoch complete. Average Loss: $avgLoss');
}
}