replay method

void replay({
  1. int batchSize = 16,
})

Implementation

void replay({int batchSize = 16}) {
  if (_states.isEmpty) return;
  final b = min(batchSize, _states.length);
  final idxs = List<int>.generate(_states.length, (i) => i);
  idxs.shuffle(_rand);
  final batchIdx = idxs.sublist(0, b);
  final xs = <List<double>>[];
  final ys = <List<double>>[];
  for (var i in batchIdx) {
    final s = _states[i];
    final a = _actions[i];
    final r = _rewards[i];
    final ns = _nextStates[i];
    final q = network.predict([s])[0];
    final qNext = (_targetNetwork ?? network).predict([ns])[0];
    final target = List<double>.from(q);
    final maxNext = qNext.reduce((x, y) => x > y ? x : y);
    target[a] = r + gamma * maxNext;
    xs.add(s);
    ys.add(target);
  }
  network.fit(xs, ys);
  _updates += 1;
  if (_targetNetwork != null && _updates % targetUpdateSteps == 0) {
    _targetNetwork!.applyParamsFrom(network);
  }
}