Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 10 additions & 122 deletions src/proxy_app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@
from rotator_library.model_info_service import init_model_info_service
from proxy_app.request_logger import log_request_to_console, redact_sensitive_data
from proxy_app.security_config import get_cors_settings, validate_secret_settings
from proxy_app.stream_usage import StreamUsageTracker
from proxy_app.batch_manager import EmbeddingBatcher
from proxy_app.api_token_auth import ApiActor, get_api_actor, require_admin_api_actor
from proxy_app.detailed_logger import RawIOLogger
Expand Down Expand Up @@ -803,12 +804,11 @@ async def streaming_response_wrapper(
Wraps a streaming response to log the full response after completion
and ensures any errors during the stream are sent to the client.
"""
response_chunks = []
full_response = {}
usage_data = None
tracker = StreamUsageTracker(model=request_data.get("model"))
full_response = tracker.build_logging_payload()
status_code = 200
stream_error: Exception | None = None
model = request_data.get("model")
model = tracker.model

try:
async for chunk_str in response_stream:
Expand All @@ -821,7 +821,8 @@ async def streaming_response_wrapper(
if content != "[DONE]":
try:
chunk_data = json.loads(content)
response_chunks.append(chunk_data)
if isinstance(chunk_data, dict):
tracker.ingest_chunk(chunk_data)
if logger:
logger.log_stream_chunk(chunk_data)
except json.JSONDecodeError:
Expand All @@ -848,123 +849,10 @@ async def streaming_response_wrapper(
)
return # Stop further processing
finally:
if response_chunks:
# --- Aggregation Logic ---
final_message = {"role": "assistant"}
aggregated_tool_calls = {}
usage_data = None
finish_reason = None

for chunk in response_chunks:
if "choices" in chunk and chunk["choices"]:
choice = chunk["choices"][0]
delta = choice.get("delta", {})

# Dynamically aggregate all fields from the delta
for key, value in delta.items():
if value is None:
continue

if key == "content":
if "content" not in final_message:
final_message["content"] = ""
if value:
final_message["content"] += value

elif key == "tool_calls":
for tc_chunk in value:
index = tc_chunk["index"]
if index not in aggregated_tool_calls:
aggregated_tool_calls[index] = {
"type": "function",
"function": {"name": "", "arguments": ""},
}
# Ensure 'function' key exists for this index before accessing its sub-keys
if "function" not in aggregated_tool_calls[index]:
aggregated_tool_calls[index]["function"] = {
"name": "",
"arguments": "",
}
if tc_chunk.get("id"):
aggregated_tool_calls[index]["id"] = tc_chunk["id"]
if "function" in tc_chunk:
if "name" in tc_chunk["function"]:
if tc_chunk["function"]["name"] is not None:
aggregated_tool_calls[index]["function"][
"name"
] += tc_chunk["function"]["name"]
if "arguments" in tc_chunk["function"]:
if (
tc_chunk["function"]["arguments"]
is not None
):
aggregated_tool_calls[index]["function"][
"arguments"
] += tc_chunk["function"]["arguments"]

elif key == "function_call":
if "function_call" not in final_message:
final_message["function_call"] = {
"name": "",
"arguments": "",
}
if "name" in value:
if value["name"] is not None:
final_message["function_call"]["name"] += value[
"name"
]
if "arguments" in value:
if value["arguments"] is not None:
final_message["function_call"]["arguments"] += (
value["arguments"]
)

else: # Generic key handling for other data like 'reasoning'
# FIX: Role should always replace, never concatenate
if key == "role":
final_message[key] = value
elif key not in final_message:
final_message[key] = value
elif isinstance(final_message.get(key), str):
final_message[key] += value
else:
final_message[key] = value

if "finish_reason" in choice and choice["finish_reason"]:
finish_reason = choice["finish_reason"]

if "usage" in chunk and chunk["usage"]:
usage_data = chunk["usage"]

# --- Final Response Construction ---
if aggregated_tool_calls:
final_message["tool_calls"] = list(aggregated_tool_calls.values())
# CRITICAL FIX: Override finish_reason when tool_calls exist
# This ensures OpenCode and other agentic systems continue the conversation loop
finish_reason = "tool_calls"

# Ensure standard fields are present for consistent logging
for field in ["content", "tool_calls", "function_call"]:
if field not in final_message:
final_message[field] = None

first_chunk = response_chunks[0]
final_choice = {
"index": 0,
"message": final_message,
"finish_reason": finish_reason,
}

full_response = {
"id": first_chunk.get("id"),
"object": "chat.completion",
"created": first_chunk.get("created"),
"model": first_chunk.get("model"),
"choices": [final_choice],
"usage": usage_data,
}
model = full_response.get("model") or model
request_id = _resolve_request_id(request, full_response.get("id") or request_id)
full_response = tracker.build_logging_payload()
usage_data = tracker.usage
model = tracker.model or model
request_id = _resolve_request_id(request, tracker.response_id or request_id)

if logger:
logger.log_final_response(
Expand Down
9 changes: 9 additions & 0 deletions src/proxy_app/routers/ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from proxy_app.db import hash_password
from proxy_app.db_models import ApiKey, User
from proxy_app.usage_queries import (
fetch_api_key_last_used_map,
fetch_usage_by_day,
fetch_usage_by_model,
fetch_usage_summary,
Expand Down Expand Up @@ -96,6 +97,14 @@ async def _load_me_context(
select(ApiKey).where(ApiKey.user_id == user_id).order_by(ApiKey.created_at.desc())
)
api_keys = list(rows)
derived_last_used = await fetch_api_key_last_used_map(
session,
user_id=user_id,
api_key_ids=[key.id for key in api_keys],
)
for key in api_keys:
key.last_used_at = derived_last_used.get(key.id, key.last_used_at)

usage_summary = await fetch_usage_summary(session, user_id=user_id)
usage_by_day = await fetch_usage_by_day(session, user_id=user_id, days=days)
return {
Expand Down
13 changes: 11 additions & 2 deletions src/proxy_app/routers/user_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from proxy_app.api_token_auth import hash_api_token
from proxy_app.db_models import ApiKey, User
from proxy_app.usage_queries import (
fetch_api_key_last_used_map,
fetch_usage_by_day,
fetch_usage_by_model,
fetch_usage_summary,
Expand Down Expand Up @@ -105,19 +106,27 @@ async def list_my_api_keys(
current_user: SessionUser = Depends(require_user),
session: AsyncSession = Depends(get_db_session),
) -> ApiKeyListResponse:
rows = await session.scalars(
rows = list(
await session.scalars(
select(ApiKey)
.where(ApiKey.user_id == current_user.id)
.order_by(ApiKey.created_at.desc())
)
)
derived_last_used = await fetch_api_key_last_used_map(
session,
user_id=current_user.id,
api_key_ids=[row.id for row in rows],
)

return ApiKeyListResponse(
api_keys=[
ApiKeyItem(
id=row.id,
name=row.name,
token_prefix=row.token_prefix,
created_at=row.created_at,
last_used_at=row.last_used_at,
last_used_at=derived_last_used.get(row.id, row.last_used_at),
revoked_at=row.revoked_at,
expires_at=row.expires_at,
)
Expand Down
40 changes: 40 additions & 0 deletions src/proxy_app/stream_usage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from dataclasses import dataclass
from typing import Any


@dataclass
class StreamUsageTracker:
response_id: str | None = None
model: str | None = None
created: int | None = None
usage: dict[str, Any] | None = None

def ingest_chunk(self, chunk_data: dict[str, Any]) -> None:
if self.response_id is None:
response_id = chunk_data.get("id")
if isinstance(response_id, str):
self.response_id = response_id

if self.model is None:
model = chunk_data.get("model")
if isinstance(model, str):
self.model = model

if self.created is None:
created = chunk_data.get("created")
if isinstance(created, int):
self.created = created

usage = chunk_data.get("usage")
if isinstance(usage, dict) and usage:
self.usage = usage

def build_logging_payload(self) -> dict[str, Any]:
return {
"id": self.response_id,
"object": "chat.completion",
"created": self.created,
"model": self.model,
"choices": [],
"usage": self.usage,
}
4 changes: 3 additions & 1 deletion src/proxy_app/templates/me.html
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ <h2>API Keys</h2>
<th>Name</th>
<th>Prefix</th>
<th>Created</th>
<th>Last used</th>
<th>Status</th>
<th>Action</th>
</tr>
Expand All @@ -51,6 +52,7 @@ <h2>API Keys</h2>
<td>{{ key.name }}</td>
<td><code>{{ key.token_prefix }}</code></td>
<td>{{ key.created_at }}</td>
<td>{{ key.last_used_at if key.last_used_at else "-" }}</td>
<td>{% if key.revoked_at %}Revoked{% else %}Active{% endif %}</td>
<td>
{% if not key.revoked_at %}
Expand All @@ -65,7 +67,7 @@ <h2>API Keys</h2>
</tr>
{% else %}
<tr>
<td colspan="6" class="muted">No API keys yet.</td>
<td colspan="7" class="muted">No API keys yet.</td>
</tr>
{% endfor %}
</tbody>
Expand Down
23 changes: 23 additions & 0 deletions src/proxy_app/usage_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,3 +118,26 @@ async def fetch_usage_by_model(
}
for row in rows
]


async def fetch_api_key_last_used_map(
session: AsyncSession,
*,
user_id: int,
api_key_ids: list[int],
) -> dict[int, datetime]:
if not api_key_ids:
return {}

rows = await session.execute(
select(UsageEvent.api_key_id, func.max(UsageEvent.timestamp))
.where(UsageEvent.user_id == user_id)
.where(UsageEvent.api_key_id.in_(api_key_ids))
.group_by(UsageEvent.api_key_id)
)

result: dict[int, datetime] = {}
for api_key_id, last_used in rows:
if api_key_id is not None and last_used is not None:
result[int(api_key_id)] = last_used
return result
46 changes: 45 additions & 1 deletion tests/test_api_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from proxy_app.api_token_auth import hash_api_token
from proxy_app.auth import SessionUser
from proxy_app.db_models import ApiKey
from proxy_app.db_models import ApiKey, UsageEvent
from proxy_app.routers.user_api import (
CreateApiKeyRequest,
create_my_api_key,
Expand Down Expand Up @@ -65,3 +65,47 @@ async def test_create_list_revoke_api_key_hides_plaintext_at_rest(
assert revoked == {"ok": True}
assert reloaded is not None
assert reloaded.revoked_at is not None


@pytest.mark.asyncio
async def test_list_api_keys_uses_derived_last_used_timestamp(
session_maker,
seeded_user,
monkeypatch: pytest.MonkeyPatch,
) -> None:
monkeypatch.setenv("API_TOKEN_PEPPER", "pepper-for-tests")
monkeypatch.setattr(
"proxy_app.routers.user_api.generate_api_token",
lambda: "pk_plaintext_for_last_used_case",
)

session_user = SessionUser(id=seeded_user.id, username=seeded_user.username, role="user")

async with session_maker() as session:
created = await create_my_api_key(
payload=CreateApiKeyRequest(name="usage key"),
current_user=session_user,
session=session,
)

async with session_maker() as session:
session.add(
UsageEvent(
user_id=seeded_user.id,
api_key_id=created.id,
endpoint="/v1/chat/completions",
provider="openai",
model="openai/gpt-4o-mini",
request_id="req-derived-last-used",
status_code=200,
total_tokens=12,
)
)
await session.commit()

async with session_maker() as session:
listed = await list_my_api_keys(current_user=session_user, session=session)

assert listed.api_keys
listed_item = listed.api_keys[0]
assert listed_item.last_used_at is not None
Loading
Loading