verifyMatMul function
void
verifyMatMul()
Implementation
void verifyMatMul() {
// A [1, 2] * B [2, 1] = C [1, 1]
final A = Tensor.fill([1, 2], 0.0);
final B = Tensor.fill([2, 1], 0.0);
A.data = [2.0, 3.0];
B.data = [4.0, 5.0];
final loss = A.matmul(B).sum();
loss.backward();
final gradA = A.grad;
final gradB = B.grad;
// d(AB)/dA = B^T -> [4.0, 5.0]
// d(AB)/dB = A^T -> [2.0, 3.0]
bool okA = closeEnough(gradA[0], 4.0) && closeEnough(gradA[1], 5.0);
bool okB = closeEnough(gradB[0], 2.0) && closeEnough(gradB[1], 3.0);
print("MATMUL: ${okA && okB ? '✅ PASS' : '❌ FAIL'}");
}