updateFromEpisode method
void
updateFromEpisode(})
Update from a single episode represented as lists of states, actions and returns. Supports optional baseline subtraction and advantage normalization.
Implementation
void updateFromEpisode(
List<List<double>> states,
List<int> actions,
List<double> returns, {
bool normalize = true,
}) {
if (states.length != actions.length || actions.length != returns.length) {
throw ArgumentError('episode lengths must match');
}
final advs = <double>[];
for (var r in returns) {
var a = r;
if (useBaseline) {
a = r - baseline;
}
advs.add(a);
}
// optional normalization
if (normalize && advs.isNotEmpty) {
final mean = advs.reduce((a, b) => a + b) / advs.length;
final variance =
advs.map((a) => (a - mean) * (a - mean)).reduce((x, y) => x + y) /
advs.length;
final std = sqrt(variance + 1e-8);
for (var i = 0; i < advs.length; i++) {
advs[i] = (advs[i] - mean) / std;
}
}
final xs = <List<double>>[];
final ys = <List<double>>[];
for (var i = 0; i < states.length; i++) {
final s = states[i];
final a = actions[i];
final adv = advs[i];
final target = List<double>.filled(nActions, 0.0);
target[a] = adv;
xs.add(s);
ys.add(target);
}
policy.fit(xs, ys);
if (useBaseline && returns.isNotEmpty) {
// update running baseline (simple exponential moving average)
final meanReturn = returns.reduce((a, b) => a + b) / returns.length;
baseline = baseline * 0.9 + meanReturn * 0.1;
}
}