forward method

Tensor forward(
  1. Tensor x,
  2. List<Tensor> tracker, {
  3. Tensor? kv,
})

Implementation

Tensor forward(Tensor x, List<Tensor> tracker, {Tensor? kv}) {
  // --- DEBUG: Input Check ---
  _debugTensor("Input X", x);

  // 2. Keys and Values come from the context (kv) if provided,
  // otherwise they come from x (standard self-attention).
  // final context = kv ?? x;

  final q = queryLayer.forward(x, tracker);
  final k = keyLayer.forward(x, tracker);
  final v = valueLayer.forward(x, tracker);

  // --- DEBUG: Projection Check ---
  // If these are huge (> 20), AFT will likely explode
  _debugTensor("Query (Q)", q);
  _debugTensor("Key (K)", k);

  final out = Tensor.aft(q, k, v, posBias, masked);

  // --- DEBUG: Output Check ---
  _debugTensor("AFT Output", out);

  tracker.add(out);
  return out;
}