train method

void train(
  1. List<List<double>> X,
  2. List<double> y,
  3. int epochs,
  4. double lr,
)

Implementation

void train(List<List<double>> X, List<double> y, int epochs, double lr) {
  for (int epoch = 0; epoch < epochs; epoch++) {
    Node loss = Node(0.0);
    List<Node> outputs = [];

    for (int i = 0; i < X.length; i++) {
      List<Node> inputs = [Node(X[i][0]), Node(X[i][1])];
      Node pred = forward(inputs);
      Node target = Node(y[i]);

      // Mean Squared Error (MSE) Loss
      Node error = (pred - target).pow(2);
      loss = loss + error;
      outputs.add(pred);
    }

    // Compute gradients
    loss.backward();

    // Update weights and biases using gradient descent
    for (var w in weights) {
      w.value -= lr * w.grad;
      w.zeroGrad();
    }
    for (var b in biases) {
      b.value -= lr * b.grad;
      b.zeroGrad();
    }

    if (epoch % 10 == 0) {
      print('Epoch $epoch - Loss: ${loss.value}');
    }
  }
}