step method

  1. @override
void step()
override

Performs a single optimization step (parameter update).

Subclasses must implement this method to define their specific update rule (e.g., the standard gradient descent update, or the more complex Adam update).

Implementation

@override
void step() {
  _t++;
  for (Tensor param in parameters) {
    Matrix valMat = param.value as Matrix;
    Matrix gradMat = param.grad as Matrix;
    Matrix mMat = _m[param]!;
    Matrix vMat = _v[param]!;

    for (int r = 0; r < valMat.length; r++) {
      for (int c = 0; c < valMat[0].length; c++) {
        mMat[r][c] = beta1 * mMat[r][c] + (1 - beta1) * gradMat[r][c];
        vMat[r][c] = beta2 * vMat[r][c] + (1 - beta2) * pow(gradMat[r][c], 2);

        double mHat = mMat[r][c] / (1 - pow(beta1, _t));
        double vHat = vMat[r][c] / (1 - pow(beta2, _t));

        valMat[r][c] -= learningRate * mHat / (sqrt(vHat) + epsilon);
      }
    }
  }
}