concatenateMatricesByColumn function

Tensor<Matrix> concatenateMatricesByColumn(
  1. List<Tensor<Matrix>> matrices
)

Implementation

Tensor<Matrix> concatenateMatricesByColumn(List<Tensor<Matrix>> matrices) {
  int numRows = matrices[0].shape[0];

  List<Matrix> cachedMatrices = [];
  for (int k = 0; k < matrices.length; k = k + 1) {
    cachedMatrices.add(matrices[k].value);
  }

  Matrix outValue = [];
  for (int i = 0; i < numRows; i = i + 1) {
    Vector newRow = [];
    for (int k = 0; k < matrices.length; k = k + 1) {
      Matrix mMat = cachedMatrices[k];
      int mCols = matrices[k].shape[1];
      for (int j = 0; j < mCols; j = j + 1) {
        newRow.add(mMat[i][j]);
      }
    }
    outValue.add(newRow);
  }

  Tensor<Matrix> out = Tensor<Matrix>(outValue);

  out.creator = Node(
      matrices,
          () {
        int currentCol = 0;
        int outCols = out.shape[1];
        for (int k = 0; k < matrices.length; k = k + 1) {
          Tensor<Matrix> m = matrices[k];
          int numCols = m.shape[1];

          for (int r = 0; r < numRows; r = r + 1) {
            for (int c = 0; c < numCols; c = c + 1) {
              int mIdx = r * numCols + c;
              int outIdx = r * outCols + (currentCol + c);
              m.grad[mIdx] = m.grad[mIdx] + out.grad[outIdx];
            }
          }
          currentCol = currentCol + numCols;
        }
      },
      opName: 'concat_matrix_col'
  );
  return out;
}