forward method

List<ValueVector> forward(
  1. List<int> idx,
  2. List<ValueVector> encoderOutput
)

The forward pass for the Transformer Decoder model.

Takes a list of integer target token indices idx (shifted right for training) and the encoderOutput from the Transformer Encoder. Returns logits for the next token prediction.

Implementation

List<ValueVector> forward(List<int> idx, List<ValueVector> encoderOutput) {
  final T = idx.length;
  if (T > blockSize) {
    throw ArgumentError(
        "Input sequence length ($T) exceeds model's block size ($blockSize)");
  }

  // 1. Get token and position embeddings and sum them
  var x = List.generate(T, (t) {
    final tok_emb = tokenEmbeddings[idx[t]];
    final pos_emb = positionEmbeddings[t];
    return tok_emb + pos_emb;
  });

  // 2. Pass sequence through all transformer decoder blocks
  // Each block now also receives the encoder's output
  for (final block in blocks) {
    x = block.forward(x, encoderOutput); // Pass encoderOutput
  }

  // 3. Apply final layer norm
  x = List.generate(T, (t) => finalLayerNorm.forward(x[t]));

  // 4. Language model head to get final logits
  final logits = List.generate(T, (t) => lmHead.forward(x[t]));

  return logits;
}