concatenateMatricesByColumn function
Implementation
Tensor<Matrix> concatenateMatricesByColumn(List<Tensor<Matrix>> matrices) {
int numRows = matrices[0].value.length;
Matrix outValue = [];
for (int i = 0; i < numRows; i++) {
Vector newRow = [];
for (Tensor<Matrix> m in matrices) {
newRow.addAll(m.value[i]);
}
outValue.add(newRow);
}
Tensor<Matrix> out = Tensor<Matrix>(outValue);
out.creator = Node(matrices, () {
int currentCol = 0;
for (Tensor<Matrix> m in matrices) {
int numCols = m.value[0].length;
for (int r = 0; r < numRows; r++) {
for (int c = 0; c < numCols; c++) {
m.grad[r][c] += out.grad[r][currentCol + c];
}
}
currentCol += numCols;
}
}, opName: 'concat_matrix_col'); // This name is already unique
return out;
}