main function

void main()

Implementation

void main() {
  print("--- Generative Pretrained Transformer (GPT) Example ---");

  // 1. Define GPT Model Hyperparameters
  const int vocabSize =
      20; // Example vocabulary size (e.g., a few common words)
  const int embedSize = 32;
  const int blockSize = 10; // Maximum sequence length the GPT can process
  const int numLayers = 3;
  const int numHeads = 4;

  print("GPT Model Configuration:");
  print("  Vocabulary Size: $vocabSize");
  print("  Embedding Size: $embedSize");
  print("  Block Size (Max Context Length): $blockSize");
  print("  Number of Layers: $numLayers");
  print("  Number of Heads: $numHeads");

  // 2. Simple Vocabulary for demonstration
  final Map<String, int> stoi = {
    "hello": 0,
    "world": 1,
    "this": 2,
    "is": 3,
    "a": 4,
    "test": 5,
    "generation": 6,
    "model": 7,
    "the": 8,
    "quick": 9,
    "brown": 10,
    "fox": 11,
    "jumps": 12,
    "over": 13,
    "lazy": 14,
    "dog": 15,
    ".": 16, // End of sentence token
    "<start>": 17, // Start of sequence token
    "<pad>": 18, // Padding token
    // You might also have an <unk> token for unknown words
  };
  final Map<int, String> itos = stoi.map((key, value) => MapEntry(value, key));

  // Get the ID for the start token
  final int startTokenId = stoi["<start>"]!;
  final int endTokenId = stoi["."]!; // Using '.' as an example end token

  print("\nExample Vocabulary:");
  print(itos);

  // 3. Instantiate the GPT model (your TransformerDecoder)
  print("\nInitializing GPT (TransformerDecoder)...");
  final gptModel = TransformerDecoder(
    vocabSize: vocabSize,
    embedSize: embedSize,
    blockSize: blockSize,
    numLayers: numLayers,
    numHeads: numHeads,
    // For a GPT, the cross-attention part of TransformerDecoderBlock is not used.
    // We pass embedSize here just to satisfy the constructor.
    // In a pure GPT, you'd likely have a separate TransformerDecoder class
    // that doesn't include cross-attention at all.
    encoderEmbedSize: embedSize,
  );
  print(
      "GPT (TransformerDecoder) initialized. Total parameters: ${gptModel.parameters().length}");

  // 4. Text Generation Loop (Greedy Sampling)
  print("\n--- Starting Text Generation ---");
  List<int> generatedSequence = [startTokenId]; // Start with the <start> token
  final int maxGenerationLength = 15; // Max tokens to generate

  // Create a dummy encoder output for the cross-attention layer in TransformerDecoderBlock.
  // In a true GPT, the cross-attention layer would not exist, or its input would be ignored.
  // Here, we provide an empty list or a list of zeros to prevent errors,
  // knowing that the masked self-attention is what's truly driving generation.
  final List<ValueVector> simpleDummyEncoderOutput = [
    ValueVector(List.filled(embedSize, Value(0.0)))
  ]; // (1, embedSize)

  for (int i = 0; i < maxGenerationLength; i++) {
    // If the sequence exceeds blockSize, truncate it (common for long contexts)
    // Or, for generation, keep expanding and handle attention efficiently.
    // For simplicity, we'll keep the whole generated sequence for now if within blockSize.
    List<int> currentInput = List.from(generatedSequence);
    if (currentInput.length > blockSize) {
      currentInput = currentInput.sublist(currentInput.length - blockSize);
    }

    // Forward pass through the GPT (TransformerDecoder)
    // Pass the dummy encoder output to satisfy the method signature.
    final List<ValueVector> logits =
        gptModel.forward(currentInput, simpleDummyEncoderOutput);

    // Get the logits for the *last* token in the sequence (the prediction for the next token)
    final ValueVector lastTokenLogits = logits.last;

    // Apply softmax to get probabilities
    final ValueVector probabilities = lastTokenLogits.softmax();

    // Greedy sampling: pick the token with the highest probability
    double maxProb = -1.0;
    int predictedNextToken = -1;
    for (int j = 0; j < probabilities.values.length; j++) {
      if (probabilities.values[j].data > maxProb) {
        maxProb = probabilities.values[j].data;
        predictedNextToken = j;
      }
    }

    // Add the predicted token to the generated sequence
    generatedSequence.add(predictedNextToken);

    // Print current generation progress (convert IDs back to words)
    print("Generated: ${generatedSequence.map((id) => itos[id]).join(' ')}");

    // Stop if an end-of-sequence token is predicted
    if (predictedNextToken == endTokenId) {
      print("End of sequence token detected.");
      break;
    }
    if (generatedSequence.length >= maxGenerationLength + 1) {
      // +1 because we start with <start>
      print("Maximum generation length reached.");
      break;
    }
  }

  print("\n--- Final Generated Sequence ---");
  print(generatedSequence.map((id) => itos[id]).join(' '));
  print("--------------------------------");
}