trainMuZero function

void trainMuZero(
  1. MuZeroGreedyAgent agent,
  2. List<int> targetSequence,
  3. int epochs
)

Training with Unrolled Dynamics to prevent "Latent Collapse"

Implementation

void trainMuZero(
  MuZeroGreedyAgent agent,
  List<int> targetSequence,
  int epochs,
) {
  final optimizer = Adam(agent.model.parameters(), lr: 0.001);

  for (int epoch = 0; epoch <= epochs; epoch++) {
    optimizer.zeroGrad();
    double totalEpochLoss = 0.0;
    final List<Tensor> tracker = [];

    // Mode switching: Policy trains Representation/Head, Dynamics trains Transitions
    bool isPolicyMode = (epoch % 3 == 0);

    // Initial State
    Tensor currentState = agent.representation([targetSequence[0]], tracker);

    for (int t = 0; t < targetSequence.length - 1; t++) {
      final int target = targetSequence[t + 1];

      if (isPolicyMode) {
        // Mode A: Policy (f) + Representation (h)
        final Tensor logits = agent.predictPolicy(currentState, tracker);
        final Tensor pLoss = logits.crossEntropy([target]);
        pLoss.backward();
        totalEpochLoss += pLoss.fetchData()[0];

        // Teacher Forcing: Reset to ground truth for Policy training
        currentState = agent.representation(
          targetSequence.sublist(0, t + 2),
          tracker,
        );
      } else {
        // Mode B: Dynamics (g) - UNROLLED
        // We force g() to produce a state that f() recognizes as the next word
        final Tensor nextState = agent.dynamics(
          currentState,
          targetSequence[t],
          t + 1,
          tracker,
        );
        final Tensor iLogits = agent.predictPolicy(nextState, tracker);
        final Tensor iLoss = iLogits.crossEntropy([target]);

        iLoss.backward();
        totalEpochLoss += iLoss.fetchData()[0];

        // NO TEACHER FORCING: Use imagined state for the next step of the loop
        // We detach to keep the CUDA gradient chain from getting too deep
        currentState = nextState.detach();
      }
    }

    optimizer.step();
    for (final t in tracker) t.dispose();
    tracker.clear();

    if (epoch % 50 == 0) {
      String mode = isPolicyMode ? "POLICY  " : "DYNAMICS";
      print(
        "Epoch ${epoch.toString().padLeft(3)} | $mode | Loss: ${totalEpochLoss.toStringAsFixed(6)}",
      );
    }
  }
}