getCategoryList method

List<Category> getCategoryList()

Gets a list of Category from the TensorLabel object.

The axis of label should be effectively the last axis (which means every sub tensor specified by this axis should have a flat size of 1), so that each labelled sub tensor could be converted into a float value score. Example: A TensorLabel with shape {2, 5, 3} and axis 2 is valid. If axis is 1 or 0, it cannot be converted into a Category.

TensorLabel.getMapWithFloatValue is an alternative but returns a Map as the result.

Throws StateError if size of a sub tensor on each label is not 1.

Implementation

List<Category> getCategoryList() {
  int labeledAxis = getFirstAxisWithSizeGreaterThanOne(_tensorBuffer);
  SupportPreconditions.checkState(labeledAxis == _shape.length - 1,
      errorMessage:
          "get a Category list is only valid when the only labeled axis is the last one.");
  List<String> labels = _axisLabels[labeledAxis]!;
  List<double> data = _tensorBuffer.getDoubleList();
  SupportPreconditions.checkState(labels.length == data.length);
  List<Category> result = [];
  labels.asMap().forEach((i, label) {
    result.add(Category(label, data[i]));
  });
  return result;
}