verifyMatMul function

void verifyMatMul()

Implementation

void verifyMatMul() {
  // A [1, 2] * B [2, 1] = C [1, 1]
  final A = Tensor.fill([1, 2], 0.0);
  final B = Tensor.fill([2, 1], 0.0);
  A.data = [2.0, 3.0];
  B.data = [4.0, 5.0];

  final loss = A.matmul(B).sum();
  loss.backward();

  final gradA = A.grad;
  final gradB = B.grad;

  // d(AB)/dA = B^T -> [4.0, 5.0]
  // d(AB)/dB = A^T -> [2.0, 3.0]
  bool okA = closeEnough(gradA[0], 4.0) && closeEnough(gradA[1], 5.0);
  bool okB = closeEnough(gradB[0], 2.0) && closeEnough(gradB[1], 3.0);

  print("MATMUL: ${okA && okB ? '✅ PASS' : '❌ FAIL'}");
}