main function

Future<void> main()

Implementation

Future<void> main() async {
  const lr = 0.002;
  final model = MultiLayerPerceptron(lr); // 4 inputs → 2 outputs
  const imagePath = 'lib/loader/mnist/train-images.idx3-ubyte';
  const labelPath = 'lib/loader/mnist/train-labels.idx1-ubyte';

  final mnist = await MNISTDataset.load(imagePath, labelPath);
  final images = normalizeImages(mnist.images);

  print('Loaded ${images.length} images and ${mnist.labels.length} labels.');
  // print("Labels: ${mnist.labels}");

  final inputs = List.generate(
      images.length, (int i) => ValueVector.fromDoubleList(images[i].toList()));
  // final inputs = [
  //   ValueVector.fromDoubleList(imgBytes
  //       .map((e) => e / 255.0)
  //       .toList()), // Correct casting might be needed
  //   ValueVector.fromDoubleList(imgBytes.map((e) => e / 255.0).toList()),
  //   ValueVector.fromDoubleList(imgBytes.map((e) => e / 255.0).toList())
  // ];
  final targets = List.generate(
      images.length,
      (int i) => ValueVector.fromDoubleList(List.generate(10, (index) {
            if (mnist.labels[i] == index) {
              return 1.0; // One-hot encoding
            } else {
              return 0.0;
            }
          }))); //  mnist.labels[i].map((e) => e.toDouble()).toList()));

  const epochs = 400;

  for (int epoch = 0; epoch < epochs; epoch++) {
    final losses = <Value>[];

    // Reset gradients
    int samples = 0;
    // Compute loss for all samples
    for (int i = 0; i < inputs.length; i++) {
      // final rand = Random().nextInt(images.length - 1);
      // final input = inputs[rand];

      final yPred = model.forward(inputs[i]);
      final yTrue = targets[i];

      // Use crossEntropy instead of squared error
      final sampleLoss = yPred
          .crossEntropy(yTrue); // Assuming yPred and yTrue are ValueVectors
      losses.add(sampleLoss);

      if (samples > 32) break;
      samples++;
    }

    final totalLoss = losses.reduce((a, b) => a + b);
    // final avgLoss = totalLoss * (1.0 / inputs.length);
    // avgLoss.backward();
    totalLoss.backward();

    // Gradient descent
    model.updateWeights();

    // if (epoch % 4 == 0) {
    //   print("Epoch $epoch | Loss = ${totalLoss.data.toStringAsFixed(4)}");
    // }
    // if (epoch % 4 == 0) {
    print("Epoch $epoch | Loss = ${totalLoss.data.toStringAsFixed(10)}");
    // }
    model.zeroGrad();
  }

  for (var input in inputs) {
    // Reset gradients
    for (var p in model.parameters()) {
      p.grad = 0;
    }
    // print("Input: ${input}");
    print("Output: ${model.forward(input)}");
    print("");
  }
}