concatenateMatricesByColumn function
Implementation
Tensor<Matrix> concatenateMatricesByColumn(List<Tensor<Matrix>> matrices) {
int numRows = matrices[0].shape[0];
List<Matrix> cachedMatrices = [];
for (int k = 0; k < matrices.length; k = k + 1) {
cachedMatrices.add(matrices[k].value);
}
Matrix outValue = [];
for (int i = 0; i < numRows; i = i + 1) {
Vector newRow = [];
for (int k = 0; k < matrices.length; k = k + 1) {
Matrix mMat = cachedMatrices[k];
int mCols = matrices[k].shape[1];
for (int j = 0; j < mCols; j = j + 1) {
newRow.add(mMat[i][j]);
}
}
outValue.add(newRow);
}
Tensor<Matrix> out = Tensor<Matrix>(outValue);
out.creator = Node(
matrices,
() {
int currentCol = 0;
int outCols = out.shape[1];
for (int k = 0; k < matrices.length; k = k + 1) {
Tensor<Matrix> m = matrices[k];
int numCols = m.shape[1];
for (int r = 0; r < numRows; r = r + 1) {
for (int c = 0; c < numCols; c = c + 1) {
int mIdx = r * numCols + c;
int outIdx = r * outCols + (currentCol + c);
m.grad[mIdx] = m.grad[mIdx] + out.grad[outIdx];
}
}
currentCol = currentCol + numCols;
}
},
opName: 'concat_matrix_col'
);
return out;
}