conv2d function

Tensor<Matrix> conv2d(
  1. Tensor<Matrix> input,
  2. Tensor<Matrix> kernel, {
  3. String padding = 'valid',
})

Implementation

Tensor<Matrix> conv2d(
    Tensor<Matrix> input,
    Tensor<Matrix> kernel, {
      String padding = 'valid',
    }) {
  Matrix inputMatrix = input.value;
  Matrix kernelMatrix = kernel.value;

  int padSize = 0;
  int originalInputHeight = input.shape[0];
  int originalInputWidth = input.shape[1];
  int kernelHeight = kernel.shape[0];
  int kernelWidth = kernel.shape[1];

  if (padding == 'same') {
    padSize = (kernelHeight - 1) ~/ 2;
    inputMatrix = padMatrix(inputMatrix, padSize);
  }

  int inputHeight = inputMatrix.length;
  int inputWidth = inputMatrix[0].length;
  int outputHeight = inputHeight - kernelHeight + 1;
  int outputWidth = inputWidth - kernelWidth + 1;

  Matrix outputValue = [];
  for (int i = 0; i < outputHeight; i = i + 1) {
    Vector row = [];
    for (int j = 0; j < outputWidth; j = j + 1) {
      row.add(0.0);
    }
    outputValue.add(row);
  }

  for (int y = 0; y < outputHeight; y = y + 1) {
    for (int x = 0; x < outputWidth; x = x + 1) {
      double sum = 0.0;
      for (int ky = 0; ky < kernelHeight; ky = ky + 1) {
        for (int kx = 0; kx < kernelWidth; kx = kx + 1) {
          sum = sum + inputMatrix[y + ky][x + kx] * kernelMatrix[ky][kx];
        }
      }
      outputValue[y][x] = sum;
    }
  }

  Tensor<Matrix> out = Tensor<Matrix>(outputValue);
  int cost = outputHeight * outputWidth * 2 * kernelHeight * kernelWidth;

  out.creator = Node(
    [input, kernel],
        () {
      for (int y = 0; y < outputHeight; y = y + 1) {
        for (int x = 0; x < outputWidth; x = x + 1) {
          int outIdx = y * outputWidth + x;

          for (int ky = 0; ky < kernelHeight; ky = ky + 1) {
            for (int kx = 0; kx < kernelWidth; kx = kx + 1) {

              if (padding == 'same' &&
                  (y + ky < padSize ||
                      y + ky >= originalInputHeight + padSize ||
                      x + kx < padSize ||
                      x + kx >= originalInputWidth + padSize)) {
                continue;
              }

              int inputGradY = (padding == 'same') ? y + ky - padSize : y + ky;
              int inputGradX = (padding == 'same') ? x + kx - padSize : x + kx;

              int inIdx = inputGradY * originalInputWidth + inputGradX;
              int kIdx = ky * kernelWidth + kx;

              input.grad[inIdx] = input.grad[inIdx] + kernelMatrix[ky][kx] * out.grad[outIdx];
              kernel.grad[kIdx] = kernel.grad[kIdx] + inputMatrix[y + ky][x + kx] * out.grad[outIdx];
            }
          }
        }
      }
    },
    opName: 'conv2d',
    extraParams: {
      'padding': padding,
    },
    cost: cost,
  );
  return out;
}