forward method
The forward pass for the Vision Transformer.
Takes a flattened list of image pixel data. Returns logits for image classification.
Implementation
List<Value> forward(List<double> imageData) {
// 1. Create patch embeddings
final patchEmbeddings = _createPatchesAndEmbeddings(imageData);
// 2. Prepend the learnable [CLS] token
// Create a new ValueVector for clsToken to ensure it's part of the current graph.
// Otherwise, its gradient won't be calculated if it's the same object every time.
final currentClsToken = ValueVector(List.generate(
embedSize,
(i) =>
Value(clsToken.values[i].data, {clsToken.values[i]}, 'cls_copy')));
final sequence = [currentClsToken, ...patchEmbeddings];
// 3. Add positional embeddings
final sequenceWithPositionalEmbeddings =
List.generate(sequence.length, (i) {
return sequence[i] + positionEmbeddings[i];
});
// 4. Pass the sequence through the Transformer Encoder
final encodedFeatures =
transformerEncoder.forwardEmbeddings(sequenceWithPositionalEmbeddings);
// 5. Take the output corresponding to the [CLS] token (first element)
final clsOutput = encodedFeatures[0];
// 6. Pass through the MLP head for classification logits
final logits = mlpHead.forward(clsOutput);
return logits
.values; // Return a list of Value objects for a single prediction
}