sampleNucleus function
Implementation
int sampleNucleus(List<double> row, {double temp = 1.0, double topP = 0.9}) {
// Apply Temperature
double maxL = row.reduce(math.max);
List<double> probs = row.map((v) => math.exp((v - maxL) / temp)).toList();
// Normalize
double sumExp = probs.reduce((a, b) => a + b);
for (int i = 0; i < probs.length; i++) {
probs[i] /= sumExp;
}
// Sort for Nucleus
List<MapEntry<int, double>> indexedProbs = probs.asMap().entries.toList();
indexedProbs.sort((a, b) => b.value.compareTo(a.value));
// Find Top-P threshold
double cumulativeProb = 0.0;
int cutoffIndex = 0;
for (int i = 0; i < indexedProbs.length; i++) {
cumulativeProb += indexedProbs[i].value;
cutoffIndex = i;
if (cumulativeProb >= topP) break;
}
// Re-normalize top candidates
List<MapEntry<int, double>> candidates = indexedProbs.sublist(
0,
cutoffIndex + 1,
);
double candidateSum = candidates.fold(0, (sum, item) => sum + item.value);
// Random Weighted Sample
double r = math.Random().nextDouble() * candidateSum;
double current = 0;
for (var entry in candidates) {
current += entry.value;
if (r <= current) return entry.key;
}
return candidates.first.key;
}