generate method
Generate a token sequence auto-regressively.
encoderOutput: (1, srcLen, encoderDim)
promptTokens: initial token IDs to seed the generation
maxLength: maximum number of tokens to generate
eosTokenId: token ID that signals end of generation
Returns generated token ID sequence.
Implementation
List<int> generate({
required Tensor encoderOutput,
required List<int> promptTokens,
int maxLength = 1536,
int? eosTokenId,
int? padTokenId,
bool greedy = true,
}) {
final generated = List<int>.from(promptTokens);
KVCache? cache;
int pastLength = 0;
// Process prompt (prefill)
if (promptTokens.length > 1) {
final (_, prefillCache) = forward(
promptTokens.sublist(0, promptTokens.length - 1),
encoderHiddenStates: encoderOutput,
cache: cache,
pastLength: 0,
);
cache = prefillCache;
pastLength = promptTokens.length - 1;
}
// Auto-regressive generation
var currentToken = [promptTokens.last];
for (int step = 0; step < maxLength - promptTokens.length; step++) {
final (logits, newCache) = forward(
currentToken,
encoderHiddenStates: encoderOutput,
cache: cache,
pastLength: pastLength,
);
cache = newCache;
pastLength += 1;
// Get logits for last token: (1, 1, vocabSize) → pick last
final lastLogits = logits[0][logits.shape[1] - 1]; // (vocabSize,)
int nextToken;
if (greedy) {
nextToken = lastLogits.argmax();
} else {
// Sample from softmax distribution
final probs = lastLogits.softmax(0);
nextToken = _sampleFromDistribution(probs);
}
generated.add(nextToken);
currentToken = [nextToken];
// Stop on EOS
if (eosTokenId != null && nextToken == eosTokenId) {
break;
}
// Stop on PAD
if (padTokenId != null && nextToken == padTokenId) {
break;
}
}
return generated;
}