update method

void update(
  1. List<double> state,
  2. int action,
  3. double reward,
  4. List<double> nextState, {
  5. double actorLr = 0.01,
  6. double criticLr = 0.01,
})

Implementation

void update(
  List<double> state,
  int action,
  double reward,
  List<double> nextState, {
  double actorLr = 0.01,
  double criticLr = 0.01,
}) {
  final v = critic.predict([state])[0][0];
  final vNext = critic.predict([nextState])[0][0];
  final td = reward + 0.99 * vNext - v;
  // critic target
  critic.fit(
    [state],
    [
      [v + td],
    ],
  );
  // actor target: increase logit for action proportionally to td
  final target = List<double>.filled(nActions, 0.0);
  target[action] = td;
  actor.fit([state], [target]);
}