verifyWeightUpdate function

void verifyWeightUpdate()

Implementation

void verifyWeightUpdate() {
  final X = Tensor.fromList([2, 3], [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
  final W = Tensor.fromList([3, 2], List.filled(6, 0.5));

  final out = X.matmul(W);
  final loss = out.sum();
  loss.backward();

  final gradW = W.grad;
  bool ok = closeEnough(gradW[0], 5.0) && closeEnough(gradW[5], 9.0);

  print("WEIGHT UPDATE: ${ok ? '✅ PASS' : '❌ FAIL (Got $gradW)'}");
}