diff --git a/conftest.py b/conftest.py index 2aff5cd3..e10df408 100644 --- a/conftest.py +++ b/conftest.py @@ -1,9 +1,10 @@ import logging +import os import subprocess import time from io import BytesIO from threading import Thread -from typing import AsyncGenerator, Callable, cast +from typing import AsyncGenerator, Callable, Generator, cast import psutil import pytest @@ -23,7 +24,9 @@ def hatchet() -> Hatchet: @pytest.fixture() -def worker(request: pytest.FixtureRequest): +def worker( + request: pytest.FixtureRequest, +) -> Generator[subprocess.Popen[bytes], None, None]: example = cast(str, request.param) command = ["poetry", "run", example] diff --git a/examples/__init__.py b/examples/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/_deprecated/README.md b/examples/_deprecated/README.md deleted file mode 100644 index ee47c61b..00000000 --- a/examples/_deprecated/README.md +++ /dev/null @@ -1 +0,0 @@ -The examples and tests in this directory are deprecated, but we're maintaining them to ensure backwards compatibility. diff --git a/examples/_deprecated/concurrency_limit_rr/event.py b/examples/_deprecated/concurrency_limit_rr/event.py deleted file mode 100644 index 16b2bcd0..00000000 --- a/examples/_deprecated/concurrency_limit_rr/event.py +++ /dev/null @@ -1,15 +0,0 @@ -from dotenv import load_dotenv - -from hatchet_sdk import new_client - -load_dotenv() - -client = new_client() - -for i in range(200): - group = "0" - - if i % 2 == 0: - group = "1" - - client.event.push("concurrency-test", {"group": group}) diff --git a/examples/_deprecated/concurrency_limit_rr/test_dep_concurrency_limit_rr.py b/examples/_deprecated/concurrency_limit_rr/test_dep_concurrency_limit_rr.py deleted file mode 100644 index 978186d1..00000000 --- a/examples/_deprecated/concurrency_limit_rr/test_dep_concurrency_limit_rr.py +++ /dev/null @@ -1,58 +0,0 @@ -import asyncio -import time - -import pytest - -from hatchet_sdk import Hatchet, Worker -from hatchet_sdk.workflow_run import WorkflowRunRef - - -# requires scope module or higher for shared event loop -@pytest.mark.parametrize("worker", ["concurrency_limit_rr"], indirect=True) -@pytest.mark.skip(reason="The timing for this test is not reliable") -@pytest.mark.asyncio(scope="session") -async def test_run(aiohatchet: Hatchet, worker: Worker) -> None: - num_groups = 2 - runs: list[WorkflowRunRef] = [] - - # Start all runs - for i in range(1, num_groups + 1): - run = aiohatchet.admin.run_workflow("ConcurrencyDemoWorkflowRR", {"group": i}) - runs.append(run) - run = aiohatchet.admin.run_workflow("ConcurrencyDemoWorkflowRR", {"group": i}) - runs.append(run) - - # Wait for all results - successful_runs = [] - cancelled_runs = [] - - start_time = time.time() - - # Process each run individually - for i, run in enumerate(runs, start=1): - try: - result = await run.result() - successful_runs.append((i, result)) - except Exception as e: - if "CANCELLED_BY_CONCURRENCY_LIMIT" in str(e): - cancelled_runs.append((i, str(e))) - else: - raise # Re-raise if it's an unexpected error - - end_time = time.time() - total_time = end_time - start_time - - # Check that we have the correct number of successful and cancelled runs - assert ( - len(successful_runs) == 4 - ), f"Expected 4 successful runs, got {len(successful_runs)}" - assert ( - len(cancelled_runs) == 0 - ), f"Expected 0 cancelled run, got {len(cancelled_runs)}" - - # Check that the total time is close to 2 seconds - assert ( - 3.8 <= total_time <= 7 - ), f"Expected runtime to be about 4 seconds, but it took {total_time:.2f} seconds" - - print(f"Total execution time: {total_time:.2f} seconds") diff --git a/examples/_deprecated/concurrency_limit_rr/worker.py b/examples/_deprecated/concurrency_limit_rr/worker.py deleted file mode 100644 index 9678e798..00000000 --- a/examples/_deprecated/concurrency_limit_rr/worker.py +++ /dev/null @@ -1,38 +0,0 @@ -import time - -from dotenv import load_dotenv - -from hatchet_sdk import ConcurrencyLimitStrategy, Context, Hatchet - -load_dotenv() - -hatchet = Hatchet(debug=True) - - -@hatchet.workflow(on_events=["concurrency-test"], schedule_timeout="10m") -class ConcurrencyDemoWorkflowRR: - - # NOTE: We're replacing the concurrency key function with a CEL expression - # to simplify architecture. - # See ../../concurrency_limit_rr/worker.py for the new implementation. - @hatchet.concurrency( - max_runs=1, limit_strategy=ConcurrencyLimitStrategy.GROUP_ROUND_ROBIN - ) - def concurrency(self, context: Context) -> str: - input = context.workflow_input() - print(input) - return f'group-{input["group"]}' - - @hatchet.step() - def step1(self, context: Context) -> None: - print("starting step1") - time.sleep(2) - print("finished step1") - pass - - -workflow = ConcurrencyDemoWorkflowRR() -worker = hatchet.worker("concurrency-demo-worker-rr", max_runs=10) -worker.register_workflow(workflow) - -worker.start() diff --git a/examples/_deprecated/test_event_client.py b/examples/_deprecated/test_event_client.py deleted file mode 100644 index 41c6ad65..00000000 --- a/examples/_deprecated/test_event_client.py +++ /dev/null @@ -1,26 +0,0 @@ -import pytest -from dotenv import load_dotenv - -from hatchet_sdk import new_client -from hatchet_sdk.hatchet import Hatchet - -load_dotenv() - - -@pytest.mark.asyncio(scope="session") -async def test_direct_client_event() -> None: - client = new_client() - e = client.event.push("user:create", {"test": "test"}) - - assert e.eventId is not None - - -@pytest.mark.filterwarnings( - "ignore:Direct access to client is deprecated:DeprecationWarning" -) -@pytest.mark.asyncio(scope="session") -async def test_hatchet_client_event() -> None: - hatchet = Hatchet() - e = hatchet.client.event.push("user:create", {"test": "test"}) - - assert e.eventId is not None diff --git a/examples/affinity-workers/event.py b/examples/affinity-workers/event.py index 3d4cae41..6b01a724 100644 --- a/examples/affinity-workers/event.py +++ b/examples/affinity-workers/event.py @@ -1,5 +1,6 @@ from dotenv import load_dotenv +from hatchet_sdk.clients.events import PushEventOptions from hatchet_sdk.hatchet import Hatchet load_dotenv() @@ -9,5 +10,5 @@ hatchet.event.push( "affinity:run", {"test": "test"}, - options={"additional_metadata": {"hello": "moon"}}, + options=PushEventOptions(additional_metadata={"hello": "moon"}), ) diff --git a/examples/affinity-workers/worker.py b/examples/affinity-workers/worker.py index 5099a804..ea721c9d 100644 --- a/examples/affinity-workers/worker.py +++ b/examples/affinity-workers/worker.py @@ -1,6 +1,7 @@ from dotenv import load_dotenv from hatchet_sdk import Context, Hatchet, WorkerLabelComparator +from hatchet_sdk.labels import DesiredWorkerLabel load_dotenv() @@ -11,12 +12,12 @@ class AffinityWorkflow: @hatchet.step( desired_worker_labels={ - "model": {"value": "fancy-ai-model-v2", "weight": 10}, - "memory": { - "value": 256, - "required": True, - "comparator": WorkerLabelComparator.LESS_THAN, - }, + "model": DesiredWorkerLabel(value="fancy-ai-model-v2", weight=10), + "memory": DesiredWorkerLabel( + value=256, + required=True, + comparator=WorkerLabelComparator.LESS_THAN, + ), }, ) async def step(self, context: Context) -> dict[str, str | None]: diff --git a/examples/blocked_async/event.py b/examples/blocked_async/event.py index 116b227d..3dcc4fcc 100644 --- a/examples/blocked_async/event.py +++ b/examples/blocked_async/event.py @@ -6,7 +6,8 @@ client = new_client() -# client.event.push("user:create", {"test": "test"}) client.event.push( - "user:create", {"test": "test"}, options={"additional_metadata": {"hello": "moon"}} + "user:create", + {"test": "test"}, + options=PushEventOptions(additional_metadata={"hello": "moon"}), ) diff --git a/examples/bulk_fanout/bulk_trigger.py b/examples/bulk_fanout/bulk_trigger.py index d0606673..51905063 100644 --- a/examples/bulk_fanout/bulk_trigger.py +++ b/examples/bulk_fanout/bulk_trigger.py @@ -7,7 +7,7 @@ from dotenv import load_dotenv from hatchet_sdk import new_client -from hatchet_sdk.clients.admin import TriggerWorkflowOptions +from hatchet_sdk.clients.admin import TriggerWorkflowOptions, WorkflowRunDict from hatchet_sdk.clients.rest.models.workflow_run import WorkflowRun from hatchet_sdk.clients.run_event_listener import StepRunEventType @@ -16,25 +16,22 @@ async def main() -> None: load_dotenv() hatchet = new_client() - workflowRuns: list[dict[str, Any]] = [] - - # we are going to run the BulkParent workflow 20 which will trigger the Child workflows n times for each n in range(20) - for i in range(20): - workflowRuns.append( - { - "workflow_name": "BulkParent", - "input": {"n": i}, - "options": { - "additional_metadata": { - "bulk-trigger": i, - "hello-{i}": "earth-{i}", - }, - }, - } + workflow_runs = [ + WorkflowRunDict( + workflow_name="BulkParent", + input={"n": i}, + options=TriggerWorkflowOptions( + additional_metadata={ + "bulk-trigger": i, + "hello-{i}": "earth-{i}", + } + ), ) + for i in range(20) + ] workflowRunRefs = hatchet.admin.run_workflows( - workflowRuns, + workflow_runs, ) results = await asyncio.gather( diff --git a/examples/bulk_fanout/stream.py b/examples/bulk_fanout/stream.py index 08d0cb4a..c0d03388 100644 --- a/examples/bulk_fanout/stream.py +++ b/examples/bulk_fanout/stream.py @@ -6,10 +6,9 @@ from dotenv import load_dotenv -from hatchet_sdk import new_client +from hatchet_sdk import Hatchet, new_client from hatchet_sdk.clients.admin import TriggerWorkflowOptions from hatchet_sdk.clients.run_event_listener import StepRunEventType -from hatchet_sdk.v2.hatchet import Hatchet async def main() -> None: @@ -31,7 +30,7 @@ async def main() -> None: workflowRun = hatchet.admin.run_workflow( "Parent", {"n": 2}, - options={"additional_metadata": {streamKey: streamVal}}, + options=TriggerWorkflowOptions(additional_metadata={streamKey: streamVal}), ) # Stream all events for the additional meta key value diff --git a/examples/bulk_fanout/trigger.py b/examples/bulk_fanout/trigger.py index 1a1b3f17..fe44a627 100644 --- a/examples/bulk_fanout/trigger.py +++ b/examples/bulk_fanout/trigger.py @@ -7,6 +7,7 @@ from hatchet_sdk import new_client from hatchet_sdk.clients.admin import TriggerWorkflowOptions +from hatchet_sdk.clients.events import PushEventOptions from hatchet_sdk.clients.rest.models.workflow_run import WorkflowRun from hatchet_sdk.clients.run_event_listener import StepRunEventType @@ -15,10 +16,10 @@ async def main() -> None: load_dotenv() hatchet = new_client() - workflowRuns: WorkflowRun = [] # type: ignore[assignment] - - event = hatchet.event.push( - "parent:create", {"n": 999}, {"additional_metadata": {"no-dedupe": "world"}} + hatchet.event.push( + "parent:create", + {"n": 999}, + PushEventOptions(additional_metadata={"no-dedupe": "world"}), ) diff --git a/examples/bulk_fanout/worker.py b/examples/bulk_fanout/worker.py index e0ea3c50..baf9953b 100644 --- a/examples/bulk_fanout/worker.py +++ b/examples/bulk_fanout/worker.py @@ -4,7 +4,7 @@ from dotenv import load_dotenv from hatchet_sdk import Context, Hatchet -from hatchet_sdk.clients.admin import ChildWorkflowRunDict +from hatchet_sdk.clients.admin import ChildTriggerWorkflowOptions, ChildWorkflowRunDict load_dotenv() @@ -22,18 +22,17 @@ async def spawn(self, context: Context) -> dict[str, list[Any]]: n = context.workflow_input().get("n", 100) - child_workflow_runs: list[ChildWorkflowRunDict] = [] - - for i in range(n): - - child_workflow_runs.append( - { - "workflow_name": "BulkChild", - "input": {"a": str(i)}, - "key": f"child{i}", - "options": {"additional_metadata": {"hello": "earth"}}, - } + child_workflow_runs = [ + ChildWorkflowRunDict( + workflow_name="BulkChild", + input={"a": str(i)}, + key=f"child{i}", + options=ChildTriggerWorkflowOptions( + additional_metadata={"hello": "earth"} + ), ) + for i in range(n) + ] if len(child_workflow_runs) == 0: return {} diff --git a/examples/dedupe/worker.py b/examples/dedupe/worker.py index 2f22f52d..6e5c1f02 100644 --- a/examples/dedupe/worker.py +++ b/examples/dedupe/worker.py @@ -4,7 +4,7 @@ from dotenv import load_dotenv -from hatchet_sdk import Context, Hatchet +from hatchet_sdk import ChildTriggerWorkflowOptions, Context, Hatchet from hatchet_sdk.clients.admin import DedupeViolationErr from hatchet_sdk.loader import ClientConfig @@ -29,7 +29,9 @@ async def spawn(self, context: Context) -> dict[str, list[Any]]: "DedupeChild", {"a": str(i)}, key=f"child{i}", - options={"additional_metadata": {"dedupe": "test"}}, + options=ChildTriggerWorkflowOptions( + additional_metadata={"dedupe": "test"} + ), ) ).result() ) diff --git a/examples/default_priority/worker.py b/examples/default_priority/worker.py deleted file mode 100644 index 070d20f9..00000000 --- a/examples/default_priority/worker.py +++ /dev/null @@ -1,40 +0,0 @@ -import asyncio -from typing import TypedDict - -from dotenv import load_dotenv - -from hatchet_sdk import Context -from hatchet_sdk.v2.hatchet import Hatchet - -load_dotenv() - -hatchet = Hatchet(debug=True) - - -class MyResultType(TypedDict): - return_string: str - - -@hatchet.function(default_priority=2) -async def high_prio_func(context: Context) -> MyResultType: - await asyncio.sleep(5) - return MyResultType(return_string="High Priority Return") - - -@hatchet.function(default_priority=1) -async def low_prio_func(context: Context) -> MyResultType: - await asyncio.sleep(5) - return MyResultType(return_string="Low Priority Return") - - -def main() -> None: - worker = hatchet.worker("example-priority-worker", max_runs=1) - hatchet.admin.run(high_prio_func, {"test": "test"}) - hatchet.admin.run(high_prio_func, {"test": "test"}) - hatchet.admin.run(low_prio_func, {"test": "test"}) - hatchet.admin.run(low_prio_func, {"test": "test"}) - worker.start() - - -if __name__ == "__main__": - main() diff --git a/examples/durable_sticky_with_affinity/worker.py b/examples/durable_sticky_with_affinity/worker.py deleted file mode 100644 index 93505d98..00000000 --- a/examples/durable_sticky_with_affinity/worker.py +++ /dev/null @@ -1,72 +0,0 @@ -import asyncio -from typing import Any - -from dotenv import load_dotenv - -from hatchet_sdk import Context, StickyStrategy, WorkerLabelComparator -from hatchet_sdk.v2.callable import DurableContext -from hatchet_sdk.v2.hatchet import Hatchet - -load_dotenv() - -hatchet = Hatchet(debug=True) - - -@hatchet.durable( - sticky=StickyStrategy.HARD, - desired_worker_labels={ - "running_workflow": { - "value": "True", - "required": True, - "comparator": WorkerLabelComparator.NOT_EQUAL, - }, - }, -) -async def my_durable_func(context: DurableContext) -> dict[str, Any]: - try: - ref = await context.aio.spawn_workflow( - "StickyChildWorkflow", {}, options={"sticky": True} - ) - result = await ref.result() - except Exception as e: - result = str(e) - - await context.worker.async_upsert_labels({"running_workflow": "False"}) - return {"worker_result": result} - - -@hatchet.workflow(on_events=["sticky:child"], sticky=StickyStrategy.HARD) -class StickyChildWorkflow: - @hatchet.step( - desired_worker_labels={ - "running_workflow": { - "value": "True", - "required": True, - "comparator": WorkerLabelComparator.NOT_EQUAL, - }, - }, - ) - async def child(self, context: Context) -> dict[str, str | None]: - await context.worker.async_upsert_labels({"running_workflow": "True"}) - - print(f"Heavy work started on {context.worker.id()}") - await asyncio.sleep(15) - print(f"Finished Heavy work on {context.worker.id()}") - - return {"worker": context.worker.id()} - - -def main() -> None: - worker = hatchet.worker( - "sticky-worker", - max_runs=10, - labels={"running_workflow": "False"}, - ) - - worker.register_workflow(StickyChildWorkflow()) - - worker.start() - - -if __name__ == "__main__": - main() diff --git a/examples/events/test_event.py b/examples/events/test_event.py index a4fca8ae..5fd920e6 100644 --- a/examples/events/test_event.py +++ b/examples/events/test_event.py @@ -24,34 +24,34 @@ async def test_async_event_push(aiohatchet: Hatchet) -> None: @pytest.mark.asyncio(scope="session") async def test_async_event_bulk_push(aiohatchet: Hatchet) -> None: - events: List[BulkPushEventWithMetadata] = [ - { - "key": "event1", - "payload": {"message": "This is event 1"}, - "additional_metadata": {"source": "test", "user_id": "user123"}, - }, - { - "key": "event2", - "payload": {"message": "This is event 2"}, - "additional_metadata": {"source": "test", "user_id": "user456"}, - }, - { - "key": "event3", - "payload": {"message": "This is event 3"}, - "additional_metadata": {"source": "test", "user_id": "user789"}, - }, + events = [ + BulkPushEventWithMetadata( + key="event1", + payload={"message": "This is event 1"}, + additional_metadata={"source": "test", "user_id": "user123"}, + ), + BulkPushEventWithMetadata( + key="event2", + payload={"message": "This is event 2"}, + additional_metadata={"source": "test", "user_id": "user456"}, + ), + BulkPushEventWithMetadata( + key="event3", + payload={"message": "This is event 3"}, + additional_metadata={"source": "test", "user_id": "user789"}, + ), ] - opts: BulkPushEventOptions = {"namespace": "bulk-test"} + opts = BulkPushEventOptions(namespace="bulk-test") e = await aiohatchet.event.async_bulk_push(events, opts) assert len(e) == 3 # Sort both lists of events by their key to ensure comparison order - sorted_events = sorted(events, key=lambda x: x["key"]) + sorted_events = sorted(events, key=lambda x: x.key) sorted_returned_events = sorted(e, key=lambda x: x.key) namespace = "bulk-test" # Check that the returned events match the original events for original_event, returned_event in zip(sorted_events, sorted_returned_events): - assert returned_event.key == namespace + original_event["key"] + assert returned_event.key == namespace + original_event.key diff --git a/examples/fanout/stream.py b/examples/fanout/stream.py index 08d0cb4a..c0d03388 100644 --- a/examples/fanout/stream.py +++ b/examples/fanout/stream.py @@ -6,10 +6,9 @@ from dotenv import load_dotenv -from hatchet_sdk import new_client +from hatchet_sdk import Hatchet, new_client from hatchet_sdk.clients.admin import TriggerWorkflowOptions from hatchet_sdk.clients.run_event_listener import StepRunEventType -from hatchet_sdk.v2.hatchet import Hatchet async def main() -> None: @@ -31,7 +30,7 @@ async def main() -> None: workflowRun = hatchet.admin.run_workflow( "Parent", {"n": 2}, - options={"additional_metadata": {streamKey: streamVal}}, + options=TriggerWorkflowOptions(additional_metadata={streamKey: streamVal}), ) # Stream all events for the additional meta key value diff --git a/examples/fanout/sync_stream.py b/examples/fanout/sync_stream.py index 0b1a7140..d035ddc3 100644 --- a/examples/fanout/sync_stream.py +++ b/examples/fanout/sync_stream.py @@ -6,10 +6,9 @@ from dotenv import load_dotenv -from hatchet_sdk import new_client +from hatchet_sdk import Hatchet from hatchet_sdk.clients.admin import TriggerWorkflowOptions from hatchet_sdk.clients.run_event_listener import StepRunEventType -from hatchet_sdk.v2.hatchet import Hatchet def main() -> None: @@ -31,7 +30,7 @@ def main() -> None: workflowRun = hatchet.admin.run_workflow( "Parent", {"n": 2}, - options={"additional_metadata": {streamKey: streamVal}}, + options=TriggerWorkflowOptions(additional_metadata={streamKey: streamVal}), ) # Stream all events for the additional meta key value diff --git a/examples/fanout/trigger.py b/examples/fanout/trigger.py index c34d01b3..e156322c 100644 --- a/examples/fanout/trigger.py +++ b/examples/fanout/trigger.py @@ -1,13 +1,9 @@ import asyncio -import base64 -import json -import os from dotenv import load_dotenv from hatchet_sdk import new_client from hatchet_sdk.clients.admin import TriggerWorkflowOptions -from hatchet_sdk.clients.run_event_listener import StepRunEventType async def main() -> None: @@ -17,7 +13,7 @@ async def main() -> None: hatchet.admin.run_workflow( "Parent", {"test": "test"}, - options={"additional_metadata": {"hello": "moon"}}, + options=TriggerWorkflowOptions(additional_metadata={"hello": "moon"}), ) diff --git a/examples/fanout/worker.py b/examples/fanout/worker.py index c32344c9..0a678937 100644 --- a/examples/fanout/worker.py +++ b/examples/fanout/worker.py @@ -3,7 +3,7 @@ from dotenv import load_dotenv -from hatchet_sdk import Context, Hatchet +from hatchet_sdk import ChildTriggerWorkflowOptions, Context, Hatchet load_dotenv() @@ -28,7 +28,9 @@ async def spawn(self, context: Context) -> dict[str, Any]: "Child", {"a": str(i)}, key=f"child{i}", - options={"additional_metadata": {"hello": "earth"}}, + options=ChildTriggerWorkflowOptions( + additional_metadata={"hello": "earth"} + ), ) ).result() ) diff --git a/examples/logger/event.py b/examples/logger/event.py index 5f7818f6..3dcc4fcc 100644 --- a/examples/logger/event.py +++ b/examples/logger/event.py @@ -7,5 +7,7 @@ client = new_client() client.event.push( - "user:create", {"test": "test"}, options={"additional_metadata": {"hello": "moon"}} + "user:create", + {"test": "test"}, + options=PushEventOptions(additional_metadata={"hello": "moon"}), ) diff --git a/examples/manual_trigger/stream.py b/examples/manual_trigger/stream.py index bc4adfab..c05a73b6 100644 --- a/examples/manual_trigger/stream.py +++ b/examples/manual_trigger/stream.py @@ -17,7 +17,7 @@ async def main() -> None: workflowRun = hatchet.admin.run_workflow( "ManualTriggerWorkflow", {"test": "test"}, - options={"additional_metadata": {"hello": "moon"}}, + options=TriggerWorkflowOptions(additional_metadata={"hello": "moon"}), ) listener = workflowRun.stream() diff --git a/examples/simple/event.py b/examples/simple/event.py index c2d0178a..68b70a85 100644 --- a/examples/simple/event.py +++ b/examples/simple/event.py @@ -1,9 +1,11 @@ -from typing import List - from dotenv import load_dotenv from hatchet_sdk import new_client -from hatchet_sdk.clients.events import BulkPushEventWithMetadata +from hatchet_sdk.clients.events import ( + BulkPushEventOptions, + BulkPushEventWithMetadata, + PushEventOptions, +) load_dotenv() @@ -11,31 +13,32 @@ # client.event.push("user:create", {"test": "test"}) client.event.push( - "user:create", {"test": "test"}, options={"additional_metadata": {"hello": "moon"}} + "user:create", + {"test": "test"}, + options=PushEventOptions(additional_metadata={"hello": "moon"}), ) -events: List[BulkPushEventWithMetadata] = [ - { - "key": "event1", - "payload": {"message": "This is event 1"}, - "additional_metadata": {"source": "test", "user_id": "user123"}, - }, - { - "key": "event2", - "payload": {"message": "This is event 2"}, - "additional_metadata": {"source": "test", "user_id": "user456"}, - }, - { - "key": "event3", - "payload": {"message": "This is event 3"}, - "additional_metadata": {"source": "test", "user_id": "user789"}, - }, +events = [ + BulkPushEventWithMetadata( + key="event1", + payload={"message": "This is event 1"}, + additional_metadata={"source": "test", "user_id": "user123"}, + ), + BulkPushEventWithMetadata( + key="event2", + payload={"message": "This is event 2"}, + additional_metadata={"source": "test", "user_id": "user456"}, + ), + BulkPushEventWithMetadata( + key="event3", + payload={"message": "This is event 3"}, + additional_metadata={"source": "test", "user_id": "user789"}, + ), ] result = client.event.bulk_push( - events, - options={"namespace": "bulk-test"}, + events, options=BulkPushEventOptions(namespace="bulk-test") ) print(result) diff --git a/examples/sticky_workers/event.py b/examples/sticky_workers/event.py index 55ed9b8f..67855b49 100644 --- a/examples/sticky_workers/event.py +++ b/examples/sticky_workers/event.py @@ -1,5 +1,6 @@ from dotenv import load_dotenv +from hatchet_sdk.clients.events import PushEventOptions from hatchet_sdk.hatchet import Hatchet load_dotenv() @@ -10,5 +11,5 @@ hatchet.event.push( "sticky:parent", {"test": "test"}, - options={"additional_metadata": {"hello": "moon"}}, + options=PushEventOptions(additional_metadata={"hello": "moon"}), ) diff --git a/examples/sticky_workers/worker.py b/examples/sticky_workers/worker.py index abedb820..bc681a2a 100644 --- a/examples/sticky_workers/worker.py +++ b/examples/sticky_workers/worker.py @@ -1,6 +1,6 @@ from dotenv import load_dotenv -from hatchet_sdk import Context, Hatchet, StickyStrategy +from hatchet_sdk import ChildTriggerWorkflowOptions, Context, Hatchet, StickyStrategy from hatchet_sdk.context.context import ContextAioImpl load_dotenv() @@ -21,7 +21,7 @@ def step1b(self, context: Context) -> dict[str, str | None]: @hatchet.step(parents=["step1a", "step1b"]) async def step2(self, context: ContextAioImpl) -> dict[str, str | None]: ref = await context.spawn_workflow( - "StickyChildWorkflow", {}, options={"sticky": True} + "StickyChildWorkflow", {}, options=ChildTriggerWorkflowOptions(sticky=True) ) await ref.result() diff --git a/examples/sync_to_async/worker.py b/examples/sync_to_async/worker.py deleted file mode 100644 index 5ac3a912..00000000 --- a/examples/sync_to_async/worker.py +++ /dev/null @@ -1,98 +0,0 @@ -import asyncio -import os -import time -from typing import Any - -from dotenv import load_dotenv - -from hatchet_sdk import Context, sync_to_async -from hatchet_sdk.v2.hatchet import Hatchet - -os.environ["PYTHONASYNCIODEBUG"] = "1" -load_dotenv() - -hatchet = Hatchet(debug=True) - - -@hatchet.function() -async def fanout_sync_async(context: Context) -> dict[str, Any]: - print("spawning child") - - context.put_stream("spawning...") - results = [] - - n = context.workflow_input().get("n", 10) - - start_time = time.time() - for i in range(n): - results.append( - ( - await context.aio.spawn_workflow( - "Child", - {"a": str(i)}, - key=f"child{i}", - options={"additional_metadata": {"hello": "earth"}}, - ) - ).result() - ) - - result = await asyncio.gather(*results) - - execution_time = time.time() - start_time - print(f"Completed in {execution_time:.2f} seconds") - - return {"results": result} - - -@hatchet.workflow(on_events=["child:create"]) -class Child: - ###### Example Functions ###### - def sync_blocking_function(self) -> dict[str, str]: - time.sleep(5) - return {"type": "sync_blocking"} - - @sync_to_async # this makes the function async safe! - def decorated_sync_blocking_function(self) -> dict[str, str]: - time.sleep(5) - return {"type": "decorated_sync_blocking"} - - @sync_to_async # this makes the async function loop safe! - async def async_blocking_function(self) -> dict[str, str]: - time.sleep(5) - return {"type": "async_blocking"} - - ###### Hatchet Steps ###### - @hatchet.step() - async def handle_blocking_sync_in_async(self, context: Context) -> dict[str, str]: - wrapped_blocking_function = sync_to_async(self.sync_blocking_function) - - # This will now be async safe! - data = await wrapped_blocking_function() - return {"blocking_status": "success", "data": data} - - @hatchet.step() - async def handle_decorated_blocking_sync_in_async( - self, context: Context - ) -> dict[str, str]: - data = await self.decorated_sync_blocking_function() - return {"blocking_status": "success", "data": data} - - @hatchet.step() - async def handle_blocking_async_in_async(self, context: Context) -> dict[str, str]: - data = await self.async_blocking_function() - return {"blocking_status": "success", "data": data} - - @hatchet.step() - async def non_blocking_async(self, context: Context) -> dict[str, str]: - await asyncio.sleep(5) - return {"nonblocking_status": "success"} - - -def main() -> None: - worker = hatchet.worker("fanout-worker", max_runs=50) - worker.register_workflow(Child()) - worker.start() - - -if __name__ == "__main__": - main() diff --git a/examples/v2/simple/test_v2_worker.py b/examples/v2/simple/test_v2_worker.py deleted file mode 100644 index c06dae1f..00000000 --- a/examples/v2/simple/test_v2_worker.py +++ /dev/null @@ -1,26 +0,0 @@ -import pytest - -from examples.v2.simple.worker import MyResultType, my_durable_func, my_func -from hatchet_sdk import Hatchet, Worker -from hatchet_sdk.workflow_run import RunRef - - -# requires scope module or higher for shared event loop -@pytest.mark.asyncio(scope="session") -@pytest.mark.parametrize("worker", ["v2_simple"], indirect=True) -async def test_durable(hatchet: Hatchet, worker: Worker) -> None: - durable_run: RunRef[dict[str, str]] = hatchet.admin.run( - my_durable_func, {"test": "test"} - ) - result = await durable_run.result() - - assert result == {"my_durable_func": "testing123"} - - -@pytest.mark.asyncio(scope="session") -@pytest.mark.parametrize("worker", ["v2_simple"], indirect=True) -async def test_func(hatchet: Hatchet, worker: Worker) -> None: - durable_run: RunRef[MyResultType] = hatchet.admin.run(my_func, {"test": "test"}) - result = await durable_run.result() - - assert result == {"my_func": "testing123"} diff --git a/examples/v2/simple/worker.py b/examples/v2/simple/worker.py deleted file mode 100644 index 215bdd0e..00000000 --- a/examples/v2/simple/worker.py +++ /dev/null @@ -1,44 +0,0 @@ -import json -import time -from typing import Any, TypedDict, cast - -from dotenv import load_dotenv - -from hatchet_sdk import Context -from hatchet_sdk.v2.callable import DurableContext -from hatchet_sdk.v2.hatchet import Hatchet -from hatchet_sdk.workflow_run import RunRef - -load_dotenv() - -hatchet = Hatchet(debug=True) - - -class MyResultType(TypedDict): - my_func: str - - -@hatchet.function() -def my_func(context: Context) -> MyResultType: - return MyResultType(my_func="testing123") - - -@hatchet.durable() -async def my_durable_func(context: DurableContext) -> dict[str, MyResultType | None]: - result = cast(dict[str, Any], await context.run(my_func, {"test": "test"}).result()) - - context.log(result) - - return {"my_durable_func": result.get("my_func")} - - -def main() -> None: - worker = hatchet.worker("test-worker", max_runs=5) - - hatchet.admin.run(my_durable_func, {"test": "test"}) - - worker.start() - - -if __name__ == "__main__": - main() diff --git a/hatchet_sdk/__init__.py b/hatchet_sdk/__init__.py index 3162c25c..fc81cced 100644 --- a/hatchet_sdk/__init__.py +++ b/hatchet_sdk/__init__.py @@ -137,8 +137,9 @@ from .clients.run_event_listener import StepRunEventType, WorkflowRunEventType from .context.context import Context from .context.worker_context import WorkerContext -from .hatchet import ClientConfig, Hatchet, concurrency, on_failure_step, step, workflow -from .worker import Worker, WorkerStartOptions, WorkerStatus +from .hatchet import Hatchet, concurrency, on_failure_step, step, workflow +from .loader import ClientConfig +from .worker.worker import Worker, WorkerStartOptions, WorkerStatus from .workflow import ConcurrencyExpression __all__ = [ diff --git a/hatchet_sdk/client.py b/hatchet_sdk/client.py index 45dfd394..f1715972 100644 --- a/hatchet_sdk/client.py +++ b/hatchet_sdk/client.py @@ -12,43 +12,34 @@ from .clients.dispatcher.dispatcher import DispatcherClient, new_dispatcher from .clients.events import EventClient, new_event from .clients.rest_client import RestApi -from .loader import ClientConfig, ConfigLoader +from .loader import ClientConfig class Client: - admin: AdminClient - dispatcher: DispatcherClient - event: EventClient - rest: RestApi - workflow_listener: PooledWorkflowRunListener - logInterceptor: Logger - debug: bool = False - @classmethod def from_environment( cls, defaults: ClientConfig = ClientConfig(), debug: bool = False, *opts_functions: Callable[[ClientConfig], None], - ): + ) -> "Client": try: loop = asyncio.get_running_loop() except RuntimeError: loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) - config: ClientConfig = ConfigLoader(".").load_client_config(defaults) for opt_function in opts_functions: - opt_function(config) + opt_function(defaults) - return cls.from_config(config, debug) + return cls.from_config(defaults, debug) @classmethod def from_config( cls, config: ClientConfig = ClientConfig(), debug: bool = False, - ): + ) -> "Client": try: loop = asyncio.get_running_loop() except RuntimeError: @@ -61,7 +52,7 @@ def from_config( if config.host_port is None: raise ValueError("Host and port are required") - conn: grpc.Channel = new_conn(config) + conn: grpc.Channel = new_conn(config, False) # Instantiate clients event_client = new_event(conn, config) @@ -85,7 +76,7 @@ def __init__( event_client: EventClient, admin_client: AdminClient, dispatcher_client: DispatcherClient, - workflow_listener: PooledWorkflowRunListener, + workflow_listener: PooledWorkflowRunListener | None, rest_client: RestApi, config: ClientConfig, debug: bool = False, @@ -103,17 +94,9 @@ def __init__( self.config = config self.listener = RunEventListenerClient(config) self.workflow_listener = workflow_listener - self.logInterceptor = config.logInterceptor + self.logInterceptor = config.logger self.debug = debug -def with_host_port(host: str, port: int): - def with_host_port_impl(config: ClientConfig): - config.host = host - config.port = port - - return with_host_port_impl - - new_client = Client.from_environment new_client_raw = Client.from_config diff --git a/hatchet_sdk/clients/admin.py b/hatchet_sdk/clients/admin.py index 81662d1b..518bb38f 100644 --- a/hatchet_sdk/clients/admin.py +++ b/hatchet_sdk/clients/admin.py @@ -1,9 +1,10 @@ import json from datetime import datetime -from typing import Any, Callable, Dict, List, Optional, TypedDict, TypeVar, Union +from typing import Any, Callable, TypeVar, Union, cast import grpc from google.protobuf import timestamp_pb2 +from pydantic import BaseModel, Field from hatchet_sdk.clients.rest.models.workflow_run import WorkflowRun from hatchet_sdk.clients.rest.tenacity_utils import tenacity_retry @@ -30,6 +31,7 @@ inject_carrier_into_metadata, parse_carrier_from_metadata, ) +from hatchet_sdk.utils.types import JSONSerializableDict from hatchet_sdk.workflow_run import RunRef, WorkflowRunRef from ..loader import ClientConfig @@ -37,40 +39,40 @@ from ..workflow import WorkflowMeta -def new_admin(config: ClientConfig): +def new_admin(config: ClientConfig) -> "AdminClient": return AdminClient(config) -class ScheduleTriggerWorkflowOptions(TypedDict, total=False): - parent_id: Optional[str] - parent_step_run_id: Optional[str] - child_index: Optional[int] - child_key: Optional[str] - namespace: Optional[str] +class ScheduleTriggerWorkflowOptions(BaseModel): + parent_id: str | None = None + parent_step_run_id: str | None = None + child_index: int | None = None + child_key: str | None = None + namespace: str | None = None -class ChildTriggerWorkflowOptions(TypedDict, total=False): - additional_metadata: Dict[str, str] | None = None +class ChildTriggerWorkflowOptions(BaseModel): + additional_metadata: JSONSerializableDict = Field(default_factory=dict) sticky: bool | None = None -class ChildWorkflowRunDict(TypedDict, total=False): +class ChildWorkflowRunDict(BaseModel): workflow_name: str - input: Any + input: JSONSerializableDict options: ChildTriggerWorkflowOptions key: str | None = None -class TriggerWorkflowOptions(ScheduleTriggerWorkflowOptions, total=False): - additional_metadata: Dict[str, str] | None = None +class TriggerWorkflowOptions(ScheduleTriggerWorkflowOptions): + additional_metadata: JSONSerializableDict = Field(default_factory=dict) desired_worker_id: str | None = None namespace: str | None = None -class WorkflowRunDict(TypedDict, total=False): +class WorkflowRunDict(BaseModel): workflow_name: str - input: Any - options: TriggerWorkflowOptions | None + input: JSONSerializableDict + options: TriggerWorkflowOptions class DedupeViolationErr(Exception): @@ -83,24 +85,21 @@ class AdminClientBase: pooled_workflow_listener: PooledWorkflowRunListener | None = None def _prepare_workflow_request( - self, workflow_name: str, input: any, options: TriggerWorkflowOptions = None - ): + self, workflow_name: str, input: dict[str, Any], options: TriggerWorkflowOptions + ) -> TriggerWorkflowRequest: try: payload_data = json.dumps(input) + _options = options.model_dump() try: - meta = ( - None - if options is None or "additional_metadata" not in options - else options["additional_metadata"] - ) - if meta is not None: - options["additional_metadata"] = json.dumps(meta).encode("utf-8") + _options["additional_metadata"] = json.dumps( + options.additional_metadata + ).encode("utf-8") except json.JSONDecodeError as e: raise ValueError(f"Error encoding payload: {e}") return TriggerWorkflowRequest( - name=workflow_name, input=payload_data, **(options or {}) + name=workflow_name, input=payload_data, **_options ) except json.JSONDecodeError as e: raise ValueError(f"Error encoding payload: {e}") @@ -110,14 +109,14 @@ def _prepare_put_workflow_request( name: str, workflow: CreateWorkflowVersionOpts | WorkflowMeta, overrides: CreateWorkflowVersionOpts | None = None, - ): + ) -> PutWorkflowRequest: try: opts: CreateWorkflowVersionOpts if isinstance(workflow, CreateWorkflowVersionOpts): opts = workflow else: - opts = workflow.get_create_opts(self.client.config.namespace) + opts = workflow.get_create_opts(self.client.config.namespace) # type: ignore[attr-defined] if overrides is not None: opts.MergeFrom(overrides) @@ -133,10 +132,10 @@ def _prepare_put_workflow_request( def _prepare_schedule_workflow_request( self, name: str, - schedules: List[Union[datetime, timestamp_pb2.Timestamp]], - input={}, - options: ScheduleTriggerWorkflowOptions = None, - ): + schedules: list[Union[datetime, timestamp_pb2.Timestamp]], + input: JSONSerializableDict = {}, + options: ScheduleTriggerWorkflowOptions = ScheduleTriggerWorkflowOptions(), + ) -> ScheduleWorkflowRequest: timestamp_schedules = [] for schedule in schedules: if isinstance(schedule, datetime): @@ -156,7 +155,7 @@ def _prepare_schedule_workflow_request( name=name, schedules=timestamp_schedules, input=json.dumps(input), - **(options or {}), + **options.model_dump(), ) @@ -167,7 +166,7 @@ class AdminClientAioImpl(AdminClientBase): def __init__(self, config: ClientConfig): aio_conn = new_conn(config, True) self.config = config - self.aio_client = WorkflowServiceStub(aio_conn) + self.aio_client = WorkflowServiceStub(aio_conn) # type: ignore[no-untyped-call] self.token = config.token self.listener_client = new_listener(config) self.namespace = config.namespace @@ -176,13 +175,17 @@ def __init__(self, config: ClientConfig): async def run( self, function: Union[str, Callable[[Any], T]], - input: any, - options: TriggerWorkflowOptions = None, + input: JSONSerializableDict, + options: TriggerWorkflowOptions = TriggerWorkflowOptions(), ) -> "RunRef[T]": - workflow_name = function - - if not isinstance(function, str): - workflow_name = function.function_name + workflow_name = cast( + str, + ( + function + if isinstance(function, str) + else getattr(function, "function_name") + ), + ) wrr = await self.run_workflow(workflow_name, input, options) @@ -192,11 +195,12 @@ async def run( @tenacity_retry async def run_workflow( - self, workflow_name: str, input: any, options: TriggerWorkflowOptions = None + self, + workflow_name: str, + input: JSONSerializableDict, + options: TriggerWorkflowOptions = TriggerWorkflowOptions(), ) -> WorkflowRunRef: - ctx = parse_carrier_from_metadata( - (options or {}).get("additional_metadata", {}) - ) + ctx = parse_carrier_from_metadata(options.additional_metadata) with self.otel_tracer.start_as_current_span( f"hatchet.async_run_workflow.{workflow_name}", context=ctx @@ -209,27 +213,18 @@ async def run_workflow( self.config ) - namespace = self.namespace - - if ( - options is not None - and "namespace" in options - and options["namespace"] is not None - ): - namespace = options.pop("namespace") + namespace = options.namespace or self.namespace if namespace != "" and not workflow_name.startswith(self.namespace): workflow_name = f"{namespace}{workflow_name}" - if options is not None and "additional_metadata" in options: - options["additional_metadata"] = inject_carrier_into_metadata( - options["additional_metadata"], carrier - ) - span.set_attributes( - flatten( - options["additional_metadata"], parent_key="", separator="." - ) - ) + options.additional_metadata = inject_carrier_into_metadata( + options.additional_metadata, carrier + ) + + span.set_attributes( + flatten(options.additional_metadata, parent_key="", separator=".") + ) request = self._prepare_workflow_request(workflow_name, input, options) @@ -262,30 +257,22 @@ async def run_workflow( async def run_workflows( self, workflows: list[WorkflowRunDict], - options: TriggerWorkflowOptions | None = None, - ) -> List[WorkflowRunRef]: + options: TriggerWorkflowOptions = TriggerWorkflowOptions(), + ) -> list[WorkflowRunRef]: if len(workflows) == 0: raise ValueError("No workflows to run") if not self.pooled_workflow_listener: self.pooled_workflow_listener = PooledWorkflowRunListener(self.config) - namespace = self.namespace - - if ( - options is not None - and "namespace" in options - and options["namespace"] is not None - ): - namespace = options["namespace"] - del options["namespace"] + namespace = options.namespace or self.namespace - workflow_run_requests: TriggerWorkflowRequest = [] + workflow_run_requests: list[TriggerWorkflowRequest] = [] for workflow in workflows: - workflow_name = workflow["workflow_name"] - input_data = workflow["input"] - options = workflow["options"] + workflow_name = workflow.workflow_name + input_data = workflow.input + options = workflow.options if namespace != "" and not workflow_name.startswith(self.namespace): workflow_name = f"{namespace}{workflow_name}" @@ -294,10 +281,8 @@ async def run_workflows( request = self._prepare_workflow_request(workflow_name, input_data, options) workflow_run_requests.append(request) - request = BulkTriggerWorkflowRequest(workflows=workflow_run_requests) - resp: BulkTriggerWorkflowResponse = await self.aio_client.BulkTriggerWorkflow( - request, + BulkTriggerWorkflowRequest(workflows=workflow_run_requests), metadata=get_metadata(self.token), ) @@ -319,9 +304,12 @@ async def put_workflow( ) -> WorkflowVersion: opts = self._prepare_put_workflow_request(name, workflow, overrides) - return await self.aio_client.PutWorkflow( - opts, - metadata=get_metadata(self.token), + return cast( + WorkflowVersion, + await self.aio_client.PutWorkflow( + opts, + metadata=get_metadata(self.token), + ), ) @tenacity_retry @@ -330,7 +318,7 @@ async def put_rate_limit( key: str, limit: int, duration: RateLimitDuration = RateLimitDuration.SECOND, - ): + ) -> None: await self.aio_client.PutRateLimit( PutRateLimitRequest( key=key, @@ -344,20 +332,12 @@ async def put_rate_limit( async def schedule_workflow( self, name: str, - schedules: List[Union[datetime, timestamp_pb2.Timestamp]], - input={}, - options: ScheduleTriggerWorkflowOptions = None, + schedules: list[Union[datetime, timestamp_pb2.Timestamp]], + input: JSONSerializableDict = {}, + options: ScheduleTriggerWorkflowOptions = ScheduleTriggerWorkflowOptions(), ) -> WorkflowVersion: try: - namespace = self.namespace - - if ( - options is not None - and "namespace" in options - and options["namespace"] is not None - ): - namespace = options["namespace"] - del options["namespace"] + namespace = options.namespace or self.namespace if namespace != "" and not name.startswith(self.namespace): name = f"{namespace}{name}" @@ -366,9 +346,12 @@ async def schedule_workflow( name, schedules, input, options ) - return await self.aio_client.ScheduleWorkflow( - request, - metadata=get_metadata(self.token), + return cast( + WorkflowVersion, + await self.aio_client.ScheduleWorkflow( + request, + metadata=get_metadata(self.token), + ), ) except (grpc.aio.AioRpcError, grpc.RpcError) as e: if e.code() == grpc.StatusCode.ALREADY_EXISTS: @@ -379,9 +362,9 @@ async def schedule_workflow( class AdminClient(AdminClientBase): def __init__(self, config: ClientConfig): - conn = new_conn(config) + conn = new_conn(config, False) self.config = config - self.client = WorkflowServiceStub(conn) + self.client = WorkflowServiceStub(conn) # type: ignore[no-untyped-call] self.aio = AdminClientAioImpl(config) self.token = config.token self.listener_client = new_listener(config) @@ -409,8 +392,8 @@ def put_rate_limit( self, key: str, limit: int, - duration: Union[RateLimitDuration.Value, str] = RateLimitDuration.SECOND, - ): + duration: Union[RateLimitDuration, str] = RateLimitDuration.SECOND, + ) -> None: self.client.PutRateLimit( PutRateLimitRequest( key=key, @@ -424,20 +407,12 @@ def put_rate_limit( def schedule_workflow( self, name: str, - schedules: List[Union[datetime, timestamp_pb2.Timestamp]], - input={}, - options: ScheduleTriggerWorkflowOptions = None, + schedules: list[Union[datetime, timestamp_pb2.Timestamp]], + input: JSONSerializableDict = {}, + options: ScheduleTriggerWorkflowOptions = ScheduleTriggerWorkflowOptions(), ) -> WorkflowVersion: try: - namespace = self.namespace - - if ( - options is not None - and "namespace" in options - and options["namespace"] is not None - ): - namespace = options["namespace"] - del options["namespace"] + namespace = options.namespace or self.namespace if namespace != "" and not name.startswith(self.namespace): name = f"{namespace}{name}" @@ -446,9 +421,12 @@ def schedule_workflow( name, schedules, input, options ) - return self.client.ScheduleWorkflow( - request, - metadata=get_metadata(self.token), + return cast( + WorkflowVersion, + self.client.ScheduleWorkflow( + request, + metadata=get_metadata(self.token), + ), ) except (grpc.RpcError, grpc.aio.AioRpcError) as e: if e.code() == grpc.StatusCode.ALREADY_EXISTS: @@ -460,11 +438,12 @@ def schedule_workflow( ## TODO: `any` type hint should come from `typing` @tenacity_retry def run_workflow( - self, workflow_name: str, input: any, options: TriggerWorkflowOptions = None + self, + workflow_name: str, + input: JSONSerializableDict, + options: TriggerWorkflowOptions = TriggerWorkflowOptions(), ) -> WorkflowRunRef: - ctx = parse_carrier_from_metadata( - (options or {}).get("additional_metadata", {}) - ) + ctx = parse_carrier_from_metadata(options.additional_metadata) with self.otel_tracer.start_as_current_span( f"hatchet.run_workflow.{workflow_name}", context=ctx @@ -477,26 +456,15 @@ def run_workflow( self.config ) - namespace = self.namespace - - ## TODO: Factor this out - it's repeated a lot of places - if ( - options is not None - and "namespace" in options - and options["namespace"] is not None - ): - namespace = options.pop("namespace") + namespace = options.namespace or self.namespace - if options is not None and "additional_metadata" in options: - options["additional_metadata"] = inject_carrier_into_metadata( - options["additional_metadata"], carrier - ) + options.additional_metadata = inject_carrier_into_metadata( + options.additional_metadata, carrier + ) - span.set_attributes( - flatten( - options["additional_metadata"], parent_key="", separator="." - ) - ) + span.set_attributes( + flatten(options.additional_metadata, parent_key="", separator=".") + ) if namespace != "" and not workflow_name.startswith(self.namespace): workflow_name = f"{namespace}{workflow_name}" @@ -530,39 +498,33 @@ def run_workflow( @tenacity_retry def run_workflows( - self, workflows: List[WorkflowRunDict], options: TriggerWorkflowOptions = None + self, + workflows: list[WorkflowRunDict], + options: TriggerWorkflowOptions = TriggerWorkflowOptions(), ) -> list[WorkflowRunRef]: - workflow_run_requests: TriggerWorkflowRequest = [] + workflow_run_requests: list[TriggerWorkflowRequest] = [] if not self.pooled_workflow_listener: self.pooled_workflow_listener = PooledWorkflowRunListener(self.config) for workflow in workflows: - workflow_name = workflow["workflow_name"] - input_data = workflow["input"] - options = workflow["options"] + workflow_name = workflow.workflow_name + input_data = workflow.input + options = workflow.options - namespace = self.namespace + namespace = options.namespace or self.namespace - if ( - options is not None - and "namespace" in options - and options["namespace"] is not None - ): - namespace = options["namespace"] - del options["namespace"] + if namespace != "" and not workflow_name.startswith(self.namespace): + workflow_name = f"{namespace}{workflow_name}" - if namespace != "" and not workflow_name.startswith(self.namespace): - workflow_name = f"{namespace}{workflow_name}" + # Prepare and trigger workflow for each workflow name and input + request = self._prepare_workflow_request(workflow_name, input_data, options) - # Prepare and trigger workflow for each workflow name and input - request = self._prepare_workflow_request(workflow_name, input_data, options) + workflow_run_requests.append(request) - workflow_run_requests.append(request) - - request = BulkTriggerWorkflowRequest(workflows=workflow_run_requests) + bulk_request = BulkTriggerWorkflowRequest(workflows=workflow_run_requests) resp: BulkTriggerWorkflowResponse = self.client.BulkTriggerWorkflow( - request, + bulk_request, metadata=get_metadata(self.token), ) @@ -578,13 +540,17 @@ def run_workflows( def run( self, function: Union[str, Callable[[Any], T]], - input: any, - options: TriggerWorkflowOptions = None, + input: JSONSerializableDict, + options: TriggerWorkflowOptions = TriggerWorkflowOptions(), ) -> "RunRef[T]": - workflow_name = function - - if not isinstance(function, str): - workflow_name = function.function_name + workflow_name = cast( + str, + ( + function + if isinstance(function, str) + else getattr(function, "function_name") + ), + ) wrr = self.run_workflow(workflow_name, input, options) diff --git a/hatchet_sdk/clients/dispatcher/action_listener.py b/hatchet_sdk/clients/dispatcher/action_listener.py index fc2887bd..763c0e6a 100644 --- a/hatchet_sdk/clients/dispatcher/action_listener.py +++ b/hatchet_sdk/clients/dispatcher/action_listener.py @@ -2,10 +2,11 @@ import json import time from dataclasses import dataclass, field -from typing import Any, AsyncGenerator, List, Optional +from typing import Any, AsyncGenerator, AsyncIterable, AsyncIterator, Optional, cast import grpc -from grpc._cython import cygrpc +import grpc.aio +from grpc._cython import cygrpc # type: ignore[attr-defined] from hatchet_sdk.clients.event_ts import Event_ts, read_with_interrupt from hatchet_sdk.clients.run_event_listener import ( @@ -24,6 +25,7 @@ from hatchet_sdk.logger import logger from hatchet_sdk.utils.backoff import exp_backoff_sleep from hatchet_sdk.utils.serialization import flatten +from hatchet_sdk.utils.types import JSONSerializableDict from ...loader import ClientConfig from ...metadata import get_metadata @@ -39,14 +41,14 @@ @dataclass class GetActionListenerRequest: worker_name: str - services: List[str] - actions: List[str] + services: list[str] + actions: list[str] max_runs: Optional[int] = None _labels: dict[str, str | int] = field(default_factory=dict) labels: dict[str, WorkerLabels] = field(init=False) - def __post_init__(self): + def __post_init__(self) -> None: self.labels = {} for key, value in self._labels.items(): @@ -68,16 +70,16 @@ class Action: step_id: str step_run_id: str action_id: str - action_payload: str action_type: ActionType retry_count: int - additional_metadata: dict[str, str] | None = None + action_payload: JSONSerializableDict = field(default_factory=dict) + additional_metadata: JSONSerializableDict = field(default_factory=dict) child_workflow_index: int | None = None child_workflow_key: str | None = None parent_workflow_run_id: str | None = None - def __post_init__(self): + def __post_init__(self) -> None: if isinstance(self.additional_metadata, str) and self.additional_metadata != "": try: self.additional_metadata = json.loads(self.additional_metadata) @@ -113,11 +115,6 @@ def otel_attributes(self) -> dict[str, Any]: ) -START_STEP_RUN = 0 -CANCEL_STEP_RUN = 1 -START_GET_GROUP_KEY = 2 - - @dataclass class ActionListener: config: ClientConfig @@ -130,22 +127,22 @@ class ActionListener: last_connection_attempt: float = field(default=0, init=False) last_heartbeat_succeeded: bool = field(default=True, init=False) time_last_hb_succeeded: float = field(default=9999999999999, init=False) - heartbeat_task: Optional[asyncio.Task] = field(default=None, init=False) + heartbeat_task: Optional[asyncio.Task[None]] = field(default=None, init=False) run_heartbeat: bool = field(default=True, init=False) listen_strategy: str = field(default="v2", init=False) stop_signal: bool = field(default=False, init=False) missed_heartbeats: int = field(default=0, init=False) - def __post_init__(self): - self.client = DispatcherStub(new_conn(self.config)) - self.aio_client = DispatcherStub(new_conn(self.config, True)) + def __post_init__(self) -> None: + self.client = DispatcherStub(new_conn(self.config, False)) # type: ignore[no-untyped-call] + self.aio_client = DispatcherStub(new_conn(self.config, True)) # type: ignore[no-untyped-call] self.token = self.config.token - def is_healthy(self): + def is_healthy(self) -> bool: return self.last_heartbeat_succeeded - async def heartbeat(self): + async def heartbeat(self) -> None: # send a heartbeat every 4 seconds heartbeat_delay = 4 @@ -205,7 +202,7 @@ async def heartbeat(self): break await asyncio.sleep(heartbeat_delay) - async def start_heartbeater(self): + async def start_heartbeater(self) -> None: if self.heartbeat_task is not None: return @@ -219,10 +216,10 @@ async def start_heartbeater(self): raise e self.heartbeat_task = loop.create_task(self.heartbeat()) - def __aiter__(self): + def __aiter__(self) -> AsyncGenerator[Action | None, None]: return self._generator() - async def _generator(self) -> AsyncGenerator[Action, None]: + async def _generator(self) -> AsyncGenerator[Action | None, None]: listener = None while not self.stop_signal: @@ -238,6 +235,10 @@ async def _generator(self) -> AsyncGenerator[Action, None]: try: while not self.stop_signal: self.interrupt = Event_ts() + + if listener is None: + continue + t = asyncio.create_task( read_with_interrupt(listener, self.interrupt) ) @@ -250,7 +251,10 @@ async def _generator(self) -> AsyncGenerator[Action, None]: ) t.cancel() - listener.cancel() + + if listener: + listener.cancel() + break assigned_action = t.result() @@ -260,20 +264,23 @@ async def _generator(self) -> AsyncGenerator[Action, None]: break self.retries = 0 - assigned_action: AssignedAction # Process the received action - action_type = self.map_action_type(assigned_action.actionType) + action_type = assigned_action.actionType - if ( - assigned_action.actionPayload is None - or assigned_action.actionPayload == "" - ): - action_payload = None - else: - action_payload = self.parse_action_payload( - assigned_action.actionPayload + action_payload = ( + {} + if not assigned_action.actionPayload + else self.parse_action_payload(assigned_action.actionPayload) + ) + + try: + additional_metadata = cast( + dict[str, Any], + json.loads(assigned_action.additional_metadata), ) + except json.JSONDecodeError: + additional_metadata = {} action = Action( tenant_id=assigned_action.tenantId, @@ -289,7 +296,7 @@ async def _generator(self) -> AsyncGenerator[Action, None]: action_payload=action_payload, action_type=action_type, retry_count=assigned_action.retryCount, - additional_metadata=assigned_action.additional_metadata, + additional_metadata=additional_metadata, child_workflow_index=assigned_action.child_workflow_index, child_workflow_key=assigned_action.child_workflow_key, parent_workflow_run_id=assigned_action.parent_workflow_run_id, @@ -323,25 +330,15 @@ async def _generator(self) -> AsyncGenerator[Action, None]: self.retries = self.retries + 1 - def parse_action_payload(self, payload: str): + def parse_action_payload(self, payload: str) -> JSONSerializableDict: try: - payload_data = json.loads(payload) + return cast(JSONSerializableDict, json.loads(payload)) except json.JSONDecodeError as e: raise ValueError(f"Error decoding payload: {e}") - return payload_data - - def map_action_type(self, action_type): - if action_type == ActionType.START_STEP_RUN: - return START_STEP_RUN - elif action_type == ActionType.CANCEL_STEP_RUN: - return CANCEL_STEP_RUN - elif action_type == ActionType.START_GET_GROUP_KEY: - return START_GET_GROUP_KEY - else: - # logger.error(f"Unknown action type: {action_type}") - return None - async def get_listen_client(self): + async def get_listen_client( + self, + ) -> grpc.aio.UnaryStreamCall[WorkerListenRequest, AssignedAction]: current_time = int(time.time()) if ( @@ -369,7 +366,7 @@ async def get_listen_client(self): f"action listener connection interrupted, retrying... ({self.retries}/{DEFAULT_ACTION_LISTENER_RETRY_COUNT})" ) - self.aio_client = DispatcherStub(new_conn(self.config, True)) + self.aio_client = DispatcherStub(new_conn(self.config, True)) # type: ignore[no-untyped-call] if self.listen_strategy == "v2": # we should await for the listener to be established before @@ -390,11 +387,14 @@ async def get_listen_client(self): self.last_connection_attempt = current_time - return listener + return cast( + grpc.aio.UnaryStreamCall[WorkerListenRequest, AssignedAction], listener + ) - def cleanup(self): + def cleanup(self) -> None: self.run_heartbeat = False - self.heartbeat_task.cancel() + if self.heartbeat_task is not None: + self.heartbeat_task.cancel() try: self.unregister() @@ -404,9 +404,11 @@ def cleanup(self): if self.interrupt: self.interrupt.set() - def unregister(self): + def unregister(self) -> WorkerUnsubscribeRequest: self.run_heartbeat = False - self.heartbeat_task.cancel() + + if self.heartbeat_task is not None: + self.heartbeat_task.cancel() try: req = self.aio_client.Unsubscribe( @@ -416,6 +418,6 @@ def unregister(self): ) if self.interrupt is not None: self.interrupt.set() - return req + return cast(WorkerUnsubscribeRequest, req) except grpc.RpcError as e: raise Exception(f"Failed to unsubscribe: {e}") diff --git a/hatchet_sdk/clients/dispatcher/dispatcher.py b/hatchet_sdk/clients/dispatcher/dispatcher.py index e52aca2a..d9cae493 100644 --- a/hatchet_sdk/clients/dispatcher/dispatcher.py +++ b/hatchet_sdk/clients/dispatcher/dispatcher.py @@ -1,5 +1,6 @@ from typing import Any, cast +import grpc.aio from google.protobuf.timestamp_pb2 import Timestamp from hatchet_sdk.clients.dispatcher.action_listener import ( @@ -41,7 +42,7 @@ class DispatcherClient: config: ClientConfig def __init__(self, config: ClientConfig): - conn = new_conn(config) + conn = new_conn(config, False) self.client = DispatcherStub(conn) # type: ignore[no-untyped-call] aio_conn = new_conn(config, True) @@ -69,7 +70,7 @@ async def get_action_listener( async def send_step_action_event( self, action: Action, event_type: StepActionEventType, payload: str - ) -> Any: + ) -> grpc.aio.UnaryUnaryCall[StepActionEvent, ActionEventResponse] | None: try: return await self._try_send_step_action_event(action, event_type, payload) except Exception as e: @@ -84,12 +85,12 @@ async def send_step_action_event( "Failed to send finished event: " + str(e), ) - return + return None @tenacity_retry async def _try_send_step_action_event( self, action: Action, event_type: StepActionEventType, payload: str - ) -> Any: + ) -> grpc.aio.UnaryUnaryCall[StepActionEvent, ActionEventResponse]: eventTimestamp = Timestamp() eventTimestamp.GetCurrentTime() @@ -105,15 +106,17 @@ async def _try_send_step_action_event( eventPayload=payload, ) - ## TODO: What does this return? - return await self.aio_client.SendStepActionEvent( - event, - metadata=get_metadata(self.token), + return cast( + grpc.aio.UnaryUnaryCall[StepActionEvent, ActionEventResponse], + await self.aio_client.SendStepActionEvent( + event, + metadata=get_metadata(self.token), + ), ) async def send_group_key_action_event( self, action: Action, event_type: GroupKeyActionEventType, payload: str - ) -> Any: + ) -> grpc.aio.UnaryUnaryCall[GroupKeyActionEvent, ActionEventResponse]: eventTimestamp = Timestamp() eventTimestamp.GetCurrentTime() @@ -128,9 +131,12 @@ async def send_group_key_action_event( ) ## TODO: What does this return? - return await self.aio_client.SendGroupKeyActionEvent( - event, - metadata=get_metadata(self.token), + return cast( + grpc.aio.UnaryUnaryCall[GroupKeyActionEvent, ActionEventResponse], + await self.aio_client.SendGroupKeyActionEvent( + event, + metadata=get_metadata(self.token), + ), ) def put_overrides_data(self, data: OverridesData) -> ActionEventResponse: diff --git a/hatchet_sdk/clients/event_ts.py b/hatchet_sdk/clients/event_ts.py index 1d3c3978..7a85d467 100644 --- a/hatchet_sdk/clients/event_ts.py +++ b/hatchet_sdk/clients/event_ts.py @@ -1,5 +1,8 @@ import asyncio -from typing import Any +from typing import Any, TypeVar, cast + +import grpc.aio +from grpc._cython import cygrpc # type: ignore[attr-defined] class Event_ts(asyncio.Event): @@ -7,22 +10,32 @@ class Event_ts(asyncio.Event): Event_ts is a subclass of asyncio.Event that allows for thread-safe setting and clearing of the event. """ - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) - if self._loop is None: + if self._loop is None: # type: ignore[has-type] self._loop = asyncio.get_event_loop() - def set(self): + def set(self) -> None: if not self._loop.is_closed(): self._loop.call_soon_threadsafe(super().set) - def clear(self): + def clear(self) -> None: self._loop.call_soon_threadsafe(super().clear) -async def read_with_interrupt(listener: Any, interrupt: Event_ts): +TRequest = TypeVar("TRequest") +TResponse = TypeVar("TResponse") + + +async def read_with_interrupt( + listener: grpc.aio.UnaryStreamCall[TRequest, TResponse], interrupt: Event_ts +) -> TResponse: try: result = await listener.read() - return result + + if result is cygrpc.EOF: + raise ValueError("Unexpected EOF") + + return cast(TResponse, result) finally: interrupt.set() diff --git a/hatchet_sdk/clients/events.py b/hatchet_sdk/clients/events.py index 160b780e..c8d5556e 100644 --- a/hatchet_sdk/clients/events.py +++ b/hatchet_sdk/clients/events.py @@ -1,11 +1,12 @@ import asyncio import datetime import json -from typing import Any, Dict, List, Optional, TypedDict +from typing import Any, Dict, List, Optional, TypedDict, cast from uuid import uuid4 import grpc from google.protobuf import timestamp_pb2 +from pydantic import BaseModel, Field from hatchet_sdk.clients.rest.tenacity_utils import tenacity_retry from hatchet_sdk.contracts.events_pb2 import ( @@ -18,24 +19,26 @@ from hatchet_sdk.contracts.events_pb2_grpc import EventsServiceStub from hatchet_sdk.utils.serialization import flatten from hatchet_sdk.utils.tracing import ( + OTEL_CARRIER_KEY, create_carrier, create_tracer, inject_carrier_into_metadata, parse_carrier_from_metadata, ) +from hatchet_sdk.utils.types import JSONSerializableDict from ..loader import ClientConfig from ..metadata import get_metadata -def new_event(conn, config: ClientConfig): +def new_event(conn: grpc.Channel, config: ClientConfig) -> "EventClient": return EventClient( - client=EventsServiceStub(conn), + client=EventsServiceStub(conn), # type: ignore[no-untyped-call] config=config, ) -def proto_timestamp_now(): +def proto_timestamp_now() -> timestamp_pb2.Timestamp: t = datetime.datetime.now().timestamp() seconds = int(t) nanos = int(t % 1 * 1e9) @@ -43,19 +46,20 @@ def proto_timestamp_now(): return timestamp_pb2.Timestamp(seconds=seconds, nanos=nanos) -class PushEventOptions(TypedDict, total=False): - additional_metadata: Dict[str, str] | None = None +class PushEventOptions(BaseModel): + additional_metadata: JSONSerializableDict = Field(default_factory=dict) namespace: str | None = None -class BulkPushEventOptions(TypedDict, total=False): +class BulkPushEventOptions(BaseModel): namespace: str | None = None + otel_carrier: dict[str, str] = Field(default_factory=dict) -class BulkPushEventWithMetadata(TypedDict, total=False): +class BulkPushEventWithMetadata(BaseModel): key: str payload: Any - additional_metadata: Optional[Dict[str, Any]] # Optional metadata + additional_metadata: JSONSerializableDict = Field(default_factory=dict) class EventClient: @@ -66,7 +70,10 @@ def __init__(self, client: EventsServiceStub, config: ClientConfig): self.otel_tracer = create_tracer(config=config) async def async_push( - self, event_key, payload, options: Optional[PushEventOptions] = None + self, + event_key: str, + payload: dict[str, Any], + options: PushEventOptions = PushEventOptions(), ) -> Event: return await asyncio.to_thread( self.push, event_key=event_key, payload=payload, options=options @@ -74,77 +81,67 @@ async def async_push( async def async_bulk_push( self, - events: List[BulkPushEventWithMetadata], - options: Optional[BulkPushEventOptions] = None, + events: list[BulkPushEventWithMetadata], + options: BulkPushEventOptions = BulkPushEventOptions(), ) -> List[Event]: return await asyncio.to_thread(self.bulk_push, events=events, options=options) @tenacity_retry - def push(self, event_key, payload, options: PushEventOptions = None) -> Event: - ctx = parse_carrier_from_metadata( - (options or {}).get("additional_metadata", {}) - ) + def push( + self, + event_key: str, + payload: dict[str, Any], + options: PushEventOptions = PushEventOptions(), + ) -> Event: + ctx = parse_carrier_from_metadata(options.additional_metadata) with self.otel_tracer.start_as_current_span( "hatchet.push", context=ctx ) as span: carrier = create_carrier() - namespace = self.namespace - - if ( - options is not None - and "namespace" in options - and options["namespace"] is not None - ): - namespace = options.pop("namespace") + namespace = options.namespace or self.namespace namespaced_event_key = namespace + event_key try: meta = inject_carrier_into_metadata( - dict() if options is None else options["additional_metadata"], + options.additional_metadata, carrier, ) - meta_bytes = None if meta is None else json.dumps(meta).encode("utf-8") + meta_bytes = None if meta is None else json.dumps(meta) except Exception as e: raise ValueError(f"Error encoding meta: {e}") span.set_attributes(flatten(meta, parent_key="", separator=".")) try: - payload_bytes = json.dumps(payload).encode("utf-8") - except json.UnicodeEncodeError as e: + payload_str = json.dumps(payload) + except (TypeError, ValueError) as e: raise ValueError(f"Error encoding payload: {e}") request = PushEventRequest( key=namespaced_event_key, - payload=payload_bytes, + payload=payload_str, eventTimestamp=proto_timestamp_now(), additionalMetadata=meta_bytes, ) span.add_event("Pushing event", attributes={"key": namespaced_event_key}) - return self.client.Push(request, metadata=get_metadata(self.token)) + return cast( + Event, self.client.Push(request, metadata=get_metadata(self.token)) + ) @tenacity_retry def bulk_push( self, events: List[BulkPushEventWithMetadata], - options: BulkPushEventOptions = None, + options: BulkPushEventOptions, ) -> List[Event]: - namespace = self.namespace + namespace = options.namespace or self.namespace bulk_push_correlation_id = uuid4() - ctx = parse_carrier_from_metadata( - (options or {}).get("additional_metadata", {}) - ) - if ( - options is not None - and "namespace" in options - and options["namespace"] is not None - ): - namespace = options.pop("namespace") + ctx = parse_carrier_from_metadata({OTEL_CARRIER_KEY: options.otel_carrier}) bulk_events = [] for event in events: @@ -156,29 +153,27 @@ def bulk_push( "bulk_push_correlation_id", str(bulk_push_correlation_id) ) - event_key = namespace + event["key"] - payload = event["payload"] + event_key = namespace + event.key + payload = event.payload + + meta = inject_carrier_into_metadata(event.additional_metadata, carrier) + span.set_attributes(flatten(meta, parent_key="", separator=".")) try: - meta = inject_carrier_into_metadata( - event.get("additional_metadata", {}), carrier - ) - meta_bytes = json.dumps(meta).encode("utf-8") if meta else None + meta_str = json.dumps(meta) except Exception as e: raise ValueError(f"Error encoding meta: {e}") - span.set_attributes(flatten(meta, parent_key="", separator=".")) - try: - payload_bytes = json.dumps(payload).encode("utf-8") - except json.UnicodeEncodeError as e: + payload = json.dumps(payload) + except (TypeError, ValueError) as e: raise ValueError(f"Error encoding payload: {e}") request = PushEventRequest( key=event_key, - payload=payload_bytes, + payload=payload, eventTimestamp=proto_timestamp_now(), - additionalMetadata=meta_bytes, + additionalMetadata=meta_str, ) bulk_events.append(request) @@ -187,9 +182,12 @@ def bulk_push( span.add_event("Pushing bulk events") response = self.client.BulkPush(bulk_request, metadata=get_metadata(self.token)) - return response.events + return cast( + list[Event], + response.events, + ) - def log(self, message: str, step_run_id: str): + def log(self, message: str, step_run_id: str) -> None: try: request = PutLogRequest( stepRunId=step_run_id, @@ -201,7 +199,7 @@ def log(self, message: str, step_run_id: str): except Exception as e: raise ValueError(f"Error logging: {e}") - def stream(self, data: str | bytes, step_run_id: str): + def stream(self, data: str | bytes, step_run_id: str) -> None: try: if isinstance(data, str): data_bytes = data.encode("utf-8") diff --git a/hatchet_sdk/clients/rest/api_client.py b/hatchet_sdk/clients/rest/api_client.py index 76446dda..62bdc472 100644 --- a/hatchet_sdk/clients/rest/api_client.py +++ b/hatchet_sdk/clients/rest/api_client.py @@ -97,7 +97,7 @@ async def __aenter__(self): async def __aexit__(self, exc_type, exc_value, traceback): await self.close() - async def close(self): + async def close(self) -> None: await self.rest_client.close() @property diff --git a/hatchet_sdk/clients/rest/tenacity_utils.py b/hatchet_sdk/clients/rest/tenacity_utils.py index 377266a1..c90f7352 100644 --- a/hatchet_sdk/clients/rest/tenacity_utils.py +++ b/hatchet_sdk/clients/rest/tenacity_utils.py @@ -27,7 +27,7 @@ def tenacity_alert_retry(retry_state: tenacity.RetryCallState) -> None: ) -def tenacity_should_retry(ex: Exception) -> bool: +def tenacity_should_retry(ex: BaseException) -> bool: if isinstance(ex, (grpc.aio.AioRpcError, grpc.RpcError)): if ex.code() in [ grpc.StatusCode.UNIMPLEMENTED, diff --git a/hatchet_sdk/clients/rest_client.py b/hatchet_sdk/clients/rest_client.py index f6458e5a..83266a55 100644 --- a/hatchet_sdk/clients/rest_client.py +++ b/hatchet_sdk/clients/rest_client.py @@ -2,7 +2,7 @@ import atexit import datetime import threading -from typing import Any, Coroutine, List +from typing import Coroutine, TypeVar from pydantic import StrictInt @@ -14,11 +14,11 @@ from hatchet_sdk.clients.rest.api.workflow_runs_api import WorkflowRunsApi from hatchet_sdk.clients.rest.api_client import ApiClient from hatchet_sdk.clients.rest.configuration import Configuration -from hatchet_sdk.clients.rest.models import TriggerWorkflowRunRequest from hatchet_sdk.clients.rest.models.create_cron_workflow_trigger_request import ( CreateCronWorkflowTriggerRequest, ) from hatchet_sdk.clients.rest.models.cron_workflows import CronWorkflows +from hatchet_sdk.clients.rest.models.cron_workflows_list import CronWorkflowsList from hatchet_sdk.clients.rest.models.cron_workflows_order_by_field import ( CronWorkflowsOrderByField, ) @@ -27,6 +27,9 @@ EventOrderByDirection, ) from hatchet_sdk.clients.rest.models.event_order_by_field import EventOrderByField +from hatchet_sdk.clients.rest.models.event_update_cancel200_response import ( + EventUpdateCancel200Response, +) from hatchet_sdk.clients.rest.models.log_line_level import LogLineLevel from hatchet_sdk.clients.rest.models.log_line_list import LogLineList from hatchet_sdk.clients.rest.models.log_line_order_by_direction import ( @@ -44,9 +47,15 @@ ScheduleWorkflowRunRequest, ) from hatchet_sdk.clients.rest.models.scheduled_workflows import ScheduledWorkflows +from hatchet_sdk.clients.rest.models.scheduled_workflows_list import ( + ScheduledWorkflowsList, +) from hatchet_sdk.clients.rest.models.scheduled_workflows_order_by_field import ( ScheduledWorkflowsOrderByField, ) +from hatchet_sdk.clients.rest.models.trigger_workflow_run_request import ( + TriggerWorkflowRunRequest, +) from hatchet_sdk.clients.rest.models.workflow import Workflow from hatchet_sdk.clients.rest.models.workflow_kind import WorkflowKind from hatchet_sdk.clients.rest.models.workflow_list import WorkflowList @@ -66,6 +75,18 @@ WorkflowRunsCancelRequest, ) from hatchet_sdk.clients.rest.models.workflow_version import WorkflowVersion +from hatchet_sdk.utils.types import JSONSerializableDict + +## Type variables to use with coroutines. +## See https://stackoverflow.com/questions/73240620/the-right-way-to-type-hint-a-coroutine-function +## Return type +R = TypeVar("R") + +## Yield type +Y = TypeVar("Y") + +## Send type +S = TypeVar("S") class AsyncRestApi: @@ -77,50 +98,50 @@ def __init__(self, host: str, api_key: str, tenant_id: str): access_token=api_key, ) - self._api_client = None - self._workflow_api = None - self._workflow_run_api = None - self._step_run_api = None - self._event_api = None - self._log_api = None + self._api_client: ApiClient | None = None + self._workflow_api: WorkflowApi | None = None + self._workflow_run_api: WorkflowRunApi | None = None + self._step_run_api: StepRunApi | None = None + self._event_api: EventApi | None = None + self._log_api: LogApi | None = None @property - def api_client(self): + def api_client(self) -> ApiClient: if self._api_client is None: self._api_client = ApiClient(configuration=self.config) return self._api_client @property - def workflow_api(self): + def workflow_api(self) -> WorkflowApi: if self._workflow_api is None: self._workflow_api = WorkflowApi(self.api_client) return self._workflow_api @property - def workflow_run_api(self): + def workflow_run_api(self) -> WorkflowRunApi: if self._workflow_run_api is None: self._workflow_run_api = WorkflowRunApi(self.api_client) return self._workflow_run_api @property - def step_run_api(self): + def step_run_api(self) -> StepRunApi: if self._step_run_api is None: self._step_run_api = StepRunApi(self.api_client) return self._step_run_api @property - def event_api(self): + def event_api(self) -> EventApi: if self._event_api is None: self._event_api = EventApi(self.api_client) return self._event_api @property - def log_api(self): + def log_api(self) -> LogApi: if self._log_api is None: self._log_api = LogApi(self.api_client) return self._log_api - async def close(self): + async def close(self) -> None: # Ensure the aiohttp client session is closed if self._api_client is not None: await self._api_client.close() @@ -184,13 +205,13 @@ async def workflow_run_replay( return await self.workflow_run_api.workflow_run_update_replay( tenant=self.tenant_id, replay_workflow_runs_request=ReplayWorkflowRunsRequest( - workflow_run_ids=workflow_run_ids, + workflowRunIds=workflow_run_ids, ), ) async def workflow_run_cancel( self, workflow_run_id: str - ) -> WorkflowRunCancel200Response: + ) -> EventUpdateCancel200Response: return await self.workflow_run_api.workflow_run_cancel( tenant=self.tenant_id, workflow_runs_cancel_request=WorkflowRunsCancelRequest( @@ -200,7 +221,7 @@ async def workflow_run_cancel( async def workflow_run_bulk_cancel( self, workflow_run_ids: list[str] - ) -> WorkflowRunCancel200Response: + ) -> EventUpdateCancel200Response: return await self.workflow_run_api.workflow_run_cancel( tenant=self.tenant_id, workflow_runs_cancel_request=WorkflowRunsCancelRequest( @@ -211,16 +232,16 @@ async def workflow_run_bulk_cancel( async def workflow_run_create( self, workflow_id: str, - input: dict[str, Any], + input: JSONSerializableDict, version: str | None = None, - additional_metadata: list[str] | None = None, + additional_metadata: JSONSerializableDict = {}, ) -> WorkflowRun: return await self.workflow_run_api.workflow_run_create( workflow=workflow_id, version=version, trigger_workflow_run_request=TriggerWorkflowRunRequest( input=input, - additional_metadata=additional_metadata, + additionalMetadata=additional_metadata, ), ) @@ -229,9 +250,9 @@ async def cron_create( workflow_name: str, cron_name: str, expression: str, - input: dict[str, Any], - additional_metadata: dict[str, str], - ): + input: JSONSerializableDict, + additional_metadata: JSONSerializableDict, + ) -> CronWorkflows: return await self.workflow_run_api.cron_workflow_trigger_create( tenant=self.tenant_id, workflow=workflow_name, @@ -239,12 +260,12 @@ async def cron_create( cronName=cron_name, cronExpression=expression, input=input, - additional_metadata=additional_metadata, + additionalMetadata=additional_metadata, ), ) - async def cron_delete(self, cron_trigger_id: str): - return await self.workflow_api.workflow_cron_delete( + async def cron_delete(self, cron_trigger_id: str) -> None: + await self.workflow_api.workflow_cron_delete( tenant=self.tenant_id, cron_workflow=cron_trigger_id, ) @@ -257,7 +278,7 @@ async def cron_list( additional_metadata: list[str] | None = None, order_by_field: CronWorkflowsOrderByField | None = None, order_by_direction: WorkflowRunOrderByDirection | None = None, - ): + ) -> CronWorkflowsList: return await self.workflow_api.cron_workflow_list( tenant=self.tenant_id, offset=offset, @@ -268,7 +289,7 @@ async def cron_list( order_by_direction=order_by_direction, ) - async def cron_get(self, cron_trigger_id: str): + async def cron_get(self, cron_trigger_id: str) -> CronWorkflows: return await self.workflow_api.workflow_cron_get( tenant=self.tenant_id, cron_workflow=cron_trigger_id, @@ -278,21 +299,21 @@ async def schedule_create( self, name: str, trigger_at: datetime.datetime, - input: dict[str, Any], - additional_metadata: dict[str, str], - ): + input: JSONSerializableDict, + additional_metadata: JSONSerializableDict, + ) -> ScheduledWorkflows: return await self.workflow_run_api.scheduled_workflow_run_create( tenant=self.tenant_id, workflow=name, schedule_workflow_run_request=ScheduleWorkflowRunRequest( triggerAt=trigger_at, input=input, - additional_metadata=additional_metadata, + additionalMetadata=additional_metadata, ), ) - async def schedule_delete(self, scheduled_trigger_id: str): - return await self.workflow_api.workflow_scheduled_delete( + async def schedule_delete(self, scheduled_trigger_id: str) -> None: + await self.workflow_api.workflow_scheduled_delete( tenant=self.tenant_id, scheduled_workflow_run=scheduled_trigger_id, ) @@ -307,7 +328,7 @@ async def schedule_list( parent_step_run_id: str | None = None, order_by_field: ScheduledWorkflowsOrderByField | None = None, order_by_direction: WorkflowRunOrderByDirection | None = None, - ): + ) -> ScheduledWorkflowsList: return await self.workflow_api.workflow_scheduled_list( tenant=self.tenant_id, offset=offset, @@ -320,7 +341,7 @@ async def schedule_list( order_by_direction=order_by_direction, ) - async def schedule_get(self, scheduled_trigger_id: str): + async def schedule_get(self, scheduled_trigger_id: str) -> ScheduledWorkflows: return await self.workflow_api.workflow_scheduled_get( tenant=self.tenant_id, scheduled_workflow_run=scheduled_trigger_id, @@ -373,9 +394,10 @@ async def events_list( async def events_replay(self, event_ids: list[str] | EventList) -> EventList: if isinstance(event_ids, EventList): - event_ids = [r.metadata.id for r in event_ids.rows] + rows = event_ids.rows or [] + event_ids = [r.metadata.id for r in rows] - return self.event_api.event_update_replay( + return await self.event_api.event_update_replay( tenant=self.tenant_id, replay_event_request=ReplayEventRequest(eventIds=event_ids), ) @@ -393,7 +415,7 @@ def __init__(self, host: str, api_key: str, tenant_id: str): # Register the cleanup method to be called on exit atexit.register(self._cleanup) - def _cleanup(self): + def _cleanup(self) -> None: """ Stop the running thread and clean up the event loop. """ @@ -401,14 +423,14 @@ def _cleanup(self): self._loop.call_soon_threadsafe(self._loop.stop) self._thread.join() - def _run_event_loop(self): + def _run_event_loop(self) -> None: """ Run the asyncio event loop in a separate thread. """ asyncio.set_event_loop(self._loop) self._loop.run_forever() - def _run_coroutine(self, coro) -> Any: + def _run_coroutine(self, coro: Coroutine[Y, S, R]) -> R: """ Execute a coroutine in the event loop and return the result. """ @@ -459,20 +481,20 @@ def workflow_run_list( def workflow_run_get(self, workflow_run_id: str) -> WorkflowRun: return self._run_coroutine(self.aio.workflow_run_get(workflow_run_id)) - def workflow_run_cancel(self, workflow_run_id: str) -> WorkflowRunCancel200Response: + def workflow_run_cancel(self, workflow_run_id: str) -> EventUpdateCancel200Response: return self._run_coroutine(self.aio.workflow_run_cancel(workflow_run_id)) def workflow_run_bulk_cancel( self, workflow_run_ids: list[str] - ) -> WorkflowRunCancel200Response: + ) -> EventUpdateCancel200Response: return self._run_coroutine(self.aio.workflow_run_bulk_cancel(workflow_run_ids)) def workflow_run_create( self, workflow_id: str, - input: dict[str, Any], + input: JSONSerializableDict, version: str | None = None, - additional_metadata: list[str] | None = None, + additional_metadata: JSONSerializableDict = {}, ) -> WorkflowRun: return self._run_coroutine( self.aio.workflow_run_create( @@ -485,8 +507,8 @@ def cron_create( workflow_name: str, cron_name: str, expression: str, - input: dict[str, Any], - additional_metadata: dict[str, str], + input: JSONSerializableDict, + additional_metadata: JSONSerializableDict, ) -> CronWorkflows: return self._run_coroutine( self.aio.cron_create( @@ -494,8 +516,8 @@ def cron_create( ) ) - def cron_delete(self, cron_trigger_id: str): - return self._run_coroutine(self.aio.cron_delete(cron_trigger_id)) + def cron_delete(self, cron_trigger_id: str) -> None: + self._run_coroutine(self.aio.cron_delete(cron_trigger_id)) def cron_list( self, @@ -505,7 +527,7 @@ def cron_list( additional_metadata: list[str] | None = None, order_by_field: CronWorkflowsOrderByField | None = None, order_by_direction: WorkflowRunOrderByDirection | None = None, - ): + ) -> CronWorkflowsList: return self._run_coroutine( self.aio.cron_list( offset, @@ -517,24 +539,24 @@ def cron_list( ) ) - def cron_get(self, cron_trigger_id: str): + def cron_get(self, cron_trigger_id: str) -> CronWorkflows: return self._run_coroutine(self.aio.cron_get(cron_trigger_id)) def schedule_create( self, workflow_name: str, trigger_at: datetime.datetime, - input: dict[str, Any], - additional_metadata: dict[str, str], - ): + input: JSONSerializableDict, + additional_metadata: JSONSerializableDict, + ) -> ScheduledWorkflows: return self._run_coroutine( self.aio.schedule_create( workflow_name, trigger_at, input, additional_metadata ) ) - def schedule_delete(self, scheduled_trigger_id: str): - return self._run_coroutine(self.aio.schedule_delete(scheduled_trigger_id)) + def schedule_delete(self, scheduled_trigger_id: str) -> None: + self._run_coroutine(self.aio.schedule_delete(scheduled_trigger_id)) def schedule_list( self, @@ -544,7 +566,7 @@ def schedule_list( additional_metadata: list[str] | None = None, order_by_field: CronWorkflowsOrderByField | None = None, order_by_direction: WorkflowRunOrderByDirection | None = None, - ): + ) -> ScheduledWorkflowsList: return self._run_coroutine( self.aio.schedule_list( offset, @@ -556,7 +578,7 @@ def schedule_list( ) ) - def schedule_get(self, scheduled_trigger_id: str): + def schedule_get(self, scheduled_trigger_id: str) -> ScheduledWorkflows: return self._run_coroutine(self.aio.schedule_get(scheduled_trigger_id)) def list_logs( diff --git a/hatchet_sdk/clients/run_event_listener.py b/hatchet_sdk/clients/run_event_listener.py index b5db6a74..570d248b 100644 --- a/hatchet_sdk/clients/run_event_listener.py +++ b/hatchet_sdk/clients/run_event_listener.py @@ -1,6 +1,7 @@ import asyncio import json -from typing import AsyncGenerator +from enum import Enum +from typing import Any, AsyncGenerator, Callable, Generator, cast import grpc @@ -21,7 +22,7 @@ DEFAULT_ACTION_LISTENER_RETRY_COUNT = 5 -class StepRunEventType: +class StepRunEventType(str, Enum): STEP_RUN_EVENT_TYPE_STARTED = "STEP_RUN_EVENT_TYPE_STARTED" STEP_RUN_EVENT_TYPE_COMPLETED = "STEP_RUN_EVENT_TYPE_COMPLETED" STEP_RUN_EVENT_TYPE_FAILED = "STEP_RUN_EVENT_TYPE_FAILED" @@ -30,7 +31,7 @@ class StepRunEventType: STEP_RUN_EVENT_TYPE_STREAM = "STEP_RUN_EVENT_TYPE_STREAM" -class WorkflowRunEventType: +class WorkflowRunEventType(str, Enum): WORKFLOW_RUN_EVENT_TYPE_STARTED = "WORKFLOW_RUN_EVENT_TYPE_STARTED" WORKFLOW_RUN_EVENT_TYPE_COMPLETED = "WORKFLOW_RUN_EVENT_TYPE_COMPLETED" WORKFLOW_RUN_EVENT_TYPE_FAILED = "WORKFLOW_RUN_EVENT_TYPE_FAILED" @@ -62,14 +63,14 @@ def __init__(self, type: StepRunEventType, payload: str): self.payload = payload -def new_listener(config: ClientConfig): +def new_listener(config: ClientConfig) -> "RunEventListenerClient": return RunEventListenerClient(config=config) class RunEventListener: - workflow_run_id: str = None - additional_meta_kv: tuple[str, str] = None + workflow_run_id: str | None = None + additional_meta_kv: tuple[str, str] | None = None def __init__(self, client: DispatcherStub, token: str): self.client = client @@ -77,7 +78,9 @@ def __init__(self, client: DispatcherStub, token: str): self.token = token @classmethod - def for_run_id(cls, workflow_run_id: str, client: DispatcherStub, token: str): + def for_run_id( + cls, workflow_run_id: str, client: DispatcherStub, token: str + ) -> "RunEventListener": listener = RunEventListener(client, token) listener.workflow_run_id = workflow_run_id return listener @@ -85,21 +88,21 @@ def for_run_id(cls, workflow_run_id: str, client: DispatcherStub, token: str): @classmethod def for_additional_meta( cls, key: str, value: str, client: DispatcherStub, token: str - ): + ) -> "RunEventListener": listener = RunEventListener(client, token) listener.additional_meta_kv = (key, value) return listener - def abort(self): + def abort(self) -> None: self.stop_signal = True - def __aiter__(self): + def __aiter__(self) -> AsyncGenerator[StepRunEvent, None]: return self._generator() - async def __anext__(self): + async def __anext__(self) -> StepRunEvent: return await self._generator().__anext__() - def __iter__(self): + def __iter__(self) -> Generator[StepRunEvent, None, None]: try: loop = asyncio.get_event_loop() except RuntimeError as e: @@ -145,15 +148,18 @@ async def _generator(self) -> AsyncGenerator[StepRunEvent, None]: try: if workflow_event.eventPayload: + ## TODO: Should this be `dumps` instead? payload = json.loads(workflow_event.eventPayload) except Exception as e: payload = workflow_event.eventPayload pass + assert isinstance(payload, str) + yield StepRunEvent(type=eventType, payload=payload) elif workflow_event.resourceType == RESOURCE_TYPE_WORKFLOW_RUN: - if workflow_event.eventType in workflow_run_event_type_mapping: - eventType = workflow_run_event_type_mapping[ + if workflow_event.eventType in step_run_event_type_mapping: + workflowRunEventType = step_run_event_type_mapping[ workflow_event.eventType ] else: @@ -169,7 +175,9 @@ async def _generator(self) -> AsyncGenerator[StepRunEvent, None]: except Exception as e: pass - yield StepRunEvent(type=eventType, payload=payload) + assert isinstance(payload, str) + + yield StepRunEvent(type=workflowRunEventType, payload=payload) if workflow_event.hangup: listener = None @@ -194,7 +202,7 @@ async def _generator(self) -> AsyncGenerator[StepRunEvent, None]: break # Raise StopAsyncIteration to properly end the generator - async def retry_subscribe(self): + async def retry_subscribe(self) -> AsyncGenerator[WorkflowEvent, None]: retries = 0 while retries < DEFAULT_ACTION_LISTENER_RETRY_COUNT: @@ -203,19 +211,25 @@ async def retry_subscribe(self): await asyncio.sleep(DEFAULT_ACTION_LISTENER_RETRY_INTERVAL) if self.workflow_run_id is not None: - return self.client.SubscribeToWorkflowEvents( - SubscribeToWorkflowEventsRequest( - workflowRunId=self.workflow_run_id, + return cast( + AsyncGenerator[WorkflowEvent, None], + self.client.SubscribeToWorkflowEvents( + SubscribeToWorkflowEventsRequest( + workflowRunId=self.workflow_run_id, + ), + metadata=get_metadata(self.token), ), - metadata=get_metadata(self.token), ) elif self.additional_meta_kv is not None: - return self.client.SubscribeToWorkflowEvents( - SubscribeToWorkflowEventsRequest( - additionalMetaKey=self.additional_meta_kv[0], - additionalMetaValue=self.additional_meta_kv[1], + return cast( + AsyncGenerator[WorkflowEvent, None], + self.client.SubscribeToWorkflowEvents( + SubscribeToWorkflowEventsRequest( + additionalMetaKey=self.additional_meta_kv[0], + additionalMetaValue=self.additional_meta_kv[1], + ), + metadata=get_metadata(self.token), ), - metadata=get_metadata(self.token), ) else: raise Exception("no listener method provided") @@ -226,34 +240,38 @@ async def retry_subscribe(self): else: raise ValueError(f"gRPC error: {e}") + raise Exception("Failed to subscribe to workflow events") + class RunEventListenerClient: def __init__(self, config: ClientConfig): self.token = config.token self.config = config - self.client: DispatcherStub = None + self.client: DispatcherStub | None = None - def stream_by_run_id(self, workflow_run_id: str): + def stream_by_run_id(self, workflow_run_id: str) -> RunEventListener: return self.stream(workflow_run_id) - def stream(self, workflow_run_id: str): + def stream(self, workflow_run_id: str) -> RunEventListener: if not isinstance(workflow_run_id, str) and hasattr(workflow_run_id, "__str__"): workflow_run_id = str(workflow_run_id) if not self.client: aio_conn = new_conn(self.config, True) - self.client = DispatcherStub(aio_conn) + self.client = DispatcherStub(aio_conn) # type: ignore[no-untyped-call] return RunEventListener.for_run_id(workflow_run_id, self.client, self.token) - def stream_by_additional_metadata(self, key: str, value: str): + def stream_by_additional_metadata(self, key: str, value: str) -> RunEventListener: if not self.client: aio_conn = new_conn(self.config, True) - self.client = DispatcherStub(aio_conn) + self.client = DispatcherStub(aio_conn) # type: ignore[no-untyped-call] return RunEventListener.for_additional_meta(key, value, self.client, self.token) - async def on(self, workflow_run_id: str, handler: callable = None): + async def on( + self, workflow_run_id: str, handler: Callable[[StepRunEvent], Any] | None = None + ) -> None: async for event in self.stream(workflow_run_id): # call the handler if provided if handler: diff --git a/hatchet_sdk/clients/workflow_listener.py b/hatchet_sdk/clients/workflow_listener.py index b1131587..d38d6be8 100644 --- a/hatchet_sdk/clients/workflow_listener.py +++ b/hatchet_sdk/clients/workflow_listener.py @@ -1,10 +1,11 @@ import asyncio import json from collections.abc import AsyncIterator -from typing import AsyncGenerator +from typing import Any, AsyncGenerator, cast import grpc -from grpc._cython import cygrpc +import grpc.aio +from grpc._cython import cygrpc # type: ignore[attr-defined] from hatchet_sdk.clients.event_ts import Event_ts, read_with_interrupt from hatchet_sdk.connection import new_conn @@ -31,10 +32,10 @@ def __init__(self, id: int, workflow_run_id: str): self.workflow_run_id = workflow_run_id self.queue: asyncio.Queue[WorkflowRunEvent | None] = asyncio.Queue() - async def __aiter__(self): + async def __aiter__(self) -> "_Subscription": return self - async def __anext__(self) -> WorkflowRunEvent: + async def __anext__(self) -> WorkflowRunEvent | None: return await self.queue.get() async def get(self) -> WorkflowRunEvent: @@ -45,10 +46,10 @@ async def get(self) -> WorkflowRunEvent: return event - async def put(self, item: WorkflowRunEvent): + async def put(self, item: WorkflowRunEvent) -> None: await self.queue.put(item) - async def close(self): + async def close(self) -> None: await self.queue.put(None) @@ -62,25 +63,28 @@ class PooledWorkflowRunListener: subscription_counter: int = 0 subscription_counter_lock: asyncio.Lock = asyncio.Lock() - requests: asyncio.Queue[SubscribeToWorkflowRunsRequest] = asyncio.Queue() + requests: asyncio.Queue[SubscribeToWorkflowRunsRequest | int] = asyncio.Queue() - listener: AsyncGenerator[WorkflowRunEvent, None] = None - listener_task: asyncio.Task = None + listener: ( + grpc.aio.UnaryStreamCall[SubscribeToWorkflowRunsRequest, WorkflowRunEvent] + | None + ) = None + listener_task: asyncio.Task[None] | None = None curr_requester: int = 0 # events have keys of the format workflow_run_id + subscription_id events: dict[int, _Subscription] = {} - interrupter: asyncio.Task = None + interrupter: asyncio.Task[None] | None = None def __init__(self, config: ClientConfig): conn = new_conn(config, True) - self.client = DispatcherStub(conn) + self.client = DispatcherStub(conn) # type: ignore[no-untyped-call] self.token = config.token self.config = config - async def _interrupter(self): + async def _interrupter(self) -> None: """ _interrupter runs in a separate thread and interrupts the listener according to a configurable duration. """ @@ -89,7 +93,7 @@ async def _interrupter(self): if self.interrupt is not None: self.interrupt.set() - async def _init_producer(self): + async def _init_producer(self) -> None: try: if not self.listener: while True: @@ -106,6 +110,9 @@ async def _init_producer(self): while True: self.interrupt = Event_ts() + if self.listener is None: + continue + t = asyncio.create_task( read_with_interrupt(self.listener, self.interrupt) ) @@ -118,7 +125,8 @@ async def _init_producer(self): ) t.cancel() - self.listener.cancel() + if self.listener: + self.listener.cancel() await asyncio.sleep( DEFAULT_WORKFLOW_LISTENER_RETRY_INTERVAL ) @@ -178,7 +186,7 @@ async def _request(self) -> AsyncIterator[SubscribeToWorkflowRunsRequest]: yield request self.requests.task_done() - def cleanup_subscription(self, subscription_id: int): + def cleanup_subscription(self, subscription_id: int) -> None: workflow_run_id = self.subscriptionsToWorkflows[subscription_id] if workflow_run_id in self.workflowsToSubscriptions: @@ -187,8 +195,7 @@ def cleanup_subscription(self, subscription_id: int): del self.subscriptionsToWorkflows[subscription_id] del self.events[subscription_id] - async def subscribe(self, workflow_run_id: str): - init_producer: asyncio.Task = None + async def subscribe(self, workflow_run_id: str) -> WorkflowRunEvent: try: # create a new subscription id, place a mutex on the counter await self.subscription_counter_lock.acquire() @@ -216,15 +223,13 @@ async def subscribe(self, workflow_run_id: str): if not self.listener_task or self.listener_task.done(): self.listener_task = asyncio.create_task(self._init_producer()) - event = await self.events[subscription_id].get() - - return event + return await self.events[subscription_id].get() except asyncio.CancelledError: raise finally: self.cleanup_subscription(subscription_id) - async def result(self, workflow_run_id: str): + async def result(self, workflow_run_id: str) -> dict[str, Any]: from hatchet_sdk.clients.admin import DedupeViolationErr event = await self.subscribe(workflow_run_id) @@ -248,7 +253,9 @@ async def result(self, workflow_run_id: str): return results - async def _retry_subscribe(self): + async def _retry_subscribe( + self, + ) -> grpc.aio.UnaryStreamCall[SubscribeToWorkflowRunsRequest, WorkflowRunEvent]: retries = 0 while retries < DEFAULT_WORKFLOW_LISTENER_RETRY_COUNT: @@ -260,14 +267,19 @@ async def _retry_subscribe(self): if self.curr_requester != 0: self.requests.put_nowait(self.curr_requester) - listener = self.client.SubscribeToWorkflowRuns( - self._request(), - metadata=get_metadata(self.token), + return cast( + grpc.aio.UnaryStreamCall[ + SubscribeToWorkflowRunsRequest, WorkflowRunEvent + ], + self.client.SubscribeToWorkflowRuns( + self._request(), + metadata=get_metadata(self.token), + ), ) - - return listener except grpc.RpcError as e: if e.code() == grpc.StatusCode.UNAVAILABLE: retries = retries + 1 else: raise ValueError(f"gRPC error: {e}") + + raise ValueError("Failed to connect to workflow run listener") diff --git a/hatchet_sdk/connection.py b/hatchet_sdk/connection.py index 185395e4..2373d8dd 100644 --- a/hatchet_sdk/connection.py +++ b/hatchet_sdk/connection.py @@ -1,5 +1,5 @@ import os -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Literal, cast, overload import grpc @@ -7,8 +7,15 @@ from hatchet_sdk.loader import ClientConfig -def new_conn(config: "ClientConfig", aio=False): +@overload +def new_conn(config: "ClientConfig", aio: Literal[False]) -> grpc.Channel: ... + +@overload +def new_conn(config: "ClientConfig", aio: Literal[True]) -> grpc.aio.Channel: ... + + +def new_conn(config: "ClientConfig", aio: bool) -> grpc.Channel | grpc.aio.Channel: credentials: grpc.ChannelCredentials | None = None # load channel credentials @@ -20,6 +27,10 @@ def new_conn(config: "ClientConfig", aio=False): credentials = grpc.ssl_channel_credentials(root_certificates=root) elif config.tls_config.tls_strategy == "mtls": + assert config.tls_config.ca_file + assert config.tls_config.key_file + assert config.tls_config.cert_file + root = open(config.tls_config.ca_file, "rb").read() private_key = open(config.tls_config.key_file, "rb").read() certificate_chain = open(config.tls_config.cert_file, "rb").read() @@ -32,7 +43,7 @@ def new_conn(config: "ClientConfig", aio=False): start = grpc if not aio else grpc.aio - channel_options = [ + channel_options: list[tuple[str, str | int]] = [ ("grpc.max_send_message_length", config.grpc_max_send_message_length), ("grpc.max_receive_message_length", config.grpc_max_recv_message_length), ("grpc.keepalive_time_ms", 10 * 1000), @@ -61,4 +72,8 @@ def new_conn(config: "ClientConfig", aio=False): credentials=credentials, options=channel_options, ) - return conn + + return cast( + grpc.Channel | grpc.aio.Channel, + conn, + ) diff --git a/hatchet_sdk/context/context.py b/hatchet_sdk/context/context.py index f20acd66..aa72f208 100644 --- a/hatchet_sdk/context/context.py +++ b/hatchet_sdk/context/context.py @@ -18,7 +18,7 @@ BulkTriggerWorkflowRequest, TriggerWorkflowRequest, ) -from hatchet_sdk.utils.types import WorkflowValidator +from hatchet_sdk.utils.types import JSONSerializableDict, WorkflowValidator from hatchet_sdk.utils.typing import is_basemodel_subclass from hatchet_sdk.workflow_run import WorkflowRunRef @@ -54,29 +54,20 @@ class BaseContext: def _prepare_workflow_options( self, key: str | None = None, - options: ChildTriggerWorkflowOptions | None = None, + options: ChildTriggerWorkflowOptions = ChildTriggerWorkflowOptions(), worker_id: str | None = None, ) -> TriggerWorkflowOptions: workflow_run_id = self.action.workflow_run_id step_run_id = self.action.step_run_id - desired_worker_id = None - if options is not None and "sticky" in options and options["sticky"] == True: - desired_worker_id = worker_id - - meta = None - if options is not None and "additional_metadata" in options: - meta = options["additional_metadata"] - - ## TODO: Pydantic here to simplify this - trigger_options: TriggerWorkflowOptions = { - "parent_id": workflow_run_id, - "parent_step_run_id": step_run_id, - "child_key": key, - "child_index": self.spawn_index, - "additional_metadata": meta, - "desired_worker_id": desired_worker_id, - } + trigger_options = TriggerWorkflowOptions( + parent_id=workflow_run_id, + parent_step_run_id=step_run_id, + child_key=key, + child_index=self.spawn_index, + additional_metadata=options.additional_metadata, + desired_worker_id=worker_id if options.sticky else None, + ) self.spawn_index += 1 return trigger_options @@ -90,7 +81,7 @@ def __init__( admin_client: AdminClient, event_client: EventClient, rest_client: RestApi, - workflow_listener: PooledWorkflowRunListener, + workflow_listener: PooledWorkflowRunListener | None, workflow_run_event_listener: RunEventListenerClient, worker: WorkerContext, namespace: str = "", @@ -110,20 +101,11 @@ def __init__( async def spawn_workflow( self, workflow_name: str, - input: dict[str, Any] = {}, + input: JSONSerializableDict = {}, key: str | None = None, - options: ChildTriggerWorkflowOptions | None = None, + options: ChildTriggerWorkflowOptions = ChildTriggerWorkflowOptions(), ) -> WorkflowRunRef: worker_id = self.worker.id() - # if ( - # options is not None - # and "sticky" in options - # and options["sticky"] == True - # and not self.worker.has_workflow(workflow_name) - # ): - # raise Exception( - # f"cannot run with sticky: workflow {workflow_name} is not registered on the worker" - # ) trigger_options = self._prepare_workflow_options(key, options, worker_id) @@ -141,21 +123,16 @@ async def spawn_workflows( worker_id = self.worker.id() - bulk_trigger_workflow_runs: list[WorkflowRunDict] = [] - for child_workflow_run in child_workflow_runs: - workflow_name = child_workflow_run["workflow_name"] - input = child_workflow_run["input"] - - key = child_workflow_run.get("key") - options = child_workflow_run.get("options", {}) - - trigger_options = self._prepare_workflow_options(key, options, worker_id) - - bulk_trigger_workflow_runs.append( - WorkflowRunDict( - workflow_name=workflow_name, input=input, options=trigger_options - ) + bulk_trigger_workflow_runs = [ + WorkflowRunDict( + workflow_name=child_workflow_run.workflow_name, + input=child_workflow_run.input, + options=self._prepare_workflow_options( + child_workflow_run.key, child_workflow_run.options, worker_id + ), ) + for child_workflow_run in child_workflow_runs + ] return await self.admin_client.aio.run_workflows(bulk_trigger_workflow_runs) @@ -172,7 +149,7 @@ def __init__( admin_client: AdminClient, event_client: EventClient, rest_client: RestApi, - workflow_listener: PooledWorkflowRunListener, + workflow_listener: PooledWorkflowRunListener | None, workflow_run_event_listener: RunEventListenerClient, worker: WorkerContext, namespace: str = "", @@ -193,6 +170,8 @@ def __init__( namespace, ) + self.data: dict[str, Any] + # Check the type of action.action_payload before attempting to load it as JSON if isinstance(action.action_payload, (str, bytes, bytearray)): try: @@ -203,16 +182,14 @@ def __init__( self.data: dict[str, Any] = {} # type: ignore[no-redef] else: # Directly assign the payload to self.data if it's already a dict - self.data = ( - action.action_payload if isinstance(action.action_payload, dict) else {} - ) + self.data = action.action_payload self.action = action # FIXME: stepRunId is a legacy field, we should remove it self.stepRunId = action.step_run_id - self.step_run_id = action.step_run_id + self.step_run_id: str = action.step_run_id self.exit_flag = False self.dispatcher_client = dispatcher_client self.admin_client = admin_client diff --git a/hatchet_sdk/features/cron.py b/hatchet_sdk/features/cron.py index c54e5b3b..413251d2 100644 --- a/hatchet_sdk/features/cron.py +++ b/hatchet_sdk/features/cron.py @@ -1,6 +1,6 @@ -from typing import Union +from typing import Any, Union, cast -from pydantic import BaseModel, field_validator +from pydantic import BaseModel, Field, field_validator from hatchet_sdk.client import Client from hatchet_sdk.clients.rest.models.cron_workflows import CronWorkflows @@ -11,9 +11,10 @@ from hatchet_sdk.clients.rest.models.workflow_run_order_by_direction import ( WorkflowRunOrderByDirection, ) +from hatchet_sdk.utils.types import JSONSerializableDict -class CreateCronTriggerInput(BaseModel): +class CreateCronTriggerJSONSerializableDict(BaseModel): """ Schema for creating a workflow run triggered by a cron. @@ -23,12 +24,13 @@ class CreateCronTriggerInput(BaseModel): additional_metadata (dict[str, str]): Additional metadata associated with the cron trigger (e.g. {"key1": "value1", "key2": "value2"}). """ - expression: str = None - input: dict = {} - additional_metadata: dict[str, str] = {} + expression: str + input: JSONSerializableDict = Field(default_factory=dict) + additional_metadata: JSONSerializableDict = Field(default_factory=dict) @field_validator("expression") - def validate_cron_expression(cls, v): + @classmethod + def validate_cron_expression(cls, v: str) -> str: """ Validates the cron expression to ensure it adheres to the expected format. @@ -86,8 +88,8 @@ def create( workflow_name: str, cron_name: str, expression: str, - input: dict, - additional_metadata: dict[str, str], + input: JSONSerializableDict, + additional_metadata: JSONSerializableDict, ) -> CronWorkflows: """ Creates a new workflow cron trigger. @@ -102,7 +104,7 @@ def create( Returns: CronWorkflows: The created cron workflow instance. """ - validated_input = CreateCronTriggerInput( + validated_input = CreateCronTriggerJSONSerializableDict( expression=expression, input=input, additional_metadata=additional_metadata ) @@ -121,10 +123,11 @@ def delete(self, cron_trigger: Union[str, CronWorkflows]) -> None: Args: cron_trigger (Union[str, CronWorkflows]): The cron trigger ID or CronWorkflows instance to delete. """ - id_ = cron_trigger - if isinstance(cron_trigger, CronWorkflows): - id_ = cron_trigger.metadata.id - self._client.rest.cron_delete(id_) + self._client.rest.cron_delete( + cron_trigger.metadata.id + if isinstance(cron_trigger, CronWorkflows) + else cron_trigger + ) def list( self, @@ -168,10 +171,11 @@ def get(self, cron_trigger: Union[str, CronWorkflows]) -> CronWorkflows: Returns: CronWorkflows: The requested cron workflow instance. """ - id_ = cron_trigger - if isinstance(cron_trigger, CronWorkflows): - id_ = cron_trigger.metadata.id - return self._client.rest.cron_get(id_) + return self._client.rest.cron_get( + cron_trigger.metadata.id + if isinstance(cron_trigger, CronWorkflows) + else cron_trigger + ) class CronClientAsync: @@ -198,8 +202,8 @@ async def create( workflow_name: str, cron_name: str, expression: str, - input: dict, - additional_metadata: dict[str, str], + input: JSONSerializableDict, + additional_metadata: JSONSerializableDict, ) -> CronWorkflows: """ Asynchronously creates a new workflow cron trigger. @@ -214,7 +218,7 @@ async def create( Returns: CronWorkflows: The created cron workflow instance. """ - validated_input = CreateCronTriggerInput( + validated_input = CreateCronTriggerJSONSerializableDict( expression=expression, input=input, additional_metadata=additional_metadata ) @@ -233,10 +237,11 @@ async def delete(self, cron_trigger: Union[str, CronWorkflows]) -> None: Args: cron_trigger (Union[str, CronWorkflows]): The cron trigger ID or CronWorkflows instance to delete. """ - id_ = cron_trigger - if isinstance(cron_trigger, CronWorkflows): - id_ = cron_trigger.metadata.id - await self._client.rest.aio.cron_delete(id_) + await self._client.rest.aio.cron_delete( + cron_trigger.metadata.id + if isinstance(cron_trigger, CronWorkflows) + else cron_trigger + ) async def list( self, @@ -280,7 +285,9 @@ async def get(self, cron_trigger: Union[str, CronWorkflows]) -> CronWorkflows: Returns: CronWorkflows: The requested cron workflow instance. """ - id_ = cron_trigger - if isinstance(cron_trigger, CronWorkflows): - id_ = cron_trigger.metadata.id - return await self._client.rest.aio.cron_get(id_) + + return await self._client.rest.aio.cron_get( + cron_trigger.metadata.id + if isinstance(cron_trigger, CronWorkflows) + else cron_trigger + ) diff --git a/hatchet_sdk/features/scheduled.py b/hatchet_sdk/features/scheduled.py index 45af2609..cf8462df 100644 --- a/hatchet_sdk/features/scheduled.py +++ b/hatchet_sdk/features/scheduled.py @@ -1,7 +1,7 @@ import datetime -from typing import Any, Coroutine, Dict, List, Optional, Union +from typing import Any, Coroutine, Dict, List, Optional, Union, cast -from pydantic import BaseModel +from pydantic import BaseModel, Field from hatchet_sdk.client import Client from hatchet_sdk.clients.rest.models.cron_workflows import CronWorkflows @@ -12,12 +12,16 @@ from hatchet_sdk.clients.rest.models.scheduled_workflows_list import ( ScheduledWorkflowsList, ) +from hatchet_sdk.clients.rest.models.scheduled_workflows_order_by_field import ( + ScheduledWorkflowsOrderByField, +) from hatchet_sdk.clients.rest.models.workflow_run_order_by_direction import ( WorkflowRunOrderByDirection, ) +from hatchet_sdk.utils.types import JSONSerializableDict -class CreateScheduledTriggerInput(BaseModel): +class CreateScheduledTriggerJSONSerializableDict(BaseModel): """ Schema for creating a scheduled workflow run. @@ -27,9 +31,9 @@ class CreateScheduledTriggerInput(BaseModel): trigger_at (Optional[datetime.datetime]): The datetime when the run should be triggered. """ - input: Dict[str, Any] = {} - additional_metadata: Dict[str, str] = {} - trigger_at: Optional[datetime.datetime] = None + input: JSONSerializableDict = Field(default_factory=dict) + additional_metadata: JSONSerializableDict = Field(default_factory=dict) + trigger_at: datetime.datetime class ScheduledClient: @@ -57,8 +61,8 @@ def create( self, workflow_name: str, trigger_at: datetime.datetime, - input: Dict[str, Any], - additional_metadata: Dict[str, str], + input: JSONSerializableDict, + additional_metadata: JSONSerializableDict, ) -> ScheduledWorkflows: """ Creates a new scheduled workflow run asynchronously. @@ -73,7 +77,7 @@ def create( ScheduledWorkflows: The created scheduled workflow instance. """ - validated_input = CreateScheduledTriggerInput( + validated_input = CreateScheduledTriggerJSONSerializableDict( trigger_at=trigger_at, input=input, additional_metadata=additional_metadata ) @@ -91,10 +95,11 @@ def delete(self, scheduled: Union[str, ScheduledWorkflows]) -> None: Args: scheduled (Union[str, ScheduledWorkflows]): The scheduled workflow trigger ID or ScheduledWorkflows instance to delete. """ - id_ = scheduled - if isinstance(scheduled, ScheduledWorkflows): - id_ = scheduled.metadata.id - self._client.rest.schedule_delete(id_) + self._client.rest.schedule_delete( + scheduled.metadata.id + if isinstance(scheduled, ScheduledWorkflows) + else scheduled + ) def list( self, @@ -138,10 +143,11 @@ def get(self, scheduled: Union[str, ScheduledWorkflows]) -> ScheduledWorkflows: Returns: ScheduledWorkflows: The requested scheduled workflow instance. """ - id_ = scheduled - if isinstance(scheduled, ScheduledWorkflows): - id_ = scheduled.metadata.id - return self._client.rest.schedule_get(id_) + return self._client.rest.schedule_get( + scheduled.metadata.id + if isinstance(scheduled, ScheduledWorkflows) + else scheduled + ) class ScheduledClientAsync: @@ -167,8 +173,8 @@ async def create( self, workflow_name: str, trigger_at: datetime.datetime, - input: Dict[str, Any], - additional_metadata: Dict[str, str], + input: JSONSerializableDict, + additional_metadata: JSONSerializableDict, ) -> ScheduledWorkflows: """ Creates a new scheduled workflow run asynchronously. @@ -193,10 +199,11 @@ async def delete(self, scheduled: Union[str, ScheduledWorkflows]) -> None: Args: scheduled (Union[str, ScheduledWorkflows]): The scheduled workflow trigger ID or ScheduledWorkflows instance to delete. """ - id_ = scheduled - if isinstance(scheduled, ScheduledWorkflows): - id_ = scheduled.metadata.id - await self._client.rest.aio.schedule_delete(id_) + await self._client.rest.aio.schedule_delete( + scheduled.metadata.id + if isinstance(scheduled, ScheduledWorkflows) + else scheduled + ) async def list( self, @@ -204,7 +211,7 @@ async def list( limit: Optional[int] = None, workflow_id: Optional[str] = None, additional_metadata: Optional[List[str]] = None, - order_by_field: Optional[CronWorkflowsOrderByField] = None, + order_by_field: Optional[ScheduledWorkflowsOrderByField] = None, order_by_direction: Optional[WorkflowRunOrderByDirection] = None, ) -> ScheduledWorkflowsList: """ @@ -242,7 +249,8 @@ async def get( Returns: ScheduledWorkflows: The requested scheduled workflow instance. """ - id_ = scheduled - if isinstance(scheduled, ScheduledWorkflows): - id_ = scheduled.metadata.id - return await self._client.rest.aio.schedule_get(id_) + return await self._client.rest.aio.schedule_get( + scheduled.metadata.id + if isinstance(scheduled, ScheduledWorkflows) + else scheduled + ) diff --git a/hatchet_sdk/hatchet.py b/hatchet_sdk/hatchet.py index bf0e9089..cdf17225 100644 --- a/hatchet_sdk/hatchet.py +++ b/hatchet_sdk/hatchet.py @@ -12,13 +12,13 @@ CreateStepRateLimit, DesiredWorkerLabels, StickyStrategy, + WorkerLabelComparator, ) from hatchet_sdk.features.cron import CronClient from hatchet_sdk.features.scheduled import ScheduledClient from hatchet_sdk.labels import DesiredWorkerLabel -from hatchet_sdk.loader import ClientConfig, ConfigLoader +from hatchet_sdk.loader import ClientConfig from hatchet_sdk.rate_limit import RateLimit -from hatchet_sdk.v2.callable import HatchetCallable from .client import Client, new_client, new_client_raw from .clients.admin import AdminClient @@ -48,7 +48,7 @@ def workflow( version: str = "", timeout: str = "60m", schedule_timeout: str = "5m", - sticky: Union[StickyStrategy.Value, None] = None, # type: ignore[name-defined] + sticky: Union[StickyStrategy, None] = None, default_priority: int | None = None, concurrency: ConcurrencyExpression | None = None, input_validator: Type[T] | None = None, @@ -107,13 +107,13 @@ def inner(func: Callable[P, R]) -> Callable[P, R]: setattr(func, "_step_backoff_max_seconds", backoff_max_seconds) def create_label(d: DesiredWorkerLabel) -> DesiredWorkerLabels: - value = d["value"] if "value" in d else None + value = d.value return DesiredWorkerLabels( - strValue=str(value) if not isinstance(value, int) else None, + strValue=value if not isinstance(value, int) else None, intValue=value if isinstance(value, int) else None, - required=d["required"] if "required" in d else None, # type: ignore[arg-type] - weight=d["weight"] if "weight" in d else None, - comparator=d["comparator"] if "comparator" in d else None, # type: ignore[arg-type] + required=d.required, + weight=d.weight, + comparator=d.comparator, # type: ignore[arg-type] ) setattr( @@ -187,11 +187,8 @@ class HatchetRest: rest (RestApi): Interface for REST API operations. """ - rest: RestApi - def __init__(self, config: ClientConfig = ClientConfig()): - _config: ClientConfig = ConfigLoader(".").load_client_config(config) - self.rest = RestApi(_config.server_url, _config.token, _config.tenant_id) + self.rest = RestApi(config.server_url, config.token, config.tenant_id) class Hatchet: diff --git a/hatchet_sdk/labels.py b/hatchet_sdk/labels.py index 646c666d..55836e31 100644 --- a/hatchet_sdk/labels.py +++ b/hatchet_sdk/labels.py @@ -1,10 +1,10 @@ -from typing import TypedDict +from pydantic import BaseModel, ConfigDict -class DesiredWorkerLabel(TypedDict, total=False): +class DesiredWorkerLabel(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + value: str | int - required: bool | None = None + required: bool = False weight: int | None = None - comparator: int | None = ( - None # _ClassVar[WorkerLabelComparator] TODO figure out type - ) + comparator: int | None = None diff --git a/hatchet_sdk/loader.py b/hatchet_sdk/loader.py index d754c2ae..deda2348 100644 --- a/hatchet_sdk/loader.py +++ b/hatchet_sdk/loader.py @@ -1,246 +1,168 @@ import json import os from logging import Logger, getLogger -from typing import Dict, Optional - -import yaml - -from .token import get_addresses_from_jwt, get_tenant_id_from_jwt - - -class ClientTLSConfig: - def __init__( - self, - tls_strategy: str, - cert_file: str, - key_file: str, - ca_file: str, - server_name: str, - ): - self.tls_strategy = tls_strategy - self.cert_file = cert_file - self.key_file = key_file - self.ca_file = ca_file - self.server_name = server_name - - -class ClientConfig: - logInterceptor: Logger - - def __init__( - self, - tenant_id: str = None, - tls_config: ClientTLSConfig = None, - token: str = None, - host_port: str = "localhost:7070", - server_url: str = "https://app.dev.hatchet-tools.com", - namespace: str = None, - listener_v2_timeout: int = None, - logger: Logger = None, - grpc_max_recv_message_length: int = 4 * 1024 * 1024, # 4MB - grpc_max_send_message_length: int = 4 * 1024 * 1024, # 4MB - otel_exporter_oltp_endpoint: str | None = None, - otel_service_name: str | None = None, - otel_exporter_oltp_headers: dict[str, str] | None = None, - otel_exporter_oltp_protocol: str | None = None, - worker_healthcheck_port: int | None = None, - worker_healthcheck_enabled: bool | None = None, - ): - self.tenant_id = tenant_id - self.tls_config = tls_config - self.host_port = host_port - self.token = token - self.server_url = server_url - self.namespace = "" - self.logInterceptor = logger - self.grpc_max_recv_message_length = grpc_max_recv_message_length - self.grpc_max_send_message_length = grpc_max_send_message_length - self.otel_exporter_oltp_endpoint = otel_exporter_oltp_endpoint - self.otel_service_name = otel_service_name - self.otel_exporter_oltp_headers = otel_exporter_oltp_headers - self.otel_exporter_oltp_protocol = otel_exporter_oltp_protocol - self.worker_healthcheck_port = worker_healthcheck_port - self.worker_healthcheck_enabled = worker_healthcheck_enabled - - if not self.logInterceptor: - self.logInterceptor = getLogger() - - # case on whether the namespace already has a trailing underscore - if namespace and not namespace.endswith("_"): - self.namespace = f"{namespace}_" - elif namespace: - self.namespace = namespace - - self.namespace = self.namespace.lower() - - self.listener_v2_timeout = listener_v2_timeout - - -class ConfigLoader: - def __init__(self, directory: str): - self.directory = directory - - def load_client_config(self, defaults: ClientConfig) -> ClientConfig: - config_file_path = os.path.join(self.directory, "client.yaml") - config_data: object = {"tls": {}} - - # determine if client.yaml exists - if os.path.exists(config_file_path): - with open(config_file_path, "r") as file: - config_data = yaml.safe_load(file) - - def get_config_value(key, env_var): - if key in config_data: - return config_data[key] - - if self._get_env_var(env_var) is not None: - return self._get_env_var(env_var) - - return getattr(defaults, key, None) - - namespace = get_config_value("namespace", "HATCHET_CLIENT_NAMESPACE") - - tenant_id = get_config_value("tenantId", "HATCHET_CLIENT_TENANT_ID") - token = get_config_value("token", "HATCHET_CLIENT_TOKEN") - listener_v2_timeout = get_config_value( - "listener_v2_timeout", "HATCHET_CLIENT_LISTENER_V2_TIMEOUT" - ) - listener_v2_timeout = int(listener_v2_timeout) if listener_v2_timeout else None - +from typing import cast + +from pydantic import BaseModel, ConfigDict, ValidationInfo, field_validator + +from hatchet_sdk.token import get_addresses_from_jwt, get_tenant_id_from_jwt + + +class ClientTLSConfig(BaseModel): + tls_strategy: str + cert_file: str | None + key_file: str | None + ca_file: str | None + server_name: str + + +def _load_tls_config(host_port: str | None = None) -> ClientTLSConfig: + server_name = os.getenv("HATCHET_CLIENT_TLS_SERVER_NAME") + + if not server_name and host_port: + server_name = host_port.split(":")[0] + + if not server_name: + server_name = "localhost" + + return ClientTLSConfig( + tls_strategy=os.getenv("HATCHET_CLIENT_TLS_STRATEGY", "tls"), + cert_file=os.getenv("HATCHET_CLIENT_TLS_CERT_FILE"), + key_file=os.getenv("HATCHET_CLIENT_TLS_KEY_FILE"), + ca_file=os.getenv("HATCHET_CLIENT_TLS_ROOT_CA_FILE"), + server_name=server_name, + ) + + +def parse_listener_timeout(timeout: str | None) -> int | None: + if timeout is None: + return None + + return int(timeout) + + +DEFAULT_HOST_PORT = "localhost:7070" + + +class ClientConfig(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True, validate_default=True) + + token: str = os.getenv("HATCHET_CLIENT_TOKEN", "") + logger: Logger = getLogger() + tenant_id: str = os.getenv("HATCHET_CLIENT_TENANT_ID", "") + + ## IMPORTANT: Order matters here. The validators run in the order that the + ## fields are defined in the model. So, we need to make sure that the + ## host_port is set before we try to load the tls_config and server_url + host_port: str = os.getenv("HATCHET_CLIENT_HOST_PORT", DEFAULT_HOST_PORT) + tls_config: ClientTLSConfig = _load_tls_config() + + server_url: str = "https://app.dev.hatchet-tools.com" + namespace: str = os.getenv("HATCHET_CLIENT_NAMESPACE", "") + listener_v2_timeout: int | None = parse_listener_timeout( + os.getenv("HATCHET_CLIENT_LISTENER_V2_TIMEOUT") + ) + grpc_max_recv_message_length: int = int( + os.getenv("HATCHET_CLIENT_GRPC_MAX_RECV_MESSAGE_LENGTH", 4 * 1024 * 1024) + ) # 4MB + grpc_max_send_message_length: int = int( + os.getenv("HATCHET_CLIENT_GRPC_MAX_SEND_MESSAGE_LENGTH", 4 * 1024 * 1024) + ) # 4MB + otel_exporter_oltp_endpoint: str | None = os.getenv( + "HATCHET_CLIENT_OTEL_EXPORTER_OTLP_ENDPOINT" + ) + otel_service_name: str | None = os.getenv("HATCHET_CLIENT_OTEL_SERVICE_NAME") + otel_exporter_oltp_headers: str | None = os.getenv( + "HATCHET_CLIENT_OTEL_EXPORTER_OTLP_HEADERS" + ) + otel_exporter_oltp_protocol: str | None = os.getenv( + "HATCHET_CLIENT_OTEL_EXPORTER_OTLP_PROTOCOL" + ) + worker_healthcheck_port: int = int( + os.getenv("HATCHET_CLIENT_WORKER_HEALTHCHECK_PORT", 8001) + ) + worker_healthcheck_enabled: bool = ( + os.getenv("HATCHET_CLIENT_WORKER_HEALTHCHECK_ENABLED", "False") == "True" + ) + + @field_validator("token", mode="after") + @classmethod + def validate_token(cls, token: str) -> str: if not token: - raise ValueError( - "Token must be set via HATCHET_CLIENT_TOKEN environment variable" - ) + raise ValueError("Token must be set") - host_port = get_config_value("hostPort", "HATCHET_CLIENT_HOST_PORT") - server_url: str | None = None + return token - grpc_max_recv_message_length = get_config_value( - "grpc_max_recv_message_length", - "HATCHET_CLIENT_GRPC_MAX_RECV_MESSAGE_LENGTH", - ) - grpc_max_send_message_length = get_config_value( - "grpc_max_send_message_length", - "HATCHET_CLIENT_GRPC_MAX_SEND_MESSAGE_LENGTH", - ) + @field_validator("namespace", mode="after") + @classmethod + def validate_namespace(cls, namespace: str) -> str: + if not namespace: + return "" - if grpc_max_recv_message_length: - grpc_max_recv_message_length = int(grpc_max_recv_message_length) + if not namespace.endswith("_"): + namespace = f"{namespace}_" - if grpc_max_send_message_length: - grpc_max_send_message_length = int(grpc_max_send_message_length) + return namespace.lower() - if not host_port: - # extract host and port from token - server_url, grpc_broadcast_address = get_addresses_from_jwt(token) - host_port = grpc_broadcast_address + @field_validator("tenant_id", mode="after") + @classmethod + def validate_tenant_id(cls, tenant_id: str, info: ValidationInfo) -> str: + token = cast(str | None, info.data.get("token")) if not tenant_id: - tenant_id = get_tenant_id_from_jwt(token) - - tls_config = self._load_tls_config(config_data["tls"], host_port) - - otel_exporter_oltp_endpoint = get_config_value( - "otel_exporter_oltp_endpoint", "HATCHET_CLIENT_OTEL_EXPORTER_OTLP_ENDPOINT" - ) - - otel_service_name = get_config_value( - "otel_service_name", "HATCHET_CLIENT_OTEL_SERVICE_NAME" - ) - - _oltp_headers = get_config_value( - "otel_exporter_oltp_headers", "HATCHET_CLIENT_OTEL_EXPORTER_OTLP_HEADERS" - ) - - if _oltp_headers: - try: - otel_header_key, api_key = _oltp_headers.split("=", maxsplit=1) - otel_exporter_oltp_headers = {otel_header_key: api_key} - except ValueError: - raise ValueError( - "HATCHET_CLIENT_OTEL_EXPORTER_OTLP_HEADERS must be in the format `key=value`" - ) - else: - otel_exporter_oltp_headers = None - - otel_exporter_oltp_protocol = get_config_value( - "otel_exporter_oltp_protocol", "HATCHET_CLIENT_OTEL_EXPORTER_OTLP_PROTOCOL" - ) - - worker_healthcheck_port = int( - get_config_value( - "worker_healthcheck_port", "HATCHET_CLIENT_WORKER_HEALTHCHECK_PORT" - ) - or 8001 - ) - - worker_healthcheck_enabled = ( - str( - get_config_value( - "worker_healthcheck_port", - "HATCHET_CLIENT_WORKER_HEALTHCHECK_ENABLED", - ) - ) - == "True" - ) - - return ClientConfig( - tenant_id=tenant_id, - tls_config=tls_config, - token=token, - host_port=host_port, - server_url=server_url, - namespace=namespace, - listener_v2_timeout=listener_v2_timeout, - logger=defaults.logInterceptor, - grpc_max_recv_message_length=grpc_max_recv_message_length, - grpc_max_send_message_length=grpc_max_send_message_length, - otel_exporter_oltp_endpoint=otel_exporter_oltp_endpoint, - otel_service_name=otel_service_name, - otel_exporter_oltp_headers=otel_exporter_oltp_headers, - otel_exporter_oltp_protocol=otel_exporter_oltp_protocol, - worker_healthcheck_port=worker_healthcheck_port, - worker_healthcheck_enabled=worker_healthcheck_enabled, - ) - - def _load_tls_config(self, tls_data: Dict, host_port) -> ClientTLSConfig: - tls_strategy = ( - tls_data["tlsStrategy"] - if "tlsStrategy" in tls_data - else self._get_env_var("HATCHET_CLIENT_TLS_STRATEGY") - ) - - if not tls_strategy: - tls_strategy = "tls" - - cert_file = ( - tls_data["tlsCertFile"] - if "tlsCertFile" in tls_data - else self._get_env_var("HATCHET_CLIENT_TLS_CERT_FILE") - ) - key_file = ( - tls_data["tlsKeyFile"] - if "tlsKeyFile" in tls_data - else self._get_env_var("HATCHET_CLIENT_TLS_KEY_FILE") - ) - ca_file = ( - tls_data["tlsRootCAFile"] - if "tlsRootCAFile" in tls_data - else self._get_env_var("HATCHET_CLIENT_TLS_ROOT_CA_FILE") - ) - - server_name = ( - tls_data["tlsServerName"] - if "tlsServerName" in tls_data - else self._get_env_var("HATCHET_CLIENT_TLS_SERVER_NAME") - ) - - # if server_name is not set, use the host from the host_port - if not server_name: - server_name = host_port.split(":")[0] - - return ClientTLSConfig(tls_strategy, cert_file, key_file, ca_file, server_name) - - @staticmethod - def _get_env_var(env_var: str, default: Optional[str] = None) -> str: - return os.environ.get(env_var, default) + if not token: + raise ValueError("Either the token or tenant_id must be set") + + return get_tenant_id_from_jwt(token) + + return tenant_id + + @field_validator("host_port", mode="after") + @classmethod + def validate_host_port(cls, host_port: str, info: ValidationInfo) -> str: + if host_port and host_port != DEFAULT_HOST_PORT: + return host_port + + token = cast(str, info.data.get("token")) + + if not token: + raise ValueError("Token must be set") + + _, grpc_broadcast_address = get_addresses_from_jwt(token) + + return grpc_broadcast_address + + @field_validator("server_url", mode="after") + @classmethod + def validate_server_url(cls, server_url: str, info: ValidationInfo) -> str: + ## IMPORTANT: Order matters here. The validators run in the order that the + ## fields are defined in the model. So, we need to make sure that the + ## host_port is set before we try to load the server_url + host_port = cast(str, info.data.get("host_port")) + + if host_port and host_port != DEFAULT_HOST_PORT: + return host_port + + token = cast(str, info.data.get("token")) + + if not token: + raise ValueError("Token must be set") + + _server_url, _ = get_addresses_from_jwt(token) + + return _server_url + + @field_validator("tls_config", mode="after") + @classmethod + def validate_tls_config( + cls, tls_config: ClientTLSConfig, info: ValidationInfo + ) -> ClientTLSConfig: + ## IMPORTANT: Order matters here. This validator runs in the order + ## that the fields are defined in the model. So, we need to make sure + ## that the host_port is set before we try to load the tls_config + host_port = cast(str, info.data.get("host_port")) + + return _load_tls_config(host_port) + + def __hash__(self) -> int: + return hash(json.dumps(self.model_dump(), default=str)) diff --git a/hatchet_sdk/metadata.py b/hatchet_sdk/metadata.py index 38a31b8b..d4004c64 100644 --- a/hatchet_sdk/metadata.py +++ b/hatchet_sdk/metadata.py @@ -1,2 +1,2 @@ -def get_metadata(token: str): +def get_metadata(token: str) -> list[tuple[str, str]]: return [("authorization", "bearer " + token)] diff --git a/hatchet_sdk/rate_limit.py b/hatchet_sdk/rate_limit.py index 0d7b9143..f9f574a4 100644 --- a/hatchet_sdk/rate_limit.py +++ b/hatchet_sdk/rate_limit.py @@ -1,7 +1,8 @@ from dataclasses import dataclass +from enum import Enum from typing import Union -from celpy import CELEvalError, Environment +from celpy import CELEvalError, Environment # type: ignore from hatchet_sdk.contracts.workflows_pb2 import CreateStepRateLimit @@ -15,7 +16,7 @@ def validate_cel_expression(expr: str) -> bool: return False -class RateLimitDuration: +class RateLimitDuration(str, Enum): SECOND = "SECOND" MINUTE = "MINUTE" HOUR = "HOUR" @@ -71,9 +72,9 @@ class RateLimit: limit: Union[int, str, None] = None duration: RateLimitDuration = RateLimitDuration.MINUTE - _req: CreateStepRateLimit = None + _req: CreateStepRateLimit | None = None - def __post_init__(self): + def __post_init__(self) -> None: # juggle the key and key_expr fields key = self.static_key key_expression = self.dynamic_key diff --git a/hatchet_sdk/token.py b/hatchet_sdk/token.py index 313a6671..58d34c65 100644 --- a/hatchet_sdk/token.py +++ b/hatchet_sdk/token.py @@ -1,20 +1,25 @@ import base64 -import json +from pydantic import BaseModel -def get_tenant_id_from_jwt(token: str) -> str: - claims = extract_claims_from_jwt(token) - return claims.get("sub") +class Claims(BaseModel): + sub: str + server_url: str + grpc_broadcast_address: str + + +def get_tenant_id_from_jwt(token: str) -> str: + return extract_claims_from_jwt(token).sub -def get_addresses_from_jwt(token: str) -> (str, str): +def get_addresses_from_jwt(token: str) -> tuple[str, str]: claims = extract_claims_from_jwt(token) - return claims.get("server_url"), claims.get("grpc_broadcast_address") + return claims.server_url, claims.grpc_broadcast_address -def extract_claims_from_jwt(token: str): +def extract_claims_from_jwt(token: str) -> Claims: parts = token.split(".") if len(parts) != 3: raise ValueError("Invalid token format") @@ -22,6 +27,5 @@ def extract_claims_from_jwt(token: str): claims_part = parts[1] claims_part += "=" * ((4 - len(claims_part) % 4) % 4) # Padding for base64 decoding claims_data = base64.urlsafe_b64decode(claims_part) - claims = json.loads(claims_data) - return claims + return Claims.model_validate_json(claims_data) diff --git a/hatchet_sdk/utils/tracing.py b/hatchet_sdk/utils/tracing.py index afc398f7..634c3995 100644 --- a/hatchet_sdk/utils/tracing.py +++ b/hatchet_sdk/utils/tracing.py @@ -16,6 +16,16 @@ OTEL_CARRIER_KEY = "__otel_carrier" +def parse_headers(headers: str | None) -> dict[str, str]: + if headers is None: + return {} + + try: + return dict([headers.split("=", maxsplit=1)]) + except ValueError: + raise ValueError("OTLP headers must be in the format `key=value`") + + @cache def create_tracer(config: ClientConfig) -> Tracer: ## TODO: Figure out how to specify protocol here @@ -27,7 +37,7 @@ def create_tracer(config: ClientConfig) -> Tracer: processor = BatchSpanProcessor( OTLPSpanExporter( endpoint=config.otel_exporter_oltp_endpoint, - headers=config.otel_exporter_oltp_headers, + headers=parse_headers(config.otel_exporter_oltp_headers), ), ) diff --git a/hatchet_sdk/utils/types.py b/hatchet_sdk/utils/types.py index 30e469f7..16ab43f6 100644 --- a/hatchet_sdk/utils/types.py +++ b/hatchet_sdk/utils/types.py @@ -1,4 +1,4 @@ -from typing import Type +from typing import Any, Type from pydantic import BaseModel @@ -6,3 +6,6 @@ class WorkflowValidator(BaseModel): workflow_input: Type[BaseModel] | None = None step_output: Type[BaseModel] | None = None + + +JSONSerializableDict = dict[str, Any] diff --git a/hatchet_sdk/v2/__init__.py b/hatchet_sdk/v2/__init__.py new file mode 100644 index 00000000..e4d009d2 --- /dev/null +++ b/hatchet_sdk/v2/__init__.py @@ -0,0 +1,3 @@ +from .hatchet import Hatchet as Hatchet +from .workflows import Workflow as Workflow +from .workflows import WorkflowConfig as WorkflowConfig diff --git a/hatchet_sdk/v2/callable.py b/hatchet_sdk/v2/callable.py deleted file mode 100644 index 097a7d87..00000000 --- a/hatchet_sdk/v2/callable.py +++ /dev/null @@ -1,202 +0,0 @@ -import asyncio -from typing import ( - Any, - Callable, - Dict, - Generic, - List, - Optional, - TypedDict, - TypeVar, - Union, -) - -from hatchet_sdk.clients.admin import ChildTriggerWorkflowOptions -from hatchet_sdk.context.context import Context -from hatchet_sdk.contracts.workflows_pb2 import ( # type: ignore[attr-defined] - CreateStepRateLimit, - CreateWorkflowJobOpts, - CreateWorkflowStepOpts, - CreateWorkflowVersionOpts, - DesiredWorkerLabels, - StickyStrategy, - WorkflowConcurrencyOpts, - WorkflowKind, -) -from hatchet_sdk.labels import DesiredWorkerLabel -from hatchet_sdk.logger import logger -from hatchet_sdk.rate_limit import RateLimit -from hatchet_sdk.v2.concurrency import ConcurrencyFunction -from hatchet_sdk.workflow_run import RunRef - -T = TypeVar("T") - - -class HatchetCallable(Generic[T]): - def __init__( - self, - func: Callable[[Context], T], - durable: bool = False, - name: str = "", - auto_register: bool = True, - on_events: list[str] | None = None, - on_crons: list[str] | None = None, - version: str = "", - timeout: str = "60m", - schedule_timeout: str = "5m", - sticky: StickyStrategy = None, - retries: int = 0, - rate_limits: List[RateLimit] | None = None, - concurrency: ConcurrencyFunction | None = None, - on_failure: Union["HatchetCallable[T]", None] = None, - desired_worker_labels: dict[str, DesiredWorkerLabel] = {}, - default_priority: int | None = None, - ): - self.func = func - - on_events = on_events or [] - on_crons = on_crons or [] - - limits = None - if rate_limits: - limits = [rate_limit._req for rate_limit in rate_limits or []] - - self.function_desired_worker_labels = {} - - for key, d in desired_worker_labels.items(): - value = d["value"] if "value" in d else None - self.function_desired_worker_labels[key] = DesiredWorkerLabels( - strValue=str(value) if not isinstance(value, int) else None, - intValue=value if isinstance(value, int) else None, - required=d["required"] if "required" in d else None, - weight=d["weight"] if "weight" in d else None, - comparator=d["comparator"] if "comparator" in d else None, - ) - self.sticky = sticky - self.default_priority = default_priority - self.durable = durable - self.function_name = name.lower() or str(func.__name__).lower() - self.function_version = version - self.function_on_events = on_events - self.function_on_crons = on_crons - self.function_timeout = timeout - self.function_schedule_timeout = schedule_timeout - self.function_retries = retries - self.function_rate_limits = limits - self.function_concurrency = concurrency - self.function_on_failure = on_failure - self.function_namespace = "default" - self.function_auto_register = auto_register - - self.is_coroutine = False - - if asyncio.iscoroutinefunction(func): - self.is_coroutine = True - - def __call__(self, context: Context) -> T: - return self.func(context) - - def with_namespace(self, namespace: str) -> None: - if namespace is not None and namespace != "": - self.function_namespace = namespace - self.function_name = namespace + self.function_name - - def to_workflow_opts(self) -> CreateWorkflowVersionOpts: - kind: WorkflowKind = WorkflowKind.FUNCTION - - if self.durable: - kind = WorkflowKind.DURABLE - - on_failure_job: CreateWorkflowJobOpts | None = None - - if self.function_on_failure is not None: - on_failure_job = CreateWorkflowJobOpts( - name=self.function_name + "-on-failure", - steps=[ - self.function_on_failure.to_step(), - ], - ) - - concurrency: WorkflowConcurrencyOpts | None = None - - if self.function_concurrency is not None: - self.function_concurrency.set_namespace(self.function_namespace) - concurrency = WorkflowConcurrencyOpts( - action=self.function_concurrency.get_action_name(), - max_runs=self.function_concurrency.max_runs, - limit_strategy=self.function_concurrency.limit_strategy, - ) - - validated_priority = ( - max(1, min(3, self.default_priority)) if self.default_priority else None - ) - if validated_priority != self.default_priority: - logger.warning( - "Warning: Default Priority Must be between 1 and 3 -- inclusively. Adjusted to be within the range." - ) - - return CreateWorkflowVersionOpts( - name=self.function_name, - kind=kind, - version=self.function_version, - event_triggers=self.function_on_events, - cron_triggers=self.function_on_crons, - schedule_timeout=self.function_schedule_timeout, - sticky=self.sticky, - on_failure_job=on_failure_job, - concurrency=concurrency, - jobs=[ - CreateWorkflowJobOpts( - name=self.function_name, - steps=[ - self.to_step(), - ], - ) - ], - default_priority=validated_priority, - ) - - def to_step(self) -> CreateWorkflowStepOpts: - return CreateWorkflowStepOpts( - readable_id=self.function_name, - action=self.get_action_name(), - timeout=self.function_timeout, - inputs="{}", - parents=[], - retries=self.function_retries, - rate_limits=self.function_rate_limits, - worker_labels=self.function_desired_worker_labels, - ) - - def get_action_name(self) -> str: - return self.function_namespace + ":" + self.function_name - - -class DurableContext(Context): - def run( - self, - function: str | Callable[[Context], Any], - input: dict[Any, Any] = {}, - key: str | None = None, - options: ChildTriggerWorkflowOptions | None = None, - ) -> "RunRef[T]": - worker_id = self.worker.id() - - workflow_name = function - - if not isinstance(function, str): - workflow_name = function.function_name - - # if ( - # options is not None - # and "sticky" in options - # and options["sticky"] == True - # and not self.worker.has_workflow(workflow_name) - # ): - # raise Exception( - # f"cannot run with sticky: workflow {workflow_name} is not registered on the worker" - # ) - - trigger_options = self._prepare_workflow_options(key, options, worker_id) - - return self.admin_client.run(function, input, trigger_options) diff --git a/hatchet_sdk/v2/concurrency.py b/hatchet_sdk/v2/concurrency.py deleted file mode 100644 index 73d9e3b4..00000000 --- a/hatchet_sdk/v2/concurrency.py +++ /dev/null @@ -1,47 +0,0 @@ -from typing import Any, Callable - -from hatchet_sdk.context.context import Context -from hatchet_sdk.contracts.workflows_pb2 import ( # type: ignore[attr-defined] - ConcurrencyLimitStrategy, -) - - -class ConcurrencyFunction: - def __init__( - self, - func: Callable[[Context], str], - name: str = "concurrency", - max_runs: int = 1, - limit_strategy: ConcurrencyLimitStrategy = ConcurrencyLimitStrategy.GROUP_ROUND_ROBIN, - ): - self.func = func - self.name = name - self.max_runs = max_runs - self.limit_strategy = limit_strategy - self.namespace = "default" - - def set_namespace(self, namespace: str) -> None: - self.namespace = namespace - - def get_action_name(self) -> str: - return self.namespace + ":" + self.name - - def __call__(self, *args: Any, **kwargs: Any) -> str: - return self.func(*args, **kwargs) - - def __str__(self) -> str: - return f"{self.name}({self.max_runs})" - - def __repr__(self) -> str: - return f"{self.name}({self.max_runs})" - - -def concurrency( - name: str = "", - max_runs: int = 1, - limit_strategy: ConcurrencyLimitStrategy = ConcurrencyLimitStrategy.GROUP_ROUND_ROBIN, -) -> Callable[[Callable[[Context], str]], ConcurrencyFunction]: - def inner(func: Callable[[Context], str]) -> ConcurrencyFunction: - return ConcurrencyFunction(func, name, max_runs, limit_strategy) - - return inner diff --git a/hatchet_sdk/v2/hatchet.py b/hatchet_sdk/v2/hatchet.py index 4dd3faf0..3e02cdae 100644 --- a/hatchet_sdk/v2/hatchet.py +++ b/hatchet_sdk/v2/hatchet.py @@ -1,224 +1,145 @@ -from typing import Any, Callable, TypeVar, Union - -from hatchet_sdk import Worker -from hatchet_sdk.context.context import Context -from hatchet_sdk.contracts.workflows_pb2 import ( # type: ignore[attr-defined] - ConcurrencyLimitStrategy, - StickyStrategy, -) -from hatchet_sdk.hatchet import Hatchet as HatchetV1 -from hatchet_sdk.hatchet import workflow -from hatchet_sdk.labels import DesiredWorkerLabel -from hatchet_sdk.rate_limit import RateLimit -from hatchet_sdk.v2.callable import DurableContext, HatchetCallable -from hatchet_sdk.v2.concurrency import ConcurrencyFunction -from hatchet_sdk.worker.worker import register_on_worker - -T = TypeVar("T") - - -def function( - name: str = "", - auto_register: bool = True, - on_events: list[str] | None = None, - on_crons: list[str] | None = None, - version: str = "", - timeout: str = "60m", - schedule_timeout: str = "5m", - sticky: StickyStrategy = None, - retries: int = 0, - rate_limits: list[RateLimit] | None = None, - desired_worker_labels: dict[str, DesiredWorkerLabel] = {}, - concurrency: ConcurrencyFunction | None = None, - on_failure: Union["HatchetCallable[T]", None] = None, - default_priority: int | None = None, -) -> Callable[[Callable[[Context], str]], HatchetCallable[T]]: - def inner(func: Callable[[Context], T]) -> HatchetCallable[T]: - return HatchetCallable( - func=func, - name=name, - auto_register=auto_register, - on_events=on_events, - on_crons=on_crons, - version=version, - timeout=timeout, - schedule_timeout=schedule_timeout, - sticky=sticky, - retries=retries, - rate_limits=rate_limits, - desired_worker_labels=desired_worker_labels, - concurrency=concurrency, - on_failure=on_failure, - default_priority=default_priority, - ) - - return inner - - -def durable( - name: str = "", - auto_register: bool = True, - on_events: list[str] | None = None, - on_crons: list[str] | None = None, - version: str = "", - timeout: str = "60m", - schedule_timeout: str = "5m", - sticky: StickyStrategy = None, - retries: int = 0, - rate_limits: list[RateLimit] | None = None, - desired_worker_labels: dict[str, DesiredWorkerLabel] = {}, - concurrency: ConcurrencyFunction | None = None, - on_failure: HatchetCallable[T] | None = None, - default_priority: int | None = None, -) -> Callable[[HatchetCallable[T]], HatchetCallable[T]]: - def inner(func: HatchetCallable[T]) -> HatchetCallable[T]: - func.durable = True - - f = function( - name=name, - auto_register=auto_register, - on_events=on_events, - on_crons=on_crons, - version=version, - timeout=timeout, - schedule_timeout=schedule_timeout, - sticky=sticky, - retries=retries, - rate_limits=rate_limits, - desired_worker_labels=desired_worker_labels, - concurrency=concurrency, - on_failure=on_failure, - default_priority=default_priority, - ) +import asyncio +import logging +from typing import TYPE_CHECKING, Any, Optional - resp = f(func) +from typing_extensions import deprecated - resp.durable = True +from hatchet_sdk.clients.rest_client import RestApi +from hatchet_sdk.features.cron import CronClient +from hatchet_sdk.features.scheduled import ScheduledClient +from hatchet_sdk.loader import ClientConfig +from hatchet_sdk.v2.workflows import StepType, step_factory +from hatchet_sdk.worker import Worker - return resp +from ..client import Client, new_client, new_client_raw +from ..clients.admin import AdminClient +from ..clients.dispatcher.dispatcher import DispatcherClient +from ..clients.events import EventClient +from ..clients.run_event_listener import RunEventListenerClient +from ..logger import logger - return inner +class HatchetRest: + """ + Main client for interacting with the Hatchet API. -def concurrency( - name: str = "concurrency", - max_runs: int = 1, - limit_strategy: ConcurrencyLimitStrategy = ConcurrencyLimitStrategy.GROUP_ROUND_ROBIN, -) -> Callable[[Callable[[Context], str]], ConcurrencyFunction]: - def inner(func: Callable[[Context], str]) -> ConcurrencyFunction: - return ConcurrencyFunction(func, name, max_runs, limit_strategy) + This class provides access to various client interfaces and utility methods + for working with Hatchet via the REST API, - return inner + Attributes: + rest (RestApi): Interface for REST API operations. + """ + def __init__(self, config: ClientConfig = ClientConfig()): + self.rest = RestApi(config.server_url, config.token, config.tenant_id) -class Hatchet(HatchetV1): - dag = staticmethod(workflow) - concurrency = staticmethod(concurrency) - functions: list[HatchetCallable[T]] = [] +class Hatchet: + """ + Main client for interacting with the Hatchet SDK. - def function( - self, - name: str = "", - auto_register: bool = True, - on_events: list[str] | None = None, - on_crons: list[str] | None = None, - version: str = "", - timeout: str = "60m", - schedule_timeout: str = "5m", - retries: int = 0, - rate_limits: list[RateLimit] | None = None, - desired_worker_labels: dict[str, DesiredWorkerLabel] = {}, - concurrency: ConcurrencyFunction | None = None, - on_failure: Union["HatchetCallable[T]", None] = None, - default_priority: int | None = None, - ) -> Callable[[Callable[[Context], Any]], Callable[[Context], Any]]: - resp = function( - name=name, - auto_register=auto_register, - on_events=on_events, - on_crons=on_crons, - version=version, - timeout=timeout, - schedule_timeout=schedule_timeout, - retries=retries, - rate_limits=rate_limits, - desired_worker_labels=desired_worker_labels, - concurrency=concurrency, - on_failure=on_failure, - default_priority=default_priority, - ) + This class provides access to various client interfaces and utility methods + for working with Hatchet workers, workflows, and steps. - def wrapper(func: Callable[[Context], str]) -> HatchetCallable[T]: - wrapped_resp = resp(func) + Attributes: + cron (CronClient): Interface for cron trigger operations. - if wrapped_resp.function_auto_register: - self.functions.append(wrapped_resp) + admin (AdminClient): Interface for administrative operations. + dispatcher (DispatcherClient): Interface for dispatching operations. + event (EventClient): Interface for event-related operations. + rest (RestApi): Interface for REST API operations. + """ - wrapped_resp.with_namespace(self._client.config.namespace) + _client: Client + cron: CronClient + scheduled: ScheduledClient - return wrapped_resp + @classmethod + def from_environment( + cls, defaults: ClientConfig = ClientConfig(), **kwargs: Any + ) -> "Hatchet": + return cls(client=new_client(defaults), **kwargs) - return wrapper + @classmethod + def from_config(cls, config: ClientConfig, **kwargs: Any) -> "Hatchet": + return cls(client=new_client_raw(config), **kwargs) - def durable( + def __init__( self, - name: str = "", - auto_register: bool = True, - on_events: list[str] | None = None, - on_crons: list[str] | None = None, - version: str = "", - timeout: str = "60m", - schedule_timeout: str = "5m", - sticky: StickyStrategy = None, - retries: int = 0, - rate_limits: list[RateLimit] | None = None, - desired_worker_labels: dict[str, DesiredWorkerLabel] = {}, - concurrency: ConcurrencyFunction | None = None, - on_failure: Union["HatchetCallable[T]", None] = None, - default_priority: int | None = None, - ) -> Callable[[Callable[[DurableContext], Any]], Callable[[DurableContext], Any]]: - resp = durable( - name=name, - auto_register=auto_register, - on_events=on_events, - on_crons=on_crons, - version=version, - timeout=timeout, - schedule_timeout=schedule_timeout, - sticky=sticky, - retries=retries, - rate_limits=rate_limits, - desired_worker_labels=desired_worker_labels, - concurrency=concurrency, - on_failure=on_failure, - default_priority=default_priority, - ) - - def wrapper(func: HatchetCallable[T]) -> HatchetCallable[T]: - wrapped_resp = resp(func) - - if wrapped_resp.function_auto_register: - self.functions.append(wrapped_resp) - - wrapped_resp.with_namespace(self._client.config.namespace) - - return wrapped_resp - - return wrapper + debug: bool = False, + client: Optional[Client] = None, + config: ClientConfig = ClientConfig(), + ): + """ + Initialize a new Hatchet instance. + + Args: + debug (bool, optional): Enable debug logging. Defaults to False. + client (Optional[Client], optional): A pre-configured Client instance. Defaults to None. + config (ClientConfig, optional): Configuration for creating a new Client. Defaults to ClientConfig(). + """ + if client is not None: + self._client = client + else: + self._client = new_client(config, debug) + + if debug: + logger.setLevel(logging.DEBUG) + + self.cron = CronClient(self._client) + self.scheduled = ScheduledClient(self._client) + + @property + @deprecated( + "Direct access to client is deprecated and will be removed in a future version. Use specific client properties (Hatchet.admin, Hatchet.dispatcher, Hatchet.event, Hatchet.rest) instead. [0.32.0]", + ) + def client(self) -> Client: + return self._client + + @property + def admin(self) -> AdminClient: + return self._client.admin + + @property + def dispatcher(self) -> DispatcherClient: + return self._client.dispatcher + + @property + def event(self) -> EventClient: + return self._client.event + + @property + def rest(self) -> RestApi: + return self._client.rest + + @property + def listener(self) -> RunEventListenerClient: + return self._client.listener + + @property + def config(self) -> ClientConfig: + return self._client.config + + @property + def tenant_id(self) -> str: + return self._client.config.tenant_id + + step = staticmethod(step_factory(type=StepType.DEFAULT)) + on_failure_step = staticmethod(step_factory(type=StepType.ON_FAILURE)) def worker( self, name: str, max_runs: int | None = None, labels: dict[str, str | int] = {} - ): - worker = Worker( + ) -> Worker: + try: + loop = asyncio.get_running_loop() + except RuntimeError: + loop = None + + return Worker( name=name, max_runs=max_runs, labels=labels, config=self._client.config, debug=self._client.debug, + owned_loop=loop is None, ) - - for func in self.functions: - register_on_worker(func, worker) - - return worker diff --git a/hatchet_sdk/v2/workflows.py b/hatchet_sdk/v2/workflows.py new file mode 100644 index 00000000..2c068e27 --- /dev/null +++ b/hatchet_sdk/v2/workflows.py @@ -0,0 +1,296 @@ +import asyncio +from enum import Enum +from typing import Any, Callable, ParamSpec, Type, TypeVar, Union + +from pydantic import BaseModel, ConfigDict + +from hatchet_sdk.clients.rest_client import RestApi +from hatchet_sdk.context.context import Context +from hatchet_sdk.contracts.workflows_pb2 import ( + ConcurrencyLimitStrategy, + CreateStepRateLimit, + CreateWorkflowJobOpts, + CreateWorkflowStepOpts, + CreateWorkflowVersionOpts, + DesiredWorkerLabels, + StickyStrategy, + WorkflowConcurrencyOpts, + WorkflowKind, +) +from hatchet_sdk.labels import DesiredWorkerLabel +from hatchet_sdk.rate_limit import RateLimit + +from ..logger import logger + +R = TypeVar("R") +P = ParamSpec("P") + + +class ConcurrencyExpression: + """ + Defines concurrency limits for a workflow using a CEL expression. + + Args: + expression (str): CEL expression to determine concurrency grouping. (i.e. "input.user_id") + max_runs (int): Maximum number of concurrent workflow runs. + limit_strategy (ConcurrencyLimitStrategy): Strategy for handling limit violations. + + Example: + ConcurrencyExpression("input.user_id", 5, ConcurrencyLimitStrategy.CANCEL_IN_PROGRESS) + """ + + def __init__( + self, expression: str, max_runs: int, limit_strategy: ConcurrencyLimitStrategy + ): + self.expression = expression + self.max_runs = max_runs + self.limit_strategy = limit_strategy + + +class EmptyModel(BaseModel): + model_config = ConfigDict(extra="allow") + + +class WorkflowConfig(BaseModel): + model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True) + name: str = "" + on_events: list[str] = [] + on_crons: list[str] = [] + version: str = "" + timeout: str = "60m" + schedule_timeout: str = "5m" + sticky: Union[StickyStrategy, None] = None + default_priority: int = 1 + concurrency: ConcurrencyExpression | None = None + input_validator: Type[BaseModel] = EmptyModel + + +class StepType(str, Enum): + DEFAULT = "default" + CONCURRENCY = "concurrency" + ON_FAILURE = "on_failure" + + +class Step: + def __init__( + self, + fn: Callable[[Any, Context], R], + type: StepType, + name: str = "", + timeout: str = "60m", + parents: list[str] = [], + retries: int = 0, + rate_limits: list[CreateStepRateLimit] = [], + desired_worker_labels: dict[str, DesiredWorkerLabels] = {}, + backoff_factor: float | None = None, + backoff_max_seconds: int | None = None, + ) -> None: + self.fn = fn + self.is_async_function = bool(asyncio.iscoroutinefunction(fn)) + + self.type = type + self.timeout = timeout + self.name = name + self.parents = parents + self.retries = retries + self.rate_limits = rate_limits + self.desired_worker_labels = desired_worker_labels + self.backoff_factor = backoff_factor + self.backoff_max_seconds = backoff_max_seconds + self.concurrency__max_runs = 1 + self.concurrency__limit_strategy = ConcurrencyLimitStrategy.CANCEL_IN_PROGRESS + + def __call__(self, *args: Any, **kwargs: Any) -> R: + return self.fn(*args, **kwargs) + + +class Workflow: + config: WorkflowConfig = WorkflowConfig() + + def get_service_name(self, namespace: str) -> str: + return f"{namespace}{self.config.name.lower()}" + + def _get_steps_by_type(self, step_type: StepType) -> list[Step]: + return [ + attr + for _, attr in self.__class__.__dict__.items() + if isinstance(attr, Step) and attr.type == step_type + ] + + @property + def on_failure_steps(self) -> list[Step]: + return self._get_steps_by_type(StepType.ON_FAILURE) + + @property + def concurrency_actions(self) -> list[Step]: + return self._get_steps_by_type(StepType.CONCURRENCY) + + @property + def default_steps(self) -> list[Step]: + return self._get_steps_by_type(StepType.DEFAULT) + + @property + def steps(self) -> list[Step]: + return self.default_steps + self.concurrency_actions + self.on_failure_steps + + def create_action_name(self, namespace: str, step: Step) -> str: + return self.get_service_name(namespace) + ":" + step.name + + def __init__(self) -> None: + self.config.name = self.config.name or str(self.__class__.__name__) + + def get_name(self, namespace: str) -> str: + return namespace + self.config.name + + def validate_concurrency_actions( + self, service_name: str + ) -> WorkflowConcurrencyOpts | None: + if len(self.concurrency_actions) > 0 and self.config.concurrency: + raise ValueError( + "Error: Both concurrencyActions and concurrency_expression are defined. Please use only one concurrency configuration method." + ) + + if len(self.concurrency_actions) > 0: + action = self.concurrency_actions[0] + + return WorkflowConcurrencyOpts( + action=service_name + ":" + action.name, + max_runs=action.concurrency__max_runs, + limit_strategy=action.concurrency__limit_strategy, + ) + + if self.config.concurrency: + return WorkflowConcurrencyOpts( + expression=self.config.concurrency.expression, + max_runs=self.config.concurrency.max_runs, + limit_strategy=self.config.concurrency.limit_strategy, + ) + + return None + + def validate_on_failure_steps( + self, name: str, service_name: str + ) -> CreateWorkflowJobOpts | None: + if not self.on_failure_steps: + return None + + on_failure_step = next(iter(self.on_failure_steps)) + + return CreateWorkflowJobOpts( + name=name + "-on-failure", + steps=[ + CreateWorkflowStepOpts( + readable_id=on_failure_step.name, + action=service_name + ":" + on_failure_step.name, + timeout=on_failure_step.timeout or "60s", + inputs="{}", + parents=[], + retries=on_failure_step.retries, + rate_limits=on_failure_step.rate_limits, + backoff_factor=on_failure_step.backoff_factor, + backoff_max_seconds=on_failure_step.backoff_max_seconds, + ) + ], + ) + + def validate_priority(self, default_priority: int | None) -> int | None: + validated_priority = ( + max(1, min(3, default_priority)) if default_priority else None + ) + if validated_priority != default_priority: + logger.warning( + "Warning: Default Priority Must be between 1 and 3 -- inclusively. Adjusted to be within the range." + ) + + return validated_priority + + def get_create_opts(self, namespace: str) -> CreateWorkflowVersionOpts: + service_name = self.get_service_name(namespace) + + name = self.get_name(namespace) + event_triggers = [namespace + event for event in self.config.on_events] + + create_step_opts = [ + CreateWorkflowStepOpts( + readable_id=step.name, + action=service_name + ":" + step.name, + timeout=step.timeout or "60s", + inputs="{}", + parents=[x for x in step.parents], + retries=step.retries, + rate_limits=step.rate_limits, + worker_labels=step.desired_worker_labels, + backoff_factor=step.backoff_factor, + backoff_max_seconds=step.backoff_max_seconds, + ) + for step in self.steps + ] + + concurrency = self.validate_concurrency_actions(service_name) + on_failure_job = self.validate_on_failure_steps(name, service_name) + validated_priority = self.validate_priority(self.config.default_priority) + + return CreateWorkflowVersionOpts( + name=name, + kind=WorkflowKind.DAG, + version=self.config.version, + event_triggers=event_triggers, + cron_triggers=self.config.on_crons, + schedule_timeout=self.config.schedule_timeout, + sticky=self.config.sticky, + jobs=[ + CreateWorkflowJobOpts( + name=name, + steps=create_step_opts, + ) + ], + on_failure_job=on_failure_job, + concurrency=concurrency, + default_priority=validated_priority, + ) + + +def transform_desired_worker_label(d: DesiredWorkerLabel) -> DesiredWorkerLabels: + value = d.value + return DesiredWorkerLabels( + strValue=value if not isinstance(value, int) else None, + intValue=value if isinstance(value, int) else None, + required=d.required, + weight=d.weight, + comparator=d.comparator, # type: ignore[arg-type] + ) + + +def step_factory( + type: StepType, +) -> Callable[..., Callable[[Callable[[Any, Context], R]], Step]]: + def _step( + name: str = "", + timeout: str = "60m", + parents: list[str] = [], + retries: int = 0, + rate_limits: list[RateLimit] = [], + desired_worker_labels: dict[str, DesiredWorkerLabel] = {}, + backoff_factor: float | None = None, + backoff_max_seconds: int | None = None, + ) -> Callable[[Callable[[Any, Context], R]], Step]: + def inner(func: Callable[[Any, Context], R]) -> Step: + return Step( + fn=func, + type=type, + name=name.lower() or str(func.__name__).lower(), + timeout=timeout, + parents=parents, + retries=retries, + rate_limits=[r for rate_limit in rate_limits if (r := rate_limit._req)], + desired_worker_labels={ + key: transform_desired_worker_label(d) + for key, d in desired_worker_labels.items() + }, + backoff_factor=backoff_factor, + backoff_max_seconds=backoff_max_seconds, + ) + + return inner + + return _step diff --git a/hatchet_sdk/worker/action_listener_process.py b/hatchet_sdk/worker/action_listener_process.py index 08508607..6017ab68 100644 --- a/hatchet_sdk/worker/action_listener_process.py +++ b/hatchet_sdk/worker/action_listener_process.py @@ -4,16 +4,16 @@ import time from dataclasses import dataclass, field from multiprocessing import Queue -from typing import Any, List, Mapping, Optional +from typing import Any, List, Literal, Mapping, Optional import grpc -from hatchet_sdk.clients.dispatcher.action_listener import Action -from hatchet_sdk.clients.dispatcher.dispatcher import ( +from hatchet_sdk.clients.dispatcher.action_listener import ( + Action, ActionListener, GetActionListenerRequest, - new_dispatcher, ) +from hatchet_sdk.clients.dispatcher.dispatcher import new_dispatcher from hatchet_sdk.contracts.dispatcher_pb2 import ( GROUP_KEY_EVENT_TYPE_STARTED, STEP_EVENT_TYPE_STARTED, @@ -30,10 +30,11 @@ class ActionEvent: action: Action type: Any # TODO type - payload: Optional[str] = None + payload: str -STOP_LOOP = "STOP_LOOP" # Sentinel object to stop the loop +STOP_LOOP_TYPE = Literal["STOP_LOOP"] +STOP_LOOP: STOP_LOOP_TYPE = "STOP_LOOP" # Sentinel object to stop the loop # TODO link to a block post BLOCKED_THREAD_WARNING = ( @@ -41,7 +42,7 @@ class ActionEvent: ) -def noop_handler(): +def noop_handler() -> None: pass @@ -51,22 +52,22 @@ class WorkerActionListenerProcess: actions: List[str] max_runs: int config: ClientConfig - action_queue: Queue - event_queue: Queue + action_queue: "Queue[Action]" + event_queue: "Queue[ActionEvent | STOP_LOOP_TYPE]" handle_kill: bool = True debug: bool = False - labels: dict = field(default_factory=dict) + labels: dict[str, str | int] = field(default_factory=dict) - listener: ActionListener = field(init=False, default=None) + listener: ActionListener = field(init=False) killing: bool = field(init=False, default=False) - action_loop_task: asyncio.Task = field(init=False, default=None) - event_send_loop_task: asyncio.Task = field(init=False, default=None) + action_loop_task: asyncio.Task[None] | None = field(init=False, default=None) + event_send_loop_task: asyncio.Task[None] | None = field(init=False, default=None) - running_step_runs: Mapping[str, float] = field(init=False, default_factory=dict) + running_step_runs: dict[str, float] = field(init=False, default_factory=dict) - def __post_init__(self): + def __post_init__(self) -> None: if self.debug: logger.setLevel(logging.DEBUG) @@ -77,7 +78,7 @@ def __post_init__(self): signal.SIGQUIT, lambda: asyncio.create_task(self.exit_gracefully()) ) - async def start(self, retry_attempt=0): + async def start(self, retry_attempt: int = 0) -> None: if retry_attempt > 5: logger.error("could not start action listener") return @@ -108,13 +109,13 @@ async def start(self, retry_attempt=0): self.blocked_main_loop = asyncio.create_task(self.start_blocked_main_loop()) # TODO move event methods to separate class - async def _get_event(self): + async def _get_event(self) -> ActionEvent | STOP_LOOP_TYPE: loop = asyncio.get_running_loop() return await loop.run_in_executor(None, self.event_queue.get) - async def start_event_send_loop(self): + async def start_event_send_loop(self) -> None: while True: - event: ActionEvent = await self._get_event() + event = await self._get_event() if event == STOP_LOOP: logger.debug("stopping event send loop...") break @@ -122,11 +123,11 @@ async def start_event_send_loop(self): logger.debug(f"tx: event: {event.action.action_id}/{event.type}") asyncio.create_task(self.send_event(event)) - async def start_blocked_main_loop(self): + async def start_blocked_main_loop(self) -> None: threshold = 1 while not self.killing: count = 0 - for step_run_id, start_time in self.running_step_runs.items(): + for _, start_time in self.running_step_runs.items(): diff = self.now() - start_time if diff > threshold: count += 1 @@ -135,7 +136,7 @@ async def start_blocked_main_loop(self): logger.warning(f"{BLOCKED_THREAD_WARNING}: Waiting Steps {count}") await asyncio.sleep(1) - async def send_event(self, event: ActionEvent, retry_attempt: int = 1): + async def send_event(self, event: ActionEvent, retry_attempt: int = 1) -> None: try: match event.action.action_type: # FIXME: all events sent from an execution of a function are of type ActionType.START_STEP_RUN since @@ -185,10 +186,10 @@ async def send_event(self, event: ActionEvent, retry_attempt: int = 1): await exp_backoff_sleep(retry_attempt, 1) await self.send_event(event, retry_attempt + 1) - def now(self): + def now(self) -> float: return time.time() - async def start_action_loop(self): + async def start_action_loop(self) -> None: try: async for action in self.listener: if action is None: @@ -201,6 +202,7 @@ async def start_action_loop(self): ActionEvent( action=action, type=STEP_EVENT_TYPE_STARTED, # TODO ack type + payload="", ) ) logger.info( @@ -220,6 +222,7 @@ async def start_action_loop(self): ActionEvent( action=action, type=GROUP_KEY_EVENT_TYPE_STARTED, # TODO ack type + payload="", ) ) logger.info( @@ -239,9 +242,9 @@ async def start_action_loop(self): finally: logger.info("action loop closed") if not self.killing: - await self.exit_gracefully(skip_unregister=True) + await self.exit_gracefully() - async def cleanup(self): + async def cleanup(self) -> None: self.killing = True if self.listener is not None: @@ -249,7 +252,7 @@ async def cleanup(self): self.event_queue.put(STOP_LOOP) - async def exit_gracefully(self, skip_unregister=False): + async def exit_gracefully(self) -> None: if self.killing: return @@ -262,13 +265,13 @@ async def exit_gracefully(self, skip_unregister=False): logger.info("action listener closed") - def exit_forcefully(self): + def exit_forcefully(self) -> None: asyncio.run(self.cleanup()) logger.debug("forcefully closing listener...") -def worker_action_listener_process(*args, **kwargs): - async def run(): +def worker_action_listener_process(*args: Any, **kwargs: Any) -> None: + async def run() -> None: process = WorkerActionListenerProcess(*args, **kwargs) await process.start() # Keep the process running diff --git a/hatchet_sdk/worker/runner/run_loop_manager.py b/hatchet_sdk/worker/runner/run_loop_manager.py index 27ed788c..972c9cd5 100644 --- a/hatchet_sdk/worker/runner/run_loop_manager.py +++ b/hatchet_sdk/worker/runner/run_loop_manager.py @@ -2,7 +2,7 @@ import logging from dataclasses import dataclass, field from multiprocessing import Queue -from typing import Callable, TypeVar +from typing import Callable, Literal, TypeVar from hatchet_sdk import Context from hatchet_sdk.client import Client, new_client_raw @@ -10,10 +10,12 @@ from hatchet_sdk.loader import ClientConfig from hatchet_sdk.logger import logger from hatchet_sdk.utils.types import WorkflowValidator +from hatchet_sdk.worker.action_listener_process import ActionEvent from hatchet_sdk.worker.runner.runner import Runner from hatchet_sdk.worker.runner.utils.capture_logs import capture_logs -STOP_LOOP = "STOP_LOOP" +STOP_LOOP_TYPE = Literal["STOP_LOOP"] +STOP_LOOP: STOP_LOOP_TYPE = "STOP_LOOP" T = TypeVar("T") @@ -25,28 +27,28 @@ class WorkerActionRunLoopManager: validator_registry: dict[str, WorkflowValidator] max_runs: int | None config: ClientConfig - action_queue: Queue - event_queue: Queue + action_queue: "Queue[Action | STOP_LOOP_TYPE]" + event_queue: "Queue[ActionEvent]" loop: asyncio.AbstractEventLoop handle_kill: bool = True debug: bool = False labels: dict[str, str | int] = field(default_factory=dict) - client: Client = field(init=False, default=None) + client: Client = field(init=False) killing: bool = field(init=False, default=False) - runner: Runner = field(init=False, default=None) + runner: Runner | None = field(init=False, default=None) - def __post_init__(self): + def __post_init__(self) -> None: if self.debug: logger.setLevel(logging.DEBUG) self.client = new_client_raw(self.config, self.debug) self.start() - def start(self, retry_count=1): + def start(self, retry_count: int = 1) -> None: k = self.loop.create_task(self.async_start(retry_count)) - async def async_start(self, retry_count=1): + async def async_start(self, retry_count: int = 1) -> None: await capture_logs( self.client.logInterceptor, self.client.event, @@ -63,6 +65,7 @@ async def _async_start(self, retry_count: int = 1) -> None: def cleanup(self) -> None: self.killing = True + ## TODO: The action queue is a queue of `Action`, so I don't think this will work self.action_queue.put(STOP_LOOP) async def wait_for_tasks(self) -> None: @@ -83,7 +86,8 @@ async def _start_action_loop(self) -> None: logger.debug(f"'{self.name}' waiting for {list(self.action_registry.keys())}") while not self.killing: - action: Action = await self._get_action() + action = await self._get_action() + ## TODO: This is a queue of `Action`, so I don't think this will work if action == STOP_LOOP: logger.debug("stopping action runner loop...") break @@ -91,7 +95,7 @@ async def _start_action_loop(self) -> None: self.runner.run(action) logger.debug("action runner loop stopped") - async def _get_action(self): + async def _get_action(self) -> Action | STOP_LOOP_TYPE: return await self.loop.run_in_executor(None, self.action_queue.get) async def exit_gracefully(self) -> None: diff --git a/hatchet_sdk/worker/runner/runner.py b/hatchet_sdk/worker/runner/runner.py index f72fb04b..a5df9fc5 100644 --- a/hatchet_sdk/worker/runner/runner.py +++ b/hatchet_sdk/worker/runner/runner.py @@ -19,7 +19,7 @@ from hatchet_sdk.clients.dispatcher.dispatcher import new_dispatcher from hatchet_sdk.clients.run_event_listener import new_listener from hatchet_sdk.clients.workflow_listener import PooledWorkflowRunListener -from hatchet_sdk.context import Context # type: ignore[attr-defined] +from hatchet_sdk.context.context import Context from hatchet_sdk.context.worker_context import WorkerContext from hatchet_sdk.contracts.dispatcher_pb2 import ( GROUP_KEY_EVENT_TYPE_COMPLETED, @@ -34,7 +34,6 @@ from hatchet_sdk.logger import logger from hatchet_sdk.utils.tracing import create_tracer, parse_carrier_from_metadata from hatchet_sdk.utils.types import WorkflowValidator -from hatchet_sdk.v2.callable import DurableContext from hatchet_sdk.worker.action_listener_process import ActionEvent from hatchet_sdk.worker.runner.utils.capture_logs import copy_context_vars, sr, wr @@ -276,23 +275,7 @@ def cleanup_run_id(self, run_id: str | None) -> None: if run_id in self.contexts: del self.contexts[run_id] - def create_context( - self, action: Action, action_func: Callable[..., Any] | None - ) -> Context | DurableContext: - if hasattr(action_func, "durable") and getattr(action_func, "durable"): - return DurableContext( - action, - self.dispatcher_client, - self.admin_client, - self.client.event, - self.client.rest, - self.client.workflow_listener, - self.workflow_run_event_listener, - self.worker_context, - self.client.config.namespace, - validator_registry=self.validator_registry, - ) - + def create_context(self, action: Action) -> Context: return Context( action, self.dispatcher_client, @@ -318,16 +301,13 @@ async def handle_start_step_run(self, action: Action) -> None: # Find the corresponding action function from the registry action_func = self.action_registry.get(action_name) - context = self.create_context(action, action_func) + context = self.create_context(action) self.contexts[action.step_run_id] = context if action_func: self.event_queue.put( - ActionEvent( - action=action, - type=STEP_EVENT_TYPE_STARTED, - ) + ActionEvent(action=action, type=STEP_EVENT_TYPE_STARTED, payload="") ) loop = asyncio.get_event_loop() @@ -377,8 +357,7 @@ async def handle_start_group_key_run(self, action: Action) -> None: # send an event that the group key run has started self.event_queue.put( ActionEvent( - action=action, - type=GROUP_KEY_EVENT_TYPE_STARTED, + action=action, type=GROUP_KEY_EVENT_TYPE_STARTED, payload="" ) ) diff --git a/hatchet_sdk/worker/runner/utils/capture_logs.py b/hatchet_sdk/worker/runner/utils/capture_logs.py index 08c57de8..6fec015c 100644 --- a/hatchet_sdk/worker/runner/utils/capture_logs.py +++ b/hatchet_sdk/worker/runner/utils/capture_logs.py @@ -2,11 +2,12 @@ import functools import logging from concurrent.futures import ThreadPoolExecutor +from contextvars import ContextVar from io import StringIO -from typing import Any, Coroutine +from typing import Any, Awaitable, Callable, Coroutine, ItemsView, ParamSpec, TypeVar -from hatchet_sdk import logger from hatchet_sdk.clients.events import EventClient +from hatchet_sdk.logger import logger wr: contextvars.ContextVar[str | None] = contextvars.ContextVar( "workflow_run_id", default=None @@ -16,7 +17,16 @@ ) -def copy_context_vars(ctx_vars, func, *args, **kwargs): +T = TypeVar("T") +P = ParamSpec("P") + + +def copy_context_vars( + ctx_vars: ItemsView[ContextVar[Any], Any], + func: Callable[P, T], + *args: P.args, + **kwargs: P.kwargs, +) -> T: for var, value in ctx_vars: var.set(value) return func(*args, **kwargs) @@ -25,19 +35,20 @@ def copy_context_vars(ctx_vars, func, *args, **kwargs): class InjectingFilter(logging.Filter): # For some reason, only the InjectingFilter has access to the contextvars method sr.get(), # otherwise we would use emit within the CustomLogHandler - def filter(self, record): + def filter(self, record: logging.LogRecord) -> bool: + ## TODO: Change how we do this to not assign to the log record record.workflow_run_id = wr.get() record.step_run_id = sr.get() return True -class CustomLogHandler(logging.StreamHandler): - def __init__(self, event_client: EventClient, stream=None): +class CustomLogHandler(logging.StreamHandler[Any]): + def __init__(self, event_client: EventClient, stream: StringIO | None = None): super().__init__(stream) self.logger_thread_pool = ThreadPoolExecutor(max_workers=1) self.event_client = event_client - def _log(self, line: str, step_run_id: str | None): + def _log(self, line: str, step_run_id: str | None) -> None: try: if not step_run_id: return @@ -46,20 +57,20 @@ def _log(self, line: str, step_run_id: str | None): except Exception as e: logger.error(f"Error logging: {e}") - def emit(self, record): + def emit(self, record: logging.LogRecord) -> None: super().emit(record) log_entry = self.format(record) - self.logger_thread_pool.submit(self._log, log_entry, record.step_run_id) + + ## TODO: Change how we do this to not assign to the log record + self.logger_thread_pool.submit(self._log, log_entry, record.step_run_id) # type: ignore def capture_logs( - logger: logging.Logger, - event_client: EventClient, - func: Coroutine[Any, Any, Any], -): + logger: logging.Logger, event_client: "EventClient", func: Callable[P, Awaitable[T]] +) -> Callable[P, Awaitable[T]]: @functools.wraps(func) - async def wrapper(*args, **kwargs): + async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: if not logger: raise Exception("No logger configured on client") diff --git a/hatchet_sdk/worker/runner/utils/error_with_traceback.py b/hatchet_sdk/worker/runner/utils/error_with_traceback.py index 9c09602f..6aff1cb6 100644 --- a/hatchet_sdk/worker/runner/utils/error_with_traceback.py +++ b/hatchet_sdk/worker/runner/utils/error_with_traceback.py @@ -1,6 +1,6 @@ import traceback -def errorWithTraceback(message: str, e: Exception): +def errorWithTraceback(message: str, e: Exception) -> str: trace = "".join(traceback.format_exception(type(e), e, e.__traceback__)) return f"{message}\n{trace}" diff --git a/hatchet_sdk/worker/worker.py b/hatchet_sdk/worker/worker.py index b6ec1531..9ed0da7c 100644 --- a/hatchet_sdk/worker/worker.py +++ b/hatchet_sdk/worker/worker.py @@ -10,7 +10,7 @@ from multiprocessing import Queue from multiprocessing.process import BaseProcess from types import FrameType -from typing import Any, Callable, TypeVar, get_type_hints +from typing import TYPE_CHECKING, Any, Callable, TypeVar, Union, get_type_hints from aiohttp import web from aiohttp.web_request import Request @@ -19,16 +19,24 @@ from hatchet_sdk import Context from hatchet_sdk.client import Client, new_client_raw +from hatchet_sdk.clients.dispatcher.action_listener import Action from hatchet_sdk.contracts.workflows_pb2 import CreateWorkflowVersionOpts from hatchet_sdk.loader import ClientConfig from hatchet_sdk.logger import logger from hatchet_sdk.utils.types import WorkflowValidator from hatchet_sdk.utils.typing import is_basemodel_subclass -from hatchet_sdk.v2.callable import HatchetCallable -from hatchet_sdk.v2.concurrency import ConcurrencyFunction -from hatchet_sdk.worker.action_listener_process import worker_action_listener_process -from hatchet_sdk.worker.runner.run_loop_manager import WorkerActionRunLoopManager -from hatchet_sdk.workflow import WorkflowInterface +from hatchet_sdk.worker.action_listener_process import ( + ActionEvent, + worker_action_listener_process, +) +from hatchet_sdk.worker.runner.run_loop_manager import ( + STOP_LOOP_TYPE, + WorkerActionRunLoopManager, +) + +if TYPE_CHECKING: + from hatchet_sdk.v2 import Workflow + from hatchet_sdk.v2.workflows import Step T = TypeVar("T") @@ -45,9 +53,6 @@ class WorkerStartOptions: loop: asyncio.AbstractEventLoop | None = field(default=None) -TWorkflow = TypeVar("TWorkflow", bound=object) - - class Worker: def __init__( self, @@ -76,13 +81,13 @@ def __init__( self._status: WorkerStatus self.action_listener_process: BaseProcess - self.action_listener_health_check: asyncio.Task[Any] + self.action_listener_health_check: asyncio.Task[None] self.action_runner: WorkerActionRunLoopManager self.ctx = multiprocessing.get_context("spawn") - self.action_queue: "Queue[Any]" = self.ctx.Queue() - self.event_queue: "Queue[Any]" = self.ctx.Queue() + self.action_queue: "Queue[Action | STOP_LOOP_TYPE]" = self.ctx.Queue() + self.event_queue: "Queue[ActionEvent]" = self.ctx.Queue() self.loop: asyncio.AbstractEventLoop @@ -108,10 +113,7 @@ def register_workflow_from_opts( logger.error(e) sys.exit(1) - def register_workflow(self, workflow: TWorkflow) -> None: - ## Hack for typing - assert isinstance(workflow, WorkflowInterface) - + def register_workflow(self, workflow: Union["Workflow", Any]) -> None: namespace = self.client.config.namespace try: @@ -124,24 +126,22 @@ def register_workflow(self, workflow: TWorkflow) -> None: sys.exit(1) def create_action_function( - action_func: Callable[..., T] + action_func: "Step" ) -> Callable[[Context], T]: def action_function(context: Context) -> T: return action_func(workflow, context) - if asyncio.iscoroutinefunction(action_func): - setattr(action_function, "is_coroutine", True) - else: - setattr(action_function, "is_coroutine", False) + setattr(action_function, "is_coroutine", action_func.is_async_function) return action_function - for action_name, action_func in workflow.get_actions(namespace): - self.action_registry[action_name] = create_action_function(action_func) - return_type = get_type_hints(action_func).get("return") + for step in workflow.steps: + action_name = workflow.create_action_name(namespace, step) + self.action_registry[action_name] = create_action_function(step) + return_type = get_type_hints(step.fn).get("return") self.validator_registry[action_name] = WorkflowValidator( - workflow_input=workflow.input_validator, + workflow_input=workflow.config.input_validator, step_output=return_type if is_basemodel_subclass(return_type) else None, ) @@ -195,12 +195,10 @@ async def start_health_server(self) -> None: logger.info(f"healthcheck server running on port {port}") - def start( - self, options: WorkerStartOptions = WorkerStartOptions() - ) -> Future[asyncio.Task[Any] | None]: + def start(self, options: WorkerStartOptions = WorkerStartOptions()) -> None: self.owned_loop = self.setup_loop(options.loop) - f = asyncio.run_coroutine_threadsafe( + asyncio.run_coroutine_threadsafe( self.async_start(options, _from_start=True), self.loop ) @@ -211,14 +209,12 @@ def start( if self.handle_kill: sys.exit(0) - return f - ## Start methods async def async_start( self, options: WorkerStartOptions = WorkerStartOptions(), _from_start: bool = False, - ) -> Any | None: + ) -> None: main_pid = os.getpid() logger.info("------------------------------------------") logger.info("STARTING HATCHET...") @@ -247,7 +243,7 @@ async def async_start( self._check_listener_health() ) - return await self.action_listener_health_check + await self.action_listener_health_check def _run_action_runner(self) -> WorkerActionRunLoopManager: # Retrieve the shared queue @@ -362,7 +358,8 @@ def exit_forcefully(self) -> None: logger.debug(f"forcefully stopping worker: {self.name}") - self.close() + ## TODO: `self.close` needs to be awaited / used + self.close() # type: ignore[unused-coroutine] if self.action_listener_process: self.action_listener_process.kill() # Forcefully kill the process @@ -371,22 +368,3 @@ def exit_forcefully(self) -> None: sys.exit( 1 ) # Exit immediately TODO - should we exit with 1 here, there may be other workers to cleanup - - -def register_on_worker(callable: HatchetCallable[T], worker: Worker) -> None: - worker.register_function(callable.get_action_name(), callable) - - if callable.function_on_failure is not None: - worker.register_function( - callable.function_on_failure.get_action_name(), callable.function_on_failure - ) - - if callable.function_concurrency is not None: - worker.register_function( - callable.function_concurrency.get_action_name(), - callable.function_concurrency, - ) - - opts = callable.to_workflow_opts() - - worker.register_workflow_from_opts(opts.name, opts) diff --git a/hatchet_sdk/workflow.py b/hatchet_sdk/workflow.py index 9c5cef90..4a1e045c 100644 --- a/hatchet_sdk/workflow.py +++ b/hatchet_sdk/workflow.py @@ -93,7 +93,7 @@ def get_create_opts(self, namespace: str) -> Any: ... version: str timeout: str schedule_timeout: str - sticky: Union[StickyStrategy.Value, None] # type: ignore[name-defined] + sticky: Union[StickyStrategy, None] default_priority: int | None concurrency_expression: ConcurrencyExpression | None input_validator: Type[BaseModel] | None diff --git a/hatchet_sdk/workflow_run.py b/hatchet_sdk/workflow_run.py index 51a23821..3bf570db 100644 --- a/hatchet_sdk/workflow_run.py +++ b/hatchet_sdk/workflow_run.py @@ -1,5 +1,5 @@ import asyncio -from typing import Any, Coroutine, Generic, Optional, TypedDict, TypeVar +from typing import Any, Coroutine, Generic, TypeVar, cast from hatchet_sdk.clients.run_event_listener import ( RunEventListener, @@ -22,19 +22,19 @@ def __init__( self.workflow_listener = workflow_listener self.workflow_run_event_listener = workflow_run_event_listener - def __str__(self): + def __str__(self) -> str: return self.workflow_run_id def stream(self) -> RunEventListener: return self.workflow_run_event_listener.stream(self.workflow_run_id) - def result(self) -> Coroutine: + def result(self) -> Coroutine[None, None, dict[str, Any]]: return self.workflow_listener.result(self.workflow_run_id) - def sync_result(self) -> dict: + def sync_result(self) -> dict[str, Any]: loop = get_active_event_loop() if loop is None: - with EventLoopThread() as loop: + with EventLoopThread() as loop: # type: ignore[call-arg] coro = self.workflow_listener.result(self.workflow_run_id) future = asyncio.run_coroutine_threadsafe(coro, loop) return future.result() @@ -48,7 +48,7 @@ def sync_result(self) -> dict: class RunRef(WorkflowRunRef, Generic[T]): - async def result(self) -> T: + async def result(self) -> Any | dict[str, Any]: res = await self.workflow_listener.result(self.workflow_run_id) if len(res) == 1: diff --git a/poetry.lock b/poetry.lock index 603caef7..2462047e 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.4 and should not be changed by hand. +# This file is automatically @generated by Poetry 2.0.0 and should not be changed by hand. [[package]] name = "aiohappyeyeballs" @@ -6,6 +6,7 @@ version = "2.4.4" description = "Happy Eyeballs for asyncio" optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "aiohappyeyeballs-2.4.4-py3-none-any.whl", hash = "sha256:a980909d50efcd44795c4afeca523296716d50cd756ddca6af8c65b996e27de8"}, {file = "aiohappyeyeballs-2.4.4.tar.gz", hash = "sha256:5fdd7d87889c63183afc18ce9271f9b0a7d32c2303e394468dd45d514a757745"}, @@ -17,6 +18,7 @@ version = "3.11.11" description = "Async http client/server framework (asyncio)" optional = false python-versions = ">=3.9" +groups = ["main"] files = [ {file = "aiohttp-3.11.11-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:a60804bff28662cbcf340a4d61598891f12eea3a66af48ecfdc975ceec21e3c8"}, {file = "aiohttp-3.11.11-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:4b4fa1cb5f270fb3eab079536b764ad740bb749ce69a94d4ec30ceee1b5940d5"}, @@ -115,6 +117,7 @@ version = "2.9.1" description = "Simple retry client for aiohttp" optional = false python-versions = ">=3.7" +groups = ["main"] files = [ {file = "aiohttp_retry-2.9.1-py3-none-any.whl", hash = "sha256:66d2759d1921838256a05a3f80ad7e724936f083e35be5abb5e16eed6be6dc54"}, {file = "aiohttp_retry-2.9.1.tar.gz", hash = "sha256:8eb75e904ed4ee5c2ec242fefe85bf04240f685391c4879d8f541d6028ff01f1"}, @@ -129,6 +132,7 @@ version = "1.3.2" description = "aiosignal: a list of registered asynchronous callbacks" optional = false python-versions = ">=3.9" +groups = ["main"] files = [ {file = "aiosignal-1.3.2-py2.py3-none-any.whl", hash = "sha256:45cde58e409a301715980c2b01d0c28bdde3770d8290b5eb2173759d9acb31a5"}, {file = "aiosignal-1.3.2.tar.gz", hash = "sha256:a8c255c66fafb1e499c9351d0bf32ff2d8a0321595ebac3b93713656d2436f54"}, @@ -143,6 +147,7 @@ version = "0.5.2" description = "Generator-based operators for asynchronous iteration" optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "aiostream-0.5.2-py3-none-any.whl", hash = "sha256:054660370be9d37f6fe3ece3851009240416bd082e469fd90cc8673d3818cf71"}, {file = "aiostream-0.5.2.tar.gz", hash = "sha256:b71b519a2d66c38f0872403ab86417955b77352f08d9ad02ad46fc3926b389f4"}, @@ -157,6 +162,7 @@ version = "0.7.0" description = "Reusable constraint types to use with typing.Annotated" optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "annotated_types-0.7.0-py3-none-any.whl", hash = "sha256:1f02e8b43a8fbbc3f3e0d4f0f4bfc8131bcb4eebe8849b8e5c773f3a1c582a53"}, {file = "annotated_types-0.7.0.tar.gz", hash = "sha256:aff07c09a53a08bc8cfccb9c85b05f1aa9a2a6f23728d790723543408344ce89"}, @@ -168,6 +174,8 @@ version = "5.0.1" description = "Timeout context manager for asyncio programs" optional = false python-versions = ">=3.8" +groups = ["main"] +markers = "python_version < \"3.11\"" files = [ {file = "async_timeout-5.0.1-py3-none-any.whl", hash = "sha256:39e3809566ff85354557ec2398b55e096c8364bacac9405a7a1fa429e77fe76c"}, {file = "async_timeout-5.0.1.tar.gz", hash = "sha256:d9321a7a3d5a6a5e187e824d2fa0793ce379a202935782d555d6e9d2735677d3"}, @@ -179,6 +187,7 @@ version = "24.3.0" description = "Classes Without Boilerplate" optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "attrs-24.3.0-py3-none-any.whl", hash = "sha256:ac96cd038792094f438ad1f6ff80837353805ac950cd2aa0e0625ef19850c308"}, {file = "attrs-24.3.0.tar.gz", hash = "sha256:8f5c07333d543103541ba7be0e2ce16eeee8130cb0b3f9238ab904ce1e85baff"}, @@ -198,6 +207,7 @@ version = "2.16.0" description = "Internationalization utilities" optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "babel-2.16.0-py3-none-any.whl", hash = "sha256:368b5b98b37c06b7daf6696391c3240c938b37767d4584413e8438c5c435fa8b"}, {file = "babel-2.16.0.tar.gz", hash = "sha256:d1f3554ca26605fe173f3de0c65f750f5a42f924499bf134de6423582298e316"}, @@ -212,6 +222,7 @@ version = "24.10.0" description = "The uncompromising code formatter." optional = false python-versions = ">=3.9" +groups = ["lint"] files = [ {file = "black-24.10.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:e6668650ea4b685440857138e5fe40cde4d652633b1bdffc62933d0db4ed9812"}, {file = "black-24.10.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:1c536fcf674217e87b8cc3657b81809d3c085d7bf3ef262ead700da345bfa6ea"}, @@ -258,6 +269,7 @@ version = "0.1.5" description = "Pure Python CEL Implementation" optional = false python-versions = ">=3.7, <4" +groups = ["main"] files = [ {file = "cel-python-0.1.5.tar.gz", hash = "sha256:d3911bb046bc3ed12792bd88ab453f72d98c66923b72a2fa016bcdffd96e2f98"}, {file = "cel_python-0.1.5-py3-none-any.whl", hash = "sha256:ac81fab8ba08b633700a45d84905be2863529c6a32935c9da7ef53fc06844f1a"}, @@ -278,6 +290,7 @@ version = "2024.12.14" description = "Python package for providing Mozilla's CA Bundle." optional = false python-versions = ">=3.6" +groups = ["main"] files = [ {file = "certifi-2024.12.14-py3-none-any.whl", hash = "sha256:1275f7a45be9464efc1173084eaa30f866fe2e47d389406136d332ed4967ec56"}, {file = "certifi-2024.12.14.tar.gz", hash = "sha256:b650d30f370c2b724812bee08008be0c4163b163ddaec3f2546c1caf65f191db"}, @@ -289,6 +302,7 @@ version = "3.4.1" description = "The Real First Universal Charset Detector. Open, modern and actively maintained alternative to Chardet." optional = false python-versions = ">=3.7" +groups = ["main"] files = [ {file = "charset_normalizer-3.4.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:91b36a978b5ae0ee86c394f5a54d6ef44db1de0815eb43de826d41d21e4af3de"}, {file = "charset_normalizer-3.4.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7461baadb4dc00fd9e0acbe254e3d7d2112e7f92ced2adc96e54ef6501c5f176"}, @@ -390,6 +404,7 @@ version = "8.1.8" description = "Composable command line interface toolkit" optional = false python-versions = ">=3.7" +groups = ["lint"] files = [ {file = "click-8.1.8-py3-none-any.whl", hash = "sha256:63c132bbbed01578a06712a2d1f497bb62d9c1c0d329b7903a866228027263b2"}, {file = "click-8.1.8.tar.gz", hash = "sha256:ed53c9d8990d83c2a27deae68e4ee337473f6330c040a31d4225c9574d16096a"}, @@ -404,10 +419,12 @@ version = "0.4.6" description = "Cross-platform colored terminal text." optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7" +groups = ["main", "dev", "lint", "test"] files = [ {file = "colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6"}, {file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"}, ] +markers = {main = "sys_platform == \"win32\"", dev = "sys_platform == \"win32\"", lint = "platform_system == \"Windows\"", test = "sys_platform == \"win32\""} [[package]] name = "deprecated" @@ -415,6 +432,7 @@ version = "1.2.15" description = "Python @deprecated decorator to deprecate old python classes, functions or methods." optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,>=2.7" +groups = ["main"] files = [ {file = "Deprecated-1.2.15-py2.py3-none-any.whl", hash = "sha256:353bc4a8ac4bfc96800ddab349d89c25dec1079f65fd53acdcc1e0b975b21320"}, {file = "deprecated-1.2.15.tar.gz", hash = "sha256:683e561a90de76239796e6b6feac66b99030d2dd3fcf61ef996330f14bbb9b0d"}, @@ -432,6 +450,8 @@ version = "1.2.2" description = "Backport of PEP 654 (exception groups)" optional = false python-versions = ">=3.7" +groups = ["dev", "test"] +markers = "python_version < \"3.11\"" files = [ {file = "exceptiongroup-1.2.2-py3-none-any.whl", hash = "sha256:3111b9d131c238bec2f8f516e123e14ba243563fb135d3fe885990585aa7795b"}, {file = "exceptiongroup-1.2.2.tar.gz", hash = "sha256:47c2edf7c6738fafb49fd34290706d1a1a2f4d1c6df275526b62cbb4aa5393cc"}, @@ -446,6 +466,7 @@ version = "1.5.0" description = "A list-like structure which implements collections.abc.MutableSequence" optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "frozenlist-1.5.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:5b6a66c18b5b9dd261ca98dffcb826a525334b2f29e7caa54e182255c5f6a65a"}, {file = "frozenlist-1.5.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d1b3eb7b05ea246510b43a7e53ed1653e55c2121019a97e60cad7efb881a97bb"}, @@ -547,6 +568,7 @@ version = "1.66.0" description = "Common protobufs used in Google APIs" optional = false python-versions = ">=3.7" +groups = ["main"] files = [ {file = "googleapis_common_protos-1.66.0-py2.py3-none-any.whl", hash = "sha256:d7abcd75fabb2e0ec9f74466401f6c119a0b498e27370e9be4c94cb7e382b8ed"}, {file = "googleapis_common_protos-1.66.0.tar.gz", hash = "sha256:c3e7b33d15fdca5374cc0a7346dd92ffa847425cc4ea941d970f13680052ec8c"}, @@ -558,12 +580,28 @@ protobuf = ">=3.20.2,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4 [package.extras] grpc = ["grpcio (>=1.44.0,<2.0.0.dev0)"] +[[package]] +name = "grpc-stubs" +version = "1.53.0.5" +description = "Mypy stubs for gRPC" +optional = false +python-versions = ">=3.6" +groups = ["dev"] +files = [ + {file = "grpc-stubs-1.53.0.5.tar.gz", hash = "sha256:3e1b642775cbc3e0c6332cfcedfccb022176db87e518757bef3a1241397be406"}, + {file = "grpc_stubs-1.53.0.5-py3-none-any.whl", hash = "sha256:04183fb65a1b166a1febb9627e3d9647d3926ccc2dfe049fe7b6af243428dbe1"}, +] + +[package.dependencies] +grpcio = "*" + [[package]] name = "grpcio" version = "1.69.0" description = "HTTP/2-based RPC framework" optional = false python-versions = ">=3.8" +groups = ["main", "dev"] files = [ {file = "grpcio-1.69.0-cp310-cp310-linux_armv7l.whl", hash = "sha256:2060ca95a8db295ae828d0fc1c7f38fb26ccd5edf9aa51a0f44251f5da332e97"}, {file = "grpcio-1.69.0-cp310-cp310-macosx_12_0_universal2.whl", hash = "sha256:2e52e107261fd8fa8fa457fe44bfadb904ae869d87c1280bf60f93ecd3e79278"}, @@ -631,6 +669,7 @@ version = "1.69.0" description = "Protobuf code generator for gRPC" optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "grpcio_tools-1.69.0-cp310-cp310-linux_armv7l.whl", hash = "sha256:8c210630faa581c3bd08953dac4ad21a7f49862f3b92d69686e9b436d2f1265d"}, {file = "grpcio_tools-1.69.0-cp310-cp310-macosx_12_0_universal2.whl", hash = "sha256:09b66ea279fcdaebae4ec34b1baf7577af3b14322738aa980c1c33cfea71f7d7"}, @@ -700,6 +739,7 @@ version = "3.10" description = "Internationalized Domain Names in Applications (IDNA)" optional = false python-versions = ">=3.6" +groups = ["main"] files = [ {file = "idna-3.10-py3-none-any.whl", hash = "sha256:946d195a0d259cbba61165e88e65941f16e9b36ea6ddb97f00452bae8b1287d3"}, {file = "idna-3.10.tar.gz", hash = "sha256:12f65c9b470abda6dc35cf8e63cc574b1c52b11df2c86030af0ac09b01b13ea9"}, @@ -714,6 +754,7 @@ version = "8.5.0" description = "Read metadata from Python packages" optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "importlib_metadata-8.5.0-py3-none-any.whl", hash = "sha256:45e54197d28b7a7f1559e60b95e7c567032b602131fbd588f1497f47880aa68b"}, {file = "importlib_metadata-8.5.0.tar.gz", hash = "sha256:71522656f0abace1d072b9e5481a48f07c138e00f079c38c8f883823f9c26bd7"}, @@ -737,6 +778,7 @@ version = "2.0.0" description = "brain-dead simple config-ini parsing" optional = false python-versions = ">=3.7" +groups = ["dev", "test"] files = [ {file = "iniconfig-2.0.0-py3-none-any.whl", hash = "sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374"}, {file = "iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3"}, @@ -748,6 +790,7 @@ version = "5.13.2" description = "A Python utility / library to sort Python imports." optional = false python-versions = ">=3.8.0" +groups = ["lint"] files = [ {file = "isort-5.13.2-py3-none-any.whl", hash = "sha256:8ca5e72a8d85860d5a3fa69b8745237f2939afe12dbf656afbcb47fe72d947a6"}, {file = "isort-5.13.2.tar.gz", hash = "sha256:48fdfcb9face5d58a4f6dde2e72a1fb8dcaf8ab26f95ab49fab84c2ddefb0109"}, @@ -762,6 +805,7 @@ version = "1.0.1" description = "JSON Matching Expressions" optional = false python-versions = ">=3.7" +groups = ["main"] files = [ {file = "jmespath-1.0.1-py3-none-any.whl", hash = "sha256:02e2e4cc71b5bcab88332eebf907519190dd9e6e82107fa7f83b1003a6252980"}, {file = "jmespath-1.0.1.tar.gz", hash = "sha256:90261b206d6defd58fdd5e85f478bf633a2901798906be2ad389150c5c60edbe"}, @@ -773,6 +817,7 @@ version = "0.12.0" description = "a modern parsing library" optional = false python-versions = "*" +groups = ["main"] files = [ {file = "lark-parser-0.12.0.tar.gz", hash = "sha256:15967db1f1214013dca65b1180745047b9be457d73da224fcda3d9dd4e96a138"}, {file = "lark_parser-0.12.0-py2.py3-none-any.whl", hash = "sha256:0eaf30cb5ba787fe404d73a7d6e61df97b21d5a63ac26c5008c78a494373c675"}, @@ -789,6 +834,7 @@ version = "0.7.3" description = "Python logging made (stupidly) simple" optional = false python-versions = "<4.0,>=3.5" +groups = ["main"] files = [ {file = "loguru-0.7.3-py3-none-any.whl", hash = "sha256:31a33c10c8e1e10422bfd431aeb5d351c7cf7fa671e3c4df004162264b28220c"}, {file = "loguru-0.7.3.tar.gz", hash = "sha256:19480589e77d47b8d85b2c827ad95d49bf31b0dcde16593892eb51dd18706eb6"}, @@ -807,6 +853,7 @@ version = "6.1.0" description = "multidict implementation" optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "multidict-6.1.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:3380252550e372e8511d49481bd836264c009adb826b23fefcc5dd3c69692f60"}, {file = "multidict-6.1.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:99f826cbf970077383d7de805c0681799491cb939c25450b9b5b3ced03ca99f1"}, @@ -911,6 +958,7 @@ version = "1.14.1" description = "Optional static typing for Python" optional = false python-versions = ">=3.8" +groups = ["lint"] files = [ {file = "mypy-1.14.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:52686e37cf13d559f668aa398dd7ddf1f92c5d613e4f8cb262be2fb4fedb0fcb"}, {file = "mypy-1.14.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:1fb545ca340537d4b45d3eecdb3def05e913299ca72c290326be19b3804b39c0"}, @@ -970,6 +1018,7 @@ version = "1.0.0" description = "Type system extensions for programs checked with the mypy type checker." optional = false python-versions = ">=3.5" +groups = ["lint"] files = [ {file = "mypy_extensions-1.0.0-py3-none-any.whl", hash = "sha256:4392f6c0eb8a5668a69e23d168ffa70f0be9ccfd32b5cc2d26a34ae5b844552d"}, {file = "mypy_extensions-1.0.0.tar.gz", hash = "sha256:75dbf8955dc00442a438fc4d0666508a9a97b6bd41aa2f0ffe9d2f2725af0782"}, @@ -981,6 +1030,7 @@ version = "1.6.0" description = "Patch asyncio to allow nested event loops" optional = false python-versions = ">=3.5" +groups = ["main"] files = [ {file = "nest_asyncio-1.6.0-py3-none-any.whl", hash = "sha256:87af6efd6b5e897c81050477ef65c62e2b2f35d51703cae01aff2905b1852e1c"}, {file = "nest_asyncio-1.6.0.tar.gz", hash = "sha256:6f172d5449aca15afd6c646851f4e31e02c598d553a667e38cafa997cfec55fe"}, @@ -992,6 +1042,7 @@ version = "1.29.0" description = "OpenTelemetry Python API" optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "opentelemetry_api-1.29.0-py3-none-any.whl", hash = "sha256:5fcd94c4141cc49c736271f3e1efb777bebe9cc535759c54c936cca4f1b312b8"}, {file = "opentelemetry_api-1.29.0.tar.gz", hash = "sha256:d04a6cf78aad09614f52964ecb38021e248f5714dc32c2e0d8fd99517b4d69cf"}, @@ -1007,6 +1058,7 @@ version = "0.50b0" description = "OpenTelemetry Python Distro" optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "opentelemetry_distro-0.50b0-py3-none-any.whl", hash = "sha256:5fa2e2a99a047ea477fab53e73fb8088b907bda141e8440745b92eb2a84d74aa"}, {file = "opentelemetry_distro-0.50b0.tar.gz", hash = "sha256:3e059e00f53553ebd646d1162d1d3edf5d7c6d3ceafd54a49e74c90dc1c39a7d"}, @@ -1026,6 +1078,7 @@ version = "1.29.0" description = "OpenTelemetry Collector Exporters" optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "opentelemetry_exporter_otlp-1.29.0-py3-none-any.whl", hash = "sha256:b8da6e20f5b0ffe604154b1e16a407eade17ce310c42fb85bb4e1246fc3688ad"}, {file = "opentelemetry_exporter_otlp-1.29.0.tar.gz", hash = "sha256:ee7dfcccbb5e87ad9b389908452e10b7beeab55f70a83f41ce5b8c4efbde6544"}, @@ -1041,6 +1094,7 @@ version = "1.29.0" description = "OpenTelemetry Protobuf encoding" optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "opentelemetry_exporter_otlp_proto_common-1.29.0-py3-none-any.whl", hash = "sha256:a9d7376c06b4da9cf350677bcddb9618ed4b8255c3f6476975f5e38274ecd3aa"}, {file = "opentelemetry_exporter_otlp_proto_common-1.29.0.tar.gz", hash = "sha256:e7c39b5dbd1b78fe199e40ddfe477e6983cb61aa74ba836df09c3869a3e3e163"}, @@ -1055,6 +1109,7 @@ version = "1.29.0" description = "OpenTelemetry Collector Protobuf over gRPC Exporter" optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "opentelemetry_exporter_otlp_proto_grpc-1.29.0-py3-none-any.whl", hash = "sha256:5a2a3a741a2543ed162676cf3eefc2b4150e6f4f0a193187afb0d0e65039c69c"}, {file = "opentelemetry_exporter_otlp_proto_grpc-1.29.0.tar.gz", hash = "sha256:3d324d07d64574d72ed178698de3d717f62a059a93b6b7685ee3e303384e73ea"}, @@ -1075,6 +1130,7 @@ version = "1.29.0" description = "OpenTelemetry Collector Protobuf over HTTP Exporter" optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "opentelemetry_exporter_otlp_proto_http-1.29.0-py3-none-any.whl", hash = "sha256:b228bdc0f0cfab82eeea834a7f0ffdd2a258b26aa33d89fb426c29e8e934d9d0"}, {file = "opentelemetry_exporter_otlp_proto_http-1.29.0.tar.gz", hash = "sha256:b10d174e3189716f49d386d66361fbcf6f2b9ad81e05404acdee3f65c8214204"}, @@ -1095,6 +1151,7 @@ version = "0.50b0" description = "Instrumentation Tools & Auto Instrumentation for OpenTelemetry Python" optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "opentelemetry_instrumentation-0.50b0-py3-none-any.whl", hash = "sha256:b8f9fc8812de36e1c6dffa5bfc6224df258841fb387b6dfe5df15099daa10630"}, {file = "opentelemetry_instrumentation-0.50b0.tar.gz", hash = "sha256:7d98af72de8dec5323e5202e46122e5f908592b22c6d24733aad619f07d82979"}, @@ -1112,6 +1169,7 @@ version = "1.29.0" description = "OpenTelemetry Python Proto" optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "opentelemetry_proto-1.29.0-py3-none-any.whl", hash = "sha256:495069c6f5495cbf732501cdcd3b7f60fda2b9d3d4255706ca99b7ca8dec53ff"}, {file = "opentelemetry_proto-1.29.0.tar.gz", hash = "sha256:3c136aa293782e9b44978c738fff72877a4b78b5d21a64e879898db7b2d93e5d"}, @@ -1126,6 +1184,7 @@ version = "1.29.0" description = "OpenTelemetry Python SDK" optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "opentelemetry_sdk-1.29.0-py3-none-any.whl", hash = "sha256:173be3b5d3f8f7d671f20ea37056710217959e774e2749d984355d1f9391a30a"}, {file = "opentelemetry_sdk-1.29.0.tar.gz", hash = "sha256:b0787ce6aade6ab84315302e72bd7a7f2f014b0fb1b7c3295b88afe014ed0643"}, @@ -1142,6 +1201,7 @@ version = "0.50b0" description = "OpenTelemetry Semantic Conventions" optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "opentelemetry_semantic_conventions-0.50b0-py3-none-any.whl", hash = "sha256:e87efba8fdb67fb38113efea6a349531e75ed7ffc01562f65b802fcecb5e115e"}, {file = "opentelemetry_semantic_conventions-0.50b0.tar.gz", hash = "sha256:02dc6dbcb62f082de9b877ff19a3f1ffaa3c306300fa53bfac761c4567c83d38"}, @@ -1157,6 +1217,7 @@ version = "24.2" description = "Core utilities for Python packages" optional = false python-versions = ">=3.8" +groups = ["main", "dev", "lint", "test"] files = [ {file = "packaging-24.2-py3-none-any.whl", hash = "sha256:09abb1bccd265c01f4a3aa3f7a7db064b36514d2cba19a2f694fe6150451a759"}, {file = "packaging-24.2.tar.gz", hash = "sha256:c228a6dc5e932d346bc5739379109d49e8853dd8223571c7c5b55260edc0b97f"}, @@ -1168,6 +1229,7 @@ version = "0.12.1" description = "Utility library for gitignore style pattern matching of file paths." optional = false python-versions = ">=3.8" +groups = ["lint"] files = [ {file = "pathspec-0.12.1-py3-none-any.whl", hash = "sha256:a0d503e138a4c123b27490a4f7beda6a01c6f288df0e4a8b79c7eb0dc7b4cc08"}, {file = "pathspec-0.12.1.tar.gz", hash = "sha256:a482d51503a1ab33b1c67a6c3813a26953dbdc71c31dacaef9a838c4e29f5712"}, @@ -1179,6 +1241,7 @@ version = "4.3.6" description = "A small Python package for determining appropriate platform-specific dirs, e.g. a `user data dir`." optional = false python-versions = ">=3.8" +groups = ["lint"] files = [ {file = "platformdirs-4.3.6-py3-none-any.whl", hash = "sha256:73e575e1408ab8103900836b97580d5307456908a03e92031bab39e4554cc3fb"}, {file = "platformdirs-4.3.6.tar.gz", hash = "sha256:357fb2acbc885b0419afd3ce3ed34564c13c9b95c89360cd9563f73aa5e2b907"}, @@ -1195,6 +1258,7 @@ version = "1.5.0" description = "plugin and hook calling mechanisms for python" optional = false python-versions = ">=3.8" +groups = ["dev", "test"] files = [ {file = "pluggy-1.5.0-py3-none-any.whl", hash = "sha256:44e1ad92c8ca002de6377e165f3e0f1be63266ab4d554740532335b9d75ea669"}, {file = "pluggy-1.5.0.tar.gz", hash = "sha256:2cffa88e94fdc978c4c574f15f9e59b7f4201d439195c3715ca9e2486f1d0cf1"}, @@ -1210,6 +1274,7 @@ version = "0.21.1" description = "Python client for the Prometheus monitoring system." optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "prometheus_client-0.21.1-py3-none-any.whl", hash = "sha256:594b45c410d6f4f8888940fe80b5cc2521b305a1fafe1c58609ef715a001f301"}, {file = "prometheus_client-0.21.1.tar.gz", hash = "sha256:252505a722ac04b0456be05c05f75f45d760c2911ffc45f2a06bcaed9f3ae3fb"}, @@ -1224,6 +1289,7 @@ version = "0.2.1" description = "Accelerated property cache" optional = false python-versions = ">=3.9" +groups = ["main"] files = [ {file = "propcache-0.2.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:6b3f39a85d671436ee3d12c017f8fdea38509e4f25b28eb25877293c98c243f6"}, {file = "propcache-0.2.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:39d51fbe4285d5db5d92a929e3e21536ea3dd43732c5b177c7ef03f918dff9f2"}, @@ -1315,6 +1381,7 @@ version = "5.29.2" description = "" optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "protobuf-5.29.2-cp310-abi3-win32.whl", hash = "sha256:c12ba8249f5624300cf51c3d0bfe5be71a60c63e4dcf51ffe9a68771d958c851"}, {file = "protobuf-5.29.2-cp310-abi3-win_amd64.whl", hash = "sha256:842de6d9241134a973aab719ab42b008a18a90f9f07f06ba480df268f86432f9"}, @@ -1335,6 +1402,7 @@ version = "6.1.1" description = "Cross-platform lib for process and system monitoring in Python." optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,>=2.7" +groups = ["dev"] files = [ {file = "psutil-6.1.1-cp27-cp27m-macosx_10_9_x86_64.whl", hash = "sha256:9ccc4316f24409159897799b83004cb1e24f9819b0dcf9c0b68bdcb6cefee6a8"}, {file = "psutil-6.1.1-cp27-cp27m-manylinux2010_i686.whl", hash = "sha256:ca9609c77ea3b8481ab005da74ed894035936223422dc591d6772b147421f777"}, @@ -1365,6 +1433,7 @@ version = "2.10.4" description = "Data validation using Python type hints" optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "pydantic-2.10.4-py3-none-any.whl", hash = "sha256:597e135ea68be3a37552fb524bc7d0d66dcf93d395acd93a00682f1efcb8ee3d"}, {file = "pydantic-2.10.4.tar.gz", hash = "sha256:82f12e9723da6de4fe2ba888b5971157b3be7ad914267dea8f05f82b28254f06"}, @@ -1385,6 +1454,7 @@ version = "2.27.2" description = "Core functionality for Pydantic validation and serialization" optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "pydantic_core-2.27.2-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:2d367ca20b2f14095a8f4fa1210f5a7b78b8a20009ecced6b12818f455b1e9fa"}, {file = "pydantic_core-2.27.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:491a2b73db93fab69731eaee494f320faa4e093dbed776be1a829c2eb222c34c"}, @@ -1497,6 +1567,7 @@ version = "8.3.4" description = "pytest: simple powerful testing with Python" optional = false python-versions = ">=3.8" +groups = ["dev", "test"] files = [ {file = "pytest-8.3.4-py3-none-any.whl", hash = "sha256:50e16d954148559c9a74109af1eaf0c945ba2d8f30f0a3d3335edde19788b6f6"}, {file = "pytest-8.3.4.tar.gz", hash = "sha256:965370d062bce11e73868e0335abac31b4d3de0e82f4007408d242b4f8610761"}, @@ -1519,6 +1590,7 @@ version = "0.23.8" description = "Pytest support for asyncio" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "pytest_asyncio-0.23.8-py3-none-any.whl", hash = "sha256:50265d892689a5faefb84df80819d1ecef566eb3549cf915dfb33569359d1ce2"}, {file = "pytest_asyncio-0.23.8.tar.gz", hash = "sha256:759b10b33a6dc61cce40a8bd5205e302978bbbcc00e279a8b61d9a6a3c82e4d3"}, @@ -1531,12 +1603,32 @@ pytest = ">=7.0.0,<9" docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1.0)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] +[[package]] +name = "pytest-env" +version = "1.1.5" +description = "pytest plugin that allows you to add environment variables." +optional = false +python-versions = ">=3.8" +groups = ["test"] +files = [ + {file = "pytest_env-1.1.5-py3-none-any.whl", hash = "sha256:ce90cf8772878515c24b31cd97c7fa1f4481cd68d588419fd45f10ecaee6bc30"}, + {file = "pytest_env-1.1.5.tar.gz", hash = "sha256:91209840aa0e43385073ac464a554ad2947cc2fd663a9debf88d03b01e0cc1cf"}, +] + +[package.dependencies] +pytest = ">=8.3.3" +tomli = {version = ">=2.0.1", markers = "python_version < \"3.11\""} + +[package.extras] +testing = ["covdefaults (>=2.3)", "coverage (>=7.6.1)", "pytest-mock (>=3.14)"] + [[package]] name = "pytest-timeout" version = "2.3.1" description = "pytest plugin to abort hanging tests" optional = false python-versions = ">=3.7" +groups = ["test"] files = [ {file = "pytest-timeout-2.3.1.tar.gz", hash = "sha256:12397729125c6ecbdaca01035b9e5239d4db97352320af155b3f5de1ba5165d9"}, {file = "pytest_timeout-2.3.1-py3-none-any.whl", hash = "sha256:68188cb703edfc6a18fad98dc25a3c61e9f24d644b0b70f33af545219fc7813e"}, @@ -1551,6 +1643,7 @@ version = "2.9.0.post0" description = "Extensions to the standard Python datetime module" optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7" +groups = ["main"] files = [ {file = "python-dateutil-2.9.0.post0.tar.gz", hash = "sha256:37dd54208da7e1cd875388217d5e00ebd4179249f90fb72437e91a35459a0ad3"}, {file = "python_dateutil-2.9.0.post0-py2.py3-none-any.whl", hash = "sha256:a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427"}, @@ -1565,6 +1658,7 @@ version = "1.0.1" description = "Read key-value pairs from a .env file and set them as environment variables" optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "python-dotenv-1.0.1.tar.gz", hash = "sha256:e324ee90a023d808f1959c46bcbc04446a10ced277783dc6ee09987c37ec10ca"}, {file = "python_dotenv-1.0.1-py3-none-any.whl", hash = "sha256:f7b63ef50f1b690dddf550d03497b66d609393b40b564ed0d674909a68ebf16a"}, @@ -1579,6 +1673,7 @@ version = "6.0.2" description = "YAML parser and emitter for Python" optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "PyYAML-6.0.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:0a9a2848a5b7feac301353437eb7d5957887edbf81d56e903999a75a3d743086"}, {file = "PyYAML-6.0.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:29717114e51c84ddfba879543fb232a6ed60086602313ca38cce623c1d62cfbf"}, @@ -1641,6 +1736,7 @@ version = "2.32.3" description = "Python HTTP for Humans." optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "requests-2.32.3-py3-none-any.whl", hash = "sha256:70761cfe03c773ceb22aa2f671b4757976145175cdfca038c02654d061d6dcc6"}, {file = "requests-2.32.3.tar.gz", hash = "sha256:55365417734eb18255590a9ff9eb97e9e1da868d4ccd6402399eaf68af20a760"}, @@ -1662,6 +1758,7 @@ version = "75.7.0" description = "Easily download, build, install, upgrade, and uninstall Python packages" optional = false python-versions = ">=3.9" +groups = ["main"] files = [ {file = "setuptools-75.7.0-py3-none-any.whl", hash = "sha256:84fb203f278ebcf5cd08f97d3fb96d3fbed4b629d500b29ad60d11e00769b183"}, {file = "setuptools-75.7.0.tar.gz", hash = "sha256:886ff7b16cd342f1d1defc16fc98c9ce3fde69e087a4e1983d7ab634e5f41f4f"}, @@ -1682,6 +1779,7 @@ version = "1.17.0" description = "Python 2 and 3 compatibility utilities" optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7" +groups = ["main"] files = [ {file = "six-1.17.0-py2.py3-none-any.whl", hash = "sha256:4721f391ed90541fddacab5acf947aa0d3dc7d27b2e1e8eda2be8970586c3274"}, {file = "six-1.17.0.tar.gz", hash = "sha256:ff70335d468e7eb6ec65b95b99d3a2836546063f63acc5171de367e834932a81"}, @@ -1693,6 +1791,7 @@ version = "9.0.0" description = "Retry code until it succeeds" optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "tenacity-9.0.0-py3-none-any.whl", hash = "sha256:93de0c98785b27fcf659856aa9f54bfbd399e29969b0621bc7f762bd441b4539"}, {file = "tenacity-9.0.0.tar.gz", hash = "sha256:807f37ca97d62aa361264d497b0e31e92b8027044942bfa756160d908320d73b"}, @@ -1708,6 +1807,8 @@ version = "2.2.1" description = "A lil' TOML parser" optional = false python-versions = ">=3.8" +groups = ["dev", "lint", "test"] +markers = "python_version < \"3.11\"" files = [ {file = "tomli-2.2.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:678e4fa69e4575eb77d103de3df8a895e1591b48e740211bd1067378c69e8249"}, {file = "tomli-2.2.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:023aa114dd824ade0100497eb2318602af309e5a55595f76b626d6d9f3b7b0a6"}, @@ -1749,17 +1850,31 @@ version = "5.29.1.20241207" description = "Typing stubs for protobuf" optional = false python-versions = ">=3.8" +groups = ["lint"] files = [ {file = "types_protobuf-5.29.1.20241207-py3-none-any.whl", hash = "sha256:92893c42083e9b718c678badc0af7a9a1307b92afe1599e5cba5f3d35b668b2f"}, {file = "types_protobuf-5.29.1.20241207.tar.gz", hash = "sha256:2ebcadb8ab3ef2e3e2f067e0882906d64ba0dc65fc5b0fd7a8b692315b4a0be9"}, ] +[[package]] +name = "types-psutil" +version = "6.1.0.20241221" +description = "Typing stubs for psutil" +optional = false +python-versions = ">=3.8" +groups = ["lint"] +files = [ + {file = "types_psutil-6.1.0.20241221-py3-none-any.whl", hash = "sha256:8498dbe13285a9ba7d4b2fa934c569cc380efc74e3dacdb34ae16d2cdf389ec3"}, + {file = "types_psutil-6.1.0.20241221.tar.gz", hash = "sha256:600f5a36bd5e0eb8887f0e3f3ff2cf154d90690ad8123c8a707bba4ab94d3185"}, +] + [[package]] name = "typing-extensions" version = "4.12.2" description = "Backported and Experimental Type Hints for Python 3.8+" optional = false python-versions = ">=3.8" +groups = ["main", "lint"] files = [ {file = "typing_extensions-4.12.2-py3-none-any.whl", hash = "sha256:04e5ca0351e0f3f85c6853954072df659d0d13fac324d0072316b67d7794700d"}, {file = "typing_extensions-4.12.2.tar.gz", hash = "sha256:1a7ead55c7e559dd4dee8856e3a88b41225abfe1ce8df57b7c13915fe121ffb8"}, @@ -1771,6 +1886,7 @@ version = "2.3.0" description = "HTTP library with thread-safe connection pooling, file post, and more." optional = false python-versions = ">=3.9" +groups = ["main"] files = [ {file = "urllib3-2.3.0-py3-none-any.whl", hash = "sha256:1cee9ad369867bfdbbb48b7dd50374c0967a0bb7710050facf0dd6911440e3df"}, {file = "urllib3-2.3.0.tar.gz", hash = "sha256:f8c5449b3cf0861679ce7e0503c7b44b5ec981bec0d1d3795a07f1ba96f0204d"}, @@ -1788,6 +1904,8 @@ version = "1.2.0" description = "A small Python utility to set file creation time on Windows" optional = false python-versions = ">=3.5" +groups = ["main"] +markers = "sys_platform == \"win32\"" files = [ {file = "win32_setctime-1.2.0-py3-none-any.whl", hash = "sha256:95d644c4e708aba81dc3704a116d8cbc974d70b3bdb8be1d150e36be6e9d1390"}, {file = "win32_setctime-1.2.0.tar.gz", hash = "sha256:ae1fdf948f5640aae05c511ade119313fb6a30d7eabe25fef9764dca5873c4c0"}, @@ -1802,6 +1920,7 @@ version = "1.17.0" description = "Module for decorators, wrappers and monkey patching." optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "wrapt-1.17.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:2a0c23b8319848426f305f9cb0c98a6e32ee68a36264f45948ccf8e7d2b941f8"}, {file = "wrapt-1.17.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b1ca5f060e205f72bec57faae5bd817a1560fcfc4af03f414b08fa29106b7e2d"}, @@ -1876,6 +1995,7 @@ version = "1.18.3" description = "Yet another URL library" optional = false python-versions = ">=3.9" +groups = ["main"] files = [ {file = "yarl-1.18.3-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:7df647e8edd71f000a5208fe6ff8c382a1de8edfbccdbbfe649d263de07d8c34"}, {file = "yarl-1.18.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:c69697d3adff5aa4f874b19c0e4ed65180ceed6318ec856ebc423aa5850d84f7"}, @@ -1972,6 +2092,7 @@ version = "3.21.0" description = "Backport of pathlib-compatible object wrapper for zip files" optional = false python-versions = ">=3.9" +groups = ["main"] files = [ {file = "zipp-3.21.0-py3-none-any.whl", hash = "sha256:ac1bbe05fd2991f160ebce24ffbac5f6d11d83dc90891255885223d42b3cd931"}, {file = "zipp-3.21.0.tar.gz", hash = "sha256:2c9958f6430a2040341a52eb608ed6dd93ef4392e02ffe219417c1b28b5dd1f4"}, @@ -1986,6 +2107,6 @@ test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools", type = ["pytest-mypy"] [metadata] -lock-version = "2.0" +lock-version = "2.1" python-versions = "^3.10" -content-hash = "414d63b255f80d13260cb3a9ecce29f782af46280bba79395554595a47c42f05" +content-hash = "e59b746d16c418856dbf00015dfb396a703e13961b53e0196bb511c530920e47" diff --git a/pyproject.toml b/pyproject.toml index 099b3e2c..3a21726f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,11 +10,11 @@ include = ["hatchet_sdk/py.typed"] python = "^3.10" grpcio = [ { version = ">=1.64.1, !=1.68.*", markers = "python_version < '3.13'" }, - { version = ">=1.69.0", markers = "python_version >= '3.13'" } + { version = ">=1.69.0", markers = "python_version >= '3.13'" }, ] grpcio-tools = [ { version = ">=1.64.1, !=1.68.*", markers = "python_version < '3.13'" }, - { version = ">=1.69.0", markers = "python_version >= '3.13'" } + { version = ">=1.69.0", markers = "python_version >= '3.13'" }, ] python-dotenv = "^1.0.0" protobuf = "^5.29.1" @@ -41,15 +41,18 @@ prometheus-client = "^0.21.1" pytest = "^8.2.2" pytest-asyncio = "^0.23.8" psutil = "^6.0.0" +grpc-stubs = "^1.53.0.5" [tool.poetry.group.lint.dependencies] mypy = "^1.14.0" types-protobuf = "^5.28.3.20241030" black = "^24.10.0" isort = "^5.13.2" +types-psutil = "^6.1.0.20241221" [tool.poetry.group.test.dependencies] pytest-timeout = "^2.3.1" +pytest-env = "^1.1.5" [build-system] requires = ["poetry-core"] @@ -57,19 +60,22 @@ build-backend = "poetry.core.masonry.api" [tool.pytest.ini_options] log_cli = true +env = [ + "HATCHET_CLIENT_TOKEN=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhdWQiOiJodHRwOi8vbG9jYWxob3N0OjEyMzQiLCJleHAiOjk5OTk5OTk5OTksImdycGNfYnJvYWRjYXN0X2FkZHJlc3MiOiJodHRwOi8vbG9jYWxob3N0OjQ0MyIsImlhdCI6MTIzNDU2Nzg5MSwiaXNzIjoiaHR0cDovL2xvY2FsaG9zdDoxMjM0Iiwic2VydmVyX3VybCI6Imh0dHA6Ly9sb2NhbGhvc3Q6MTIzNCIsInN1YiI6IjAwMDAwMDAwLTVmN2QtNGM1NS1iZmEzLWFkZDk0MTc4YjhmNyIsInRva2VuX2lkIjoiMDAwMDAwMDAtZmU5ZS00ZGEyLThmOTgtNTQ5YTgxOWRmZTE5In0.bIly53KfKcXP_7wjySWvbmxG9cVqit-fzVQAF5K7rPc", +] [tool.isort] profile = "black" known_third_party = [ - "grpcio", - "grpcio_tools", - "loguru", - "protobuf", - "pydantic", - "python_dotenv", - "python_dateutil", - "pyyaml", - "urllib3", + "grpcio", + "grpcio_tools", + "loguru", + "protobuf", + "pydantic", + "python_dotenv", + "python_dateutil", + "pyyaml", + "urllib3", ] extend_skip = ["hatchet_sdk/contracts/"] @@ -77,26 +83,17 @@ extend_skip = ["hatchet_sdk/contracts/"] extend_exclude = "hatchet_sdk/contracts/" [tool.mypy] -strict = true -files = [ - "hatchet_sdk/hatchet.py", - "hatchet_sdk/worker/worker.py", - "hatchet_sdk/context/context.py", - "hatchet_sdk/worker/runner/runner.py", - "hatchet_sdk/workflow.py", - "hatchet_sdk/utils/serialization.py", - "hatchet_sdk/utils/tracing.py", - "hatchet_sdk/utils/types.py", - "hatchet_sdk/utils/backoff.py", - "examples/**/*.py", - "hatchet_sdk/clients/rest/models/workflow_list.py", - "hatchet_sdk/clients/rest/models/workflow_run.py", - "hatchet_sdk/context/worker_context.py", - "hatchet_sdk/clients/dispatcher/dispatcher.py", +files = ["."] +exclude = [ + "hatchet_sdk/clients/rest/api/*", + "hatchet_sdk/clients/rest/models/*", + "hatchet_sdk/contracts", + "hatchet_sdk/clients/rest/api_client.py", + "hatchet_sdk/clients/rest/configuration.py", + "hatchet_sdk/clients/rest/exceptions.py", + "hatchet_sdk/clients/rest/rest.py", ] -follow_imports = "silent" -disable_error_code = ["unused-coroutine"] -explicit_package_bases = true +strict = true [tool.poetry.scripts] api = "examples.api.api:main" @@ -121,4 +118,3 @@ existing_loop = "examples.worker_existing_loop.worker:main" bulk_fanout = "examples.bulk_fanout.worker:main" retries_with_backoff = "examples.retries_with_backoff.worker:main" pydantic = "examples.pydantic.worker:main" -v2_simple = "examples.v2.simple.worker:main" diff --git a/tests/test_client.py b/tests/test_client.py new file mode 100644 index 00000000..f72571f5 --- /dev/null +++ b/tests/test_client.py @@ -0,0 +1,17 @@ +import os + +from hatchet_sdk.loader import DEFAULT_HOST_PORT, ClientConfig + + +def test_client_initialization_from_defaults() -> None: + assert isinstance(ClientConfig(), ClientConfig) + + +def test_client_host_port_overrides() -> None: + host_port = "localhost:8080" + with_host_port = ClientConfig(host_port=host_port) + assert with_host_port.host_port == host_port + assert with_host_port.server_url == host_port + + assert ClientConfig().host_port != host_port + assert ClientConfig().server_url != host_port