inspectGraph method
Implementation
void inspectGraph(List<double> inputData, List<double> targetData) {
Logger.log('\n--- Inspecting Computational Graph ---');
// 1. Convert raw lists to Tensors
Tensor<Vector> input = Tensor<Vector>(inputData);
Tensor<Vector> target = Tensor<Vector>(targetData);
// 2. Run the forward pass
Tensor<Vector> finalOutput = forward(input) as Tensor<Vector>;
// 3. Calculate the loss (to complete the graph)
Tensor<Scalar> loss = mse(finalOutput, target);
// 4. Print the graph
loss.printGraph();
Logger.log('--------------------------------------\n');
}