tflite_next 0.0.1
tflite_next: ^0.0.1 copied to clipboard
A Flutter plugin to run TensorFlow Lite models on Android and iOS.
import 'dart:io';
import 'dart:math';
import 'dart:typed_data';
import 'package:flutter/material.dart';
import 'package:flutter/services.dart';
import 'package:image_picker/image_picker.dart';
import 'package:tflite_next/tflite_next.dart';
import 'package:image/image.dart' as img;
void main() => runApp(App());
const String mobile = "MobileNet";
const String ssd = "SSD MobileNet";
const String yolo = "Tiny YOLOv2";
const String deeplab = "DeepLab";
const String posenet = "PoseNet";
class App extends StatelessWidget {
const App({super.key});
@override
Widget build(BuildContext context) {
return MaterialApp(debugShowCheckedModeBanner: false, home: MyApp());
}
}
class MyApp extends StatefulWidget {
const MyApp({super.key});
@override
State<MyApp> createState() => _MyAppState();
}
class _MyAppState extends State<MyApp> {
File? _image;
List? _recognitions = [];
String _model = mobile;
double _image_height = 0.0;
double _image_width = 0.0;
bool _busy = false;
ImagePicker _picker = ImagePicker();
/// Functionality for picking an image from the gallery and triggering prediction
Future<void> onPredictImagePicker() async {
XFile? image = await _picker.pickImage(source: ImageSource.gallery);
if (image == null) return;
setState(() {
_busy = true;
});
onPredictImage(File(image.path));
}
/// Functionality for predicting the given image using the selected ML model.
///
/// - If no image is provided, the function returns immediately.
/// - Depending on the currently selected `_model`, it routes the image to:
/// - YOLOv2 Tiny (`onYolov2Tiny`)
/// - SSD MobileNet (`onSSDMobileNet`)
/// - DeepLab (`onSegmentMobileNet`)
/// - PoseNet (`onPoseNet`)
/// - or the default recognizer (`onRecognizeImage`).
/// - After prediction, it resolves the image to fetch its width and height,
/// and updates the state with the image reference and prediction status.
Future<void> onPredictImage(File? image) async {
if (image == null) return;
switch (_model) {
case yolo:
await onYolov2Tiny(image);
break;
case ssd:
await onSSDMobileNet(image);
break;
case deeplab:
await onSegmentMobileNet(image);
break;
case posenet:
await onPoseNet(image);
break;
default:
await onRecognizeImage(image);
}
FileImage(image)
.resolve(ImageConfiguration())
.addListener(
ImageStreamListener((info, _) {
setState(() {
_image_height = info.image.height.toDouble();
_image_width = info.image.width.toDouble();
});
}),
);
setState(() {
_image = image;
_busy = false;
});
}
/// Functionality for initializing the widget state and asynchronously loading the model.
@override
void initState() {
super.initState();
_busy = true;
onLoadModel().then((value) {
setState(() {
_busy = false;
});
});
}
/// Functionality for loading and initializing different TensorFlow Lite models
/// (YOLO, SSD, DeepLab, PoseNet, MobileNet) based on the selected model type.
Future<void> onLoadModel() async {
TfliteNext.close();
try {
String? result = "";
switch (_model) {
case yolo:
result = await TfliteNext.loadModel(model: "assets/yolov2_tiny.tflite", labels: "assets/yolov2_tiny.txt");
break;
case ssd:
result = await TfliteNext.loadModel(model: "assets/ssd_mobilenet.tflite", labels: "assets/ssd_mobilenet.txt");
break;
case deeplab:
result = await TfliteNext.loadModel(model: "assets/deeplabv3_257_mv_gpu.tflite", labels: "assets/deeplabv3_257_mv_gpu.txt");
break;
case posenet:
result = await TfliteNext.loadModel(model: "assets/posenet_mv1_075_float_from_checkpoints.tflite");
break;
default:
result = await TfliteNext.loadModel(model: "assets/mobilenet_v1_1.0_224.tflite", labels: "assets/mobilenet_v1_1.0_224.txt");
}
debugPrint(result);
} on PlatformException {
debugPrint("Failed to load model.");
}
}
/// Functionality for converting an image into a normalized Float32 byte list
/// (used as input tensor for TFLite models).
Uint8List onImageToByteListFloat32(img.Image image, int inputSize, double mean, double std) {
Float32List concverted_bytes = Float32List(1 * inputSize * inputSize * 3);
Float32List buffer = Float32List.view(concverted_bytes.buffer);
int pixel_index = 0;
for (int i = 0; i < inputSize; i++) {
for (int j = 0; j < inputSize; j++) {
img.Pixel pixel = image.getPixel(j, i);
buffer[pixel_index++] = (pixel.r - mean) / std;
buffer[pixel_index++] = (pixel.g - mean) / std;
buffer[pixel_index++] = (pixel.b - mean) / std;
}
}
return concverted_bytes.buffer.asUint8List();
}
/// Functionality for converting an image into a normalized Uint8 byte buffer
/// (width × height × 3 channels) suitable for TensorFlow Lite model input.
Uint8List onImageToByteListUint8(img.Image image, int inputSize) {
Uint8List converted_bytes = Uint8List(1 * inputSize * inputSize * 3);
Uint8List buffer = Uint8List.view(converted_bytes.buffer);
int pixel_index = 0;
for (int i = 0; i < inputSize; i++) {
for (int j = 0; j < inputSize; j++) {
img.Pixel pixel = image.getPixel(j, i);
buffer[pixel_index++] = pixel.r.toInt();
buffer[pixel_index++] = pixel.g.toInt();
buffer[pixel_index++] = pixel.b.toInt();
}
}
return converted_bytes.buffer.asUint8List();
}
/// Functionality for running image recognition using TFLite model,
/// updating the state with recognition results, and logging inference time.
Future<void> onRecognizeImage(File image) async {
int start_time = DateTime.now().millisecondsSinceEpoch;
List<dynamic>? recognitions = await TfliteNext.runModelOnImage(path: image.path, numResults: 6, threshold: 0.05, imageMean: 127.5, imageStd: 127.5);
setState(() {
_recognitions = recognitions;
});
int end_time = new DateTime.now().millisecondsSinceEpoch;
debugPrint("Inference took ${end_time - start_time}ms");
}
/// Functionality for performing image recognition using a TFLite model on binary image data.
Future<void> onRecognizeImageBinary(File image) async {
int start_time = DateTime.now().millisecondsSinceEpoch;
ByteBuffer image_bytes = (await rootBundle.load(image.path)).buffer;
img.Image? ori_image = img.decodeJpg(image_bytes.asUint8List());
img.Image resized_image = img.copyResize(ori_image!, height: 224, width: 224);
var recognitions = await TfliteNext.runModelOnBinary(binary: onImageToByteListFloat32(resized_image, 224, 127.5, 127.5), numResults: 6, threshold: 0.05);
setState(() {
_recognitions = recognitions;
});
int end_time = DateTime.now().millisecondsSinceEpoch;
debugPrint("Inference took ${end_time - start_time}ms");
}
/// Functionality for running object detection on an image using the YOLOv2-Tiny model,
/// updating recognition results, and measuring inference time.
Future<void> onYolov2Tiny(File image) async {
int start_time = DateTime.now().millisecondsSinceEpoch;
var recognitions = await TfliteNext.detectObjectOnImage(
path: image.path,
model: "YOLO",
threshold: 0.3,
imageMean: 0.0,
imageStd: 255.0,
numResultsPerClass: 1,
);
setState(() {
_recognitions = recognitions;
});
int end_time = new DateTime.now().millisecondsSinceEpoch;
debugPrint("Inference took ${end_time - start_time}ms");
}
/// Functionality for performing object detection on an input image
/// using the SSD MobileNet model via TfliteNext, updating recognition
/// results in the UI state, and logging the inference time.
Future<void> onSSDMobileNet(File image) async {
int start_time = DateTime.now().millisecondsSinceEpoch;
var recognitions = await TfliteNext.detectObjectOnImage(path: image.path, numResultsPerClass: 1);
setState(() {
_recognitions = recognitions;
});
int end_time = DateTime.now().millisecondsSinceEpoch;
print("Inference took ${end_time - start_time}ms");
}
/// Functionality for running image segmentation on a given image using MobileNet,
/// updating the recognitions state, and logging inference time.
Future<void> onSegmentMobileNet(File image) async {
int start_time = DateTime.now().millisecondsSinceEpoch;
var recognitions = await TfliteNext.runSegmentationOnImage(path: image.path, imageMean: 127.5, imageStd: 127.5);
setState(() {
_recognitions = recognitions;
});
int end_time = DateTime.now().millisecondsSinceEpoch;
debugPrint("Inference took ${end_time - start_time}");
}
/// Functionality for running PoseNet inference on an input image,
/// retrieving keypoint recognitions, updating the state with results,
/// and logging the inference time.
Future<void> onPoseNet(File image) async {
int start_time = DateTime.now().millisecondsSinceEpoch;
var recognitions = await TfliteNext.runPoseNetOnImage(path: image.path, numResults: 2);
debugPrint("$recognitions");
setState(() {
_recognitions = recognitions;
});
int end_time = DateTime.now().millisecondsSinceEpoch;
debugPrint("Inference took ${end_time - start_time}ms");
}
/// Functionality for handling model selection, loading the chosen model,
/// and running prediction on the current image if available.
onSelect(model) async {
setState(() {
_busy = true;
_model = model;
_recognitions = null;
});
await onLoadModel();
if (_image != null) {
onPredictImage(_image);
} else {
setState(() {
_busy = false;
});
}
}
/// Functionality for rendering detection bounding boxes with labels and confidence scores
List<Widget> onRenderBoxes(Size screen) {
if (_recognitions == null) return [];
double factorX = screen.width;
double factorY = _image_height / _image_width * screen.width;
Color blue = Color.fromRGBO(37, 213, 253, 1.0);
return _recognitions!.map((value) {
return Positioned(
left: value["rect"]["x"] * factorX,
top: value["rect"]["y"] * factorY,
width: value["rect"]["w"] * factorX,
height: value["rect"]["h"] * factorY,
child: Container(
decoration: BoxDecoration(
borderRadius: BorderRadius.all(Radius.circular(8.0)),
border: Border.all(color: blue, width: 2),
),
child: Text(
"${value["detectedClass"]} ${(value["confidenceInClass"] * 100).toStringAsFixed(0)}%",
style: TextStyle(background: Paint()..color = blue, color: Colors.white, fontSize: 12.0),
),
),
);
}).toList();
}
/// Functionality for rendering detected PoseNet keypoints on the screen
/// by mapping model output coordinates to widget positions and displaying
/// them as labeled markers.
List<Widget> onRenderKeypoints(Size screen) {
if (_recognitions == null) return [];
double factorX = screen.width;
double factorY = _image_height / _image_width * screen.width;
var lists = <Widget>[];
for (var value in _recognitions!) {
var color = Color((Random().nextDouble() * 0xFFFFFF).toInt() << 0).withValues(alpha: 1.0);
var list = value["keypoints"].values.map<Widget>((k) {
return Positioned(
left: k["x"] * factorX - 6,
top: k["y"] * factorY - 6,
width: 100,
height: 12,
child: Text("● ${k["part"]}", style: TextStyle(color: color, fontSize: 12.0)),
);
}).toList();
lists.addAll(list);
}
return lists;
}
@override
Widget build(BuildContext context) {
Size size = MediaQuery.of(context).size;
List<Widget> stackChildren = [];
if (_model == deeplab && _recognitions != null) {
stackChildren.add(
Positioned(
top: 0.0,
left: 0.0,
width: size.width,
child: _image == null
? Text('No image selected.')
: Container(
decoration: BoxDecoration(
image: DecorationImage(alignment: Alignment.topCenter, image: MemoryImage(_recognitions as Uint8List), fit: BoxFit.fill),
),
child: Opacity(opacity: 0.3, child: Image.file(_image!)),
),
),
);
} else {
stackChildren.add(Positioned(top: 0.0, left: 0.0, width: size.width, child: _image == null ? Text('No image selected.') : Image.file(_image!)));
}
if (_model == mobile) {
stackChildren.add(
Center(
child: Column(
children: _recognitions != null
? _recognitions!.map((res) {
return Text(
"${res["index"]} - ${res["label"]}: ${res["confidence"].toStringAsFixed(3)}",
style: TextStyle(color: Colors.black, fontSize: 20.0, background: Paint()..color = Colors.white),
);
}).toList()
: [],
),
),
);
} else if (_model == ssd || _model == yolo) {
stackChildren.addAll(onRenderBoxes(size));
} else if (_model == posenet) {
stackChildren.addAll(onRenderKeypoints(size));
}
if (_busy) {
stackChildren.add(const Opacity(opacity: 0.3, child: ModalBarrier(dismissible: false, color: Colors.grey)));
stackChildren.add(const Center(child: CircularProgressIndicator()));
}
return Scaffold(
appBar: AppBar(
title: const Text("TfLite Next Example"),
actions: <Widget>[
PopupMenuButton<String>(
onSelected: onSelect,
itemBuilder: (context) {
List<PopupMenuEntry<String>> menuEntries = [
const PopupMenuItem<String>(value: mobile, child: Text(mobile)),
const PopupMenuItem<String>(value: ssd, child: Text(ssd)),
const PopupMenuItem<String>(value: yolo, child: Text(yolo)),
const PopupMenuItem<String>(value: deeplab, child: Text(deeplab)),
const PopupMenuItem<String>(value: posenet, child: Text(posenet)),
];
return menuEntries;
},
),
],
),
body: Stack(children: stackChildren),
floatingActionButton: FloatingActionButton(onPressed: onPredictImagePicker, tooltip: 'Pick Image', child: Icon(Icons.image)),
);
}
}