ViTBackbone constructor
ViTBackbone({})
Implementation
ViTBackbone({
required this.imageSize,
required this.patchSize,
this.numChannels = 3,
required this.embedSize,
int numLayers = 4,
int numHeads = 4,
}) : transformerEncoder = TransformerEncoder(
vocabSize: 1, // Not used for vision
embedSize: embedSize,
blockSize: (imageSize ~/ patchSize) * (imageSize ~/ patchSize) + 1,
numLayers: numLayers,
numHeads: numHeads,
) {
int numPatches = (imageSize ~/ patchSize) * (imageSize ~/ patchSize);
// 1. Projection layer (GPU weights initialized via Layer class)
patchProjection = Layer(
patchSize * patchSize * numChannels,
embedSize,
useGelu: false,
);
// 2. Learnable [CLS] token [1, embedSize]
clsToken = Tensor.random([1, embedSize], scale: 0.02);
// 3. Positional Embeddings [N+1, embedSize]
posEmbeddings = Tensor.random([numPatches + 1, embedSize], scale: 0.02);
}