tflite_next 0.0.1 copy "tflite_next: ^0.0.1" to clipboard
tflite_next: ^0.0.1 copied to clipboard

A Flutter plugin to run TensorFlow Lite models on Android and iOS.

example/lib/main.dart

import 'dart:io';
import 'dart:math';
import 'dart:typed_data';

import 'package:flutter/material.dart';
import 'package:flutter/services.dart';
import 'package:image_picker/image_picker.dart';
import 'package:tflite_next/tflite_next.dart';
import 'package:image/image.dart' as img;

void main() => runApp(App());

const String mobile = "MobileNet";
const String ssd = "SSD MobileNet";
const String yolo = "Tiny YOLOv2";
const String deeplab = "DeepLab";
const String posenet = "PoseNet";

class App extends StatelessWidget {
  const App({super.key});

  @override
  Widget build(BuildContext context) {
    return MaterialApp(debugShowCheckedModeBanner: false, home: MyApp());
  }
}

class MyApp extends StatefulWidget {
  const MyApp({super.key});

  @override
  State<MyApp> createState() => _MyAppState();
}

class _MyAppState extends State<MyApp> {
  File? _image;
  List? _recognitions = [];
  String _model = mobile;
  double _image_height = 0.0;
  double _image_width = 0.0;
  bool _busy = false;

  ImagePicker _picker = ImagePicker();

  /// Functionality for picking an image from the gallery and triggering prediction
  Future<void> onPredictImagePicker() async {
    XFile? image = await _picker.pickImage(source: ImageSource.gallery);
    if (image == null) return;
    setState(() {
      _busy = true;
    });
    onPredictImage(File(image.path));
  }

  /// Functionality for predicting the given image using the selected ML model.
  ///
  /// - If no image is provided, the function returns immediately.
  /// - Depending on the currently selected `_model`, it routes the image to:
  ///   - YOLOv2 Tiny (`onYolov2Tiny`)
  ///   - SSD MobileNet (`onSSDMobileNet`)
  ///   - DeepLab (`onSegmentMobileNet`)
  ///   - PoseNet (`onPoseNet`)
  ///   - or the default recognizer (`onRecognizeImage`).
  /// - After prediction, it resolves the image to fetch its width and height,
  ///   and updates the state with the image reference and prediction status.
  Future<void> onPredictImage(File? image) async {
    if (image == null) return;
    switch (_model) {
      case yolo:
        await onYolov2Tiny(image);
        break;
      case ssd:
        await onSSDMobileNet(image);
        break;
      case deeplab:
        await onSegmentMobileNet(image);
        break;
      case posenet:
        await onPoseNet(image);
        break;
      default:
        await onRecognizeImage(image);
    }

    FileImage(image)
        .resolve(ImageConfiguration())
        .addListener(
          ImageStreamListener((info, _) {
            setState(() {
              _image_height = info.image.height.toDouble();
              _image_width = info.image.width.toDouble();
            });
          }),
        );

    setState(() {
      _image = image;
      _busy = false;
    });
  }

  /// Functionality for initializing the widget state and asynchronously loading the model.
  @override
  void initState() {
    super.initState();
    _busy = true;
    onLoadModel().then((value) {
      setState(() {
        _busy = false;
      });
    });
  }

  /// Functionality for loading and initializing different TensorFlow Lite models
  /// (YOLO, SSD, DeepLab, PoseNet, MobileNet) based on the selected model type.
  Future<void> onLoadModel() async {
    TfliteNext.close();
    try {
      String? result = "";
      switch (_model) {
        case yolo:
          result = await TfliteNext.loadModel(model: "assets/yolov2_tiny.tflite", labels: "assets/yolov2_tiny.txt");
          break;
        case ssd:
          result = await TfliteNext.loadModel(model: "assets/ssd_mobilenet.tflite", labels: "assets/ssd_mobilenet.txt");
          break;
        case deeplab:
          result = await TfliteNext.loadModel(model: "assets/deeplabv3_257_mv_gpu.tflite", labels: "assets/deeplabv3_257_mv_gpu.txt");
          break;
        case posenet:
          result = await TfliteNext.loadModel(model: "assets/posenet_mv1_075_float_from_checkpoints.tflite");
          break;
        default:
          result = await TfliteNext.loadModel(model: "assets/mobilenet_v1_1.0_224.tflite", labels: "assets/mobilenet_v1_1.0_224.txt");
      }
      debugPrint(result);
    } on PlatformException {
      debugPrint("Failed to load model.");
    }
  }

  /// Functionality for converting an image into a normalized Float32 byte list
  /// (used as input tensor for TFLite models).
  Uint8List onImageToByteListFloat32(img.Image image, int inputSize, double mean, double std) {
    Float32List concverted_bytes = Float32List(1 * inputSize * inputSize * 3);
    Float32List buffer = Float32List.view(concverted_bytes.buffer);
    int pixel_index = 0;

    for (int i = 0; i < inputSize; i++) {
      for (int j = 0; j < inputSize; j++) {
        img.Pixel pixel = image.getPixel(j, i);
        buffer[pixel_index++] = (pixel.r - mean) / std;
        buffer[pixel_index++] = (pixel.g - mean) / std;
        buffer[pixel_index++] = (pixel.b - mean) / std;
      }
    }
    return concverted_bytes.buffer.asUint8List();
  }

  /// Functionality for converting an image into a normalized Uint8 byte buffer
  /// (width × height × 3 channels) suitable for TensorFlow Lite model input.
  Uint8List onImageToByteListUint8(img.Image image, int inputSize) {
    Uint8List converted_bytes = Uint8List(1 * inputSize * inputSize * 3);
    Uint8List buffer = Uint8List.view(converted_bytes.buffer);
    int pixel_index = 0;

    for (int i = 0; i < inputSize; i++) {
      for (int j = 0; j < inputSize; j++) {
        img.Pixel pixel = image.getPixel(j, i);
        buffer[pixel_index++] = pixel.r.toInt();
        buffer[pixel_index++] = pixel.g.toInt();
        buffer[pixel_index++] = pixel.b.toInt();
      }
    }

    return converted_bytes.buffer.asUint8List();
  }

  /// Functionality for running image recognition using TFLite model,
  /// updating the state with recognition results, and logging inference time.
  Future<void> onRecognizeImage(File image) async {
    int start_time = DateTime.now().millisecondsSinceEpoch;
    List<dynamic>? recognitions = await TfliteNext.runModelOnImage(path: image.path, numResults: 6, threshold: 0.05, imageMean: 127.5, imageStd: 127.5);

    setState(() {
      _recognitions = recognitions;
    });
    int end_time = new DateTime.now().millisecondsSinceEpoch;
    debugPrint("Inference took ${end_time - start_time}ms");
  }

  /// Functionality for performing image recognition using a TFLite model on binary image data.
  Future<void> onRecognizeImageBinary(File image) async {
    int start_time = DateTime.now().millisecondsSinceEpoch;
    ByteBuffer image_bytes = (await rootBundle.load(image.path)).buffer;
    img.Image? ori_image = img.decodeJpg(image_bytes.asUint8List());
    img.Image resized_image = img.copyResize(ori_image!, height: 224, width: 224);
    var recognitions = await TfliteNext.runModelOnBinary(binary: onImageToByteListFloat32(resized_image, 224, 127.5, 127.5), numResults: 6, threshold: 0.05);
    setState(() {
      _recognitions = recognitions;
    });
    int end_time = DateTime.now().millisecondsSinceEpoch;
    debugPrint("Inference took ${end_time - start_time}ms");
  }

  /// Functionality for running object detection on an image using the YOLOv2-Tiny model,
  /// updating recognition results, and measuring inference time.
  Future<void> onYolov2Tiny(File image) async {
    int start_time = DateTime.now().millisecondsSinceEpoch;
    var recognitions = await TfliteNext.detectObjectOnImage(
      path: image.path,
      model: "YOLO",
      threshold: 0.3,
      imageMean: 0.0,
      imageStd: 255.0,
      numResultsPerClass: 1,
    );
    setState(() {
      _recognitions = recognitions;
    });
    int end_time = new DateTime.now().millisecondsSinceEpoch;
    debugPrint("Inference took ${end_time - start_time}ms");
  }

  /// Functionality for performing object detection on an input image
  /// using the SSD MobileNet model via TfliteNext, updating recognition
  /// results in the UI state, and logging the inference time.
  Future<void> onSSDMobileNet(File image) async {
    int start_time = DateTime.now().millisecondsSinceEpoch;
    var recognitions = await TfliteNext.detectObjectOnImage(path: image.path, numResultsPerClass: 1);
    setState(() {
      _recognitions = recognitions;
    });
    int end_time = DateTime.now().millisecondsSinceEpoch;
    print("Inference took ${end_time - start_time}ms");
  }

  /// Functionality for running image segmentation on a given image using MobileNet,
  /// updating the recognitions state, and logging inference time.
  Future<void> onSegmentMobileNet(File image) async {
    int start_time = DateTime.now().millisecondsSinceEpoch;
    var recognitions = await TfliteNext.runSegmentationOnImage(path: image.path, imageMean: 127.5, imageStd: 127.5);

    setState(() {
      _recognitions = recognitions;
    });
    int end_time = DateTime.now().millisecondsSinceEpoch;
    debugPrint("Inference took ${end_time - start_time}");
  }

  /// Functionality for running PoseNet inference on an input image,
  /// retrieving keypoint recognitions, updating the state with results,
  /// and logging the inference time.
  Future<void> onPoseNet(File image) async {
    int start_time = DateTime.now().millisecondsSinceEpoch;
    var recognitions = await TfliteNext.runPoseNetOnImage(path: image.path, numResults: 2);

    debugPrint("$recognitions");

    setState(() {
      _recognitions = recognitions;
    });
    int end_time = DateTime.now().millisecondsSinceEpoch;
    debugPrint("Inference took ${end_time - start_time}ms");
  }

  /// Functionality for handling model selection, loading the chosen model,
  /// and running prediction on the current image if available.
  onSelect(model) async {
    setState(() {
      _busy = true;
      _model = model;
      _recognitions = null;
    });
    await onLoadModel();

    if (_image != null) {
      onPredictImage(_image);
    } else {
      setState(() {
        _busy = false;
      });
    }
  }

  /// Functionality for rendering detection bounding boxes with labels and confidence scores
  List<Widget> onRenderBoxes(Size screen) {
    if (_recognitions == null) return [];

    double factorX = screen.width;
    double factorY = _image_height / _image_width * screen.width;
    Color blue = Color.fromRGBO(37, 213, 253, 1.0);
    return _recognitions!.map((value) {
      return Positioned(
        left: value["rect"]["x"] * factorX,
        top: value["rect"]["y"] * factorY,
        width: value["rect"]["w"] * factorX,
        height: value["rect"]["h"] * factorY,
        child: Container(
          decoration: BoxDecoration(
            borderRadius: BorderRadius.all(Radius.circular(8.0)),
            border: Border.all(color: blue, width: 2),
          ),
          child: Text(
            "${value["detectedClass"]} ${(value["confidenceInClass"] * 100).toStringAsFixed(0)}%",
            style: TextStyle(background: Paint()..color = blue, color: Colors.white, fontSize: 12.0),
          ),
        ),
      );
    }).toList();
  }

  /// Functionality for rendering detected PoseNet keypoints on the screen
  /// by mapping model output coordinates to widget positions and displaying
  /// them as labeled markers.
  List<Widget> onRenderKeypoints(Size screen) {
    if (_recognitions == null) return [];

    double factorX = screen.width;
    double factorY = _image_height / _image_width * screen.width;

    var lists = <Widget>[];
    for (var value in _recognitions!) {
      var color = Color((Random().nextDouble() * 0xFFFFFF).toInt() << 0).withValues(alpha: 1.0);
      var list = value["keypoints"].values.map<Widget>((k) {
        return Positioned(
          left: k["x"] * factorX - 6,
          top: k["y"] * factorY - 6,
          width: 100,
          height: 12,
          child: Text("● ${k["part"]}", style: TextStyle(color: color, fontSize: 12.0)),
        );
      }).toList();

      lists.addAll(list);
    }

    return lists;
  }

  @override
  Widget build(BuildContext context) {
    Size size = MediaQuery.of(context).size;
    List<Widget> stackChildren = [];

    if (_model == deeplab && _recognitions != null) {
      stackChildren.add(
        Positioned(
          top: 0.0,
          left: 0.0,
          width: size.width,
          child: _image == null
              ? Text('No image selected.')
              : Container(
                  decoration: BoxDecoration(
                    image: DecorationImage(alignment: Alignment.topCenter, image: MemoryImage(_recognitions as Uint8List), fit: BoxFit.fill),
                  ),
                  child: Opacity(opacity: 0.3, child: Image.file(_image!)),
                ),
        ),
      );
    } else {
      stackChildren.add(Positioned(top: 0.0, left: 0.0, width: size.width, child: _image == null ? Text('No image selected.') : Image.file(_image!)));
    }

    if (_model == mobile) {
      stackChildren.add(
        Center(
          child: Column(
            children: _recognitions != null
                ? _recognitions!.map((res) {
                    return Text(
                      "${res["index"]} - ${res["label"]}: ${res["confidence"].toStringAsFixed(3)}",
                      style: TextStyle(color: Colors.black, fontSize: 20.0, background: Paint()..color = Colors.white),
                    );
                  }).toList()
                : [],
          ),
        ),
      );
    } else if (_model == ssd || _model == yolo) {
      stackChildren.addAll(onRenderBoxes(size));
    } else if (_model == posenet) {
      stackChildren.addAll(onRenderKeypoints(size));
    }

    if (_busy) {
      stackChildren.add(const Opacity(opacity: 0.3, child: ModalBarrier(dismissible: false, color: Colors.grey)));
      stackChildren.add(const Center(child: CircularProgressIndicator()));
    }

    return Scaffold(
      appBar: AppBar(
        title: const Text("TfLite Next Example"),
        actions: <Widget>[
          PopupMenuButton<String>(
            onSelected: onSelect,
            itemBuilder: (context) {
              List<PopupMenuEntry<String>> menuEntries = [
                const PopupMenuItem<String>(value: mobile, child: Text(mobile)),
                const PopupMenuItem<String>(value: ssd, child: Text(ssd)),
                const PopupMenuItem<String>(value: yolo, child: Text(yolo)),
                const PopupMenuItem<String>(value: deeplab, child: Text(deeplab)),
                const PopupMenuItem<String>(value: posenet, child: Text(posenet)),
              ];
              return menuEntries;
            },
          ),
        ],
      ),
      body: Stack(children: stackChildren),
      floatingActionButton: FloatingActionButton(onPressed: onPredictImagePicker, tooltip: 'Pick Image', child: Icon(Icons.image)),
    );
  }
}
2
likes
140
points
13
downloads

Publisher

unverified uploader

Weekly Downloads

A Flutter plugin to run TensorFlow Lite models on Android and iOS.

Repository (GitHub)
View/report issues

Documentation

API reference

License

MIT (license)

Dependencies

flutter, plugin_platform_interface

More

Packages that depend on tflite_next

Packages that implement tflite_next