exampleSequenceGeneration function

void exampleSequenceGeneration()

Implementation

void exampleSequenceGeneration() {
  print("\n--- Example 6: Simplified Sequence Generation ---");

  // This is a very basic generative example. True generation involves
  // sampling from predicted probabilities and feeding the sampled token back.
  // The current model is decoder-only, so it can do this.

  final vocabSize = 10;
  final embedSize = 16;
  final blockSize = 4;

  // We'll load a pre-trained (or simply initialized) model
  final model = Transformer(
    vocabSize: vocabSize,
    embedSize: embedSize,
    blockSize: blockSize,
    numLayers: 2,
    numHeads: 2,
  );
  // In a real scenario, you'd load trained weights here.
  // For this example, we'll just use the randomly initialized model.

  List<int> prompt = [1, 2]; // Start with tokens 1, 2
  final int maxNewTokens = 5;

  print("Prompt: $prompt");
  print("Generating $maxNewTokens new tokens...");

  List<int> generatedSequence = List.from(prompt);

  for (int i = 0; i < maxNewTokens; i++) {
    // Crop the sequence to the block size if it exceeds it
    final currentInput = generatedSequence.length > blockSize
        ? generatedSequence.sublist(generatedSequence.length - blockSize)
        : generatedSequence;

    // Forward pass to get logits
    final logits = model.forward(currentInput);

    // Get the logits for the last token in the sequence (which is the prediction for the *next* token)
    final lastTokenLogits = logits.last;

    // Apply softmax to get probabilities
    final probabilities = lastTokenLogits.softmax();

    // Find the token with the highest probability (greedy sampling)
    double maxProb = -1.0;
    int predictedNextToken = -1;
    for (int j = 0; j < probabilities.values.length; j++) {
      if (probabilities.values[j].data > maxProb) {
        maxProb = probabilities.values[j].data;
        predictedNextToken = j;
      }
    }

    // Add the predicted token to the sequence
    generatedSequence.add(predictedNextToken);
    print(
        "Step ${i + 1}: Predicted token: $predictedNextToken (Prob: ${(maxProb * 100).toStringAsFixed(2)}%)");
  }

  print("Generated sequence: $generatedSequence");
  print(
      "Note: This is a simplified example. For better generation, consider techniques like top-k or nucleus sampling.");
}