main function

void main()

Implementation

void main() {
  print("--- Generative Pretrained Transformer (GPT) Training Example ---");

  // 1. Define GPT Model Hyperparameters
  const int vocabSize = 20; // Example vocabulary size
  const int embedSize = 32;
  const int blockSize = 10; // Maximum sequence length the GPT can process
  const int numLayers = 3;
  const int numHeads = 4;

  print("GPT Model Configuration:");
  print("  Vocabulary Size: $vocabSize");
  print("  Embedding Size: $embedSize");
  print("  Block Size (Max Context Length): $blockSize");
  print("  Number of Layers: $numLayers");
  print("  Number of Heads: $numHeads");

  // 2. Simple Vocabulary for demonstration
  // This vocabulary must be consistent between training and inference
  final Map<String, int> stoi = {
    "hello": 0,
    "world": 1,
    "this": 2,
    "is": 3,
    "a": 4,
    "test": 5,
    "generation": 6,
    "model": 7,
    "the": 8,
    "quick": 9,
    "brown": 10,
    "fox": 11,
    "jumps": 12,
    "over": 13,
    "lazy": 14,
    "dog": 15,
    ".": 16, // End of sentence token
    "<start>": 17, // Start of sequence token
    "<pad>": 18, // Padding token
  };
  final Map<int, String> itos = stoi.map((key, value) => MapEntry(value, key));

  // Get special token IDs
  final int startTokenId = stoi["<start>"]!;
  final int padTokenId = stoi["<pad>"]!;

  print("\nExample Vocabulary:");
  print(itos);

  // 3. Create a Dummy Dataset
  // In a real scenario, this would be loaded from files, tokenized, and batched.
  // We'll create a few simple sequences for next-token prediction.
  // Each sequence is (input_tokens, target_tokens) where target_tokens are input_tokens shifted by one.
  // E.g., "hello world ." -> input: "<start> hello world", target: "hello world ."

  final List<List<int>> rawSequences = [
    [startTokenId, stoi["hello"]!, stoi["world"]!, stoi["."]!],
    [
      startTokenId,
      stoi["this"]!,
      stoi["is"]!,
      stoi["a"]!,
      stoi["test"]!,
      stoi["."]!
    ],
    [
      startTokenId,
      stoi["the"]!,
      stoi["quick"]!,
      stoi["brown"]!,
      stoi["fox"]!,
      stoi["."]!
    ],
  ];

  List<List<int>> trainInputs = [];
  List<List<int>> trainTargets = [];

  for (var seq in rawSequences) {
    // Input sequence: all tokens except the last one
    List<int> input = seq.sublist(0, seq.length - 1);
    // Target sequence: all tokens except the first one (what we predict)
    List<int> target = seq.sublist(1);

    // Pad sequences to blockSize if needed (for simplicity, we'll keep them shorter or truncate)
    if (input.length > blockSize) {
      input = input.sublist(0, blockSize);
      target =
          target.sublist(0, blockSize); // Make sure target matches input length
    }
    // Pad if shorter than blockSize for consistent input shapes in a batch
    while (input.length < blockSize) {
      input.add(padTokenId);
      target.add(padTokenId);
    }

    trainInputs.add(input);
    trainTargets.add(target);
  }

  print("\nDummy Training Data:");
  for (int i = 0; i < trainInputs.length; i++) {
    print("  Input:  ${trainInputs[i].map((id) => itos[id]).join(' ')}");
    print("  Target: ${trainTargets[i].map((id) => itos[id]).join(' ')}");
  }

  // 4. Instantiate the GPT model (your TransformerDecoder)
  print("\nInitializing GPT (TransformerDecoder) for training...");
  final gptModel = TransformerDecoder(
    vocabSize: vocabSize,
    embedSize: embedSize,
    blockSize: blockSize,
    numLayers: numLayers,
    numHeads: numHeads,
    encoderEmbedSize:
        embedSize, // Still needed to satisfy constructor for cross-attention
  );
  print(
      "GPT (TransformerDecoder) initialized. Total parameters: ${gptModel.parameters().length}");

  // 5. Setup Optimizer
  const double learningRate = 0.01;
  final optimizer = SGD(gptModel.parameters(), learningRate);
  print("Optimizer (SGD) initialized with learning rate: $learningRate");

  // FIX: Provide a non-empty dummy encoder output to satisfy the CrossAttention layer.
  // In a true GPT, the CrossAttention layer would typically not exist or be ignored.
  // This dummy output allows the code to run without a "No element" error,
  // even though its values are not functionally meaningful for a pure GPT.
  final List<ValueVector> dummyEncoderOutput = List.generate(
    1, // Provide at least one dummy token
    (_) => ValueVector(List.filled(
        embedSize,
        Value(
            0.0))), // Each token vector should be of encoderEmbedSize (which is embedSize here)
  );

  // 6. Training Loop
  const int numEpochs = 500;
  print("\n--- Starting Training ---");

  for (int epoch = 0; epoch < numEpochs; epoch++) {
    double totalLoss = 0.0;

    for (int i = 0; i < trainInputs.length; i++) {
      final inputSequence = trainInputs[i];
      final targetSequence = trainTargets[i];

      // Zero gradients
      optimizer.zeroGrad();

      // Forward pass
      final List<ValueVector> logits =
          gptModel.forward(inputSequence, dummyEncoderOutput);

      // Calculate loss (Cross-Entropy Loss)
      // We are predicting the next token for each position in the input sequence.
      Value batchLoss = Value(0.0);
      int activeTokens = 0; // Count tokens that are not padding

      for (int t = 0; t < logits.length; t++) {
        // Only calculate loss for non-padding tokens
        if (targetSequence[t] != padTokenId) {
          final ValueVector tokenLogits = logits[t];
          final int trueTargetId = targetSequence[t];

          // Softmax then negative log likelihood for true target
          // This is a simplified cross-entropy calculation
          final Value trueLogit = tokenLogits.values[trueTargetId];
          final Value sumExpLogits =
              tokenLogits.values.map((v) => v.exp()).reduce((a, b) => a + b);
          final Value logSumExp = sumExpLogits.log();
          final Value negLogProb =
              logSumExp - trueLogit; // Negative log-likelihood

          batchLoss += negLogProb;
          activeTokens++;
        }
      }

      // Average loss over active tokens
      if (activeTokens > 0) {
        batchLoss = batchLoss / Value(activeTokens.toDouble());
      } else {
        batchLoss = Value(0.0); // No active tokens, no loss
      }

      totalLoss += batchLoss.data;

      // Backward pass
      batchLoss.backward();

      // Update parameters
      optimizer.step();
    }

    if ((epoch + 1) % 1 == 0 || epoch == 0) {
      print(
          "Epoch ${epoch + 1}/${numEpochs}, Loss: ${totalLoss / trainInputs.length}");
    }
  }

  print("\n--- Training Complete ---");

  // 7. Test Generation after (pseudo) training
  print("\n--- Testing Generation After Training ---");
  List<int> generatedSequence = [startTokenId];
  final int maxTestGenerationLength = 10;

  for (int i = 0; i < maxTestGenerationLength; i++) {
    List<int> currentInput = List.from(generatedSequence);
    if (currentInput.length > blockSize) {
      currentInput = currentInput.sublist(currentInput.length - blockSize);
    }

    // Pass dummy encoder output as before
    final List<ValueVector> logits =
        gptModel.forward(currentInput, dummyEncoderOutput);

    // Get the logits for the last token and sample
    final ValueVector lastTokenLogits = logits.last;
    final ValueVector probabilities = lastTokenLogits.softmax();

    // Greedy sampling for simplicity
    double maxProb = -1.0;
    int predictedNextToken = -1;
    for (int j = 0; j < probabilities.values.length; j++) {
      if (probabilities.values[j].data > maxProb) {
        maxProb = probabilities.values[j].data;
        predictedNextToken = j;
      }
    }

    generatedSequence.add(predictedNextToken);

    if (predictedNextToken == stoi["."]) {
      // Stop on sentence end token
      break;
    }
    if (generatedSequence.length >= maxTestGenerationLength + 1) {
      // +1 for start token
      break;
    }
  }

  print("Generated Text: ${generatedSequence.map((id) => itos[id]).join(' ')}");
  print("---------------------------------------");
}