diff --git a/src/rotator_library/anthropic_compat/streaming.py b/src/rotator_library/anthropic_compat/streaming.py index ecb074ba..3fa37ae6 100644 --- a/src/rotator_library/anthropic_compat/streaming.py +++ b/src/rotator_library/anthropic_compat/streaming.py @@ -25,6 +25,7 @@ async def anthropic_streaming_wrapper( request_id: Optional[str] = None, is_disconnected: Optional[Callable[[], Awaitable[bool]]] = None, transaction_logger: Optional["TransactionLogger"] = None, + precalculated_input_tokens: Optional[int] = None, ) -> AsyncGenerator[str, None]: """ Convert OpenAI streaming format to Anthropic streaming format. @@ -47,6 +48,10 @@ async def anthropic_streaming_wrapper( request_id: Optional request ID (auto-generated if not provided) is_disconnected: Optional async callback that returns True if client disconnected transaction_logger: Optional TransactionLogger for logging the final Anthropic response + precalculated_input_tokens: Optional pre-calculated input token count for message_start. + When provided, this value is used in message_start to match Anthropic's native + behavior (which provides input_tokens upfront). Without this, message_start will + have input_tokens=0 since OpenAI-format streams provide usage data at the end. Yields: SSE format strings in Anthropic's streaming format @@ -60,7 +65,9 @@ async def anthropic_streaming_wrapper( current_block_index = 0 tool_calls_by_index = {} # Track tool calls by their index tool_block_indices = {} # Track which block index each tool call uses - input_tokens = 0 + # Use precalculated input tokens if provided, otherwise start at 0 + # This allows message_start to have accurate input_tokens like Anthropic's native API + input_tokens = precalculated_input_tokens if precalculated_input_tokens is not None else 0 output_tokens = 0 cached_tokens = 0 # Track cached tokens for proper Anthropic format accumulated_text = "" # Track accumulated text for logging @@ -128,7 +135,10 @@ async def anthropic_streaming_wrapper( stop_reason_final = stop_reason # Build final usage dict with cached tokens - final_usage = {"output_tokens": output_tokens} + final_usage = { + "input_tokens": input_tokens - cached_tokens, + "output_tokens": output_tokens, + } if cached_tokens > 0: final_usage["cache_read_input_tokens"] = cached_tokens final_usage["cache_creation_input_tokens"] = 0 @@ -416,7 +426,10 @@ async def anthropic_streaming_wrapper( yield f'event: content_block_stop\ndata: {{"type": "content_block_stop", "index": {current_block_index}}}\n\n' # Build final usage with cached tokens - final_usage = {"output_tokens": 0} + final_usage = { + "input_tokens": input_tokens - cached_tokens, + "output_tokens": 0, + } if cached_tokens > 0: final_usage["cache_read_input_tokens"] = cached_tokens final_usage["cache_creation_input_tokens"] = 0 diff --git a/src/rotator_library/client/anthropic.py b/src/rotator_library/client/anthropic.py index 507e82fb..e25c2673 100644 --- a/src/rotator_library/client/anthropic.py +++ b/src/rotator_library/client/anthropic.py @@ -103,6 +103,14 @@ async def messages( openai_request["_parent_log_dir"] = anthropic_logger.log_dir if request.stream: + # Pre-calculate input tokens for message_start + # Anthropic's native API provides input_tokens in message_start, but OpenAI-format + # streams only provide usage data at the end. We calculate upfront to match behavior. + precalculated_input_tokens = self._client.token_count( + model=request.model, + messages=openai_request.get("messages", []), + ) + # Streaming response response_generator = await self._client.acompletion( request=raw_request, @@ -123,6 +131,7 @@ async def messages( request_id=request_id, is_disconnected=is_disconnected, transaction_logger=anthropic_logger, + precalculated_input_tokens=precalculated_input_tokens, ) else: # Non-streaming response diff --git a/src/rotator_library/transaction_logger.py b/src/rotator_library/transaction_logger.py index e1de4d67..61f01e91 100644 --- a/src/rotator_library/transaction_logger.py +++ b/src/rotator_library/transaction_logger.py @@ -265,8 +265,12 @@ def _log_metadata( model = response_data.get("model", self.model) finish_reason = "N/A" + # Handle OpenAI format (choices[0].finish_reason) if "choices" in response_data and response_data["choices"]: finish_reason = response_data["choices"][0].get("finish_reason", "N/A") + # Handle Anthropic format (stop_reason at top level) + elif "stop_reason" in response_data: + finish_reason = response_data.get("stop_reason", "N/A") # Check for provider subdirectory has_provider_logs = False @@ -279,6 +283,19 @@ def _log_metadata( except OSError: has_provider_logs = False + # Extract token counts - support both OpenAI and Anthropic formats + # Prefers OpenAI format if available: prompt_tokens, completion_tokens + # Falls back to Anthropic format: input_tokens, output_tokens + prompt_tokens = usage.get("prompt_tokens") + if prompt_tokens is None: + prompt_tokens = usage.get("input_tokens") + completion_tokens = usage.get("completion_tokens") + if completion_tokens is None: + completion_tokens = usage.get("output_tokens") + total_tokens = usage.get("total_tokens") + if total_tokens is None and prompt_tokens is not None and completion_tokens is not None: + total_tokens = prompt_tokens + completion_tokens + metadata = { "request_id": self.request_id, "timestamp_utc": datetime.utcnow().isoformat(), @@ -288,9 +305,9 @@ def _log_metadata( "model": model, "streaming": self.streaming, "usage": { - "prompt_tokens": usage.get("prompt_tokens"), - "completion_tokens": usage.get("completion_tokens"), - "total_tokens": usage.get("total_tokens"), + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "total_tokens": total_tokens, }, "finish_reason": finish_reason, "has_provider_logs": has_provider_logs,