forward method
The forward pass for the Transformer Decoder model.
Takes a list of integer target token indices idx (shifted right for training)
and the encoderOutput from the Transformer Encoder.
Returns logits for the next token prediction.
Implementation
List<ValueVector> forward(List<int> idx, List<ValueVector> encoderOutput) {
final T = idx.length;
if (T > blockSize) {
throw ArgumentError(
"Input sequence length ($T) exceeds model's block size ($blockSize)");
}
// 1. Get token and position embeddings and sum them
var x = List.generate(T, (t) {
final tok_emb = tokenEmbeddings[idx[t]];
final pos_emb = positionEmbeddings[t];
return tok_emb + pos_emb;
});
// 2. Pass sequence through all transformer decoder blocks
// Each block now also receives the encoder's output
for (final block in blocks) {
x = block.forward(x, encoderOutput); // Pass encoderOutput
}
// 3. Apply final layer norm
x = List.generate(T, (t) => finalLayerNorm.forward(x[t]));
// 4. Language model head to get final logits
final logits = List.generate(T, (t) => lmHead.forward(x[t]));
return logits;
}