forward method

(Tensor, (Tensor, Tensor), (Tensor, Tensor)) forward(
  1. Tensor hiddenStates, {
  2. required Tensor encoderHiddenStates,
  3. (Tensor, Tensor)? selfAttnPast,
  4. (Tensor, Tensor)? crossAttnPast,
  5. Tensor? causalMask,
})

Forward pass.

hiddenStates: (batch, seqLen, embedDim) encoderHiddenStates: (batch, srcLen, embedDim) — encoder output selfAttnPast: cached self-attention K,V crossAttnPast: cached cross-attention K,V causalMask: causal attention mask

Returns (output, selfAttnCache, crossAttnCache)

Implementation

(Tensor, (Tensor, Tensor), (Tensor, Tensor)) forward(
  Tensor hiddenStates, {
  required Tensor encoderHiddenStates,
  (Tensor, Tensor)? selfAttnPast,
  (Tensor, Tensor)? crossAttnPast,
  Tensor? causalMask,
}) {
  var residual = hiddenStates;

  // 1. Self-attention
  hiddenStates = selfAttnLayerNorm.forward(hiddenStates);
  final (selfAttnOut, selfAttnCache) = selfAttn.forward(
    hiddenStates,
    pastKeyValue: selfAttnPast,
    attentionMask: causalMask,
  );
  hiddenStates = residual + selfAttnOut;

  // 2. Cross-attention
  residual = hiddenStates;
  hiddenStates = encoderAttnLayerNorm.forward(hiddenStates);
  final (crossAttnOut, crossAttnCache) = encoderAttn.forward(
    hiddenStates,
    keyValueStates: encoderHiddenStates,
    pastKeyValue: crossAttnPast,
  );
  hiddenStates = residual + crossAttnOut;

  // 3. FFN
  residual = hiddenStates;
  hiddenStates = finalLayerNorm.forward(hiddenStates);
  hiddenStates = fc1.forward(hiddenStates);
  hiddenStates = hiddenStates.gelu();
  hiddenStates = fc2.forward(hiddenStates);
  hiddenStates = residual + hiddenStates;

  return (hiddenStates, selfAttnCache, crossAttnCache);
}