inspectGraph method

void inspectGraph(
  1. List<double> inputData,
  2. List<double> targetData
)

Implementation

void inspectGraph(List<double> inputData, List<double> targetData) {
  Logger.log('\n--- Inspecting Computational Graph ---');

  // 1. Convert raw lists to Tensors
  Tensor<Vector> input = Tensor<Vector>(inputData);
  Tensor<Vector> target = Tensor<Vector>(targetData);

  // 2. Run the forward pass
  Tensor<Vector> finalOutput = forward(input) as Tensor<Vector>;

  // 3. Calculate the loss (to complete the graph)
  Tensor<Scalar> loss = mse(finalOutput, target);

  // 4. Print the graph
  loss.printGraph();

  Logger.log('--------------------------------------\n');
}