fit method

void fit(
  1. List<List<double>> X,
  2. List<List<double>> Y, {
  3. int? batchSize,
  4. bool verbose = false,
  5. String optimizer = 'sgd',
  6. double momentum = 0.9,
  7. double beta1 = 0.9,
  8. double beta2 = 0.999,
  9. double epsilon = 1e-8,
  10. double l2 = 0.0,
  11. String lrSchedule = 'constant',
  12. int stepSize = 10,
  13. double stepDecay = 0.5,
  14. double expDecay = 0.99,
  15. int epochsOverride = -1,
})

Convenience training wrapper that calls the main fit implementation. (This class intentionally exposes a single training API.)

Implementation

void fit(
  List<List<double>> X,
  List<List<double>> Y, {
  int? batchSize,
  bool verbose = false,
  String optimizer = 'sgd',
  double momentum = 0.9,
  double beta1 = 0.9,
  double beta2 = 0.999,
  double epsilon = 1e-8,
  // l2 regularization (weight decay)
  double l2 = 0.0,
  // learning rate schedule: 'constant', 'step', 'exp'
  String lrSchedule = 'constant',
  int stepSize = 10,
  double stepDecay = 0.5,
  double expDecay = 0.99,
  int epochsOverride = -1,
}) {
  if (X.isEmpty) throw ArgumentError('Empty dataset');
  if (X.length != Y.length) {
    throw ArgumentError('X and Y must have same number of rows');
  }
  _initParams();
  final n = X.length;
  final useBatch = (batchSize == null || batchSize >= n) ? n : batchSize;
  final int runEpochs = epochsOverride > 0 ? epochsOverride : epochs;

  // optimizer state (for momentum/adam)
  List<List<List<double>>> vW = [];
  List<List<double>> vB = [];
  List<List<List<double>>> mW = [];
  List<List<List<double>>> vAdamW = [];
  List<List<double>> mB = [];
  List<List<double>> vAdamB = [];
  var t = 0;

  // initialize optimizer accumulators
  if (optimizer == 'momentum') {
    for (var l = 0; l < weights.length; l++) {
      vW.add(
        List.generate(
          weights[l].length,
          (_) => List<double>.filled(weights[l][0].length, 0.0),
        ),
      );
      vB.add(List<double>.filled(biases[l].length, 0.0));
    }
  } else if (optimizer == 'adam') {
    for (var l = 0; l < weights.length; l++) {
      mW.add(
        List.generate(
          weights[l].length,
          (_) => List<double>.filled(weights[l][0].length, 0.0),
        ),
      );
      vAdamW.add(
        List.generate(
          weights[l].length,
          (_) => List<double>.filled(weights[l][0].length, 0.0),
        ),
      );
      mB.add(List<double>.filled(biases[l].length, 0.0));
      vAdamB.add(List<double>.filled(biases[l].length, 0.0));
    }
  }

  for (var epoch = 0; epoch < runEpochs; epoch++) {
    // compute current learning rate based on schedule
    double currentLr = lr;
    if (lrSchedule == 'step') {
      currentLr = lr * pow(stepDecay, (epoch / stepSize).floor());
    } else if (lrSchedule == 'exp') {
      currentLr = lr * pow(expDecay, epoch);
    }

    // create shuffled indices for mini-batching
    final indices = List<int>.generate(n, (i) => i);
    if (useBatch < n) {
      for (var i = indices.length - 1; i > 0; i--) {
        final j = _rand.nextInt(i + 1);
        final tmp = indices[i];
        indices[i] = indices[j];
        indices[j] = tmp;
      }
    }

    for (var batchStart = 0; batchStart < n; batchStart += useBatch) {
      final batchEnd = (batchStart + useBatch).clamp(0, n);
      final bsize = batchEnd - batchStart;
      if (bsize <= 0) continue;

      // prepare batch activations and preacts
      final activations = <List<List<double>>>[];
      final preacts = <List<List<double>>>[];

      for (var bi = batchStart; bi < batchEnd; bi++) {
        final idx = indices[bi];
        var a = X[idx];
        final acts = <List<double>>[a];
        final pres = <List<double>>[];
        for (var l = 0; l < weights.length; l++) {
          final w = weights[l];
          final b = biases[l];
          final z = List<double>.filled(w.length, 0.0);
          for (var ii = 0; ii < w.length; ii++) {
            var s = b[ii];
            for (var jj = 0; jj < w[ii].length; jj++) {
              s += w[ii][jj] * a[jj];
            }
            z[ii] = s;
          }
          pres.add(z);
          if (l == weights.length - 1) {
            a = z.map(_sigmoid).toList();
          } else {
            a = z.map(_relu).toList();
          }
          acts.add(a);
        }
        activations.add(acts);
        preacts.add(pres);
      }

      // compute gradients for batch
      final gradW = List.generate(
        weights.length,
        (l) => List.generate(
          weights[l].length,
          (_) => List<double>.filled(weights[l][0].length, 0.0),
        ),
      );
      final gradB = List.generate(
        biases.length,
        (l) => List<double>.filled(biases[l].length, 0.0),
      );

      for (var bi = 0; bi < activations.length; bi++) {
        final idx = indices[batchStart + bi];
        final y = Y[idx];
        final acts = activations[bi];
        final pres = preacts[bi];
        final out = acts.last;
        final delta = List<double>.filled(out.length, 0.0);
        for (var k = 0; k < out.length; k++) {
          final err = out[k] - y[k];
          delta[k] = err * _sigmoidDeriv(out[k]);
        }
        var curDelta = delta;
        for (var l = weights.length - 1; l >= 0; l--) {
          final aPrev = acts[l];
          for (var iOut = 0; iOut < weights[l].length; iOut++) {
            gradB[l][iOut] += curDelta[iOut];
            for (var iIn = 0; iIn < weights[l][iOut].length; iIn++) {
              gradW[l][iOut][iIn] += curDelta[iOut] * aPrev[iIn];
            }
          }
          if (l > 0) {
            final nextDelta = List<double>.filled(weights[l - 1].length, 0.0);
            for (var iPrev = 0; iPrev < weights[l - 1].length; iPrev++) {
              var s = 0.0;
              for (var iOut = 0; iOut < weights[l].length; iOut++) {
                s += weights[l][iOut][iPrev] * curDelta[iOut];
              }
              nextDelta[iPrev] = s * _reluDeriv(pres[l - 1][iPrev]);
            }
            curDelta = nextDelta;
          }
        }
      }

      // average gradients over batch
      for (var l = 0; l < weights.length; l++) {
        for (var iOut = 0; iOut < weights[l].length; iOut++) {
          for (var iIn = 0; iIn < weights[l][iOut].length; iIn++) {
            gradW[l][iOut][iIn] /= bsize;
          }
          gradB[l][iOut] /= bsize;
        }
      }

      // optimizer updates (use currentLr)
      if (optimizer == 'sgd') {
        for (var l = 0; l < weights.length; l++) {
          for (var iOut = 0; iOut < weights[l].length; iOut++) {
            for (var iIn = 0; iIn < weights[l][iOut].length; iIn++) {
              weights[l][iOut][iIn] -= currentLr * gradW[l][iOut][iIn];
              // weight decay
              if (l2 > 0) weights[l][iOut][iIn] *= (1 - currentLr * l2);
            }
            biases[l][iOut] -= currentLr * gradB[l][iOut];
          }
        }
      } else if (optimizer == 'momentum') {
        for (var l = 0; l < weights.length; l++) {
          for (var iOut = 0; iOut < weights[l].length; iOut++) {
            for (var iIn = 0; iIn < weights[l][iOut].length; iIn++) {
              vW[l][iOut][iIn] =
                  momentum * vW[l][iOut][iIn] +
                  currentLr * gradW[l][iOut][iIn];
              weights[l][iOut][iIn] -= vW[l][iOut][iIn];
              if (l2 > 0) weights[l][iOut][iIn] *= (1 - currentLr * l2);
            }
            vB[l][iOut] = momentum * vB[l][iOut] + currentLr * gradB[l][iOut];
            biases[l][iOut] -= vB[l][iOut];
          }
        }
      } else if (optimizer == 'adam') {
        t += 1;
        for (var l = 0; l < weights.length; l++) {
          for (var iOut = 0; iOut < weights[l].length; iOut++) {
            for (var iIn = 0; iIn < weights[l][iOut].length; iIn++) {
              final g = gradW[l][iOut][iIn];
              mW[l][iOut][iIn] = beta1 * mW[l][iOut][iIn] + (1 - beta1) * g;
              vAdamW[l][iOut][iIn] =
                  beta2 * vAdamW[l][iOut][iIn] + (1 - beta2) * g * g;
              final mHat = mW[l][iOut][iIn] / (1 - pow(beta1, t));
              final vHat = vAdamW[l][iOut][iIn] / (1 - pow(beta2, t));
              weights[l][iOut][iIn] -=
                  currentLr * mHat / (sqrt(vHat) + epsilon);
              if (l2 > 0) weights[l][iOut][iIn] *= (1 - currentLr * l2);
            }
            final gb = gradB[l][iOut];
            mB[l][iOut] = beta1 * mB[l][iOut] + (1 - beta1) * gb;
            vAdamB[l][iOut] = beta2 * vAdamB[l][iOut] + (1 - beta2) * gb * gb;
            final mHatB = mB[l][iOut] / (1 - pow(beta1, t));
            final vHatB = vAdamB[l][iOut] / (1 - pow(beta2, t));
            biases[l][iOut] -= currentLr * mHatB / (sqrt(vHatB) + epsilon);
          }
        }
      } else {
        throw ArgumentError('Unknown optimizer: $optimizer');
      }
    }

    // compute epoch loss and store (mean over samples) using helper
    lastLoss = _mseLoss(predict(X), Y);
    lossHistory.add(lastLoss!);
    if (verbose) print('epoch=$epoch loss=$lastLoss');
  }
}