selectMatrixFrom3DGPU function

GPUTensor<Matrix> selectMatrixFrom3DGPU(
  1. GPUTensor<Tensor3D> t,
  2. int index,
  3. CommandBuffer tape
)

Implementation

GPUTensor<Matrix> selectMatrixFrom3DGPU(
    GPUTensor<Tensor3D> t,
    int index,
    CommandBuffer tape) {

  int height = t.shape[1];
  int width = t.shape[2];

  List<int> outShape = <int>[height, width];
  GPUTensor<Matrix> out = GPUTensor<Matrix>.empty(outShape);

  tape.putInt(OP_SLICE_ROW);
  tape.putString(t.id);
  tape.putString(out.id);
  // FIX: inCols entfernt, um Tape-Desync zu verhindern!
  tape.putInt(index);

  out.creator = GPUNode(
    <GPUTensor>[t],
        (CommandBuffer backwardTape) {
      backwardTape.putInt(OP_SLICE_ROW_BACKWARD);
      backwardTape.putString('${out.id}_grad');
      backwardTape.putString('${t.id}_grad');
      // FIX: inCols entfernt!
      backwardTape.putInt(index);
    },
    opName: 'selectMatrixFrom3DGPU',
  );

  return out;
}