forward method

Tensor forward(
  1. Tensor videoEmbeddings,
  2. List<Tensor> tracker
)

Forward pass on the GPU

Implementation

Tensor forward(Tensor videoEmbeddings, List<Tensor> tracker) {
  final int numFrames = videoEmbeddings.shape[0];

  if (numFrames > maxVideoSequenceLength) {
    throw ArgumentError("Video too long for maxVideoSequenceLength");
  }

  // 1. Projection
  Tensor x = frameProjection != null
      ? frameProjection!.forward(videoEmbeddings, tracker)
      : videoEmbeddings;

  // 2. Slice and Add Positions
  final pos = posEmbeddings.slice(0, numFrames);
  // Note: slice is a view, usually doesn't need tracking unless you want to auto-dispose

  final xWithPos = x + pos;
  tracker.add(xWithPos);

  // 3. Transformer Encoder
  final encoded = transformerEncoder.forwardEmbeddings(xWithPos, tracker);

  // 4. Global Average Pooling
  final pooled = encoded.mean();
  tracker.add(pooled);

  // 5. MLP Head
  final logits = mlpHead.forward(pooled, tracker);

  return logits;
}