selectMatrixFrom3DGPU function
Implementation
GPUTensor<Matrix> selectMatrixFrom3DGPU(
GPUTensor<Tensor3D> t,
int index,
CommandBuffer tape) {
int height = t.shape[1];
int width = t.shape[2];
List<int> outShape = <int>[height, width];
GPUTensor<Matrix> out = GPUTensor<Matrix>.empty(outShape);
tape.putInt(OP_SLICE_ROW);
tape.putString(t.id);
tape.putString(out.id);
// FIX: inCols entfernt, um Tape-Desync zu verhindern!
tape.putInt(index);
out.creator = GPUNode(
<GPUTensor>[t],
(CommandBuffer backwardTape) {
backwardTape.putInt(OP_SLICE_ROW_BACKWARD);
backwardTape.putString('${out.id}_grad');
backwardTape.putString('${t.id}_grad');
// FIX: inCols entfernt!
backwardTape.putInt(index);
},
opName: 'selectMatrixFrom3DGPU',
);
return out;
}