flutter_ml
A deep learning library for Dart, built from the ground up with a pure Dart autograd engine.
This project provides a Keras-inspired high-level API (SNetwork) for building and training neural networks, powered by a custom automatic differentiation engine written purely in Dart.
A Library for Learning
This package is designed for learning. To get started, I recommend looking directly into the Tensor class. It contains the bare-bones building blocks of the autograd engine and reveals everything about how gradient propagation works.
The entire structure is open. All math operations, layers, and optimizers are written in simple, readable Dart. This makes it incredibly easy to debug, experiment with, and understand how a complex system like this works.
This transparency comes at a performance cost (Currently working on adding cuda support as well as model loading saving) but it makes the engine a perfect tool for understanding why things work the way they do.
Features
This package includes a wide range of features, from the low-level engine to high-level model APIs.
- Pure Dart Autograd Engine: Automatically computes gradients for any model.
- Dynamic Computational Graphs: Graphs are built on the fly, just like in PyTorch.
- Familiar Layer API: A general
Layerstructure inspired by Keras and PyTorch. - High-Level Model API: The
SNetworkclass for simple model stacking, compiling, and training. - Detailed Graph Printouts: A custom logger can print the entire computational graph.
- Modern Architecture Support: Includes all building blocks for Transformer networks.
- Comprehensive Components: A wide array of layers, optimizers, and activation functions are ready to use.
Full Component List
- Core Layers
DenseLayer: Standard fully-connected layer forVectordata.DenseLayerMatrix: Batch-processed fully-connected layer forMatrixdata.
- Recurrent Layers
RNN: A simple Recurrent Neural Network (RNN) layer.LSTMLayer: A Long Short-Term Memory (LSTM) recurrent layer.DualLSTMLayer: A custom hierarchical LSTM with two internal tiers (fast and slow).MultiTierLSTMLayer: A generalized, hierarchical LSTM with an arbitrary number of configured timescales.GeneralizedChainedScaleLayer: A self-contained, multi-scale recurrent layer for processing high-frequency data.
- Convolutional Layers
Conv2DLayer: A 2D convolutional layer for 2DMatrixinputs.ConvLSTMLayer: A Convolutional LSTM layer for spatiotemporal data (like video).
- Pooling Layers
MaxPooling1DLayer/MaxPooling2DLayerAveragePooling2DLayerGlobalAveragePoolingLayer/GlobalAveragePooling1D
- Transformer Layers
EmbeddingLayer/EmbeddingLayerMatrix: Converts token indices to dense vectors.PositionalEncoding: Injects sinusoidal position information.SingleHeadAttention: A single head of the self-attention mechanism.MultiHeadAttention: Runs multipleSingleHeadAttentionheads in parallel.TransformerEncoderBlock: A full encoder block (Attention + FeedForward).
- Normalization & Utility Layers
BatchNorm1D/BatchNorm2D: Batch Normalization for 1D and 3D data.LayerNormalization/LayerNormalizationVector: Layer Normalization forMatrixandVectordata.DropoutLayer/DropoutLayerMatrix: Regularization layer forVectorandMatrixinputs.FlattenLayer: Reshapes aMatrixinto aVector.ReshapeVectorToMatrixLayer: Reshapes aVectorto a 1x1Matrix.
- Activation Layers (as standalone layers)
ReLULayer/ReLULayerMatrix
Available Optimizers
SGD: The standard Stochastic Gradient Descent.Momentum:SGDwith the addition of a momentum (velocity) term.NAG: Nesterov Accelerated Gradient, an improvement onMomentum.Adagrad: An adaptive optimizer good for sparse data.RMSprop: An adaptive optimizer that performs well with RNNs.Adam: The most common, general-purpose adaptive optimizer.AMSGrad: A variant ofAdamthat fixes a potential convergence issue.AdamW: A variant ofAdamthat improves weight decay (L2 regularization).
Available Activation Functions
ReLULeakyReLUELUSigmoid(For binary classification outputs)Softmax(For multi-class classification outputs)SiLU/SwishMish
Quick Start: The XOR Example
This example shows the full workflow: defining a model, compiling it, and training it to solve the classic XOR problem.
/* A complete, runnable example of training a simple network
to solve the XOR problem.
*/
void main() {
// --- 1. Define XOR Dataset ---
final List<Vector> xorInputs = [
[0.0, 0.0],
[0.0, 1.0],
[1.0, 0.0],
[1.0, 1.0]
];
final List<Vector> xorTargets = [
[0.0], // 0 XOR 0 = 0
[1.0], // 0 XOR 1 = 1
[1.0], // 1 XOR 0 = 1
[0.0] // 1 XOR 1 = 0
];
// --- 2. Build the SNetwork (Simple Sequential Model) ---
final SNetwork model = SNetwork([
// Hidden Layer (2 -> 2)
DenseLayer(2, activation: ReLU()),
// Output Layer (2 -> 1)
DenseLayer(1, activation: Sigmoid())
], name: 'XOR-Net');
// --- 3. Build & Compile the Network ---
// This first call to `predict` is necessary to build the layers
// and initialize their parameters before the optimizer needs them.
model.predict(Tensor<Vector>(xorInputs[0]));
// The 'model.parameters' list is now populated.
final SGD optimizer = SGD(model.parameters, learningRate: 0.1);
model.compile(configuredOptimizer: optimizer);
// --- 4. Train the Network ---
print('Training ${model.name} for 5000 epochs...');
model.fit(xorInputs, xorTargets, epochs: 5000, debug: true);
// --- 5. Evaluate and Test ---
print('\n--- Testing Predictions ---');
for (int i = 0; i < xorInputs.length; i++) {
final Tensor<Vector> inputTensor = Tensor<Vector>(xorInputs[i]);
final Tensor<Vector> prediction = model.predict(inputTensor) as Tensor<Vector>;
final int target = xorTargets[i][0].toInt();
final double rawOutput = prediction.value[0];
final int predictedClass = (rawOutput > 0.5) ? 1 : 0;
print('Input: ${xorInputs[i]}, Target: $target, Predicted: $predictedClass');
}
}
Advanced Showcase: Building a Transformer
This engine is powerful enough to run modern architectures. The library includes all the necessary building blocks to create a full Transformer Encoder from scratch.
This example shows how to build a sentiment classifier by stacking the available Transformer components.
/*
This code snippet shows how all the building blocks are
assembled into a single SNetwork for sentiment analysis.
*/
// --- 1. Define Model Hyperparameters ---
int vocabSize = 15; // How many unique words in our vocabulary
int dModel = 16; // The "width" of the model (embedding dimension)
int numHeads = 2; // Number of attention heads
int dff = 32; // Hidden dimension of the feed-forward network
int maxSequenceLength = 10; // Max sentence length for positional encoding
// --- 2. Assemble the SNetwork ---
SNetwork sentimentClassifier = SNetwork([
// 1. Convert word indices to vectors
EmbeddingLayer(vocabSize, dModel),
// 2. Add word order information
PositionalEncoding(maxSequenceLength, dModel),
// 3. Run through two Transformer blocks
TransformerEncoderBlock(dModel, numHeads, dff),
TransformerEncoderBlock(dModel, numHeads, dff),
// 4. Pool the final sequence into a single vector
GlobalAveragePooling1D(),
// 5. Classify the vector (0.0 = negative, 1.0 = positive)
DenseLayer(1, activation: Sigmoid()),
]);
// --- 3. Build, Compile, and Train ---
// (Build the model with a dummy input)
sentimentClassifier.predict(Tensor<Vector>([1, 2, 3]));
// (Compile with an optimizer)
sentimentClassifier.compile(
configuredOptimizer: Adam(sentimentClassifier.parameters, learningRate: 0.01)
);
// (Train the model)
// sentimentClassifier.fit(inputs, targets, epochs: 100);
❗️ Status: Work in Progress
This library is in the early stages of development. It is a large, ongoing project. The API is not yet stable and is subject to change.
The entire engine is written in pure Dart. While this is fantastic for learning and debugging, it is not yet optimized for performance. This is an experimentation and learning package, not a production-ready training framework (yet).
Installation
This package is not yet published on pub.dev. To use it, add it to your pubspec.yaml as a git dependency:
dependencies:
flutter_ml:
git:
url: https://github.com/your-username/flutter_ml.git
ref: main
Contributing & Future Roadmap
This project is a long-term effort. Contributions are extremely welcome. If you are interested in machine learning or high-performance Dart, feel free to open an issue or submit a pull request.
Roadmap:
Performance optimization (the biggest task).Add more loss functions (e.g.,CrossEntropy).Add more metrics (e.g.,Accuracy,Precision).Improve theSNetwork.fit()method withbatch_size, validation splits, and callbacks.Long-term: Explore Dart FFI for GPU acceleration via C++ libraries.
How to Contribute:
- If you find a bug or have a feature request, please file an issue.
- Feel free to fork the repository and submit a pull request.
Libraries
- activationFunctions/activation_funciton
- activationFunctions/elu
- activationFunctions/leakyRelu
- activationFunctions/mish
- activationFunctions/relu
- activationFunctions/sigmoid
- activationFunctions/silu
- activationFunctions/softmax
- autogradEngine/tensor
- diagnosysTools/logger
- examples/hemingshakes
- examples/stock_predict
- flutter_ml
- layertypes/averagePooling
- layertypes/batchNormalizationLayer
- layertypes/conv2d
- layertypes/convlstmLayer
- layertypes/denseLayer
- layertypes/dropout
- layertypes/duallstm
- layertypes/flattenLayer
- layertypes/layer
- layertypes/lstmLayer
- layertypes/maxPooling
- layertypes/multilstm
- layertypes/reluLayer
- layertypes/rnnLayer
- layertypes/singleHeadAttentionLayer
- layertypes/trendmodelLayer
- nets/snet
- optimizers/adagrad
- optimizers/adam
- optimizers/adamw
- optimizers/amsgrad
- optimizers/nag
- optimizers/optimizers
- optimizers/rmsprop
- optimizers/sgd
- optimizers/sgdmomentum
- transformer/embeddingLayer
- transformer/finalModel
- transformer/globalpoolingLayer
- transformer/layerNormalization
- transformer/multiHeadAttentionLayer
- transformer/positionalEncodingLayer
- transformer/transformer
- transformer/transformerEncodingLayer