cat static method

Tensor cat(
  1. List<Tensor> tensors, [
  2. int dim = 0
])

Concatenate a list of tensors along a given dimension.

Implementation

static Tensor cat(List<Tensor> tensors, [int dim = 0]) {
  if (tensors.length == 1) return tensors[0];
  final ndim = tensors[0].ndim;
  if (dim < 0) dim = ndim + dim;

  // Compute output shape
  final outShape = List<int>.from(tensors[0].shape);
  for (int i = 1; i < tensors.length; i++) {
    outShape[dim] += tensors[i].shape[dim];
  }

  final result = Tensor.zeros(outShape);
  int offset = 0;
  final outerDims = outShape.sublist(0, dim);
  final outerSize = outerDims.isEmpty ? 1 : _productOfShape(outerDims);
  final innerSize =
      dim + 1 < ndim ? _productOfShape(outShape.sublist(dim + 1)) : 1;
  final outDimStride = outShape[dim] * innerSize;

  for (final t in tensors) {
    final tDimSize = t.shape[dim];
    for (int outer = 0; outer < outerSize; outer++) {
      for (int d = 0; d < tDimSize; d++) {
        final srcStart = outer * tDimSize * innerSize + d * innerSize;
        final dstStart = outer * outDimStride + (offset + d) * innerSize;
        for (int inner = 0; inner < innerSize; inner++) {
          result.data[dstStart + inner] = t.data[srcStart + inner];
        }
      }
    }
    offset += tDimSize;
  }
  return result;
}