forward method
Implementation
Tensor forward(
List<int> textTokens,
Tensor encoderOutput,
List<Tensor> tracker,
) {
final int T = textTokens.length;
if (T > maxSeqLen) {
throw ArgumentError("Text sequence length exceeds maxSeqLen");
}
// 1. Embed text tokens and add positional information
final tokenEmbeds = Tensor.embeddings(textTokens, wte, wpe);
tracker.add(tokenEmbeds);
// 2. Pass through decoder blocks
Tensor x = tokenEmbeds;
for (final block in blocks) {
x = block.forward(x, encoderOutput, tracker);
}
// 3. Final LayerNorm and Language Model Head
final normalized = finalLayerNorm.forward(x, tracker);
return lmHead.forward(
normalized,
tracker,
); // Logits for each vocabulary token
}