selectRow function

Tensor<Vector> selectRow(
  1. Tensor<Matrix> m,
  2. int rowIndex
)

Implementation

Tensor<Vector> selectRow(Tensor<Matrix> m, int rowIndex) {
  int numCols = m.shape[1];

  Matrix mMat = m.value;

  Vector outValue = [];
  for (int i = 0; i < numCols; i = i + 1) {
    outValue.add(mMat[rowIndex][i]);
  }

  Tensor<Vector> out = Tensor<Vector>(outValue);

  out.creator = Node(
    [m],
        () {
      for (int i = 0; i < numCols; i = i + 1) {
        int mIdx = rowIndex * numCols + i;
        m.grad[mIdx] = m.grad[mIdx] + out.grad[i];
      }
    },
    opName: 'selectRow',
    extraParams: {'rowIndex': rowIndex},
    cost: 0,
  );
  return out;
}