printParallelGraph method

void printParallelGraph()

Implementation

void printParallelGraph() {
  Logger.yellow('Parallel Computational Graph:', prefix: '⚡️');

  // 1. Perform a topological sort to get all nodes in dependency order.
  List<Node> topo = [];
  Set<Node> visitedNodes = {};
  void buildTopo(Node? node) {
    if (node == null || visitedNodes.contains(node)) return;
    visitedNodes.add(node);
    for (Tensor inputTensor in node.inputs) {
      buildTopo(inputTensor.creator);
    }
    topo.add(node);
  }

  buildTopo(creator);

  // 2. Calculate the dependency level for each node.
  Map<Node, int> levels = {};
  for (Node node in topo) {
    int maxInputLevel = -1;
    for (Tensor inputTensor in node.inputs) {
      if (inputTensor.creator != null) {
        int inputLevel = levels[inputTensor.creator!] ?? 0;
        if (inputLevel > maxInputLevel) {
          maxInputLevel = inputLevel;
        }
      }
    }
    levels[node] = maxInputLevel + 1;
  }

  // 3. Group nodes by their dependency level.
  Map<int, List<Node>> parallelGroups = {};
  for (Node node in topo) {
    int level = levels[node] ?? 0;
    if (!parallelGroups.containsKey(level)) {
      parallelGroups[level] = [];
    }
    parallelGroups[level]!.add(node);
  }

  // 4. Print the grouped operations level by level.
  List<int> sortedLevels = parallelGroups.keys.toList()..sort();
  int totalCost = 0;

  for (int level in sortedLevels) {
    List<Node> nodesInLevel = parallelGroups[level]!;
    int levelCost = 0;
    for (Node node in nodesInLevel) {
      levelCost += node.cost;
    }
    totalCost += levelCost;

    Logger.cyan(
      '--- Level ${level + 1} (${nodesInLevel.length} parallel ops, cost: $levelCost) ---',
      prefix: '',
    );
    for (Node node in nodesInLevel) {
      print('  - op: ${node.opName}, cost: ${node.cost}');
    }
  }

  Logger.yellow('Total Graph Cost: ~${totalCost} FLOPs', prefix: 'Σ');
}