printParallelGraph method
void
printParallelGraph()
Implementation
void printParallelGraph() {
Logger.yellow('Parallel Computational Graph:', prefix: '⚡️');
// 1. Perform a topological sort to get all nodes in dependency order.
List<Node> topo = [];
Set<Node> visitedNodes = {};
void buildTopo(Node? node) {
if (node == null || visitedNodes.contains(node)) return;
visitedNodes.add(node);
for (Tensor inputTensor in node.inputs) {
buildTopo(inputTensor.creator);
}
topo.add(node);
}
buildTopo(creator);
// 2. Calculate the dependency level for each node.
Map<Node, int> levels = {};
for (Node node in topo) {
int maxInputLevel = -1;
for (Tensor inputTensor in node.inputs) {
if (inputTensor.creator != null) {
int inputLevel = levels[inputTensor.creator!] ?? 0;
if (inputLevel > maxInputLevel) {
maxInputLevel = inputLevel;
}
}
}
levels[node] = maxInputLevel + 1;
}
// 3. Group nodes by their dependency level.
Map<int, List<Node>> parallelGroups = {};
for (Node node in topo) {
int level = levels[node] ?? 0;
if (!parallelGroups.containsKey(level)) {
parallelGroups[level] = [];
}
parallelGroups[level]!.add(node);
}
// 4. Print the grouped operations level by level.
List<int> sortedLevels = parallelGroups.keys.toList()..sort();
int totalCost = 0;
for (int level in sortedLevels) {
List<Node> nodesInLevel = parallelGroups[level]!;
int levelCost = 0;
for (Node node in nodesInLevel) {
levelCost += node.cost;
}
totalCost += levelCost;
Logger.cyan(
'--- Level ${level + 1} (${nodesInLevel.length} parallel ops, cost: $levelCost) ---',
prefix: '',
);
for (Node node in nodesInLevel) {
print(' - op: ${node.opName}, cost: ${node.cost}');
}
}
Logger.yellow('Total Graph Cost: ~${totalCost} FLOPs', prefix: 'Σ');
}