main function
void
main()
Implementation
void main() async {
const trainImagePath = 'lib/loader/mnist/train-images.idx3-ubyte';
const trainLabelPath = 'lib/loader/mnist/train-labels.idx1-ubyte';
final mnist = await MNISTDataset.load(trainImagePath, trainLabelPath);
print(
'Loaded ${mnist.images.length} images and ${mnist.labels.length} labels.');
// Initialize Neural Network
print("Creating model");
final model = MultiLayerPerceptron(
// inputSize: 784, // 28x28 images flattened
// hiddenSizes: [128, 128], // Two hidden layers
// outputSize: 10, // Digits 0-9
// learningRate: 0.01,
0.05);
// Train the model
print("Training model");
trainModel(model, mnist.images, mnist.labels, epochs: 3000, batchSize: 32);
// Test on first image
final prediction =
model.forward(ValueVector.fromFloat32List(mnist.images[0]));
final maxVal = prediction.values.map((n) => n.data).reduce(max);
print(
'Predicted digit: ${prediction.values.indexWhere((v) => v.data == maxVal)}');
print('Expected digit: ${mnist.labels[0]}');
}