sampleNucleus function

int sampleNucleus(
  1. List<double> row, {
  2. double temp = 1.0,
  3. double topP = 0.9,
})

Implementation

int sampleNucleus(List<double> row, {double temp = 1.0, double topP = 0.9}) {
  // Apply Temperature
  double maxL = row.reduce(math.max);
  List<double> probs = row.map((v) => math.exp((v - maxL) / temp)).toList();

  // Normalize
  double sumExp = probs.reduce((a, b) => a + b);
  for (int i = 0; i < probs.length; i++) {
    probs[i] /= sumExp;
  }

  // Sort for Nucleus
  List<MapEntry<int, double>> indexedProbs = probs.asMap().entries.toList();
  indexedProbs.sort((a, b) => b.value.compareTo(a.value));

  // Find Top-P threshold
  double cumulativeProb = 0.0;
  int cutoffIndex = 0;
  for (int i = 0; i < indexedProbs.length; i++) {
    cumulativeProb += indexedProbs[i].value;
    cutoffIndex = i;
    if (cumulativeProb >= topP) break;
  }

  // Re-normalize top candidates
  List<MapEntry<int, double>> candidates = indexedProbs.sublist(
    0,
    cutoffIndex + 1,
  );
  double candidateSum = candidates.fold(0, (sum, item) => sum + item.value);

  // Random Weighted Sample
  double r = math.Random().nextDouble() * candidateSum;
  double current = 0;
  for (var entry in candidates) {
    current += entry.value;
    if (r <= current) return entry.key;
  }
  return candidates.first.key;
}