VisionTransformer constructor

VisionTransformer({
  1. required int imageSize,
  2. required int patchSize,
  3. int numChannels = 3,
  4. required int embedSize,
  5. required int numClasses,
  6. int numLayers = 2,
  7. int numHeads = 4,
})

Implementation

VisionTransformer({
  required this.imageSize,
  required this.patchSize,
  this.numChannels = 3,
  required this.embedSize,
  required this.numClasses,
  this.numLayers = 2, // Reduced for faster example execution
  this.numHeads = 4, // Reduced for faster example execution
})  : assert(imageSize % patchSize == 0,
          "Image size must be divisible by patch size"),
      assert(embedSize % numHeads == 0,
          "Embed size must be divisible by numHeads"),
      // Patch embedding converts (patch_size * patch_size * num_channels) into embedSize
      patchProjection =
          Layer.fromNeurons(patchSize * patchSize * numChannels, embedSize),
      // Initialize CLS token as a learnable vector
      clsToken = ValueVector.fromDoubleList(List.generate(
          embedSize, (j) => math.Random().nextDouble() * 0.02 - 0.01)),
      // Calculate the number of patches along one side
      // Example: 224 / 16 = 14 patches per side -> 14 * 14 = 196 patches total
      // Plus 1 for the [CLS] token: (num_patches + 1) positions
      positionEmbeddings = List.generate(
          (imageSize ~/ patchSize) * (imageSize ~/ patchSize) + 1,
          (i) => ValueVector.fromDoubleList(List.generate(
              embedSize, (j) => math.Random().nextDouble() * 0.02 - 0.01))),
      // The TransformerEncoder is used as the backbone
      // Its blockSize must match the sequence length (patches + CLS token)
      transformerEncoder = TransformerEncoder(
        vocabSize: 0, // Not used, as embeddings are provided directly
        embedSize: embedSize,
        blockSize: (imageSize ~/ patchSize) * (imageSize ~/ patchSize) + 1,
        numLayers: numLayers,
        numHeads: numHeads,
      ),
      // Classification head
      mlpHead = Layer.fromNeurons(embedSize, numClasses);