selectRow function
Implementation
Tensor<Vector> selectRow(Tensor<Matrix> m, int rowIndex) {
int numCols = m.shape[1];
Matrix mMat = m.value;
Vector outValue = [];
for (int i = 0; i < numCols; i = i + 1) {
outValue.add(mMat[rowIndex][i]);
}
Tensor<Vector> out = Tensor<Vector>(outValue);
out.creator = Node(
[m],
() {
for (int i = 0; i < numCols; i = i + 1) {
int mIdx = rowIndex * numCols + i;
m.grad[mIdx] = m.grad[mIdx] + out.grad[i];
}
},
opName: 'selectRow',
extraParams: {'rowIndex': rowIndex},
cost: 0,
);
return out;
}