forward method
Forward pass on the GPU
Implementation
Tensor forward(Tensor videoEmbeddings, List<Tensor> tracker) {
final int numFrames = videoEmbeddings.shape[0];
if (numFrames > maxVideoSequenceLength) {
throw ArgumentError("Video too long for maxVideoSequenceLength");
}
// 1. Projection
Tensor x = frameProjection != null
? frameProjection!.forward(videoEmbeddings, tracker)
: videoEmbeddings;
// 2. Slice and Add Positions
final pos = posEmbeddings.slice(0, numFrames);
// Note: slice is a view, usually doesn't need tracking unless you want to auto-dispose
final xWithPos = x + pos;
tracker.add(xWithPos);
// 3. Transformer Encoder
final encoded = transformerEncoder.forwardEmbeddings(xWithPos, tracker);
// 4. Global Average Pooling
final pooled = encoded.mean();
tracker.add(pooled);
// 5. MLP Head
final logits = mlpHead.forward(pooled, tracker);
return logits;
}