generateMuZeroPure function

void generateMuZeroPure(
  1. MuZeroGreedyAgent agent,
  2. List<int> prompt,
  3. int maxLength,
  4. Map<int, String> itos,
  5. Map<String, int> stoi,
)

Implementation

void generateMuZeroPure(
  MuZeroGreedyAgent agent,
  List<int> prompt,
  int maxLength,
  Map<int, String> itos,
  Map<String, int> stoi,
) {
  final List<int> generated = List.from(prompt);
  final List<Tensor> initTracker = [];

  Tensor rawInitState = agent.representation(prompt, initTracker);
  Tensor currentState = rawInitState.detach();
  for (var t in initTracker) t.dispose();

  for (int i = 0; i < maxLength; i++) {
    final List<Tensor> stepTracker = [];

    // 1. Predict
    final Tensor logits = agent.predictPolicy(currentState, stepTracker);
    final List<double> data = logits.fetchData();

    // 2. Loop-Breaker: Repetition Penalty for Step 1
    // If we just said a word, slightly discourage saying it again immediately
    if (generated.isNotEmpty) {
      data[generated.last] -= 2.0;
    }

    final int bestToken = argMax(data);
    generated.add(bestToken);
    print("Step ${i.toString().padLeft(2)} -> ${itos[bestToken]}");

    if (bestToken == stoi["."]!) break;

    // 3. Dynamics (The "Imagination" Step)
    // Pass i + 1 to ensure the Positional Embedding shifts forward
    Tensor nextStateRaw = agent.dynamics(
      currentState,
      bestToken,
      i + 1,
      stepTracker,
    );
    Tensor nextState = nextStateRaw.detach();

    currentState.dispose();
    for (var t in stepTracker) t.dispose();
    currentState = nextState;
  }

  currentState.dispose();
  print("\nFinal Result: ${generated.map((id) => itos[id]).join(" ")}");
}