logisticRegressionFit function

List<double> logisticRegressionFit(
  1. List<List<double>> X,
  2. List<int> y, {
  3. double lr = 0.1,
  4. int epochs = 500,
})

Implementation

List<double> logisticRegressionFit(
  List<List<double>> X,
  List<int> y, {
  double lr = 0.1,
  int epochs = 500,
}) {
  final n = X.length;
  if (n == 0) return [0.0];
  final m = X[0].length;
  if (y.length != n) throw ArgumentError('X and y length mismatch');

  final w = List<double>.filled(m + 1, 0.0); // intercept + weights
  for (var epoch = 0; epoch < epochs; epoch++) {
    final grad = List<double>.filled(m + 1, 0.0);
    for (var i = 0; i < n; i++) {
      var z = w[0];
      for (var j = 0; j < m; j++) {
        z += w[j + 1] * X[i][j];
      }
      final p = _sigmoid(z);
      final diff = p - y[i];
      grad[0] += diff;
      for (var j = 0; j < m; j++) {
        grad[j + 1] += diff * X[i][j];
      }
    }
    for (var k = 0; k <= m; k++) {
      w[k] -= lr * grad[k] / n;
    }
  }
  return w;
}