matmul method

Tensor matmul(
  1. Tensor o
)

Matrix Multiplication (Dot Product) Result shape: M, N

Implementation

Tensor matmul(Tensor o) {
  // 1. Determine M, K1, K2, N
  // We assume 'this' is [..., M, K] and 'o' is [K, N]
  int m = shape[shape.length - 2];
  int k1 = shape.last;
  int k2 = o.shape[0];
  int n = o.shape.last;

  // if (k1 != k2) {
  //   throw ArgumentError(
  //     "Dimension mismatch: A columns ($k1) must match B rows ($k2). "
  //     "Full Shapes: $shape @ ${o.shape}",
  //   );
  // }

  // 2. Calculate output shape
  // For [Batch, M, K] @ [K, N] -> [Batch, M, N]
  List<int> outShape = List.from(shape);
  outShape[outShape.length - 1] = n;

  // 3. Dispatch to CUDA
  // Ensure your C++ engine.matmulTensors handles the stride for batching!
  return Tensor._raw(engine.matmulTensors(_handle, o._handle), outShape);
}