List<List<double>> predict(List<List<int>> X) { final xs = X.map((toks) => _embedAndPool(toks)).toList(); return head.predict(xs); }