exampleMultiHeadAttention function
void
exampleMultiHeadAttention()
Implementation
void exampleMultiHeadAttention() {
print("\n--- Example 2: MultiHeadAttention Shape Verification ---");
final embedSize = 16;
final numHeads = 4;
final sequenceLength = 5;
final mha = MultiHeadAttention(numHeads, embedSize, masked: true);
final x = List.generate(
sequenceLength,
(i) => ValueVector.fromDoubleList(
List.generate(embedSize, (j) => Random().nextDouble())));
print("Input sequence length: ${x.length}");
print("Input embedding size: ${x[0].values.length}");
final output = mha.forward(x);
print("Output sequence length: ${output.length}");
print("Output embedding size: ${output[0].values.length}");
assert(output.length == sequenceLength);
assert(output[0].values.length == embedSize);
print(
"MultiHeadAttention output shape is correct: ($sequenceLength, $embedSize)");
print("MultiHeadAttention parameters count: ${mha.parameters().length}");
}