verifyMatMul function

void verifyMatMul()

Implementation

void verifyMatMul() {
  final A = Tensor.fromList([1, 2], [2.0, 3.0]);
  final B = Tensor.fromList([2, 1], [4.0, 5.0]);

  final loss = A.matmul(B).sum();
  loss.backward();

  final gradA = A.grad;
  final gradB = B.grad;

  bool okA = closeEnough(gradA[0], 4.0) && closeEnough(gradA[1], 5.0);
  bool okB = closeEnough(gradB[0], 2.0) && closeEnough(gradB[1], 3.0);

  print("MATMUL: ${okA && okB ? '✅ PASS' : '❌ FAIL'}");
}