verifyWeightUpdate function

void verifyWeightUpdate()

Implementation

void verifyWeightUpdate() {
  // Input X [2, 3], Weights W [3, 2]
  final X = Tensor.fromList([2, 3], [1, 2, 3, 4, 5, 6]);
  final W = Tensor.fill([3, 2], 0.5);

  final out = X.matmul(W);
  final loss = out.sum();
  loss.backward();

  final gradW = W.grad;

  // dL/dW = X^T * dL/dOut
  // Since dL/dOut is all 1s, gradW[0] should be X[0,0] + X[1,0] = 1 + 4 = 5.0
  bool ok = closeEnough(gradW[0], 5.0) && closeEnough(gradW[5], 9.0);

  print("WEIGHT UPDATE: ${ok ? '✅ PASS' : '❌ FAIL (Got $gradW)'}");
}