main function
void
main()
Implementation
void main() async {
final File file = File('tiny_shakespeare.txt');
if (!await file.exists()) return;
final String rawText = await file.readAsString();
final tokenizer = CharTokenizer(rawText);
final data = tokenizer.encode(rawText);
const int blockSize = 64; // Reduced for VRAM
const int embedSize = 64; // Reduced for VRAM
const int numLayers = 2;
const int numHeads = 4;
final gpt = TransformerDecoder(
vocabSize: tokenizer.vocabSize,
embedSize: embedSize,
encoderEmbedSize: embedSize,
numLayers: numLayers,
numHeads: numHeads,
blockSize: blockSize,
);
final agent = MuZeroGreedyAgent(gpt, embedSize);
final optimizer = Adam(gpt.parameters(), lr: 0.001);
print("🎠Training MuZero-Shakespeare...");
for (int epoch = 0; epoch < 200; epoch++) {
optimizer.zeroGrad();
List<Tensor> tracker = [];
double totalLoss = 0;
// 1. Get Batch
final (x, y) = getBatch(data, blockSize);
// 2. Initial Representation
Tensor currentState = agent.representation([x[0]], tracker);
// 3. Unrolled Latent Training
// We alternate between Policy and Dynamics to keep gradients stable
bool isPolicyStep = (epoch % 2 == 0);
for (int t = 0; t < x.length - 1; t++) {
if (isPolicyStep) {
final logits = agent.predictPolicy(currentState, tracker);
final loss = logits.crossEntropy([y[t]]);
loss.backward();
totalLoss += loss.fetchData()[0];
// Teacher forcing for Policy mode
currentState = agent.representation(x.sublist(0, t + 2), tracker);
} else {
// Imagination mode: Move state forward using g(s, a)
Tensor nextState = agent.dynamics(currentState, x[t], t + 1, tracker);
final logits = agent.predictPolicy(nextState, tracker);
final loss = logits.crossEntropy([y[t]]);
loss.backward();
totalLoss += loss.fetchData()[0];
currentState = nextState.detach();
}
}
optimizer.step();
// _safeCleanup(tracker, gpt.parameters());
_safeCleanup(tracker, gpt.parameters());
if (epoch % 10 == 0) {
print(
"Epoch $epoch | Loss: ${(totalLoss / blockSize).toStringAsFixed(4)}",
);
}
}
print("\n--- Generating Pure Latent Shakespeare ---");
generateMuZeroShakespeare(agent, tokenizer, "ROMEO: ", 200);
}