generate method

List<int> generate({
  1. required Tensor encoderOutput,
  2. required List<int> promptTokens,
  3. int maxLength = 1536,
  4. int? eosTokenId,
  5. int? padTokenId,
  6. bool greedy = true,
})

Generate a token sequence auto-regressively.

encoderOutput: (1, srcLen, encoderDim) promptTokens: initial token IDs to seed the generation maxLength: maximum number of tokens to generate eosTokenId: token ID that signals end of generation

Returns generated token ID sequence.

Implementation

List<int> generate({
  required Tensor encoderOutput,
  required List<int> promptTokens,
  int maxLength = 1536,
  int? eosTokenId,
  int? padTokenId,
  bool greedy = true,
}) {
  final generated = List<int>.from(promptTokens);
  KVCache? cache;
  int pastLength = 0;

  // Process prompt (prefill)
  if (promptTokens.length > 1) {
    final (_, prefillCache) = forward(
      promptTokens.sublist(0, promptTokens.length - 1),
      encoderHiddenStates: encoderOutput,
      cache: cache,
      pastLength: 0,
    );
    cache = prefillCache;
    pastLength = promptTokens.length - 1;
  }

  // Auto-regressive generation
  var currentToken = [promptTokens.last];

  for (int step = 0; step < maxLength - promptTokens.length; step++) {
    final (logits, newCache) = forward(
      currentToken,
      encoderHiddenStates: encoderOutput,
      cache: cache,
      pastLength: pastLength,
    );
    cache = newCache;
    pastLength += 1;

    // Get logits for last token: (1, 1, vocabSize) → pick last
    final lastLogits = logits[0][logits.shape[1] - 1]; // (vocabSize,)

    int nextToken;
    if (greedy) {
      nextToken = lastLogits.argmax();
    } else {
      // Sample from softmax distribution
      final probs = lastLogits.softmax(0);
      nextToken = _sampleFromDistribution(probs);
    }

    generated.add(nextToken);
    currentToken = [nextToken];

    // Stop on EOS
    if (eosTokenId != null && nextToken == eosTokenId) {
      break;
    }

    // Stop on PAD
    if (padTokenId != null && nextToken == padTokenId) {
      break;
    }
  }

  return generated;
}