forward method

(Tensor, KVCache) forward(
  1. List<int> inputIds, {
  2. required Tensor encoderHiddenStates,
  3. KVCache? cache,
  4. int pastLength = 0,
})

Forward pass through the decoder.

inputIds: list of token IDs (single sequence) encoderHiddenStates: (1, srcLen, encoderDim) — encoder output cache: optional KV cache for auto-regressive decoding pastLength: number of previously generated tokens (for position offset)

Returns (logits, updatedCache)

Implementation

(Tensor, KVCache) forward(
  List<int> inputIds, {
  required Tensor encoderHiddenStates,
  KVCache? cache,
  int pastLength = 0,
}) {
  cache ??= KVCache(decoderLayers);
  final seqLen = inputIds.length;

  // Token embeddings (scaled)
  var hidden = embedTokens.forward(inputIds).mulScalar(embedScale);

  // Position embeddings (+2 offset as in mBART)
  final posIds = List.generate(seqLen, (i) => pastLength + i + 2);
  final posEmbeds = embedPositions.forward(posIds);
  hidden = hidden + posEmbeds;

  // Apply layernorm after embedding (mBART-specific)
  hidden = layernormEmbedding.forward(hidden.reshape([seqLen, embedDim]));

  // Reshape to (1, seqLen, embedDim) for batch processing
  hidden = hidden.unsqueeze(0);

  // Causal mask for self-attention
  Tensor? causalMask;
  if (seqLen > 1) {
    causalMask = Tensor.causalMask(seqLen);
  }

  // Apply decoder layers
  for (int i = 0; i < layers.length; i++) {
    final (output, selfCache, crossCache) = layers[i].forward(
      hidden,
      encoderHiddenStates: encoderHiddenStates,
      selfAttnPast: cache.selfAttnCache[i],
      crossAttnPast: cache.crossAttnCache[i],
      causalMask: causalMask,
    );
    hidden = output;
    cache.selfAttnCache[i] = selfCache;
    cache.crossAttnCache[i] = crossCache;
  }

  // Final layer norm
  hidden = layerNorm.forward(hidden);

  // LM head: project to vocabulary
  final logits = lmHead.forward(hidden); // (1, seqLen, vocabSize)

  return (logits, cache);
}