fit method

void fit(
  1. List<List<double>> X
)

Implementation

void fit(List<List<double>> X) {
  final n = X.length;
  if (n == 0) throw ArgumentError('Empty dataset');
  final m = X[0].length;
  // init with kmeans++
  final rnd = Random(0);
  final centers = <List<double>>[];
  centers.add(List<double>.from(X[rnd.nextInt(n)]));
  final dist = List<double>.filled(n, double.infinity);
  while (centers.length < k) {
    for (var i = 0; i < n; i++) {
      dist[i] = centers
          .map((c) => _distSq(X[i], c))
          .reduce((a, b) => a < b ? a : b);
    }
    final total = dist.reduce((a, b) => a + b);
    var r = rnd.nextDouble() * total;
    var idx = 0;
    while (r > 0 && idx < n) {
      r -= dist[idx++];
    }
    centers.add(List<double>.from(X[(idx - 1).clamp(0, n - 1)]));
  }

  labels = List<int>.filled(n, 0);
  for (var iter = 0; iter < maxIter; iter++) {
    var changed = false;
    // assign
    for (var i = 0; i < n; i++) {
      var best = 0;
      var bestD = double.infinity;
      for (var j = 0; j < centers.length; j++) {
        final d = _distSq(X[i], centers[j]);
        if (d < bestD) {
          bestD = d;
          best = j;
        }
      }
      if (labels![i] != best) {
        labels![i] = best;
        changed = true;
      }
    }
    // update
    final sums = List.generate(k, (_) => List<double>.filled(m, 0.0));
    final counts = List<int>.filled(k, 0);
    for (var i = 0; i < n; i++) {
      final c = labels![i];
      counts[c]++;
      for (var j = 0; j < m; j++) {
        sums[c][j] += X[i][j];
      }
    }
    var maxMove = 0.0;
    for (var c = 0; c < k; c++) {
      if (counts[c] == 0) continue; // leave empty cluster centroid as-is
      final newC = List<double>.filled(m, 0.0);
      for (var j = 0; j < m; j++) {
        newC[j] = sums[c][j] / counts[c];
      }
      maxMove = max(maxMove, sqrt(_distSq(centers[c], newC)));
      centers[c] = newC;
    }
    if (!changed || maxMove <= tol) break;
  }
  centroids = centers;
}