ai_image_classifier 1.0.0
ai_image_classifier: ^1.0.0 copied to clipboard
A simple and powerful Flutter package for on-device image classification using TensorFlow Lite models. Supports custom models, image resizing, and normalization.
import 'dart:io';
import 'package:flutter/material.dart';
import 'package:image_picker/image_picker.dart';
import 'package:ai_image_classifier/ai_image_classifier.dart';
void main() {
runApp(const MyApp());
}
class MyApp extends StatelessWidget {
const MyApp({super.key});
@override
Widget build(BuildContext context) {
return MaterialApp(
title: 'AI Image Classifier',
theme: ThemeData(
colorScheme: ColorScheme.fromSeed(seedColor: Colors.deepPurple, brightness: Brightness.dark),
useMaterial3: true,
),
home: const HomePage(),
);
}
}
class HomePage extends StatefulWidget {
const HomePage({super.key});
@override
State<HomePage> createState() => _HomePageState();
}
class _HomePageState extends State<HomePage> {
final AiImageClassifier _classifier = AiImageClassifier();
File? _image;
List<Classification>? _results;
bool _isLoading = false;
final ImagePicker _picker = ImagePicker();
@override
void initState() {
super.initState();
_loadModel();
}
Future<void> _loadModel() async {
try {
await _classifier.loadModel(
modelPath: 'assets/models/mobilenet_v1.tflite',
labelsPath: 'assets/models/labels.txt',
);
} catch (e) {
debugPrint("Error loading model: $e");
}
}
Future<void> _pickImage(ImageSource source) async {
final XFile? pickedFile = await _picker.pickImage(source: source);
if (pickedFile != null) {
setState(() {
_image = File(pickedFile.path);
_isLoading = true;
_results = null;
});
try {
final results = await _classifier.classifyImagePath(pickedFile.path);
setState(() {
_results = results;
_isLoading = false;
});
} catch (e) {
debugPrint("Error classifying image: $e");
setState(() {
_isLoading = false;
});
}
}
}
@override
void dispose() {
_classifier.dispose();
super.dispose();
}
@override
Widget build(BuildContext context) {
return Scaffold(
appBar: AppBar(
title: const Text('AI Image Classifier'),
centerTitle: true,
elevation: 2,
),
body: SingleChildScrollView(
child: Column(
children: [
const SizedBox(height: 20),
_buildImageDisplay(),
const SizedBox(height: 20),
_buildActionButtons(),
const SizedBox(height: 30),
if (_isLoading)
const CircularProgressIndicator()
else if (_results != null)
_buildResultsList(),
],
),
),
);
}
Widget _buildImageDisplay() {
return Center(
child: Container(
width: 300,
height: 300,
decoration: BoxDecoration(
color: Colors.grey[900],
borderRadius: BorderRadius.circular(20),
boxShadow: [
BoxShadow(
color: Colors.black.withOpacity(0.5),
blurRadius: 10,
offset: const Offset(0, 5),
),
],
),
child: ClipRRect(
borderRadius: BorderRadius.circular(20),
child: _image != null
? Image.file(_image!, fit: BoxFit.cover)
: const Column(
mainAxisAlignment: MainAxisAlignment.center,
children: [
Icon(Icons.image, size: 100, color: Colors.grey),
SizedBox(height: 10),
Text("No image selected", style: TextStyle(color: Colors.grey)),
],
),
),
),
);
}
Widget _buildActionButtons() {
return Row(
mainAxisAlignment: MainAxisAlignment.spaceEvenly,
children: [
ElevatedButton.icon(
onPressed: () => _pickImage(ImageSource.gallery),
icon: const Icon(Icons.photo_library),
label: const Text("Gallery"),
style: ElevatedButton.styleFrom(
padding: const EdgeInsets.symmetric(horizontal: 20, vertical: 12),
),
),
ElevatedButton.icon(
onPressed: () => _pickImage(ImageSource.camera),
icon: const Icon(Icons.camera_alt),
label: const Text("Camera"),
style: ElevatedButton.styleFrom(
padding: const EdgeInsets.symmetric(horizontal: 20, vertical: 12),
backgroundColor: Theme.of(context).colorScheme.primary,
foregroundColor: Theme.of(context).colorScheme.onPrimary,
),
),
],
);
}
Widget _buildResultsList() {
return Padding(
padding: const EdgeInsets.symmetric(horizontal: 20),
child: Column(
crossAxisAlignment: CrossAxisAlignment.start,
children: [
const Text(
"Classification Results:",
style: TextStyle(fontSize: 18, fontWeight: FontWeight.bold),
),
const SizedBox(height: 10),
ListView.builder(
shrinkWrap: true,
physics: const NeverScrollableScrollPhysics(),
itemCount: _results!.length,
itemBuilder: (context, index) {
final res = _results![index];
return Card(
margin: const EdgeInsets.only(bottom: 10),
child: ListTile(
title: Text(res.label, style: const TextStyle(fontWeight: FontWeight.w600)),
trailing: Text(
"${(res.confidence * 100).toStringAsFixed(1)}%",
style: TextStyle(
color: Theme.of(context).colorScheme.primary,
fontWeight: FontWeight.bold,
fontSize: 16,
),
),
subtitle: LinearProgressIndicator(
value: res.confidence,
backgroundColor: Colors.grey[800],
borderRadius: BorderRadius.circular(10),
),
),
);
},
),
],
),
);
}
}