main function

void main()

Implementation

void main() {
  const int dim = 16;
  const int seqLen = 10;
  const int currentT = 5;
  const double lr = 0.001; // Reduced for stability

  // 1. Initialize
  final aft = AFTAttention(dim, 4, seqLen, masked: false);

  // --- CRITICAL: Stabilize Weights ---
  // AFT is extremely sensitive to large weights.
  // We force everything to small values to prevent the Epoch 0/1 explosion.
  for (var p in aft.parameters()) {
    final data = p.fetchData();
    final rand = math.Random();
    for (int i = 0; i < data.length; i++) {
      data[i] = (rand.nextDouble() * 2 - 1) * 0.02;
    }
    p.data = data;
  }

  final x = Tensor.random([currentT, dim]);
  final target = Tensor.fill([
    currentT,
    dim,
  ], 0.1); // Small target for stability
  List<Tensor> tracker = [];

  // Use tighter gradClip in Adam
  final optimizer = Adam(aft.parameters(), lr: lr, gradClip: 0.1);

  // print('--- AFT Debug Training Run ---');

  for (int step = 0; step < 10; step++) {
    // print('\nSTEP $step');
    optimizer.zeroGrad();

    final output = aft.forward(x, tracker);
    final loss = output.mseLoss(target);

    final lVal = loss.fetchData()[0];
    // print('Current Loss: ${lVal.toStringAsFixed(6)}');

    if (lVal.isNaN) {
      print("🛑 Training aborted: Loss is NaN.");
      break;
    }

    loss.backward();

    // DEBUG: Check Gradients
    _checkGradients(aft.parameters());

    optimizer.step();

    // Memory Cleanup
    for (var t in tracker) t.dispose();
    tracker.clear();
    loss.dispose();
    output.dispose();
  }
}