conv2dMultiChannelGPU function
///////////////////////////////// Advanced Layers (800- 999) /// /////////////////////////////////
Implementation
GPUTensor<Tensor3D> conv2dMultiChannelGPU(
GPUTensor<dynamic> input,
GPUTensor<Tensor3D> weight,
GPUTensor<Vector> bias,
int kH, int kW,
CommandBuffer tape, {String padding = 'valid', int strideH = 1, int strideW = 1}) {
int inChannels = input.shape.length == 2 ? 1 : input.shape[0];
int inHeight = input.shape.length == 2 ? input.shape[0] : input.shape[1];
int inWidth = input.shape.length == 2 ? input.shape[1] : input.shape[2];
int outChannels = weight.shape[0];
int padT = 0;
int padL = 0;
// Calculate new output dimensions based on the stride jump
int outHeight = (inHeight - kH) ~/ strideH + 1;
int outWidth = (inWidth - kW) ~/ strideW + 1;
if (padding == 'same') {
padT = (kH - 1) ~/ 2;
padL = (kW - 1) ~/ 2;
outHeight = (inHeight + 2 * padT - kH) ~/ strideH + 1;
outWidth = (inWidth + 2 * padL - kW) ~/ strideW + 1;
}
GPUTensor<Tensor3D> out = GPUTensor<Tensor3D>.empty(<int>[outChannels, outHeight, outWidth]);
tape.putInt(OP_CONV2D_MULTI_FORWARD);
tape.putString(input.id);
tape.putString(weight.id);
tape.putString(bias.id);
tape.putString(out.id);
tape.putInt(inChannels);
tape.putInt(outChannels);
tape.putInt(kH);
tape.putInt(kW);
tape.putInt(padT);
tape.putInt(padL);
tape.putInt(strideH); // ⚡ Added to tape
tape.putInt(strideW); // ⚡ Added to tape
int cost = outHeight * outWidth * outChannels * inChannels * kH * kW * 2;
out.creator = GPUNode(
<GPUTensor>[input, weight, bias],
(CommandBuffer bTape) {
bTape.putInt(OP_CONV2D_MULTI_BACKWARD_INPUT);
bTape.putString('${out.id}_grad');
bTape.putString(weight.id);
bTape.putString('${input.id}_grad');
bTape.putInt(inChannels);
bTape.putInt(outChannels);
bTape.putInt(kH);
bTape.putInt(kW);
bTape.putInt(padT);
bTape.putInt(padL);
bTape.putInt(strideH); // ⚡ Added to tape
bTape.putInt(strideW); // ⚡ Added to tape
bTape.putInt(OP_CONV2D_MULTI_BACKWARD_WEIGHT);
bTape.putString(input.id);
bTape.putString('${out.id}_grad');
bTape.putString('${weight.id}_grad');
bTape.putString('${bias.id}_grad');
bTape.putInt(inChannels);
bTape.putInt(outChannels);
bTape.putInt(kH);
bTape.putInt(kW);
bTape.putInt(padT);
bTape.putInt(padL);
bTape.putInt(strideH); // ⚡ Added to tape
bTape.putInt(strideW); // ⚡ Added to tape
},
opName: 'conv2dMultiChannelGPU',
cost: cost,
);
return out;
}