forward method

List<ValueVector> forward(
  1. List<ValueVector> x
)

Forward pass for Multi-Head Attention.

Implementation

List<ValueVector> forward(List<ValueVector> x) {
  // 1. Compute all head outputs in parallel
  final headOutputs = heads.map((h) => h.forward(x)).toList();

  // 2. Concatenate head outputs for each token position
  final T = x.length;
  final concatenated = List.generate(T, (i) {
    final values =
        headOutputs.expand((head_out) => head_out[i].values).toList();
    return ValueVector(values);
  });

  // 3. Apply final projection layer
  final out = concatenated.map((c) => proj.forward(c)).toList();
  return out;
}