main function

void main()

Implementation

void main() async {
  final String rawText = "the cat sat on the mat";
  final tokenizer = CharTokenizer(rawText);
  final data = tokenizer.encode(rawText);

  const int blockSize = 8; // Increased from 4
  const int embedSize = 32; // Increased from 16

  final gpt = TransformerDecoder(
    vocabSize: tokenizer.vocabSize,
    embedSize: embedSize,
    encoderEmbedSize: embedSize,
    numLayers: 1,
    numHeads: 2,
    blockSize: blockSize,
  );

  final optimizer = Adam(gpt.parameters(), lr: 0.01);
  final dummyEnc = Tensor.zeros([1, embedSize]);

  print("🚀 Training to overfit the full sentence...");

  // 1. IMPROVED TRAINING: Slide through the whole sentence
  for (int step = 0; step < 501; step++) {
    optimizer.zeroGrad();
    double epochLoss = 0;
    int count = 0;

    // Slide over the text so the model sees all transitions
    for (int i = 0; i < data.length - blockSize; i++) {
      List<Tensor> tracker = [];
      final x = data.sublist(i, i + blockSize);
      final y = data.sublist(i + 1, i + blockSize + 1);

      final logits = gpt.forward(x, dummyEnc, tracker);
      final loss = logits.crossEntropy(y);

      loss.backward();
      epochLoss += loss.fetchData()[0];
      count++;

      // Cleanup
      for (var t in tracker) {
        if (!gpt.parameters().contains(t)) t.dispose();
      }
      loss.dispose();
    }
    optimizer.step();

    if (step % 100 == 0) {
      print("Step $step | Avg Loss: ${(epochLoss / count).toStringAsFixed(4)}");
    }
  }

  // 2. GENERATION: The Autoregressive Loop
  print("\n--- Generating Entire Sequence ---");

  // Start with the initial prompt "the "
  List<int> generatedIds = data.sublist(0, blockSize);
  stdout.write(tokenizer.decode(generatedIds));

  // Generate until we reach the length of the original text
  for (int i = 0; i < rawText.length - blockSize; i++) {
    List<Tensor> evalTracker = [];

    // Always take the most recent context (blockSize)
    List<int> context = generatedIds.sublist(generatedIds.length - blockSize);

    final logits = gpt.forward(context, dummyEnc, evalTracker);

    // Get the logits for the very last token in the window
    List<double> lastRow = logits.fetchRow(context.length - 1);

    // Greedy Search: Pick the absolute best character (ArgMax)
    int predId = 0;
    double maxVal = -double.infinity;
    for (int v = 0; v < lastRow.length; v++) {
      if (lastRow[v] > maxVal) {
        maxVal = lastRow[v];
        predId = v;
      }
    }

    generatedIds.add(predId);
    stdout.write(tokenizer.decode([predId]));

    // Cleanup memory
    for (var t in evalTracker) {
      t.dispose();
    }
    logits.dispose();
  }

  print("\n\nFinished!");
}