verifyWeightUpdate function
void
verifyWeightUpdate()
Implementation
void verifyWeightUpdate() {
// Input X [2, 3], Weights W [3, 2]
final X = Tensor.fromList([2, 3], [1, 2, 3, 4, 5, 6]);
final W = Tensor.fill([3, 2], 0.5);
final out = X.matmul(W);
final loss = out.sum();
loss.backward();
final gradW = W.grad;
// dL/dW = X^T * dL/dOut
// Since dL/dOut is all 1s, gradW[0] should be X[0,0] + X[1,0] = 1 + 4 = 5.0
bool ok = closeEnough(gradW[0], 5.0) && closeEnough(gradW[5], 9.0);
print("WEIGHT UPDATE: ${ok ? '✅ PASS' : '❌ FAIL (Got $gradW)'}");
}