loadClassificationModel static method

Future<ClassificationModel> loadClassificationModel(
  1. String path,
  2. int imageWidth,
  3. int imageHeight,
  4. int numberOfClasses, {
  5. String? labelPath,
  6. bool ensureMatchingNumberOfClasses = true,
  7. ModelLocation modelLocation = ModelLocation.asset,
  8. LabelsLocation labelsLocation = LabelsLocation.asset,
})

Sets pytorch model path and returns Model

Implementation

static Future<ClassificationModel> loadClassificationModel(
    String path, int imageWidth, int imageHeight, int numberOfClasses,
    {String? labelPath,
    bool ensureMatchingNumberOfClasses = true,
    ModelLocation modelLocation = ModelLocation.asset,
    LabelsLocation labelsLocation = LabelsLocation.asset}) async {
  if (modelLocation == ModelLocation.asset) {
    path = await _getAbsolutePath(path);
  }

  int index =
      await ModelApi().loadModel(path, null, imageWidth, imageHeight, null);
  List<String> labels = [];
  if (labelPath != null) {
    String labelData =
        await _loadLabelsFile(labelPath, labelsLocation: labelsLocation);
    if (labelPath.endsWith(".txt")) {
      labels = await _getLabelsTxt(labelData);
    } else {
      labels = await _getLabelsCsv(labelData);
    }
    if (ensureMatchingNumberOfClasses) {
      if (labels.length != numberOfClasses) {
        throw Exception(
            "Number of labels does not match number of classes ,labels ${labels.length} classes $numberOfClasses");
      }
    }
  }

  return ClassificationModel(index, labels, imageWidth, imageHeight);
}