predict method

  1. @override
int predict(
  1. List<double> X
)
override

Implementation of sklearn.nayve_bayes.GaussianNB.predict.

Implementation

@override
int predict(List<double> X) {
  // List<double> probabilities = List(sigma.length);
  List<double> probabilities =
      List<double>.generate(sigma.length, (index) => index.toDouble());
  for (int i = 0; i < sigma.length; i++) {
    double sum = 0.0;
    for (int j = 0; j < sigma[0].length; j++) {
      sum += log(2.0 * pi * sigma[i][j]);
    }
    double nij = -0.5 * sum;
    sum = 0.0;
    for (int j = 0; j < sigma.length; j++) {
      sum += pow(X[j] - theta[i][j], 2.0) / sigma[i][j];
    }
    nij -= 0.5 * sum;
    probabilities[i] = log(classPrior[i]) + nij;
  }

  return classes[argmax(probabilities)];
}