verifyCausalMasking function

void verifyCausalMasking()

Implementation

void verifyCausalMasking() {
  final attn = Tensor.fromList([3, 3], List.filled(9, 10.0));
  final mask = Tensor.fromList([3, 3], [1, 0, 0, 1, 1, 0, 1, 1, 1]);

  final masked = attn * mask;
  final loss = masked.sum();
  loss.backward();

  final grads = attn.grad;

  bool ok =
      closeEnough(grads[1], 0.0) &&
      closeEnough(grads[2], 0.0) &&
      closeEnough(grads[5], 0.0) &&
      closeEnough(grads[0], 1.0);

  print("CAUSAL MASK: ${ok ? '✅ PASS' : '❌ FAIL'}");
}