forward method
Forward pass for the object detection head.
Takes a single ValueVector representing the aggregated image feature
(e.g., the CLS token output from the ViT backbone).
Returns a Map containing:
- 'boxes': List of
numQueriesValueVectors, each of 4 bounding box coordinates - 'logits': List of
numQueriesValueVectors, each of (numClasses + 1) class logits
Implementation
Map<String, List<ValueVector>> forward(ValueVector backboneFeature) {
// Predict flattened bounding box coordinates for all queries
final ValueVector rawBboxesFlat =
bboxRegressionHead.forward(backboneFeature);
// Predict flattened class logits for all queries
final ValueVector classLogitsFlat =
classPredictionHead.forward(backboneFeature);
// Reshape flattened outputs into lists of ValueVectors for each query
final List<ValueVector> predictedBoxes = [];
for (int i = 0; i < numQueries; i++) {
predictedBoxes.add(ValueVector(rawBboxesFlat.values
.sublist(i * numBoxCoords, (i + 1) * numBoxCoords)));
}
final List<ValueVector> predictedLogits = [];
for (int i = 0; i < numQueries; i++) {
predictedLogits.add(ValueVector(classLogitsFlat.values
.sublist(i * (numClasses + 1), (i + 1) * (numClasses + 1))));
}
return {
'boxes': predictedBoxes,
'logits': predictedLogits,
};
}