copyWithPadding method

Future<void> copyWithPadding(
  1. Tensor<T> input,
  2. Tensor<T> output
)

Implementation

Future<void> copyWithPadding(Tensor<T> input, Tensor<T> output) async {
  final rank = input.shape.length;
  final wgslType = getWGSLType(dataType);

  // Generate stride calculations for both tensors
  String generateStrides(String name, List<int> shape) {
    final strides = <int>[];
    var stride = 1;
    for (int i = shape.length - 1; i >= 0; i--) {
      strides.insert(0, stride);
      stride *= shape[i];
    }
    return 'const ${name}_strides : array<u32, $rank> = array<u32, $rank>(${strides.map((s) => '${s}u').join(', ')});';
  }

  final sourceStrides = generateStrides('input', input.shape);
  final targetStrides = generateStrides('output', output.shape);
  final sourceShapeArray =
      'array<u32, $rank>(${input.shape.map((s) => '${s}u').join(', ')})';

  final shaderCode =
      '''
$sourceStrides
$targetStrides
const source_shape : array<u32, $rank> = $sourceShapeArray;
const rank : u32 = ${rank}u;

@group(0) @binding(0) var<storage, read_write> input: array<$wgslType>;
@group(0) @binding(1) var<storage, read_write> output: array<$wgslType>;

@compute @workgroup_size(256)
fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
let source_index: u32 = global_id.x;
if (source_index >= ${input.size}u) { return; }

// Convert flat index to multi-dimensional indices
var indices: array<u32, rank>;
var remainder: u32 = source_index;

for (var i: u32 = 0u; i < rank; i = i + 1u) {
  indices[i] = remainder / source_strides[i];
  remainder = remainder % source_strides[i];
}

// Calculate output flat index using output strides
var target_index: u32 = 0u;
for (var i: u32 = 0u; i < rank; i = i + 1u) {
  target_index = target_index + indices[i] * target_strides[i];
}

output[target_index] = input[source_index];
}
''';

  final shader = gpu.createComputeShader();
  try {
    shader.loadKernelString(shaderCode);
    shader.setBuffer('input', input.buffer);
    shader.setBuffer('output', output.buffer);

    final workgroups = (input.size + 255) ~/ 256;
    await shader.dispatch(workgroups, 1, 1);
  } finally {
    shader.destroy();
  }
}