AdamMatrix constructor

AdamMatrix(
  1. List<Tensor> parameters, {
  2. required double learningRate,
  3. double beta1 = 0.9,
  4. double beta2 = 0.999,
  5. double epsilon = 1e-8,
})

Implementation

AdamMatrix(
    super.parameters, {
      required super.learningRate,
      this.beta1 = 0.9,
      this.beta2 = 0.999,
      this.epsilon = 1e-8,
    }) {
  _m = {};
  _v = {};
  for (Tensor param in parameters) {
    Matrix mMatrix = [];
    Matrix vMatrix = [];
    int numRows = (param.value as Matrix).length;
    int numCols = numRows > 0 ? (param.value as Matrix)[0].length : 0;
    for (int i = 0; i < numRows; i++) {
      mMatrix.add(List<double>.filled(numCols, 0.0));
      vMatrix.add(List<double>.filled(numCols, 0.0));
    }
    _m[param] = mMatrix;
    _v[param] = vMatrix;
  }
}