forward method
Forward pass for Multi-Head Attention.
Implementation
List<ValueVector> forward(List<ValueVector> x) {
// 1. Compute all head outputs in parallel
final headOutputs = heads.map((h) => h.forward(x)).toList();
// 2. Concatenate head outputs for each token position
final T = x.length;
final concatenated = List.generate(T, (i) {
final values =
headOutputs.expand((head_out) => head_out[i].values).toList();
return ValueVector(values);
});
// 3. Apply final projection layer
final out = concatenated.map((c) => proj.forward(c)).toList();
return out;
}