forward method

List<ValueVector> forward(
  1. List<ValueVector> imageFlatPatches,
  2. int H_img,
  3. int W_img
)

Forward pass for the Swin Transformer.

imageFlatPatches are the flattened patches of the input image. (num_patches, patchSizepatchSizeinChannels)

Implementation

List<ValueVector> forward(
    List<ValueVector> imageFlatPatches, int H_img, int W_img) {
  // Initial patch embedding
  var x =
      patchEmbedding.forward(imageFlatPatches); // (num_patches, embedDims[0])

  // Calculate initial H, W in terms of patches
  int currentH = H_img ~/ patchSize;
  int currentW = W_img ~/ patchSize;

  // Pass through all stages
  for (int i = 0; i < stages.length; i++) {
    x = stages[i].forward(x, currentH, currentW);
    // Update H, W if patch merging occurred in the stage
    if (i < stages.length - 1) {
      // Only if not the last stage (which has merging)
      currentH ~/= 2;
      currentW ~/= 2;
    }
  }

  // Apply final layer normalization
  x = x.map((v) => finalNorm.forward(v)).toList();

  // If for classification, often a global average pooling and a linear head
  // For general backbone, return the feature map.
  return x;
}