forward method

  1. @override
Tensor<Matrix> forward(
  1. Tensor input
)
override

The core logic of the layer's transformation.

Subclasses must implement this method to define how they process input tensors and return an output tensor.

Implementation

@override
Tensor<Matrix> forward(Tensor<dynamic> input) {
  Tensor<Matrix> inputMatrix = input as Tensor<Matrix>;

  Tensor<Matrix> attentionOutput = mha.call(inputMatrix) as Tensor<Matrix>;
  Tensor<Matrix> addAndNorm1 =
  layerNorm1.call(addMatrix(inputMatrix, attentionOutput)) as Tensor<Matrix>;

  Tensor<Matrix> ffnOutput = ffn.call(addAndNorm1) as Tensor<Matrix>;
  Tensor<Matrix> addAndNorm2 =
  layerNorm2.call(addMatrix(addAndNorm1, ffnOutput)) as Tensor<Matrix>;

  return addAndNorm2;
}