main function

void main()

Implementation

void main() async {
  final File file = File('tiny_shakespeare.txt');
  if (!await file.exists()) return;
  final String rawText = await file.readAsString();
  final tokenizer = CharTokenizer(rawText);
  final data = tokenizer.encode(rawText);

  const int blockSize = 64; // Reduced for VRAM
  const int embedSize = 64; // Reduced for VRAM
  const int numLayers = 2;
  const int numHeads = 4;

  final gpt = TransformerDecoder(
    vocabSize: tokenizer.vocabSize,
    embedSize: embedSize,
    encoderEmbedSize: embedSize,
    numLayers: numLayers,
    numHeads: numHeads,
    blockSize: blockSize,
  );

  final agent = MuZeroGreedyAgent(gpt, embedSize);
  final optimizer = Adam(gpt.parameters(), lr: 0.001);

  print("🎭 Training MuZero-Shakespeare...");

  for (int epoch = 0; epoch < 200; epoch++) {
    optimizer.zeroGrad();
    List<Tensor> tracker = [];
    double totalLoss = 0;

    // 1. Get Batch
    final (x, y) = getBatch(data, blockSize);

    // 2. Initial Representation
    Tensor currentState = agent.representation([x[0]], tracker);

    // 3. Unrolled Latent Training
    // We alternate between Policy and Dynamics to keep gradients stable
    bool isPolicyStep = (epoch % 2 == 0);

    for (int t = 0; t < x.length - 1; t++) {
      if (isPolicyStep) {
        final logits = agent.predictPolicy(currentState, tracker);
        final loss = logits.crossEntropy([y[t]]);
        loss.backward();
        totalLoss += loss.fetchData()[0];

        // Teacher forcing for Policy mode
        currentState = agent.representation(x.sublist(0, t + 2), tracker);
      } else {
        // Imagination mode: Move state forward using g(s, a)
        Tensor nextState = agent.dynamics(currentState, x[t], t + 1, tracker);
        final logits = agent.predictPolicy(nextState, tracker);
        final loss = logits.crossEntropy([y[t]]);
        loss.backward();
        totalLoss += loss.fetchData()[0];

        currentState = nextState.detach();
      }
    }

    optimizer.step();
    // _safeCleanup(tracker, gpt.parameters());
    _safeCleanup(tracker, gpt.parameters());

    if (epoch % 10 == 0) {
      print(
        "Epoch $epoch | Loss: ${(totalLoss / blockSize).toStringAsFixed(4)}",
      );
    }
  }

  print("\n--- Generating Pure Latent Shakespeare ---");
  generateMuZeroShakespeare(agent, tokenizer, "ROMEO: ", 200);
}