concatenate3D function
Implementation
Tensor<Tensor3D> concatenate3D(Tensor<Tensor3D> a, Tensor<Tensor3D> b) {
int aDepth = a.shape[0];
int bDepth = b.shape[0];
Tensor3D aVal = a.value;
Tensor3D bVal = b.value;
Tensor3D outValue = [];
for (int d = 0; d < aDepth; d = d + 1) {
outValue.add(aVal[d]);
}
for (int d = 0; d < bDepth; d = d + 1) {
outValue.add(bVal[d]);
}
Tensor<Tensor3D> out = Tensor<Tensor3D>(outValue);
out.creator = Node(
[a, b],
() {
int aLen = a.data.length;
for (int i = 0; i < aLen; i = i + 1) {
a.grad[i] = a.grad[i] + out.grad[i];
}
int bLen = b.data.length;
for (int i = 0; i < bLen; i = i + 1) {
b.grad[i] = b.grad[i] + out.grad[aLen + i];
}
},
opName: 'concat_3d',
cost: 0,
);
return out;
}