trainModel function

void trainModel(
  1. MultiLayerPerceptron model,
  2. List<Float32List> images,
  3. List<int> labels, {
  4. int epochs = 5,
  5. int batchSize = 32,
})

Implementation

void trainModel(
    MultiLayerPerceptron model, List<Float32List> images, List<int> labels,
    {int epochs = 5, int batchSize = 32}) {
  for (int epoch = 0; epoch < epochs; epoch++) {
    final losses = <Value>[];

    // Reset gradients
    // print("zeroing gradients");
    model.zeroGrad();
    int samplesLength = 0;
    for (int i = 0; i < images.length; i++) {
      final rand = Random().nextInt(images.length - 1);
      final input = images[rand];
      final target = List<double>.filled(10, 0.0);
      target[labels[rand]] = 1.0; // One-hot encoding

      // Compute loss for all samples
      // print("Predicting");
      final yPred = model.forward(ValueVector.fromFloat32List(input));
      final yTrue = ValueVector.fromDoubleList(target);
      final diff = yPred - yTrue;
      final squared = diff.squared();
      final sampleLoss = squared.mean();

      // print("Sample loss: $sampleLoss");
      losses.add(sampleLoss);
      if (losses.length > batchSize) break;
      samplesLength++;
    }

    final totalLoss = losses.reduce((a, b) => a + b);
    final avgLoss = totalLoss * (1.0 / samplesLength);
    // print("Average loss: $avgLoss");

    // print("Performing back propagation");
    avgLoss.backward();

    // Gradient descent
    // print("Weightd update");
    model.updateWeights();

    print('Epoch $epoch complete. Average Loss: $avgLoss');
  }
}