main function

void main()

Implementation

void main() {
  print("--- Vision Transformer Example ---");

  // Model parameters
  final imageSize = 32; // Example: Small 32x32 image
  final patchSize = 8; // Patches will be 8x8 pixels
  final numChannels = 3; // RGB image
  final embedSize = 64; // Transformer embedding dimension
  final numClasses = 10; // E.g., for CIFAR-10 dataset
  final numLayers = 2; // Small number of layers for quick execution
  final numHeads = 4; // Number of attention heads

  // Instantiate the ViT model
  final vit = VisionTransformer(
    imageSize: imageSize,
    patchSize: patchSize,
    numChannels: numChannels,
    embedSize: embedSize,
    numClasses: numClasses,
    numLayers: numLayers,
    numHeads: numHeads,
  );

  final optimizer = SGD(vit.parameters(), 0.01);

  // --- Dummy Image Data ---
  // A flattened list of pixel values (0.0 to 1.0)
  // Size: imageSize * imageSize * numChannels
  final int totalPixels = imageSize * imageSize * numChannels;
  final Random random = Random();

  // Create a dummy image data and a target class for training
  final List<double> dummyImageData =
      List.generate(totalPixels, (i) => random.nextDouble());
  final int dummyTargetClass =
      random.nextInt(numClasses); // A random target class

  print(
      "Dummy Image Data created (first 10 values): ${dummyImageData.sublist(0, 10).map((v) => v.toStringAsFixed(2)).toList()}...");
  print("Target Class: $dummyTargetClass");

  // --- Training Loop (Simplified) ---
  final epochs = 50; // Run for a few epochs
  print("\nTraining ViT for $epochs epochs...");

  for (int epoch = 0; epoch < epochs; epoch++) {
    // 1. Forward pass
    final logits =
        vit.forward(dummyImageData); // Returns a List<Value> for classes

    // 2. Calculate Cross-Entropy Loss
    // Convert target class to one-hot vector (using Value objects)
    final targetVector = ValueVector(List.generate(
      numClasses,
      (i) => Value(i == dummyTargetClass ? 1.0 : 0.0),
    ));

    // Convert logits (List<Value>) to ValueVector for softmax and crossEntropy
    final logitsVector = ValueVector(logits);

    final loss = logitsVector.softmax().crossEntropy(targetVector);

    // 3. Backward pass and optimization step
    vit.zeroGrad(); // Clear gradients
    loss.backward(); // Compute gradients
    optimizer.step(); // Update parameters

    if (epoch % 10 == 0 || epoch == epochs - 1) {
      print("Epoch $epoch | Loss: ${loss.data.toStringAsFixed(4)}");
    }
  }
  print("✅ ViT training complete.");

  // --- Inference Example ---
  print("\n--- ViT Inference ---");
  final List<double> newDummyImageData = List.generate(
      totalPixels, (i) => random.nextDouble()); // A new random image

  print(
      "New Dummy Image Data created (first 10 values): ${newDummyImageData.sublist(0, 10).map((v) => v.toStringAsFixed(2)).toList()}...");

  final inferenceLogits = vit.forward(newDummyImageData);
  final predictedProbs = ValueVector(inferenceLogits).softmax();

  // Find the predicted class (index with highest probability)
  double maxProb = -1.0;
  int predictedClass = -1;
  for (int i = 0; i < predictedProbs.values.length; i++) {
    if (predictedProbs.values[i].data > maxProb) {
      maxProb = predictedProbs.values[i].data;
      predictedClass = i;
    }
  }

  print(
      "Inference Logits: ${inferenceLogits.map((v) => v.data.toStringAsFixed(4)).toList()}");
  print(
      "Predicted Probabilities: ${predictedProbs.values.map((v) => v.data.toStringAsFixed(4)).toList()}");
  print(
      "Predicted Class: $predictedClass (with probability ${maxProb.toStringAsFixed(4)})");

  print(
      "\nNote: For real-world usage, the `_createPatchesAndEmbeddings` function within ViT would need robust image processing to handle actual image files and their pixel layouts (e.g., from `image` package or custom parsing).");
}