conv2dMultiChannelGPU function

GPUTensor<Tensor3D> conv2dMultiChannelGPU(
  1. GPUTensor input,
  2. GPUTensor<Tensor3D> weight,
  3. GPUTensor<Vector> bias,
  4. int kH,
  5. int kW,
  6. CommandBuffer tape, {
  7. String padding = 'valid',
  8. int strideH = 1,
  9. int strideW = 1,
})

///////////////////////////////// Advanced Layers (800- 999) /// /////////////////////////////////

Implementation

GPUTensor<Tensor3D> conv2dMultiChannelGPU(
    GPUTensor<dynamic> input,
    GPUTensor<Tensor3D> weight,
    GPUTensor<Vector> bias,
    int kH, int kW,
    CommandBuffer tape, {String padding = 'valid', int strideH = 1, int strideW = 1}) {

  int inChannels = input.shape.length == 2 ? 1 : input.shape[0];
  int inHeight = input.shape.length == 2 ? input.shape[0] : input.shape[1];
  int inWidth = input.shape.length == 2 ? input.shape[1] : input.shape[2];

  int outChannels = weight.shape[0];

  int padT = 0;
  int padL = 0;
  // Calculate new output dimensions based on the stride jump
  int outHeight = (inHeight - kH) ~/ strideH + 1;
  int outWidth = (inWidth - kW) ~/ strideW + 1;

  if (padding == 'same') {
    padT = (kH - 1) ~/ 2;
    padL = (kW - 1) ~/ 2;
    outHeight = (inHeight + 2 * padT - kH) ~/ strideH + 1;
    outWidth = (inWidth + 2 * padL - kW) ~/ strideW + 1;
  }

  GPUTensor<Tensor3D> out = GPUTensor<Tensor3D>.empty(<int>[outChannels, outHeight, outWidth]);

  tape.putInt(OP_CONV2D_MULTI_FORWARD);
  tape.putString(input.id);
  tape.putString(weight.id);
  tape.putString(bias.id);
  tape.putString(out.id);
  tape.putInt(inChannels);
  tape.putInt(outChannels);
  tape.putInt(kH);
  tape.putInt(kW);
  tape.putInt(padT);
  tape.putInt(padL);
  tape.putInt(strideH); // ⚡ Added to tape
  tape.putInt(strideW); // ⚡ Added to tape

  int cost = outHeight * outWidth * outChannels * inChannels * kH * kW * 2;

  out.creator = GPUNode(
    <GPUTensor>[input, weight, bias],
        (CommandBuffer bTape) {
      bTape.putInt(OP_CONV2D_MULTI_BACKWARD_INPUT);
      bTape.putString('${out.id}_grad');
      bTape.putString(weight.id);
      bTape.putString('${input.id}_grad');
      bTape.putInt(inChannels);
      bTape.putInt(outChannels);
      bTape.putInt(kH);
      bTape.putInt(kW);
      bTape.putInt(padT);
      bTape.putInt(padL);
      bTape.putInt(strideH); // ⚡ Added to tape
      bTape.putInt(strideW); // ⚡ Added to tape

      bTape.putInt(OP_CONV2D_MULTI_BACKWARD_WEIGHT);
      bTape.putString(input.id);
      bTape.putString('${out.id}_grad');
      bTape.putString('${weight.id}_grad');
      bTape.putString('${bias.id}_grad');
      bTape.putInt(inChannels);
      bTape.putInt(outChannels);
      bTape.putInt(kH);
      bTape.putInt(kW);
      bTape.putInt(padT);
      bTape.putInt(padL);
      bTape.putInt(strideH); // ⚡ Added to tape
      bTape.putInt(strideW); // ⚡ Added to tape
    },
    opName: 'conv2dMultiChannelGPU',
    cost: cost,
  );

  return out;
}