Skip to content
This repository was archived by the owner on Mar 26, 2025. It is now read-only.
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
159 changes: 8 additions & 151 deletions hatchet_sdk/clients/admin.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import json
from datetime import datetime
from typing import Any, Callable, Dict, List, Optional, TypedDict, TypeVar, Union
Expand Down Expand Up @@ -168,9 +169,7 @@ def _prepare_schedule_workflow_request(

class AdminClientAioImpl(AdminClientBase):
def __init__(self, config: ClientConfig):
aio_conn = new_conn(config, True)
self.config = config
self.aio_client = WorkflowServiceStub(aio_conn)
self.token = config.token
self.listener_client = new_listener(config)
self.namespace = config.namespace
Expand All @@ -197,121 +196,15 @@ async def run(
async def run_workflow(
self, workflow_name: str, input: any, options: TriggerWorkflowOptions = None
) -> WorkflowRunRef:
ctx = parse_carrier_from_metadata(
(options or {}).get("additional_metadata", {})
)

with self.otel_tracer.start_as_current_span(
f"hatchet.async_run_workflow.{workflow_name}", context=ctx
) as span:
carrier = create_carrier()

try:
if not self.pooled_workflow_listener:
self.pooled_workflow_listener = PooledWorkflowRunListener(
self.config
)

namespace = self.namespace

if (
options is not None
and "namespace" in options
and options["namespace"] is not None
):
namespace = options.pop("namespace")

if namespace != "" and not workflow_name.startswith(self.namespace):
workflow_name = f"{namespace}{workflow_name}"

if options is not None and "additional_metadata" in options:
options["additional_metadata"] = inject_carrier_into_metadata(
options["additional_metadata"], carrier
)
span.set_attributes(
flatten(
options["additional_metadata"], parent_key="", separator="."
)
)

request = self._prepare_workflow_request(workflow_name, input, options)

span.add_event(
"Triggering workflow", attributes={"workflow_name": workflow_name}
)

resp: TriggerWorkflowResponse = await self.aio_client.TriggerWorkflow(
request,
metadata=get_metadata(self.token),
)

span.add_event(
"Received workflow response",
attributes={"workflow_name": workflow_name},
)

return WorkflowRunRef(
workflow_run_id=resp.workflow_run_id,
workflow_listener=self.pooled_workflow_listener,
workflow_run_event_listener=self.listener_client,
)
except (grpc.RpcError, grpc.aio.AioRpcError) as e:
if e.code() == grpc.StatusCode.ALREADY_EXISTS:
raise DedupeViolationErr(e.details())

raise e
return await asyncio.to_thread(self.run_workflow, workflow_name, input, options)

@tenacity_retry
async def run_workflows(
self,
workflows: list[WorkflowRunDict],
options: TriggerWorkflowOptions | None = None,
) -> List[WorkflowRunRef]:
if len(workflows) == 0:
raise ValueError("No workflows to run")

if not self.pooled_workflow_listener:
self.pooled_workflow_listener = PooledWorkflowRunListener(self.config)

namespace = self.namespace

if (
options is not None
and "namespace" in options
and options["namespace"] is not None
):
namespace = options["namespace"]
del options["namespace"]

workflow_run_requests: TriggerWorkflowRequest = []

for workflow in workflows:
workflow_name = workflow["workflow_name"]
input_data = workflow["input"]
options = workflow["options"]

if namespace != "" and not workflow_name.startswith(self.namespace):
workflow_name = f"{namespace}{workflow_name}"

# Prepare and trigger workflow for each workflow name and input
request = self._prepare_workflow_request(workflow_name, input_data, options)
workflow_run_requests.append(request)

request = BulkTriggerWorkflowRequest(workflows=workflow_run_requests)

resp: BulkTriggerWorkflowResponse = await self.aio_client.BulkTriggerWorkflow(
request,
metadata=get_metadata(self.token),
)

return [
WorkflowRunRef(
workflow_run_id=workflow_run_id,
workflow_listener=self.pooled_workflow_listener,
workflow_run_event_listener=self.listener_client,
)
for workflow_run_id in resp.workflow_run_ids
]
return await asyncio.to_thread(self.run_workflows, workflows, options)

@tenacity_retry
async def put_workflow(
Expand All @@ -320,12 +213,7 @@ async def put_workflow(
workflow: CreateWorkflowVersionOpts | WorkflowMeta,
overrides: CreateWorkflowVersionOpts | None = None,
) -> WorkflowVersion:
opts = self._prepare_put_workflow_request(name, workflow, overrides)

return await self.aio_client.PutWorkflow(
opts,
metadata=get_metadata(self.token),
)
return await asyncio.to_thread(self.put_workflow, name, workflow, overrides)

@tenacity_retry
async def put_rate_limit(
Expand All @@ -334,14 +222,7 @@ async def put_rate_limit(
limit: int,
duration: RateLimitDuration = RateLimitDuration.SECOND,
):
await self.aio_client.PutRateLimit(
PutRateLimitRequest(
key=key,
limit=limit,
duration=duration,
),
metadata=get_metadata(self.token),
)
return await asyncio.to_thread(self.put_rate_limit, key, limit, duration)

@tenacity_retry
async def schedule_workflow(
Expand All @@ -351,33 +232,9 @@ async def schedule_workflow(
input={},
options: ScheduleTriggerWorkflowOptions = None,
) -> WorkflowVersion:
try:
namespace = self.namespace

if (
options is not None
and "namespace" in options
and options["namespace"] is not None
):
namespace = options["namespace"]
del options["namespace"]

if namespace != "" and not name.startswith(self.namespace):
name = f"{namespace}{name}"

request = self._prepare_schedule_workflow_request(
name, schedules, input, options
)

return await self.aio_client.ScheduleWorkflow(
request,
metadata=get_metadata(self.token),
)
except (grpc.aio.AioRpcError, grpc.RpcError) as e:
if e.code() == grpc.StatusCode.ALREADY_EXISTS:
raise DedupeViolationErr(e.details())

raise e
return await asyncio.to_thread(
self.schedule_workflow, name, schedules, input, options
)


class AdminClient(AdminClientBase):
Expand Down
Loading