cat static method
Concatenate a list of tensors along a given dimension.
Implementation
static Tensor cat(List<Tensor> tensors, [int dim = 0]) {
if (tensors.length == 1) return tensors[0];
final ndim = tensors[0].ndim;
if (dim < 0) dim = ndim + dim;
// Compute output shape
final outShape = List<int>.from(tensors[0].shape);
for (int i = 1; i < tensors.length; i++) {
outShape[dim] += tensors[i].shape[dim];
}
final result = Tensor.zeros(outShape);
int offset = 0;
final outerDims = outShape.sublist(0, dim);
final outerSize = outerDims.isEmpty ? 1 : _productOfShape(outerDims);
final innerSize =
dim + 1 < ndim ? _productOfShape(outShape.sublist(dim + 1)) : 1;
final outDimStride = outShape[dim] * innerSize;
for (final t in tensors) {
final tDimSize = t.shape[dim];
for (int outer = 0; outer < outerSize; outer++) {
for (int d = 0; d < tDimSize; d++) {
final srcStart = outer * tDimSize * innerSize + d * innerSize;
final dstStart = outer * outDimStride + (offset + d) * innerSize;
for (int inner = 0; inner < innerSize; inner++) {
result.data[dstStart + inner] = t.data[srcStart + inner];
}
}
}
offset += tDimSize;
}
return result;
}