concatenate3D function
Implementation
Tensor<Tensor3D> concatenate3D(Tensor<Tensor3D> a, Tensor<Tensor3D> b) {
Tensor3D outValue = [...a.value, ...b.value];
Tensor<Tensor3D> out = Tensor<Tensor3D>(outValue);
out.creator = Node(
[a, b],
() {
int aDepth = a.value.length;
for (int d = 0; d < aDepth; d++) {
for (int h = 0; h < a.value[0].length; h++) {
for (int w = 0; w < a.value[0][0].length; w++) {
a.grad[d][h][w] += out.grad[d][h][w];
}
}
}
int bDepth = b.value.length;
for (int d = 0; d < bDepth; d++) {
for (int h = 0; h < b.value[0].length; h++) {
for (int w = 0; w < b.value[0][0].length; w++) {
b.grad[d][h][w] += out.grad[aDepth + d][h][w];
}
}
}
},
opName: 'concat_3d', // <-- Renamed for clarity
cost: 0,
);
return out;
}