fit method

void fit(
  1. List<List<double>> inputs,
  2. List<List<double>> targets, {
  3. int epochs = 100,
  4. bool averageWeight = false,
  5. bool debug = true,
})

Implementation

void fit(List<List<double>> inputs, List<List<double>> targets,
    {int epochs = 100, bool averageWeight = false, bool debug = true}) {
  if (debug) {
    print('--- STARTING TRAINING ---');
  }
  Stopwatch stopwatch = Stopwatch()..start();

  for (int epoch = 0; epoch < epochs; epoch++) {
    double epochLoss = 0.0;

    for (int i = 0; i < inputs.length; i++) {
      Tensor<Vector> input = Tensor<Vector>(inputs[i]);
      Tensor<Vector> target = Tensor<Vector>(targets[i]);

      Tensor<Vector> finalOutput = forward(input) as Tensor<Vector>;
      Tensor<Scalar> loss = mse(finalOutput, target);

      epochLoss += loss.value;

      loss.backward();
      optimizer.step();
      optimizer.zeroGrad();

      // --- NEW: Progress Bar Logic ---
      if (debug) {
        int barWidth = 20;
        double progress = (i + 1) / inputs.length;
        int completed = (progress * barWidth).round();
        String bar = '=' * completed + '>' + ' ' * (barWidth - completed);
        int percent = (progress * 100).round();

        // Use stdout.write and carriage return to update the line
        stdout.write('\rEpoch ${epoch + 1}/$epochs: [$bar] $percent%');
      }
    }

    if (debug) {
      double avgLoss = epochLoss / inputs.length;

      // After the progress bar is full, overwrite it with the final loss
      stdout.write('\rEpoch ${epoch + 1}/$epochs: [====================>] 100%, Avg Loss: ${avgLoss.toStringAsFixed(6)}');

      // Calculate the interval once before the loop for efficiency.
      int logInterval = max(1, (epochs / 10).round());

// Inside the loop, the check is now always safe.
      bool isLogInterval = (epoch + 1) % logInterval == 0;
      if (averageWeight && isLogInterval) {
        // Calculate and print weight magnitude on a new line for clarity
        double totalWeightSum = 0;
        int totalWeightCount = 0;
        for (Tensor param in parameters) {
          if (param.value is Vector) {
            for (double weight in (param.value as Vector)) {
              totalWeightSum += weight.abs();
              totalWeightCount++;
            }
          } else if (param.value is Matrix) {
            for (Vector row in (param.value as Matrix)) {
              for (double weight in row) {
                totalWeightSum += weight.abs();
                totalWeightCount++;
              }
            }
          }
        }
        if (totalWeightCount > 0) {
          double avg = totalWeightSum / totalWeightCount;
          stdout.write(', Avg Weight Mag: ${avg.toStringAsFixed(4)}');
        }
      }
      // Print a newline to move to the next epoch's log
      print('');
    }
  }

  stopwatch.stop();
  if (debug) {
    print('--- TRAINING FINISHED in ${stopwatch.elapsedMilliseconds}ms ---\n');
  }
}