From ddd46fa98ab660e1c5d1b2eca36c1d49557f4d2c Mon Sep 17 00:00:00 2001 From: mrkaye97 Date: Fri, 7 Feb 2025 21:34:45 -0700 Subject: [PATCH 1/2] fix: replace duped code with `asyncio.to_thread` --- hatchet_sdk/clients/admin.py | 160 ++--------------------------------- 1 file changed, 9 insertions(+), 151 deletions(-) diff --git a/hatchet_sdk/clients/admin.py b/hatchet_sdk/clients/admin.py index e3d345b1..62567634 100644 --- a/hatchet_sdk/clients/admin.py +++ b/hatchet_sdk/clients/admin.py @@ -1,5 +1,7 @@ +import asyncio import json from datetime import datetime +from functools import partial from typing import Any, Callable, Dict, List, Optional, TypedDict, TypeVar, Union import grpc @@ -168,9 +170,7 @@ def _prepare_schedule_workflow_request( class AdminClientAioImpl(AdminClientBase): def __init__(self, config: ClientConfig): - aio_conn = new_conn(config, True) self.config = config - self.aio_client = WorkflowServiceStub(aio_conn) self.token = config.token self.listener_client = new_listener(config) self.namespace = config.namespace @@ -197,69 +197,7 @@ async def run( async def run_workflow( self, workflow_name: str, input: any, options: TriggerWorkflowOptions = None ) -> WorkflowRunRef: - ctx = parse_carrier_from_metadata( - (options or {}).get("additional_metadata", {}) - ) - - with self.otel_tracer.start_as_current_span( - f"hatchet.async_run_workflow.{workflow_name}", context=ctx - ) as span: - carrier = create_carrier() - - try: - if not self.pooled_workflow_listener: - self.pooled_workflow_listener = PooledWorkflowRunListener( - self.config - ) - - namespace = self.namespace - - if ( - options is not None - and "namespace" in options - and options["namespace"] is not None - ): - namespace = options.pop("namespace") - - if namespace != "" and not workflow_name.startswith(self.namespace): - workflow_name = f"{namespace}{workflow_name}" - - if options is not None and "additional_metadata" in options: - options["additional_metadata"] = inject_carrier_into_metadata( - options["additional_metadata"], carrier - ) - span.set_attributes( - flatten( - options["additional_metadata"], parent_key="", separator="." - ) - ) - - request = self._prepare_workflow_request(workflow_name, input, options) - - span.add_event( - "Triggering workflow", attributes={"workflow_name": workflow_name} - ) - - resp: TriggerWorkflowResponse = await self.aio_client.TriggerWorkflow( - request, - metadata=get_metadata(self.token), - ) - - span.add_event( - "Received workflow response", - attributes={"workflow_name": workflow_name}, - ) - - return WorkflowRunRef( - workflow_run_id=resp.workflow_run_id, - workflow_listener=self.pooled_workflow_listener, - workflow_run_event_listener=self.listener_client, - ) - except (grpc.RpcError, grpc.aio.AioRpcError) as e: - if e.code() == grpc.StatusCode.ALREADY_EXISTS: - raise DedupeViolationErr(e.details()) - - raise e + return await asyncio.to_thread(self.run_workflow, workflow_name, input, options) @tenacity_retry async def run_workflows( @@ -267,51 +205,7 @@ async def run_workflows( workflows: list[WorkflowRunDict], options: TriggerWorkflowOptions | None = None, ) -> List[WorkflowRunRef]: - if len(workflows) == 0: - raise ValueError("No workflows to run") - - if not self.pooled_workflow_listener: - self.pooled_workflow_listener = PooledWorkflowRunListener(self.config) - - namespace = self.namespace - - if ( - options is not None - and "namespace" in options - and options["namespace"] is not None - ): - namespace = options["namespace"] - del options["namespace"] - - workflow_run_requests: TriggerWorkflowRequest = [] - - for workflow in workflows: - workflow_name = workflow["workflow_name"] - input_data = workflow["input"] - options = workflow["options"] - - if namespace != "" and not workflow_name.startswith(self.namespace): - workflow_name = f"{namespace}{workflow_name}" - - # Prepare and trigger workflow for each workflow name and input - request = self._prepare_workflow_request(workflow_name, input_data, options) - workflow_run_requests.append(request) - - request = BulkTriggerWorkflowRequest(workflows=workflow_run_requests) - - resp: BulkTriggerWorkflowResponse = await self.aio_client.BulkTriggerWorkflow( - request, - metadata=get_metadata(self.token), - ) - - return [ - WorkflowRunRef( - workflow_run_id=workflow_run_id, - workflow_listener=self.pooled_workflow_listener, - workflow_run_event_listener=self.listener_client, - ) - for workflow_run_id in resp.workflow_run_ids - ] + return await asyncio.to_thread(self.run_workflows, workflows, options) @tenacity_retry async def put_workflow( @@ -320,12 +214,7 @@ async def put_workflow( workflow: CreateWorkflowVersionOpts | WorkflowMeta, overrides: CreateWorkflowVersionOpts | None = None, ) -> WorkflowVersion: - opts = self._prepare_put_workflow_request(name, workflow, overrides) - - return await self.aio_client.PutWorkflow( - opts, - metadata=get_metadata(self.token), - ) + return await asyncio.to_thread(self.put_workflow, name, workflow, overrides) @tenacity_retry async def put_rate_limit( @@ -334,14 +223,7 @@ async def put_rate_limit( limit: int, duration: RateLimitDuration = RateLimitDuration.SECOND, ): - await self.aio_client.PutRateLimit( - PutRateLimitRequest( - key=key, - limit=limit, - duration=duration, - ), - metadata=get_metadata(self.token), - ) + return await asyncio.to_thread(self.put_rate_limit, key, limit, duration) @tenacity_retry async def schedule_workflow( @@ -351,33 +233,9 @@ async def schedule_workflow( input={}, options: ScheduleTriggerWorkflowOptions = None, ) -> WorkflowVersion: - try: - namespace = self.namespace - - if ( - options is not None - and "namespace" in options - and options["namespace"] is not None - ): - namespace = options["namespace"] - del options["namespace"] - - if namespace != "" and not name.startswith(self.namespace): - name = f"{namespace}{name}" - - request = self._prepare_schedule_workflow_request( - name, schedules, input, options - ) - - return await self.aio_client.ScheduleWorkflow( - request, - metadata=get_metadata(self.token), - ) - except (grpc.aio.AioRpcError, grpc.RpcError) as e: - if e.code() == grpc.StatusCode.ALREADY_EXISTS: - raise DedupeViolationErr(e.details()) - - raise e + return await asyncio.to_thread( + self.schedule_workflow, name, schedules, input, options + ) class AdminClient(AdminClientBase): From 07ecaf63d9b57ed8be3c2ba53037c8458b451aa4 Mon Sep 17 00:00:00 2001 From: mrkaye97 Date: Fri, 7 Feb 2025 21:36:09 -0700 Subject: [PATCH 2/2] fix: rm partial --- hatchet_sdk/clients/admin.py | 1 - 1 file changed, 1 deletion(-) diff --git a/hatchet_sdk/clients/admin.py b/hatchet_sdk/clients/admin.py index 62567634..32b61cac 100644 --- a/hatchet_sdk/clients/admin.py +++ b/hatchet_sdk/clients/admin.py @@ -1,7 +1,6 @@ import asyncio import json from datetime import datetime -from functools import partial from typing import Any, Callable, Dict, List, Optional, TypedDict, TypeVar, Union import grpc