forward method
Forward pass through the decoder.
inputIds: list of token IDs (single sequence)
encoderHiddenStates: (1, srcLen, encoderDim) — encoder output
cache: optional KV cache for auto-regressive decoding
pastLength: number of previously generated tokens (for position offset)
Returns (logits, updatedCache)
Implementation
(Tensor, KVCache) forward(
List<int> inputIds, {
required Tensor encoderHiddenStates,
KVCache? cache,
int pastLength = 0,
}) {
cache ??= KVCache(decoderLayers);
final seqLen = inputIds.length;
// Token embeddings (scaled)
var hidden = embedTokens.forward(inputIds).mulScalar(embedScale);
// Position embeddings (+2 offset as in mBART)
final posIds = List.generate(seqLen, (i) => pastLength + i + 2);
final posEmbeds = embedPositions.forward(posIds);
hidden = hidden + posEmbeds;
// Apply layernorm after embedding (mBART-specific)
hidden = layernormEmbedding.forward(hidden.reshape([seqLen, embedDim]));
// Reshape to (1, seqLen, embedDim) for batch processing
hidden = hidden.unsqueeze(0);
// Causal mask for self-attention
Tensor? causalMask;
if (seqLen > 1) {
causalMask = Tensor.causalMask(seqLen);
}
// Apply decoder layers
for (int i = 0; i < layers.length; i++) {
final (output, selfCache, crossCache) = layers[i].forward(
hidden,
encoderHiddenStates: encoderHiddenStates,
selfAttnPast: cache.selfAttnCache[i],
crossAttnPast: cache.crossAttnCache[i],
causalMask: causalMask,
);
hidden = output;
cache.selfAttnCache[i] = selfCache;
cache.crossAttnCache[i] = crossCache;
}
// Final layer norm
hidden = layerNorm.forward(hidden);
// LM head: project to vocabulary
final logits = lmHead.forward(hidden); // (1, seqLen, vocabSize)
return (logits, cache);
}