replay method
void
replay(
{ - int batchSize = 16,
})
Implementation
void replay({int batchSize = 16}) {
if (_states.isEmpty) return;
final b = min(batchSize, _states.length);
final idxs = List<int>.generate(_states.length, (i) => i);
idxs.shuffle(_rand);
final batchIdx = idxs.sublist(0, b);
final xs = <List<double>>[];
final ys = <List<double>>[];
for (var i in batchIdx) {
final s = _states[i];
final a = _actions[i];
final r = _rewards[i];
final ns = _nextStates[i];
final q = network.predict([s])[0];
final qNext = (_targetNetwork ?? network).predict([ns])[0];
final target = List<double>.from(q);
final maxNext = qNext.reduce((x, y) => x > y ? x : y);
target[a] = r + gamma * maxNext;
xs.add(s);
ys.add(target);
}
network.fit(xs, ys);
_updates += 1;
if (_targetNetwork != null && _updates % targetUpdateSteps == 0) {
_targetNetwork!.applyParamsFrom(network);
}
}