dart_onnx 0.1.1 copy "dart_onnx: ^0.1.1" to clipboard
dart_onnx: ^0.1.1 copied to clipboard

A cross-platform Dart package for running ONNX models using ONNX Runtime via Dart FFI.

example/dart_onnx_example.dart

/// Example: Running SmolLM2-135M inference with dart_onnx.
///
/// This example loads the quantized SmolLM2-135M ONNX model and runs a single
/// forward pass to predict the next token for a short prompt.
///
/// ## Setup
///
/// First, download the model files (≈138 MB total):
///
/// ```
/// dart run tool/download_model.dart
/// ```
///
/// Then run this example from the package root:
///
/// ```
/// dart run example/dart_onnx_example.dart
/// ```
library;

import 'dart:io';
import 'dart:typed_data';
import 'package:dart_onnx/dart_onnx.dart';

// Model configuration (matches config.json from HuggingFace)

/// Number of transformer layers (num_hidden_layers).
const int kNumLayers = 30;

/// Number of key/value attention heads (num_key_value_heads).
const int kNumKvHeads = 3;

/// Dimension per attention head (head_dim).
const int kHeadDim = 64;

/// Vocabulary size.
const int kVocabSize = 49152;

/// End-of-sequence token ID.
const int kEosTokenId = 0;

// A minimal hard-coded tokenization of the prompt "Hello, I am".
//
// These IDs were looked up in SmolLM2's GPT-2-style BPE vocabulary
// (tokenizer.json) and are stable for this exact prompt.
//
// In a real application you would use a proper tokenizer library that reads
// `tokenizer.json` to tokenize arbitrary text.
const List<int> kPromptTokenIds = [
  12906, // "Hello"
  13, // ","
  309, // " I"
  837, // " am"
];

void main() {
  // Locate model file
  final scriptDir = File.fromUri(Platform.script).parent;
  final modelPath = '${scriptDir.path}/onnx/model_quantized.onnx';

  if (!File(modelPath).existsSync()) {
    stderr.writeln('Model not found at: $modelPath');
    stderr.writeln(
      'Run `dart run tool/download_model.dart` first to download it.',
    );
    exit(1);
  }

  // Initialize ONNX Runtime environment
  print('Initializing ONNX Runtime...');
  final env = DartONNX(loggingLevel: DartONNXLoggingLevel.warning);
  print('ORT version : ${env.ortVersion}');

  // Load the session
  print('Loading model: $modelPath');
  final session = DartONNXSession.fromFile(
    env,
    modelPath,
    executionProviders: [
      DartONNXExecutionProvider.coreML, // Apple Neural Engine (optional)
      DartONNXExecutionProvider.cpu, // Fallback
    ],
  );

  print('Input  names : ${session.inputNames}');
  print('Output names : ${session.outputNames}');

  // Build input tensors
  final seqLen = kPromptTokenIds.length;
  const int batch = 1;

  // input_ids  [1, seqLen]  — the token IDs of the prompt
  final inputIds = DartONNXTensor.int64(
    data: Int64List.fromList(kPromptTokenIds),
    shape: [batch, seqLen],
  );

  // attention_mask  [1, seqLen]  — all 1s (no padding)
  final attentionMask = DartONNXTensor.int64(
    data: Int64List.fromList(List.filled(seqLen, 1)),
    shape: [batch, seqLen],
  );

  // position_ids  [1, seqLen]  — 0, 1, 2, …, seqLen-1
  final positionIds = DartONNXTensor.int64(
    data: Int64List.fromList(List.generate(seqLen, (i) => i)),
    shape: [batch, seqLen],
  );

  // Build the input map with only the tensors the model expects.
  // If the session also expects past_key_values we provide zero tensors.
  final inputs = <String, DartONNXTensor>{
    'input_ids': inputIds,
    'attention_mask': attentionMask,
  };

  if (session.inputNames.contains('position_ids')) {
    inputs['position_ids'] = positionIds;
  }

  // Provide empty (zero) past_key_values for layer 0..N-1 when needed.
  // Shape: [batch, num_kv_heads, 0, head_dim]  — sequence dim = 0 means "empty cache".
  final kvInputNames = session.inputNames
      .where((n) => n.startsWith('past_key_values.'))
      .toList();

  if (kvInputNames.isNotEmpty) {
    print('Providing ${kvInputNames.length} empty past_key_value tensors...');
    for (final name in kvInputNames) {
      inputs[name] = DartONNXTensor.float32(
        data: Float32List(0),
        shape: [batch, kNumKvHeads, 0, kHeadDim],
      );
    }
  }

  // Run inference
  print('\nRunning forward pass on prompt: "Hello, I am"');
  print('Prompt token IDs: $kPromptTokenIds\n');

  final stopwatch = Stopwatch()..start();
  final outputs = session.run(inputs);
  stopwatch.stop();

  print('Inference time: ${stopwatch.elapsedMilliseconds} ms');

  // Read logits and show top-5 next-token predictions
  final logitsTensor = outputs['logits'];
  if (logitsTensor == null) {
    throw StateError('Expected "logits" in outputs but got: ${outputs.keys}');
  }

  // logits shape: [batch=1, seq_len, vocab_size]
  // We want the logits at the last token position → index [0, seqLen-1, :].
  final logitsData = logitsTensor.data as Float32List;
  final logitsShape = logitsTensor.shape;

  print('Logits shape  : $logitsShape');

  // Slice out last-token logits.
  final lastTokenOffset = (seqLen - 1) * kVocabSize;
  final lastTokenLogits = logitsData.sublist(
    lastTokenOffset,
    lastTokenOffset + kVocabSize,
  );

  // Compute top-5 by brute-force sort (vocab is small enough).
  final indexed = List.generate(kVocabSize, (i) => (i, lastTokenLogits[i]))
    ..sort((a, b) => b.$2.compareTo(a.$2));

  print('\nTop-5 next-token predictions (token_id → logit_score):');
  for (final (tokenId, score) in indexed.take(5)) {
    print('  token $tokenId → ${score.toStringAsFixed(4)}');
  }

  // Greedy next-token
  final nextTokenId = indexed.first.$1;
  print('\nGreedy next token ID : $nextTokenId');
  if (nextTokenId == kEosTokenId) {
    print('(EOS token — model wants to stop here)');
  }
  print('Decode with `tokenizer.json` to convert token IDs back to a string.');

  // Cleanup
  for (final t in inputs.values) {
    t.dispose();
  }
  for (final t in outputs.values) {
    t.dispose();
  }
  session.dispose();

  print('\nDone.');
}
1
likes
160
points
192
downloads

Documentation

API reference

Publisher

verified publisherneevash.com

Weekly Downloads

A cross-platform Dart package for running ONNX models using ONNX Runtime via Dart FFI.

Repository (GitHub)
View/report issues

License

MIT (license)

Dependencies

ffi

More

Packages that depend on dart_onnx