forward method
Forward pass.
hiddenStates: (batch, seqLen, embedDim)
encoderHiddenStates: (batch, srcLen, embedDim) — encoder output
selfAttnPast: cached self-attention K,V
crossAttnPast: cached cross-attention K,V
causalMask: causal attention mask
Returns (output, selfAttnCache, crossAttnCache)
Implementation
(Tensor, (Tensor, Tensor), (Tensor, Tensor)) forward(
Tensor hiddenStates, {
required Tensor encoderHiddenStates,
(Tensor, Tensor)? selfAttnPast,
(Tensor, Tensor)? crossAttnPast,
Tensor? causalMask,
}) {
var residual = hiddenStates;
// 1. Self-attention
hiddenStates = selfAttnLayerNorm.forward(hiddenStates);
final (selfAttnOut, selfAttnCache) = selfAttn.forward(
hiddenStates,
pastKeyValue: selfAttnPast,
attentionMask: causalMask,
);
hiddenStates = residual + selfAttnOut;
// 2. Cross-attention
residual = hiddenStates;
hiddenStates = encoderAttnLayerNorm.forward(hiddenStates);
final (crossAttnOut, crossAttnCache) = encoderAttn.forward(
hiddenStates,
keyValueStates: encoderHiddenStates,
pastKeyValue: crossAttnPast,
);
hiddenStates = residual + crossAttnOut;
// 3. FFN
residual = hiddenStates;
hiddenStates = finalLayerNorm.forward(hiddenStates);
hiddenStates = fc1.forward(hiddenStates);
hiddenStates = hiddenStates.gelu();
hiddenStates = fc2.forward(hiddenStates);
hiddenStates = residual + hiddenStates;
return (hiddenStates, selfAttnCache, crossAttnCache);
}