selectRow function

Tensor<Vector> selectRow(
  1. Tensor<Matrix> m,
  2. int rowIndex
)

Implementation

Tensor<Vector> selectRow(Tensor<Matrix> m, int rowIndex) {
  Vector outValue = m.value[rowIndex];
  Tensor<Vector> out = Tensor<Vector>(outValue);
  out.creator = Node(
    [m],
    () {
      for (int i = 0; i < outValue.length; i++) {
        m.grad[rowIndex][i] += out.grad[i];
      }
    },
    opName: 'selectRow',
    // <-- CRITICAL: Storing the non-Tensor parameter
    extraParams: {'rowIndex': rowIndex},
    cost: 0,
  );
  return out;
}