embed static method

Future<List<double>> embed(
  1. String text
)

Convert text to 384-dimensional embedding

Implementation

static Future<List<double>> embed(String text) async {
  if (_session == null) {
    throw Exception("EmbeddingService not initialized. Call init() first.");
  }

  // 1. Tokenize with Rust tokenizer
  final tokenIds = tokenize(text: text);

  if (debugMode) {
    print('[DEBUG] Text: "$text"');
    print('[DEBUG] Token IDs: $tokenIds (length: ${tokenIds.length})');
  }

  // 2. Generate attention_mask
  final seqLen = tokenIds.length;
  final attentionMask = List<int>.filled(seqLen, 1);

  // 3. Create ONNX input tensors
  final inputIdsData = Int64List.fromList(
    tokenIds.map((e) => e.toInt()).toList(),
  );
  final attentionMaskData = Int64List.fromList(
    attentionMask.map((e) => e.toInt()).toList(),
  );
  // BGE-m3 / many modern embedding models do not use token_type_ids
  // Passing it to a model that doesn't expect it causes an ONNX runtime error.

  final shape = [1, seqLen];

  final inputIdsTensor = OrtValueTensor.createTensorWithDataList(
    inputIdsData,
    shape,
  );
  final attentionMaskTensor = OrtValueTensor.createTensorWithDataList(
    attentionMaskData,
    shape,
  );

  // 4. Run inference
  final inputs = {
    'input_ids': inputIdsTensor,
    'attention_mask': attentionMaskTensor,
    // 'token_type_ids' removed
  };

  final runOptions = OrtRunOptions();
  final outputs = await _session!.runAsync(runOptions, inputs);

  // 5. Extract results and apply mean pooling
  final outputTensor = outputs?[0];
  if (outputTensor == null) {
    throw Exception("ONNX inference returned null output");
  }

  final outputData = outputTensor.value as List;

  if (debugMode) {
    print('[DEBUG] Output shape: ${_getShape(outputData)}');
  }

  // [1, seq_len, 384] -> mean pooling -> [384]
  List<double> embedding;
  if (outputData.isNotEmpty && outputData[0] is List) {
    // 3D output: [batch, seq_len, hidden]
    final batchData = outputData[0] as List;
    if (batchData.isNotEmpty && batchData[0] is List) {
      final hiddenSize = (batchData[0] as List).length;
      embedding = List<double>.filled(hiddenSize, 0.0);

      // Apply mean pooling over all tokens (with attention mask)
      // Includes CLS and SEP - sentence-transformers default behavior
      int count = 0;
      for (int t = 0; t < batchData.length; t++) {
        // Only include tokens with attention_mask == 1
        if (t < attentionMask.length && attentionMask[t] == 1) {
          final tokenEmb = batchData[t] as List;
          for (int h = 0; h < hiddenSize; h++) {
            embedding[h] += (tokenEmb[h] as num).toDouble();
          }
          count++;
        }
      }

      if (count > 0) {
        for (int h = 0; h < hiddenSize; h++) {
          embedding[h] /= count;
        }
      }

      if (debugMode) {
        print('[DEBUG] Embedding (first 5): ${embedding.take(5).toList()}');
      }
    } else {
      // 2D output: [batch, hidden]
      embedding = (batchData).map((e) => (e as num).toDouble()).toList();
    }
  } else {
    // 1D output: [hidden]
    embedding = outputData.map((e) => (e as num).toDouble()).toList();
  }

  // 6. Release resources
  inputIdsTensor.release();
  attentionMaskTensor.release();
  runOptions.release();
  for (final output in outputs ?? []) {
    output?.release();
  }

  return embedding;
}