SwinTransformer constructor
SwinTransformer({})
Implementation
SwinTransformer({
required int imageSize, // e.g., 224
required this.patchSize, // e.g., 4
required int inChannels, // e.g., 3
required List<int> embedDims, // e.g., [96, 192, 384, 768]
required List<int> depths, // e.g., [2, 2, 6, 2] - num blocks per stage
required List<int> numHeads, // e.g., [3, 6, 12, 24] - num heads per stage
required int windowSize, // e.g., 7
required int numClasses, // For classification head (optional)
}) : patchEmbedding = PatchEmbedding(
patchSize: patchSize,
inChannels: inChannels,
embedDim: embedDims[0]),
stages = List.generate(depths.length, (i) {
final isLastStage = (i == depths.length - 1);
return SwinStage(
embedDim: embedDims[i],
depth: depths[i],
numHeads: numHeads[i],
windowSize: windowSize,
doPatchMerging: !isLastStage, // No merging after the last stage
inDimForMerging: i > 0
? embedDims[i - 1]
: null, // Pass previous stage's embed dim
);
}),
finalNorm = LayerNorm(embedDims.last);