trainMuZero function
Implementation
void trainMuZero(
MuZeroGreedyAgent agent,
List<int> targetSequence,
int epochs,
) {
final optimizer = Adam(agent.model.parameters(), lr: 0.001);
for (int epoch = 0; epoch < epochs; epoch++) {
optimizer.zeroGrad();
double totalEpochLoss = 0.0;
// Tracker is per-epoch
final List<Tensor> tracker = [];
// Start from first token
Tensor currentState = agent.representation([targetSequence[0]], tracker);
for (int t = 0; t < targetSequence.length - 1; t++) {
final int target = targetSequence[t + 1];
// --- Policy ---
final Tensor logits = agent.predictPolicy(currentState, tracker);
final Tensor loss = logits.crossEntropy([target]);
// Backprop while graph exists
loss.backward();
totalEpochLoss += loss.fetchData()[0];
// --- Dynamics ---
final Tensor nextState = agent.dynamics(
currentState,
target,
t + 1,
tracker,
);
currentState = nextState.detach();
// 🔑 HARD GRAPH CUT:
// Dispose entire step graph and re-enter model cleanly
for (final t in tracker) {
t.dispose();
}
tracker.clear();
// Recompute state WITHOUT history (no BPTT)
currentState = agent.representation(
targetSequence.sublist(0, t + 2),
tracker,
);
}
// Update parameters
optimizer.step();
// Final cleanup
for (final t in tracker) {
t.dispose();
}
tracker.clear();
if (epoch % 50 == 0) {
final wSample = agent.model.lmHead.parameters()[0].fetchData().sublist(
0,
3,
);
print(
"Epoch ${epoch.toString().padLeft(3)} | "
"Loss: ${totalEpochLoss.toStringAsFixed(6)} | "
"Weights: $wSample",
);
}
}
}