fit method

void fit(
  1. List<List<double>> inputs,
  2. List<List<double>> targets, {
  3. int epochs = 100,
  4. bool averageWeight = false,
  5. bool debug = true,
})

Implementation

void fit(List<List<double>> inputs, List<List<double>> targets,
    {int epochs = 100, bool averageWeight = false, bool debug = true}) {
  if (debug) {
    Logger.log('--- STARTING TRAINING ---');
  }
  Stopwatch stopwatch = Stopwatch();
  stopwatch.start();

  for (int epoch = 0; epoch < epochs; epoch = epoch + 1) {
    double epochLoss = 0.0;

    for (int i = 0; i < inputs.length; i = i + 1) {
      Tensor<Vector> input = Tensor<Vector>(inputs[i]);
      Tensor<Vector> target = Tensor<Vector>(targets[i]);

      Tensor<Vector> finalOutput = forward(input) as Tensor<Vector>;
      Tensor<Scalar> loss = mse(finalOutput, target);

      epochLoss = epochLoss + loss.value;

      loss.backward();
      optimizer.step();
      optimizer.zeroGrad();

      if (debug) {
        int barWidth = 20;
        double progress = (i + 1) / inputs.length;
        int completed = (progress * barWidth).round();

        String bar = '';
        for (int b = 0; b < completed; b = b + 1) {
          bar = '$bar=';
        }
        bar = '$bar>';
        for (int b = 0; b < (barWidth - completed); b = b + 1) {
          bar = '$bar ';
        }

        int percent = (progress * 100).round();
        stdout.write('\rEpoch ${epoch + 1}/$epochs: [$bar] $percent%');
      }
    }

    if (debug) {
      double avgLoss = epochLoss / inputs.length;

      stdout.write('\rEpoch ${epoch + 1}/$epochs: [====================>] 100%, Avg Loss: ${avgLoss.toStringAsFixed(6)}');

      int logInterval = max(1, (epochs / 10).round());

      bool isLogInterval = (epoch + 1) % logInterval == 0;
      if (averageWeight && isLogInterval) {
        double totalWeightSum = 0.0;
        int totalWeightCount = 0;

        List<Tensor<dynamic>> params = parameters;
        for (int p = 0; p < params.length; p = p + 1) {
          Tensor<dynamic> param = params[p];
          if (param.value is Vector) {
            Vector v = param.value as Vector;
            for (int w = 0; w < v.length; w = w + 1) {
              totalWeightSum = totalWeightSum + v[w].abs();
              totalWeightCount = totalWeightCount + 1;
            }
          } else if (param.value is Matrix) {
            Matrix m = param.value as Matrix;
            for (int r = 0; r < m.length; r = r + 1) {
              for (int c = 0; c < m[r].length; c = c + 1) {
                totalWeightSum = totalWeightSum + m[r][c].abs();
                totalWeightCount = totalWeightCount + 1;
              }
            }
          }
        }

        if (totalWeightCount > 0) {
          double avg = totalWeightSum / totalWeightCount;
          stdout.write(', Avg Weight Mag: ${avg.toStringAsFixed(4)}');
        }
      }
      Logger.log('');
    }
  }

  stopwatch.stop();
  if (debug) {
    Logger.log('--- TRAINING FINISHED in ${stopwatch.elapsedMilliseconds}ms ---\n');
  }
}