forward method

Tensor forward(
  1. Tensor input,
  2. List<Tensor> tracker
)

Implementation

Tensor forward(Tensor input, List<Tensor> tracker) {
  int hIn = input.shape[1];
  int wIn = input.shape[2];
  int hOut = ((hIn + 2 * padding - kernelSize) ~/ stride) + 1;
  int wOut = ((wIn + 2 * padding - kernelSize) ~/ stride) + 1;

  int patchSize = inChannels * kernelSize * kernelSize;
  int numOutputs = hOut * wOut;

  // 1. Create ColBuffer. It is [patchSize, numOutputs].
  Tensor colBuffer = Tensor.fill([patchSize, numOutputs], 0.0);
  tracker.add(colBuffer);

  // 2. im2col
  engine.im2col(
    input.handle,
    inChannels,
    hIn,
    wIn,
    kernelSize,
    kernelSize,
    padding,
    padding,
    stride,
    stride,
    colBuffer.handle,
  );

  // 3. MatMul
  // weight is [outChannels, patchSize], colBuffer is [patchSize, numOutputs]
  // result is [outChannels, numOutputs]
  Tensor res = weight.matmul(colBuffer);
  tracker.add(res);

  // 4. Bias Addition
  // This produces a new Tensor.
  // IMPORTANT: Since you can't reassign shape later, the bias addition
  // must return a Tensor that we can interpret as [outChannels, hOut, wOut].
  Tensor output = res + bias;
  tracker.add(output);

  return output;
}