matMul function
Performs matrix multiplication, using the GPU for large matrices.
Implementation
Tensor<Matrix> matMul(Tensor<Matrix> a, Tensor<Matrix> b) {
int M = a.value.length;
int N = a.value[0].length;
int P = b.value[0].length;
// Note: Add your shape check assertion here if needed
Matrix outValue; // This will hold the result from either the CPU or GPU path.
// --- Condition to switch between CPU and GPU ---
// If either matrix has more than 1000 elements, use the GPU.
Matrix bT = [];
for (int i = 0; i < P; i++) {
Vector row = [];
for (int j = 0; j < N; j++) {
row.add(b.value[j][i]);
}
bT.add(row);
}
outValue = [];
for (int i = 0; i < M; i++) {
Vector rowA = a.value[i];
Vector outRow = [];
for (int j = 0; j < P; j++) {
Vector rowBT = bT[j];
double sum = 0;
for (int k = 0; k < N; k++) {
sum += rowA[k] * rowBT[k];
}
outRow.add(sum);
}
outValue.add(outRow);
}
// --- Common Logic for Both Paths (Backward Pass Setup) ---
Tensor<Matrix> out = Tensor<Matrix>(outValue);
int cost = 2 * M * N * P;
out.creator = Node(
[a, b],
() {
// Create b.T here for the gradient calculation
Matrix bT = [];
for (int i = 0; i < P; i++) {
Vector row = [];
for (int j = 0; j < N; j++) {
row.add(b.value[j][i]);
}
bT.add(row);
}
// grad_a = grad_out @ b.T
for (int i = 0; i < M; i++) {
for (int k = 0; k < N; k++) {
for (int j = 0; j < P; j++) {
a.grad[i][k] += out.grad[i][j] * bT[j][k];
}
}
}
// grad_b = a.T @ grad_out
Matrix aT = [];
for (int i = 0; i < N; i++) {
Vector row = [];
for (int j = 0; j < M; j++) {
row.add(a.value[j][i]);
}
aT.add(row);
}
for (int k = 0; k < N; k++) {
for (int j = 0; j < P; j++) {
for (int i = 0; i < M; i++) {
b.grad[k][j] += aT[k][i] * out.grad[i][j];
}
}
}
},
opName: 'matMul', // This name is specific and standard
cost: cost,
);
return out;
}