updateFromEpisode method

void updateFromEpisode(
  1. List<List<double>> states,
  2. List<int> actions,
  3. List<double> returns, {
  4. bool normalize = true,
})

Update from a single episode represented as lists of states, actions and returns. Supports optional baseline subtraction and advantage normalization.

Implementation

void updateFromEpisode(
  List<List<double>> states,
  List<int> actions,
  List<double> returns, {
  bool normalize = true,
}) {
  if (states.length != actions.length || actions.length != returns.length) {
    throw ArgumentError('episode lengths must match');
  }
  final advs = <double>[];
  for (var r in returns) {
    var a = r;
    if (useBaseline) {
      a = r - baseline;
    }
    advs.add(a);
  }
  // optional normalization
  if (normalize && advs.isNotEmpty) {
    final mean = advs.reduce((a, b) => a + b) / advs.length;
    final variance =
        advs.map((a) => (a - mean) * (a - mean)).reduce((x, y) => x + y) /
        advs.length;
    final std = sqrt(variance + 1e-8);
    for (var i = 0; i < advs.length; i++) {
      advs[i] = (advs[i] - mean) / std;
    }
  }

  final xs = <List<double>>[];
  final ys = <List<double>>[];
  for (var i = 0; i < states.length; i++) {
    final s = states[i];
    final a = actions[i];
    final adv = advs[i];
    final target = List<double>.filled(nActions, 0.0);
    target[a] = adv;
    xs.add(s);
    ys.add(target);
  }
  policy.fit(xs, ys);

  if (useBaseline && returns.isNotEmpty) {
    // update running baseline (simple exponential moving average)
    final meanReturn = returns.reduce((a, b) => a + b) / returns.length;
    baseline = baseline * 0.9 + meanReturn * 0.1;
  }
}