generateMuZeroShakespeare function

void generateMuZeroShakespeare(
  1. MuZeroGreedyAgent agent,
  2. CharTokenizer tokenizer,
  3. String prompt,
  4. int length,
)

Implementation

void generateMuZeroShakespeare(
  MuZeroGreedyAgent agent,
  CharTokenizer tokenizer,
  String prompt,
  int length,
) {
  List<int> promptTokens = tokenizer.encode(prompt);
  stdout.write(prompt);

  List<Tensor> initTracker = [];
  // Initial thought
  Tensor rawState = agent.representation(promptTokens, initTracker);
  Tensor currentState = rawState.detach();
  for (var t in initTracker) {
    t.dispose();
  }

  for (int i = 0; i < length; i++) {
    List<Tensor> stepTracker = [];

    // 1. Predict next char from current latent state
    final logits = agent.predictPolicy(currentState, stepTracker);
    final row = logits.fetchData();

    // 2. Sample
    int nextId = sampleNucleus(row, temp: 0.8, topP: 0.9);
    stdout.write(tokenizer.decode([nextId]));

    // 3. Move latent state forward (Imagination)
    // We pass the chosen character BACK into the dynamics head
    Tensor nextStateRaw = agent.dynamics(
      currentState,
      nextId,
      promptTokens.length + i,
      stepTracker,
    );
    Tensor nextState = nextStateRaw.detach();

    // 4. Memory Handover
    currentState.dispose();
    for (var t in stepTracker) {
      t.dispose();
    }
    currentState = nextState;
  }

  currentState.dispose();
  print("\n[Latent Generation Complete]");
}