forward method

Tensor forward(
  1. Tensor x,
  2. Tensor encoderOutput,
  3. List<Tensor> tracker
)

Implementation

Tensor forward(Tensor x, Tensor encoderOutput, List<Tensor> tracker) {
  // 1. Masked Self-Attention
  final attentionOutput1 = selfAttention.forward(
    norm1.forward(x, tracker),
    tracker,
  );
  final x1 = x + attentionOutput1; // Residual connection

  // 2. Cross-Attention to Multimodal Context
  // Q from x1, K/V from encoderOutput
  final attentionOutput2 = crossAttention.forward(
    norm2.forward(x1, tracker),
    tracker,
    kv: encoderOutput,
  );
  final x2 = x1 + attentionOutput2; // Residual connection

  // 3. Feed-Forward
  final ffOutput = ff2.forward(
    ff1.forward(norm3.forward(x2, tracker), tracker),
    tracker,
  );
  final x3 = x2 + ffOutput; // Residual connection

  return x3;
}