trainMuZero function
Training with Unrolled Dynamics to prevent "Latent Collapse"
Implementation
void trainMuZero(
MuZeroGreedyAgent agent,
List<int> targetSequence,
int epochs,
) {
final optimizer = Adam(agent.model.parameters(), lr: 0.001);
for (int epoch = 0; epoch <= epochs; epoch++) {
optimizer.zeroGrad();
double totalEpochLoss = 0.0;
final List<Tensor> tracker = [];
// Mode switching: Policy trains Representation/Head, Dynamics trains Transitions
bool isPolicyMode = (epoch % 3 == 0);
// Initial State
Tensor currentState = agent.representation([targetSequence[0]], tracker);
for (int t = 0; t < targetSequence.length - 1; t++) {
final int target = targetSequence[t + 1];
if (isPolicyMode) {
// Mode A: Policy (f) + Representation (h)
final Tensor logits = agent.predictPolicy(currentState, tracker);
final Tensor pLoss = logits.crossEntropy([target]);
pLoss.backward();
totalEpochLoss += pLoss.fetchData()[0];
// Teacher Forcing: Reset to ground truth for Policy training
currentState = agent.representation(
targetSequence.sublist(0, t + 2),
tracker,
);
} else {
// Mode B: Dynamics (g) - UNROLLED
// We force g() to produce a state that f() recognizes as the next word
final Tensor nextState = agent.dynamics(
currentState,
targetSequence[t],
t + 1,
tracker,
);
final Tensor iLogits = agent.predictPolicy(nextState, tracker);
final Tensor iLoss = iLogits.crossEntropy([target]);
iLoss.backward();
totalEpochLoss += iLoss.fetchData()[0];
// NO TEACHER FORCING: Use imagined state for the next step of the loop
// We detach to keep the CUDA gradient chain from getting too deep
currentState = nextState.detach();
}
}
optimizer.step();
for (final t in tracker) t.dispose();
tracker.clear();
if (epoch % 50 == 0) {
String mode = isPolicyMode ? "POLICY " : "DYNAMICS";
print(
"Epoch ${epoch.toString().padLeft(3)} | $mode | Loss: ${totalEpochLoss.toStringAsFixed(6)}",
);
}
}
}