updateBatch method

void updateBatch(
  1. List<List<double>> states,
  2. List<int> actions,
  3. List<double> advantages, {
  4. int epochs = 4,
  5. double lr = 0.001,
})

Implementation

void updateBatch(
  List<List<double>> states,
  List<int> actions,
  List<double> advantages, {
  int epochs = 4,
  double lr = 0.001,
}) {
  // Simplified: for each epoch, perform supervised fits on actor using advantages
  for (var e = 0; e < epochs; e++) {
    final xs = <List<double>>[];
    final ys = <List<double>>[];
    for (var i = 0; i < states.length; i++) {
      final target = List<double>.filled(nActions, 0.0);
      target[actions[i]] = advantages[i];
      xs.add(states[i]);
      ys.add(target);
    }
    ac.actor.fit(xs, ys);
    ac.critic.fit(states, states.map((_) => [0.0]).toList()); // placeholder
  }
}