VisionTransformer class

A Vision Transformer (ViT) model for image classification.

This model processes an image by dividing it into patches, linearly embedding them, adding positional information, and feeding them through a Transformer Encoder. The output of a special CLS token is then used for classification.

Inheritance

Constructors

VisionTransformer({required int imageSize, required int patchSize, int numChannels = 3, required int embedSize, required int numClasses, int numLayers = 2, int numHeads = 4})

Properties

clsToken ValueVector
final
embedSize int
final
hashCode int
The hash code for this object.
no setterinherited
imageSize int
final
mlpHead Layer
final
numChannels int
final
numClasses int
final
numHeads int
final
numLayers int
final
patchProjection Layer
final
patchSize int
final
positionEmbeddings List<ValueVector>
final
runtimeType Type
A representation of the runtime type of the object.
no setterinherited
transformerEncoder TransformerEncoder
final

Methods

forward(List<double> imageData) List<Value>
The forward pass for the Vision Transformer.
noSuchMethod(Invocation invocation) → dynamic
Invoked when a nonexistent method or property is accessed.
inherited
parameters() List<Value>
override
toString() String
A string representation of this object.
inherited
zeroGrad() → void
inherited

Operators

operator ==(Object other) bool
The equality operator.
inherited