verifyCausalMasking function
void
verifyCausalMasking()
Implementation
void verifyCausalMasking() {
// 3x3 Attention Matrix (all 10s)
final attn = Tensor.fill([3, 3], 10.0);
// Lower triangular mask (Causal)
final mask = Tensor.fromList([3, 3], [1, 0, 0, 1, 1, 0, 1, 1, 1]);
// Apply mask (Element-wise multiplication)
final masked = attn * mask;
final loss = masked.sum();
loss.backward();
final grads = attn.grad;
// The upper triangle (indices 1, 2, 5) should have 0 gradient because the mask was 0
bool ok =
closeEnough(grads[1], 0.0) &&
closeEnough(grads[2], 0.0) &&
closeEnough(grads[5], 0.0) &&
closeEnough(grads[0], 1.0); // Active path
print("CAUSAL MASK: ${ok ? '✅ PASS' : '❌ FAIL'}");
}