generateMuZeroPure function
void
generateMuZeroPure()
Implementation
void generateMuZeroPure(
MuZeroGreedyAgent agent,
List<int> prompt,
int maxLength,
Map<int, String> itos,
Map<String, int> stoi,
) {
final List<int> generated = List.from(prompt);
final List<Tensor> initTracker = [];
Tensor rawInitState = agent.representation(prompt, initTracker);
Tensor currentState = rawInitState.detach();
for (var t in initTracker) t.dispose();
for (int i = 0; i < maxLength; i++) {
final List<Tensor> stepTracker = [];
// 1. Predict
final Tensor logits = agent.predictPolicy(currentState, stepTracker);
final List<double> data = logits.fetchData();
// 2. Loop-Breaker: Repetition Penalty for Step 1
// If we just said a word, slightly discourage saying it again immediately
if (generated.isNotEmpty) {
data[generated.last] -= 2.0;
}
final int bestToken = argMax(data);
generated.add(bestToken);
print("Step ${i.toString().padLeft(2)} -> ${itos[bestToken]}");
if (bestToken == stoi["."]!) break;
// 3. Dynamics (The "Imagination" Step)
// Pass i + 1 to ensure the Positional Embedding shifts forward
Tensor nextStateRaw = agent.dynamics(
currentState,
bestToken,
i + 1,
stepTracker,
);
Tensor nextState = nextStateRaw.detach();
currentState.dispose();
for (var t in stepTracker) t.dispose();
currentState = nextState;
}
currentState.dispose();
print("\nFinal Result: ${generated.map((id) => itos[id]).join(" ")}");
}