main function

void main()

Implementation

void main() {
  print("--- Stable Tensor-Engine AFT-GPT Training ---");

  // 1. Hyperparameters
  const int vocabSize = 5;
  const int embedSize = 32;
  const int blockSize = 16;
  const int numLayers = 2;

  final Map<String, int> stoi = {
    "hello": 0,
    "world": 1,
    ".": 2,
    "<start>": 3,
    "<pad>": 4,
  };
  final Map<int, String> itos = stoi.map((k, v) => MapEntry(v, k));

  // Data: "hello world ."
  // Input:  <start> hello world
  // Target: hello   world .
  final List<int> inputIds = [3, 0, 1];
  final List<int> targetIds = [0, 1, 2];

  final model = TransformerDecoder(
    vocabSize: vocabSize,
    embedSize: embedSize,
    encoderEmbedSize: embedSize,
    blockSize: blockSize,
    numLayers: numLayers,
    numHeads: 4,
  );

  final List<Tensor> tracker = [];
  const double learningRate = 0.001;
  final optimizer = Adam(model.parameters(), lr: learningRate);

  // Dummy Encoder Context (e.g., could be a single 'thought' vector or a null-state)
  final dummyEnc = Tensor.zeros([1, embedSize]);

  print('Starting training...');
  for (int epoch = 0; epoch <= 1000; epoch++) {
    optimizer.zeroGrad();

    // 2. Forward Pass
    // logits: [T, vocabSize]
    final logits = model.forward(inputIds, dummyEnc, tracker);

    // 3. Loss Calculation
    // Assuming your crossEntropy implementation handles [T, V] logits vs [T] targets
    final loss = logits.crossEntropy(targetIds);
    final double lossVal = loss.fetchData()[0];

    // 4. Backprop
    loss.backward();

    // Optional: Gradient Clipping for safety
    // for (var p in model.parameters()) {
    //   p.grad?.clamp(-1.0, 1.0);
    // }

    optimizer.step();

    if (epoch % 100 == 0) {
      print(
        "Epoch ${epoch.toString().padLeft(4)} | Loss: ${lossVal.toStringAsFixed(10)}",
      );
      if (lossVal.isNaN) break;
    }

    // 5. Explicit GPU Memory Cleanup
    for (var t in tracker) {
      t.dispose();
    }
    tracker.clear();
    loss.dispose();
    logits.dispose();
  }

  // 6. Inference / Generation
  print("\n--- Inference ---");
  List<int> currentSeq = [stoi["<start>"]!];

  for (int i = 0; i < 5; i++) {
    // Generate next token
    final out = model.forward(currentSeq, dummyEnc, tracker);

    // Get the logits for the very last token generated
    final lastTokenLogits = out.fetchData().sublist(
      (currentSeq.length - 1) * vocabSize,
      currentSeq.length * vocabSize,
    );

    int nextId = 0;
    double best = -double.infinity;
    for (int v = 0; v < vocabSize; v++) {
      if (lastTokenLogits[v] > best) {
        best = lastTokenLogits[v];
        nextId = v;
      }
    }

    currentSeq.add(nextId);

    // Cleanup inference tensors
    for (var t in tracker) t.dispose();
    tracker.clear();
    out.dispose();

    if (nextId == stoi["."]) break;
  }

  print("Generated: ${currentSeq.map((id) => itos[id] ?? "??").join(" ")}");
}