concatenateMatricesByColumn function

Tensor<Matrix> concatenateMatricesByColumn(
  1. List<Tensor<Matrix>> matrices
)

Implementation

Tensor<Matrix> concatenateMatricesByColumn(List<Tensor<Matrix>> matrices) {
  int numRows = matrices[0].value.length;
  Matrix outValue = [];
  for (int i = 0; i < numRows; i++) {
    Vector newRow = [];
    for (Tensor<Matrix> m in matrices) {
      newRow.addAll(m.value[i]);
    }
    outValue.add(newRow);
  }
  Tensor<Matrix> out = Tensor<Matrix>(outValue);
  out.creator = Node(matrices, () {
    int currentCol = 0;
    for (Tensor<Matrix> m in matrices) {
      int numCols = m.value[0].length;
      for (int r = 0; r < numRows; r++) {
        for (int c = 0; c < numCols; c++) {
          m.grad[r][c] += out.grad[r][currentCol + c];
        }
      }
      currentCol += numCols;
    }
  }, opName: 'concat_matrix_col'); // This name is already unique
  return out;
}