trainMuZero function

void trainMuZero(
  1. MuZeroGreedyAgent agent,
  2. List<int> targetSequence,
  3. int epochs
)

Implementation

void trainMuZero(
  MuZeroGreedyAgent agent,
  List<int> targetSequence,
  int epochs,
) {
  final optimizer = Adam(agent.model.parameters(), lr: 0.001);

  for (int epoch = 0; epoch < epochs; epoch++) {
    optimizer.zeroGrad();
    double totalEpochLoss = 0.0;

    // Tracker is per-epoch
    final List<Tensor> tracker = [];

    // Start from first token
    Tensor currentState = agent.representation([targetSequence[0]], tracker);

    for (int t = 0; t < targetSequence.length - 1; t++) {
      final int target = targetSequence[t + 1];

      // --- Policy ---
      final Tensor logits = agent.predictPolicy(currentState, tracker);

      final Tensor loss = logits.crossEntropy([target]);

      // Backprop while graph exists
      loss.backward();

      totalEpochLoss += loss.fetchData()[0];

      // --- Dynamics ---
      final Tensor nextState = agent.dynamics(
        currentState,
        target,
        t + 1,
        tracker,
      );

      currentState = nextState.detach();

      // 🔑 HARD GRAPH CUT:
      // Dispose entire step graph and re-enter model cleanly
      for (final t in tracker) {
        t.dispose();
      }
      tracker.clear();

      // Recompute state WITHOUT history (no BPTT)
      currentState = agent.representation(
        targetSequence.sublist(0, t + 2),
        tracker,
      );
    }

    // Update parameters
    optimizer.step();

    // Final cleanup
    for (final t in tracker) {
      t.dispose();
    }
    tracker.clear();

    if (epoch % 50 == 0) {
      final wSample = agent.model.lmHead.parameters()[0].fetchData().sublist(
        0,
        3,
      );

      print(
        "Epoch ${epoch.toString().padLeft(3)} | "
        "Loss: ${totalEpochLoss.toStringAsFixed(6)} | "
        "Weights: $wSample",
      );
    }
  }
}