getAllTensorsInGraph method

List<GPUTensor> getAllTensorsInGraph(
  1. GPUTensor root
)

Implementation

List<GPUTensor> getAllTensorsInGraph(GPUTensor root) {
  List<GPUTensor> all = <GPUTensor>[];
  Set<String> visited = <String>{};

  void traverse(GPUTensor node) {
    if (visited.contains(node.id)) {
      return;
    }
    visited.add(node.id);
    all.add(node);

    if (node.creator != null) {
      for (int i = 0; i < node.creator!.inputs.length; i = i + 1) {
        traverse(node.creator!.inputs[i]);
      }
    }
  }

  traverse(root);
  return all;
}