generateMuZeroGreedy function

void generateMuZeroGreedy(
  1. MuZeroGreedyAgent agent,
  2. List<int> prompt,
  3. int maxLength,
  4. Map<int, String> itos,
)

Implementation

void generateMuZeroGreedy(
  MuZeroGreedyAgent agent,
  List<int> prompt,
  int maxLength,
  Map<int, String> itos,
) {
  final List<int> generated = List.from(prompt);

  for (int i = 0; i < maxLength; i++) {
    // Fresh tracker PER STEP (no graphs survive)
    final List<Tensor> tracker = [];

    // Recompute state from all known tokens
    final Tensor state = agent.representation(generated, tracker);

    final Tensor logits = agent.predictPolicy(state, tracker);

    final int bestToken = argMax(logits.fetchData());

    generated.add(bestToken);
    print("Step $i -> ${itos[bestToken]}");

    // Safe cleanup
    for (final t in tracker) {
      t.dispose();
    }

    if (bestToken == 2) break; // "."
  }

  print("\nFinal Result: ${generated.map((id) => itos[id]).join(" ")}");
}