forward method

Tensor forward(
  1. List<int> indices
)

Forward pass: lookup embeddings for given token IDs.

indices is a list of integer token IDs. Returns tensor of shape (indices.length, embeddingDim).

Implementation

Tensor forward(List<int> indices) {
  final n = indices.length;
  final result = Float32List(n * embeddingDim);
  for (int i = 0; i < n; i++) {
    final idx = indices[i];
    if (idx < 0 || idx >= numEmbeddings) {
      throw RangeError(
          'Embedding index $idx out of range [0, $numEmbeddings)');
    }
    final srcOffset = idx * embeddingDim;
    final dstOffset = i * embeddingDim;
    for (int j = 0; j < embeddingDim; j++) {
      result[dstOffset + j] = weight.data[srcOffset + j];
    }
  }
  return Tensor(result, [n, embeddingDim]);
}