forward method

List<Value> forward(
  1. List<double> imageData
)

The forward pass for the Vision Transformer.

Takes a flattened list of image pixel data. Returns logits for image classification.

Implementation

List<Value> forward(List<double> imageData) {
  // 1. Create patch embeddings
  final patchEmbeddings = _createPatchesAndEmbeddings(imageData);

  // 2. Prepend the learnable [CLS] token
  // Create a new ValueVector for clsToken to ensure it's part of the current graph.
  // Otherwise, its gradient won't be calculated if it's the same object every time.
  final currentClsToken = ValueVector(List.generate(
      embedSize,
      (i) =>
          Value(clsToken.values[i].data, {clsToken.values[i]}, 'cls_copy')));
  final sequence = [currentClsToken, ...patchEmbeddings];

  // 3. Add positional embeddings
  final sequenceWithPositionalEmbeddings =
      List.generate(sequence.length, (i) {
    return sequence[i] + positionEmbeddings[i];
  });

  // 4. Pass the sequence through the Transformer Encoder
  final encodedFeatures =
      transformerEncoder.forwardEmbeddings(sequenceWithPositionalEmbeddings);

  // 5. Take the output corresponding to the [CLS] token (first element)
  final clsOutput = encodedFeatures[0];

  // 6. Pass through the MLP head for classification logits
  final logits = mlpHead.forward(clsOutput);

  return logits
      .values; // Return a list of Value objects for a single prediction
}