forward method

Tensor forward(
  1. Tensor patchifiedImage,
  2. List<Tensor> tracker
)

Takes a Tensor of shape numPatches, patch_pixels For a 224x224 image with 16x16 patches, input is 196, 768

Implementation

Tensor forward(Tensor patchifiedImage, List<Tensor> tracker) {
  // 1. Linear Projection of Patches
  // [numPatches, pixels] -> [numPatches, embedSize]
  final xPatches = patchProjection.forward(patchifiedImage, tracker);

  // 2. Prepend CLS Token
  // Result shape: [numPatches + 1, embedSize]
  final xSeq = Tensor.concat([clsToken, xPatches]);
  tracker.add(xSeq);

  // 3. Add Positional Embeddings
  // (Element-wise addition on GPU)
  final xPos = xSeq + posEmbeddings;
  tracker.add(xPos);

  // 4. Transformer Encoder Blocks
  // Contextualizes all patches using AFT
  final encoded = transformerEncoder.forwardEmbeddings(xPos, tracker);

  return encoded;
}