forward method
Forward pass for the Swin Transformer.
imageFlatPatches are the flattened patches of the input image.
(num_patches, patchSizepatchSizeinChannels)
Implementation
List<ValueVector> forward(
List<ValueVector> imageFlatPatches, int H_img, int W_img) {
// Initial patch embedding
var x =
patchEmbedding.forward(imageFlatPatches); // (num_patches, embedDims[0])
// Calculate initial H, W in terms of patches
int currentH = H_img ~/ patchSize;
int currentW = W_img ~/ patchSize;
// Pass through all stages
for (int i = 0; i < stages.length; i++) {
x = stages[i].forward(x, currentH, currentW);
// Update H, W if patch merging occurred in the stage
if (i < stages.length - 1) {
// Only if not the last stage (which has merging)
currentH ~/= 2;
currentW ~/= 2;
}
}
// Apply final layer normalization
x = x.map((v) => finalNorm.forward(v)).toList();
// If for classification, often a global average pooling and a linear head
// For general backbone, return the feature map.
return x;
}