getAllTensorsInGraph method
Implementation
List<GPUTensor> getAllTensorsInGraph(GPUTensor root) {
List<GPUTensor> all = <GPUTensor>[];
Set<String> visited = <String>{};
void traverse(GPUTensor node) {
if (visited.contains(node.id)) {
return;
}
visited.add(node.id);
all.add(node);
if (node.creator != null) {
for (int i = 0; i < node.creator!.inputs.length; i = i + 1) {
traverse(node.creator!.inputs[i]);
}
}
}
traverse(root);
return all;
}