getMapWithTensorBuffer method

Map<String, TensorBuffer> getMapWithTensorBuffer()

Gets the map with a pair of the label and the corresponding TensorBuffer. Only allow the mapping on the first axis with size greater than 1 currently.

Implementation

Map<String, TensorBuffer> getMapWithTensorBuffer() {
  int labeledAxis = getFirstAxisWithSizeGreaterThanOne(_tensorBuffer);

  Map<String, TensorBuffer> labelToTensorMap = {};
  SupportPreconditions.checkArgument(_axisLabels.containsKey(labeledAxis),
      errorMessage:
          "get a <String, TensorBuffer> map requires the labels are set on the first non-1 axis.");
  List<String> labels = _axisLabels[labeledAxis]!;

  TfLiteType dataType = _tensorBuffer.getDataType();
  int typeSize = _tensorBuffer.getTypeSize();
  int flatSize = _tensorBuffer.getFlatSize();

  // Gets the underlying bytes that could be used to generate the sub-array later.
  ByteBuffer byteBuffer = _tensorBuffer.getBuffer();

  // Note: computation below is only correct when labeledAxis is the first axis with size greater
  // than 1.
  int subArrayLength = (flatSize / _shape[labeledAxis]).floor() * typeSize;
  SupportPreconditions.checkNotNull(labels,
      message: "Label list should never be null");
  labels.asMap().forEach((i, label) {
    ByteData bData = byteBuffer.asByteData(i * subArrayLength);
    TensorBuffer labelBuffer = TensorBuffer.createDynamic(dataType);
    labelBuffer.loadBuffer(bData.buffer,
        shape: _shape.sublist(labeledAxis + 1, _shape.length));
    labelToTensorMap[label] = labelBuffer;
  });
  return labelToTensorMap;
}