printGraph method
void
printGraph()
Prints a color-coded, visual representation of the computational graph, including computational cost and parallelizable dependency levels.
Implementation
void printGraph() {
Logger.yellow('Computational Graph:', prefix: '📊');
// 1. Perform a topological sort to get all nodes in dependency order.
List<Node> topo = [];
Set<Node> visited = {};
void buildTopo(Node? node) {
if (node == null || visited.contains(node)) return;
visited.add(node);
for (Tensor inputTensor in node.inputs) {
buildTopo(inputTensor.creator);
}
topo.add(node);
}
buildTopo(creator);
// 2. Calculate the dependency level for each node.
Map<Node, int> levels = {};
for (Node node in topo) {
int maxInputLevel = -1;
for (Tensor inputTensor in node.inputs) {
if (inputTensor.creator != null) {
int inputLevel = levels[inputTensor.creator!] ?? 0;
if (inputLevel > maxInputLevel) {
maxInputLevel = inputLevel;
}
}
}
levels[node] = maxInputLevel + 1;
}
// 3. Recursively print the graph, passing the level information.
Map<int, int> costs = {};
_buildGraphString(this, '', true, {}, costs, levels, isRoot: true);
int totalCost = costs.values.fold(0, (sum, cost) => sum + cost);
Logger.yellow('Total Graph Cost: ~${totalCost} FLOPs', prefix: 'Σ');
}