main function

void main()

Implementation

void main() async {
  const trainImagePath = 'lib/loader/mnist/train-images.idx3-ubyte';
  const trainLabelPath = 'lib/loader/mnist/train-labels.idx1-ubyte';

  final mnist = await MNISTDataset.load(trainImagePath, trainLabelPath);
  print(
      'Loaded ${mnist.images.length} images and ${mnist.labels.length} labels.');

  // Initialize Neural Network
  print("Creating model");
  final model = MultiLayerPerceptron(
      // inputSize: 784, // 28x28 images flattened
      // hiddenSizes: [128, 128], // Two hidden layers
      // outputSize: 10, // Digits 0-9
      // learningRate: 0.01,
      0.05);

  // Train the model
  print("Training model");
  trainModel(model, mnist.images, mnist.labels, epochs: 3000, batchSize: 32);

  // Test on first image
  final prediction =
      model.forward(ValueVector.fromFloat32List(mnist.images[0]));

  final maxVal = prediction.values.map((n) => n.data).reduce(max);
  print(
      'Predicted digit: ${prediction.values.indexWhere((v) => v.data == maxVal)}');
  print('Expected digit: ${mnist.labels[0]}');
}