forward method

Tensor forward(
  1. Tensor x,
  2. List<Tensor> tracker
)

Implementation

Tensor forward(Tensor x, List<Tensor> tracker) {
  // 1. Compute each head on GPU
  final List<Tensor> headOutputs = heads
      .map((h) => h.forward(x, tracker))
      .toList();

  // 2. Concatenate results on GPU
  final concatenated = Tensor.concat(headOutputs);
  tracker.add(concatenated);

  // 3. Final Linear Projection
  return proj.forward(concatenated, tracker);
}