-
Notifications
You must be signed in to change notification settings - Fork 11
fix(sdk): support async generator functions in control() decorator #116
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -828,12 +828,64 @@ async def async_wrapper(*args: Any, **kwargs: Any) -> Any: | |
| except (AttributeError, TypeError): | ||
| pass | ||
|
|
||
| @functools.wraps(func) | ||
| async def async_gen_wrapper(*args: Any, **kwargs: Any) -> Any: | ||
| agent = _get_current_agent() | ||
| if agent is None: | ||
| logger.warning( | ||
| "No agent initialized. Call agent_control.init() first. " | ||
| "Running without protection." | ||
| ) | ||
| async for chunk in func(*args, **kwargs): | ||
| yield chunk | ||
| return | ||
|
|
||
| controls = _get_server_controls() | ||
|
|
||
| existing_trace_id = get_current_trace_id() | ||
| if existing_trace_id: | ||
| trace_id = existing_trace_id | ||
| span_id = _generate_span_id() | ||
| else: | ||
| trace_id, span_id = get_trace_and_span_ids() | ||
|
|
||
| ctx = ControlContext( | ||
| agent_name=agent.agent_name, | ||
| server_url=_get_server_url(), | ||
| func=func, | ||
| args=args, | ||
| kwargs=kwargs, | ||
| trace_id=trace_id, | ||
| span_id=span_id, | ||
| start_time=time.perf_counter(), | ||
| step_name=step_name, | ||
| ) | ||
| ctx.log_start() | ||
|
|
||
| try: | ||
| # PRE-EXECUTION: Check controls with check_stage="pre" | ||
| await _run_control_check(ctx, "pre", ctx.pre_payload(), controls) | ||
|
|
||
| # Yield chunks while accumulating full output for post-check | ||
| accumulated: list[str] = [] | ||
| async for chunk in func(*args, **kwargs): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Unlike the existing non-streaming wrapper, this implementation yields every chunk to the caller before enforcing the post-stage result. A post-stage deny or steer therefore raises only after the full response has already been delivered, so it no longer provides the same fail-closed behavior as control() on normal async functions. If that tradeoff is intentional, it needs to be called out very clearly; otherwise the stream must be buffered or evaluated chunk-by-chunk. |
||
| accumulated.append(str(chunk)) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| yield chunk | ||
|
|
||
| # POST-EXECUTION: Check controls on full accumulated output | ||
| full_output = "".join(accumulated) | ||
| await _run_control_check(ctx, "post", ctx.post_payload(full_output), controls) | ||
| finally: | ||
| ctx.log_end() | ||
|
|
||
| @functools.wraps(func) | ||
| def sync_wrapper(*args: Any, **kwargs: Any) -> Any: | ||
| return asyncio.run( | ||
| _execute_with_control(func, args, kwargs, is_async=False, step_name=step_name) | ||
| ) | ||
|
|
||
| if inspect.isasyncgenfunction(func): | ||
| return async_gen_wrapper # type: ignore | ||
| if inspect.iscoroutinefunction(func): | ||
| return async_wrapper # type: ignore | ||
| return sync_wrapper # type: ignore | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This path only runs the post-stage check after the async for finishes. If the caller breaks early, cancels the task, or the client disconnects mid-stream, Python closes the wrapper and runs finally, but never reaches the post-check. That leaves a real bypass for streaming consumers, which is exactly where partial reads are common. This path needs explicit handling/documentation and tests for break, cancellation, and aclose().