Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 16 additions & 3 deletions src/rotator_library/anthropic_compat/streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ async def anthropic_streaming_wrapper(
request_id: Optional[str] = None,
is_disconnected: Optional[Callable[[], Awaitable[bool]]] = None,
transaction_logger: Optional["TransactionLogger"] = None,
precalculated_input_tokens: Optional[int] = None,
) -> AsyncGenerator[str, None]:
"""
Convert OpenAI streaming format to Anthropic streaming format.
Expand All @@ -47,6 +48,10 @@ async def anthropic_streaming_wrapper(
request_id: Optional request ID (auto-generated if not provided)
is_disconnected: Optional async callback that returns True if client disconnected
transaction_logger: Optional TransactionLogger for logging the final Anthropic response
precalculated_input_tokens: Optional pre-calculated input token count for message_start.
When provided, this value is used in message_start to match Anthropic's native
behavior (which provides input_tokens upfront). Without this, message_start will
have input_tokens=0 since OpenAI-format streams provide usage data at the end.

Yields:
SSE format strings in Anthropic's streaming format
Expand All @@ -60,7 +65,9 @@ async def anthropic_streaming_wrapper(
current_block_index = 0
tool_calls_by_index = {} # Track tool calls by their index
tool_block_indices = {} # Track which block index each tool call uses
input_tokens = 0
# Use precalculated input tokens if provided, otherwise start at 0
# This allows message_start to have accurate input_tokens like Anthropic's native API
input_tokens = precalculated_input_tokens if precalculated_input_tokens is not None else 0
output_tokens = 0
cached_tokens = 0 # Track cached tokens for proper Anthropic format
accumulated_text = "" # Track accumulated text for logging
Expand Down Expand Up @@ -128,7 +135,10 @@ async def anthropic_streaming_wrapper(
stop_reason_final = stop_reason

# Build final usage dict with cached tokens
final_usage = {"output_tokens": output_tokens}
final_usage = {
"input_tokens": input_tokens - cached_tokens,
"output_tokens": output_tokens,
}
if cached_tokens > 0:
final_usage["cache_read_input_tokens"] = cached_tokens
final_usage["cache_creation_input_tokens"] = 0
Expand Down Expand Up @@ -416,7 +426,10 @@ async def anthropic_streaming_wrapper(
yield f'event: content_block_stop\ndata: {{"type": "content_block_stop", "index": {current_block_index}}}\n\n'

# Build final usage with cached tokens
final_usage = {"output_tokens": 0}
final_usage = {
"input_tokens": input_tokens - cached_tokens,
"output_tokens": 0,
}
if cached_tokens > 0:
final_usage["cache_read_input_tokens"] = cached_tokens
final_usage["cache_creation_input_tokens"] = 0
Expand Down
9 changes: 9 additions & 0 deletions src/rotator_library/client/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,14 @@ async def messages(
openai_request["_parent_log_dir"] = anthropic_logger.log_dir

if request.stream:
# Pre-calculate input tokens for message_start
# Anthropic's native API provides input_tokens in message_start, but OpenAI-format
# streams only provide usage data at the end. We calculate upfront to match behavior.
precalculated_input_tokens = self._client.token_count(
model=request.model,
messages=openai_request.get("messages", []),
)

# Streaming response
response_generator = await self._client.acompletion(
request=raw_request,
Expand All @@ -123,6 +131,7 @@ async def messages(
request_id=request_id,
is_disconnected=is_disconnected,
transaction_logger=anthropic_logger,
precalculated_input_tokens=precalculated_input_tokens,
)
else:
# Non-streaming response
Expand Down
23 changes: 20 additions & 3 deletions src/rotator_library/transaction_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,8 +265,12 @@ def _log_metadata(
model = response_data.get("model", self.model)
finish_reason = "N/A"

# Handle OpenAI format (choices[0].finish_reason)
if "choices" in response_data and response_data["choices"]:
finish_reason = response_data["choices"][0].get("finish_reason", "N/A")
# Handle Anthropic format (stop_reason at top level)
elif "stop_reason" in response_data:
finish_reason = response_data.get("stop_reason", "N/A")

# Check for provider subdirectory
has_provider_logs = False
Expand All @@ -279,6 +283,19 @@ def _log_metadata(
except OSError:
has_provider_logs = False

# Extract token counts - support both OpenAI and Anthropic formats
# Prefers OpenAI format if available: prompt_tokens, completion_tokens
# Falls back to Anthropic format: input_tokens, output_tokens
prompt_tokens = usage.get("prompt_tokens")
if prompt_tokens is None:
prompt_tokens = usage.get("input_tokens")
completion_tokens = usage.get("completion_tokens")
if completion_tokens is None:
completion_tokens = usage.get("output_tokens")
total_tokens = usage.get("total_tokens")
if total_tokens is None and prompt_tokens is not None and completion_tokens is not None:
total_tokens = prompt_tokens + completion_tokens

metadata = {
"request_id": self.request_id,
"timestamp_utc": datetime.utcnow().isoformat(),
Expand All @@ -288,9 +305,9 @@ def _log_metadata(
"model": model,
"streaming": self.streaming,
"usage": {
"prompt_tokens": usage.get("prompt_tokens"),
"completion_tokens": usage.get("completion_tokens"),
"total_tokens": usage.get("total_tokens"),
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": total_tokens,
},
"finish_reason": finish_reason,
"has_provider_logs": has_provider_logs,
Expand Down