main function
void
main()
Implementation
void main() {
const int dim = 16;
const int seqLen = 10;
const int currentT = 5;
const double lr = 0.001; // Reduced for stability
// 1. Initialize
final aft = AFTAttention(dim, 4, seqLen, masked: false);
// --- CRITICAL: Stabilize Weights ---
// AFT is extremely sensitive to large weights.
// We force everything to small values to prevent the Epoch 0/1 explosion.
for (var p in aft.parameters()) {
final data = p.fetchData();
final rand = math.Random();
for (int i = 0; i < data.length; i++) {
data[i] = (rand.nextDouble() * 2 - 1) * 0.02;
}
p.data = data;
}
final x = Tensor.random([currentT, dim]);
final target = Tensor.fill([
currentT,
dim,
], 0.1); // Small target for stability
List<Tensor> tracker = [];
// Use tighter gradClip in Adam
final optimizer = Adam(aft.parameters(), lr: lr, gradClip: 0.1);
// print('--- AFT Debug Training Run ---');
for (int step = 0; step < 10; step++) {
// print('\nSTEP $step');
optimizer.zeroGrad();
final output = aft.forward(x, tracker);
final loss = output.mseLoss(target);
final lVal = loss.fetchData()[0];
// print('Current Loss: ${lVal.toStringAsFixed(6)}');
if (lVal.isNaN) {
print("🛑 Training aborted: Loss is NaN.");
break;
}
loss.backward();
// DEBUG: Check Gradients
_checkGradients(aft.parameters());
optimizer.step();
// Memory Cleanup
for (var t in tracker) t.dispose();
tracker.clear();
loss.dispose();
output.dispose();
}
}