computeWeightUpdate method

  1. @override
double computeWeightUpdate(
  1. N weight,
  2. N weightLastUpdate,
  3. num gradient,
  4. num previousGradient,
  5. List<num> previousUpdateDeltas,
  6. List<num> noImprovementCounter,
  7. int weightIndex,
  8. N neuronOutput,
)

Implementation of the weight update.

Implementation

@override
double computeWeightUpdate(
  N weight,
  N weightLastUpdate,
  num gradient,
  num previousGradient,
  List<num> previousUpdateDeltas,
  List<num> noImprovementCounter,
  int weightIndex,
  N neuronOutput,
) {
  var previousUpdateDelta = previousUpdateDeltas[weightIndex];

  var change = (gradient * previousGradient).signWithZeroTolerance();
  var gradientSign = gradient.signWithZeroTolerance();

  // Notified by previous iteration to not change direction:
  if (previousUpdateDelta < 0) {
    previousUpdateDelta = -previousUpdateDelta;
    change = 0;
  }

  double updateDelta;
  double weightUpdate;

  if (change > 0) {
    updateDelta = previousUpdateDelta.toDouble() * 1.2;
    updateDelta = min(updateDelta, weightMaxStep);

    weightUpdate = gradientSign * updateDelta;
  } else if (change < 0) {
    updateDelta = previousUpdateDelta.toDouble() * 0.50;
    updateDelta = max(updateDelta, weightMinStep);

    // Notify to the next iteration to not change direction:
    updateDelta = -updateDelta;

    if (globalLearnError > lastGlobalLearnError) {
      weightUpdate = weightLastUpdate * -1.0;
    } else {
      weightUpdate = 0.0;
    }
  } else {
    updateDelta = previousUpdateDelta * 1.0;
    weightUpdate = gradientSign * updateDelta;
  }

  previousUpdateDeltas[weightIndex] = updateDelta;

  return weightUpdate;
}