step method
void
step()
Performs the parameter update.
Implementation
void step() {
t++; // Crucial for Bias Correction in the first few epochs
for (int i = 0; i < params.length; i++) {
// 1. Clip Gradients: Clamps outliers to prevent NaN weight updates
// engine.clipGradients(params[i].handle, gradClip);
// 2. Adam Update: The heavy lifting happens inside the CUDA kernel
engine.adamStep(
params[i].handle,
m[i].handle,
v[i].handle,
t,
lr,
beta1,
beta2,
eps,
);
if (params[i].fetchData().isNotEmpty) {
bool isNaN = params[i].fetchData()[0].isNaN;
if (isNaN) {
// _safeCleanup(tracker, loss, gpt.parameters());
// print("Loss is: ${loss.fetchData()[0]}. Exiting");
dispose();
throw Exception("parameter[i]: $isNaN");
}
}
tracker.addAll([m[i], v[i]]);
}
// print("m ${m}");
}