forward method

Tensor forward(
  1. List<int> textTokens,
  2. Tensor encoderOutput,
  3. List<Tensor> tracker
)

Implementation

Tensor forward(
  List<int> textTokens,
  Tensor encoderOutput,
  List<Tensor> tracker,
) {
  final int T = textTokens.length;
  if (T > maxSeqLen) {
    throw ArgumentError("Text sequence length exceeds maxSeqLen");
  }

  // 1. Embed text tokens and add positional information
  final tokenEmbeds = Tensor.embeddings(textTokens, wte, wpe);
  tracker.add(tokenEmbeds);

  // 2. Pass through decoder blocks
  Tensor x = tokenEmbeds;
  for (final block in blocks) {
    x = block.forward(x, encoderOutput, tracker);
  }

  // 3. Final LayerNorm and Language Model Head
  final normalized = finalLayerNorm.forward(x, tracker);
  return lmHead.forward(
    normalized,
    tracker,
  ); // Logits for each vocabulary token
}