backward method

void backward()

Implementation

void backward() {
  for (int i = 0; i < grad.length; i = i + 1) {
    grad[i] = 1.0;
  }

  if (creator == null) return;

  List<Node> topo   = [];
  Set<Node> visited = {};

  void buildTopo(Node? node) {
    if (node == null || visited.contains(node)) return;
    visited.add(node);
    for (int i = 0; i < node.inputs.length; i = i + 1) {
      Tensor inputTensor = node.inputs[i];
      if (inputTensor.creator != null) {
        buildTopo(inputTensor.creator);
      }
    }
    topo.add(node);
  }

  buildTopo(creator);

  for (int i = topo.length - 1; i >= 0; i = i - 1) {
    topo[i].backwardFn();
  }
}