concatenate3D function

Tensor<Tensor3D> concatenate3D(
  1. Tensor<Tensor3D> a,
  2. Tensor<Tensor3D> b
)

Implementation

Tensor<Tensor3D> concatenate3D(Tensor<Tensor3D> a, Tensor<Tensor3D> b) {
  Tensor3D outValue = [...a.value, ...b.value];
  Tensor<Tensor3D> out = Tensor<Tensor3D>(outValue);

  out.creator = Node(
    [a, b],
    () {
      int aDepth = a.value.length;
      for (int d = 0; d < aDepth; d++) {
        for (int h = 0; h < a.value[0].length; h++) {
          for (int w = 0; w < a.value[0][0].length; w++) {
            a.grad[d][h][w] += out.grad[d][h][w];
          }
        }
      }
      int bDepth = b.value.length;
      for (int d = 0; d < bDepth; d++) {
        for (int h = 0; h < b.value[0].length; h++) {
          for (int w = 0; w < b.value[0][0].length; w++) {
            b.grad[d][h][w] += out.grad[aDepth + d][h][w];
          }
        }
      }
    },
    opName: 'concat_3d', // <-- Renamed for clarity
    cost: 0,
  );
  return out;
}