trainStep method

double trainStep(
  1. List<double> x,
  2. List<double> y
)

Implementation

double trainStep(List<double> x, List<double> y) {
  // 1. Push raw floats directly into the pre-allocated VRAM addresses
  inputRef.pushData(x);
  targetRef.pushData(y);

  // 2. Fire the pre-compiled C++ execution instructions
  CudaEngine.run(fTape);
  CudaEngine.run(bTape); // (bTape handles zero_grad internally)
  CudaEngine.run(oTape);

  // 3. Sync only the scalar loss back to the Dart CPU heap
  lossRef.toCpu();
  return lossRef.value;
}