forward method

  1. @override
Tensor<Matrix> forward(
  1. Tensor input
)
override

The core logic of the layer's transformation.

Subclasses must implement this method to define how they process input tensors and return an output tensor.

Implementation

@override
Tensor<Matrix> forward(Tensor<dynamic> input) {
  List<Tensor<Matrix>> headOutputs = [];
  for (SingleHeadAttention head in attentionHeads) {
    // Call the head, which correctly returns its output Tensor.
    Tensor<Matrix> headOutput = head.call(input) as Tensor<Matrix>;
    headOutputs.add(headOutput);
  }

  Tensor<Matrix> concatenatedOutput = concatenateMatricesByColumn(headOutputs);
  Tensor<Matrix> finalOutput = matMul(concatenatedOutput, Wo);

  return finalOutput;
}