forward method
Implementation
Tensor forward(Tensor x, Tensor encoderOutput, List<Tensor> tracker) {
// 1. Masked Self-Attention
final attentionOutput1 = selfAttention.forward(
norm1.forward(x, tracker),
tracker,
);
final x1 = x + attentionOutput1; // Residual connection
// 2. Cross-Attention to Multimodal Context
// Q from x1, K/V from encoderOutput
final attentionOutput2 = crossAttention.forward(
norm2.forward(x1, tracker),
tracker,
kv: encoderOutput,
);
final x2 = x1 + attentionOutput2; // Residual connection
// 3. Feed-Forward
final ffOutput = ff2.forward(
ff1.forward(norm3.forward(x2, tracker), tracker),
tracker,
);
final x3 = x2 + ffOutput; // Residual connection
return x3;
}