exampleMultiHeadAttention function

void exampleMultiHeadAttention()

Implementation

void exampleMultiHeadAttention() {
  print("\n--- Example 2: MultiHeadAttention Shape Verification ---");

  final embedSize = 16;
  final numHeads = 4;
  final sequenceLength = 5;

  final mha = MultiHeadAttention(numHeads, embedSize, masked: true);

  final x = List.generate(
      sequenceLength,
      (i) => ValueVector.fromDoubleList(
          List.generate(embedSize, (j) => Random().nextDouble())));

  print("Input sequence length: ${x.length}");
  print("Input embedding size: ${x[0].values.length}");

  final output = mha.forward(x);

  print("Output sequence length: ${output.length}");
  print("Output embedding size: ${output[0].values.length}");
  assert(output.length == sequenceLength);
  assert(output[0].values.length == embedSize);
  print(
      "MultiHeadAttention output shape is correct: ($sequenceLength, $embedSize)");
  print("MultiHeadAttention parameters count: ${mha.parameters().length}");
}