main function
void
main()
Implementation
void main() {
print("--- Generative Pretrained Transformer (GPT) Training Example ---");
// 1. Define GPT Model Hyperparameters
const int vocabSize = 40; // Increased vocabulary size
const int embedSize = 32;
const int blockSize = 15; // Increased block size for longer sequences
const int numLayers = 3;
const int numHeads = 4;
print("GPT Model Configuration:");
print(" Vocabulary Size: $vocabSize");
print(" Embedding Size: $embedSize");
print(" Block Size (Max Context Length): $blockSize");
print(" Number of Layers: $numLayers");
print(" Number of Heads: $numHeads");
// 2. Expanded Vocabulary for demonstration
final Map<String, int> stoi = {
"hello": 0,
"world": 1,
"this": 2,
"is": 3,
"a": 4,
"test": 5,
"generation": 6,
"model": 7,
"the": 8,
"quick": 9,
"brown": 10,
"fox": 11,
"jumps": 12,
"over": 13,
"lazy": 14,
"dog": 15,
".": 16,
"<start>": 17,
"<pad>": 18,
"dart": 19,
"programming": 20,
"language": 21,
"example": 22,
"code": 23,
"learning": 24,
"machine": 25,
"deep": 26,
"neural": 27,
"networks": 28,
"great": 29,
"simple": 30,
"powerful": 31,
"today": 32,
"future": 33,
"data": 34,
"science": 35,
"artificial": 36,
"intelligence": 37,
"next": 38,
"token": 39,
};
final Map<int, String> itos = stoi.map((key, value) => MapEntry(value, key));
// Verify vocabSize covers all tokens
assert(stoi.length <= vocabSize,
"vocabSize is too small for the defined vocabulary.");
// Get special token IDs
final int startTokenId = stoi["<start>"]!;
final int padTokenId = stoi["<pad>"]!;
final int endTokenId = stoi["."]!; // Using '.' as an example end token
print("\nExample Vocabulary:");
print(itos);
// 3. Create a Dummy Dataset with more varied sequences
final List<List<int>> rawSequences = [
[startTokenId, stoi["hello"]!, stoi["world"]!, stoi["."]!],
[
startTokenId,
stoi["this"]!,
stoi["is"]!,
stoi["a"]!,
stoi["test"]!,
stoi["."]!
],
[
startTokenId,
stoi["the"]!,
stoi["quick"]!,
stoi["brown"]!,
stoi["fox"]!,
stoi["jumps"]!,
stoi["over"]!,
stoi["the"]!,
stoi["lazy"]!,
stoi["dog"]!,
stoi["."]!
],
[
startTokenId,
stoi["dart"]!,
stoi["is"]!,
stoi["a"]!,
stoi["great"]!,
stoi["programming"]!,
stoi["language"]!,
stoi["."]!
],
[
startTokenId,
stoi["learning"]!,
stoi["deep"]!,
stoi["neural"]!,
stoi["networks"]!,
stoi["is"]!,
stoi["powerful"]!,
stoi["."]!
],
[
startTokenId,
stoi["machine"]!,
stoi["learning"]!,
stoi["example"]!,
stoi["code"]!,
stoi["."]!
],
[
startTokenId,
stoi["artificial"]!,
stoi["intelligence"]!,
stoi["is"]!,
stoi["the"]!,
stoi["future"]!,
stoi["."]!
],
];
List<List<int>> trainInputs = [];
List<List<int>> trainTargets = [];
for (var seq in rawSequences) {
// Input sequence: all tokens except the last one
List<int> input = seq.sublist(0, seq.length - 1);
// Target sequence: all tokens except the first one (what we predict)
List<int> target = seq.sublist(1);
// Pad or truncate sequences to blockSize
if (input.length > blockSize) {
input = input.sublist(0, blockSize);
target = target.sublist(0, blockSize);
}
while (input.length < blockSize) {
input.add(padTokenId);
target.add(padTokenId);
}
trainInputs.add(input);
trainTargets.add(target);
}
print("\nDummy Training Data:");
for (int i = 0; i < trainInputs.length; i++) {
print(" Input: ${trainInputs[i].map((id) => itos[id]).join(' ')}");
print(" Target: ${trainTargets[i].map((id) => itos[id]).join(' ')}");
}
// 4. Instantiate the GPT model (your TransformerDecoder)
print("\nInitializing GPT (TransformerDecoder) for training...");
final gptModel = TransformerDecoder(
vocabSize: vocabSize,
embedSize: embedSize,
blockSize: blockSize,
numLayers: numLayers,
numHeads: numHeads,
encoderEmbedSize: embedSize,
);
print(
"GPT (TransformerDecoder) initialized. Total parameters: ${gptModel.parameters().length}");
// 5. Setup Optimizer
const double learningRate = 0.01;
final optimizer = SGD(gptModel.parameters(), learningRate);
print("Optimizer (SGD) initialized with learning rate: $learningRate");
final List<ValueVector> dummyEncoderOutput = List.generate(
1,
(_) => ValueVector(List.filled(embedSize, Value(0.0))),
);
// 6. Training Loop
const int numEpochs = 1000; // Increased epochs for more complex data
print("\n--- Starting Training ---");
for (int epoch = 0; epoch < numEpochs; epoch++) {
double totalLoss = 0.0;
for (int i = 0; i < trainInputs.length; i++) {
final inputSequence = trainInputs[i];
final targetSequence = trainTargets[i];
optimizer.zeroGrad();
final List<ValueVector> logits =
gptModel.forward(inputSequence, dummyEncoderOutput);
Value batchLoss = Value(0.0);
int activeTokens = 0;
for (int t = 0; t < logits.length; t++) {
if (targetSequence[t] != padTokenId) {
final ValueVector tokenLogits = logits[t];
final int trueTargetId = targetSequence[t];
final Value trueLogit = tokenLogits.values[trueTargetId];
final Value sumExpLogits =
tokenLogits.values.map((v) => v.exp()).reduce((a, b) => a + b);
final Value logSumExp = sumExpLogits.log();
final Value negLogProb = logSumExp - trueLogit;
batchLoss += negLogProb;
activeTokens++;
}
}
if (activeTokens > 0) {
batchLoss = batchLoss / Value(activeTokens.toDouble());
} else {
batchLoss = Value(0.0);
}
totalLoss += batchLoss.data;
batchLoss.backward();
optimizer.step();
}
if ((epoch + 1) % 100 == 0 || epoch == 0) {
// Print less frequently for more epochs
print(
"Epoch ${epoch + 1}/${numEpochs}, Loss: ${totalLoss / trainInputs.length}");
}
}
print("\n--- Training Complete ---");
// 7. Test Generation after (pseudo) training
print("\n--- Testing Generation After Training ---");
List<int> generatedSequence = [startTokenId];
final int maxTestGenerationLength = 20; // Allow longer generation
for (int i = 0; i < maxTestGenerationLength; i++) {
List<int> currentInput = List.from(generatedSequence);
if (currentInput.length > blockSize) {
currentInput = currentInput.sublist(currentInput.length - blockSize);
}
final List<ValueVector> logits =
gptModel.forward(currentInput, dummyEncoderOutput);
final ValueVector lastTokenLogits = logits.last;
final ValueVector probabilities = lastTokenLogits.softmax();
double maxProb = -1.0;
int predictedNextToken = -1;
for (int j = 0; j < probabilities.values.length; j++) {
if (probabilities.values[j].data > maxProb) {
maxProb = probabilities.values[j].data;
predictedNextToken = j;
}
}
generatedSequence.add(predictedNextToken);
if (predictedNextToken == endTokenId) {
print("End of sequence token detected.");
break;
}
if (generatedSequence.length >= maxTestGenerationLength + 1) {
print("Maximum generation length reached.");
break;
}
}
print("Generated Text: ${generatedSequence.map((id) => itos[id]).join(' ')}");
print("---------------------------------------");
}