main function

void main()

Implementation

void main() {
  print("--- Swin Transformer Example Usage ---");

  // 1. Define Model Hyperparameters (Example values, typically from Swin-T, S, B, L configs)
  const int imageSize = 224; // e.g., 224x224 input image
  const int patchSize = 4;   // e.g., 4x4 initial patches
  const int inChannels = 3;  // RGB image
  const int windowSize = 7;  // Window size for attention (e.g., 7x7)

  // Swin-T (Tiny) configuration parameters for example
  // These lists define the embedding dimensions, number of blocks (depth), and heads for each stage.
  const List<int> embedDims = [96, 192, 384, 768]; // C, 2C, 4C, 8C
  const List<int> depths = [2, 2, 6, 2];          // Number of Swin Transformer blocks in each of 4 stages
  const List<int> numHeads = [3, 6, 12, 24];      // Number of attention heads in each stage
  const int numClasses = 1000;                    // Example: For ImageNet classification

  // Calculate the initial number of patches (height and width)
  final int initialNumPatchesH = imageSize ~/ patchSize;
  final int initialNumPatchesW = imageSize ~/ patchSize;
  final int totalInitialPatches = initialNumPatchesH * initialNumPatchesW;

  print("Model Configuration:");
  print("  Image Size: $imageSize x $imageSize");
  print("  Patch Size: $patchSize x $patchSize");
  print("  Initial Patch Grid: $initialNumPatchesH x $initialNumPatchesW ($totalInitialPatches patches)");
  print("  Window Size: $windowSize x $windowSize");
  print("  Embed Dims (Stages): $embedDims");
  print("  Depths (Blocks per Stage): $depths");
  print("  Num Heads (per Stage): $numHeads");
  print("  Number of Classes (for classification head): $numClasses");

  // 2. Instantiate the SwinTransformer model
  print("\nInitializing SwinTransformer...");
  final model = SwinTransformer(
    imageSize: imageSize,
    patchSize: patchSize,
    inChannels: inChannels,
    embedDims: embedDims,
    depths: depths,
    numHeads: numHeads,
    windowSize: windowSize,
    numClasses: numClasses,
  );
  print("SwinTransformer initialized. Total parameters: ${model.parameters().length}");

  // 3. Prepare a dummy image input
  // Simulate a flattened image, where each "pixel" is a Value (float).
  // For a 224x224 RGB image, total pixels = 224*224*3
  // And each patch is 4x4x3 = 48 values.
  // The PatchEmbedding expects a List<ValueVector>, where each ValueVector
  // represents a flattened patch.
  print("\nPreparing dummy image input...");
  final List<ValueVector> dummyImagePatches = List.generate(
    totalInitialPatches, // Number of patches
    (patchIdx) => ValueVector(
      List.generate(
        patchSize * patchSize * inChannels, // Size of a flattened patch
        (valIdx) => Value(math.Random().nextDouble() * 0.1), // Small random values
      ),
    ),
  );
  print("Dummy image input created: ${dummyImagePatches.length} patches, each with ${dummyImagePatches[0].values.length} features.");

  // 4. Perform a forward pass
  print("\nPerforming forward pass...");
  try {
    final List<ValueVector> outputFeatures = model.forward(
      dummyImagePatches,
      imageSize,
      imageSize,
    );
    print("Forward pass successful!");

    // The output features `outputFeatures` will be the feature map from the last stage.
    // Its dimensions will be: (H_final_patches * W_final_patches, embedDims.last)
    // For Swin-T with 224x224, patch=4, after 4 stages of 2x downsampling:
    // H_final_patches = 224 / 4 / (2^3) = 56 / 8 = 7
    // W_final_patches = 224 / 4 / (2^3) = 56 / 8 = 7
    // So, it should be 7 * 7 = 49 patches, each with embedDims.last (768) features.
    print("Output Features:");
    print("  Number of output tokens (patches): ${outputFeatures.length}");
    if (outputFeatures.isNotEmpty) {
      print("  Feature dimension per token: ${outputFeatures[0].values.length}");
      print("  Expected final patch grid size: ${imageSize ~/ (patchSize * math.pow(2, depths.length -1))} x ${imageSize ~/ (patchSize * math.pow(2, depths.length -1))} = ${outputFeatures.length}");
      print("  Expected feature dimension: ${embedDims.last}");
    }

    // You would then typically add a classification head (linear layer + softmax)
    // for classification tasks or other heads for detection/segmentation.
    // For instance, a global average pooling could be applied to `outputFeatures`
    // followed by a linear layer to `numClasses`.

  } catch (e) {
    print("An error occurred during forward pass: $e");
    print("Please ensure all dependencies (Value, ValueVector, Module, Layer, etc.) are correctly implemented and imported.");
  }

  print("\n--- End of Swin Transformer Example ---");
}