forward method
Implementation
Tensor forward(Tensor xDec, Tensor xEnc, List<Tensor> tracker) {
// --- 1. Masked Self-Attention (Pre-LN) ---
final xNorm1 = ln1.forward(xDec, tracker);
final selfAttnOut = selfAttention.forward(xNorm1, tracker);
final xRes1 = xDec + selfAttnOut;
tracker.add(xRes1);
// --- 2. Cross-Attention (Pre-LN) ---
final xNorm2 = ln2.forward(xRes1, tracker);
final crossAttnOut = crossAttention.forward(xNorm2, xEnc, tracker);
final xRes2 = xRes1 + crossAttnOut;
tracker.add(xRes2);
// --- 3. Feed-Forward (Expansion/Contraction) ---
final xNorm3 = ln3.forward(xRes2, tracker);
final ffnOut = ffn.forward(xNorm3, tracker);
final out = xRes2 + ffnOut;
// Final tracker addition for the block output
tracker.add(out);
_debugStats(out); // Monitor for explosions
return out;
}