main function

void main()

Implementation

void main() {
  // 1. Hyperparameters
  const int numHeads =
      4; // Note: The logic provided processes embedSize directly
  const int embedSize = 16; // Total embedding dimension
  const int maxSeqLen = 20; // Capacity of the position bias matrix
  const int currentT = 8; // Actual sequence length for this batch
  const double lr = 0.01; // Learning rate

  // 2. Initialize Module (with Causal Masking enabled)
  // Masked: true is typical for Autoregressive tasks (like GPT)
  final aft = MultiHeadAFT(numHeads, embedSize, maxSeqLen, masked: true);

  // 3. Create Dummy Data [T, Dim]
  final input = Tensor.random([currentT, embedSize]);

  // Target: We want the model to predict a specific pattern (e.g., all 0.5s)
  final target = Tensor.fill([currentT, embedSize], 0.5);

  List<Tensor> tracker = [];

  final optimizer = Adam(aft.parameters(), lr: lr);

  print('--- MultiHeadAFT Training Step (Causal) ---');
  for (int x = 0; x < 20; x++) {
    optimizer.zeroGrad();
    // 4. Forward Pass
    // The module internally handles the slicing of posBias and the causal loop
    final output = aft.forward(input, tracker);
    print('Output shape: ${output.shape}');

    // 5. Compute Loss
    final diff = output - target;
    final loss = diff.pow(2.0);
    // final loss = output.mseLoss(target);
    print('Initial Loss: ${loss.data[0].toStringAsFixed(6)}');

    // 6. Backward Pass
    // This triggers the custom onBackward blocks for extraction,
    // the AFT logic loop, and the linear projections.
    loss.backward();

    // 7. Optimizer Step (Manual SGD)
    optimizer.step();
    // print('Updated $paramCount parameter tensors.');
    optimizer.zeroGrad();
    // 8. Verify Progress
    final nextOutput = aft.forward(input, tracker);

    final nextDiff = nextOutput - target;
    final nextLoss = nextDiff.pow(2.0);
    // final nextLoss = nextOutput.mseLoss(target);
    print('Loss after 1 step: ${nextLoss.data[0].toStringAsFixed(6)}');

    if (nextLoss.data[0] < loss.data[0]) {
      print('Success: Gradients flowed and loss decreased!');
    }
  }

  for (var track in tracker) {
    track.dispose();
  }
}