step method

void step()

Performs the parameter update.

Implementation

void step() {
  t++; // Crucial for Bias Correction in the first few epochs

  for (int i = 0; i < params.length; i++) {
    // 1. Clip Gradients: Clamps outliers to prevent NaN weight updates
    // engine.clipGradients(params[i].handle, gradClip);

    // 2. Adam Update: The heavy lifting happens inside the CUDA kernel
    engine.adamStep(
      params[i].handle,
      m[i].handle,
      v[i].handle,
      t,
      lr,
      beta1,
      beta2,
      eps,
    );

    if (params[i].fetchData().isNotEmpty) {
      bool isNaN = params[i].fetchData()[0].isNaN;

      if (isNaN) {
        // _safeCleanup(tracker, loss, gpt.parameters());
        // print("Loss is: ${loss.fetchData()[0]}. Exiting");
        dispose();
        throw Exception("parameter[i]: $isNaN");
      }
    }

    tracker.addAll([m[i], v[i]]);
  }

  // print("m ${m}");
}