verifyCausalMasking function
void
verifyCausalMasking()
Implementation
void verifyCausalMasking() {
final attn = Tensor.fromList([3, 3], List.filled(9, 10.0));
final mask = Tensor.fromList([3, 3], [1, 0, 0, 1, 1, 0, 1, 1, 1]);
final masked = attn * mask;
final loss = masked.sum();
loss.backward();
final grads = attn.grad;
bool ok =
closeEnough(grads[1], 0.0) &&
closeEnough(grads[2], 0.0) &&
closeEnough(grads[5], 0.0) &&
closeEnough(grads[0], 1.0);
print("CAUSAL MASK: ${ok ? '✅ PASS' : '❌ FAIL'}");
}