predict method

List<List<double>> predict(
  1. List<List<double>> X
)

Implementation

List<List<double>> predict(List<List<double>> X) {
  if (!_inited) _initParams();
  final n = X.length;
  final outputs = <List<double>>[];
  for (var xi = 0; xi < n; xi++) {
    var a = X[xi];
    for (var l = 0; l < weights.length; l++) {
      final w = weights[l];
      final b = biases[l];
      final z = List<double>.filled(w.length, 0.0);
      for (var i = 0; i < w.length; i++) {
        var s = b[i];
        for (var j = 0; j < w[i].length; j++) {
          s += w[i][j] * a[j];
        }
        z[i] = s;
      }
      if (l == weights.length - 1) {
        a = z.map(_sigmoid).toList();
      } else {
        a = z.map(_relu).toList();
      }
    }
    outputs.add(a);
  }
  return outputs;
}