exampleLargerTransformerTraining function
void
exampleLargerTransformerTraining()
Implementation
void exampleLargerTransformerTraining() {
print(
"\n--- Example 5: Training a Transformer with a larger vocabulary and sequence ---");
final vocabSize = 50; // Increased vocabulary
final embedSize = 32;
final blockSize = 8; // Longer context
final numLayers = 3;
final numHeads = 4;
final model = Transformer(
vocabSize: vocabSize,
embedSize: embedSize,
blockSize: blockSize,
numLayers: numLayers,
numHeads: numHeads,
);
final optimizer =
SGD(model.parameters(), 0.05); // Slightly reduced learning rate
// More complex sample data
final sampleInputs = [0, 1, 5, 2, 8, 12, 3, 10]; // 8 tokens
final sampleTargets = [1, 5, 2, 8, 12, 3, 10, 15]; // Next tokens for each
final epochs = 100;
print("\nTraining for $epochs epochs with larger data...");
for (int epoch = 0; epoch < epochs; epoch++) {
final logits = model.forward(sampleInputs);
Value totalLoss = Value(0.0);
for (int t = 0; t < logits.length; t++) {
final outputAtT = logits[t];
final targetAtT = sampleTargets[t];
final targetVector = ValueVector(List.generate(
vocabSize,
(i) => Value(i == targetAtT ? 1.0 : 0.0),
));
totalLoss += outputAtT.softmax().crossEntropy(targetVector);
}
final meanLoss = totalLoss / Value(logits.length.toDouble());
model.zeroGrad();
meanLoss.backward();
optimizer.step();
if (epoch % 10 == 0 || epoch == epochs - 1) {
print("Epoch $epoch | Loss: ${meanLoss.data.toStringAsFixed(4)}");
}
}
print("✅ Larger model training complete.");
}