permute method

Tensor permute(
  1. List<int> order
)

Permute dimensions. E.g., (0, 2, 1) transposes last two dims.

Implementation

Tensor permute(List<int> order) {
  assert(order.length == ndim);
  final newShape = [for (int i in order) shape[i]];
  final newStrides = [for (int i in order) strides[i]];
  final result = Tensor.zeros(newShape);
  final resultStrides = _computeStrides(newShape);

  final indices = List<int>.filled(ndim, 0);
  for (int flatIdx = 0; flatIdx < size; flatIdx++) {
    // Compute source offset from indices using old strides
    int srcOffset = 0;
    for (int d = 0; d < ndim; d++) {
      srcOffset += indices[d] * newStrides[d];
    }
    // Compute dest offset using new strides
    int dstOffset = 0;
    for (int d = 0; d < ndim; d++) {
      dstOffset += indices[d] * resultStrides[d];
    }
    result.data[dstOffset] = data[srcOffset];

    // Increment indices
    for (int d = ndim - 1; d >= 0; d--) {
      indices[d]++;
      if (indices[d] < newShape[d]) break;
      indices[d] = 0;
    }
  }
  return result;
}