create static method

Future<LitertEmbeddingModel> create({
  1. required String modelPath,
  2. required String tokenizerPath,
  3. int? inputSequenceLength,
  4. int? outputDimension,
  5. VoidCallback? onClose,
})

Load a .tflite embedding model from disk and prepare it for inference on CPU.

modelPath points at a .tflite file (Gecko 64, EmbeddingGemma 256, etc.). tokenizerPath is the matching SentencePiece .model or exported .json. Input sequence length and output dimension are auto-detected from the compiled model's tensor layouts; pass them to override (rare).

Caller owns the returned instance and must call close when done.

Implementation

static Future<LitertEmbeddingModel> create({
  required String modelPath,
  required String tokenizerPath,
  int? inputSequenceLength,
  int? outputDimension,
  VoidCallback? onClose,
}) async {
  final bindings = LiteRtBindings.open();

  // Load tokenizer first (file IO; native side hasn't started yet).
  final SentencePieceTokenizer tokenizer;
  if (tokenizerPath.endsWith('.json')) {
    tokenizer = await TokenizerJsonLoader.fromJsonFile(
      tokenizerPath,
      config: const SentencePieceConfig(),
    );
  } else {
    tokenizer = await SentencePieceTokenizer.fromModelFile(
      tokenizerPath,
      config: const SentencePieceConfig(),
    );
  }

  // Environment (CPU; can be extended later for GPU acceleration).
  final envPtr = calloc<LiteRtEnvironment>();
  bindings
      .createEnvironment(0, nullptr, envPtr)
      .check('LiteRtCreateEnvironment');
  final environment = envPtr.value;
  calloc.free(envPtr);

  // Model from .tflite file.
  final pathC = modelPath.toNativeUtf8();
  final modelPtr = calloc<LiteRtModel>();
  try {
    bindings
        .createModelFromFile(pathC, modelPtr)
        .check('LiteRtCreateModelFromFile($modelPath)');
  } finally {
    calloc.free(pathC);
  }
  final model = modelPtr.value;
  calloc.free(modelPtr);

  // Compilation options: CPU only.
  final optsPtr = calloc<LiteRtOptions>();
  bindings.createOptions(optsPtr).check('LiteRtCreateOptions');
  final options = optsPtr.value;
  calloc.free(optsPtr);
  bindings
      .setOptionsHardwareAccelerators(options, kLiteRtHwAcceleratorCpu)
      .check('LiteRtSetOptionsHardwareAccelerators');

  // Compile.
  final compiledPtr = calloc<LiteRtCompiledModel>();
  bindings
      .createCompiledModel(environment, model, options, compiledPtr)
      .check('LiteRtCreateCompiledModel');
  final compiled = compiledPtr.value;
  calloc.free(compiledPtr);

  // Auto-detect seqLen + dim from compiled tensor layouts unless the
  // caller pinned them. Embedding models we care about all have:
  //   input  shape [1, seqLen]   element_type=int32
  //   output shape [1, dim]      element_type=float32
  int seqLen, dim;
  if (inputSequenceLength == null) {
    final inLayout = LiteRtLayoutView.calloc();
    try {
      bindings
          .getInputTensorLayout(compiled, 0, 0, inLayout.pointer)
          .check('LiteRtGetCompiledModelInputTensorLayout');
      if (inLayout.rank < 2) {
        throw StateError(
            'Embedding model input has rank=${inLayout.rank}, expected >=2');
      }
      seqLen = inLayout.dimension(1);
    } finally {
      inLayout.free();
    }
  } else {
    seqLen = inputSequenceLength;
  }

  if (outputDimension == null) {
    final outLayouts = LiteRtLayoutView.calloc();
    try {
      bindings
          .getOutputTensorLayouts(compiled, 0, 1, outLayouts.pointer, false)
          .check('LiteRtGetCompiledModelOutputTensorLayouts');
      if (outLayouts.rank < 2) {
        throw StateError(
            'Embedding model output has rank=${outLayouts.rank}, expected >=2');
      }
      dim = outLayouts.dimension(1);
    } finally {
      outLayouts.free();
    }
  } else {
    dim = outputDimension;
  }

  debugPrint('[LitertEmbeddingModel] loaded: seqLen=$seqLen, dim=$dim');

  return LitertEmbeddingModel._(
    bindings: bindings,
    environment: environment,
    model: model,
    options: options,
    compiledModel: compiled,
    tokenizer: tokenizer,
    inputSequenceLength: seqLen,
    outputDimension: dim,
    onClose: onClose ?? () {},
  );
}