inference 1.0.0 copy "inference: ^1.0.0" to clipboard
inference: ^1.0.0 copied to clipboard

Zero-setup ML inference for Flutter using Rust engines (Candle, Linfa). Load PyTorch, train on-device with unified API. Cross-platform support.

Inference #

pub package pub points package publisher Flutter compatibility Dart compatibility GitHub

πŸš€ Install now | πŸ“– Documentation

🚧 Development Status: This package is under active development. The core API is implemented and functional, but some advanced features are still being added. Currently supports basic model loading and inference with Candle and Linfa engines.

Zero-setup machine learning inference for Flutter applications.

Inference brings the full power of modern ML engines (Candle, Linfa) to Flutter with a unified, developer-friendly API. Load models from anywhereβ€”assets, URLs, Hugging Face Hubβ€”run predictions on any platform, and even train models on-device, all with just a few lines of code.


Features #

  • πŸš€ Zero Configuration: Install and start using ML models immediately
  • 🌐 Universal Loading: Load from assets, URLs, files, or Hugging Face Hub
  • πŸ€— Hugging Face Ready: Direct integration with Hugging Face Hub models
  • πŸ”§ Unified API: One interface for PyTorch and classical ML models
  • πŸ“± Cross-Platform: Android, iOS, Windows, macOS, Linux support
  • ⚑ Hardware Acceleration: Automatic GPU/NPU detection and optimization
  • 🎯 Auto-Detection: Intelligent engine selection based on model format
  • πŸ’Ύ Smart Caching: Automatic model caching with size management
  • 🧠 On-Device Training: Train classical ML models directly on device
  • πŸ”’ Type-Safe: Full Dart type safety with comprehensive error handling
  • πŸ“Š Rich I/O: Built-in support for images, text, tensors, and audio

Quick Start #

Installation #

Add inference to your pubspec.yaml:

dependencies:
  inference: ^1.0.0

Or install via command line:

flutter pub add inference

Basic Usage #

import 'package:inference/inference.dart';

// Load any ML model with automatic engine detection
final model = await InferenceSession.load('assets/model.safetensors');

// Make predictions with type-safe inputs
final input = await ImageInput.fromAsset('assets/test_image.jpg');
final result = await model.predict(input);

// Access results with convenience methods
final topPrediction = result.topK(1).first;
print('Prediction: ${topPrediction.classIndex} (${topPrediction.confidence})');

// Clean up resources
model.dispose();

Model Loading Options #

Inference supports multiple ways to load models, giving you maximum flexibility:

1. Asset Loading (Bundled Models) #

// Load from app assets
final model = await InferenceSession.load('assets/models/classifier.safetensors');

2. URL Loading (Remote Models) #

// Load from any URL with automatic caching
final model = await InferenceSession.loadFromUrl(
  'https://example.com/models/classifier.safetensors',
  cache: true, // Enable caching (default)
);

// Custom cache key for organization
final model = await InferenceSession.loadFromUrl(
  'https://example.com/large_model.safetensors',
  cache: true,
  cacheKey: 'production_classifier_v2',
);

3. Hugging Face Hub Integration #

// Load directly from Hugging Face
final detector = await InferenceSession.loadFromHuggingFace(
  'qualcomm/EasyOCR',
  filename: 'EasyOCR.safetensors',
);

// Specify model revision
final model = await InferenceSession.loadFromHuggingFace(
  'microsoft/DialoGPT-medium',
  filename: 'pytorch_model.safetensors',
  revision: 'v1.0.0',
);

4. File System Loading #

// Load from local file system
final model = await InferenceSession.loadFromFile('/path/to/downloaded/model.safetensors');

5. Cache Management #

// Check cache size
final sizeBytes = await InferenceSession.getCacheSize();
print('Cache size: ${(sizeBytes / 1024 / 1024).toStringAsFixed(1)} MB');

// Clear cache when needed
await InferenceSession.clearCache();

Examples #

Real-World OCR with EasyOCR #

import 'package:inference/inference.dart';

class EasyOCRPipeline {
  late InferenceSession _detector;
  late InferenceSession _recognizer;
  
  Future<void> initialize() async {
    // Load both models from Hugging Face
    _detector = await InferenceSession.loadFromHuggingFace(
      'qualcomm/EasyOCR',
      filename: 'EasyOCR.safetensors', // Text detection model (79.2 MB)
    );
    
    _recognizer = await InferenceSession.loadFromHuggingFace(
      'qualcomm/EasyOCR',
      filename: 'EasyOCRRecognizer.safetensors', // Text recognition model (14.7 MB)
    );
  }
  
  Future<List<String>> extractText(String imagePath) async {
    // Step 1: Detect text regions
    final imageInput = await ImageInput.fromFile(File(imagePath));
    final detectionResult = await _detector.predict(imageInput);
    
    // Step 2: Extract text from each region
    final textRegions = parseDetectionResult(detectionResult);
    final extractedTexts = <String>[];
    
    for (final region in textRegions) {
      final regionInput = await ImageInput.fromCrop(imageInput, region);
      final recognitionResult = await _recognizer.predict(regionInput);
      final text = parseRecognitionResult(recognitionResult);
      extractedTexts.add(text);
    }
    
    return extractedTexts;
  }
  
  void dispose() {
    _detector.dispose();
    _recognizer.dispose();
  }
}

Image Classification with Model Download #

import 'package:inference/inference.dart';
import 'package:image_picker/image_picker.dart';

class ImageClassifier {
  late InferenceSession _model;
  
  Future<void> initialize() async {
    // Download and cache MobileNet from a remote source
    _model = await InferenceSession.loadFromUrl(
      'https://example.com/models/mobilenet_v2.safetensors',
      cache: true,
      cacheKey: 'mobilenet_v2_imagenet',
    );
  }
  
  Future<String> classifyImage() async {
    // Get image from camera
    final picker = ImagePicker();
    final image = await picker.pickImage(source: ImageSource.camera);
    if (image == null) return 'No image selected';
    
    // Create input and predict
    final input = await ImageInput.fromFile(File(image.path));
    final result = await _model.predict(input);
    
    // Get top prediction
    final prediction = result.topK(1).first;
    return 'Class: ${prediction.classIndex}, Confidence: ${(prediction.confidence * 100).toStringAsFixed(1)}%';
  }
  
  void dispose() => _model.dispose();
}

Text Sentiment Analysis #

import 'package:inference/inference.dart';

class SentimentAnalyzer {
  late InferenceSession _model;
  
  Future<void> initialize() async {
    _model = await InferenceSession.loadWithCandle('assets/bert_sentiment.safetensors');
  }
  
  Future<Map<String, dynamic>> analyzeSentiment(String text) async {
    final input = NLPInput(text);
    final result = await _model.predict(input);
    
    final isPositive = result.scalar > 0.5;
    final confidence = isPositive ? result.scalar : 1 - result.scalar;
    
    return {
      'sentiment': isPositive ? 'positive' : 'negative',
      'confidence': confidence,
      'score': result.scalar,
    };
  }
  
  void dispose() => _model.dispose();
}

On-Device Training #

import 'package:inference/inference.dart';

class OnDeviceTrainer {
  Future<InferenceSession> trainClustering(List<List<double>> data) async {
    // Train K-means clustering on device
    final model = await InferenceSession.trainLinfa(
      data: data,
      algorithm: 'kmeans',
      params: {
        'n_clusters': 3,
        'max_iterations': 100,
        'tolerance': 1e-4,
      },
    );
    
    return model;
  }
  
  Future<int> predictCluster(InferenceSession model, List<double> point) async {
    final input = TensorInput(point, [point.length]);
    final result = await model.predict(input);
    return result.argmax;
  }
}

Batch Processing #

import 'package:inference/inference.dart';

class BatchProcessor {
  late InferenceSession _model;
  
  Future<void> initialize() async {
    _model = await InferenceSession.load('assets/classifier.safetensors');
  }
  
  Future<List<InferenceResult>> processImages(List<String> imagePaths) async {
    // Create inputs for all images
    final inputs = await Future.wait(
      imagePaths.map((path) => ImageInput.fromFile(File(path))),
    );
    
    // Process all images in a single batch for better performance
    return await _model.predictBatch(inputs);
  }
  
  void dispose() => _model.dispose();
}

API Reference #

Core Classes #

InferenceSession

The main interface for ML inference sessions.

class InferenceSession {
  // Asset loading
  static Future<InferenceSession> load(String modelPath);
  
  // URL loading with caching
  static Future<InferenceSession> loadFromUrl(String url, {bool cache = true, String? cacheKey});
  
  // File system loading
  static Future<InferenceSession> loadFromFile(String filePath);
  
  // Hugging Face Hub integration
  static Future<InferenceSession> loadFromHuggingFace(String modelId, {required String filename, String? revision});
  
  // Engine-specific loading
  static Future<CandleSession> loadWithCandle(String modelPath);
  
  // On-device training
  static Future<LinfaSession> trainLinfa({
    required List<List<double>> data,
    required String algorithm,
    Map<String, dynamic>? params,
  });
  
  // Cache management
  static Future<void> clearCache();
  static Future<int> getCacheSize();
  
  // Inference methods
  Future<InferenceResult> predict(InferenceInput input);
  Future<List<InferenceResult>> predictBatch(List<InferenceInput> inputs);
  
  // Resource management
  void dispose();
  
  // Properties
  List<TensorSpec> get inputSpecs;
  List<TensorSpec> get outputSpecs;
  String get engine;
}

Input Types #

ImageInput

For computer vision models.

class ImageInput extends InferenceInput {
  // Constructors
  ImageInput({required Uint8List bytes, required int width, required int height, required int channels});
  
  // Convenience factories
  static Future<ImageInput> fromFile(File file);
  static Future<ImageInput> fromAsset(String assetPath);
  static Future<ImageInput> fromBytes(Uint8List bytes);
  static ImageInput.fromPixels({required Float32List pixels, required int width, required int height, required int channels});
}

NLPInput

For natural language processing models.

class NLPInput extends InferenceInput {
  NLPInput(String text, {String? tokenizer, List<int>? tokenIds});
  
  // Pre-tokenized input
  factory NLPInput.fromTokens(List<int> tokens);
}

TensorInput

For direct tensor data.

class TensorInput extends InferenceInput {
  TensorInput(List<double> data, List<int> shape);
  
  // Convenience factories
  factory TensorInput.fromList(List<List<double>> data);
  factory TensorInput.from3D(List<List<List<double>>> data);
}

AudioInput

For audio processing models.

class AudioInput extends InferenceInput {
  AudioInput({required Float32List samples, required int sampleRate});
  
  static Future<AudioInput> fromFile(File file);
}

Results #

InferenceResult

Contains prediction results with convenience accessors.

class InferenceResult {
  // Raw data access
  Float32List get data;
  List<int> get shape;
  String get dataType;
  
  // Convenience accessors
  double get scalar;                           // Single value
  List<double> get vector;                     // 1D array
  List<List<double>> get matrix;               // 2D array
  
  // Classification helpers
  int get argmax;                              // Index of maximum value
  List<ClassificationResult> topK([int k]);   // Top K predictions
  List<ClassificationResult> topKSoftmax([int k]); // Top K with softmax
}

ClassificationResult

Individual classification prediction.

class ClassificationResult {
  final int classIndex;
  final double confidence;
  final String? className;
}

Engine-Specific Sessions #

CandleSession

For PyTorch models with HuggingFace integration.

class CandleSession extends InferenceSession {
  // HuggingFace integration
  static Future<CandleSession> fromHuggingFace({
    required String repo,
    String? revision,
    String? filename,
  });
  
  // PyTorch model loading
  static Future<CandleSession> fromPyTorch(String safetensorsPath);
  
  // Custom architectures
  static Future<CandleSession> fromArchitecture({
    required String architecture,
    required String weightsPath,
  });
  
  // Device management (read-only properties)
  bool get isCudaAvailable;
  bool get isMklAvailable;
  String get device;
}

LinfaSession

For classical ML algorithms with on-device training.

class LinfaSession extends InferenceSession {
  // Clustering
  static Future<LinfaSession> trainKMeans({
    required List<List<double>> data,
    required int numClusters,
    int maxIterations = 100,
    double tolerance = 1e-4,
  });
  
  // Regression
  static Future<LinfaSession> trainLinearRegression({
    required List<List<double>> features,
    required List<double> targets,
    double? l1Ratio,
    double? l2Ratio,
  });
  
  // Classification
  static Future<LinfaSession> trainSVM({
    required List<List<double>> features,
    required List<int> labels,
    String kernel = 'rbf',
    Map<String, dynamic>? params,
  });
  
  // Decision trees
  static Future<LinfaSession> trainDecisionTree({
    required List<List<double>> features,
    required List<int> labels,
    int? maxDepth,
    int? minSamplesSplit,
  });
  
  // Model persistence
  Future<Uint8List> serialize();
  static Future<LinfaSession> deserialize(Uint8List bytes);
}

Supported Formats #

Engine Formats Use Cases Status
Candle .safetensors, .pt, .pth PyTorch models, HuggingFace models, Computer vision, NLP βœ… Core functionality implemented
Linfa Training data Classical ML, On-device training, Small datasets βœ… Basic training implemented

Current Model Architecture Support #

Currently Supported:

  • βœ… Generic SafeTensors Loading: Any SafeTensors model can be loaded
  • βœ… BERT Models: Text classification and NLP tasks
  • βœ… ResNet Models: Image classification
  • βœ… K-means Clustering: On-device training

In Development (see Model Wrappers Roadmap):

  • 🚧 Llama Models: Text generation
  • 🚧 Whisper Models: Speech recognition
  • 🚧 GPT-2 Models: Text generation
  • 🚧 20+ additional architectures: Comprehensive model support

Platform Support #

Platform Candle Linfa
Android βœ… βœ…
iOS βœ… βœ…
Windows βœ… βœ…
macOS βœ… βœ…
Linux βœ… βœ…

Performance Tips #

Memory Management #

Always dispose of sessions when done:

final model = await InferenceSession.load('model.safetensors');
try {
  // Use model...
} finally {
  model.dispose(); // Important: prevents memory leaks
}

Batch Processing #

Use batch predictions for better performance:

// βœ… Good: Process multiple inputs together
final results = await model.predictBatch(inputs);

// ❌ Avoid: Processing inputs one by one
final results = <InferenceResult>[];
for (final input in inputs) {
  results.add(await model.predict(input));
}

GPU Acceleration #

Enable GPU acceleration when available:

// Automatically detect and use best execution provider
final model = await InferenceSession.load('model.safetensors');
// GPU will be used automatically if available

Smart Caching #

Models downloaded from URLs are cached automatically:

// First load: downloads and caches
final model1 = await InferenceSession.loadFromUrl('https://example.com/model.safetensors');

// Second load: uses cache (much faster)
final model2 = await InferenceSession.loadFromUrl('https://example.com/model.safetensors');

// Manage cache size
final sizeBytes = await InferenceSession.getCacheSize();
if (sizeBytes > 500 * 1024 * 1024) { // If cache > 500MB
  await InferenceSession.clearCache();
}

Troubleshooting #

Common Issues #

Model loading fails

  • Verify the model file exists and is accessible
  • Check that the model format is supported (.safetensors, .pt, .pth)
  • Ensure sufficient memory is available
  • For URL loading: check internet connectivity and URL validity
  • For Hugging Face models: verify repository and filename exist

Slow inference

  • GPU acceleration is automatically enabled when available
  • Use batch processing for multiple inputs
  • Ensure model is properly optimized

Memory issues

  • Always call dispose() on sessions
  • Avoid loading multiple large models simultaneously
  • Monitor cache size and clear when necessary

Error Handling #

try {
  final model = await InferenceSession.load('model.safetensors');
  final result = await model.predict(input);
} on ModelLoadException catch (e) {
  print('Failed to load model: ${e.message}');
} on PredictionException catch (e) {
  print('Prediction failed: ${e.message}');
} on UnsupportedFormatException catch (e) {
  print('Format not supported: ${e.message}');
}

Contributing #

We welcome contributions! Please see our Contributing Guide for details.

Development Setup #

  1. Clone the repository
  2. Install dependencies: flutter pub get
  3. Run tests: flutter test
  4. Run example: cd example && flutter run

License #

This project is licensed under the MIT License - see the LICENSE file for details.

Acknowledgments #

  • Candle: Rust-based PyTorch implementation
  • Linfa: Rust machine learning toolkit
  • Flutter Rust Bridge: Seamless Rust-Flutter integration

Made with ❀️ by the Flutter community

9
likes
150
points
48
downloads

Publisher

verified publishershankarkakumani.dev

Weekly Downloads

Zero-setup ML inference for Flutter using Rust engines (Candle, Linfa). Load PyTorch, train on-device with unified API. Cross-platform support.

Repository (GitHub)
View/report issues

Topics

#machine-learning #artificial-intelligence #pytorch #rust

Documentation

API reference

License

MIT (license)

Dependencies

ffi, flutter, flutter_rust_bridge, freezed_annotation, image, meta, plugin_platform_interface

More

Packages that depend on inference

Packages that implement inference