main function
void
main()
Implementation
void main() {
print("--- ViT-based Face Detection and Recognition Example ---");
// Model parameters
final imageSize = 32; // Example: Small 32x32 image
final patchSize = 8; // Patches will be 8x8 pixels
final numChannels = 3; // RGB image
final embedSize = 64; // Transformer embedding dimension
final numIdentities =
5; // Number of distinct people/identities to recognize (classes 0 to 4)
final numLayers = 2;
final numHeads = 4;
final numQueries =
5; // Fixed number of object predictions the model will output
final embeddingDim = 128; // Dimension of the face embedding for recognition
print("Model Configuration:");
print(" Image Size: $imageSize x $imageSize");
print(" Patch Size: $patchSize x $patchSize");
print(" Embed Size: $embedSize");
print(" Num Identities (Classes): $numIdentities");
print(" Num Queries (Max Objects Predicted): $numQueries");
print(" Embedding Dimension: $embeddingDim");
// Instantiate the ViTObjectDetector model (now handling face detection + recognition)
final faceDetectorRecognizer = ViTObjectDetector(
imageSize: imageSize,
patchSize: patchSize,
numChannels: numChannels,
embedSize: embedSize,
numLayers: numLayers,
numHeads: numHeads,
numClasses: numIdentities, // Pass numIdentities as numClasses
numQueries: numQueries,
embeddingDim: embeddingDim, // Pass new parameter
);
final optimizer = SGD(faceDetectorRecognizer.parameters(), 0.01);
// --- Dummy Image Data and Ground Truth ---
final int totalPixels = imageSize * imageSize * numChannels;
final Random random = Random();
// Dummy image data
final List<double> dummyImageData =
List.generate(totalPixels, (i) => random.nextDouble());
// Dummy Ground Truth for MULTIPLE faces:
// Each map represents one ground truth face:
// {'bbox': [x,y,w,h], 'class_id': int (identity), 'embedding': List<double>}
final List<Map<String, dynamic>> gtObjects = [
{
'bbox': [0.1, 0.1, 0.2, 0.2],
'class_id': random.nextInt(numIdentities),
'embedding': List.generate(
embeddingDim, (i) => random.nextDouble() * 2 - 1) // Random embedding
},
{
'bbox': [0.5, 0.5, 0.3, 0.3],
'class_id': random.nextInt(numIdentities),
'embedding':
List.generate(embeddingDim, (i) => random.nextDouble() * 2 - 1)
},
{
'bbox': [0.8, 0.2, 0.15, 0.25],
'class_id': random.nextInt(numIdentities),
'embedding':
List.generate(embeddingDim, (i) => random.nextDouble() * 2 - 1)
},
];
print(
"Dummy Image Data created (first 10 values): ${dummyImageData.sublist(0, 10).map((v) => v.toStringAsFixed(2)).toList()}...");
print(
"Ground Truth Objects: ${gtObjects.map((obj) => 'Bbox: ${obj['bbox'].map((v) => v.toStringAsFixed(2)).toList()}, Class: ${obj['class_id']}, Embedding (first 3): ${obj['embedding'].sublist(0, 3).map((v) => v.toStringAsFixed(2)).toList()}...').toList()}");
// --- Helper for calculating cost between a predicted object and a ground truth object ---
// This cost is used for bipartite matching.
Value calculatePairwiseCost(
ValueVector predBbox,
ValueVector predLogits,
ValueVector predEmbedding,
List<double> gtBbox,
int gtClassId,
List<double> gtEmbedding,
int numIdentities,
int embeddingDim) {
// Bounding Box Cost (L1 Loss)
Value bboxCost = Value(0.0);
for (int i = 0; i < 4; i++) {
bboxCost += (predBbox.values[i] - Value(gtBbox[i])).abs();
}
bboxCost = bboxCost / Value(4.0); // Average L1 cost
// Classification Cost (Negative Log-Likelihood of the true class)
final List<Value> logProbs =
predLogits.softmax().values.map((v) => v.log()).toList();
if (gtClassId >= logProbs.length || gtClassId < 0) {
return Value(double.infinity); // Invalid class ID, assign high cost
}
final Value classCost = -logProbs[gtClassId];
// NEW: Embedding Cost (L1 Loss between embeddings)
Value embeddingCost = Value(0.0);
for (int i = 0; i < embeddingDim; i++) {
embeddingCost += (predEmbedding.values[i] - Value(gtEmbedding[i])).abs();
}
embeddingCost = embeddingCost / Value(embeddingDim.toDouble());
// Total cost (weighted sum)
// Adjust weights as needed for training balance
final Value totalPairCost = bboxCost * Value(1.0) +
classCost * Value(1.0) +
embeddingCost * Value(0.5); // Added embedding cost
return totalPairCost;
}
// --- Training Loop with Conceptual Hungarian Matching ---
final epochs = 400; // Increased epochs for more complex task
print("\nTraining Face Detector and Recognizer for $epochs epochs...");
for (int epoch = 0; epoch < epochs; epoch++) {
// 1. Forward pass
final Map<String, List<ValueVector>> predictions =
faceDetectorRecognizer.forward(dummyImageData);
final List<ValueVector> predictedBboxes = predictions['boxes']!;
final List<ValueVector> predictedLogits = predictions['logits']!;
final List<ValueVector> predictedEmbeddings =
predictions['embeddings']!; // NEW
// 2. Prepare Cost Matrix for Hungarian Algorithm
final List<List<Value>> costMatrix = List.generate(
numQueries, (_) => List.generate(gtObjects.length, (_) => Value(0.0)));
for (int pIdx = 0; pIdx < numQueries; pIdx++) {
for (int gIdx = 0; gIdx < gtObjects.length; gIdx++) {
costMatrix[pIdx][gIdx] = calculatePairwiseCost(
predictedBboxes[pIdx],
predictedLogits[pIdx],
predictedEmbeddings[pIdx], // Pass predicted embedding
gtObjects[gIdx]['bbox'] as List<double>,
gtObjects[gIdx]['class_id'] as int,
gtObjects[gIdx]['embedding']
as List<double>, // Pass ground truth embedding
numIdentities,
embeddingDim,
);
}
}
// 3. Perform Bipartite Matching (Conceptual Hungarian Algorithm)
final Map<int, int> assignments = _hungarianAlgorithm(costMatrix);
// 4. Calculate Loss based on Assignments
Value totalLoss = Value(0.0);
// Loss for matched objects
final Set<int> matchedPredIndices = assignments.keys.toSet();
for (var entry in assignments.entries) {
final int predIdx = entry.key;
final int gtIdx = entry.value;
final ValueVector currentPredictedBbox = predictedBboxes[predIdx];
final ValueVector currentPredictedLogits = predictedLogits[predIdx];
final ValueVector currentPredictedEmbedding =
predictedEmbeddings[predIdx]; // NEW
final List<double> currentGtBboxCoords =
gtObjects[gtIdx]['bbox'] as List<double>;
final int currentGtClassId = gtObjects[gtIdx]['class_id'] as int;
final List<double> currentGtEmbedding =
gtObjects[gtIdx]['embedding'] as List<double>; // NEW
// Bounding Box Loss (L1 Loss)
Value bboxLoss = Value(0.0);
for (int i = 0; i < 4; i++) {
bboxLoss +=
(currentPredictedBbox.values[i] - Value(currentGtBboxCoords[i]))
.abs();
}
bboxLoss = bboxLoss / Value(4.0);
// Classification Loss (Cross-Entropy for matched class)
final gtClassVector = ValueVector(List.generate(
numIdentities + 1, // numIdentities classes + 1 for background
(i) => Value(i == currentGtClassId ? 1.0 : 0.0),
));
final classLoss =
currentPredictedLogits.softmax().crossEntropy(gtClassVector);
// NEW: Embedding Loss (L1 Loss for recognition)
Value embeddingLoss = Value(0.0);
for (int i = 0; i < embeddingDim; i++) {
embeddingLoss +=
(currentPredictedEmbedding.values[i] - Value(currentGtEmbedding[i]))
.abs();
}
embeddingLoss = embeddingLoss / Value(embeddingDim.toDouble());
totalLoss += bboxLoss + classLoss + embeddingLoss; // Add embedding loss
}
// Loss for unmatched predicted objects (they should predict background)
for (int pIdx = 0; pIdx < numQueries; pIdx++) {
if (!matchedPredIndices.contains(pIdx)) {
final ValueVector currentPredictedLogits = predictedLogits[pIdx];
// Target is background class (numIdentities is the background ID)
final gtBackgroundClassVector = ValueVector(List.generate(
numIdentities + 1,
(i) => Value(i == numIdentities ? 1.0 : 0.0),
));
final backgroundClassLoss = currentPredictedLogits
.softmax()
.crossEntropy(gtBackgroundClassVector);
totalLoss += backgroundClassLoss;
// No bounding box loss or embedding loss for background predictions
}
}
// 4. Backward pass and optimization step
faceDetectorRecognizer.zeroGrad(); // Clear gradients
totalLoss.backward(); // Compute gradients
optimizer.step(); // Update parameters
if (epoch % 20 == 0 || epoch == epochs - 1) {
print("Epoch $epoch | Total Loss: ${totalLoss.data.toStringAsFixed(4)}");
}
}
print("✅ Face Detector and Recognizer training complete.");
// --- Inference Example ---
print("\n--- Face Detector and Recognizer Inference ---");
final List<double> newDummyImageData = List.generate(
totalPixels, (i) => random.nextDouble()); // A new random image
print(
"New Dummy Image Data created (first 10 values): ${newDummyImageData.sublist(0, 10).map((v) => v.toStringAsFixed(2)).toList()}...");
final Map<String, List<ValueVector>> inferencePredictions =
faceDetectorRecognizer.forward(newDummyImageData);
final List<ValueVector> inferredBboxes = inferencePredictions['boxes']!;
final List<ValueVector> inferredLogits = inferencePredictions['logits']!;
final List<ValueVector> inferredEmbeddings =
inferencePredictions['embeddings']!; // NEW
// --- Simulate a database of known face embeddings ---
// In a real system, these would be pre-computed embeddings of known individuals.
final Map<int, List<double>> knownFaceDatabase = {};
for (int i = 0; i < numIdentities; i++) {
// Generate a unique (but random for this demo) embedding for each identity
knownFaceDatabase[i] =
List.generate(embeddingDim, (j) => random.nextDouble() * 2 - 1);
}
print("\nSimulated Known Face Database (first 3 values of each embedding):");
knownFaceDatabase.forEach((id, emb) {
print(
" Identity $id: ${emb.sublist(0, 3).map((v) => v.toStringAsFixed(2)).toList()}...");
});
print("\nInferred Faces:");
for (int q = 0; q < numQueries; q++) {
final ValueVector currentInferredBbox = inferredBboxes[q];
final ValueVector currentInferredLogits = inferredLogits[q];
final ValueVector currentInferredProbs = currentInferredLogits.softmax();
final ValueVector currentInferredEmbedding = inferredEmbeddings[q]; // NEW
// Find the predicted class (identity or background)
double maxProb = -1.0;
int predictedClass = -1;
for (int i = 0; i < currentInferredProbs.values.length; i++) {
if (currentInferredProbs.values[i].data > maxProb) {
maxProb = currentInferredProbs.values[i].data;
predictedClass = i;
}
}
// If a face is detected (not background and high confidence)
if (predictedClass != numIdentities && maxProb > 0.5) {
print(" Predicted Face ${q + 1}:");
print(
" Bbox: ${currentInferredBbox.values.map((v) => v.data.toStringAsFixed(4)).toList()}");
print(
" Detection Class: $predictedClass (Prob: ${maxProb.toStringAsFixed(4)})");
print(
" Embedding (first 3): ${currentInferredEmbedding.values.map((v) => v.data.toStringAsFixed(4)).toList().sublist(0, 3)}...");
// --- Face Recognition Logic ---
// Compare the inferred embedding to the known face database
double minDistance = double.infinity;
int recognizedIdentity = -1;
knownFaceDatabase.forEach((identityId, knownEmbedding) {
double currentDistance = 0.0;
for (int i = 0; i < embeddingDim; i++) {
currentDistance +=
(currentInferredEmbedding.values[i].data - knownEmbedding[i])
.abs(); // L1 distance
}
currentDistance /= embeddingDim; // Average distance
if (currentDistance < minDistance) {
minDistance = currentDistance;
recognizedIdentity = identityId;
}
});
// Set a threshold for recognition
const double recognitionThreshold =
0.5; // Example threshold (needs tuning)
if (minDistance < recognitionThreshold) {
print(
" Recognized as Identity: $recognizedIdentity (Distance: ${minDistance.toStringAsFixed(4)})");
} else {
print(
" Identity: Unknown (Closest: $recognizedIdentity, Distance: ${minDistance.toStringAsFixed(4)})");
}
} else if (predictedClass == numIdentities && maxProb > 0.5) {
print(
" Predicted Object ${q + 1}: Background (Prob: ${maxProb.toStringAsFixed(4)})");
} else {
print(
" Predicted Object ${q + 1}: Low confidence prediction (Class: $predictedClass, Prob: ${maxProb.toStringAsFixed(4)}) - Likely background or noise");
}
}
print(
"\nNote: This example demonstrates face detection and recognition conceptually. "
"Real-world systems use large, diverse datasets, advanced metric learning losses "
"(e.g., Triplet Loss, ArcFace), more robust matching algorithms (Hungarian), "
"and sophisticated post-processing (NMS) for accurate results.");
}