selectRow function
Implementation
Tensor<Vector> selectRow(Tensor<Matrix> m, int rowIndex) {
Vector outValue = m.value[rowIndex];
Tensor<Vector> out = Tensor<Vector>(outValue);
out.creator = Node(
[m],
() {
for (int i = 0; i < outValue.length; i++) {
m.grad[rowIndex][i] += out.grad[i];
}
},
opName: 'selectRow',
// <-- CRITICAL: Storing the non-Tensor parameter
extraParams: {'rowIndex': rowIndex},
cost: 0,
);
return out;
}