trainMuZero function

void trainMuZero(
  1. MuZeroGreedyAgent agent,
  2. List<int> targetSequence,
  3. int epochs,
  4. Map<String, int> stoi,
)

Implementation

void trainMuZero(
  MuZeroGreedyAgent agent,
  List<int> targetSequence,
  int epochs,
  Map<String, int> stoi,
) {
  final optimizer = Adam(agent.model.parameters(), lr: 0.001);

  for (int epoch = 0; epoch <= epochs; epoch++) {
    optimizer.zeroGrad();
    double totalEpochLoss = 0.0;
    final List<Tensor> tracker = [];

    // Alternating modes to keep gradients stable on custom CUDA engine
    bool isPolicyMode = (epoch % 3 == 0);

    Tensor currentState = agent.representation([targetSequence[0]], tracker);

    for (int t = 0; t < targetSequence.length - 1; t++) {
      final int target = targetSequence[t + 1];

      if (isPolicyMode) {
        // Mode A: Train Representation + Policy Head
        final Tensor logits = agent.predictPolicy(currentState, tracker);
        final Tensor pLoss = logits.crossEntropy([target]);
        pLoss.backward();
        totalEpochLoss += pLoss.fetchData()[0];

        // Teacher Forcing for Policy
        currentState = agent.representation(
          targetSequence.sublist(0, t + 2),
          tracker,
        );
      } else {
        // Mode B: Train Dynamics (Imagination)
        // We use t+1 as the position index to align with the target word's position
        Tensor nextState = agent.dynamics(
          currentState,
          targetSequence[t],
          t + 1,
          tracker,
        );
        final Tensor iLogits = agent.predictPolicy(nextState, tracker);
        final Tensor iLoss = iLogits.crossEntropy([target]);

        iLoss.backward();
        totalEpochLoss += iLoss.fetchData()[0];

        // Recurrent Handover (No Teacher Forcing)
        currentState = nextState.detach();
      }
    }

    optimizer.step();
    for (final t in tracker) t.dispose();
    tracker.clear();

    if (epoch % 100 == 0) {
      String mode = isPolicyMode ? "POLICY  " : "DYNAMICS";
      print(
        "Epoch ${epoch.toString().padLeft(4)} | $mode | Loss: ${totalEpochLoss.toStringAsFixed(6)}",
      );
    }
  }
}