step method

  1. @override
void step()
override

Performs a single optimization step (parameter update).

Subclasses must implement this method to define their specific update rule (e.g., the standard gradient descent update, or the more complex Adam update).

Implementation

@override
void step() {
  _t++;
  for (Tensor param in parameters) {
    dynamic m = _m[param]!;
    dynamic v = _v[param]!;
    if (param.value is Vector) {
      Vector valVec = param.value as Vector;
      Vector gradVec = param.grad as Vector;
      Vector mVec = m as Vector;
      Vector vVec = v as Vector;
      for (int i = 0; i < valVec.length; i++) {
        mVec[i] = beta1 * mVec[i] + (1 - beta1) * gradVec[i];
        vVec[i] = beta2 * vVec[i] + (1 - beta2) * pow(gradVec[i], 2);
        double mHat = mVec[i] / (1 - pow(beta1, _t));
        double vHat = vVec[i] / (1 - pow(beta2, _t));
        valVec[i] -= learningRate * mHat / (sqrt(vHat) + epsilon);
      }
    } else if (param.value is Matrix) {
      Matrix valMat = param.value as Matrix;
      Matrix gradMat = param.grad as Matrix;
      Matrix mMat = m as Matrix;
      Matrix vMat = v as Matrix;
      for (int r = 0; r < valMat.length; r++) {
        for (int c = 0; c < valMat[0].length; c++) {
          mMat[r][c] = beta1 * mMat[r][c] + (1 - beta1) * gradMat[r][c];
          vMat[r][c] = beta2 * vMat[r][c] + (1 - beta2) * pow(gradMat[r][c], 2);
          double mHat = mMat[r][c] / (1 - pow(beta1, _t));
          double vHat = vMat[r][c] / (1 - pow(beta2, _t));
          valMat[r][c] -= learningRate * mHat / (sqrt(vHat) + epsilon);
        }
      }
    }
  }
}