elementWiseMultiply3D function
Implementation
Tensor<Tensor3D> elementWiseMultiply3D(Tensor<Tensor3D> a, Tensor<Tensor3D> b) {
int depth = a.shape[0];
int height = a.shape[1];
int width = a.shape[2];
Tensor3D aVal = a.value;
Tensor3D bVal = b.value;
Tensor3D outValue = [];
for (int d = 0; d < depth; d = d + 1) {
Matrix matrix = [];
for (int h = 0; h < height; h = h + 1) {
Vector row = [];
for (int w = 0; w < width; w = w + 1) {
row.add(aVal[d][h][w] * bVal[d][h][w]);
}
matrix.add(row);
}
outValue.add(matrix);
}
Tensor<Tensor3D> out = Tensor<Tensor3D>(outValue);
out.creator = Node(
[a, b],
() {
int length = a.data.length;
for (int i = 0; i < length; i = i + 1) {
a.grad[i] = a.grad[i] + out.grad[i] * b.data[i];
b.grad[i] = b.grad[i] + out.grad[i] * a.data[i];
}
},
opName: 'multiply_3d',
cost: depth * height * width,
);
return out;
}