main function

void main()

Implementation

void main() {
  print("--- Encoder-Decoder Transformer Example ---");

  final sourceVocabSize = 10; // e.g., English words
  final targetVocabSize = 10; // e.g., French words
  final embedSize = 32;
  final sourceBlockSize = 8;
  final targetBlockSize = 8;
  final numLayers = 2;
  final numHeads = 4;

  // Initialize the Encoder-Decoder Transformer
  final model = EncoderDecoderTransformer(
    sourceVocabSize: sourceVocabSize,
    targetVocabSize: targetVocabSize,
    embedSize: embedSize,
    sourceBlockSize: sourceBlockSize,
    targetBlockSize: targetBlockSize,
    numLayers: numLayers,
    numHeads: numHeads,
  );

  final optimizer = SGD(model.parameters(), 0.05);

  // --- Sample Data for a simple sequence-to-sequence task ---
  // E.g., translating [1, 2, 3] to [5, 6, 7]
  // In real life, you'd use padding tokens and special start/end tokens.

  // Source sequence (e.g., "The dog barks")
  final sampleSourceInputs = [1, 2, 3, 4]; // Example token IDs

  // Target sequence (e.g., "Le chien aboie")
  // For training, target inputs are typically shifted right (teacher forcing).
  // If target sequence is [5, 6, 7, 8], input to decoder would be [START_TOKEN, 5, 6, 7]
  // and targets for loss would be [5, 6, 7, 8]. Let's simplify and use
  // target_inputs as the tokens given to the decoder, and target_outputs as what we want it to predict.
  final startToken = 0; // Assuming 0 is a special start-of-sequence token
  final sampleTargetInputs = [
    startToken,
    5,
    6,
    7
  ]; // Decoder input (shifted right)
  final sampleTargetOutputs = [
    5,
    6,
    7,
    8
  ]; // True next tokens for loss calculation

  if (sampleTargetInputs.length != sampleTargetOutputs.length) {
    throw ArgumentError(
        "Sample target inputs and outputs must have same length for this example.");
  }

  final epochs = 100;
  print("\nTraining Encoder-Decoder Transformer for $epochs epochs...");

  for (int epoch = 0; epoch < epochs; epoch++) {
    // Forward pass
    final logits = model.forward(sampleSourceInputs, sampleTargetInputs);

    // Calculate loss (only for the actual predicted tokens, excluding the START_TOKEN position)
    Value totalLoss = Value(0.0);
    // Iterate from 1 because targetInputs[0] is START_TOKEN, we want to predict targetOutputs[0]
    for (int t = 0; t < logits.length; t++) {
      final outputAtT = logits[t]; // Logits for predicting targetOutputs[t]
      final targetAtT = sampleTargetOutputs[t];

      final targetVector = ValueVector(List.generate(
        targetVocabSize,
        (i) => Value(i == targetAtT ? 1.0 : 0.0),
      ));
      totalLoss += outputAtT.softmax().crossEntropy(targetVector);
    }

    final meanLoss = totalLoss / Value(logits.length.toDouble());

    // Backward pass & optimization
    model.zeroGrad();
    meanLoss.backward();
    optimizer.step();

    if (epoch % 10 == 0 || epoch == epochs - 1) {
      print("Epoch $epoch | Loss: ${meanLoss.data.toStringAsFixed(4)}");
    }
  }
  print("✅ Encoder-Decoder Transformer training complete.");

  // --- Inference Example (Simplified Greedy Decoding) ---
  print("\n--- Encoder-Decoder Inference ---");
  final inferenceSource = [1, 2, 3]; // New source sequence to translate

  print("Source: $inferenceSource");
  List<int> generatedTargetSequence = [
    startToken
  ]; // Start with the start token
  final int maxGenerationLength = 5; // Max tokens to generate

  for (int i = 0; i < maxGenerationLength; i++) {
    // Encoder processes the source
    final encoderOut = model.encoder.forward(inferenceSource);

    // Decoder gets its current generated sequence as input and encoder output
    final decoderLogits =
        model.decoder.forward(generatedTargetSequence, encoderOut);

    // Get the logits for the *last* token generated by the decoder
    final lastTokenLogits = decoderLogits.last.softmax();

    // Greedy sampling: pick the token with the highest probability
    double maxProb = -1.0;
    int predictedNextToken = -1;
    for (int j = 0; j < lastTokenLogits.values.length; j++) {
      if (lastTokenLogits.values[j].data > maxProb) {
        maxProb = lastTokenLogits.values[j].data;
        predictedNextToken = j;
      }
    }

    // Add the predicted token to the generated sequence
    generatedTargetSequence.add(predictedNextToken);

    // Stop if an end-of-sequence token is predicted (you'd define one in your vocab)
    // For this example, we don't have an explicit end token, so we'll just generate `maxGenerationLength` tokens.
  }

  print(
      "Generated Target Sequence: $generatedTargetSequence (first token is START_TOKEN)");
  print(
      "Note: For real-world use, you'd handle padding, special tokens (EOS, PAD), and more advanced decoding strategies like beam search.");
}