main function
void
main()
Implementation
void main() {
const int embedSize = 32;
const int numHeads = 4;
const int encoderEmbedSize = 64;
const int maxSeqLen = 50;
// Use a very conservative LR for the first few steps of AFT
const double lr = 0.0001;
const double gradClip = 0.1;
final decoderBlock = TransformerDecoderBlock(
embedSize,
numHeads,
encoderEmbedSize,
maxSeqLen,
);
final xDec = Tensor.random([8, embedSize]);
final xEnc = Tensor.random([12, encoderEmbedSize]);
final target = Tensor.fill([8, embedSize], 0.5);
final List<Tensor> tracker = [];
final optimizer = Adam(decoderBlock.parameters(), lr: lr, gradClip: gradClip);
print('--- Training Corrected TransformerDecoderBlock ---');
for (int epoch = 0; epoch < 50; epoch++) {
optimizer.zeroGrad();
final output = decoderBlock.forward(xDec, xEnc, tracker);
final loss = output.mseLoss(target);
if (epoch % 5 == 0) {
print(
'Epoch ${epoch.toString().padLeft(2)} | Loss: ${loss.data[0].toStringAsFixed(8)}',
);
}
if (loss.data[0].isNaN) {
print("❌ NaN detected at Epoch $epoch. Shutting down.");
break;
}
loss.backward();
optimizer.step();
// Memory Cleanup
for (var t in tracker) t.dispose();
tracker.clear();
loss.dispose();
output.dispose();
}
}