List<List<double>> predict(List<List<List<List<double>>>> X) { final xs = _batchToFeatures(X); return head.predict(xs); }