SwinTransformer constructor

SwinTransformer({
  1. required int imageSize,
  2. required int patchSize,
  3. required int inChannels,
  4. required List<int> embedDims,
  5. required List<int> depths,
  6. required List<int> numHeads,
  7. required int windowSize,
  8. required int numClasses,
})

Implementation

SwinTransformer({
  required int imageSize, // e.g., 224
  required this.patchSize, // e.g., 4
  required int inChannels, // e.g., 3
  required List<int> embedDims, // e.g., [96, 192, 384, 768]
  required List<int> depths, // e.g., [2, 2, 6, 2] - num blocks per stage
  required List<int> numHeads, // e.g., [3, 6, 12, 24] - num heads per stage
  required int windowSize, // e.g., 7
  required int numClasses, // For classification head (optional)
})  : patchEmbedding = PatchEmbedding(
          patchSize: patchSize,
          inChannels: inChannels,
          embedDim: embedDims[0]),
      stages = List.generate(depths.length, (i) {
        final isLastStage = (i == depths.length - 1);
        return SwinStage(
          embedDim: embedDims[i],
          depth: depths[i],
          numHeads: numHeads[i],
          windowSize: windowSize,
          doPatchMerging: !isLastStage, // No merging after the last stage
          inDimForMerging: i > 0
              ? embedDims[i - 1]
              : null, // Pass previous stage's embed dim
        );
      }),
      finalNorm = LayerNorm(embedDims.last);