exampleLargerTransformerTraining function

void exampleLargerTransformerTraining()

Implementation

void exampleLargerTransformerTraining() {
  print(
      "\n--- Example 5: Training a Transformer with a larger vocabulary and sequence ---");

  final vocabSize = 50; // Increased vocabulary
  final embedSize = 32;
  final blockSize = 8; // Longer context
  final numLayers = 3;
  final numHeads = 4;

  final model = Transformer(
    vocabSize: vocabSize,
    embedSize: embedSize,
    blockSize: blockSize,
    numLayers: numLayers,
    numHeads: numHeads,
  );

  final optimizer =
      SGD(model.parameters(), 0.05); // Slightly reduced learning rate

  // More complex sample data
  final sampleInputs = [0, 1, 5, 2, 8, 12, 3, 10]; // 8 tokens
  final sampleTargets = [1, 5, 2, 8, 12, 3, 10, 15]; // Next tokens for each

  final epochs = 100;
  print("\nTraining for $epochs epochs with larger data...");

  for (int epoch = 0; epoch < epochs; epoch++) {
    final logits = model.forward(sampleInputs);

    Value totalLoss = Value(0.0);
    for (int t = 0; t < logits.length; t++) {
      final outputAtT = logits[t];
      final targetAtT = sampleTargets[t];

      final targetVector = ValueVector(List.generate(
        vocabSize,
        (i) => Value(i == targetAtT ? 1.0 : 0.0),
      ));
      totalLoss += outputAtT.softmax().crossEntropy(targetVector);
    }

    final meanLoss = totalLoss / Value(logits.length.toDouble());

    model.zeroGrad();
    meanLoss.backward();
    optimizer.step();

    if (epoch % 10 == 0 || epoch == epochs - 1) {
      print("Epoch $epoch | Loss: ${meanLoss.data.toStringAsFixed(4)}");
    }
  }
  print("✅ Larger model training complete.");
}