printGraph method

void printGraph()

Prints a color-coded, visual representation of the computational graph, including computational cost and parallelizable dependency levels.

Implementation

void printGraph() {
  Logger.yellow('Computational Graph:', prefix: '📊');

  // 1. Perform a topological sort to get all nodes in dependency order.
  List<Node> topo = [];
  Set<Node> visited = {};
  void buildTopo(Node? node) {
    if (node == null || visited.contains(node)) return;
    visited.add(node);
    for (Tensor inputTensor in node.inputs) {
      buildTopo(inputTensor.creator);
    }
    topo.add(node);
  }

  buildTopo(creator);

  // 2. Calculate the dependency level for each node.
  Map<Node, int> levels = {};
  for (Node node in topo) {
    int maxInputLevel = -1;
    for (Tensor inputTensor in node.inputs) {
      if (inputTensor.creator != null) {
        int inputLevel = levels[inputTensor.creator!] ?? 0;
        if (inputLevel > maxInputLevel) {
          maxInputLevel = inputLevel;
        }
      }
    }
    levels[node] = maxInputLevel + 1;
  }

  // 3. Recursively print the graph, passing the level information.
  Map<int, int> costs = {};
  _buildGraphString(this, '', true, {}, costs, levels, isRoot: true);

  int totalCost = costs.values.fold(0, (sum, cost) => sum + cost);
  Logger.yellow('Total Graph Cost: ~${totalCost} FLOPs', prefix: 'Σ');
}