main function
void
main()
Implementation
void main() {
print("--- Stable Tensor-Engine AFT-GPT Training ---");
// 1. Hyperparameters
const int vocabSize = 5;
const int embedSize = 32;
const int blockSize = 16;
const int numLayers = 2;
final Map<String, int> stoi = {
"hello": 0,
"world": 1,
".": 2,
"<start>": 3,
"<pad>": 4,
};
final Map<int, String> itos = stoi.map((k, v) => MapEntry(v, k));
// Data: "hello world ."
// Input: <start> hello world
// Target: hello world .
final List<int> inputIds = [3, 0, 1];
final List<int> targetIds = [0, 1, 2];
final model = TransformerDecoder(
vocabSize: vocabSize,
embedSize: embedSize,
encoderEmbedSize: embedSize,
blockSize: blockSize,
numLayers: numLayers,
numHeads: 4,
);
final List<Tensor> tracker = [];
const double learningRate = 0.001;
final optimizer = Adam(model.parameters(), lr: learningRate);
// Dummy Encoder Context (e.g., could be a single 'thought' vector or a null-state)
final dummyEnc = Tensor.zeros([1, embedSize]);
print('Starting training...');
for (int epoch = 0; epoch <= 1000; epoch++) {
optimizer.zeroGrad();
// 2. Forward Pass
// logits: [T, vocabSize]
final logits = model.forward(inputIds, dummyEnc, tracker);
// 3. Loss Calculation
// Assuming your crossEntropy implementation handles [T, V] logits vs [T] targets
final loss = logits.crossEntropy(targetIds);
final double lossVal = loss.fetchData()[0];
// 4. Backprop
loss.backward();
// Optional: Gradient Clipping for safety
// for (var p in model.parameters()) {
// p.grad?.clamp(-1.0, 1.0);
// }
optimizer.step();
if (epoch % 100 == 0) {
print(
"Epoch ${epoch.toString().padLeft(4)} | Loss: ${lossVal.toStringAsFixed(10)}",
);
if (lossVal.isNaN) break;
}
// 5. Explicit GPU Memory Cleanup
for (var t in tracker) {
t.dispose();
}
tracker.clear();
loss.dispose();
logits.dispose();
}
// 6. Inference / Generation
print("\n--- Inference ---");
List<int> currentSeq = [stoi["<start>"]!];
for (int i = 0; i < 5; i++) {
// Generate next token
final out = model.forward(currentSeq, dummyEnc, tracker);
// Get the logits for the very last token generated
final lastTokenLogits = out.fetchData().sublist(
(currentSeq.length - 1) * vocabSize,
currentSeq.length * vocabSize,
);
int nextId = 0;
double best = -double.infinity;
for (int v = 0; v < vocabSize; v++) {
if (lastTokenLogits[v] > best) {
best = lastTokenLogits[v];
nextId = v;
}
}
currentSeq.add(nextId);
// Cleanup inference tensors
for (var t in tracker) t.dispose();
tracker.clear();
out.dispose();
if (nextId == stoi["."]) break;
}
print("Generated: ${currentSeq.map((id) => itos[id] ?? "??").join(" ")}");
}