forward method

Tensor forward(
  1. Tensor xDec,
  2. Tensor xEnc,
  3. List<Tensor> tracker
)

Implementation

Tensor forward(Tensor xDec, Tensor xEnc, List<Tensor> tracker) {
  // --- 1. Masked Self-Attention (Pre-LN) ---
  final xNorm1 = ln1.forward(xDec, tracker);
  final selfAttnOut = selfAttention.forward(xNorm1, tracker);
  final xRes1 = xDec + selfAttnOut;
  tracker.add(xRes1);

  // --- 2. Cross-Attention (Pre-LN) ---
  final xNorm2 = ln2.forward(xRes1, tracker);
  final crossAttnOut = crossAttention.forward(xNorm2, xEnc, tracker);
  final xRes2 = xRes1 + crossAttnOut;
  tracker.add(xRes2);

  // --- 3. Feed-Forward (Expansion/Contraction) ---
  final xNorm3 = ln3.forward(xRes2, tracker);
  final ffnOut = ffn.forward(xNorm3, tracker);
  final out = xRes2 + ffnOut;

  // Final tracker addition for the block output
  tracker.add(out);

  _debugStats(out); // Monitor for explosions
  return out;
}