collectStreamToMessage function

Future<Message> collectStreamToMessage(
  1. Stream<StreamUpdate> stream, {
  2. void onUpdate(
    1. StreamUpdate update
    )?,
})

Consume a Stream<StreamUpdate> and assemble the final Message.

Optionally invokes onUpdate for each intermediate update (useful for driving progress UI while still awaiting the complete message).

Implementation

Future<Message> collectStreamToMessage(
  Stream<StreamUpdate> stream, {
  void Function(StreamUpdate update)? onUpdate,
}) async {
  String? accMessageId;
  // ignore: unused_local_variable
  String? accModel;
  String? stopReason;
  int inputTokens = 0;
  int outputTokens = 0;
  int? cacheCreationTokens;
  int? cacheReadTokens;
  final blocks = <int, ContentBlockAccumulator>{};

  await for (final update in stream) {
    onUpdate?.call(update);

    switch (update) {
      case MessageStartUpdate():
        accMessageId = update.messageId;
        accModel = update.model;
        if (update.usage != null) {
          inputTokens = update.usage!.inputTokens;
          outputTokens = update.usage!.outputTokens;
          cacheCreationTokens = update.usage!.cacheCreationInputTokens;
          cacheReadTokens = update.usage!.cacheReadInputTokens;
        }
      case TextDelta(:final text, :final blockIndex):
        final acc = blocks.putIfAbsent(
          blockIndex,
          () => TextAccumulator(blockIndex),
        );
        if (acc is TextAccumulator) acc.append(text);
      case ThinkingDelta(:final text, :final blockIndex):
        final acc = blocks.putIfAbsent(
          blockIndex,
          () => ThinkingAccumulator(blockIndex),
        );
        if (acc is ThinkingAccumulator) acc.append(text);
      case ToolUseStart(:final toolName, :final toolId, :final blockIndex):
        blocks.putIfAbsent(
          blockIndex,
          () => ToolUseAccumulator(
            blockIndex,
            toolId: toolId,
            toolName: toolName,
          ),
        );
      case ToolUseInputDelta(:final partialJson, :final blockIndex):
        final acc = blocks[blockIndex];
        if (acc is ToolUseAccumulator) acc.append(partialJson);
      case UsageUpdate():
        inputTokens = update.inputTokens;
        outputTokens = update.outputTokens;
        cacheCreationTokens = update.cacheCreationInputTokens;
        cacheReadTokens = update.cacheReadInputTokens;
      case MessageComplete():
        stopReason = update.stopReason;
      case StreamError(:final message):
        throw ApiError(type: ApiErrorType.unknown, message: message);
      default:
        break;
    }
  }

  // Sort blocks by index and build content.
  final sortedKeys = blocks.keys.toList()..sort();
  final content = sortedKeys.map((k) => blocks[k]!.toContentBlock()).toList();

  final stopReasonEnum = switch (stopReason) {
    'end_turn' => StopReason.endTurn,
    'max_tokens' => StopReason.maxTokens,
    'tool_use' => StopReason.toolUse,
    'stop_sequence' => StopReason.stopSequence,
    _ => null,
  };

  return Message(
    id: accMessageId,
    role: MessageRole.assistant,
    content: content.isEmpty ? [const TextBlock('')] : content,
    stopReason: stopReasonEnum,
    usage: TokenUsage(
      inputTokens: inputTokens,
      outputTokens: outputTokens,
      cacheCreationInputTokens: cacheCreationTokens,
      cacheReadInputTokens: cacheReadTokens,
    ),
  );
}