main function

void main()

Implementation

void main() {
  const int embedSize = 32;
  const int numHeads = 4;
  const int encoderEmbedSize = 64;
  const int maxSeqLen = 50;

  // Use a very conservative LR for the first few steps of AFT
  const double lr = 0.0001;
  const double gradClip = 0.1;

  final decoderBlock = TransformerDecoderBlock(
    embedSize,
    numHeads,
    encoderEmbedSize,
    maxSeqLen,
  );

  final xDec = Tensor.random([8, embedSize]);
  final xEnc = Tensor.random([12, encoderEmbedSize]);
  final target = Tensor.fill([8, embedSize], 0.5);

  final List<Tensor> tracker = [];
  final optimizer = Adam(decoderBlock.parameters(), lr: lr, gradClip: gradClip);

  print('--- Training Corrected TransformerDecoderBlock ---');

  for (int epoch = 0; epoch < 50; epoch++) {
    optimizer.zeroGrad();

    final output = decoderBlock.forward(xDec, xEnc, tracker);
    final loss = output.mseLoss(target);

    if (epoch % 5 == 0) {
      print(
        'Epoch ${epoch.toString().padLeft(2)} | Loss: ${loss.data[0].toStringAsFixed(8)}',
      );
    }

    if (loss.data[0].isNaN) {
      print("❌ NaN detected at Epoch $epoch. Shutting down.");
      break;
    }

    loss.backward();
    optimizer.step();

    // Memory Cleanup
    for (var t in tracker) t.dispose();
    tracker.clear();
    loss.dispose();
    output.dispose();
  }
}