forward method
Forward pass.
Implementation
Tensor forward(Tensor input) {
final batch = input.shape[0];
final inH = input.shape[2];
final inW = input.shape[3];
final outH = (inH + 2 * padding - kernelSize) ~/ stride + 1;
final outW = (inW + 2 * padding - kernelSize) ~/ stride + 1;
final output = Float32List(batch * outChannels * outH * outW);
for (int b = 0; b < batch; b++) {
for (int oc = 0; oc < outChannels; oc++) {
for (int oh = 0; oh < outH; oh++) {
for (int ow = 0; ow < outW; ow++) {
double sum = 0.0;
for (int ic = 0; ic < inChannels; ic++) {
for (int kh = 0; kh < kernelSize; kh++) {
for (int kw = 0; kw < kernelSize; kw++) {
final ih = oh * stride - padding + kh;
final iw = ow * stride - padding + kw;
if (ih >= 0 && ih < inH && iw >= 0 && iw < inW) {
final inputIdx =
((b * inChannels + ic) * inH + ih) * inW + iw;
final weightIdx =
((oc * inChannels + ic) * kernelSize + kh) *
kernelSize +
kw;
sum += input.data[inputIdx] * weight.data[weightIdx];
}
}
}
}
if (bias != null) {
sum += bias!.data[oc];
}
output[((b * outChannels + oc) * outH + oh) * outW + ow] = sum;
}
}
}
}
return Tensor(output, [batch, outChannels, outH, outW]);
}