forward method

Tensor forward(
  1. List<int> idx,
  2. Tensor encoderOutput,
  3. List<Tensor> tracker
)

Implementation

Tensor forward(List<int> idx, Tensor encoderOutput, List<Tensor> tracker) {
  final int T = idx.length;

  if (T > blockSize) {
    throw ArgumentError(
      "Sequence length $T exceeds max block size $blockSize",
    );
  }

  // 1. GPU Embedding Lookup
  // We pass T explicitly so the kernel knows how many rows of wpe to use
  Tensor x = Tensor.embeddings(idx, wte, wpe);
  tracker.add(x);

  // 2. Transformer Blocks
  for (final block in blocks) {
    x = block.forward(x, encoderOutput, tracker);
  }

  // 3. Final Norm & Head
  final xNorm = finalLayerNorm.forward(x, tracker);
  return lmHead.forward(xNorm, tracker);
}