predict method

  1. @override
int predict(
  1. List<double> X
)
override

Implementation of sklearn.neural_network.MLPClassifier.predict.

Implementation

@override
int predict(List<double> X) {
  List<List<double>> network = [X, [], []];
  layers.asMap().forEach(
      (i, v) => network[i + 1] = List<double>.filled(layers[i], 0.0));
  for (int i = 0; i < network.length - 1; i++) {
    for (int j = 0; j < network[i + 1].length; j++) {
      network[i + 1][j] = intercepts[i][j];
      network[i].asMap().forEach(
          (l, v) => network[i + 1][j] += network[i][l] * coefs[i][l][j]);
    }
    if ((i + 1) < (network.length - 1)) {
      network[i + 1] = _activation(network[i + 1]);
    }
  }
  network[network.length - 1] =
      _activation(network[network.length - 1], outActivation);

  if (network[network.length - 1].length == 1) {
    if (network[network.length - 1][0] > 0.5) {
      return classes[1];
    }
    return classes[0];
  }

  return classes[mathutils.argmax(network[network.length - 1])];
}