matmul method
Matrix Multiplication (Dot Product)
Result shape: M, N
Implementation
Tensor matmul(Tensor o) {
// 1. Determine M, K1, K2, N
// We assume 'this' is [..., M, K] and 'o' is [K, N]
int m = shape[shape.length - 2];
int k1 = shape.last;
int k2 = o.shape[0];
int n = o.shape.last;
// if (k1 != k2) {
// throw ArgumentError(
// "Dimension mismatch: A columns ($k1) must match B rows ($k2). "
// "Full Shapes: $shape @ ${o.shape}",
// );
// }
// 2. Calculate output shape
// For [Batch, M, K] @ [K, N] -> [Batch, M, N]
List<int> outShape = List.from(shape);
outShape[outShape.length - 1] = n;
// 3. Dispatch to CUDA
// Ensure your C++ engine.matmulTensors handles the stride for batching!
return Tensor._raw(engine.matmulTensors(_handle, o._handle), outShape);
}