main function
void
main()
Implementation
void main() {
print("--- Encoder-Decoder Transformer Example ---");
final sourceVocabSize = 10; // e.g., English words
final targetVocabSize = 10; // e.g., French words
final embedSize = 32;
final sourceBlockSize = 8;
final targetBlockSize = 8;
final numLayers = 2;
final numHeads = 4;
// Initialize the Encoder-Decoder Transformer
final model = EncoderDecoderTransformer(
sourceVocabSize: sourceVocabSize,
targetVocabSize: targetVocabSize,
embedSize: embedSize,
sourceBlockSize: sourceBlockSize,
targetBlockSize: targetBlockSize,
numLayers: numLayers,
numHeads: numHeads,
);
final optimizer = SGD(model.parameters(), 0.05);
// --- Sample Data for a simple sequence-to-sequence task ---
// E.g., translating [1, 2, 3] to [5, 6, 7]
// In real life, you'd use padding tokens and special start/end tokens.
// Source sequence (e.g., "The dog barks")
final sampleSourceInputs = [1, 2, 3, 4]; // Example token IDs
// Target sequence (e.g., "Le chien aboie")
// For training, target inputs are typically shifted right (teacher forcing).
// If target sequence is [5, 6, 7, 8], input to decoder would be [START_TOKEN, 5, 6, 7]
// and targets for loss would be [5, 6, 7, 8]. Let's simplify and use
// target_inputs as the tokens given to the decoder, and target_outputs as what we want it to predict.
final startToken = 0; // Assuming 0 is a special start-of-sequence token
final sampleTargetInputs = [
startToken,
5,
6,
7
]; // Decoder input (shifted right)
final sampleTargetOutputs = [
5,
6,
7,
8
]; // True next tokens for loss calculation
if (sampleTargetInputs.length != sampleTargetOutputs.length) {
throw ArgumentError(
"Sample target inputs and outputs must have same length for this example.");
}
final epochs = 100;
print("\nTraining Encoder-Decoder Transformer for $epochs epochs...");
for (int epoch = 0; epoch < epochs; epoch++) {
// Forward pass
final logits = model.forward(sampleSourceInputs, sampleTargetInputs);
// Calculate loss (only for the actual predicted tokens, excluding the START_TOKEN position)
Value totalLoss = Value(0.0);
// Iterate from 1 because targetInputs[0] is START_TOKEN, we want to predict targetOutputs[0]
for (int t = 0; t < logits.length; t++) {
final outputAtT = logits[t]; // Logits for predicting targetOutputs[t]
final targetAtT = sampleTargetOutputs[t];
final targetVector = ValueVector(List.generate(
targetVocabSize,
(i) => Value(i == targetAtT ? 1.0 : 0.0),
));
totalLoss += outputAtT.softmax().crossEntropy(targetVector);
}
final meanLoss = totalLoss / Value(logits.length.toDouble());
// Backward pass & optimization
model.zeroGrad();
meanLoss.backward();
optimizer.step();
if (epoch % 10 == 0 || epoch == epochs - 1) {
print("Epoch $epoch | Loss: ${meanLoss.data.toStringAsFixed(4)}");
}
}
print("✅ Encoder-Decoder Transformer training complete.");
// --- Inference Example (Simplified Greedy Decoding) ---
print("\n--- Encoder-Decoder Inference ---");
final inferenceSource = [1, 2, 3]; // New source sequence to translate
print("Source: $inferenceSource");
List<int> generatedTargetSequence = [
startToken
]; // Start with the start token
final int maxGenerationLength = 5; // Max tokens to generate
for (int i = 0; i < maxGenerationLength; i++) {
// Encoder processes the source
final encoderOut = model.encoder.forward(inferenceSource);
// Decoder gets its current generated sequence as input and encoder output
final decoderLogits =
model.decoder.forward(generatedTargetSequence, encoderOut);
// Get the logits for the *last* token generated by the decoder
final lastTokenLogits = decoderLogits.last.softmax();
// Greedy sampling: pick the token with the highest probability
double maxProb = -1.0;
int predictedNextToken = -1;
for (int j = 0; j < lastTokenLogits.values.length; j++) {
if (lastTokenLogits.values[j].data > maxProb) {
maxProb = lastTokenLogits.values[j].data;
predictedNextToken = j;
}
}
// Add the predicted token to the generated sequence
generatedTargetSequence.add(predictedNextToken);
// Stop if an end-of-sequence token is predicted (you'd define one in your vocab)
// For this example, we don't have an explicit end token, so we'll just generate `maxGenerationLength` tokens.
}
print(
"Generated Target Sequence: $generatedTargetSequence (first token is START_TOKEN)");
print(
"Note: For real-world use, you'd handle padding, special tokens (EOS, PAD), and more advanced decoding strategies like beam search.");
}