forward method

Tensor forward(
  1. Tensor input
)

Forward pass.

Implementation

Tensor forward(Tensor input) {
  final batch = input.shape[0];
  final inH = input.shape[2];
  final inW = input.shape[3];
  final outH = (inH + 2 * padding - kernelSize) ~/ stride + 1;
  final outW = (inW + 2 * padding - kernelSize) ~/ stride + 1;

  final output = Float32List(batch * outChannels * outH * outW);

  for (int b = 0; b < batch; b++) {
    for (int oc = 0; oc < outChannels; oc++) {
      for (int oh = 0; oh < outH; oh++) {
        for (int ow = 0; ow < outW; ow++) {
          double sum = 0.0;
          for (int ic = 0; ic < inChannels; ic++) {
            for (int kh = 0; kh < kernelSize; kh++) {
              for (int kw = 0; kw < kernelSize; kw++) {
                final ih = oh * stride - padding + kh;
                final iw = ow * stride - padding + kw;
                if (ih >= 0 && ih < inH && iw >= 0 && iw < inW) {
                  final inputIdx =
                      ((b * inChannels + ic) * inH + ih) * inW + iw;
                  final weightIdx =
                      ((oc * inChannels + ic) * kernelSize + kh) *
                              kernelSize +
                          kw;
                  sum += input.data[inputIdx] * weight.data[weightIdx];
                }
              }
            }
          }
          if (bias != null) {
            sum += bias!.data[oc];
          }
          output[((b * outChannels + oc) * outH + oh) * outW + ow] = sum;
        }
      }
    }
  }

  return Tensor(output, [batch, outChannels, outH, outW]);
}