encapsulateGPUGraph<T> function

Tensor<T> encapsulateGPUGraph<T>(
  1. List<Tensor> cpuDynamicInputs,
  2. List<GPUTensor> gpuDynamicInputs,
  3. List<Tensor> cpuStaticParams,
  4. List<GPUTensor> gpuStaticParams,
  5. GPUTensor gpuOutput,
  6. T initialOutputValue,
  7. Uint8List forwardTape,
  8. Uint8List backwardTape,
)

Implementation

Tensor<T> encapsulateGPUGraph<T>(
    List<Tensor> cpuDynamicInputs,
    List<GPUTensor> gpuDynamicInputs,
    List<Tensor> cpuStaticParams,
    List<GPUTensor> gpuStaticParams,
    GPUTensor gpuOutput,
    T initialOutputValue,
    Uint8List forwardTape,
    Uint8List backwardTape
    ) {

  for (int i = 0; i < cpuDynamicInputs.length; i = i + 1) {
    CudaEngine.load(gpuDynamicInputs[i].id, cpuDynamicInputs[i].dataPtr, gpuDynamicInputs[i].shape);
  }

  CudaEngine.run(forwardTape);

  Tensor<T> out = Tensor<T>(initialOutputValue);
  CudaEngine.retrieve(gpuOutput.id, out.dataPtr);

  List<Tensor> allCpuNodes = [...cpuDynamicInputs, ...cpuStaticParams];
  List<GPUTensor> allGpuNodes = [...gpuDynamicInputs, ...gpuStaticParams];

  out.creator = Node(
      allCpuNodes,
          () {
        // Step 1: Zero the GPU gradients
        CommandBuffer zeroTape = CommandBuffer();
        for (int i = 0; i < allGpuNodes.length; i = i + 1) {
          zeroTape.putInt(OP_ZERO_GRAD);
          zeroTape.putString('${allGpuNodes[i].id}_grad');
        }

        // Fixed: Use .bytes() instead of .buffer
        CudaEngine.run(zeroTape.bytes());

        // Step 2: Push the "starting" gradient for the output back to GPU
        CudaEngine.load('${gpuOutput.id}_grad', out.gradPtr, gpuOutput.shape);

        // Step 3: Run the GPU backward tape
        CudaEngine.run(backwardTape);

        // Step 4: Pull gradients back to CPU and add them
        for (int i = 0; i < allCpuNodes.length; i = i + 1) {
          int numElements = 1;
          List<int> sList = allGpuNodes[i].shape;
          for (int s = 0; s < sList.length; s = s + 1) {
            numElements = numElements * sList[s];
          }

          Pointer<Float> tempGradPtr = calloc<Float>(numElements);
          CudaEngine.retrieve('${allGpuNodes[i].id}_grad', tempGradPtr);
          Float32List tempGradView = tempGradPtr.asTypedList(numElements);

          for (int k = 0; k < numElements; k = k + 1) {
            allCpuNodes[i].grad[k] = allCpuNodes[i].grad[k] + tempGradView[k];
          }

          calloc.free(tempGradPtr);
        }
      },
      opName: 'gpu_graph_block'
  );

  return out;
}