verifyCausalMasking function

void verifyCausalMasking()

Implementation

void verifyCausalMasking() {
  // 3x3 Attention Matrix (all 10s)
  final attn = Tensor.fill([3, 3], 10.0);

  // Lower triangular mask (Causal)
  final mask = Tensor.fromList([3, 3], [1, 0, 0, 1, 1, 0, 1, 1, 1]);

  // Apply mask (Element-wise multiplication)
  final masked = attn * mask;
  final loss = masked.sum();
  loss.backward();

  final grads = attn.grad;

  // The upper triangle (indices 1, 2, 5) should have 0 gradient because the mask was 0
  bool ok =
      closeEnough(grads[1], 0.0) &&
      closeEnough(grads[2], 0.0) &&
      closeEnough(grads[5], 0.0) &&
      closeEnough(grads[0], 1.0); // Active path

  print("CAUSAL MASK: ${ok ? '✅ PASS' : '❌ FAIL'}");
}