conv2d function
Implementation
Tensor<Matrix> conv2d(
Tensor<Matrix> input,
Tensor<Matrix> kernel, {
String padding = 'valid',
}) {
Matrix inputMatrix = input.value;
Matrix kernelMatrix = kernel.value;
int padSize = 0;
int originalInputHeight = input.shape[0];
int originalInputWidth = input.shape[1];
int kernelHeight = kernel.shape[0];
int kernelWidth = kernel.shape[1];
if (padding == 'same') {
padSize = (kernelHeight - 1) ~/ 2;
inputMatrix = padMatrix(inputMatrix, padSize);
}
int inputHeight = inputMatrix.length;
int inputWidth = inputMatrix[0].length;
int outputHeight = inputHeight - kernelHeight + 1;
int outputWidth = inputWidth - kernelWidth + 1;
Matrix outputValue = [];
for (int i = 0; i < outputHeight; i = i + 1) {
Vector row = [];
for (int j = 0; j < outputWidth; j = j + 1) {
row.add(0.0);
}
outputValue.add(row);
}
for (int y = 0; y < outputHeight; y = y + 1) {
for (int x = 0; x < outputWidth; x = x + 1) {
double sum = 0.0;
for (int ky = 0; ky < kernelHeight; ky = ky + 1) {
for (int kx = 0; kx < kernelWidth; kx = kx + 1) {
sum = sum + inputMatrix[y + ky][x + kx] * kernelMatrix[ky][kx];
}
}
outputValue[y][x] = sum;
}
}
Tensor<Matrix> out = Tensor<Matrix>(outputValue);
int cost = outputHeight * outputWidth * 2 * kernelHeight * kernelWidth;
out.creator = Node(
[input, kernel],
() {
for (int y = 0; y < outputHeight; y = y + 1) {
for (int x = 0; x < outputWidth; x = x + 1) {
int outIdx = y * outputWidth + x;
for (int ky = 0; ky < kernelHeight; ky = ky + 1) {
for (int kx = 0; kx < kernelWidth; kx = kx + 1) {
if (padding == 'same' &&
(y + ky < padSize ||
y + ky >= originalInputHeight + padSize ||
x + kx < padSize ||
x + kx >= originalInputWidth + padSize)) {
continue;
}
int inputGradY = (padding == 'same') ? y + ky - padSize : y + ky;
int inputGradX = (padding == 'same') ? x + kx - padSize : x + kx;
int inIdx = inputGradY * originalInputWidth + inputGradX;
int kIdx = ky * kernelWidth + kx;
input.grad[inIdx] = input.grad[inIdx] + kernelMatrix[ky][kx] * out.grad[outIdx];
kernel.grad[kIdx] = kernel.grad[kIdx] + inputMatrix[y + ky][x + kx] * out.grad[outIdx];
}
}
}
}
},
opName: 'conv2d',
extraParams: {
'padding': padding,
},
cost: cost,
);
return out;
}