diff --git a/conftest.py b/conftest.py index 2aff5cd3..41598425 100644 --- a/conftest.py +++ b/conftest.py @@ -1,35 +1,82 @@ import logging +import os import subprocess import time from io import BytesIO from threading import Thread -from typing import AsyncGenerator, Callable, cast +from typing import AsyncGenerator, Callable, Generator, cast import psutil import pytest import pytest_asyncio -from hatchet_sdk import Hatchet +from hatchet_sdk import ClientConfig, Hatchet +from hatchet_sdk.loader import ClientTLSConfig + + +@pytest.fixture(scope="session", autouse=True) +def token() -> str: + result = subprocess.run( + [ + "docker", + "compose", + "run", + "--no-deps", + "setup-config", + "/hatchet/hatchet-admin", + "token", + "create", + "--config", + "/hatchet/config", + "--tenant-id", + "707d0855-80ab-4e1f-a156-f1c4546cbf52", + ], + capture_output=True, + text=True, + ) + + token = result.stdout.strip() + + os.environ["HATCHET_CLIENT_TOKEN"] = token + + return token @pytest_asyncio.fixture(scope="session") -async def aiohatchet() -> AsyncGenerator[Hatchet, None]: - yield Hatchet(debug=True) +async def aiohatchet(token: str) -> AsyncGenerator[Hatchet, None]: + yield Hatchet( + debug=True, + config=ClientConfig( + token=token, + tls_config=ClientTLSConfig(strategy="none"), + ), + ) @pytest.fixture(scope="session") -def hatchet() -> Hatchet: - return Hatchet(debug=True) +def hatchet(token: str) -> Hatchet: + return Hatchet( + debug=True, + config=ClientConfig( + token=token, + tls_config=ClientTLSConfig(strategy="none"), + ), + ) @pytest.fixture() -def worker(request: pytest.FixtureRequest): +def worker( + request: pytest.FixtureRequest, +) -> Generator[subprocess.Popen[bytes], None, None]: example = cast(str, request.param) command = ["poetry", "run", example] logging.info(f"Starting background worker: {' '.join(command)}") - proc = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + + proc = subprocess.Popen( + command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=os.environ.copy() + ) # Check if the process is still running if proc.poll() is not None: diff --git a/examples/__init__.py b/examples/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/_deprecated/README.md b/examples/_deprecated/README.md deleted file mode 100644 index ee47c61b..00000000 --- a/examples/_deprecated/README.md +++ /dev/null @@ -1 +0,0 @@ -The examples and tests in this directory are deprecated, but we're maintaining them to ensure backwards compatibility. diff --git a/examples/_deprecated/concurrency_limit_rr/event.py b/examples/_deprecated/concurrency_limit_rr/event.py deleted file mode 100644 index 16b2bcd0..00000000 --- a/examples/_deprecated/concurrency_limit_rr/event.py +++ /dev/null @@ -1,15 +0,0 @@ -from dotenv import load_dotenv - -from hatchet_sdk import new_client - -load_dotenv() - -client = new_client() - -for i in range(200): - group = "0" - - if i % 2 == 0: - group = "1" - - client.event.push("concurrency-test", {"group": group}) diff --git a/examples/_deprecated/concurrency_limit_rr/test_dep_concurrency_limit_rr.py b/examples/_deprecated/concurrency_limit_rr/test_dep_concurrency_limit_rr.py deleted file mode 100644 index 978186d1..00000000 --- a/examples/_deprecated/concurrency_limit_rr/test_dep_concurrency_limit_rr.py +++ /dev/null @@ -1,58 +0,0 @@ -import asyncio -import time - -import pytest - -from hatchet_sdk import Hatchet, Worker -from hatchet_sdk.workflow_run import WorkflowRunRef - - -# requires scope module or higher for shared event loop -@pytest.mark.parametrize("worker", ["concurrency_limit_rr"], indirect=True) -@pytest.mark.skip(reason="The timing for this test is not reliable") -@pytest.mark.asyncio(scope="session") -async def test_run(aiohatchet: Hatchet, worker: Worker) -> None: - num_groups = 2 - runs: list[WorkflowRunRef] = [] - - # Start all runs - for i in range(1, num_groups + 1): - run = aiohatchet.admin.run_workflow("ConcurrencyDemoWorkflowRR", {"group": i}) - runs.append(run) - run = aiohatchet.admin.run_workflow("ConcurrencyDemoWorkflowRR", {"group": i}) - runs.append(run) - - # Wait for all results - successful_runs = [] - cancelled_runs = [] - - start_time = time.time() - - # Process each run individually - for i, run in enumerate(runs, start=1): - try: - result = await run.result() - successful_runs.append((i, result)) - except Exception as e: - if "CANCELLED_BY_CONCURRENCY_LIMIT" in str(e): - cancelled_runs.append((i, str(e))) - else: - raise # Re-raise if it's an unexpected error - - end_time = time.time() - total_time = end_time - start_time - - # Check that we have the correct number of successful and cancelled runs - assert ( - len(successful_runs) == 4 - ), f"Expected 4 successful runs, got {len(successful_runs)}" - assert ( - len(cancelled_runs) == 0 - ), f"Expected 0 cancelled run, got {len(cancelled_runs)}" - - # Check that the total time is close to 2 seconds - assert ( - 3.8 <= total_time <= 7 - ), f"Expected runtime to be about 4 seconds, but it took {total_time:.2f} seconds" - - print(f"Total execution time: {total_time:.2f} seconds") diff --git a/examples/_deprecated/concurrency_limit_rr/worker.py b/examples/_deprecated/concurrency_limit_rr/worker.py deleted file mode 100644 index 9678e798..00000000 --- a/examples/_deprecated/concurrency_limit_rr/worker.py +++ /dev/null @@ -1,38 +0,0 @@ -import time - -from dotenv import load_dotenv - -from hatchet_sdk import ConcurrencyLimitStrategy, Context, Hatchet - -load_dotenv() - -hatchet = Hatchet(debug=True) - - -@hatchet.workflow(on_events=["concurrency-test"], schedule_timeout="10m") -class ConcurrencyDemoWorkflowRR: - - # NOTE: We're replacing the concurrency key function with a CEL expression - # to simplify architecture. - # See ../../concurrency_limit_rr/worker.py for the new implementation. - @hatchet.concurrency( - max_runs=1, limit_strategy=ConcurrencyLimitStrategy.GROUP_ROUND_ROBIN - ) - def concurrency(self, context: Context) -> str: - input = context.workflow_input() - print(input) - return f'group-{input["group"]}' - - @hatchet.step() - def step1(self, context: Context) -> None: - print("starting step1") - time.sleep(2) - print("finished step1") - pass - - -workflow = ConcurrencyDemoWorkflowRR() -worker = hatchet.worker("concurrency-demo-worker-rr", max_runs=10) -worker.register_workflow(workflow) - -worker.start() diff --git a/examples/_deprecated/test_event_client.py b/examples/_deprecated/test_event_client.py deleted file mode 100644 index 41c6ad65..00000000 --- a/examples/_deprecated/test_event_client.py +++ /dev/null @@ -1,26 +0,0 @@ -import pytest -from dotenv import load_dotenv - -from hatchet_sdk import new_client -from hatchet_sdk.hatchet import Hatchet - -load_dotenv() - - -@pytest.mark.asyncio(scope="session") -async def test_direct_client_event() -> None: - client = new_client() - e = client.event.push("user:create", {"test": "test"}) - - assert e.eventId is not None - - -@pytest.mark.filterwarnings( - "ignore:Direct access to client is deprecated:DeprecationWarning" -) -@pytest.mark.asyncio(scope="session") -async def test_hatchet_client_event() -> None: - hatchet = Hatchet() - e = hatchet.client.event.push("user:create", {"test": "test"}) - - assert e.eventId is not None diff --git a/examples/affinity-workers/event.py b/examples/affinity-workers/event.py index 3d4cae41..83e7024d 100644 --- a/examples/affinity-workers/event.py +++ b/examples/affinity-workers/event.py @@ -1,13 +1,10 @@ -from dotenv import load_dotenv - +from hatchet_sdk.clients.events import PushEventOptions from hatchet_sdk.hatchet import Hatchet -load_dotenv() - hatchet = Hatchet(debug=True) hatchet.event.push( "affinity:run", {"test": "test"}, - options={"additional_metadata": {"hello": "moon"}}, + options=PushEventOptions(additional_metadata={"hello": "moon"}), ) diff --git a/examples/affinity-workers/worker.py b/examples/affinity-workers/worker.py index 5099a804..beb9d0f9 100644 --- a/examples/affinity-workers/worker.py +++ b/examples/affinity-workers/worker.py @@ -1,22 +1,22 @@ -from dotenv import load_dotenv +from hatchet_sdk import BaseWorkflow, Context, Hatchet, WorkerLabelComparator +from hatchet_sdk.labels import DesiredWorkerLabel -from hatchet_sdk import Context, Hatchet, WorkerLabelComparator +hatchet = Hatchet(debug=True) -load_dotenv() +wf = hatchet.declare_workflow(on_events=["affinity:run"]) -hatchet = Hatchet(debug=True) +class AffinityWorkflow(BaseWorkflow): + config = wf.config -@hatchet.workflow(on_events=["affinity:run"]) -class AffinityWorkflow: @hatchet.step( desired_worker_labels={ - "model": {"value": "fancy-ai-model-v2", "weight": 10}, - "memory": { - "value": 256, - "required": True, - "comparator": WorkerLabelComparator.LESS_THAN, - }, + "model": DesiredWorkerLabel(value="fancy-ai-model-v2", weight=10), + "memory": DesiredWorkerLabel( + value=256, + required=True, + comparator=WorkerLabelComparator.LESS_THAN, + ), }, ) async def step(self, context: Context) -> dict[str, str | None]: diff --git a/examples/api/api.py b/examples/api/api.py index b8232846..530232ef 100644 --- a/examples/api/api.py +++ b/examples/api/api.py @@ -1,8 +1,4 @@ -from dotenv import load_dotenv - -from hatchet_sdk import Hatchet, WorkflowList - -load_dotenv() +from hatchet_sdk import Hatchet hatchet = Hatchet(debug=True) diff --git a/examples/api/async_api.py b/examples/api/async_api.py index 3d1b36bd..88181c39 100644 --- a/examples/api/async_api.py +++ b/examples/api/async_api.py @@ -1,11 +1,6 @@ import asyncio -from typing import cast -from dotenv import load_dotenv - -from hatchet_sdk import Hatchet, WorkflowList - -load_dotenv() +from hatchet_sdk import Hatchet hatchet = Hatchet(debug=True) diff --git a/examples/async/event.py b/examples/async/event.py deleted file mode 100644 index a0fc8cf7..00000000 --- a/examples/async/event.py +++ /dev/null @@ -1,8 +0,0 @@ -from dotenv import load_dotenv - -from hatchet_sdk import Hatchet - -load_dotenv() - -hatchet = Hatchet() -hatchet.event.push("async:create", {"test": "test"}) diff --git a/examples/async/test_async.py b/examples/async/test_async.py deleted file mode 100644 index 00b6c805..00000000 --- a/examples/async/test_async.py +++ /dev/null @@ -1,21 +0,0 @@ -import pytest - -from hatchet_sdk import Hatchet, Worker - - -# requires scope module or higher for shared event loop -@pytest.mark.asyncio(scope="session") -@pytest.mark.parametrize("worker", ["async"], indirect=True) -async def test_run(hatchet: Hatchet, worker: Worker) -> None: - run = hatchet.admin.run_workflow("AsyncWorkflow", {}) - result = await run.result() - assert result["step1"]["test"] == "test" - - -@pytest.mark.parametrize("worker", ["async"], indirect=True) -@pytest.mark.skip(reason="Skipping this test until we can dedicate more time to debug") -@pytest.mark.asyncio(scope="session") -async def test_run_async(aiohatchet: Hatchet, worker: Worker) -> None: - run = await aiohatchet.admin.aio.run_workflow("AsyncWorkflow", {}) - result = await run.result() - assert result["step1"]["test"] == "test" diff --git a/examples/async/worker.py b/examples/async/worker.py deleted file mode 100644 index 55cf3d0d..00000000 --- a/examples/async/worker.py +++ /dev/null @@ -1,37 +0,0 @@ -import asyncio - -from dotenv import load_dotenv - -from hatchet_sdk import Context, Hatchet - -load_dotenv() - -hatchet = Hatchet(debug=True) - - -@hatchet.workflow(on_events=["async:create"]) -class AsyncWorkflow: - - @hatchet.step(timeout="10s") - async def step1(self, context: Context) -> dict[str, str]: - print("started step1") - return {"test": "test"} - - @hatchet.step(parents=["step1"], timeout="10s") - async def step2(self, context: Context) -> None: - print("finished step2") - - -async def _main() -> None: - workflow = AsyncWorkflow() - worker = hatchet.worker("async-worker", max_runs=4) - worker.register_workflow(workflow) - await worker.async_start() - - -def main() -> None: - asyncio.run(_main()) - - -if __name__ == "__main__": - main() diff --git a/examples/blocked_async/event.py b/examples/blocked_async/event.py index 116b227d..4d6b5eab 100644 --- a/examples/blocked_async/event.py +++ b/examples/blocked_async/event.py @@ -1,12 +1,9 @@ -from dotenv import load_dotenv - from hatchet_sdk import PushEventOptions, new_client -load_dotenv() - client = new_client() -# client.event.push("user:create", {"test": "test"}) client.event.push( - "user:create", {"test": "test"}, options={"additional_metadata": {"hello": "moon"}} + "user:create", + {"test": "test"}, + options=PushEventOptions(additional_metadata={"hello": "moon"}), ) diff --git a/examples/blocked_async/worker.py b/examples/blocked_async/worker.py index a6aa3e18..e1ffcb28 100644 --- a/examples/blocked_async/worker.py +++ b/examples/blocked_async/worker.py @@ -1,11 +1,7 @@ import hashlib import time -from dotenv import load_dotenv - -from hatchet_sdk import Context, Hatchet - -load_dotenv() +from hatchet_sdk import BaseWorkflow, Context, Hatchet hatchet = Hatchet(debug=True) @@ -15,9 +11,12 @@ # # You do not want to run long sync functions in an async def function +wf = hatchet.declare_workflow(on_events=["user:create"]) + + +class Blocked(BaseWorkflow): + config = wf.config -@hatchet.workflow(on_events=["user:create"]) -class Blocked: @hatchet.step(timeout="11s", retries=3) async def step1(self, context: Context) -> dict[str, str | int | float]: print("Executing step1") @@ -43,9 +42,8 @@ async def step1(self, context: Context) -> dict[str, str | int | float]: def main() -> None: - workflow = Blocked() worker = hatchet.worker("blocked-worker", max_runs=3) - worker.register_workflow(workflow) + worker.register_workflow(Blocked()) worker.start() diff --git a/examples/bulk_fanout/bulk_trigger.py b/examples/bulk_fanout/bulk_trigger.py index d0606673..c596bb98 100644 --- a/examples/bulk_fanout/bulk_trigger.py +++ b/examples/bulk_fanout/bulk_trigger.py @@ -1,40 +1,29 @@ import asyncio -import base64 -import json -import os -from typing import Any - -from dotenv import load_dotenv from hatchet_sdk import new_client -from hatchet_sdk.clients.admin import TriggerWorkflowOptions -from hatchet_sdk.clients.rest.models.workflow_run import WorkflowRun -from hatchet_sdk.clients.run_event_listener import StepRunEventType +from hatchet_sdk.clients.admin import TriggerWorkflowOptions, WorkflowRunDict async def main() -> None: - load_dotenv() + hatchet = new_client() - workflowRuns: list[dict[str, Any]] = [] - - # we are going to run the BulkParent workflow 20 which will trigger the Child workflows n times for each n in range(20) - for i in range(20): - workflowRuns.append( - { - "workflow_name": "BulkParent", - "input": {"n": i}, - "options": { - "additional_metadata": { - "bulk-trigger": i, - "hello-{i}": "earth-{i}", - }, - }, - } + workflow_runs = [ + WorkflowRunDict( + workflow_name="BulkParent", + input={"n": i}, + options=TriggerWorkflowOptions( + additional_metadata={ + "bulk-trigger": i, + "hello-{i}": "earth-{i}", + } + ), ) + for i in range(20) + ] workflowRunRefs = hatchet.admin.run_workflows( - workflowRuns, + workflow_runs, ) results = await asyncio.gather( diff --git a/examples/bulk_fanout/stream.py b/examples/bulk_fanout/stream.py index 08d0cb4a..c357d882 100644 --- a/examples/bulk_fanout/stream.py +++ b/examples/bulk_fanout/stream.py @@ -1,19 +1,12 @@ import asyncio -import base64 -import json -import os import random -from dotenv import load_dotenv - -from hatchet_sdk import new_client +from hatchet_sdk import Hatchet from hatchet_sdk.clients.admin import TriggerWorkflowOptions -from hatchet_sdk.clients.run_event_listener import StepRunEventType -from hatchet_sdk.v2.hatchet import Hatchet async def main() -> None: - load_dotenv() + hatchet = Hatchet() # Generate a random stream key to use to track all @@ -28,10 +21,10 @@ async def main() -> None: # This key gets propagated to all child workflows # and can have an arbitrary property name. - workflowRun = hatchet.admin.run_workflow( + hatchet.admin.run_workflow( "Parent", {"n": 2}, - options={"additional_metadata": {streamKey: streamVal}}, + options=TriggerWorkflowOptions(additional_metadata={streamKey: streamVal}), ) # Stream all events for the additional meta key value diff --git a/examples/bulk_fanout/trigger.py b/examples/bulk_fanout/trigger.py index 1a1b3f17..ddcd8524 100644 --- a/examples/bulk_fanout/trigger.py +++ b/examples/bulk_fanout/trigger.py @@ -1,24 +1,17 @@ import asyncio -import base64 -import json -import os - -from dotenv import load_dotenv from hatchet_sdk import new_client -from hatchet_sdk.clients.admin import TriggerWorkflowOptions -from hatchet_sdk.clients.rest.models.workflow_run import WorkflowRun -from hatchet_sdk.clients.run_event_listener import StepRunEventType +from hatchet_sdk.clients.events import PushEventOptions async def main() -> None: - load_dotenv() - hatchet = new_client() - workflowRuns: WorkflowRun = [] # type: ignore[assignment] + hatchet = new_client() - event = hatchet.event.push( - "parent:create", {"n": 999}, {"additional_metadata": {"no-dedupe": "world"}} + hatchet.event.push( + "parent:create", + {"n": 999}, + PushEventOptions(additional_metadata={"no-dedupe": "world"}), ) diff --git a/examples/bulk_fanout/worker.py b/examples/bulk_fanout/worker.py index e0ea3c50..9ac89e0a 100644 --- a/examples/bulk_fanout/worker.py +++ b/examples/bulk_fanout/worker.py @@ -1,18 +1,33 @@ import asyncio from typing import Any -from dotenv import load_dotenv +from pydantic import BaseModel -from hatchet_sdk import Context, Hatchet -from hatchet_sdk.clients.admin import ChildWorkflowRunDict - -load_dotenv() +from hatchet_sdk import BaseWorkflow, Context, Hatchet +from hatchet_sdk.clients.admin import ChildTriggerWorkflowOptions hatchet = Hatchet(debug=True) -@hatchet.workflow(on_events=["parent:create"]) -class BulkParent: +class ParentInput(BaseModel): + n: int = 100 + + +class ChildInput(BaseModel): + a: str + + +bulk_parent_wf = hatchet.declare_workflow( + on_events=["parent:create"], input_validator=ParentInput +) +bulk_child_wf = hatchet.declare_workflow( + on_events=["child:create"], input_validator=ChildInput +) + + +class BulkParent(BaseWorkflow): + config = bulk_parent_wf.config + @hatchet.step(timeout="5m") async def spawn(self, context: Context) -> dict[str, list[Any]]: print("spawning child") @@ -20,25 +35,23 @@ async def spawn(self, context: Context) -> dict[str, list[Any]]: context.put_stream("spawning...") results = [] - n = context.workflow_input().get("n", 100) - - child_workflow_runs: list[ChildWorkflowRunDict] = [] + n = bulk_parent_wf.get_workflow_input(context).n - for i in range(n): - - child_workflow_runs.append( - { - "workflow_name": "BulkChild", - "input": {"a": str(i)}, - "key": f"child{i}", - "options": {"additional_metadata": {"hello": "earth"}}, - } + child_workflow_runs = [ + bulk_child_wf.construct_spawn_workflow_input( + input=ChildInput(a=str(i)), + key=f"child{i}", + options=ChildTriggerWorkflowOptions( + additional_metadata={"hello": "earth"} + ), ) + for i in range(n) + ] if len(child_workflow_runs) == 0: return {} - spawn_results = await context.aio.spawn_workflows(child_workflow_runs) + spawn_results = await bulk_child_wf.spawn_many(context, child_workflow_runs) results = await asyncio.gather( *[workflowRunRef.result() for workflowRunRef in spawn_results], @@ -56,11 +69,12 @@ async def spawn(self, context: Context) -> dict[str, list[Any]]: return {"results": results} -@hatchet.workflow(on_events=["child:create"]) -class BulkChild: +class BulkChild(BaseWorkflow): + config = bulk_child_wf.config + @hatchet.step() def process(self, context: Context) -> dict[str, str]: - a = context.workflow_input()["a"] + a = bulk_child_wf.get_workflow_input(context).a print(f"child process {a}") context.put_stream("child 1...") return {"status": "success " + a} diff --git a/examples/cancellation/worker.py b/examples/cancellation/worker.py index 17716602..20de5eea 100644 --- a/examples/cancellation/worker.py +++ b/examples/cancellation/worker.py @@ -1,16 +1,15 @@ import asyncio -from dotenv import load_dotenv +from hatchet_sdk import BaseWorkflow, Context, Hatchet -from hatchet_sdk import Context, Hatchet +hatchet = Hatchet(debug=True) -load_dotenv() +wf = hatchet.declare_workflow(on_events=["user:create"]) -hatchet = Hatchet(debug=True) +class CancelWorkflow(BaseWorkflow): + config = wf.config -@hatchet.workflow(on_events=["user:create"]) -class CancelWorkflow: @hatchet.step(timeout="10s", retries=1) async def step1(self, context: Context) -> None: i = 0 @@ -24,9 +23,8 @@ async def step1(self, context: Context) -> None: def main() -> None: - workflow = CancelWorkflow() worker = hatchet.worker("cancellation-worker", max_runs=4) - worker.register_workflow(workflow) + worker.register_workflow(CancelWorkflow()) worker.start() diff --git a/examples/concurrency_limit/event.py b/examples/concurrency_limit/event.py index 599f48d7..e6662720 100644 --- a/examples/concurrency_limit/event.py +++ b/examples/concurrency_limit/event.py @@ -1,9 +1,5 @@ -from dotenv import load_dotenv - from hatchet_sdk import new_client -load_dotenv() - client = new_client() client.event.push("concurrency-test", {"group": "test"}) diff --git a/examples/concurrency_limit/worker.py b/examples/concurrency_limit/worker.py index 1474b2d2..4c25e376 100644 --- a/examples/concurrency_limit/worker.py +++ b/examples/concurrency_limit/worker.py @@ -1,39 +1,50 @@ import time from typing import Any -from dotenv import load_dotenv +from pydantic import BaseModel + +from hatchet_sdk import ( + BaseWorkflow, + ConcurrencyExpression, + ConcurrencyLimitStrategy, + Context, + Hatchet, +) -from hatchet_sdk import Context, Hatchet -from hatchet_sdk.contracts.workflows_pb2 import ConcurrencyLimitStrategy -from hatchet_sdk.workflow import ConcurrencyExpression +hatchet = Hatchet(debug=True) -load_dotenv() -hatchet = Hatchet(debug=True) +class WorkflowInput(BaseModel): + run: int + group: str -@hatchet.workflow( +wf = hatchet.declare_workflow( on_events=["concurrency-test"], concurrency=ConcurrencyExpression( expression="input.group", max_runs=5, limit_strategy=ConcurrencyLimitStrategy.CANCEL_IN_PROGRESS, ), + input_validator=WorkflowInput, ) -class ConcurrencyDemoWorkflow: + + +class ConcurrencyDemoWorkflow(BaseWorkflow): + + config = wf.config @hatchet.step() def step1(self, context: Context) -> dict[str, Any]: - input = context.workflow_input() + input = wf.get_workflow_input(context) time.sleep(3) print("executed step1") - return {"run": input["run"]} + return {"run": input.run} def main() -> None: - workflow = ConcurrencyDemoWorkflow() worker = hatchet.worker("concurrency-demo-worker", max_runs=10) - worker.register_workflow(workflow) + worker.register_workflow(ConcurrencyDemoWorkflow()) worker.start() diff --git a/examples/concurrency_limit_rr/event.py b/examples/concurrency_limit_rr/event.py index 16b2bcd0..6b58f2cf 100644 --- a/examples/concurrency_limit_rr/event.py +++ b/examples/concurrency_limit_rr/event.py @@ -1,9 +1,5 @@ -from dotenv import load_dotenv - from hatchet_sdk import new_client -load_dotenv() - client = new_client() for i in range(200): diff --git a/examples/concurrency_limit_rr/worker.py b/examples/concurrency_limit_rr/worker.py index 58722682..9645cca0 100644 --- a/examples/concurrency_limit_rr/worker.py +++ b/examples/concurrency_limit_rr/worker.py @@ -1,20 +1,16 @@ import time -from dotenv import load_dotenv - from hatchet_sdk import ( + BaseWorkflow, ConcurrencyExpression, ConcurrencyLimitStrategy, Context, Hatchet, ) -load_dotenv() - hatchet = Hatchet(debug=True) - -@hatchet.workflow( +wf = hatchet.declare_workflow( on_events=["concurrency-test"], schedule_timeout="10m", concurrency=ConcurrencyExpression( @@ -23,7 +19,10 @@ limit_strategy=ConcurrencyLimitStrategy.GROUP_ROUND_ROBIN, ), ) -class ConcurrencyDemoWorkflowRR: + + +class ConcurrencyDemoWorkflowRR(BaseWorkflow): + config = wf.config @hatchet.step() def step1(self, context: Context) -> None: @@ -34,9 +33,8 @@ def step1(self, context: Context) -> None: def main() -> None: - workflow = ConcurrencyDemoWorkflowRR() worker = hatchet.worker("concurrency-demo-worker-rr", max_runs=10) - worker.register_workflow(workflow) + worker.register_workflow(ConcurrencyDemoWorkflowRR()) worker.start() diff --git a/examples/cron/programatic-async.py b/examples/cron/programatic-async.py index 0108d2d2..287c57c3 100644 --- a/examples/cron/programatic-async.py +++ b/examples/cron/programatic-async.py @@ -1,9 +1,5 @@ -from dotenv import load_dotenv - from hatchet_sdk import Hatchet -load_dotenv() - hatchet = Hatchet() @@ -21,11 +17,11 @@ async def create_cron() -> None: }, ) - id = cron_trigger.metadata.id # the id of the cron trigger + cron_trigger.metadata.id # the id of the cron trigger # !! # ❓ List - cron_triggers = await hatchet.cron.aio.list() + await hatchet.cron.aio.list() # !! # ❓ Get diff --git a/examples/cron/programatic-sync.py b/examples/cron/programatic-sync.py index d5c74d48..cc557714 100644 --- a/examples/cron/programatic-sync.py +++ b/examples/cron/programatic-sync.py @@ -1,9 +1,5 @@ -from dotenv import load_dotenv - from hatchet_sdk import Hatchet -load_dotenv() - hatchet = Hatchet() # ❓ Create diff --git a/examples/cron/workflow-definition.py b/examples/cron/workflow-definition.py index 1e49076f..c506a79e 100644 --- a/examples/cron/workflow-definition.py +++ b/examples/cron/workflow-definition.py @@ -1,10 +1,4 @@ -import time - -from dotenv import load_dotenv - -from hatchet_sdk import Context, Hatchet - -load_dotenv() +from hatchet_sdk import BaseWorkflow, Context, Hatchet hatchet = Hatchet(debug=True) @@ -13,8 +7,13 @@ # Adding a cron trigger to a workflow is as simple # as adding a `cron expression` to the `on_cron` # prop of the workflow definition -@hatchet.workflow(on_crons=["* * * * *"]) -class CronWorkflow: + +wf = hatchet.declare_workflow(on_crons=["* * * * *"]) + + +class CronWorkflow(BaseWorkflow): + config = wf.config + @hatchet.step() def step1(self, context: Context) -> dict[str, str]: @@ -27,9 +26,8 @@ def step1(self, context: Context) -> dict[str, str]: def main() -> None: - workflow = CronWorkflow() worker = hatchet.worker("test-worker", max_runs=1) - worker.register_workflow(workflow) + worker.register_workflow(CronWorkflow()) worker.start() diff --git a/examples/dag/event.py b/examples/dag/event.py index ba6f881e..e86313aa 100644 --- a/examples/dag/event.py +++ b/examples/dag/event.py @@ -1,8 +1,4 @@ -from dotenv import load_dotenv - -from hatchet_sdk import Context, Hatchet - -load_dotenv() +from hatchet_sdk import Hatchet hatchet = Hatchet(debug=True) diff --git a/examples/dag/worker.py b/examples/dag/worker.py index 45ecece8..062df660 100644 --- a/examples/dag/worker.py +++ b/examples/dag/worker.py @@ -2,17 +2,15 @@ import time from typing import Any, cast -from dotenv import load_dotenv - -from hatchet_sdk import Context, Hatchet - -load_dotenv() +from hatchet_sdk import BaseWorkflow, Context, Hatchet hatchet = Hatchet(debug=True) +wf = hatchet.declare_workflow(on_events=["dag:create"], schedule_timeout="10m") + -@hatchet.workflow(on_events=["dag:create"], schedule_timeout="10m") -class DagWorkflow: +class DagWorkflow(BaseWorkflow): + config = wf.config @hatchet.step(timeout="5s") def step1(self, context: Context) -> dict[str, int]: @@ -46,7 +44,7 @@ def step4(self, context: Context) -> dict[str, str]: print( "executed step4", time.strftime("%H:%M:%S", time.localtime()), - context.workflow_input(), + context.workflow_input, context.step_output("step1"), context.step_output("step3"), ) @@ -56,9 +54,8 @@ def step4(self, context: Context) -> dict[str, str]: def main() -> None: - workflow = DagWorkflow() worker = hatchet.worker("dag-worker") - worker.register_workflow(workflow) + worker.register_workflow(DagWorkflow()) worker.start() diff --git a/examples/dedupe/worker.py b/examples/dedupe/worker.py index 2f22f52d..9f6f3381 100644 --- a/examples/dedupe/worker.py +++ b/examples/dedupe/worker.py @@ -1,20 +1,17 @@ import asyncio -import random from typing import Any -from dotenv import load_dotenv - -from hatchet_sdk import Context, Hatchet +from hatchet_sdk import BaseWorkflow, ChildTriggerWorkflowOptions, Context, Hatchet from hatchet_sdk.clients.admin import DedupeViolationErr -from hatchet_sdk.loader import ClientConfig - -load_dotenv() hatchet = Hatchet(debug=True) +dedupe_parent_wf = hatchet.declare_workflow(on_events=["parent:create"]) + + +class DedupeParent(BaseWorkflow): + config = dedupe_parent_wf.config -@hatchet.workflow(on_events=["parent:create"]) -class DedupeParent: @hatchet.step(timeout="1m") async def spawn(self, context: Context) -> dict[str, list[Any]]: print("spawning child") @@ -25,11 +22,13 @@ async def spawn(self, context: Context) -> dict[str, list[Any]]: try: results.append( ( - await context.aio.spawn_workflow( + await context.spawn_workflow( "DedupeChild", {"a": str(i)}, key=f"child{i}", - options={"additional_metadata": {"dedupe": "test"}}, + options=ChildTriggerWorkflowOptions( + additional_metadata={"dedupe": "test"} + ), ) ).result() ) @@ -43,13 +42,17 @@ async def spawn(self, context: Context) -> dict[str, list[Any]]: return {"results": result} -@hatchet.workflow(on_events=["child:create"]) -class DedupeChild: +dedupe_child_wf = hatchet.declare_workflow(on_events=["child:create"]) + + +class DedupeChild(BaseWorkflow): + config = dedupe_child_wf.config + @hatchet.step() async def process(self, context: Context) -> dict[str, str]: await asyncio.sleep(3) - print(f"child process") + print("child process") return {"status": "success"} @hatchet.step() diff --git a/examples/default_priority/worker.py b/examples/default_priority/worker.py deleted file mode 100644 index 070d20f9..00000000 --- a/examples/default_priority/worker.py +++ /dev/null @@ -1,40 +0,0 @@ -import asyncio -from typing import TypedDict - -from dotenv import load_dotenv - -from hatchet_sdk import Context -from hatchet_sdk.v2.hatchet import Hatchet - -load_dotenv() - -hatchet = Hatchet(debug=True) - - -class MyResultType(TypedDict): - return_string: str - - -@hatchet.function(default_priority=2) -async def high_prio_func(context: Context) -> MyResultType: - await asyncio.sleep(5) - return MyResultType(return_string="High Priority Return") - - -@hatchet.function(default_priority=1) -async def low_prio_func(context: Context) -> MyResultType: - await asyncio.sleep(5) - return MyResultType(return_string="Low Priority Return") - - -def main() -> None: - worker = hatchet.worker("example-priority-worker", max_runs=1) - hatchet.admin.run(high_prio_func, {"test": "test"}) - hatchet.admin.run(high_prio_func, {"test": "test"}) - hatchet.admin.run(low_prio_func, {"test": "test"}) - hatchet.admin.run(low_prio_func, {"test": "test"}) - worker.start() - - -if __name__ == "__main__": - main() diff --git a/examples/delayed/event.py b/examples/delayed/event.py index cdc16ac6..dd3a3262 100644 --- a/examples/delayed/event.py +++ b/examples/delayed/event.py @@ -1,9 +1,5 @@ -from dotenv import load_dotenv - from hatchet_sdk import new_client -load_dotenv() - client = new_client() client.event.push("printer:schedule", {"message": "test"}) diff --git a/examples/delayed/worker.py b/examples/delayed/worker.py index cea0b838..346afa1c 100644 --- a/examples/delayed/worker.py +++ b/examples/delayed/worker.py @@ -1,16 +1,16 @@ from datetime import datetime, timedelta +from typing import Any, cast -from dotenv import load_dotenv +from hatchet_sdk import BaseWorkflow, Context, Hatchet -from hatchet_sdk import Context, Hatchet +hatchet = Hatchet(debug=True) -load_dotenv() +print_schedule_wf = hatchet.declare_workflow(on_events=["printer:schedule"]) -hatchet = Hatchet(debug=True) +class PrintSchedule(BaseWorkflow): + config = print_schedule_wf.config -@hatchet.workflow(on_events=["printer:schedule"]) -class PrintSchedule: @hatchet.step() def schedule(self, context: Context) -> None: now = datetime.now() @@ -19,17 +19,21 @@ def schedule(self, context: Context) -> None: print(f"scheduling for \t {future_time.strftime('%H:%M:%S')}") hatchet.admin.schedule_workflow( - "PrintPrinter", [future_time], context.workflow_input() + "PrintPrinter", [future_time], cast(dict[str, Any], context.workflow_input) ) -@hatchet.workflow() -class PrintPrinter: +print_printer_wf = hatchet.declare_workflow() + + +class PrintPrinter(BaseWorkflow): + config = print_printer_wf.config + @hatchet.step() def step1(self, context: Context) -> None: now = datetime.now() print(f"printed at \t {now.strftime('%H:%M:%S')}") - print(f"message \t {context.workflow_input()['message']}") + print(f"message \t {cast(dict[str, Any], context.workflow_input)['message']}") def main() -> None: diff --git a/examples/durable_sticky_with_affinity/worker.py b/examples/durable_sticky_with_affinity/worker.py deleted file mode 100644 index 93505d98..00000000 --- a/examples/durable_sticky_with_affinity/worker.py +++ /dev/null @@ -1,72 +0,0 @@ -import asyncio -from typing import Any - -from dotenv import load_dotenv - -from hatchet_sdk import Context, StickyStrategy, WorkerLabelComparator -from hatchet_sdk.v2.callable import DurableContext -from hatchet_sdk.v2.hatchet import Hatchet - -load_dotenv() - -hatchet = Hatchet(debug=True) - - -@hatchet.durable( - sticky=StickyStrategy.HARD, - desired_worker_labels={ - "running_workflow": { - "value": "True", - "required": True, - "comparator": WorkerLabelComparator.NOT_EQUAL, - }, - }, -) -async def my_durable_func(context: DurableContext) -> dict[str, Any]: - try: - ref = await context.aio.spawn_workflow( - "StickyChildWorkflow", {}, options={"sticky": True} - ) - result = await ref.result() - except Exception as e: - result = str(e) - - await context.worker.async_upsert_labels({"running_workflow": "False"}) - return {"worker_result": result} - - -@hatchet.workflow(on_events=["sticky:child"], sticky=StickyStrategy.HARD) -class StickyChildWorkflow: - @hatchet.step( - desired_worker_labels={ - "running_workflow": { - "value": "True", - "required": True, - "comparator": WorkerLabelComparator.NOT_EQUAL, - }, - }, - ) - async def child(self, context: Context) -> dict[str, str | None]: - await context.worker.async_upsert_labels({"running_workflow": "True"}) - - print(f"Heavy work started on {context.worker.id()}") - await asyncio.sleep(15) - print(f"Finished Heavy work on {context.worker.id()}") - - return {"worker": context.worker.id()} - - -def main() -> None: - worker = hatchet.worker( - "sticky-worker", - max_runs=10, - labels={"running_workflow": "False"}, - ) - - worker.register_workflow(StickyChildWorkflow()) - - worker.start() - - -if __name__ == "__main__": - main() diff --git a/examples/events/event.py b/examples/events/event.py index 21b83cdd..52b0c1a5 100644 --- a/examples/events/event.py +++ b/examples/events/event.py @@ -1,8 +1,4 @@ -from dotenv import load_dotenv - from hatchet_sdk import Hatchet -load_dotenv() - hatchet = Hatchet() hatchet.event.push("user:create", {"test": "test"}) diff --git a/examples/events/test_event.py b/examples/events/test_event.py index a4fca8ae..9131bda8 100644 --- a/examples/events/test_event.py +++ b/examples/events/test_event.py @@ -1,5 +1,3 @@ -from typing import List - import pytest from hatchet_sdk.clients.events import BulkPushEventOptions, BulkPushEventWithMetadata @@ -24,34 +22,34 @@ async def test_async_event_push(aiohatchet: Hatchet) -> None: @pytest.mark.asyncio(scope="session") async def test_async_event_bulk_push(aiohatchet: Hatchet) -> None: - events: List[BulkPushEventWithMetadata] = [ - { - "key": "event1", - "payload": {"message": "This is event 1"}, - "additional_metadata": {"source": "test", "user_id": "user123"}, - }, - { - "key": "event2", - "payload": {"message": "This is event 2"}, - "additional_metadata": {"source": "test", "user_id": "user456"}, - }, - { - "key": "event3", - "payload": {"message": "This is event 3"}, - "additional_metadata": {"source": "test", "user_id": "user789"}, - }, + events = [ + BulkPushEventWithMetadata( + key="event1", + payload={"message": "This is event 1"}, + additional_metadata={"source": "test", "user_id": "user123"}, + ), + BulkPushEventWithMetadata( + key="event2", + payload={"message": "This is event 2"}, + additional_metadata={"source": "test", "user_id": "user456"}, + ), + BulkPushEventWithMetadata( + key="event3", + payload={"message": "This is event 3"}, + additional_metadata={"source": "test", "user_id": "user789"}, + ), ] - opts: BulkPushEventOptions = {"namespace": "bulk-test"} + opts = BulkPushEventOptions(namespace="bulk-test") e = await aiohatchet.event.async_bulk_push(events, opts) assert len(e) == 3 # Sort both lists of events by their key to ensure comparison order - sorted_events = sorted(events, key=lambda x: x["key"]) + sorted_events = sorted(events, key=lambda x: x.key) sorted_returned_events = sorted(e, key=lambda x: x.key) namespace = "bulk-test" # Check that the returned events match the original events for original_event, returned_event in zip(sorted_events, sorted_returned_events): - assert returned_event.key == namespace + original_event["key"] + assert returned_event.key == namespace + original_event.key diff --git a/examples/fanout/stream.py b/examples/fanout/stream.py index 08d0cb4a..c357d882 100644 --- a/examples/fanout/stream.py +++ b/examples/fanout/stream.py @@ -1,19 +1,12 @@ import asyncio -import base64 -import json -import os import random -from dotenv import load_dotenv - -from hatchet_sdk import new_client +from hatchet_sdk import Hatchet from hatchet_sdk.clients.admin import TriggerWorkflowOptions -from hatchet_sdk.clients.run_event_listener import StepRunEventType -from hatchet_sdk.v2.hatchet import Hatchet async def main() -> None: - load_dotenv() + hatchet = Hatchet() # Generate a random stream key to use to track all @@ -28,10 +21,10 @@ async def main() -> None: # This key gets propagated to all child workflows # and can have an arbitrary property name. - workflowRun = hatchet.admin.run_workflow( + hatchet.admin.run_workflow( "Parent", {"n": 2}, - options={"additional_metadata": {streamKey: streamVal}}, + options=TriggerWorkflowOptions(additional_metadata={streamKey: streamVal}), ) # Stream all events for the additional meta key value diff --git a/examples/fanout/sync_stream.py b/examples/fanout/sync_stream.py index 0b1a7140..3b628678 100644 --- a/examples/fanout/sync_stream.py +++ b/examples/fanout/sync_stream.py @@ -1,19 +1,11 @@ -import asyncio -import base64 -import json -import os import random -from dotenv import load_dotenv - -from hatchet_sdk import new_client +from hatchet_sdk import Hatchet from hatchet_sdk.clients.admin import TriggerWorkflowOptions -from hatchet_sdk.clients.run_event_listener import StepRunEventType -from hatchet_sdk.v2.hatchet import Hatchet def main() -> None: - load_dotenv() + hatchet = Hatchet() # Generate a random stream key to use to track all @@ -28,10 +20,10 @@ def main() -> None: # This key gets propagated to all child workflows # and can have an arbitrary property name. - workflowRun = hatchet.admin.run_workflow( + hatchet.admin.run_workflow( "Parent", {"n": 2}, - options={"additional_metadata": {streamKey: streamVal}}, + options=TriggerWorkflowOptions(additional_metadata={streamKey: streamVal}), ) # Stream all events for the additional meta key value diff --git a/examples/fanout/trigger.py b/examples/fanout/trigger.py index c34d01b3..ddeca7a6 100644 --- a/examples/fanout/trigger.py +++ b/examples/fanout/trigger.py @@ -1,23 +1,17 @@ import asyncio -import base64 -import json -import os - -from dotenv import load_dotenv from hatchet_sdk import new_client from hatchet_sdk.clients.admin import TriggerWorkflowOptions -from hatchet_sdk.clients.run_event_listener import StepRunEventType async def main() -> None: - load_dotenv() + hatchet = new_client() hatchet.admin.run_workflow( "Parent", {"test": "test"}, - options={"additional_metadata": {"hello": "moon"}}, + options=TriggerWorkflowOptions(additional_metadata={"hello": "moon"}), ) diff --git a/examples/fanout/worker.py b/examples/fanout/worker.py index c32344c9..cf8034dd 100644 --- a/examples/fanout/worker.py +++ b/examples/fanout/worker.py @@ -1,49 +1,66 @@ import asyncio from typing import Any -from dotenv import load_dotenv +from pydantic import BaseModel -from hatchet_sdk import Context, Hatchet - -load_dotenv() +from hatchet_sdk import BaseWorkflow, ChildTriggerWorkflowOptions, Context, Hatchet hatchet = Hatchet(debug=True) -@hatchet.workflow(on_events=["parent:create"]) -class Parent: +class ParentInput(BaseModel): + n: int = 100 + + +class ChildInput(BaseModel): + a: str + + +parent_wf = hatchet.declare_workflow( + on_events=["parent:create"], input_validator=ParentInput +) +child_wf = hatchet.declare_workflow( + on_events=["child:create"], input_validator=ChildInput +) + + +class Parent(BaseWorkflow): + config = parent_wf.config + @hatchet.step(timeout="5m") async def spawn(self, context: Context) -> dict[str, Any]: print("spawning child") context.put_stream("spawning...") - results = [] - - n = context.workflow_input().get("n", 100) - - for i in range(n): - results.append( - ( - await context.aio.spawn_workflow( - "Child", - {"a": str(i)}, - key=f"child{i}", - options={"additional_metadata": {"hello": "earth"}}, - ) - ).result() - ) - - result = await asyncio.gather(*results) + + n = parent_wf.get_workflow_input(context).n + + children = await asyncio.gather( + *[ + child_wf.spawn_one( + ctx=context, + input=ChildInput(a=str(i)), + key=f"child{i}", + options=ChildTriggerWorkflowOptions( + additional_metadata={"hello": "earth"} + ), + ) + for i in range(n) + ] + ) + + result = await asyncio.gather(*[child.result() for child in children]) print(f"results {result}") return {"results": result} -@hatchet.workflow(on_events=["child:create"]) -class Child: +class Child(BaseWorkflow): + config = child_wf.config + @hatchet.step() def process(self, context: Context) -> dict[str, str]: - a = context.workflow_input()["a"] + a = child_wf.get_workflow_input(context).a print(f"child process {a}") context.put_stream("child 1...") return {"status": "success " + a} diff --git a/examples/logger/client.py b/examples/logger/client.py index 28df738e..afda470e 100644 --- a/examples/logger/client.py +++ b/examples/logger/client.py @@ -1,14 +1,7 @@ -import json import logging -import sys -import time - -from dotenv import load_dotenv from hatchet_sdk import ClientConfig, Hatchet -load_dotenv() - logging.basicConfig(level=logging.INFO) hatchet = Hatchet( diff --git a/examples/logger/event.py b/examples/logger/event.py index 5f7818f6..4d6b5eab 100644 --- a/examples/logger/event.py +++ b/examples/logger/event.py @@ -1,11 +1,9 @@ -from dotenv import load_dotenv - from hatchet_sdk import PushEventOptions, new_client -load_dotenv() - client = new_client() client.event.push( - "user:create", {"test": "test"}, options={"additional_metadata": {"hello": "moon"}} + "user:create", + {"test": "test"}, + options=PushEventOptions(additional_metadata={"hello": "moon"}), ) diff --git a/examples/logger/worker.py b/examples/logger/worker.py index 8053f83d..8cbd0a53 100644 --- a/examples/logger/worker.py +++ b/examples/logger/worker.py @@ -1,7 +1,3 @@ -from logging import getLogger - -from dotenv import load_dotenv - from examples.logger.client import hatchet from examples.logger.workflow import LoggingWorkflow diff --git a/examples/logger/workflow.py b/examples/logger/workflow.py index a298a4c7..53c7b968 100644 --- a/examples/logger/workflow.py +++ b/examples/logger/workflow.py @@ -2,13 +2,12 @@ import time from examples.logger.client import hatchet -from hatchet_sdk import Context +from hatchet_sdk import BaseWorkflow, Context logger = logging.getLogger(__name__) -@hatchet.workflow() -class LoggingWorkflow: +class LoggingWorkflow(BaseWorkflow): @hatchet.step() def step1(self, context: Context) -> dict[str, str]: for i in range(12): diff --git a/examples/manual_trigger/stream.py b/examples/manual_trigger/stream.py index bc4adfab..11fd0726 100644 --- a/examples/manual_trigger/stream.py +++ b/examples/manual_trigger/stream.py @@ -3,21 +3,19 @@ import json import os -from dotenv import load_dotenv - from hatchet_sdk import new_client from hatchet_sdk.clients.admin import TriggerWorkflowOptions from hatchet_sdk.clients.run_event_listener import StepRunEventType async def main() -> None: - load_dotenv() + hatchet = new_client() workflowRun = hatchet.admin.run_workflow( "ManualTriggerWorkflow", {"test": "test"}, - options={"additional_metadata": {"hello": "moon"}}, + options=TriggerWorkflowOptions(additional_metadata={"hello": "moon"}), ) listener = workflowRun.stream() diff --git a/examples/manual_trigger/worker.py b/examples/manual_trigger/worker.py index 9cf1eef5..737d1bed 100644 --- a/examples/manual_trigger/worker.py +++ b/examples/manual_trigger/worker.py @@ -2,17 +2,16 @@ import os import time -from dotenv import load_dotenv +from hatchet_sdk import BaseWorkflow, Context, Hatchet -from hatchet_sdk import Context, Hatchet +hatchet = Hatchet(debug=True) -load_dotenv() +wf = hatchet.declare_workflow(on_events=["man:create"]) -hatchet = Hatchet(debug=True) +class ManualTriggerWorkflow(BaseWorkflow): + config = wf.config -@hatchet.workflow(on_events=["man:create"]) -class ManualTriggerWorkflow: @hatchet.step() def step1(self, context: Context) -> dict[str, str]: res = context.playground("res", "HELLO") @@ -48,9 +47,8 @@ def step2(self, context: Context) -> dict[str, str]: def main() -> None: - workflow = ManualTriggerWorkflow() worker = hatchet.worker("manual-worker", max_runs=4) - worker.register_workflow(workflow) + worker.register_workflow(ManualTriggerWorkflow()) worker.start() diff --git a/examples/on_failure/worker.py b/examples/on_failure/worker.py index ddf12373..ae277583 100644 --- a/examples/on_failure/worker.py +++ b/examples/on_failure/worker.py @@ -1,10 +1,6 @@ import json -from dotenv import load_dotenv - -from hatchet_sdk import Context, Hatchet - -load_dotenv() +from hatchet_sdk import BaseWorkflow, Context, Hatchet hatchet = Hatchet(debug=True) @@ -12,9 +8,12 @@ # This workflow will fail because the step will throw an error # we define an onFailure step to handle this case +on_failure_wf = hatchet.declare_workflow(on_events=["user:create"]) + + +class OnFailureWorkflow(BaseWorkflow): + config = on_failure_wf.config -@hatchet.workflow(on_events=["user:create"]) -class OnFailureWorkflow: @hatchet.step(timeout="1s") def step1(self, context: Context) -> None: # 👀 this step will always raise an exception @@ -27,7 +26,7 @@ def on_failure(self, context: Context) -> dict[str, str]: # or notify a user here # 👀 Fetch the errors from upstream step runs from the context - print(context.step_run_errors()) + print(context.step_run_errors) return {"status": "success"} @@ -38,8 +37,13 @@ def on_failure(self, context: Context) -> dict[str, str]: # ❓ OnFailure With Details # We can access the failure details in the onFailure step # via the context method -@hatchet.workflow(on_events=["user:create"]) -class OnFailureWorkflowWithDetails: + +on_failure_wf_with_details = hatchet.declare_workflow(on_events=["user:create"]) + + +class OnFailureWorkflowWithDetails(BaseWorkflow): + config = on_failure_wf_with_details.config + # ... defined as above @hatchet.step(timeout="1s") def step1(self, context: Context) -> None: @@ -62,11 +66,9 @@ def on_failure(self, context: Context) -> dict[str, str]: def main() -> None: - workflow = OnFailureWorkflow() - workflow2 = OnFailureWorkflowWithDetails() worker = hatchet.worker("on-failure-worker", max_runs=4) - worker.register_workflow(workflow) - worker.register_workflow(workflow2) + worker.register_workflow(OnFailureWorkflow()) + worker.register_workflow(OnFailureWorkflowWithDetails()) worker.start() diff --git a/examples/overrides/worker.py b/examples/overrides/worker.py index 1af30f1e..2185edad 100644 --- a/examples/overrides/worker.py +++ b/examples/overrides/worker.py @@ -1,16 +1,15 @@ import time -from dotenv import load_dotenv +from hatchet_sdk import BaseWorkflow, Context, Hatchet -from hatchet_sdk import Context, Hatchet +hatchet = Hatchet(debug=True) -load_dotenv() +wf = hatchet.declare_workflow(on_events=["overrides:create"], schedule_timeout="10m") -hatchet = Hatchet(debug=True) +class OverridesWorkflow(BaseWorkflow): + config = wf.config -@hatchet.workflow(on_events=["overrides:create"], schedule_timeout="10m") -class OverridesWorkflow: def __init__(self) -> None: self.my_value = "test" @@ -19,7 +18,7 @@ def step1(self, context: Context) -> dict[str, str | None]: print( "starting step1", time.strftime("%H:%M:%S", time.localtime()), - context.workflow_input(), + context.workflow_input, ) overrideValue = context.playground("prompt", "You are an AI assistant...") time.sleep(3) @@ -34,7 +33,7 @@ def step2(self, context: Context) -> dict[str, str]: print( "starting step2", time.strftime("%H:%M:%S", time.localtime()), - context.workflow_input(), + context.workflow_input, ) time.sleep(5) print("executed step2", time.strftime("%H:%M:%S", time.localtime())) @@ -47,7 +46,7 @@ def step3(self, context: Context) -> dict[str, str]: print( "executed step3", time.strftime("%H:%M:%S", time.localtime()), - context.workflow_input(), + context.workflow_input, context.step_output("step1"), context.step_output("step2"), ) @@ -60,7 +59,7 @@ def step4(self, context: Context) -> dict[str, str]: print( "executed step4", time.strftime("%H:%M:%S", time.localtime()), - context.workflow_input(), + context.workflow_input, context.step_output("step1"), context.step_output("step3"), ) diff --git a/examples/programatic_replay/script.py b/examples/programatic_replay/script.py index 28b608ab..4322ca36 100644 --- a/examples/programatic_replay/script.py +++ b/examples/programatic_replay/script.py @@ -1,10 +1,6 @@ -from dotenv import load_dotenv +from hatchet_sdk.hatchet import Hatchet -from hatchet_sdk.hatchet import HatchetRest - -load_dotenv() - -hatchet = HatchetRest() +hatchet = Hatchet() def main() -> None: diff --git a/examples/pydantic/trigger.py b/examples/pydantic/trigger.py index 18d4c60d..5946c7c7 100644 --- a/examples/pydantic/trigger.py +++ b/examples/pydantic/trigger.py @@ -1,12 +1,10 @@ import asyncio -from dotenv import load_dotenv - from hatchet_sdk import new_client async def main() -> None: - load_dotenv() + hatchet = new_client() hatchet.admin.run_workflow( diff --git a/examples/pydantic/worker.py b/examples/pydantic/worker.py index dbbda9ff..5dea3444 100644 --- a/examples/pydantic/worker.py +++ b/examples/pydantic/worker.py @@ -1,11 +1,8 @@ from typing import cast -from dotenv import load_dotenv from pydantic import BaseModel -from hatchet_sdk import Context, Hatchet - -load_dotenv() +from hatchet_sdk import BaseWorkflow, Context, Hatchet hatchet = Hatchet(debug=True) @@ -16,44 +13,47 @@ class ParentInput(BaseModel): x: str -@hatchet.workflow(input_validator=ParentInput) -class Parent: +class ChildInput(BaseModel): + a: int + b: int + + +parent_workflow = hatchet.declare_workflow(input_validator=ParentInput) +child_workflow = hatchet.declare_workflow(input_validator=ChildInput) + + +class Parent(BaseWorkflow): + config = parent_workflow.config + @hatchet.step(timeout="5m") async def spawn(self, context: Context) -> dict[str, str]: ## Use `typing.cast` to cast your `workflow_input` ## to the type of your `input_validator` - input = cast(ParentInput, context.workflow_input()) ## This is a `ParentInput` + parent_workflow.get_workflow_input(context) ## This is a `ParentInput` - child = await context.aio.spawn_workflow( - "Child", - {"a": 1, "b": "10"}, - ) + child = await child_workflow.spawn_one(ctx=context, input=ChildInput(a=1, b=10)) return cast(dict[str, str], await child.result()) -class ChildInput(BaseModel): - a: int - b: int - - class StepResponse(BaseModel): status: str -@hatchet.workflow(input_validator=ChildInput) -class Child: +class Child(BaseWorkflow): + config = child_workflow.config + @hatchet.step() def process(self, context: Context) -> StepResponse: ## This is an instance `ChildInput` - input = cast(ChildInput, context.workflow_input()) + child_workflow.get_workflow_input(context) return StepResponse(status="success") @hatchet.step(parents=["process"]) def process2(self, context: Context) -> StepResponse: ## This is an instance of `StepResponse` - process_output = cast(StepResponse, context.step_output("process")) + cast(StepResponse, context.step_output("process")) return {"status": "step 2 - success"} # type: ignore[return-value] @@ -63,7 +63,7 @@ def process3(self, context: Context) -> StepResponse: ## response of `process2` was a dictionary. Note that ## Hatchet will attempt to parse that dictionary into ## an object of type `StepResponse` - process_2_output = cast(StepResponse, context.step_output("process2")) + cast(StepResponse, context.step_output("process2")) return StepResponse(status="step 3 - success") diff --git a/examples/rate_limit/dynamic.py b/examples/rate_limit/dynamic.py index b07a956c..41e7c9ce 100644 --- a/examples/rate_limit/dynamic.py +++ b/examples/rate_limit/dynamic.py @@ -1,20 +1,18 @@ -from dotenv import load_dotenv - -from hatchet_sdk import Context, Hatchet -from hatchet_sdk.rate_limit import RateLimit, RateLimitDuration - -load_dotenv() +from hatchet_sdk import BaseWorkflow, Context, Hatchet +from hatchet_sdk.rate_limit import RateLimit hatchet = Hatchet(debug=True) +wf = hatchet.declare_workflow(on_events=["rate_limit:create"]) + -@hatchet.workflow(on_events=["rate_limit:create"]) -class RateLimitWorkflow: +class RateLimitWorkflow(BaseWorkflow): + config = wf.config @hatchet.step( rate_limits=[ RateLimit( - dynamic_key=f'"LIMIT:"+input.group', + dynamic_key='"LIMIT:"+input.group', units="input.units", limit="input.limit", ) diff --git a/examples/rate_limit/event.py b/examples/rate_limit/event.py index ed077fea..7d39202d 100644 --- a/examples/rate_limit/event.py +++ b/examples/rate_limit/event.py @@ -1,9 +1,5 @@ -from dotenv import load_dotenv - from hatchet_sdk.hatchet import Hatchet -load_dotenv() - hatchet = Hatchet(debug=True) hatchet.event.push("rate_limit:create", {"test": "1"}) diff --git a/examples/rate_limit/worker.py b/examples/rate_limit/worker.py index 97c19321..999fe477 100644 --- a/examples/rate_limit/worker.py +++ b/examples/rate_limit/worker.py @@ -1,15 +1,12 @@ -from dotenv import load_dotenv - -from hatchet_sdk import Context, Hatchet +from hatchet_sdk import BaseWorkflow, Context, Hatchet from hatchet_sdk.rate_limit import RateLimit, RateLimitDuration -load_dotenv() - hatchet = Hatchet(debug=True) +wf = hatchet.declare_workflow(on_events=["rate_limit:create"]) -@hatchet.workflow(on_events=["rate_limit:create"]) -class RateLimitWorkflow: +class RateLimitWorkflow(BaseWorkflow): + config = wf.config @hatchet.step(rate_limits=[RateLimit(key="test-limit", units=1)]) def step1(self, context: Context) -> None: diff --git a/examples/retries_with_backoff/worker.py b/examples/retries_with_backoff/worker.py index d4d33410..065b947f 100644 --- a/examples/retries_with_backoff/worker.py +++ b/examples/retries_with_backoff/worker.py @@ -1,11 +1,10 @@ -from hatchet_sdk import Context, Hatchet +from hatchet_sdk import BaseWorkflow, Context, Hatchet hatchet = Hatchet(debug=True) # ❓ Backoff -@hatchet.workflow() -class BackoffWorkflow: +class BackoffWorkflow(BaseWorkflow): # 👀 Backoff configuration @hatchet.step( retries=10, @@ -16,7 +15,7 @@ class BackoffWorkflow: backoff_factor=2.0, ) def step1(self, context: Context) -> dict[str, str]: - if context.retry_count() < 3: + if context.retry_count < 3: raise Exception("step1 failed") return {"status": "success"} @@ -26,9 +25,8 @@ def step1(self, context: Context) -> dict[str, str]: def main() -> None: - workflow = BackoffWorkflow() worker = hatchet.worker("backoff-worker", max_runs=4) - worker.register_workflow(workflow) + worker.register_workflow(BackoffWorkflow()) worker.start() diff --git a/examples/scheduled/programatic-async.py b/examples/scheduled/programatic-async.py index 06a469f0..674d270f 100644 --- a/examples/scheduled/programatic-async.py +++ b/examples/scheduled/programatic-async.py @@ -1,11 +1,6 @@ from datetime import datetime, timedelta -from dotenv import load_dotenv - from hatchet_sdk import Hatchet -from hatchet_sdk.clients.rest.models.scheduled_workflows import ScheduledWorkflows - -load_dotenv() hatchet = Hatchet() @@ -23,7 +18,7 @@ async def create_scheduled() -> None: }, ) - id = scheduled_run.metadata.id # the id of the scheduled run trigger + scheduled_run.metadata.id # the id of the scheduled run trigger # !! # ❓ Delete @@ -31,7 +26,7 @@ async def create_scheduled() -> None: # !! # ❓ List - scheduled_runs = await hatchet.scheduled.aio.list() + await hatchet.scheduled.aio.list() # !! # ❓ Get diff --git a/examples/scheduled/programatic-sync.py b/examples/scheduled/programatic-sync.py index 13a52ee8..313becd4 100644 --- a/examples/scheduled/programatic-sync.py +++ b/examples/scheduled/programatic-sync.py @@ -1,11 +1,7 @@ from datetime import datetime, timedelta -from dotenv import load_dotenv - from hatchet_sdk import Hatchet -load_dotenv() - hatchet = Hatchet() # ❓ Create diff --git a/examples/simple/event.py b/examples/simple/event.py index c2d0178a..0d20f668 100644 --- a/examples/simple/event.py +++ b/examples/simple/event.py @@ -1,41 +1,40 @@ -from typing import List - -from dotenv import load_dotenv - from hatchet_sdk import new_client -from hatchet_sdk.clients.events import BulkPushEventWithMetadata - -load_dotenv() +from hatchet_sdk.clients.events import ( + BulkPushEventOptions, + BulkPushEventWithMetadata, + PushEventOptions, +) client = new_client() # client.event.push("user:create", {"test": "test"}) client.event.push( - "user:create", {"test": "test"}, options={"additional_metadata": {"hello": "moon"}} + "user:create", + {"test": "test"}, + options=PushEventOptions(additional_metadata={"hello": "moon"}), ) -events: List[BulkPushEventWithMetadata] = [ - { - "key": "event1", - "payload": {"message": "This is event 1"}, - "additional_metadata": {"source": "test", "user_id": "user123"}, - }, - { - "key": "event2", - "payload": {"message": "This is event 2"}, - "additional_metadata": {"source": "test", "user_id": "user456"}, - }, - { - "key": "event3", - "payload": {"message": "This is event 3"}, - "additional_metadata": {"source": "test", "user_id": "user789"}, - }, +events = [ + BulkPushEventWithMetadata( + key="event1", + payload={"message": "This is event 1"}, + additional_metadata={"source": "test", "user_id": "user123"}, + ), + BulkPushEventWithMetadata( + key="event2", + payload={"message": "This is event 2"}, + additional_metadata={"source": "test", "user_id": "user456"}, + ), + BulkPushEventWithMetadata( + key="event3", + payload={"message": "This is event 3"}, + additional_metadata={"source": "test", "user_id": "user789"}, + ), ] result = client.event.bulk_push( - events, - options={"namespace": "bulk-test"}, + events, options=BulkPushEventOptions(namespace="bulk-test") ) print(result) diff --git a/examples/simple/worker.py b/examples/simple/worker.py index 0cbfb8c2..c6aab0ad 100644 --- a/examples/simple/worker.py +++ b/examples/simple/worker.py @@ -1,30 +1,22 @@ -import time - -from dotenv import load_dotenv - -from hatchet_sdk import Context, Hatchet - -load_dotenv() +from hatchet_sdk import BaseWorkflow, Context, Hatchet hatchet = Hatchet(debug=True) -@hatchet.workflow(on_events=["user:create"]) -class MyWorkflow: +class MyWorkflow(BaseWorkflow): @hatchet.step(timeout="11s", retries=3) def step1(self, context: Context) -> dict[str, str]: print("executed step1") - time.sleep(10) - # raise Exception("test") return { "step1": "step1", } def main() -> None: - workflow = MyWorkflow() + wf = MyWorkflow() + worker = hatchet.worker("test-worker", max_runs=1) - worker.register_workflow(workflow) + worker.register_workflow(wf) worker.start() diff --git a/examples/sticky_workers/event.py b/examples/sticky_workers/event.py index 55ed9b8f..ed0dc883 100644 --- a/examples/sticky_workers/event.py +++ b/examples/sticky_workers/event.py @@ -1,14 +1,11 @@ -from dotenv import load_dotenv - +from hatchet_sdk.clients.events import PushEventOptions from hatchet_sdk.hatchet import Hatchet -load_dotenv() - hatchet = Hatchet(debug=True) # client.event.push("user:create", {"test": "test"}) hatchet.event.push( "sticky:parent", {"test": "test"}, - options={"additional_metadata": {"hello": "moon"}}, + options=PushEventOptions(additional_metadata={"hello": "moon"}), ) diff --git a/examples/sticky_workers/worker.py b/examples/sticky_workers/worker.py index 266a5706..5fb4aab6 100644 --- a/examples/sticky_workers/worker.py +++ b/examples/sticky_workers/worker.py @@ -1,15 +1,21 @@ -from dotenv import load_dotenv +from hatchet_sdk import ( + BaseWorkflow, + ChildTriggerWorkflowOptions, + Context, + Hatchet, + StickyStrategy, +) -from hatchet_sdk import Context, Hatchet, StickyStrategy -from hatchet_sdk.context.context import ContextAioImpl +hatchet = Hatchet(debug=True) -load_dotenv() +sticky_workflow = hatchet.declare_workflow( + on_events=["sticky:parent"], sticky=StickyStrategy.SOFT +) -hatchet = Hatchet(debug=True) +class StickyWorkflow(BaseWorkflow): + config = sticky_workflow.config -@hatchet.workflow(on_events=["sticky:parent"], sticky=StickyStrategy.SOFT) -class StickyWorkflow: @hatchet.step() def step1a(self, context: Context) -> dict[str, str | None]: return {"worker": context.worker.id()} @@ -19,9 +25,9 @@ def step1b(self, context: Context) -> dict[str, str | None]: return {"worker": context.worker.id()} @hatchet.step(parents=["step1a", "step1b"]) - async def step2(self, context: ContextAioImpl) -> dict[str, str | None]: + async def step2(self, context: Context) -> dict[str, str | None]: ref = await context.spawn_workflow( - "StickyChildWorkflow", {}, options={"sticky": True} + "StickyChildWorkflow", {}, options=ChildTriggerWorkflowOptions(sticky=True) ) await ref.result() @@ -29,8 +35,14 @@ async def step2(self, context: ContextAioImpl) -> dict[str, str | None]: return {"worker": context.worker.id()} -@hatchet.workflow(on_events=["sticky:child"], sticky=StickyStrategy.SOFT) -class StickyChildWorkflow: +sticky_child_workflow = hatchet.declare_workflow( + on_events=["sticky:child"], sticky=StickyStrategy.SOFT +) + + +class StickyChildWorkflow(BaseWorkflow): + config = sticky_child_workflow.config + @hatchet.step() def child(self, context: Context) -> dict[str, str | None]: return {"worker": context.worker.id()} diff --git a/examples/sync_to_async/worker.py b/examples/sync_to_async/worker.py deleted file mode 100644 index 5ac3a912..00000000 --- a/examples/sync_to_async/worker.py +++ /dev/null @@ -1,98 +0,0 @@ -import asyncio -import os -import time -from typing import Any - -from dotenv import load_dotenv - -from hatchet_sdk import Context, sync_to_async -from hatchet_sdk.v2.hatchet import Hatchet - -os.environ["PYTHONASYNCIODEBUG"] = "1" -load_dotenv() - -hatchet = Hatchet(debug=True) - - -@hatchet.function() -async def fanout_sync_async(context: Context) -> dict[str, Any]: - print("spawning child") - - context.put_stream("spawning...") - results = [] - - n = context.workflow_input().get("n", 10) - - start_time = time.time() - for i in range(n): - results.append( - ( - await context.aio.spawn_workflow( - "Child", - {"a": str(i)}, - key=f"child{i}", - options={"additional_metadata": {"hello": "earth"}}, - ) - ).result() - ) - - result = await asyncio.gather(*results) - - execution_time = time.time() - start_time - print(f"Completed in {execution_time:.2f} seconds") - - return {"results": result} - - -@hatchet.workflow(on_events=["child:create"]) -class Child: - ###### Example Functions ###### - def sync_blocking_function(self) -> dict[str, str]: - time.sleep(5) - return {"type": "sync_blocking"} - - @sync_to_async # this makes the function async safe! - def decorated_sync_blocking_function(self) -> dict[str, str]: - time.sleep(5) - return {"type": "decorated_sync_blocking"} - - @sync_to_async # this makes the async function loop safe! - async def async_blocking_function(self) -> dict[str, str]: - time.sleep(5) - return {"type": "async_blocking"} - - ###### Hatchet Steps ###### - @hatchet.step() - async def handle_blocking_sync_in_async(self, context: Context) -> dict[str, str]: - wrapped_blocking_function = sync_to_async(self.sync_blocking_function) - - # This will now be async safe! - data = await wrapped_blocking_function() - return {"blocking_status": "success", "data": data} - - @hatchet.step() - async def handle_decorated_blocking_sync_in_async( - self, context: Context - ) -> dict[str, str]: - data = await self.decorated_sync_blocking_function() - return {"blocking_status": "success", "data": data} - - @hatchet.step() - async def handle_blocking_async_in_async(self, context: Context) -> dict[str, str]: - data = await self.async_blocking_function() - return {"blocking_status": "success", "data": data} - - @hatchet.step() - async def non_blocking_async(self, context: Context) -> dict[str, str]: - await asyncio.sleep(5) - return {"nonblocking_status": "success"} - - -def main() -> None: - worker = hatchet.worker("fanout-worker", max_runs=50) - worker.register_workflow(Child()) - worker.start() - - -if __name__ == "__main__": - main() diff --git a/examples/timeout/event.py b/examples/timeout/event.py index 53fe9473..cb42df8f 100644 --- a/examples/timeout/event.py +++ b/examples/timeout/event.py @@ -1,9 +1,5 @@ -from dotenv import load_dotenv - from hatchet_sdk import new_client -load_dotenv() - client = new_client() client.event.push("user:create", {"test": "test"}) diff --git a/examples/timeout/worker.py b/examples/timeout/worker.py index 10aaf136..ee986a0b 100644 --- a/examples/timeout/worker.py +++ b/examples/timeout/worker.py @@ -1,16 +1,14 @@ import time -from dotenv import load_dotenv - -from hatchet_sdk import Context, Hatchet - -load_dotenv() +from hatchet_sdk import BaseWorkflow, Context, Hatchet hatchet = Hatchet(debug=True) +timeout_wf = hatchet.declare_workflow(on_events=["timeout:create"]) + -@hatchet.workflow(on_events=["timeout:create"]) -class TimeoutWorkflow: +class TimeoutWorkflow(BaseWorkflow): + config = timeout_wf.config @hatchet.step(timeout="4s") def step1(self, context: Context) -> dict[str, str]: @@ -18,8 +16,11 @@ def step1(self, context: Context) -> dict[str, str]: return {"status": "success"} -@hatchet.workflow(on_events=["refresh:create"]) -class RefreshTimeoutWorkflow: +refresh_timeout_wf = hatchet.declare_workflow(on_events=["refresh:create"]) + + +class RefreshTimeoutWorkflow(BaseWorkflow): + config = refresh_timeout_wf.config @hatchet.step(timeout="4s") def step1(self, context: Context) -> dict[str, str]: diff --git a/examples/v2/simple/test_v2_worker.py b/examples/v2/simple/test_v2_worker.py deleted file mode 100644 index c06dae1f..00000000 --- a/examples/v2/simple/test_v2_worker.py +++ /dev/null @@ -1,26 +0,0 @@ -import pytest - -from examples.v2.simple.worker import MyResultType, my_durable_func, my_func -from hatchet_sdk import Hatchet, Worker -from hatchet_sdk.workflow_run import RunRef - - -# requires scope module or higher for shared event loop -@pytest.mark.asyncio(scope="session") -@pytest.mark.parametrize("worker", ["v2_simple"], indirect=True) -async def test_durable(hatchet: Hatchet, worker: Worker) -> None: - durable_run: RunRef[dict[str, str]] = hatchet.admin.run( - my_durable_func, {"test": "test"} - ) - result = await durable_run.result() - - assert result == {"my_durable_func": "testing123"} - - -@pytest.mark.asyncio(scope="session") -@pytest.mark.parametrize("worker", ["v2_simple"], indirect=True) -async def test_func(hatchet: Hatchet, worker: Worker) -> None: - durable_run: RunRef[MyResultType] = hatchet.admin.run(my_func, {"test": "test"}) - result = await durable_run.result() - - assert result == {"my_func": "testing123"} diff --git a/examples/v2/simple/worker.py b/examples/v2/simple/worker.py deleted file mode 100644 index 215bdd0e..00000000 --- a/examples/v2/simple/worker.py +++ /dev/null @@ -1,44 +0,0 @@ -import json -import time -from typing import Any, TypedDict, cast - -from dotenv import load_dotenv - -from hatchet_sdk import Context -from hatchet_sdk.v2.callable import DurableContext -from hatchet_sdk.v2.hatchet import Hatchet -from hatchet_sdk.workflow_run import RunRef - -load_dotenv() - -hatchet = Hatchet(debug=True) - - -class MyResultType(TypedDict): - my_func: str - - -@hatchet.function() -def my_func(context: Context) -> MyResultType: - return MyResultType(my_func="testing123") - - -@hatchet.durable() -async def my_durable_func(context: DurableContext) -> dict[str, MyResultType | None]: - result = cast(dict[str, Any], await context.run(my_func, {"test": "test"}).result()) - - context.log(result) - - return {"my_durable_func": result.get("my_func")} - - -def main() -> None: - worker = hatchet.worker("test-worker", max_runs=5) - - hatchet.admin.run(my_durable_func, {"test": "test"}) - - worker.start() - - -if __name__ == "__main__": - main() diff --git a/examples/v2/trigger.py b/examples/v2/trigger.py new file mode 100644 index 00000000..a7d257ee --- /dev/null +++ b/examples/v2/trigger.py @@ -0,0 +1,11 @@ +from examples.v2.workflows import ExampleWorkflowInput, example_workflow + + +def main() -> None: + example_workflow.run( + input=ExampleWorkflowInput(message="Hello, world!"), + ) + + +if __name__ == "__main__": + main() diff --git a/examples/v2/worker.py b/examples/v2/worker.py new file mode 100644 index 00000000..a03781db --- /dev/null +++ b/examples/v2/worker.py @@ -0,0 +1,24 @@ +from examples.v2.workflows import example_workflow, hatchet +from hatchet_sdk import BaseWorkflow, Context + + +class ExampleV2Workflow(BaseWorkflow): + config = example_workflow.config + + @hatchet.step(timeout="11s", retries=3) + def step1(self, context: Context) -> None: + input = example_workflow.get_workflow_input(context) + + print(input.message) + + return None + + +def main() -> None: + worker = hatchet.worker("test-worker", max_runs=1) + worker.register_workflow(ExampleV2Workflow()) + worker.start() + + +if __name__ == "__main__": + main() diff --git a/examples/v2/workflows.py b/examples/v2/workflows.py new file mode 100644 index 00000000..62547af1 --- /dev/null +++ b/examples/v2/workflows.py @@ -0,0 +1,17 @@ +from pydantic import BaseModel + +from hatchet_sdk import Hatchet + +hatchet = Hatchet(debug=True) + + +class ExampleWorkflowInput(BaseModel): + message: str + + +example_workflow = hatchet.declare_workflow( + name="example-workflow", + on_events=["example-event"], + timeout="10m", + input_validator=ExampleWorkflowInput, +) diff --git a/examples/worker_existing_loop/worker.py b/examples/worker_existing_loop/worker.py index 0daa7c83..bea5f74e 100644 --- a/examples/worker_existing_loop/worker.py +++ b/examples/worker_existing_loop/worker.py @@ -1,17 +1,13 @@ import asyncio from contextlib import suppress -from dotenv import load_dotenv - from hatchet_sdk import Context, Hatchet - -load_dotenv() +from hatchet_sdk.workflow import BaseWorkflow hatchet = Hatchet(debug=True) -@hatchet.workflow(name="MyWorkflow") -class MyWorkflow: +class MyWorkflow(BaseWorkflow): @hatchet.step() async def step(self, context: Context) -> dict[str, str]: print("started") diff --git a/hatchet_sdk/__init__.py b/hatchet_sdk/__init__.py index 3162c25c..d798c764 100644 --- a/hatchet_sdk/__init__.py +++ b/hatchet_sdk/__init__.py @@ -118,12 +118,11 @@ ) from hatchet_sdk.clients.rest.models.workflow_version_meta import WorkflowVersionMeta from hatchet_sdk.contracts.workflows_pb2 import ( - ConcurrencyLimitStrategy, CreateWorkflowVersionOpts, RateLimitDuration, - StickyStrategy, WorkerLabelComparator, ) +from hatchet_sdk.hatchet import Hatchet from hatchet_sdk.utils.aio_utils import sync_to_async from .client import new_client @@ -137,9 +136,15 @@ from .clients.run_event_listener import StepRunEventType, WorkflowRunEventType from .context.context import Context from .context.worker_context import WorkerContext -from .hatchet import ClientConfig, Hatchet, concurrency, on_failure_step, step, workflow -from .worker import Worker, WorkerStartOptions, WorkerStatus -from .workflow import ConcurrencyExpression +from .loader import ClientConfig +from .worker.worker import Worker, WorkerStartOptions, WorkerStatus +from .workflow import ( + BaseWorkflow, + ConcurrencyExpression, + ConcurrencyLimitStrategy, + StickyStrategy, + WorkflowConfig, +) __all__ = [ "AcceptInviteRequest", @@ -244,4 +249,6 @@ "WorkerStartOptions", "WorkerStatus", "ConcurrencyExpression", + "BaseWorkflow", + "WorkflowConfig", ] diff --git a/hatchet_sdk/client.py b/hatchet_sdk/client.py index 45dfd394..2d0771bf 100644 --- a/hatchet_sdk/client.py +++ b/hatchet_sdk/client.py @@ -1,5 +1,4 @@ import asyncio -from logging import Logger from typing import Callable import grpc @@ -12,43 +11,34 @@ from .clients.dispatcher.dispatcher import DispatcherClient, new_dispatcher from .clients.events import EventClient, new_event from .clients.rest_client import RestApi -from .loader import ClientConfig, ConfigLoader +from .loader import ClientConfig class Client: - admin: AdminClient - dispatcher: DispatcherClient - event: EventClient - rest: RestApi - workflow_listener: PooledWorkflowRunListener - logInterceptor: Logger - debug: bool = False - @classmethod def from_environment( cls, defaults: ClientConfig = ClientConfig(), debug: bool = False, *opts_functions: Callable[[ClientConfig], None], - ): + ) -> "Client": try: loop = asyncio.get_running_loop() except RuntimeError: loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) - config: ClientConfig = ConfigLoader(".").load_client_config(defaults) for opt_function in opts_functions: - opt_function(config) + opt_function(defaults) - return cls.from_config(config, debug) + return cls.from_config(defaults, debug) @classmethod def from_config( cls, config: ClientConfig = ClientConfig(), debug: bool = False, - ): + ) -> "Client": try: loop = asyncio.get_running_loop() except RuntimeError: @@ -61,7 +51,7 @@ def from_config( if config.host_port is None: raise ValueError("Host and port are required") - conn: grpc.Channel = new_conn(config) + conn: grpc.Channel = new_conn(config, False) # Instantiate clients event_client = new_event(conn, config) @@ -85,7 +75,7 @@ def __init__( event_client: EventClient, admin_client: AdminClient, dispatcher_client: DispatcherClient, - workflow_listener: PooledWorkflowRunListener, + workflow_listener: PooledWorkflowRunListener | None, rest_client: RestApi, config: ClientConfig, debug: bool = False, @@ -103,17 +93,9 @@ def __init__( self.config = config self.listener = RunEventListenerClient(config) self.workflow_listener = workflow_listener - self.logInterceptor = config.logInterceptor + self.logInterceptor = config.logger self.debug = debug -def with_host_port(host: str, port: int): - def with_host_port_impl(config: ClientConfig): - config.host = host - config.port = port - - return with_host_port_impl - - new_client = Client.from_environment new_client_raw = Client.from_config diff --git a/hatchet_sdk/clients/admin.py b/hatchet_sdk/clients/admin.py index e3d345b1..62628585 100644 --- a/hatchet_sdk/clients/admin.py +++ b/hatchet_sdk/clients/admin.py @@ -1,11 +1,11 @@ import json from datetime import datetime -from typing import Any, Callable, Dict, List, Optional, TypedDict, TypeVar, Union +from typing import Any, Callable, TypeVar, Union, cast import grpc from google.protobuf import timestamp_pb2 +from pydantic import BaseModel, Field -from hatchet_sdk.clients.rest.models.workflow_run import WorkflowRun from hatchet_sdk.clients.rest.tenacity_utils import tenacity_retry from hatchet_sdk.clients.run_event_listener import new_listener from hatchet_sdk.clients.workflow_listener import PooledWorkflowRunListener @@ -30,47 +30,47 @@ inject_carrier_into_metadata, parse_carrier_from_metadata, ) +from hatchet_sdk.utils.types import JSONSerializableDict from hatchet_sdk.workflow_run import RunRef, WorkflowRunRef from ..loader import ClientConfig from ..metadata import get_metadata -from ..workflow import WorkflowMeta -def new_admin(config: ClientConfig): +def new_admin(config: ClientConfig) -> "AdminClient": return AdminClient(config) -class ScheduleTriggerWorkflowOptions(TypedDict, total=False): - parent_id: Optional[str] - parent_step_run_id: Optional[str] - child_index: Optional[int] - child_key: Optional[str] - namespace: Optional[str] +class ScheduleTriggerWorkflowOptions(BaseModel): + parent_id: str | None = None + parent_step_run_id: str | None = None + child_index: int | None = None + child_key: str | None = None + namespace: str | None = None -class ChildTriggerWorkflowOptions(TypedDict, total=False): - additional_metadata: Dict[str, str] | None = None +class ChildTriggerWorkflowOptions(BaseModel): + additional_metadata: JSONSerializableDict = Field(default_factory=dict) sticky: bool | None = None -class ChildWorkflowRunDict(TypedDict, total=False): +class ChildWorkflowRunDict(BaseModel): workflow_name: str - input: Any + input: JSONSerializableDict options: ChildTriggerWorkflowOptions key: str | None = None -class TriggerWorkflowOptions(ScheduleTriggerWorkflowOptions, total=False): - additional_metadata: Dict[str, str] | None = None +class TriggerWorkflowOptions(ScheduleTriggerWorkflowOptions): + additional_metadata: JSONSerializableDict = Field(default_factory=dict) desired_worker_id: str | None = None namespace: str | None = None -class WorkflowRunDict(TypedDict, total=False): +class WorkflowRunDict(BaseModel): workflow_name: str - input: Any - options: TriggerWorkflowOptions | None + input: JSONSerializableDict + options: TriggerWorkflowOptions class DedupeViolationErr(Exception): @@ -83,27 +83,26 @@ class AdminClientBase: pooled_workflow_listener: PooledWorkflowRunListener | None = None def _prepare_workflow_request( - self, workflow_name: str, input: any, options: TriggerWorkflowOptions = None - ): + self, workflow_name: str, input: dict[str, Any], options: TriggerWorkflowOptions + ) -> TriggerWorkflowRequest: try: payload_data = json.dumps(input) + _options = options.model_dump() + + _options.pop("namespace") try: - meta = ( - None - if options is None or "additional_metadata" not in options - else options["additional_metadata"] - ) - if meta is not None: - options = { - **options, - "additional_metadata": json.dumps(meta).encode("utf-8"), - } + _options = { + **_options, + "additional_metadata": json.dumps( + options.additional_metadata + ).encode("utf-8"), + } except json.JSONDecodeError as e: raise ValueError(f"Error encoding payload: {e}") return TriggerWorkflowRequest( - name=workflow_name, input=payload_data, **(options or {}) + name=workflow_name, input=payload_data, **_options ) except json.JSONDecodeError as e: raise ValueError(f"Error encoding payload: {e}") @@ -111,9 +110,9 @@ def _prepare_workflow_request( def _prepare_put_workflow_request( self, name: str, - workflow: CreateWorkflowVersionOpts | WorkflowMeta, + workflow: CreateWorkflowVersionOpts, overrides: CreateWorkflowVersionOpts | None = None, - ): + ) -> PutWorkflowRequest: try: opts: CreateWorkflowVersionOpts @@ -136,10 +135,10 @@ def _prepare_put_workflow_request( def _prepare_schedule_workflow_request( self, name: str, - schedules: List[Union[datetime, timestamp_pb2.Timestamp]], - input={}, - options: ScheduleTriggerWorkflowOptions = None, - ): + schedules: list[Union[datetime, timestamp_pb2.Timestamp]], + input: JSONSerializableDict = {}, + options: ScheduleTriggerWorkflowOptions = ScheduleTriggerWorkflowOptions(), + ) -> ScheduleWorkflowRequest: timestamp_schedules = [] for schedule in schedules: if isinstance(schedule, datetime): @@ -159,7 +158,7 @@ def _prepare_schedule_workflow_request( name=name, schedules=timestamp_schedules, input=json.dumps(input), - **(options or {}), + **options.model_dump(), ) @@ -170,7 +169,7 @@ class AdminClientAioImpl(AdminClientBase): def __init__(self, config: ClientConfig): aio_conn = new_conn(config, True) self.config = config - self.aio_client = WorkflowServiceStub(aio_conn) + self.aio_client = WorkflowServiceStub(aio_conn) # type: ignore[no-untyped-call] self.token = config.token self.listener_client = new_listener(config) self.namespace = config.namespace @@ -179,13 +178,17 @@ def __init__(self, config: ClientConfig): async def run( self, function: Union[str, Callable[[Any], T]], - input: any, - options: TriggerWorkflowOptions = None, + input: JSONSerializableDict, + options: TriggerWorkflowOptions = TriggerWorkflowOptions(), ) -> "RunRef[T]": - workflow_name = function - - if not isinstance(function, str): - workflow_name = function.function_name + workflow_name = cast( + str, + ( + function + if isinstance(function, str) + else getattr(function, "function_name") + ), + ) wrr = await self.run_workflow(workflow_name, input, options) @@ -195,11 +198,12 @@ async def run( @tenacity_retry async def run_workflow( - self, workflow_name: str, input: any, options: TriggerWorkflowOptions = None + self, + workflow_name: str, + input: JSONSerializableDict, + options: TriggerWorkflowOptions = TriggerWorkflowOptions(), ) -> WorkflowRunRef: - ctx = parse_carrier_from_metadata( - (options or {}).get("additional_metadata", {}) - ) + ctx = parse_carrier_from_metadata(options.additional_metadata) with self.otel_tracer.start_as_current_span( f"hatchet.async_run_workflow.{workflow_name}", context=ctx @@ -212,27 +216,18 @@ async def run_workflow( self.config ) - namespace = self.namespace - - if ( - options is not None - and "namespace" in options - and options["namespace"] is not None - ): - namespace = options.pop("namespace") + namespace = options.namespace or self.namespace if namespace != "" and not workflow_name.startswith(self.namespace): workflow_name = f"{namespace}{workflow_name}" - if options is not None and "additional_metadata" in options: - options["additional_metadata"] = inject_carrier_into_metadata( - options["additional_metadata"], carrier - ) - span.set_attributes( - flatten( - options["additional_metadata"], parent_key="", separator="." - ) - ) + options.additional_metadata = inject_carrier_into_metadata( + options.additional_metadata, carrier + ) + + span.set_attributes( + flatten(options.additional_metadata, parent_key="", separator=".") + ) request = self._prepare_workflow_request(workflow_name, input, options) @@ -265,30 +260,22 @@ async def run_workflow( async def run_workflows( self, workflows: list[WorkflowRunDict], - options: TriggerWorkflowOptions | None = None, - ) -> List[WorkflowRunRef]: + options: TriggerWorkflowOptions = TriggerWorkflowOptions(), + ) -> list[WorkflowRunRef]: if len(workflows) == 0: raise ValueError("No workflows to run") if not self.pooled_workflow_listener: self.pooled_workflow_listener = PooledWorkflowRunListener(self.config) - namespace = self.namespace + namespace = options.namespace or self.namespace - if ( - options is not None - and "namespace" in options - and options["namespace"] is not None - ): - namespace = options["namespace"] - del options["namespace"] - - workflow_run_requests: TriggerWorkflowRequest = [] + workflow_run_requests: list[TriggerWorkflowRequest] = [] for workflow in workflows: - workflow_name = workflow["workflow_name"] - input_data = workflow["input"] - options = workflow["options"] + workflow_name = workflow.workflow_name + input_data = workflow.input + options = workflow.options if namespace != "" and not workflow_name.startswith(self.namespace): workflow_name = f"{namespace}{workflow_name}" @@ -297,10 +284,8 @@ async def run_workflows( request = self._prepare_workflow_request(workflow_name, input_data, options) workflow_run_requests.append(request) - request = BulkTriggerWorkflowRequest(workflows=workflow_run_requests) - resp: BulkTriggerWorkflowResponse = await self.aio_client.BulkTriggerWorkflow( - request, + BulkTriggerWorkflowRequest(workflows=workflow_run_requests), metadata=get_metadata(self.token), ) @@ -317,14 +302,17 @@ async def run_workflows( async def put_workflow( self, name: str, - workflow: CreateWorkflowVersionOpts | WorkflowMeta, + workflow: CreateWorkflowVersionOpts, overrides: CreateWorkflowVersionOpts | None = None, ) -> WorkflowVersion: opts = self._prepare_put_workflow_request(name, workflow, overrides) - return await self.aio_client.PutWorkflow( - opts, - metadata=get_metadata(self.token), + return cast( + WorkflowVersion, + await self.aio_client.PutWorkflow( + opts, + metadata=get_metadata(self.token), + ), ) @tenacity_retry @@ -333,7 +321,7 @@ async def put_rate_limit( key: str, limit: int, duration: RateLimitDuration = RateLimitDuration.SECOND, - ): + ) -> None: await self.aio_client.PutRateLimit( PutRateLimitRequest( key=key, @@ -347,20 +335,12 @@ async def put_rate_limit( async def schedule_workflow( self, name: str, - schedules: List[Union[datetime, timestamp_pb2.Timestamp]], - input={}, - options: ScheduleTriggerWorkflowOptions = None, + schedules: list[Union[datetime, timestamp_pb2.Timestamp]], + input: JSONSerializableDict = {}, + options: ScheduleTriggerWorkflowOptions = ScheduleTriggerWorkflowOptions(), ) -> WorkflowVersion: try: - namespace = self.namespace - - if ( - options is not None - and "namespace" in options - and options["namespace"] is not None - ): - namespace = options["namespace"] - del options["namespace"] + namespace = options.namespace or self.namespace if namespace != "" and not name.startswith(self.namespace): name = f"{namespace}{name}" @@ -369,9 +349,12 @@ async def schedule_workflow( name, schedules, input, options ) - return await self.aio_client.ScheduleWorkflow( - request, - metadata=get_metadata(self.token), + return cast( + WorkflowVersion, + await self.aio_client.ScheduleWorkflow( + request, + metadata=get_metadata(self.token), + ), ) except (grpc.aio.AioRpcError, grpc.RpcError) as e: if e.code() == grpc.StatusCode.ALREADY_EXISTS: @@ -382,9 +365,9 @@ async def schedule_workflow( class AdminClient(AdminClientBase): def __init__(self, config: ClientConfig): - conn = new_conn(config) + conn = new_conn(config, False) self.config = config - self.client = WorkflowServiceStub(conn) + self.client = WorkflowServiceStub(conn) # type: ignore[no-untyped-call] self.aio = AdminClientAioImpl(config) self.token = config.token self.listener_client = new_listener(config) @@ -395,7 +378,7 @@ def __init__(self, config: ClientConfig): def put_workflow( self, name: str, - workflow: CreateWorkflowVersionOpts | WorkflowMeta, + workflow: CreateWorkflowVersionOpts, overrides: CreateWorkflowVersionOpts | None = None, ) -> WorkflowVersion: opts = self._prepare_put_workflow_request(name, workflow, overrides) @@ -412,8 +395,8 @@ def put_rate_limit( self, key: str, limit: int, - duration: Union[RateLimitDuration.Value, str] = RateLimitDuration.SECOND, - ): + duration: Union[RateLimitDuration, str] = RateLimitDuration.SECOND, + ) -> None: self.client.PutRateLimit( PutRateLimitRequest( key=key, @@ -427,20 +410,12 @@ def put_rate_limit( def schedule_workflow( self, name: str, - schedules: List[Union[datetime, timestamp_pb2.Timestamp]], - input={}, - options: ScheduleTriggerWorkflowOptions = None, + schedules: list[Union[datetime, timestamp_pb2.Timestamp]], + input: JSONSerializableDict = {}, + options: ScheduleTriggerWorkflowOptions = ScheduleTriggerWorkflowOptions(), ) -> WorkflowVersion: try: - namespace = self.namespace - - if ( - options is not None - and "namespace" in options - and options["namespace"] is not None - ): - namespace = options["namespace"] - del options["namespace"] + namespace = options.namespace or self.namespace if namespace != "" and not name.startswith(self.namespace): name = f"{namespace}{name}" @@ -449,9 +424,12 @@ def schedule_workflow( name, schedules, input, options ) - return self.client.ScheduleWorkflow( - request, - metadata=get_metadata(self.token), + return cast( + WorkflowVersion, + self.client.ScheduleWorkflow( + request, + metadata=get_metadata(self.token), + ), ) except (grpc.RpcError, grpc.aio.AioRpcError) as e: if e.code() == grpc.StatusCode.ALREADY_EXISTS: @@ -463,11 +441,12 @@ def schedule_workflow( ## TODO: `any` type hint should come from `typing` @tenacity_retry def run_workflow( - self, workflow_name: str, input: any, options: TriggerWorkflowOptions = None + self, + workflow_name: str, + input: JSONSerializableDict, + options: TriggerWorkflowOptions = TriggerWorkflowOptions(), ) -> WorkflowRunRef: - ctx = parse_carrier_from_metadata( - (options or {}).get("additional_metadata", {}) - ) + ctx = parse_carrier_from_metadata(options.additional_metadata) with self.otel_tracer.start_as_current_span( f"hatchet.run_workflow.{workflow_name}", context=ctx @@ -480,26 +459,15 @@ def run_workflow( self.config ) - namespace = self.namespace - - ## TODO: Factor this out - it's repeated a lot of places - if ( - options is not None - and "namespace" in options - and options["namespace"] is not None - ): - namespace = options.pop("namespace") + namespace = options.namespace or self.namespace - if options is not None and "additional_metadata" in options: - options["additional_metadata"] = inject_carrier_into_metadata( - options["additional_metadata"], carrier - ) + options.additional_metadata = inject_carrier_into_metadata( + options.additional_metadata, carrier + ) - span.set_attributes( - flatten( - options["additional_metadata"], parent_key="", separator="." - ) - ) + span.set_attributes( + flatten(options.additional_metadata, parent_key="", separator=".") + ) if namespace != "" and not workflow_name.startswith(self.namespace): workflow_name = f"{namespace}{workflow_name}" @@ -533,39 +501,33 @@ def run_workflow( @tenacity_retry def run_workflows( - self, workflows: List[WorkflowRunDict], options: TriggerWorkflowOptions = None + self, + workflows: list[WorkflowRunDict], + options: TriggerWorkflowOptions = TriggerWorkflowOptions(), ) -> list[WorkflowRunRef]: - workflow_run_requests: TriggerWorkflowRequest = [] + workflow_run_requests: list[TriggerWorkflowRequest] = [] if not self.pooled_workflow_listener: self.pooled_workflow_listener = PooledWorkflowRunListener(self.config) for workflow in workflows: - workflow_name = workflow["workflow_name"] - input_data = workflow["input"] - options = workflow["options"] + workflow_name = workflow.workflow_name + input_data = workflow.input + options = workflow.options - namespace = self.namespace + namespace = options.namespace or self.namespace - if ( - options is not None - and "namespace" in options - and options["namespace"] is not None - ): - namespace = options["namespace"] - del options["namespace"] + if namespace != "" and not workflow_name.startswith(self.namespace): + workflow_name = f"{namespace}{workflow_name}" - if namespace != "" and not workflow_name.startswith(self.namespace): - workflow_name = f"{namespace}{workflow_name}" + # Prepare and trigger workflow for each workflow name and input + request = self._prepare_workflow_request(workflow_name, input_data, options) - # Prepare and trigger workflow for each workflow name and input - request = self._prepare_workflow_request(workflow_name, input_data, options) + workflow_run_requests.append(request) - workflow_run_requests.append(request) - - request = BulkTriggerWorkflowRequest(workflows=workflow_run_requests) + bulk_request = BulkTriggerWorkflowRequest(workflows=workflow_run_requests) resp: BulkTriggerWorkflowResponse = self.client.BulkTriggerWorkflow( - request, + bulk_request, metadata=get_metadata(self.token), ) @@ -581,13 +543,17 @@ def run_workflows( def run( self, function: Union[str, Callable[[Any], T]], - input: any, - options: TriggerWorkflowOptions = None, + input: JSONSerializableDict, + options: TriggerWorkflowOptions = TriggerWorkflowOptions(), ) -> "RunRef[T]": - workflow_name = function - - if not isinstance(function, str): - workflow_name = function.function_name + workflow_name = cast( + str, + ( + function + if isinstance(function, str) + else getattr(function, "function_name") + ), + ) wrr = self.run_workflow(workflow_name, input, options) diff --git a/hatchet_sdk/clients/dispatcher/action_listener.py b/hatchet_sdk/clients/dispatcher/action_listener.py index fc2887bd..208981f3 100644 --- a/hatchet_sdk/clients/dispatcher/action_listener.py +++ b/hatchet_sdk/clients/dispatcher/action_listener.py @@ -2,10 +2,11 @@ import json import time from dataclasses import dataclass, field -from typing import Any, AsyncGenerator, List, Optional +from typing import Any, AsyncGenerator, Optional, cast import grpc -from grpc._cython import cygrpc +import grpc.aio +from grpc._cython import cygrpc # type: ignore[attr-defined] from hatchet_sdk.clients.event_ts import Event_ts, read_with_interrupt from hatchet_sdk.clients.run_event_listener import ( @@ -24,29 +25,27 @@ from hatchet_sdk.logger import logger from hatchet_sdk.utils.backoff import exp_backoff_sleep from hatchet_sdk.utils.serialization import flatten +from hatchet_sdk.utils.types import JSONSerializableDict from ...loader import ClientConfig from ...metadata import get_metadata from ..events import proto_timestamp_now DEFAULT_ACTION_TIMEOUT = 600 # seconds - - -DEFAULT_ACTION_LISTENER_RETRY_INTERVAL = 5 # seconds DEFAULT_ACTION_LISTENER_RETRY_COUNT = 15 @dataclass class GetActionListenerRequest: worker_name: str - services: List[str] - actions: List[str] + services: list[str] + actions: list[str] max_runs: Optional[int] = None _labels: dict[str, str | int] = field(default_factory=dict) labels: dict[str, WorkerLabels] = field(init=False) - def __post_init__(self): + def __post_init__(self) -> None: self.labels = {} for key, value in self._labels.items(): @@ -68,16 +67,16 @@ class Action: step_id: str step_run_id: str action_id: str - action_payload: str action_type: ActionType retry_count: int - additional_metadata: dict[str, str] | None = None + action_payload: JSONSerializableDict = field(default_factory=dict) + additional_metadata: JSONSerializableDict = field(default_factory=dict) child_workflow_index: int | None = None child_workflow_key: str | None = None parent_workflow_run_id: str | None = None - def __post_init__(self): + def __post_init__(self) -> None: if isinstance(self.additional_metadata, str) and self.additional_metadata != "": try: self.additional_metadata = json.loads(self.additional_metadata) @@ -113,11 +112,6 @@ def otel_attributes(self) -> dict[str, Any]: ) -START_STEP_RUN = 0 -CANCEL_STEP_RUN = 1 -START_GET_GROUP_KEY = 2 - - @dataclass class ActionListener: config: ClientConfig @@ -130,22 +124,22 @@ class ActionListener: last_connection_attempt: float = field(default=0, init=False) last_heartbeat_succeeded: bool = field(default=True, init=False) time_last_hb_succeeded: float = field(default=9999999999999, init=False) - heartbeat_task: Optional[asyncio.Task] = field(default=None, init=False) + heartbeat_task: Optional[asyncio.Task[None]] = field(default=None, init=False) run_heartbeat: bool = field(default=True, init=False) listen_strategy: str = field(default="v2", init=False) stop_signal: bool = field(default=False, init=False) missed_heartbeats: int = field(default=0, init=False) - def __post_init__(self): - self.client = DispatcherStub(new_conn(self.config)) - self.aio_client = DispatcherStub(new_conn(self.config, True)) + def __post_init__(self) -> None: + self.client = DispatcherStub(new_conn(self.config, False)) # type: ignore[no-untyped-call] + self.aio_client = DispatcherStub(new_conn(self.config, True)) # type: ignore[no-untyped-call] self.token = self.config.token - def is_healthy(self): + def is_healthy(self) -> bool: return self.last_heartbeat_succeeded - async def heartbeat(self): + async def heartbeat(self) -> None: # send a heartbeat every 4 seconds heartbeat_delay = 4 @@ -205,7 +199,7 @@ async def heartbeat(self): break await asyncio.sleep(heartbeat_delay) - async def start_heartbeater(self): + async def start_heartbeater(self) -> None: if self.heartbeat_task is not None: return @@ -219,10 +213,10 @@ async def start_heartbeater(self): raise e self.heartbeat_task = loop.create_task(self.heartbeat()) - def __aiter__(self): + def __aiter__(self) -> AsyncGenerator[Action | None, None]: return self._generator() - async def _generator(self) -> AsyncGenerator[Action, None]: + async def _generator(self) -> AsyncGenerator[Action | None, None]: listener = None while not self.stop_signal: @@ -238,6 +232,10 @@ async def _generator(self) -> AsyncGenerator[Action, None]: try: while not self.stop_signal: self.interrupt = Event_ts() + + if listener is None: + continue + t = asyncio.create_task( read_with_interrupt(listener, self.interrupt) ) @@ -250,7 +248,10 @@ async def _generator(self) -> AsyncGenerator[Action, None]: ) t.cancel() - listener.cancel() + + if listener: + listener.cancel() + break assigned_action = t.result() @@ -260,20 +261,23 @@ async def _generator(self) -> AsyncGenerator[Action, None]: break self.retries = 0 - assigned_action: AssignedAction # Process the received action - action_type = self.map_action_type(assigned_action.actionType) + action_type = assigned_action.actionType - if ( - assigned_action.actionPayload is None - or assigned_action.actionPayload == "" - ): - action_payload = None - else: - action_payload = self.parse_action_payload( - assigned_action.actionPayload + action_payload = ( + {} + if not assigned_action.actionPayload + else self.parse_action_payload(assigned_action.actionPayload) + ) + + try: + additional_metadata = cast( + dict[str, Any], + json.loads(assigned_action.additional_metadata), ) + except json.JSONDecodeError: + additional_metadata = {} action = Action( tenant_id=assigned_action.tenantId, @@ -289,7 +293,7 @@ async def _generator(self) -> AsyncGenerator[Action, None]: action_payload=action_payload, action_type=action_type, retry_count=assigned_action.retryCount, - additional_metadata=assigned_action.additional_metadata, + additional_metadata=additional_metadata, child_workflow_index=assigned_action.child_workflow_index, child_workflow_key=assigned_action.child_workflow_key, parent_workflow_run_id=assigned_action.parent_workflow_run_id, @@ -323,25 +327,15 @@ async def _generator(self) -> AsyncGenerator[Action, None]: self.retries = self.retries + 1 - def parse_action_payload(self, payload: str): + def parse_action_payload(self, payload: str) -> JSONSerializableDict: try: - payload_data = json.loads(payload) + return cast(JSONSerializableDict, json.loads(payload)) except json.JSONDecodeError as e: raise ValueError(f"Error decoding payload: {e}") - return payload_data - - def map_action_type(self, action_type): - if action_type == ActionType.START_STEP_RUN: - return START_STEP_RUN - elif action_type == ActionType.CANCEL_STEP_RUN: - return CANCEL_STEP_RUN - elif action_type == ActionType.START_GET_GROUP_KEY: - return START_GET_GROUP_KEY - else: - # logger.error(f"Unknown action type: {action_type}") - return None - async def get_listen_client(self): + async def get_listen_client( + self, + ) -> grpc.aio.UnaryStreamCall[WorkerListenRequest, AssignedAction]: current_time = int(time.time()) if ( @@ -369,7 +363,7 @@ async def get_listen_client(self): f"action listener connection interrupted, retrying... ({self.retries}/{DEFAULT_ACTION_LISTENER_RETRY_COUNT})" ) - self.aio_client = DispatcherStub(new_conn(self.config, True)) + self.aio_client = DispatcherStub(new_conn(self.config, True)) # type: ignore[no-untyped-call] if self.listen_strategy == "v2": # we should await for the listener to be established before @@ -390,11 +384,14 @@ async def get_listen_client(self): self.last_connection_attempt = current_time - return listener + return cast( + grpc.aio.UnaryStreamCall[WorkerListenRequest, AssignedAction], listener + ) - def cleanup(self): + def cleanup(self) -> None: self.run_heartbeat = False - self.heartbeat_task.cancel() + if self.heartbeat_task is not None: + self.heartbeat_task.cancel() try: self.unregister() @@ -404,9 +401,11 @@ def cleanup(self): if self.interrupt: self.interrupt.set() - def unregister(self): + def unregister(self) -> WorkerUnsubscribeRequest: self.run_heartbeat = False - self.heartbeat_task.cancel() + + if self.heartbeat_task is not None: + self.heartbeat_task.cancel() try: req = self.aio_client.Unsubscribe( @@ -416,6 +415,6 @@ def unregister(self): ) if self.interrupt is not None: self.interrupt.set() - return req + return cast(WorkerUnsubscribeRequest, req) except grpc.RpcError as e: raise Exception(f"Failed to unsubscribe: {e}") diff --git a/hatchet_sdk/clients/dispatcher/dispatcher.py b/hatchet_sdk/clients/dispatcher/dispatcher.py index 407a80cc..e9769a71 100644 --- a/hatchet_sdk/clients/dispatcher/dispatcher.py +++ b/hatchet_sdk/clients/dispatcher/dispatcher.py @@ -1,5 +1,6 @@ -from typing import Any, cast +from typing import cast +import grpc.aio from google.protobuf.timestamp_pb2 import Timestamp from hatchet_sdk.clients.dispatcher.action_listener import ( @@ -41,7 +42,7 @@ class DispatcherClient: config: ClientConfig def __init__(self, config: ClientConfig): - conn = new_conn(config) + conn = new_conn(config, False) self.client = DispatcherStub(conn) # type: ignore[no-untyped-call] aio_conn = new_conn(config, True) @@ -76,7 +77,7 @@ async def get_action_listener( async def send_step_action_event( self, action: Action, event_type: StepActionEventType, payload: str - ) -> Any: + ) -> grpc.aio.UnaryUnaryCall[StepActionEvent, ActionEventResponse] | None: try: return await self._try_send_step_action_event(action, event_type, payload) except Exception as e: @@ -91,12 +92,12 @@ async def send_step_action_event( "Failed to send finished event: " + str(e), ) - return + return None @tenacity_retry async def _try_send_step_action_event( self, action: Action, event_type: StepActionEventType, payload: str - ) -> Any: + ) -> grpc.aio.UnaryUnaryCall[StepActionEvent, ActionEventResponse]: eventTimestamp = Timestamp() eventTimestamp.GetCurrentTime() @@ -113,15 +114,17 @@ async def _try_send_step_action_event( retryCount=action.retry_count, ) - ## TODO: What does this return? - return await self.aio_client.SendStepActionEvent( - event, - metadata=get_metadata(self.token), + return cast( + grpc.aio.UnaryUnaryCall[StepActionEvent, ActionEventResponse], + await self.aio_client.SendStepActionEvent( + event, + metadata=get_metadata(self.token), + ), ) async def send_group_key_action_event( self, action: Action, event_type: GroupKeyActionEventType, payload: str - ) -> Any: + ) -> grpc.aio.UnaryUnaryCall[GroupKeyActionEvent, ActionEventResponse]: eventTimestamp = Timestamp() eventTimestamp.GetCurrentTime() @@ -136,9 +139,12 @@ async def send_group_key_action_event( ) ## TODO: What does this return? - return await self.aio_client.SendGroupKeyActionEvent( - event, - metadata=get_metadata(self.token), + return cast( + grpc.aio.UnaryUnaryCall[GroupKeyActionEvent, ActionEventResponse], + await self.aio_client.SendGroupKeyActionEvent( + event, + metadata=get_metadata(self.token), + ), ) def put_overrides_data(self, data: OverridesData) -> ActionEventResponse: diff --git a/hatchet_sdk/clients/event_ts.py b/hatchet_sdk/clients/event_ts.py index 1d3c3978..7a85d467 100644 --- a/hatchet_sdk/clients/event_ts.py +++ b/hatchet_sdk/clients/event_ts.py @@ -1,5 +1,8 @@ import asyncio -from typing import Any +from typing import Any, TypeVar, cast + +import grpc.aio +from grpc._cython import cygrpc # type: ignore[attr-defined] class Event_ts(asyncio.Event): @@ -7,22 +10,32 @@ class Event_ts(asyncio.Event): Event_ts is a subclass of asyncio.Event that allows for thread-safe setting and clearing of the event. """ - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) - if self._loop is None: + if self._loop is None: # type: ignore[has-type] self._loop = asyncio.get_event_loop() - def set(self): + def set(self) -> None: if not self._loop.is_closed(): self._loop.call_soon_threadsafe(super().set) - def clear(self): + def clear(self) -> None: self._loop.call_soon_threadsafe(super().clear) -async def read_with_interrupt(listener: Any, interrupt: Event_ts): +TRequest = TypeVar("TRequest") +TResponse = TypeVar("TResponse") + + +async def read_with_interrupt( + listener: grpc.aio.UnaryStreamCall[TRequest, TResponse], interrupt: Event_ts +) -> TResponse: try: result = await listener.read() - return result + + if result is cygrpc.EOF: + raise ValueError("Unexpected EOF") + + return cast(TResponse, result) finally: interrupt.set() diff --git a/hatchet_sdk/clients/events.py b/hatchet_sdk/clients/events.py index 160b780e..f1bf6f42 100644 --- a/hatchet_sdk/clients/events.py +++ b/hatchet_sdk/clients/events.py @@ -1,11 +1,12 @@ import asyncio import datetime import json -from typing import Any, Dict, List, Optional, TypedDict +from typing import Any, List, cast from uuid import uuid4 import grpc from google.protobuf import timestamp_pb2 +from pydantic import BaseModel, Field from hatchet_sdk.clients.rest.tenacity_utils import tenacity_retry from hatchet_sdk.contracts.events_pb2 import ( @@ -18,24 +19,26 @@ from hatchet_sdk.contracts.events_pb2_grpc import EventsServiceStub from hatchet_sdk.utils.serialization import flatten from hatchet_sdk.utils.tracing import ( + OTEL_CARRIER_KEY, create_carrier, create_tracer, inject_carrier_into_metadata, parse_carrier_from_metadata, ) +from hatchet_sdk.utils.types import JSONSerializableDict from ..loader import ClientConfig from ..metadata import get_metadata -def new_event(conn, config: ClientConfig): +def new_event(conn: grpc.Channel, config: ClientConfig) -> "EventClient": return EventClient( - client=EventsServiceStub(conn), + client=EventsServiceStub(conn), # type: ignore[no-untyped-call] config=config, ) -def proto_timestamp_now(): +def proto_timestamp_now() -> timestamp_pb2.Timestamp: t = datetime.datetime.now().timestamp() seconds = int(t) nanos = int(t % 1 * 1e9) @@ -43,19 +46,20 @@ def proto_timestamp_now(): return timestamp_pb2.Timestamp(seconds=seconds, nanos=nanos) -class PushEventOptions(TypedDict, total=False): - additional_metadata: Dict[str, str] | None = None +class PushEventOptions(BaseModel): + additional_metadata: JSONSerializableDict = Field(default_factory=dict) namespace: str | None = None -class BulkPushEventOptions(TypedDict, total=False): +class BulkPushEventOptions(BaseModel): namespace: str | None = None + otel_carrier: dict[str, str] = Field(default_factory=dict) -class BulkPushEventWithMetadata(TypedDict, total=False): +class BulkPushEventWithMetadata(BaseModel): key: str payload: Any - additional_metadata: Optional[Dict[str, Any]] # Optional metadata + additional_metadata: JSONSerializableDict = Field(default_factory=dict) class EventClient: @@ -66,7 +70,10 @@ def __init__(self, client: EventsServiceStub, config: ClientConfig): self.otel_tracer = create_tracer(config=config) async def async_push( - self, event_key, payload, options: Optional[PushEventOptions] = None + self, + event_key: str, + payload: dict[str, Any], + options: PushEventOptions = PushEventOptions(), ) -> Event: return await asyncio.to_thread( self.push, event_key=event_key, payload=payload, options=options @@ -74,77 +81,67 @@ async def async_push( async def async_bulk_push( self, - events: List[BulkPushEventWithMetadata], - options: Optional[BulkPushEventOptions] = None, + events: list[BulkPushEventWithMetadata], + options: BulkPushEventOptions = BulkPushEventOptions(), ) -> List[Event]: return await asyncio.to_thread(self.bulk_push, events=events, options=options) @tenacity_retry - def push(self, event_key, payload, options: PushEventOptions = None) -> Event: - ctx = parse_carrier_from_metadata( - (options or {}).get("additional_metadata", {}) - ) + def push( + self, + event_key: str, + payload: dict[str, Any], + options: PushEventOptions = PushEventOptions(), + ) -> Event: + ctx = parse_carrier_from_metadata(options.additional_metadata) with self.otel_tracer.start_as_current_span( "hatchet.push", context=ctx ) as span: carrier = create_carrier() - namespace = self.namespace - - if ( - options is not None - and "namespace" in options - and options["namespace"] is not None - ): - namespace = options.pop("namespace") + namespace = options.namespace or self.namespace namespaced_event_key = namespace + event_key try: meta = inject_carrier_into_metadata( - dict() if options is None else options["additional_metadata"], + options.additional_metadata, carrier, ) - meta_bytes = None if meta is None else json.dumps(meta).encode("utf-8") + meta_bytes = None if meta is None else json.dumps(meta) except Exception as e: raise ValueError(f"Error encoding meta: {e}") span.set_attributes(flatten(meta, parent_key="", separator=".")) try: - payload_bytes = json.dumps(payload).encode("utf-8") - except json.UnicodeEncodeError as e: + payload_str = json.dumps(payload) + except (TypeError, ValueError) as e: raise ValueError(f"Error encoding payload: {e}") request = PushEventRequest( key=namespaced_event_key, - payload=payload_bytes, + payload=payload_str, eventTimestamp=proto_timestamp_now(), additionalMetadata=meta_bytes, ) span.add_event("Pushing event", attributes={"key": namespaced_event_key}) - return self.client.Push(request, metadata=get_metadata(self.token)) + return cast( + Event, self.client.Push(request, metadata=get_metadata(self.token)) + ) @tenacity_retry def bulk_push( self, events: List[BulkPushEventWithMetadata], - options: BulkPushEventOptions = None, + options: BulkPushEventOptions, ) -> List[Event]: - namespace = self.namespace + namespace = options.namespace or self.namespace bulk_push_correlation_id = uuid4() - ctx = parse_carrier_from_metadata( - (options or {}).get("additional_metadata", {}) - ) - if ( - options is not None - and "namespace" in options - and options["namespace"] is not None - ): - namespace = options.pop("namespace") + ctx = parse_carrier_from_metadata({OTEL_CARRIER_KEY: options.otel_carrier}) bulk_events = [] for event in events: @@ -156,29 +153,27 @@ def bulk_push( "bulk_push_correlation_id", str(bulk_push_correlation_id) ) - event_key = namespace + event["key"] - payload = event["payload"] + event_key = namespace + event.key + payload = event.payload + + meta = inject_carrier_into_metadata(event.additional_metadata, carrier) + span.set_attributes(flatten(meta, parent_key="", separator=".")) try: - meta = inject_carrier_into_metadata( - event.get("additional_metadata", {}), carrier - ) - meta_bytes = json.dumps(meta).encode("utf-8") if meta else None + meta_str = json.dumps(meta) except Exception as e: raise ValueError(f"Error encoding meta: {e}") - span.set_attributes(flatten(meta, parent_key="", separator=".")) - try: - payload_bytes = json.dumps(payload).encode("utf-8") - except json.UnicodeEncodeError as e: + payload = json.dumps(payload) + except (TypeError, ValueError) as e: raise ValueError(f"Error encoding payload: {e}") request = PushEventRequest( key=event_key, - payload=payload_bytes, + payload=payload, eventTimestamp=proto_timestamp_now(), - additionalMetadata=meta_bytes, + additionalMetadata=meta_str, ) bulk_events.append(request) @@ -187,9 +182,12 @@ def bulk_push( span.add_event("Pushing bulk events") response = self.client.BulkPush(bulk_request, metadata=get_metadata(self.token)) - return response.events + return cast( + list[Event], + response.events, + ) - def log(self, message: str, step_run_id: str): + def log(self, message: str, step_run_id: str) -> None: try: request = PutLogRequest( stepRunId=step_run_id, @@ -201,7 +199,7 @@ def log(self, message: str, step_run_id: str): except Exception as e: raise ValueError(f"Error logging: {e}") - def stream(self, data: str | bytes, step_run_id: str): + def stream(self, data: str | bytes, step_run_id: str) -> None: try: if isinstance(data, str): data_bytes = data.encode("utf-8") diff --git a/hatchet_sdk/clients/rest/api_client.py b/hatchet_sdk/clients/rest/api_client.py index 76446dda..62bdc472 100644 --- a/hatchet_sdk/clients/rest/api_client.py +++ b/hatchet_sdk/clients/rest/api_client.py @@ -97,7 +97,7 @@ async def __aenter__(self): async def __aexit__(self, exc_type, exc_value, traceback): await self.close() - async def close(self): + async def close(self) -> None: await self.rest_client.close() @property diff --git a/hatchet_sdk/clients/rest/tenacity_utils.py b/hatchet_sdk/clients/rest/tenacity_utils.py index 377266a1..c90f7352 100644 --- a/hatchet_sdk/clients/rest/tenacity_utils.py +++ b/hatchet_sdk/clients/rest/tenacity_utils.py @@ -27,7 +27,7 @@ def tenacity_alert_retry(retry_state: tenacity.RetryCallState) -> None: ) -def tenacity_should_retry(ex: Exception) -> bool: +def tenacity_should_retry(ex: BaseException) -> bool: if isinstance(ex, (grpc.aio.AioRpcError, grpc.RpcError)): if ex.code() in [ grpc.StatusCode.UNIMPLEMENTED, diff --git a/hatchet_sdk/clients/rest_client.py b/hatchet_sdk/clients/rest_client.py index f6458e5a..51f3a912 100644 --- a/hatchet_sdk/clients/rest_client.py +++ b/hatchet_sdk/clients/rest_client.py @@ -2,7 +2,7 @@ import atexit import datetime import threading -from typing import Any, Coroutine, List +from typing import Coroutine, TypeVar from pydantic import StrictInt @@ -11,14 +11,13 @@ from hatchet_sdk.clients.rest.api.step_run_api import StepRunApi from hatchet_sdk.clients.rest.api.workflow_api import WorkflowApi from hatchet_sdk.clients.rest.api.workflow_run_api import WorkflowRunApi -from hatchet_sdk.clients.rest.api.workflow_runs_api import WorkflowRunsApi from hatchet_sdk.clients.rest.api_client import ApiClient from hatchet_sdk.clients.rest.configuration import Configuration -from hatchet_sdk.clients.rest.models import TriggerWorkflowRunRequest from hatchet_sdk.clients.rest.models.create_cron_workflow_trigger_request import ( CreateCronWorkflowTriggerRequest, ) from hatchet_sdk.clients.rest.models.cron_workflows import CronWorkflows +from hatchet_sdk.clients.rest.models.cron_workflows_list import CronWorkflowsList from hatchet_sdk.clients.rest.models.cron_workflows_order_by_field import ( CronWorkflowsOrderByField, ) @@ -27,6 +26,9 @@ EventOrderByDirection, ) from hatchet_sdk.clients.rest.models.event_order_by_field import EventOrderByField +from hatchet_sdk.clients.rest.models.event_update_cancel200_response import ( + EventUpdateCancel200Response, +) from hatchet_sdk.clients.rest.models.log_line_level import LogLineLevel from hatchet_sdk.clients.rest.models.log_line_list import LogLineList from hatchet_sdk.clients.rest.models.log_line_order_by_direction import ( @@ -44,16 +46,19 @@ ScheduleWorkflowRunRequest, ) from hatchet_sdk.clients.rest.models.scheduled_workflows import ScheduledWorkflows +from hatchet_sdk.clients.rest.models.scheduled_workflows_list import ( + ScheduledWorkflowsList, +) from hatchet_sdk.clients.rest.models.scheduled_workflows_order_by_field import ( ScheduledWorkflowsOrderByField, ) +from hatchet_sdk.clients.rest.models.trigger_workflow_run_request import ( + TriggerWorkflowRunRequest, +) from hatchet_sdk.clients.rest.models.workflow import Workflow from hatchet_sdk.clients.rest.models.workflow_kind import WorkflowKind from hatchet_sdk.clients.rest.models.workflow_list import WorkflowList from hatchet_sdk.clients.rest.models.workflow_run import WorkflowRun -from hatchet_sdk.clients.rest.models.workflow_run_cancel200_response import ( - WorkflowRunCancel200Response, -) from hatchet_sdk.clients.rest.models.workflow_run_list import WorkflowRunList from hatchet_sdk.clients.rest.models.workflow_run_order_by_direction import ( WorkflowRunOrderByDirection, @@ -66,6 +71,18 @@ WorkflowRunsCancelRequest, ) from hatchet_sdk.clients.rest.models.workflow_version import WorkflowVersion +from hatchet_sdk.utils.types import JSONSerializableDict + +## Type variables to use with coroutines. +## See https://stackoverflow.com/questions/73240620/the-right-way-to-type-hint-a-coroutine-function +## Return type +R = TypeVar("R") + +## Yield type +Y = TypeVar("Y") + +## Send type +S = TypeVar("S") class AsyncRestApi: @@ -77,50 +94,50 @@ def __init__(self, host: str, api_key: str, tenant_id: str): access_token=api_key, ) - self._api_client = None - self._workflow_api = None - self._workflow_run_api = None - self._step_run_api = None - self._event_api = None - self._log_api = None + self._api_client: ApiClient | None = None + self._workflow_api: WorkflowApi | None = None + self._workflow_run_api: WorkflowRunApi | None = None + self._step_run_api: StepRunApi | None = None + self._event_api: EventApi | None = None + self._log_api: LogApi | None = None @property - def api_client(self): + def api_client(self) -> ApiClient: if self._api_client is None: self._api_client = ApiClient(configuration=self.config) return self._api_client @property - def workflow_api(self): + def workflow_api(self) -> WorkflowApi: if self._workflow_api is None: self._workflow_api = WorkflowApi(self.api_client) return self._workflow_api @property - def workflow_run_api(self): + def workflow_run_api(self) -> WorkflowRunApi: if self._workflow_run_api is None: self._workflow_run_api = WorkflowRunApi(self.api_client) return self._workflow_run_api @property - def step_run_api(self): + def step_run_api(self) -> StepRunApi: if self._step_run_api is None: self._step_run_api = StepRunApi(self.api_client) return self._step_run_api @property - def event_api(self): + def event_api(self) -> EventApi: if self._event_api is None: self._event_api = EventApi(self.api_client) return self._event_api @property - def log_api(self): + def log_api(self) -> LogApi: if self._log_api is None: self._log_api = LogApi(self.api_client) return self._log_api - async def close(self): + async def close(self) -> None: # Ensure the aiohttp client session is closed if self._api_client is not None: await self._api_client.close() @@ -184,13 +201,13 @@ async def workflow_run_replay( return await self.workflow_run_api.workflow_run_update_replay( tenant=self.tenant_id, replay_workflow_runs_request=ReplayWorkflowRunsRequest( - workflow_run_ids=workflow_run_ids, + workflowRunIds=workflow_run_ids, ), ) async def workflow_run_cancel( self, workflow_run_id: str - ) -> WorkflowRunCancel200Response: + ) -> EventUpdateCancel200Response: return await self.workflow_run_api.workflow_run_cancel( tenant=self.tenant_id, workflow_runs_cancel_request=WorkflowRunsCancelRequest( @@ -200,7 +217,7 @@ async def workflow_run_cancel( async def workflow_run_bulk_cancel( self, workflow_run_ids: list[str] - ) -> WorkflowRunCancel200Response: + ) -> EventUpdateCancel200Response: return await self.workflow_run_api.workflow_run_cancel( tenant=self.tenant_id, workflow_runs_cancel_request=WorkflowRunsCancelRequest( @@ -211,16 +228,16 @@ async def workflow_run_bulk_cancel( async def workflow_run_create( self, workflow_id: str, - input: dict[str, Any], + input: JSONSerializableDict, version: str | None = None, - additional_metadata: list[str] | None = None, + additional_metadata: JSONSerializableDict = {}, ) -> WorkflowRun: return await self.workflow_run_api.workflow_run_create( workflow=workflow_id, version=version, trigger_workflow_run_request=TriggerWorkflowRunRequest( input=input, - additional_metadata=additional_metadata, + additionalMetadata=additional_metadata, ), ) @@ -229,9 +246,9 @@ async def cron_create( workflow_name: str, cron_name: str, expression: str, - input: dict[str, Any], - additional_metadata: dict[str, str], - ): + input: JSONSerializableDict, + additional_metadata: JSONSerializableDict, + ) -> CronWorkflows: return await self.workflow_run_api.cron_workflow_trigger_create( tenant=self.tenant_id, workflow=workflow_name, @@ -239,12 +256,12 @@ async def cron_create( cronName=cron_name, cronExpression=expression, input=input, - additional_metadata=additional_metadata, + additionalMetadata=additional_metadata, ), ) - async def cron_delete(self, cron_trigger_id: str): - return await self.workflow_api.workflow_cron_delete( + async def cron_delete(self, cron_trigger_id: str) -> None: + await self.workflow_api.workflow_cron_delete( tenant=self.tenant_id, cron_workflow=cron_trigger_id, ) @@ -257,7 +274,7 @@ async def cron_list( additional_metadata: list[str] | None = None, order_by_field: CronWorkflowsOrderByField | None = None, order_by_direction: WorkflowRunOrderByDirection | None = None, - ): + ) -> CronWorkflowsList: return await self.workflow_api.cron_workflow_list( tenant=self.tenant_id, offset=offset, @@ -268,7 +285,7 @@ async def cron_list( order_by_direction=order_by_direction, ) - async def cron_get(self, cron_trigger_id: str): + async def cron_get(self, cron_trigger_id: str) -> CronWorkflows: return await self.workflow_api.workflow_cron_get( tenant=self.tenant_id, cron_workflow=cron_trigger_id, @@ -278,21 +295,21 @@ async def schedule_create( self, name: str, trigger_at: datetime.datetime, - input: dict[str, Any], - additional_metadata: dict[str, str], - ): + input: JSONSerializableDict, + additional_metadata: JSONSerializableDict, + ) -> ScheduledWorkflows: return await self.workflow_run_api.scheduled_workflow_run_create( tenant=self.tenant_id, workflow=name, schedule_workflow_run_request=ScheduleWorkflowRunRequest( triggerAt=trigger_at, input=input, - additional_metadata=additional_metadata, + additionalMetadata=additional_metadata, ), ) - async def schedule_delete(self, scheduled_trigger_id: str): - return await self.workflow_api.workflow_scheduled_delete( + async def schedule_delete(self, scheduled_trigger_id: str) -> None: + await self.workflow_api.workflow_scheduled_delete( tenant=self.tenant_id, scheduled_workflow_run=scheduled_trigger_id, ) @@ -307,7 +324,7 @@ async def schedule_list( parent_step_run_id: str | None = None, order_by_field: ScheduledWorkflowsOrderByField | None = None, order_by_direction: WorkflowRunOrderByDirection | None = None, - ): + ) -> ScheduledWorkflowsList: return await self.workflow_api.workflow_scheduled_list( tenant=self.tenant_id, offset=offset, @@ -320,7 +337,7 @@ async def schedule_list( order_by_direction=order_by_direction, ) - async def schedule_get(self, scheduled_trigger_id: str): + async def schedule_get(self, scheduled_trigger_id: str) -> ScheduledWorkflows: return await self.workflow_api.workflow_scheduled_get( tenant=self.tenant_id, scheduled_workflow_run=scheduled_trigger_id, @@ -373,9 +390,10 @@ async def events_list( async def events_replay(self, event_ids: list[str] | EventList) -> EventList: if isinstance(event_ids, EventList): - event_ids = [r.metadata.id for r in event_ids.rows] + rows = event_ids.rows or [] + event_ids = [r.metadata.id for r in rows] - return self.event_api.event_update_replay( + return await self.event_api.event_update_replay( tenant=self.tenant_id, replay_event_request=ReplayEventRequest(eventIds=event_ids), ) @@ -393,7 +411,7 @@ def __init__(self, host: str, api_key: str, tenant_id: str): # Register the cleanup method to be called on exit atexit.register(self._cleanup) - def _cleanup(self): + def _cleanup(self) -> None: """ Stop the running thread and clean up the event loop. """ @@ -401,14 +419,14 @@ def _cleanup(self): self._loop.call_soon_threadsafe(self._loop.stop) self._thread.join() - def _run_event_loop(self): + def _run_event_loop(self) -> None: """ Run the asyncio event loop in a separate thread. """ asyncio.set_event_loop(self._loop) self._loop.run_forever() - def _run_coroutine(self, coro) -> Any: + def _run_coroutine(self, coro: Coroutine[Y, S, R]) -> R: """ Execute a coroutine in the event loop and return the result. """ @@ -459,20 +477,20 @@ def workflow_run_list( def workflow_run_get(self, workflow_run_id: str) -> WorkflowRun: return self._run_coroutine(self.aio.workflow_run_get(workflow_run_id)) - def workflow_run_cancel(self, workflow_run_id: str) -> WorkflowRunCancel200Response: + def workflow_run_cancel(self, workflow_run_id: str) -> EventUpdateCancel200Response: return self._run_coroutine(self.aio.workflow_run_cancel(workflow_run_id)) def workflow_run_bulk_cancel( self, workflow_run_ids: list[str] - ) -> WorkflowRunCancel200Response: + ) -> EventUpdateCancel200Response: return self._run_coroutine(self.aio.workflow_run_bulk_cancel(workflow_run_ids)) def workflow_run_create( self, workflow_id: str, - input: dict[str, Any], + input: JSONSerializableDict, version: str | None = None, - additional_metadata: list[str] | None = None, + additional_metadata: JSONSerializableDict = {}, ) -> WorkflowRun: return self._run_coroutine( self.aio.workflow_run_create( @@ -485,8 +503,8 @@ def cron_create( workflow_name: str, cron_name: str, expression: str, - input: dict[str, Any], - additional_metadata: dict[str, str], + input: JSONSerializableDict, + additional_metadata: JSONSerializableDict, ) -> CronWorkflows: return self._run_coroutine( self.aio.cron_create( @@ -494,8 +512,8 @@ def cron_create( ) ) - def cron_delete(self, cron_trigger_id: str): - return self._run_coroutine(self.aio.cron_delete(cron_trigger_id)) + def cron_delete(self, cron_trigger_id: str) -> None: + self._run_coroutine(self.aio.cron_delete(cron_trigger_id)) def cron_list( self, @@ -505,7 +523,7 @@ def cron_list( additional_metadata: list[str] | None = None, order_by_field: CronWorkflowsOrderByField | None = None, order_by_direction: WorkflowRunOrderByDirection | None = None, - ): + ) -> CronWorkflowsList: return self._run_coroutine( self.aio.cron_list( offset, @@ -517,24 +535,24 @@ def cron_list( ) ) - def cron_get(self, cron_trigger_id: str): + def cron_get(self, cron_trigger_id: str) -> CronWorkflows: return self._run_coroutine(self.aio.cron_get(cron_trigger_id)) def schedule_create( self, workflow_name: str, trigger_at: datetime.datetime, - input: dict[str, Any], - additional_metadata: dict[str, str], - ): + input: JSONSerializableDict, + additional_metadata: JSONSerializableDict, + ) -> ScheduledWorkflows: return self._run_coroutine( self.aio.schedule_create( workflow_name, trigger_at, input, additional_metadata ) ) - def schedule_delete(self, scheduled_trigger_id: str): - return self._run_coroutine(self.aio.schedule_delete(scheduled_trigger_id)) + def schedule_delete(self, scheduled_trigger_id: str) -> None: + self._run_coroutine(self.aio.schedule_delete(scheduled_trigger_id)) def schedule_list( self, @@ -544,7 +562,7 @@ def schedule_list( additional_metadata: list[str] | None = None, order_by_field: CronWorkflowsOrderByField | None = None, order_by_direction: WorkflowRunOrderByDirection | None = None, - ): + ) -> ScheduledWorkflowsList: return self._run_coroutine( self.aio.schedule_list( offset, @@ -556,7 +574,7 @@ def schedule_list( ) ) - def schedule_get(self, scheduled_trigger_id: str): + def schedule_get(self, scheduled_trigger_id: str) -> ScheduledWorkflows: return self._run_coroutine(self.aio.schedule_get(scheduled_trigger_id)) def list_logs( diff --git a/hatchet_sdk/clients/run_event_listener.py b/hatchet_sdk/clients/run_event_listener.py index b5db6a74..3495ce14 100644 --- a/hatchet_sdk/clients/run_event_listener.py +++ b/hatchet_sdk/clients/run_event_listener.py @@ -1,6 +1,7 @@ import asyncio import json -from typing import AsyncGenerator +from enum import Enum +from typing import Any, AsyncGenerator, Callable, Generator, cast import grpc @@ -21,7 +22,7 @@ DEFAULT_ACTION_LISTENER_RETRY_COUNT = 5 -class StepRunEventType: +class StepRunEventType(str, Enum): STEP_RUN_EVENT_TYPE_STARTED = "STEP_RUN_EVENT_TYPE_STARTED" STEP_RUN_EVENT_TYPE_COMPLETED = "STEP_RUN_EVENT_TYPE_COMPLETED" STEP_RUN_EVENT_TYPE_FAILED = "STEP_RUN_EVENT_TYPE_FAILED" @@ -30,7 +31,7 @@ class StepRunEventType: STEP_RUN_EVENT_TYPE_STREAM = "STEP_RUN_EVENT_TYPE_STREAM" -class WorkflowRunEventType: +class WorkflowRunEventType(str, Enum): WORKFLOW_RUN_EVENT_TYPE_STARTED = "WORKFLOW_RUN_EVENT_TYPE_STARTED" WORKFLOW_RUN_EVENT_TYPE_COMPLETED = "WORKFLOW_RUN_EVENT_TYPE_COMPLETED" WORKFLOW_RUN_EVENT_TYPE_FAILED = "WORKFLOW_RUN_EVENT_TYPE_FAILED" @@ -62,14 +63,14 @@ def __init__(self, type: StepRunEventType, payload: str): self.payload = payload -def new_listener(config: ClientConfig): +def new_listener(config: ClientConfig) -> "RunEventListenerClient": return RunEventListenerClient(config=config) class RunEventListener: - workflow_run_id: str = None - additional_meta_kv: tuple[str, str] = None + workflow_run_id: str | None = None + additional_meta_kv: tuple[str, str] | None = None def __init__(self, client: DispatcherStub, token: str): self.client = client @@ -77,7 +78,9 @@ def __init__(self, client: DispatcherStub, token: str): self.token = token @classmethod - def for_run_id(cls, workflow_run_id: str, client: DispatcherStub, token: str): + def for_run_id( + cls, workflow_run_id: str, client: DispatcherStub, token: str + ) -> "RunEventListener": listener = RunEventListener(client, token) listener.workflow_run_id = workflow_run_id return listener @@ -85,21 +88,21 @@ def for_run_id(cls, workflow_run_id: str, client: DispatcherStub, token: str): @classmethod def for_additional_meta( cls, key: str, value: str, client: DispatcherStub, token: str - ): + ) -> "RunEventListener": listener = RunEventListener(client, token) listener.additional_meta_kv = (key, value) return listener - def abort(self): + def abort(self) -> None: self.stop_signal = True - def __aiter__(self): + def __aiter__(self) -> AsyncGenerator[StepRunEvent, None]: return self._generator() - async def __anext__(self): + async def __anext__(self) -> StepRunEvent: return await self._generator().__anext__() - def __iter__(self): + def __iter__(self) -> Generator[StepRunEvent, None, None]: try: loop = asyncio.get_event_loop() except RuntimeError as e: @@ -145,15 +148,18 @@ async def _generator(self) -> AsyncGenerator[StepRunEvent, None]: try: if workflow_event.eventPayload: + ## TODO: Should this be `dumps` instead? payload = json.loads(workflow_event.eventPayload) - except Exception as e: + except Exception: payload = workflow_event.eventPayload pass + assert isinstance(payload, str) + yield StepRunEvent(type=eventType, payload=payload) elif workflow_event.resourceType == RESOURCE_TYPE_WORKFLOW_RUN: - if workflow_event.eventType in workflow_run_event_type_mapping: - eventType = workflow_run_event_type_mapping[ + if workflow_event.eventType in step_run_event_type_mapping: + workflowRunEventType = step_run_event_type_mapping[ workflow_event.eventType ] else: @@ -166,10 +172,12 @@ async def _generator(self) -> AsyncGenerator[StepRunEvent, None]: try: if workflow_event.eventPayload: payload = json.loads(workflow_event.eventPayload) - except Exception as e: + except Exception: pass - yield StepRunEvent(type=eventType, payload=payload) + assert isinstance(payload, str) + + yield StepRunEvent(type=workflowRunEventType, payload=payload) if workflow_event.hangup: listener = None @@ -194,7 +202,7 @@ async def _generator(self) -> AsyncGenerator[StepRunEvent, None]: break # Raise StopAsyncIteration to properly end the generator - async def retry_subscribe(self): + async def retry_subscribe(self) -> AsyncGenerator[WorkflowEvent, None]: retries = 0 while retries < DEFAULT_ACTION_LISTENER_RETRY_COUNT: @@ -203,19 +211,25 @@ async def retry_subscribe(self): await asyncio.sleep(DEFAULT_ACTION_LISTENER_RETRY_INTERVAL) if self.workflow_run_id is not None: - return self.client.SubscribeToWorkflowEvents( - SubscribeToWorkflowEventsRequest( - workflowRunId=self.workflow_run_id, + return cast( + AsyncGenerator[WorkflowEvent, None], + self.client.SubscribeToWorkflowEvents( + SubscribeToWorkflowEventsRequest( + workflowRunId=self.workflow_run_id, + ), + metadata=get_metadata(self.token), ), - metadata=get_metadata(self.token), ) elif self.additional_meta_kv is not None: - return self.client.SubscribeToWorkflowEvents( - SubscribeToWorkflowEventsRequest( - additionalMetaKey=self.additional_meta_kv[0], - additionalMetaValue=self.additional_meta_kv[1], + return cast( + AsyncGenerator[WorkflowEvent, None], + self.client.SubscribeToWorkflowEvents( + SubscribeToWorkflowEventsRequest( + additionalMetaKey=self.additional_meta_kv[0], + additionalMetaValue=self.additional_meta_kv[1], + ), + metadata=get_metadata(self.token), ), - metadata=get_metadata(self.token), ) else: raise Exception("no listener method provided") @@ -226,34 +240,38 @@ async def retry_subscribe(self): else: raise ValueError(f"gRPC error: {e}") + raise Exception("Failed to subscribe to workflow events") + class RunEventListenerClient: def __init__(self, config: ClientConfig): self.token = config.token self.config = config - self.client: DispatcherStub = None + self.client: DispatcherStub | None = None - def stream_by_run_id(self, workflow_run_id: str): + def stream_by_run_id(self, workflow_run_id: str) -> RunEventListener: return self.stream(workflow_run_id) - def stream(self, workflow_run_id: str): + def stream(self, workflow_run_id: str) -> RunEventListener: if not isinstance(workflow_run_id, str) and hasattr(workflow_run_id, "__str__"): workflow_run_id = str(workflow_run_id) if not self.client: aio_conn = new_conn(self.config, True) - self.client = DispatcherStub(aio_conn) + self.client = DispatcherStub(aio_conn) # type: ignore[no-untyped-call] return RunEventListener.for_run_id(workflow_run_id, self.client, self.token) - def stream_by_additional_metadata(self, key: str, value: str): + def stream_by_additional_metadata(self, key: str, value: str) -> RunEventListener: if not self.client: aio_conn = new_conn(self.config, True) - self.client = DispatcherStub(aio_conn) + self.client = DispatcherStub(aio_conn) # type: ignore[no-untyped-call] return RunEventListener.for_additional_meta(key, value, self.client, self.token) - async def on(self, workflow_run_id: str, handler: callable = None): + async def on( + self, workflow_run_id: str, handler: Callable[[StepRunEvent], Any] | None = None + ) -> None: async for event in self.stream(workflow_run_id): # call the handler if provided if handler: diff --git a/hatchet_sdk/clients/workflow_listener.py b/hatchet_sdk/clients/workflow_listener.py index b1131587..f8dda4cf 100644 --- a/hatchet_sdk/clients/workflow_listener.py +++ b/hatchet_sdk/clients/workflow_listener.py @@ -1,10 +1,11 @@ import asyncio import json from collections.abc import AsyncIterator -from typing import AsyncGenerator +from typing import Any, cast import grpc -from grpc._cython import cygrpc +import grpc.aio +from grpc._cython import cygrpc # type: ignore[attr-defined] from hatchet_sdk.clients.event_ts import Event_ts, read_with_interrupt from hatchet_sdk.connection import new_conn @@ -31,10 +32,10 @@ def __init__(self, id: int, workflow_run_id: str): self.workflow_run_id = workflow_run_id self.queue: asyncio.Queue[WorkflowRunEvent | None] = asyncio.Queue() - async def __aiter__(self): + async def __aiter__(self) -> "_Subscription": return self - async def __anext__(self) -> WorkflowRunEvent: + async def __anext__(self) -> WorkflowRunEvent | None: return await self.queue.get() async def get(self) -> WorkflowRunEvent: @@ -45,10 +46,10 @@ async def get(self) -> WorkflowRunEvent: return event - async def put(self, item: WorkflowRunEvent): + async def put(self, item: WorkflowRunEvent) -> None: await self.queue.put(item) - async def close(self): + async def close(self) -> None: await self.queue.put(None) @@ -62,25 +63,28 @@ class PooledWorkflowRunListener: subscription_counter: int = 0 subscription_counter_lock: asyncio.Lock = asyncio.Lock() - requests: asyncio.Queue[SubscribeToWorkflowRunsRequest] = asyncio.Queue() + requests: asyncio.Queue[SubscribeToWorkflowRunsRequest | int] = asyncio.Queue() - listener: AsyncGenerator[WorkflowRunEvent, None] = None - listener_task: asyncio.Task = None + listener: ( + grpc.aio.UnaryStreamCall[SubscribeToWorkflowRunsRequest, WorkflowRunEvent] + | None + ) = None + listener_task: asyncio.Task[None] | None = None curr_requester: int = 0 # events have keys of the format workflow_run_id + subscription_id events: dict[int, _Subscription] = {} - interrupter: asyncio.Task = None + interrupter: asyncio.Task[None] | None = None def __init__(self, config: ClientConfig): conn = new_conn(config, True) - self.client = DispatcherStub(conn) + self.client = DispatcherStub(conn) # type: ignore[no-untyped-call] self.token = config.token self.config = config - async def _interrupter(self): + async def _interrupter(self) -> None: """ _interrupter runs in a separate thread and interrupts the listener according to a configurable duration. """ @@ -89,7 +93,7 @@ async def _interrupter(self): if self.interrupt is not None: self.interrupt.set() - async def _init_producer(self): + async def _init_producer(self) -> None: try: if not self.listener: while True: @@ -106,6 +110,9 @@ async def _init_producer(self): while True: self.interrupt = Event_ts() + if self.listener is None: + continue + t = asyncio.create_task( read_with_interrupt(self.listener, self.interrupt) ) @@ -118,7 +125,8 @@ async def _init_producer(self): ) t.cancel() - self.listener.cancel() + if self.listener: + self.listener.cancel() await asyncio.sleep( DEFAULT_WORKFLOW_LISTENER_RETRY_INTERVAL ) @@ -178,7 +186,7 @@ async def _request(self) -> AsyncIterator[SubscribeToWorkflowRunsRequest]: yield request self.requests.task_done() - def cleanup_subscription(self, subscription_id: int): + def cleanup_subscription(self, subscription_id: int) -> None: workflow_run_id = self.subscriptionsToWorkflows[subscription_id] if workflow_run_id in self.workflowsToSubscriptions: @@ -187,8 +195,7 @@ def cleanup_subscription(self, subscription_id: int): del self.subscriptionsToWorkflows[subscription_id] del self.events[subscription_id] - async def subscribe(self, workflow_run_id: str): - init_producer: asyncio.Task = None + async def subscribe(self, workflow_run_id: str) -> WorkflowRunEvent: try: # create a new subscription id, place a mutex on the counter await self.subscription_counter_lock.acquire() @@ -216,15 +223,13 @@ async def subscribe(self, workflow_run_id: str): if not self.listener_task or self.listener_task.done(): self.listener_task = asyncio.create_task(self._init_producer()) - event = await self.events[subscription_id].get() - - return event + return await self.events[subscription_id].get() except asyncio.CancelledError: raise finally: self.cleanup_subscription(subscription_id) - async def result(self, workflow_run_id: str): + async def result(self, workflow_run_id: str) -> dict[str, Any]: from hatchet_sdk.clients.admin import DedupeViolationErr event = await self.subscribe(workflow_run_id) @@ -248,7 +253,9 @@ async def result(self, workflow_run_id: str): return results - async def _retry_subscribe(self): + async def _retry_subscribe( + self, + ) -> grpc.aio.UnaryStreamCall[SubscribeToWorkflowRunsRequest, WorkflowRunEvent]: retries = 0 while retries < DEFAULT_WORKFLOW_LISTENER_RETRY_COUNT: @@ -260,14 +267,19 @@ async def _retry_subscribe(self): if self.curr_requester != 0: self.requests.put_nowait(self.curr_requester) - listener = self.client.SubscribeToWorkflowRuns( - self._request(), - metadata=get_metadata(self.token), + return cast( + grpc.aio.UnaryStreamCall[ + SubscribeToWorkflowRunsRequest, WorkflowRunEvent + ], + self.client.SubscribeToWorkflowRuns( + self._request(), + metadata=get_metadata(self.token), + ), ) - - return listener except grpc.RpcError as e: if e.code() == grpc.StatusCode.UNAVAILABLE: retries = retries + 1 else: raise ValueError(f"gRPC error: {e}") + + raise ValueError("Failed to connect to workflow run listener") diff --git a/hatchet_sdk/connection.py b/hatchet_sdk/connection.py index 185395e4..f09b625e 100644 --- a/hatchet_sdk/connection.py +++ b/hatchet_sdk/connection.py @@ -1,5 +1,5 @@ import os -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Literal, cast, overload import grpc @@ -7,20 +7,31 @@ from hatchet_sdk.loader import ClientConfig -def new_conn(config: "ClientConfig", aio=False): +@overload +def new_conn(config: "ClientConfig", aio: Literal[False]) -> grpc.Channel: ... + +@overload +def new_conn(config: "ClientConfig", aio: Literal[True]) -> grpc.aio.Channel: ... + + +def new_conn(config: "ClientConfig", aio: bool) -> grpc.Channel | grpc.aio.Channel: credentials: grpc.ChannelCredentials | None = None # load channel credentials - if config.tls_config.tls_strategy == "tls": + if config.tls_config.strategy == "tls": root: Any | None = None - if config.tls_config.ca_file: - root = open(config.tls_config.ca_file, "rb").read() + if config.tls_config.root_ca_file: + root = open(config.tls_config.root_ca_file, "rb").read() credentials = grpc.ssl_channel_credentials(root_certificates=root) - elif config.tls_config.tls_strategy == "mtls": - root = open(config.tls_config.ca_file, "rb").read() + elif config.tls_config.strategy == "mtls": + assert config.tls_config.root_ca_file + assert config.tls_config.key_file + assert config.tls_config.cert_file + + root = open(config.tls_config.root_ca_file, "rb").read() private_key = open(config.tls_config.key_file, "rb").read() certificate_chain = open(config.tls_config.cert_file, "rb").read() @@ -32,7 +43,7 @@ def new_conn(config: "ClientConfig", aio=False): start = grpc if not aio else grpc.aio - channel_options = [ + channel_options: list[tuple[str, str | int]] = [ ("grpc.max_send_message_length", config.grpc_max_send_message_length), ("grpc.max_receive_message_length", config.grpc_max_recv_message_length), ("grpc.keepalive_time_ms", 10 * 1000), @@ -46,7 +57,7 @@ def new_conn(config: "ClientConfig", aio=False): # When steps execute via os.fork, we see `TSI_DATA_CORRUPTED` errors. os.environ["GRPC_ENABLE_FORK_SUPPORT"] = "False" - if config.tls_config.tls_strategy == "none": + if config.tls_config.strategy == "none": conn = start.insecure_channel( target=config.host_port, options=channel_options, @@ -61,4 +72,8 @@ def new_conn(config: "ClientConfig", aio=False): credentials=credentials, options=channel_options, ) - return conn + + return cast( + grpc.Channel | grpc.aio.Channel, + conn, + ) diff --git a/hatchet_sdk/context/__init__.py b/hatchet_sdk/context/__init__.py index 0cebf2bf..e69de29b 100644 --- a/hatchet_sdk/context/__init__.py +++ b/hatchet_sdk/context/__init__.py @@ -1 +0,0 @@ -from .context import Context diff --git a/hatchet_sdk/context/context.py b/hatchet_sdk/context/context.py index f20acd66..8f8a50f5 100644 --- a/hatchet_sdk/context/context.py +++ b/hatchet_sdk/context/context.py @@ -2,38 +2,32 @@ import json import traceback from concurrent.futures import Future, ThreadPoolExecutor -from typing import Any, Generic, Type, TypeVar, cast, overload +from typing import Any, TypeVar, cast from warnings import warn from pydantic import BaseModel, StrictStr -from hatchet_sdk.clients.events import EventClient -from hatchet_sdk.clients.rest.tenacity_utils import tenacity_retry -from hatchet_sdk.clients.rest_client import RestApi -from hatchet_sdk.clients.run_event_listener import RunEventListenerClient -from hatchet_sdk.clients.workflow_listener import PooledWorkflowRunListener -from hatchet_sdk.context.worker_context import WorkerContext -from hatchet_sdk.contracts.dispatcher_pb2 import OverridesData -from hatchet_sdk.contracts.workflows_pb2 import ( - BulkTriggerWorkflowRequest, - TriggerWorkflowRequest, -) -from hatchet_sdk.utils.types import WorkflowValidator -from hatchet_sdk.utils.typing import is_basemodel_subclass -from hatchet_sdk.workflow_run import WorkflowRunRef - -from ..clients.admin import ( +from hatchet_sdk.clients.admin import ( AdminClient, ChildTriggerWorkflowOptions, ChildWorkflowRunDict, TriggerWorkflowOptions, WorkflowRunDict, ) -from ..clients.dispatcher.dispatcher import ( # type: ignore[attr-defined] +from hatchet_sdk.clients.dispatcher.dispatcher import ( # type: ignore[attr-defined] Action, DispatcherClient, ) -from ..logger import logger +from hatchet_sdk.clients.events import EventClient +from hatchet_sdk.clients.rest.tenacity_utils import tenacity_retry +from hatchet_sdk.clients.rest_client import RestApi +from hatchet_sdk.clients.run_event_listener import RunEventListenerClient +from hatchet_sdk.clients.workflow_listener import PooledWorkflowRunListener +from hatchet_sdk.context.worker_context import WorkerContext +from hatchet_sdk.contracts.dispatcher_pb2 import OverridesData +from hatchet_sdk.logger import logger +from hatchet_sdk.utils.types import JSONSerializableDict, WorkflowValidator +from hatchet_sdk.workflow_run import WorkflowRunRef DEFAULT_WORKFLOW_POLLING_INTERVAL = 5 # Seconds @@ -46,125 +40,9 @@ def get_caller_file_path() -> str: return caller_frame.filename -class BaseContext: - - action: Action - spawn_index: int - - def _prepare_workflow_options( - self, - key: str | None = None, - options: ChildTriggerWorkflowOptions | None = None, - worker_id: str | None = None, - ) -> TriggerWorkflowOptions: - workflow_run_id = self.action.workflow_run_id - step_run_id = self.action.step_run_id - - desired_worker_id = None - if options is not None and "sticky" in options and options["sticky"] == True: - desired_worker_id = worker_id - - meta = None - if options is not None and "additional_metadata" in options: - meta = options["additional_metadata"] - - ## TODO: Pydantic here to simplify this - trigger_options: TriggerWorkflowOptions = { - "parent_id": workflow_run_id, - "parent_step_run_id": step_run_id, - "child_key": key, - "child_index": self.spawn_index, - "additional_metadata": meta, - "desired_worker_id": desired_worker_id, - } - - self.spawn_index += 1 - return trigger_options - - -class ContextAioImpl(BaseContext): - def __init__( - self, - action: Action, - dispatcher_client: DispatcherClient, - admin_client: AdminClient, - event_client: EventClient, - rest_client: RestApi, - workflow_listener: PooledWorkflowRunListener, - workflow_run_event_listener: RunEventListenerClient, - worker: WorkerContext, - namespace: str = "", - ): - self.action = action - self.dispatcher_client = dispatcher_client - self.admin_client = admin_client - self.event_client = event_client - self.rest_client = rest_client - self.workflow_listener = workflow_listener - self.workflow_run_event_listener = workflow_run_event_listener - self.namespace = namespace - self.spawn_index = -1 - self.worker = worker - - @tenacity_retry - async def spawn_workflow( - self, - workflow_name: str, - input: dict[str, Any] = {}, - key: str | None = None, - options: ChildTriggerWorkflowOptions | None = None, - ) -> WorkflowRunRef: - worker_id = self.worker.id() - # if ( - # options is not None - # and "sticky" in options - # and options["sticky"] == True - # and not self.worker.has_workflow(workflow_name) - # ): - # raise Exception( - # f"cannot run with sticky: workflow {workflow_name} is not registered on the worker" - # ) - - trigger_options = self._prepare_workflow_options(key, options, worker_id) - - return await self.admin_client.aio.run_workflow( - workflow_name, input, trigger_options - ) - - @tenacity_retry - async def spawn_workflows( - self, child_workflow_runs: list[ChildWorkflowRunDict] - ) -> list[WorkflowRunRef]: - - if len(child_workflow_runs) == 0: - raise Exception("no child workflows to spawn") - - worker_id = self.worker.id() - - bulk_trigger_workflow_runs: list[WorkflowRunDict] = [] - for child_workflow_run in child_workflow_runs: - workflow_name = child_workflow_run["workflow_name"] - input = child_workflow_run["input"] - - key = child_workflow_run.get("key") - options = child_workflow_run.get("options", {}) - - trigger_options = self._prepare_workflow_options(key, options, worker_id) - - bulk_trigger_workflow_runs.append( - WorkflowRunDict( - workflow_name=workflow_name, input=input, options=trigger_options - ) - ) - - return await self.admin_client.aio.run_workflows(bulk_trigger_workflow_runs) - - -class Context(BaseContext): +class Context: spawn_index = -1 - worker: WorkerContext - def __init__( self, action: Action, @@ -172,7 +50,7 @@ def __init__( admin_client: AdminClient, event_client: EventClient, rest_client: RestApi, - workflow_listener: PooledWorkflowRunListener, + workflow_listener: PooledWorkflowRunListener | None, workflow_run_event_listener: RunEventListenerClient, worker: WorkerContext, namespace: str = "", @@ -181,17 +59,7 @@ def __init__( self.worker = worker self.validator_registry = validator_registry - self.aio = ContextAioImpl( - action, - dispatcher_client, - admin_client, - event_client, - rest_client, - workflow_listener, - workflow_run_event_listener, - worker, - namespace, - ) + self.data: dict[str, Any] # Check the type of action.action_payload before attempting to load it as JSON if isinstance(action.action_payload, (str, bytes, bytearray)): @@ -203,16 +71,14 @@ def __init__( self.data: dict[str, Any] = {} # type: ignore[no-redef] else: # Directly assign the payload to self.data if it's already a dict - self.data = ( - action.action_payload if isinstance(action.action_payload, dict) else {} - ) + self.data = action.action_payload self.action = action # FIXME: stepRunId is a legacy field, we should remove it self.stepRunId = action.step_run_id - self.step_run_id = action.step_run_id + self.step_run_id: str = action.step_run_id self.exit_flag = False self.dispatcher_client = dispatcher_client self.admin_client = admin_client @@ -236,6 +102,27 @@ def __init__( else: self.input = self.data.get("input", {}) + def _prepare_workflow_options( + self, + key: str | None = None, + options: ChildTriggerWorkflowOptions = ChildTriggerWorkflowOptions(), + worker_id: str | None = None, + ) -> TriggerWorkflowOptions: + workflow_run_id = self.action.workflow_run_id + step_run_id = self.action.step_run_id + + trigger_options = TriggerWorkflowOptions( + parent_id=workflow_run_id, + parent_step_run_id=step_run_id, + child_key=key, + child_index=self.spawn_index, + additional_metadata=options.additional_metadata, + desired_worker_id=worker_id if options.sticky else None, + ) + + self.spawn_index += 1 + return trigger_options + def step_output(self, step: str) -> dict[str, Any] | BaseModel: workflow_validator = next( (v for k, v in self.validator_registry.items() if k.split(":")[-1] == step), @@ -252,9 +139,11 @@ def step_output(self, step: str) -> dict[str, Any] | BaseModel: return parent_step_data + @property def triggered_by_event(self) -> bool: return cast(str, self.data.get("triggered_by", "")) == "event" + @property def workflow_input(self) -> dict[str, Any] | T: if (r := self.validator_registry.get(self.action.action_id)) and ( i := r.workflow_input @@ -266,6 +155,7 @@ def workflow_input(self) -> dict[str, Any] | T: return self.input + @property def workflow_run_id(self) -> str: return self.action.workflow_run_id @@ -361,21 +251,27 @@ def refresh_timeout(self, increment_by: str) -> None: except Exception as e: logger.error(f"Error refreshing timeout: {e}") + @property def retry_count(self) -> int: return self.action.retry_count + @property def additional_metadata(self) -> dict[str, Any] | None: return self.action.additional_metadata + @property def child_index(self) -> int | None: return self.action.child_workflow_index + @property def child_key(self) -> str | None: return self.action.child_workflow_key + @property def parent_workflow_run_id(self) -> str | None: return self.action.parent_workflow_run_id + @property def step_run_errors(self) -> dict[str, str]: errors = cast(dict[str, str], self.data.get("step_run_errors", {})) @@ -403,3 +299,42 @@ def fetch_run_failures(self) -> list[dict[str, StrictStr]]: for step_run in job_run.step_runs if step_run.error and step_run.step ] + + @tenacity_retry + async def spawn_workflow( + self, + workflow_name: str, + input: JSONSerializableDict = {}, + key: str | None = None, + options: ChildTriggerWorkflowOptions = ChildTriggerWorkflowOptions(), + ) -> WorkflowRunRef: + worker_id = self.worker.id() + + trigger_options = self._prepare_workflow_options(key, options, worker_id) + + return await self.admin_client.aio.run_workflow( + workflow_name, input, trigger_options + ) + + @tenacity_retry + async def spawn_workflows( + self, child_workflow_runs: list[ChildWorkflowRunDict] + ) -> list[WorkflowRunRef]: + + if len(child_workflow_runs) == 0: + raise Exception("no child workflows to spawn") + + worker_id = self.worker.id() + + bulk_trigger_workflow_runs = [ + WorkflowRunDict( + workflow_name=child_workflow_run.workflow_name, + input=child_workflow_run.input, + options=self._prepare_workflow_options( + child_workflow_run.key, child_workflow_run.options, worker_id + ), + ) + for child_workflow_run in child_workflow_runs + ] + + return await self.admin_client.aio.run_workflows(bulk_trigger_workflow_runs) diff --git a/hatchet_sdk/features/cron.py b/hatchet_sdk/features/cron.py index c54e5b3b..7d3a34ce 100644 --- a/hatchet_sdk/features/cron.py +++ b/hatchet_sdk/features/cron.py @@ -1,6 +1,6 @@ from typing import Union -from pydantic import BaseModel, field_validator +from pydantic import BaseModel, Field, field_validator from hatchet_sdk.client import Client from hatchet_sdk.clients.rest.models.cron_workflows import CronWorkflows @@ -11,9 +11,10 @@ from hatchet_sdk.clients.rest.models.workflow_run_order_by_direction import ( WorkflowRunOrderByDirection, ) +from hatchet_sdk.utils.types import JSONSerializableDict -class CreateCronTriggerInput(BaseModel): +class CreateCronTriggerJSONSerializableDict(BaseModel): """ Schema for creating a workflow run triggered by a cron. @@ -23,12 +24,13 @@ class CreateCronTriggerInput(BaseModel): additional_metadata (dict[str, str]): Additional metadata associated with the cron trigger (e.g. {"key1": "value1", "key2": "value2"}). """ - expression: str = None - input: dict = {} - additional_metadata: dict[str, str] = {} + expression: str + input: JSONSerializableDict = Field(default_factory=dict) + additional_metadata: JSONSerializableDict = Field(default_factory=dict) @field_validator("expression") - def validate_cron_expression(cls, v): + @classmethod + def validate_cron_expression(cls, v: str) -> str: """ Validates the cron expression to ensure it adheres to the expected format. @@ -86,8 +88,8 @@ def create( workflow_name: str, cron_name: str, expression: str, - input: dict, - additional_metadata: dict[str, str], + input: JSONSerializableDict, + additional_metadata: JSONSerializableDict, ) -> CronWorkflows: """ Creates a new workflow cron trigger. @@ -102,7 +104,7 @@ def create( Returns: CronWorkflows: The created cron workflow instance. """ - validated_input = CreateCronTriggerInput( + validated_input = CreateCronTriggerJSONSerializableDict( expression=expression, input=input, additional_metadata=additional_metadata ) @@ -121,10 +123,11 @@ def delete(self, cron_trigger: Union[str, CronWorkflows]) -> None: Args: cron_trigger (Union[str, CronWorkflows]): The cron trigger ID or CronWorkflows instance to delete. """ - id_ = cron_trigger - if isinstance(cron_trigger, CronWorkflows): - id_ = cron_trigger.metadata.id - self._client.rest.cron_delete(id_) + self._client.rest.cron_delete( + cron_trigger.metadata.id + if isinstance(cron_trigger, CronWorkflows) + else cron_trigger + ) def list( self, @@ -168,10 +171,11 @@ def get(self, cron_trigger: Union[str, CronWorkflows]) -> CronWorkflows: Returns: CronWorkflows: The requested cron workflow instance. """ - id_ = cron_trigger - if isinstance(cron_trigger, CronWorkflows): - id_ = cron_trigger.metadata.id - return self._client.rest.cron_get(id_) + return self._client.rest.cron_get( + cron_trigger.metadata.id + if isinstance(cron_trigger, CronWorkflows) + else cron_trigger + ) class CronClientAsync: @@ -198,8 +202,8 @@ async def create( workflow_name: str, cron_name: str, expression: str, - input: dict, - additional_metadata: dict[str, str], + input: JSONSerializableDict, + additional_metadata: JSONSerializableDict, ) -> CronWorkflows: """ Asynchronously creates a new workflow cron trigger. @@ -214,7 +218,7 @@ async def create( Returns: CronWorkflows: The created cron workflow instance. """ - validated_input = CreateCronTriggerInput( + validated_input = CreateCronTriggerJSONSerializableDict( expression=expression, input=input, additional_metadata=additional_metadata ) @@ -233,10 +237,11 @@ async def delete(self, cron_trigger: Union[str, CronWorkflows]) -> None: Args: cron_trigger (Union[str, CronWorkflows]): The cron trigger ID or CronWorkflows instance to delete. """ - id_ = cron_trigger - if isinstance(cron_trigger, CronWorkflows): - id_ = cron_trigger.metadata.id - await self._client.rest.aio.cron_delete(id_) + await self._client.rest.aio.cron_delete( + cron_trigger.metadata.id + if isinstance(cron_trigger, CronWorkflows) + else cron_trigger + ) async def list( self, @@ -280,7 +285,9 @@ async def get(self, cron_trigger: Union[str, CronWorkflows]) -> CronWorkflows: Returns: CronWorkflows: The requested cron workflow instance. """ - id_ = cron_trigger - if isinstance(cron_trigger, CronWorkflows): - id_ = cron_trigger.metadata.id - return await self._client.rest.aio.cron_get(id_) + + return await self._client.rest.aio.cron_get( + cron_trigger.metadata.id + if isinstance(cron_trigger, CronWorkflows) + else cron_trigger + ) diff --git a/hatchet_sdk/features/scheduled.py b/hatchet_sdk/features/scheduled.py index 45af2609..b27502b6 100644 --- a/hatchet_sdk/features/scheduled.py +++ b/hatchet_sdk/features/scheduled.py @@ -1,10 +1,9 @@ import datetime -from typing import Any, Coroutine, Dict, List, Optional, Union +from typing import List, Optional, Union -from pydantic import BaseModel +from pydantic import BaseModel, Field from hatchet_sdk.client import Client -from hatchet_sdk.clients.rest.models.cron_workflows import CronWorkflows from hatchet_sdk.clients.rest.models.cron_workflows_order_by_field import ( CronWorkflowsOrderByField, ) @@ -12,12 +11,16 @@ from hatchet_sdk.clients.rest.models.scheduled_workflows_list import ( ScheduledWorkflowsList, ) +from hatchet_sdk.clients.rest.models.scheduled_workflows_order_by_field import ( + ScheduledWorkflowsOrderByField, +) from hatchet_sdk.clients.rest.models.workflow_run_order_by_direction import ( WorkflowRunOrderByDirection, ) +from hatchet_sdk.utils.types import JSONSerializableDict -class CreateScheduledTriggerInput(BaseModel): +class CreateScheduledTriggerJSONSerializableDict(BaseModel): """ Schema for creating a scheduled workflow run. @@ -27,9 +30,9 @@ class CreateScheduledTriggerInput(BaseModel): trigger_at (Optional[datetime.datetime]): The datetime when the run should be triggered. """ - input: Dict[str, Any] = {} - additional_metadata: Dict[str, str] = {} - trigger_at: Optional[datetime.datetime] = None + input: JSONSerializableDict = Field(default_factory=dict) + additional_metadata: JSONSerializableDict = Field(default_factory=dict) + trigger_at: datetime.datetime class ScheduledClient: @@ -57,8 +60,8 @@ def create( self, workflow_name: str, trigger_at: datetime.datetime, - input: Dict[str, Any], - additional_metadata: Dict[str, str], + input: JSONSerializableDict, + additional_metadata: JSONSerializableDict, ) -> ScheduledWorkflows: """ Creates a new scheduled workflow run asynchronously. @@ -73,7 +76,7 @@ def create( ScheduledWorkflows: The created scheduled workflow instance. """ - validated_input = CreateScheduledTriggerInput( + validated_input = CreateScheduledTriggerJSONSerializableDict( trigger_at=trigger_at, input=input, additional_metadata=additional_metadata ) @@ -91,10 +94,11 @@ def delete(self, scheduled: Union[str, ScheduledWorkflows]) -> None: Args: scheduled (Union[str, ScheduledWorkflows]): The scheduled workflow trigger ID or ScheduledWorkflows instance to delete. """ - id_ = scheduled - if isinstance(scheduled, ScheduledWorkflows): - id_ = scheduled.metadata.id - self._client.rest.schedule_delete(id_) + self._client.rest.schedule_delete( + scheduled.metadata.id + if isinstance(scheduled, ScheduledWorkflows) + else scheduled + ) def list( self, @@ -138,10 +142,11 @@ def get(self, scheduled: Union[str, ScheduledWorkflows]) -> ScheduledWorkflows: Returns: ScheduledWorkflows: The requested scheduled workflow instance. """ - id_ = scheduled - if isinstance(scheduled, ScheduledWorkflows): - id_ = scheduled.metadata.id - return self._client.rest.schedule_get(id_) + return self._client.rest.schedule_get( + scheduled.metadata.id + if isinstance(scheduled, ScheduledWorkflows) + else scheduled + ) class ScheduledClientAsync: @@ -167,8 +172,8 @@ async def create( self, workflow_name: str, trigger_at: datetime.datetime, - input: Dict[str, Any], - additional_metadata: Dict[str, str], + input: JSONSerializableDict, + additional_metadata: JSONSerializableDict, ) -> ScheduledWorkflows: """ Creates a new scheduled workflow run asynchronously. @@ -193,10 +198,11 @@ async def delete(self, scheduled: Union[str, ScheduledWorkflows]) -> None: Args: scheduled (Union[str, ScheduledWorkflows]): The scheduled workflow trigger ID or ScheduledWorkflows instance to delete. """ - id_ = scheduled - if isinstance(scheduled, ScheduledWorkflows): - id_ = scheduled.metadata.id - await self._client.rest.aio.schedule_delete(id_) + await self._client.rest.aio.schedule_delete( + scheduled.metadata.id + if isinstance(scheduled, ScheduledWorkflows) + else scheduled + ) async def list( self, @@ -204,7 +210,7 @@ async def list( limit: Optional[int] = None, workflow_id: Optional[str] = None, additional_metadata: Optional[List[str]] = None, - order_by_field: Optional[CronWorkflowsOrderByField] = None, + order_by_field: Optional[ScheduledWorkflowsOrderByField] = None, order_by_direction: Optional[WorkflowRunOrderByDirection] = None, ) -> ScheduledWorkflowsList: """ @@ -242,7 +248,8 @@ async def get( Returns: ScheduledWorkflows: The requested scheduled workflow instance. """ - id_ = scheduled - if isinstance(scheduled, ScheduledWorkflows): - id_ = scheduled.metadata.id - return await self._client.rest.aio.schedule_get(id_) + return await self._client.rest.aio.schedule_get( + scheduled.metadata.id + if isinstance(scheduled, ScheduledWorkflows) + else scheduled + ) diff --git a/hatchet_sdk/hatchet.py b/hatchet_sdk/hatchet.py index bf0e9089..d23a8954 100644 --- a/hatchet_sdk/hatchet.py +++ b/hatchet_sdk/hatchet.py @@ -1,197 +1,47 @@ import asyncio import logging -from typing import Any, Callable, Optional, ParamSpec, Type, TypeVar, Union - -from pydantic import BaseModel -from typing_extensions import deprecated +from typing import TYPE_CHECKING, Any, Callable, Optional, Type, TypeVar, cast +from hatchet_sdk.client import Client, new_client, new_client_raw +from hatchet_sdk.clients.admin import AdminClient +from hatchet_sdk.clients.dispatcher.dispatcher import DispatcherClient +from hatchet_sdk.clients.events import EventClient from hatchet_sdk.clients.rest_client import RestApi +from hatchet_sdk.clients.run_event_listener import RunEventListenerClient from hatchet_sdk.context.context import Context -from hatchet_sdk.contracts.workflows_pb2 import ( - ConcurrencyLimitStrategy, - CreateStepRateLimit, - DesiredWorkerLabels, - StickyStrategy, -) +from hatchet_sdk.contracts.workflows_pb2 import DesiredWorkerLabels from hatchet_sdk.features.cron import CronClient from hatchet_sdk.features.scheduled import ScheduledClient from hatchet_sdk.labels import DesiredWorkerLabel -from hatchet_sdk.loader import ClientConfig, ConfigLoader +from hatchet_sdk.loader import ClientConfig +from hatchet_sdk.logger import logger from hatchet_sdk.rate_limit import RateLimit -from hatchet_sdk.v2.callable import HatchetCallable - -from .client import Client, new_client, new_client_raw -from .clients.admin import AdminClient -from .clients.dispatcher.dispatcher import DispatcherClient -from .clients.events import EventClient -from .clients.run_event_listener import RunEventListenerClient -from .logger import logger -from .worker.worker import Worker -from .workflow import ( +from hatchet_sdk.workflow import ( ConcurrencyExpression, - WorkflowInterface, - WorkflowMeta, - WorkflowStepProtocol, + EmptyModel, + Step, + StepType, + StickyStrategy, + TWorkflowInput, + WorkflowConfig, + WorkflowDeclaration, ) -T = TypeVar("T", bound=BaseModel) -R = TypeVar("R") -P = ParamSpec("P") - -TWorkflow = TypeVar("TWorkflow", bound=object) - - -def workflow( - name: str = "", - on_events: list[str] | None = None, - on_crons: list[str] | None = None, - version: str = "", - timeout: str = "60m", - schedule_timeout: str = "5m", - sticky: Union[StickyStrategy.Value, None] = None, # type: ignore[name-defined] - default_priority: int | None = None, - concurrency: ConcurrencyExpression | None = None, - input_validator: Type[T] | None = None, -) -> Callable[[Type[TWorkflow]], WorkflowMeta]: - on_events = on_events or [] - on_crons = on_crons or [] - - def inner(cls: Type[TWorkflow]) -> WorkflowMeta: - nonlocal name - name = name or str(cls.__name__) - - setattr(cls, "on_events", on_events) - setattr(cls, "on_crons", on_crons) - setattr(cls, "name", name) - setattr(cls, "version", version) - setattr(cls, "timeout", timeout) - setattr(cls, "schedule_timeout", schedule_timeout) - setattr(cls, "sticky", sticky) - setattr(cls, "default_priority", default_priority) - setattr(cls, "concurrency_expression", concurrency) - - # Define a new class with the same name and bases as the original, but - # with WorkflowMeta as its metaclass - - ## TODO: Figure out how to type this metaclass correctly - setattr(cls, "input_validator", input_validator) - - return WorkflowMeta(name, cls.__bases__, dict(cls.__dict__)) - - return inner - - -def step( - name: str = "", - timeout: str = "", - parents: list[str] | None = None, - retries: int = 0, - rate_limits: list[RateLimit] | None = None, - desired_worker_labels: dict[str, DesiredWorkerLabel] = {}, - backoff_factor: float | None = None, - backoff_max_seconds: int | None = None, -) -> Callable[[Callable[P, R]], Callable[P, R]]: - parents = parents or [] - - def inner(func: Callable[P, R]) -> Callable[P, R]: - limits = None - if rate_limits: - limits = [rate_limit._req for rate_limit in rate_limits or []] - - setattr(func, "_step_name", name.lower() or str(func.__name__).lower()) - setattr(func, "_step_parents", parents) - setattr(func, "_step_timeout", timeout) - setattr(func, "_step_retries", retries) - setattr(func, "_step_rate_limits", limits) - setattr(func, "_step_backoff_factor", backoff_factor) - setattr(func, "_step_backoff_max_seconds", backoff_max_seconds) - - def create_label(d: DesiredWorkerLabel) -> DesiredWorkerLabels: - value = d["value"] if "value" in d else None - return DesiredWorkerLabels( - strValue=str(value) if not isinstance(value, int) else None, - intValue=value if isinstance(value, int) else None, - required=d["required"] if "required" in d else None, # type: ignore[arg-type] - weight=d["weight"] if "weight" in d else None, - comparator=d["comparator"] if "comparator" in d else None, # type: ignore[arg-type] - ) - - setattr( - func, - "_step_desired_worker_labels", - {key: create_label(d) for key, d in desired_worker_labels.items()}, - ) - - return func - - return inner - - -def on_failure_step( - name: str = "", - timeout: str = "", - retries: int = 0, - rate_limits: list[RateLimit] | None = None, - backoff_factor: float | None = None, - backoff_max_seconds: int | None = None, -) -> Callable[..., Any]: - def inner(func: Callable[[Context], Any]) -> Callable[[Context], Any]: - limits = None - if rate_limits: - limits = [ - CreateStepRateLimit(key=rate_limit.static_key, units=rate_limit.units) # type: ignore[arg-type] - for rate_limit in rate_limits or [] - ] - - setattr( - func, "_on_failure_step_name", name.lower() or str(func.__name__).lower() - ) - setattr(func, "_on_failure_step_timeout", timeout) - setattr(func, "_on_failure_step_retries", retries) - setattr(func, "_on_failure_step_rate_limits", limits) - setattr(func, "_on_failure_step_backoff_factor", backoff_factor) - setattr(func, "_on_failure_step_backoff_max_seconds", backoff_max_seconds) - - return func - - return inner - - -def concurrency( - name: str = "", - max_runs: int = 1, - limit_strategy: ConcurrencyLimitStrategy = ConcurrencyLimitStrategy.CANCEL_IN_PROGRESS, -) -> Callable[..., Any]: - def inner(func: Callable[[Context], Any]) -> Callable[[Context], Any]: - setattr( - func, - "_concurrency_fn_name", - name.lower() or str(func.__name__).lower(), - ) - setattr(func, "_concurrency_max_runs", max_runs) - setattr(func, "_concurrency_limit_strategy", limit_strategy) - - return func - - return inner - - -class HatchetRest: - """ - Main client for interacting with the Hatchet API. - - This class provides access to various client interfaces and utility methods - for working with Hatchet via the REST API, +if TYPE_CHECKING: + from hatchet_sdk.worker.worker import Worker - Attributes: - rest (RestApi): Interface for REST API operations. - """ +R = TypeVar("R") - rest: RestApi - def __init__(self, config: ClientConfig = ClientConfig()): - _config: ClientConfig = ConfigLoader(".").load_client_config(config) - self.rest = RestApi(_config.server_url, _config.token, _config.tenant_id) +def transform_desired_worker_label(d: DesiredWorkerLabel) -> DesiredWorkerLabels: + value = d.value + return DesiredWorkerLabels( + strValue=value if not isinstance(value, int) else None, + intValue=value if isinstance(value, int) else None, + required=d.required, + weight=d.weight, + comparator=d.comparator, # type: ignore[arg-type] + ) class Hatchet: @@ -249,13 +99,6 @@ def __init__( self.cron = CronClient(self._client) self.scheduled = ScheduledClient(self._client) - @property - @deprecated( - "Direct access to client is deprecated and will be removed in a future version. Use specific client properties (Hatchet.admin, Hatchet.dispatcher, Hatchet.event, Hatchet.rest) instead. [0.32.0]", - ) - def client(self) -> Client: - return self._client - @property def admin(self) -> AdminClient: return self._client.admin @@ -284,17 +127,71 @@ def config(self) -> ClientConfig: def tenant_id(self) -> str: return self._client.config.tenant_id - concurrency = staticmethod(concurrency) + def step( + self, + name: str = "", + timeout: str = "60m", + parents: list[str] = [], + retries: int = 0, + rate_limits: list[RateLimit] = [], + desired_worker_labels: dict[str, DesiredWorkerLabel] = {}, + backoff_factor: float | None = None, + backoff_max_seconds: int | None = None, + ) -> Callable[[Callable[[Any, Context], Any]], Step[R]]: + def inner(func: Callable[[Any, Context], R]) -> Step[R]: + return Step( + fn=func, + type=StepType.DEFAULT, + name=name.lower() or str(func.__name__).lower(), + timeout=timeout, + parents=parents, + retries=retries, + rate_limits=[r for rate_limit in rate_limits if (r := rate_limit._req)], + desired_worker_labels={ + key: transform_desired_worker_label(d) + for key, d in desired_worker_labels.items() + }, + backoff_factor=backoff_factor, + backoff_max_seconds=backoff_max_seconds, + ) - workflow = staticmethod(workflow) + return inner - step = staticmethod(step) + def on_failure_step( + self, + name: str = "", + timeout: str = "60m", + parents: list[str] = [], + retries: int = 0, + rate_limits: list[RateLimit] = [], + desired_worker_labels: dict[str, DesiredWorkerLabel] = {}, + backoff_factor: float | None = None, + backoff_max_seconds: int | None = None, + ) -> Callable[[Callable[[Any, Context], Any]], Step[R]]: + def inner(func: Callable[[Any, Context], R]) -> Step[R]: + return Step( + fn=func, + type=StepType.ON_FAILURE, + name=name.lower() or str(func.__name__).lower(), + timeout=timeout, + parents=parents, + retries=retries, + rate_limits=[r for rate_limit in rate_limits if (r := rate_limit._req)], + desired_worker_labels={ + key: transform_desired_worker_label(d) + for key, d in desired_worker_labels.items() + }, + backoff_factor=backoff_factor, + backoff_max_seconds=backoff_max_seconds, + ) - on_failure_step = staticmethod(on_failure_step) + return inner def worker( self, name: str, max_runs: int | None = None, labels: dict[str, str | int] = {} - ) -> Worker: + ) -> "Worker": + from hatchet_sdk.worker.worker import Worker + try: loop = asyncio.get_running_loop() except RuntimeError: @@ -308,3 +205,33 @@ def worker( debug=self._client.debug, owned_loop=loop is None, ) + + def declare_workflow( + self, + name: str = "", + on_events: list[str] = [], + on_crons: list[str] = [], + version: str = "", + timeout: str = "60m", + schedule_timeout: str = "5m", + sticky: StickyStrategy | None = None, + default_priority: int = 1, + concurrency: ConcurrencyExpression | None = None, + input_validator: Type[TWorkflowInput] | None = None, + ) -> WorkflowDeclaration[TWorkflowInput]: + return WorkflowDeclaration[TWorkflowInput]( + WorkflowConfig( + name=name, + on_events=on_events, + on_crons=on_crons, + version=version, + timeout=timeout, + schedule_timeout=schedule_timeout, + sticky=sticky, + default_priority=default_priority, + concurrency=concurrency, + input_validator=input_validator + or cast(Type[TWorkflowInput], EmptyModel), + ), + self, + ) diff --git a/hatchet_sdk/labels.py b/hatchet_sdk/labels.py index 646c666d..55836e31 100644 --- a/hatchet_sdk/labels.py +++ b/hatchet_sdk/labels.py @@ -1,10 +1,10 @@ -from typing import TypedDict +from pydantic import BaseModel, ConfigDict -class DesiredWorkerLabel(TypedDict, total=False): +class DesiredWorkerLabel(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + value: str | int - required: bool | None = None + required: bool = False weight: int | None = None - comparator: int | None = ( - None # _ClassVar[WorkerLabelComparator] TODO figure out type - ) + comparator: int | None = None diff --git a/hatchet_sdk/loader.py b/hatchet_sdk/loader.py index 38b0b2bf..017ac739 100644 --- a/hatchet_sdk/loader.py +++ b/hatchet_sdk/loader.py @@ -1,259 +1,128 @@ import json -import os from logging import Logger, getLogger -from typing import Dict, Optional - -import yaml - -from .token import get_addresses_from_jwt, get_tenant_id_from_jwt - - -class ClientTLSConfig: - def __init__( - self, - tls_strategy: str, - cert_file: str, - key_file: str, - ca_file: str, - server_name: str, - ): - self.tls_strategy = tls_strategy - self.cert_file = cert_file - self.key_file = key_file - self.ca_file = ca_file - self.server_name = server_name - - -class ClientConfig: - logInterceptor: Logger - - def __init__( - self, - tenant_id: str = None, - tls_config: ClientTLSConfig = None, - token: str = None, - host_port: str = "localhost:7070", - server_url: str = "https://app.dev.hatchet-tools.com", - namespace: str = None, - listener_v2_timeout: int = None, - logger: Logger = None, - grpc_max_recv_message_length: int = 4 * 1024 * 1024, # 4MB - grpc_max_send_message_length: int = 4 * 1024 * 1024, # 4MB - otel_exporter_oltp_endpoint: str | None = None, - otel_service_name: str | None = None, - otel_exporter_oltp_headers: dict[str, str] | None = None, - otel_exporter_oltp_protocol: str | None = None, - worker_healthcheck_port: int | None = None, - worker_healthcheck_enabled: bool | None = None, - worker_preset_labels: dict[str, str] = {}, - ): - self.tenant_id = tenant_id - self.tls_config = tls_config - self.host_port = host_port - self.token = token - self.server_url = server_url - self.namespace = "" - self.logInterceptor = logger - self.grpc_max_recv_message_length = grpc_max_recv_message_length - self.grpc_max_send_message_length = grpc_max_send_message_length - self.otel_exporter_oltp_endpoint = otel_exporter_oltp_endpoint - self.otel_service_name = otel_service_name - self.otel_exporter_oltp_headers = otel_exporter_oltp_headers - self.otel_exporter_oltp_protocol = otel_exporter_oltp_protocol - self.worker_healthcheck_port = worker_healthcheck_port - self.worker_healthcheck_enabled = worker_healthcheck_enabled - self.worker_preset_labels = worker_preset_labels - - if not self.logInterceptor: - self.logInterceptor = getLogger() - - # case on whether the namespace already has a trailing underscore - if namespace and not namespace.endswith("_"): - self.namespace = f"{namespace}_" - elif namespace: - self.namespace = namespace - - self.namespace = self.namespace.lower() - - self.listener_v2_timeout = listener_v2_timeout - - -class ConfigLoader: - def __init__(self, directory: str): - self.directory = directory - - def load_client_config(self, defaults: ClientConfig) -> ClientConfig: - config_file_path = os.path.join(self.directory, "client.yaml") - config_data: object = {"tls": {}} - - # determine if client.yaml exists - if os.path.exists(config_file_path): - with open(config_file_path, "r") as file: - config_data = yaml.safe_load(file) - - def get_config_value(key, env_var): - if key in config_data: - return config_data[key] - - if self._get_env_var(env_var) is not None: - return self._get_env_var(env_var) - - return getattr(defaults, key, None) - - namespace = get_config_value("namespace", "HATCHET_CLIENT_NAMESPACE") - - tenant_id = get_config_value("tenantId", "HATCHET_CLIENT_TENANT_ID") - token = get_config_value("token", "HATCHET_CLIENT_TOKEN") - listener_v2_timeout = get_config_value( - "listener_v2_timeout", "HATCHET_CLIENT_LISTENER_V2_TIMEOUT" - ) - listener_v2_timeout = int(listener_v2_timeout) if listener_v2_timeout else None - - if not token: - raise ValueError( - "Token must be set via HATCHET_CLIENT_TOKEN environment variable" - ) - - host_port = get_config_value("hostPort", "HATCHET_CLIENT_HOST_PORT") - server_url: str | None = None - - grpc_max_recv_message_length = get_config_value( - "grpc_max_recv_message_length", - "HATCHET_CLIENT_GRPC_MAX_RECV_MESSAGE_LENGTH", - ) - grpc_max_send_message_length = get_config_value( - "grpc_max_send_message_length", - "HATCHET_CLIENT_GRPC_MAX_SEND_MESSAGE_LENGTH", - ) - - if grpc_max_recv_message_length: - grpc_max_recv_message_length = int(grpc_max_recv_message_length) - - if grpc_max_send_message_length: - grpc_max_send_message_length = int(grpc_max_send_message_length) - - if not host_port: - # extract host and port from token - server_url, grpc_broadcast_address = get_addresses_from_jwt(token) - host_port = grpc_broadcast_address - - if not tenant_id: - tenant_id = get_tenant_id_from_jwt(token) - - tls_config = self._load_tls_config(config_data["tls"], host_port) - - otel_exporter_oltp_endpoint = get_config_value( - "otel_exporter_oltp_endpoint", "HATCHET_CLIENT_OTEL_EXPORTER_OTLP_ENDPOINT" - ) - - otel_service_name = get_config_value( - "otel_service_name", "HATCHET_CLIENT_OTEL_SERVICE_NAME" - ) - - _oltp_headers = get_config_value( - "otel_exporter_oltp_headers", "HATCHET_CLIENT_OTEL_EXPORTER_OTLP_HEADERS" - ) - - if _oltp_headers: - try: - otel_header_key, api_key = _oltp_headers.split("=", maxsplit=1) - otel_exporter_oltp_headers = {otel_header_key: api_key} - except ValueError: - raise ValueError( - "HATCHET_CLIENT_OTEL_EXPORTER_OTLP_HEADERS must be in the format `key=value`" - ) + +from pydantic import Field, field_validator, model_validator +from pydantic_settings import BaseSettings, SettingsConfigDict + +from hatchet_sdk.token import get_addresses_from_jwt, get_tenant_id_from_jwt + + +def create_settings_config(env_prefix: str) -> SettingsConfigDict: + return SettingsConfigDict( + env_prefix=env_prefix, + env_file=(".env", ".env.hatchet", ".env.dev", ".env.local"), + extra="ignore", + ) + + +class ClientTLSConfig(BaseSettings): + model_config = create_settings_config( + env_prefix="HATCHET_CLIENT_TLS_", + ) + + strategy: str = "tls" + cert_file: str | None = None + key_file: str | None = None + root_ca_file: str | None = None + server_name: str = "" + + +class OTELConfig(BaseSettings): + model_config = create_settings_config( + env_prefix="HATCHET_CLIENT_OTEL_", + ) + + service_name: str | None = None + exporter_otlp_endpoint: str | None = None + exporter_otlp_headers: str | None = None + exporter_otlp_protocol: str | None = None + + +class HealthcheckConfig(BaseSettings): + model_config = create_settings_config( + env_prefix="HATCHET_CLIENT_WORKER_HEALTHCHECK_", + ) + + port: int = 8001 + enabled: bool = False + + +DEFAULT_HOST_PORT = "localhost:7070" + + +class ClientConfig(BaseSettings): + model_config = create_settings_config( + env_prefix="HATCHET_CLIENT_", + ) + + token: str = "" + logger: Logger = getLogger() + + tenant_id: str = "" + host_port: str = DEFAULT_HOST_PORT + server_url: str = "https://app.dev.hatchet-tools.com" + namespace: str = "" + + tls_config: ClientTLSConfig = Field(default_factory=lambda: ClientTLSConfig()) + otel: OTELConfig = Field(default_factory=lambda: OTELConfig()) + healthcheck: HealthcheckConfig = Field(default_factory=lambda: HealthcheckConfig()) + + listener_v2_timeout: int | None = None + grpc_max_recv_message_length: int = Field( + default=4 * 1024 * 1024, description="4MB default" + ) + grpc_max_send_message_length: int = Field( + default=4 * 1024 * 1024, description="4MB default" + ) + + worker_preset_labels: dict[str, str] = Field(default_factory=dict) + + @model_validator(mode="after") + def validate_token_and_tenant(self) -> "ClientConfig": + if not self.token: + raise ValueError("Token must be set") + + if not self.tenant_id: + self.tenant_id = get_tenant_id_from_jwt(self.token) + + return self + + @model_validator(mode="after") + def validate_addresses(self) -> "ClientConfig": + if self.host_port == DEFAULT_HOST_PORT: + server_url, grpc_broadcast_address = get_addresses_from_jwt(self.token) + self.host_port = grpc_broadcast_address + self.server_url = server_url else: - otel_exporter_oltp_headers = None - - otel_exporter_oltp_protocol = get_config_value( - "otel_exporter_oltp_protocol", "HATCHET_CLIENT_OTEL_EXPORTER_OTLP_PROTOCOL" - ) - - worker_healthcheck_port = int( - get_config_value( - "worker_healthcheck_port", "HATCHET_CLIENT_WORKER_HEALTHCHECK_PORT" - ) - or 8001 - ) - - worker_healthcheck_enabled = ( - str( - get_config_value( - "worker_healthcheck_port", - "HATCHET_CLIENT_WORKER_HEALTHCHECK_ENABLED", - ) - ) - == "True" - ) - - # Add preset labels to the worker config - worker_preset_labels: dict[str, str] = defaults.worker_preset_labels - - autoscaling_target = get_config_value( - "autoscaling_target", "HATCHET_CLIENT_AUTOSCALING_TARGET" - ) - - if autoscaling_target: - worker_preset_labels["hatchet-autoscaling-target"] = autoscaling_target - - return ClientConfig( - tenant_id=tenant_id, - tls_config=tls_config, - token=token, - host_port=host_port, - server_url=server_url, - namespace=namespace, - listener_v2_timeout=listener_v2_timeout, - logger=defaults.logInterceptor, - grpc_max_recv_message_length=grpc_max_recv_message_length, - grpc_max_send_message_length=grpc_max_send_message_length, - otel_exporter_oltp_endpoint=otel_exporter_oltp_endpoint, - otel_service_name=otel_service_name, - otel_exporter_oltp_headers=otel_exporter_oltp_headers, - otel_exporter_oltp_protocol=otel_exporter_oltp_protocol, - worker_healthcheck_port=worker_healthcheck_port, - worker_healthcheck_enabled=worker_healthcheck_enabled, - worker_preset_labels=worker_preset_labels, - ) - - def _load_tls_config(self, tls_data: Dict, host_port) -> ClientTLSConfig: - tls_strategy = ( - tls_data["tlsStrategy"] - if "tlsStrategy" in tls_data - else self._get_env_var("HATCHET_CLIENT_TLS_STRATEGY") - ) - - if not tls_strategy: - tls_strategy = "tls" - - cert_file = ( - tls_data["tlsCertFile"] - if "tlsCertFile" in tls_data - else self._get_env_var("HATCHET_CLIENT_TLS_CERT_FILE") - ) - key_file = ( - tls_data["tlsKeyFile"] - if "tlsKeyFile" in tls_data - else self._get_env_var("HATCHET_CLIENT_TLS_KEY_FILE") - ) - ca_file = ( - tls_data["tlsRootCAFile"] - if "tlsRootCAFile" in tls_data - else self._get_env_var("HATCHET_CLIENT_TLS_ROOT_CA_FILE") - ) - - server_name = ( - tls_data["tlsServerName"] - if "tlsServerName" in tls_data - else self._get_env_var("HATCHET_CLIENT_TLS_SERVER_NAME") - ) - - # if server_name is not set, use the host from the host_port - if not server_name: - server_name = host_port.split(":")[0] - - return ClientTLSConfig(tls_strategy, cert_file, key_file, ca_file, server_name) - - @staticmethod - def _get_env_var(env_var: str, default: Optional[str] = None) -> str: - return os.environ.get(env_var, default) + self.server_url = self.host_port + + if not self.tls_config.server_name: + self.tls_config.server_name = self.host_port.split(":")[0] + + if not self.tls_config.server_name: + self.tls_config.server_name = "localhost" + + return self + + @field_validator("listener_v2_timeout") + @classmethod + def validate_listener_timeout(cls, value: int | None | str) -> int | None: + if value is None: + return None + + if isinstance(value, int): + return value + + return int(value) + + @field_validator("namespace") + @classmethod + def validate_namespace(cls, namespace: str) -> str: + if not namespace: + return "" + if not namespace.endswith("_"): + namespace = f"{namespace}_" + return namespace.lower() + + def __hash__(self) -> int: + return hash(json.dumps(self.model_dump(), default=str)) diff --git a/hatchet_sdk/metadata.py b/hatchet_sdk/metadata.py index 38a31b8b..d4004c64 100644 --- a/hatchet_sdk/metadata.py +++ b/hatchet_sdk/metadata.py @@ -1,2 +1,2 @@ -def get_metadata(token: str): +def get_metadata(token: str) -> list[tuple[str, str]]: return [("authorization", "bearer " + token)] diff --git a/hatchet_sdk/rate_limit.py b/hatchet_sdk/rate_limit.py index 0d7b9143..f9f574a4 100644 --- a/hatchet_sdk/rate_limit.py +++ b/hatchet_sdk/rate_limit.py @@ -1,7 +1,8 @@ from dataclasses import dataclass +from enum import Enum from typing import Union -from celpy import CELEvalError, Environment +from celpy import CELEvalError, Environment # type: ignore from hatchet_sdk.contracts.workflows_pb2 import CreateStepRateLimit @@ -15,7 +16,7 @@ def validate_cel_expression(expr: str) -> bool: return False -class RateLimitDuration: +class RateLimitDuration(str, Enum): SECOND = "SECOND" MINUTE = "MINUTE" HOUR = "HOUR" @@ -71,9 +72,9 @@ class RateLimit: limit: Union[int, str, None] = None duration: RateLimitDuration = RateLimitDuration.MINUTE - _req: CreateStepRateLimit = None + _req: CreateStepRateLimit | None = None - def __post_init__(self): + def __post_init__(self) -> None: # juggle the key and key_expr fields key = self.static_key key_expression = self.dynamic_key diff --git a/hatchet_sdk/token.py b/hatchet_sdk/token.py index 313a6671..58d34c65 100644 --- a/hatchet_sdk/token.py +++ b/hatchet_sdk/token.py @@ -1,20 +1,25 @@ import base64 -import json +from pydantic import BaseModel -def get_tenant_id_from_jwt(token: str) -> str: - claims = extract_claims_from_jwt(token) - return claims.get("sub") +class Claims(BaseModel): + sub: str + server_url: str + grpc_broadcast_address: str + + +def get_tenant_id_from_jwt(token: str) -> str: + return extract_claims_from_jwt(token).sub -def get_addresses_from_jwt(token: str) -> (str, str): +def get_addresses_from_jwt(token: str) -> tuple[str, str]: claims = extract_claims_from_jwt(token) - return claims.get("server_url"), claims.get("grpc_broadcast_address") + return claims.server_url, claims.grpc_broadcast_address -def extract_claims_from_jwt(token: str): +def extract_claims_from_jwt(token: str) -> Claims: parts = token.split(".") if len(parts) != 3: raise ValueError("Invalid token format") @@ -22,6 +27,5 @@ def extract_claims_from_jwt(token: str): claims_part = parts[1] claims_part += "=" * ((4 - len(claims_part) % 4) % 4) # Padding for base64 decoding claims_data = base64.urlsafe_b64decode(claims_part) - claims = json.loads(claims_data) - return claims + return Claims.model_validate_json(claims_data) diff --git a/hatchet_sdk/utils/tracing.py b/hatchet_sdk/utils/tracing.py index afc398f7..ce93bab6 100644 --- a/hatchet_sdk/utils/tracing.py +++ b/hatchet_sdk/utils/tracing.py @@ -1,4 +1,3 @@ -import json from functools import cache from typing import Any @@ -16,18 +15,28 @@ OTEL_CARRIER_KEY = "__otel_carrier" +def parse_headers(headers: str | None) -> dict[str, str]: + if headers is None: + return {} + + try: + return dict([headers.split("=", maxsplit=1)]) + except ValueError: + raise ValueError("OTLP headers must be in the format `key=value`") + + @cache def create_tracer(config: ClientConfig) -> Tracer: ## TODO: Figure out how to specify protocol here resource = Resource( - attributes={SERVICE_NAME: config.otel_service_name or "hatchet.run"} + attributes={SERVICE_NAME: config.otel.service_name or "hatchet.run"} ) - if config.otel_exporter_oltp_endpoint and config.otel_exporter_oltp_headers: + if config.otel.exporter_otlp_endpoint and config.otel.exporter_otlp_headers: processor = BatchSpanProcessor( OTLPSpanExporter( - endpoint=config.otel_exporter_oltp_endpoint, - headers=config.otel_exporter_oltp_headers, + endpoint=config.otel.exporter_otlp_endpoint, + headers=parse_headers(config.otel.exporter_otlp_headers), ), ) diff --git a/hatchet_sdk/utils/types.py b/hatchet_sdk/utils/types.py index 30e469f7..16ab43f6 100644 --- a/hatchet_sdk/utils/types.py +++ b/hatchet_sdk/utils/types.py @@ -1,4 +1,4 @@ -from typing import Type +from typing import Any, Type from pydantic import BaseModel @@ -6,3 +6,6 @@ class WorkflowValidator(BaseModel): workflow_input: Type[BaseModel] | None = None step_output: Type[BaseModel] | None = None + + +JSONSerializableDict = dict[str, Any] diff --git a/hatchet_sdk/utils/typing.py b/hatchet_sdk/utils/typing.py index db111db5..a61bd7c3 100644 --- a/hatchet_sdk/utils/typing.py +++ b/hatchet_sdk/utils/typing.py @@ -1,4 +1,4 @@ -from typing import Any, Type, TypeGuard, TypeVar +from typing import Any, TypeVar from pydantic import BaseModel diff --git a/hatchet_sdk/v2/callable.py b/hatchet_sdk/v2/callable.py deleted file mode 100644 index 097a7d87..00000000 --- a/hatchet_sdk/v2/callable.py +++ /dev/null @@ -1,202 +0,0 @@ -import asyncio -from typing import ( - Any, - Callable, - Dict, - Generic, - List, - Optional, - TypedDict, - TypeVar, - Union, -) - -from hatchet_sdk.clients.admin import ChildTriggerWorkflowOptions -from hatchet_sdk.context.context import Context -from hatchet_sdk.contracts.workflows_pb2 import ( # type: ignore[attr-defined] - CreateStepRateLimit, - CreateWorkflowJobOpts, - CreateWorkflowStepOpts, - CreateWorkflowVersionOpts, - DesiredWorkerLabels, - StickyStrategy, - WorkflowConcurrencyOpts, - WorkflowKind, -) -from hatchet_sdk.labels import DesiredWorkerLabel -from hatchet_sdk.logger import logger -from hatchet_sdk.rate_limit import RateLimit -from hatchet_sdk.v2.concurrency import ConcurrencyFunction -from hatchet_sdk.workflow_run import RunRef - -T = TypeVar("T") - - -class HatchetCallable(Generic[T]): - def __init__( - self, - func: Callable[[Context], T], - durable: bool = False, - name: str = "", - auto_register: bool = True, - on_events: list[str] | None = None, - on_crons: list[str] | None = None, - version: str = "", - timeout: str = "60m", - schedule_timeout: str = "5m", - sticky: StickyStrategy = None, - retries: int = 0, - rate_limits: List[RateLimit] | None = None, - concurrency: ConcurrencyFunction | None = None, - on_failure: Union["HatchetCallable[T]", None] = None, - desired_worker_labels: dict[str, DesiredWorkerLabel] = {}, - default_priority: int | None = None, - ): - self.func = func - - on_events = on_events or [] - on_crons = on_crons or [] - - limits = None - if rate_limits: - limits = [rate_limit._req for rate_limit in rate_limits or []] - - self.function_desired_worker_labels = {} - - for key, d in desired_worker_labels.items(): - value = d["value"] if "value" in d else None - self.function_desired_worker_labels[key] = DesiredWorkerLabels( - strValue=str(value) if not isinstance(value, int) else None, - intValue=value if isinstance(value, int) else None, - required=d["required"] if "required" in d else None, - weight=d["weight"] if "weight" in d else None, - comparator=d["comparator"] if "comparator" in d else None, - ) - self.sticky = sticky - self.default_priority = default_priority - self.durable = durable - self.function_name = name.lower() or str(func.__name__).lower() - self.function_version = version - self.function_on_events = on_events - self.function_on_crons = on_crons - self.function_timeout = timeout - self.function_schedule_timeout = schedule_timeout - self.function_retries = retries - self.function_rate_limits = limits - self.function_concurrency = concurrency - self.function_on_failure = on_failure - self.function_namespace = "default" - self.function_auto_register = auto_register - - self.is_coroutine = False - - if asyncio.iscoroutinefunction(func): - self.is_coroutine = True - - def __call__(self, context: Context) -> T: - return self.func(context) - - def with_namespace(self, namespace: str) -> None: - if namespace is not None and namespace != "": - self.function_namespace = namespace - self.function_name = namespace + self.function_name - - def to_workflow_opts(self) -> CreateWorkflowVersionOpts: - kind: WorkflowKind = WorkflowKind.FUNCTION - - if self.durable: - kind = WorkflowKind.DURABLE - - on_failure_job: CreateWorkflowJobOpts | None = None - - if self.function_on_failure is not None: - on_failure_job = CreateWorkflowJobOpts( - name=self.function_name + "-on-failure", - steps=[ - self.function_on_failure.to_step(), - ], - ) - - concurrency: WorkflowConcurrencyOpts | None = None - - if self.function_concurrency is not None: - self.function_concurrency.set_namespace(self.function_namespace) - concurrency = WorkflowConcurrencyOpts( - action=self.function_concurrency.get_action_name(), - max_runs=self.function_concurrency.max_runs, - limit_strategy=self.function_concurrency.limit_strategy, - ) - - validated_priority = ( - max(1, min(3, self.default_priority)) if self.default_priority else None - ) - if validated_priority != self.default_priority: - logger.warning( - "Warning: Default Priority Must be between 1 and 3 -- inclusively. Adjusted to be within the range." - ) - - return CreateWorkflowVersionOpts( - name=self.function_name, - kind=kind, - version=self.function_version, - event_triggers=self.function_on_events, - cron_triggers=self.function_on_crons, - schedule_timeout=self.function_schedule_timeout, - sticky=self.sticky, - on_failure_job=on_failure_job, - concurrency=concurrency, - jobs=[ - CreateWorkflowJobOpts( - name=self.function_name, - steps=[ - self.to_step(), - ], - ) - ], - default_priority=validated_priority, - ) - - def to_step(self) -> CreateWorkflowStepOpts: - return CreateWorkflowStepOpts( - readable_id=self.function_name, - action=self.get_action_name(), - timeout=self.function_timeout, - inputs="{}", - parents=[], - retries=self.function_retries, - rate_limits=self.function_rate_limits, - worker_labels=self.function_desired_worker_labels, - ) - - def get_action_name(self) -> str: - return self.function_namespace + ":" + self.function_name - - -class DurableContext(Context): - def run( - self, - function: str | Callable[[Context], Any], - input: dict[Any, Any] = {}, - key: str | None = None, - options: ChildTriggerWorkflowOptions | None = None, - ) -> "RunRef[T]": - worker_id = self.worker.id() - - workflow_name = function - - if not isinstance(function, str): - workflow_name = function.function_name - - # if ( - # options is not None - # and "sticky" in options - # and options["sticky"] == True - # and not self.worker.has_workflow(workflow_name) - # ): - # raise Exception( - # f"cannot run with sticky: workflow {workflow_name} is not registered on the worker" - # ) - - trigger_options = self._prepare_workflow_options(key, options, worker_id) - - return self.admin_client.run(function, input, trigger_options) diff --git a/hatchet_sdk/v2/concurrency.py b/hatchet_sdk/v2/concurrency.py deleted file mode 100644 index 73d9e3b4..00000000 --- a/hatchet_sdk/v2/concurrency.py +++ /dev/null @@ -1,47 +0,0 @@ -from typing import Any, Callable - -from hatchet_sdk.context.context import Context -from hatchet_sdk.contracts.workflows_pb2 import ( # type: ignore[attr-defined] - ConcurrencyLimitStrategy, -) - - -class ConcurrencyFunction: - def __init__( - self, - func: Callable[[Context], str], - name: str = "concurrency", - max_runs: int = 1, - limit_strategy: ConcurrencyLimitStrategy = ConcurrencyLimitStrategy.GROUP_ROUND_ROBIN, - ): - self.func = func - self.name = name - self.max_runs = max_runs - self.limit_strategy = limit_strategy - self.namespace = "default" - - def set_namespace(self, namespace: str) -> None: - self.namespace = namespace - - def get_action_name(self) -> str: - return self.namespace + ":" + self.name - - def __call__(self, *args: Any, **kwargs: Any) -> str: - return self.func(*args, **kwargs) - - def __str__(self) -> str: - return f"{self.name}({self.max_runs})" - - def __repr__(self) -> str: - return f"{self.name}({self.max_runs})" - - -def concurrency( - name: str = "", - max_runs: int = 1, - limit_strategy: ConcurrencyLimitStrategy = ConcurrencyLimitStrategy.GROUP_ROUND_ROBIN, -) -> Callable[[Callable[[Context], str]], ConcurrencyFunction]: - def inner(func: Callable[[Context], str]) -> ConcurrencyFunction: - return ConcurrencyFunction(func, name, max_runs, limit_strategy) - - return inner diff --git a/hatchet_sdk/v2/hatchet.py b/hatchet_sdk/v2/hatchet.py deleted file mode 100644 index 4dd3faf0..00000000 --- a/hatchet_sdk/v2/hatchet.py +++ /dev/null @@ -1,224 +0,0 @@ -from typing import Any, Callable, TypeVar, Union - -from hatchet_sdk import Worker -from hatchet_sdk.context.context import Context -from hatchet_sdk.contracts.workflows_pb2 import ( # type: ignore[attr-defined] - ConcurrencyLimitStrategy, - StickyStrategy, -) -from hatchet_sdk.hatchet import Hatchet as HatchetV1 -from hatchet_sdk.hatchet import workflow -from hatchet_sdk.labels import DesiredWorkerLabel -from hatchet_sdk.rate_limit import RateLimit -from hatchet_sdk.v2.callable import DurableContext, HatchetCallable -from hatchet_sdk.v2.concurrency import ConcurrencyFunction -from hatchet_sdk.worker.worker import register_on_worker - -T = TypeVar("T") - - -def function( - name: str = "", - auto_register: bool = True, - on_events: list[str] | None = None, - on_crons: list[str] | None = None, - version: str = "", - timeout: str = "60m", - schedule_timeout: str = "5m", - sticky: StickyStrategy = None, - retries: int = 0, - rate_limits: list[RateLimit] | None = None, - desired_worker_labels: dict[str, DesiredWorkerLabel] = {}, - concurrency: ConcurrencyFunction | None = None, - on_failure: Union["HatchetCallable[T]", None] = None, - default_priority: int | None = None, -) -> Callable[[Callable[[Context], str]], HatchetCallable[T]]: - def inner(func: Callable[[Context], T]) -> HatchetCallable[T]: - return HatchetCallable( - func=func, - name=name, - auto_register=auto_register, - on_events=on_events, - on_crons=on_crons, - version=version, - timeout=timeout, - schedule_timeout=schedule_timeout, - sticky=sticky, - retries=retries, - rate_limits=rate_limits, - desired_worker_labels=desired_worker_labels, - concurrency=concurrency, - on_failure=on_failure, - default_priority=default_priority, - ) - - return inner - - -def durable( - name: str = "", - auto_register: bool = True, - on_events: list[str] | None = None, - on_crons: list[str] | None = None, - version: str = "", - timeout: str = "60m", - schedule_timeout: str = "5m", - sticky: StickyStrategy = None, - retries: int = 0, - rate_limits: list[RateLimit] | None = None, - desired_worker_labels: dict[str, DesiredWorkerLabel] = {}, - concurrency: ConcurrencyFunction | None = None, - on_failure: HatchetCallable[T] | None = None, - default_priority: int | None = None, -) -> Callable[[HatchetCallable[T]], HatchetCallable[T]]: - def inner(func: HatchetCallable[T]) -> HatchetCallable[T]: - func.durable = True - - f = function( - name=name, - auto_register=auto_register, - on_events=on_events, - on_crons=on_crons, - version=version, - timeout=timeout, - schedule_timeout=schedule_timeout, - sticky=sticky, - retries=retries, - rate_limits=rate_limits, - desired_worker_labels=desired_worker_labels, - concurrency=concurrency, - on_failure=on_failure, - default_priority=default_priority, - ) - - resp = f(func) - - resp.durable = True - - return resp - - return inner - - -def concurrency( - name: str = "concurrency", - max_runs: int = 1, - limit_strategy: ConcurrencyLimitStrategy = ConcurrencyLimitStrategy.GROUP_ROUND_ROBIN, -) -> Callable[[Callable[[Context], str]], ConcurrencyFunction]: - def inner(func: Callable[[Context], str]) -> ConcurrencyFunction: - return ConcurrencyFunction(func, name, max_runs, limit_strategy) - - return inner - - -class Hatchet(HatchetV1): - dag = staticmethod(workflow) - concurrency = staticmethod(concurrency) - - functions: list[HatchetCallable[T]] = [] - - def function( - self, - name: str = "", - auto_register: bool = True, - on_events: list[str] | None = None, - on_crons: list[str] | None = None, - version: str = "", - timeout: str = "60m", - schedule_timeout: str = "5m", - retries: int = 0, - rate_limits: list[RateLimit] | None = None, - desired_worker_labels: dict[str, DesiredWorkerLabel] = {}, - concurrency: ConcurrencyFunction | None = None, - on_failure: Union["HatchetCallable[T]", None] = None, - default_priority: int | None = None, - ) -> Callable[[Callable[[Context], Any]], Callable[[Context], Any]]: - resp = function( - name=name, - auto_register=auto_register, - on_events=on_events, - on_crons=on_crons, - version=version, - timeout=timeout, - schedule_timeout=schedule_timeout, - retries=retries, - rate_limits=rate_limits, - desired_worker_labels=desired_worker_labels, - concurrency=concurrency, - on_failure=on_failure, - default_priority=default_priority, - ) - - def wrapper(func: Callable[[Context], str]) -> HatchetCallable[T]: - wrapped_resp = resp(func) - - if wrapped_resp.function_auto_register: - self.functions.append(wrapped_resp) - - wrapped_resp.with_namespace(self._client.config.namespace) - - return wrapped_resp - - return wrapper - - def durable( - self, - name: str = "", - auto_register: bool = True, - on_events: list[str] | None = None, - on_crons: list[str] | None = None, - version: str = "", - timeout: str = "60m", - schedule_timeout: str = "5m", - sticky: StickyStrategy = None, - retries: int = 0, - rate_limits: list[RateLimit] | None = None, - desired_worker_labels: dict[str, DesiredWorkerLabel] = {}, - concurrency: ConcurrencyFunction | None = None, - on_failure: Union["HatchetCallable[T]", None] = None, - default_priority: int | None = None, - ) -> Callable[[Callable[[DurableContext], Any]], Callable[[DurableContext], Any]]: - resp = durable( - name=name, - auto_register=auto_register, - on_events=on_events, - on_crons=on_crons, - version=version, - timeout=timeout, - schedule_timeout=schedule_timeout, - sticky=sticky, - retries=retries, - rate_limits=rate_limits, - desired_worker_labels=desired_worker_labels, - concurrency=concurrency, - on_failure=on_failure, - default_priority=default_priority, - ) - - def wrapper(func: HatchetCallable[T]) -> HatchetCallable[T]: - wrapped_resp = resp(func) - - if wrapped_resp.function_auto_register: - self.functions.append(wrapped_resp) - - wrapped_resp.with_namespace(self._client.config.namespace) - - return wrapped_resp - - return wrapper - - def worker( - self, name: str, max_runs: int | None = None, labels: dict[str, str | int] = {} - ): - worker = Worker( - name=name, - max_runs=max_runs, - labels=labels, - config=self._client.config, - debug=self._client.debug, - ) - - for func in self.functions: - register_on_worker(func, worker) - - return worker diff --git a/hatchet_sdk/worker/__init__.py b/hatchet_sdk/worker/__init__.py index 450f0cac..e69de29b 100644 --- a/hatchet_sdk/worker/__init__.py +++ b/hatchet_sdk/worker/__init__.py @@ -1 +0,0 @@ -from .worker import Worker, WorkerStartOptions, WorkerStatus diff --git a/hatchet_sdk/worker/action_listener_process.py b/hatchet_sdk/worker/action_listener_process.py index 08508607..01e934f6 100644 --- a/hatchet_sdk/worker/action_listener_process.py +++ b/hatchet_sdk/worker/action_listener_process.py @@ -4,16 +4,16 @@ import time from dataclasses import dataclass, field from multiprocessing import Queue -from typing import Any, List, Mapping, Optional +from typing import Any, List, Literal import grpc -from hatchet_sdk.clients.dispatcher.action_listener import Action -from hatchet_sdk.clients.dispatcher.dispatcher import ( +from hatchet_sdk.clients.dispatcher.action_listener import ( + Action, ActionListener, GetActionListenerRequest, - new_dispatcher, ) +from hatchet_sdk.clients.dispatcher.dispatcher import new_dispatcher from hatchet_sdk.contracts.dispatcher_pb2 import ( GROUP_KEY_EVENT_TYPE_STARTED, STEP_EVENT_TYPE_STARTED, @@ -30,10 +30,11 @@ class ActionEvent: action: Action type: Any # TODO type - payload: Optional[str] = None + payload: str -STOP_LOOP = "STOP_LOOP" # Sentinel object to stop the loop +STOP_LOOP_TYPE = Literal["STOP_LOOP"] +STOP_LOOP: STOP_LOOP_TYPE = "STOP_LOOP" # Sentinel object to stop the loop # TODO link to a block post BLOCKED_THREAD_WARNING = ( @@ -41,7 +42,7 @@ class ActionEvent: ) -def noop_handler(): +def noop_handler() -> None: pass @@ -51,22 +52,22 @@ class WorkerActionListenerProcess: actions: List[str] max_runs: int config: ClientConfig - action_queue: Queue - event_queue: Queue + action_queue: "Queue[Action]" + event_queue: "Queue[ActionEvent | STOP_LOOP_TYPE]" handle_kill: bool = True debug: bool = False - labels: dict = field(default_factory=dict) + labels: dict[str, str | int] = field(default_factory=dict) - listener: ActionListener = field(init=False, default=None) + listener: ActionListener = field(init=False) killing: bool = field(init=False, default=False) - action_loop_task: asyncio.Task = field(init=False, default=None) - event_send_loop_task: asyncio.Task = field(init=False, default=None) + action_loop_task: asyncio.Task[None] | None = field(init=False, default=None) + event_send_loop_task: asyncio.Task[None] | None = field(init=False, default=None) - running_step_runs: Mapping[str, float] = field(init=False, default_factory=dict) + running_step_runs: dict[str, float] = field(init=False, default_factory=dict) - def __post_init__(self): + def __post_init__(self) -> None: if self.debug: logger.setLevel(logging.DEBUG) @@ -77,7 +78,7 @@ def __post_init__(self): signal.SIGQUIT, lambda: asyncio.create_task(self.exit_gracefully()) ) - async def start(self, retry_attempt=0): + async def start(self, retry_attempt: int = 0) -> None: if retry_attempt > 5: logger.error("could not start action listener") return @@ -108,13 +109,13 @@ async def start(self, retry_attempt=0): self.blocked_main_loop = asyncio.create_task(self.start_blocked_main_loop()) # TODO move event methods to separate class - async def _get_event(self): + async def _get_event(self) -> ActionEvent | STOP_LOOP_TYPE: loop = asyncio.get_running_loop() return await loop.run_in_executor(None, self.event_queue.get) - async def start_event_send_loop(self): + async def start_event_send_loop(self) -> None: while True: - event: ActionEvent = await self._get_event() + event = await self._get_event() if event == STOP_LOOP: logger.debug("stopping event send loop...") break @@ -122,11 +123,11 @@ async def start_event_send_loop(self): logger.debug(f"tx: event: {event.action.action_id}/{event.type}") asyncio.create_task(self.send_event(event)) - async def start_blocked_main_loop(self): + async def start_blocked_main_loop(self) -> None: threshold = 1 while not self.killing: count = 0 - for step_run_id, start_time in self.running_step_runs.items(): + for _, start_time in self.running_step_runs.items(): diff = self.now() - start_time if diff > threshold: count += 1 @@ -135,7 +136,7 @@ async def start_blocked_main_loop(self): logger.warning(f"{BLOCKED_THREAD_WARNING}: Waiting Steps {count}") await asyncio.sleep(1) - async def send_event(self, event: ActionEvent, retry_attempt: int = 1): + async def send_event(self, event: ActionEvent, retry_attempt: int = 1) -> None: try: match event.action.action_type: # FIXME: all events sent from an execution of a function are of type ActionType.START_STEP_RUN since @@ -185,10 +186,10 @@ async def send_event(self, event: ActionEvent, retry_attempt: int = 1): await exp_backoff_sleep(retry_attempt, 1) await self.send_event(event, retry_attempt + 1) - def now(self): + def now(self) -> float: return time.time() - async def start_action_loop(self): + async def start_action_loop(self) -> None: try: async for action in self.listener: if action is None: @@ -201,6 +202,7 @@ async def start_action_loop(self): ActionEvent( action=action, type=STEP_EVENT_TYPE_STARTED, # TODO ack type + payload="", ) ) logger.info( @@ -220,6 +222,7 @@ async def start_action_loop(self): ActionEvent( action=action, type=GROUP_KEY_EVENT_TYPE_STARTED, # TODO ack type + payload="", ) ) logger.info( @@ -239,9 +242,9 @@ async def start_action_loop(self): finally: logger.info("action loop closed") if not self.killing: - await self.exit_gracefully(skip_unregister=True) + await self.exit_gracefully() - async def cleanup(self): + async def cleanup(self) -> None: self.killing = True if self.listener is not None: @@ -249,7 +252,7 @@ async def cleanup(self): self.event_queue.put(STOP_LOOP) - async def exit_gracefully(self, skip_unregister=False): + async def exit_gracefully(self) -> None: if self.killing: return @@ -262,13 +265,13 @@ async def exit_gracefully(self, skip_unregister=False): logger.info("action listener closed") - def exit_forcefully(self): + def exit_forcefully(self) -> None: asyncio.run(self.cleanup()) logger.debug("forcefully closing listener...") -def worker_action_listener_process(*args, **kwargs): - async def run(): +def worker_action_listener_process(*args: Any, **kwargs: Any) -> None: + async def run() -> None: process = WorkerActionListenerProcess(*args, **kwargs) await process.start() # Keep the process running diff --git a/hatchet_sdk/worker/runner/run_loop_manager.py b/hatchet_sdk/worker/runner/run_loop_manager.py index 27ed788c..7f14ab2f 100644 --- a/hatchet_sdk/worker/runner/run_loop_manager.py +++ b/hatchet_sdk/worker/runner/run_loop_manager.py @@ -2,18 +2,22 @@ import logging from dataclasses import dataclass, field from multiprocessing import Queue -from typing import Callable, TypeVar +from typing import TYPE_CHECKING, Any, Literal, TypeVar -from hatchet_sdk import Context from hatchet_sdk.client import Client, new_client_raw from hatchet_sdk.clients.dispatcher.action_listener import Action from hatchet_sdk.loader import ClientConfig from hatchet_sdk.logger import logger from hatchet_sdk.utils.types import WorkflowValidator +from hatchet_sdk.worker.action_listener_process import ActionEvent from hatchet_sdk.worker.runner.runner import Runner from hatchet_sdk.worker.runner.utils.capture_logs import capture_logs -STOP_LOOP = "STOP_LOOP" +if TYPE_CHECKING: + from hatchet_sdk.workflow import Step + +STOP_LOOP_TYPE = Literal["STOP_LOOP"] +STOP_LOOP: STOP_LOOP_TYPE = "STOP_LOOP" T = TypeVar("T") @@ -21,32 +25,32 @@ @dataclass class WorkerActionRunLoopManager: name: str - action_registry: dict[str, Callable[[Context], T]] + action_registry: dict[str, "Step[Any]"] validator_registry: dict[str, WorkflowValidator] max_runs: int | None config: ClientConfig - action_queue: Queue - event_queue: Queue + action_queue: "Queue[Action | STOP_LOOP_TYPE]" + event_queue: "Queue[ActionEvent]" loop: asyncio.AbstractEventLoop handle_kill: bool = True debug: bool = False labels: dict[str, str | int] = field(default_factory=dict) - client: Client = field(init=False, default=None) + client: Client = field(init=False) killing: bool = field(init=False, default=False) - runner: Runner = field(init=False, default=None) + runner: Runner | None = field(init=False, default=None) - def __post_init__(self): + def __post_init__(self) -> None: if self.debug: logger.setLevel(logging.DEBUG) self.client = new_client_raw(self.config, self.debug) self.start() - def start(self, retry_count=1): - k = self.loop.create_task(self.async_start(retry_count)) + def start(self, retry_count: int = 1) -> None: + k = self.loop.create_task(self.async_start(retry_count)) # noqa: F841 - async def async_start(self, retry_count=1): + async def async_start(self, retry_count: int = 1) -> None: await capture_logs( self.client.logInterceptor, self.client.event, @@ -63,6 +67,7 @@ async def _async_start(self, retry_count: int = 1) -> None: def cleanup(self) -> None: self.killing = True + ## TODO: The action queue is a queue of `Action`, so I don't think this will work self.action_queue.put(STOP_LOOP) async def wait_for_tasks(self) -> None: @@ -83,7 +88,8 @@ async def _start_action_loop(self) -> None: logger.debug(f"'{self.name}' waiting for {list(self.action_registry.keys())}") while not self.killing: - action: Action = await self._get_action() + action = await self._get_action() + ## TODO: This is a queue of `Action`, so I don't think this will work if action == STOP_LOOP: logger.debug("stopping action runner loop...") break @@ -91,7 +97,7 @@ async def _start_action_loop(self) -> None: self.runner.run(action) logger.debug("action runner loop stopped") - async def _get_action(self): + async def _get_action(self) -> Action | STOP_LOOP_TYPE: return await self.loop.run_in_executor(None, self.action_queue.get) async def exit_gracefully(self) -> None: diff --git a/hatchet_sdk/worker/runner/runner.py b/hatchet_sdk/worker/runner/runner.py index f72fb04b..3e949842 100644 --- a/hatchet_sdk/worker/runner/runner.py +++ b/hatchet_sdk/worker/runner/runner.py @@ -8,7 +8,7 @@ from enum import Enum from multiprocessing import Queue from threading import Thread, current_thread -from typing import Any, Callable, Dict, Literal, Type, TypeVar, cast, overload +from typing import TYPE_CHECKING, Any, Callable, Dict, TypeVar, cast from opentelemetry.trace import StatusCode from pydantic import BaseModel @@ -19,7 +19,7 @@ from hatchet_sdk.clients.dispatcher.dispatcher import new_dispatcher from hatchet_sdk.clients.run_event_listener import new_listener from hatchet_sdk.clients.workflow_listener import PooledWorkflowRunListener -from hatchet_sdk.context import Context # type: ignore[attr-defined] +from hatchet_sdk.context.context import Context from hatchet_sdk.context.worker_context import WorkerContext from hatchet_sdk.contracts.dispatcher_pb2 import ( GROUP_KEY_EVENT_TYPE_COMPLETED, @@ -34,10 +34,14 @@ from hatchet_sdk.logger import logger from hatchet_sdk.utils.tracing import create_tracer, parse_carrier_from_metadata from hatchet_sdk.utils.types import WorkflowValidator -from hatchet_sdk.v2.callable import DurableContext from hatchet_sdk.worker.action_listener_process import ActionEvent from hatchet_sdk.worker.runner.utils.capture_logs import copy_context_vars, sr, wr +T = TypeVar("T") + +if TYPE_CHECKING: + from hatchet_sdk.workflow import Step + class WorkerStatus(Enum): INITIALIZED = 1 @@ -53,7 +57,7 @@ def __init__( event_queue: "Queue[Any]", max_runs: int | None = None, handle_kill: bool = True, - action_registry: dict[str, Callable[..., Any]] = {}, + action_registry: dict[str, "Step[T]"] = {}, validator_registry: dict[str, WorkflowValidator] = {}, config: ClientConfig = ClientConfig(), labels: dict[str, str | int] = {}, @@ -65,7 +69,7 @@ def __init__( self.max_runs = max_runs self.tasks: dict[str, asyncio.Task[Any]] = {} # Store run ids and futures self.contexts: dict[str, Context] = {} # Store run ids and contexts - self.action_registry: dict[str, Callable[..., Any]] = action_registry + self.action_registry: dict[str, "Step[T]"] = action_registry self.validator_registry = validator_registry self.event_queue = event_queue @@ -212,8 +216,8 @@ def inner_callback(task: asyncio.Task[Any]) -> None: ## TODO: Stricter type hinting here def thread_action_func( - self, context: Context, action_func: Callable[..., Any], action: Action - ) -> Any: + self, context: Context, step: "Step[T]", action: Action + ) -> T: if action.step_run_id is not None and action.step_run_id != "": self.threads[action.step_run_id] = current_thread() elif ( @@ -222,25 +226,23 @@ def thread_action_func( ): self.threads[action.get_group_key_run_id] = current_thread() - return action_func(context) + return step.call(context) ## TODO: Stricter type hinting here # We wrap all actions in an async func async def async_wrapped_action_func( self, context: Context, - action_func: Callable[..., Any], + step: "Step[T]", action: Action, run_id: str, - ) -> Any: - wr.set(context.workflow_run_id()) + ) -> T: + wr.set(context.workflow_run_id) sr.set(context.step_run_id) try: - if ( - hasattr(action_func, "is_coroutine") and action_func.is_coroutine - ) or asyncio.iscoroutinefunction(action_func): - return await action_func(context) + if step.is_async_function: + return await step.acall(context) else: pfunc = functools.partial( # we must copy the context vars to the new thread, as only asyncio natively supports @@ -249,7 +251,7 @@ async def async_wrapped_action_func( contextvars.copy_context().items(), self.thread_action_func, context, - action_func, + step, action, ) @@ -276,23 +278,7 @@ def cleanup_run_id(self, run_id: str | None) -> None: if run_id in self.contexts: del self.contexts[run_id] - def create_context( - self, action: Action, action_func: Callable[..., Any] | None - ) -> Context | DurableContext: - if hasattr(action_func, "durable") and getattr(action_func, "durable"): - return DurableContext( - action, - self.dispatcher_client, - self.admin_client, - self.client.event, - self.client.rest, - self.client.workflow_listener, - self.workflow_run_event_listener, - self.worker_context, - self.client.config.namespace, - validator_registry=self.validator_registry, - ) - + def create_context(self, action: Action) -> Context: return Context( action, self.dispatcher_client, @@ -318,16 +304,13 @@ async def handle_start_step_run(self, action: Action) -> None: # Find the corresponding action function from the registry action_func = self.action_registry.get(action_name) - context = self.create_context(action, action_func) + context = self.create_context(action) self.contexts[action.step_run_id] = context if action_func: self.event_queue.put( - ActionEvent( - action=action, - type=STEP_EVENT_TYPE_STARTED, - ) + ActionEvent(action=action, type=STEP_EVENT_TYPE_STARTED, payload="") ) loop = asyncio.get_event_loop() @@ -377,8 +360,7 @@ async def handle_start_group_key_run(self, action: Action) -> None: # send an event that the group key run has started self.event_queue.put( ActionEvent( - action=action, - type=GROUP_KEY_EVENT_TYPE_STARTED, + action=action, type=GROUP_KEY_EVENT_TYPE_STARTED, payload="" ) ) diff --git a/hatchet_sdk/worker/runner/utils/capture_logs.py b/hatchet_sdk/worker/runner/utils/capture_logs.py index 08c57de8..c8121ba8 100644 --- a/hatchet_sdk/worker/runner/utils/capture_logs.py +++ b/hatchet_sdk/worker/runner/utils/capture_logs.py @@ -2,11 +2,12 @@ import functools import logging from concurrent.futures import ThreadPoolExecutor +from contextvars import ContextVar from io import StringIO -from typing import Any, Coroutine +from typing import Any, Awaitable, Callable, ItemsView, ParamSpec, TypeVar -from hatchet_sdk import logger from hatchet_sdk.clients.events import EventClient +from hatchet_sdk.logger import logger wr: contextvars.ContextVar[str | None] = contextvars.ContextVar( "workflow_run_id", default=None @@ -16,7 +17,16 @@ ) -def copy_context_vars(ctx_vars, func, *args, **kwargs): +T = TypeVar("T") +P = ParamSpec("P") + + +def copy_context_vars( + ctx_vars: ItemsView[ContextVar[Any], Any], + func: Callable[P, T], + *args: P.args, + **kwargs: P.kwargs, +) -> T: for var, value in ctx_vars: var.set(value) return func(*args, **kwargs) @@ -25,19 +35,20 @@ def copy_context_vars(ctx_vars, func, *args, **kwargs): class InjectingFilter(logging.Filter): # For some reason, only the InjectingFilter has access to the contextvars method sr.get(), # otherwise we would use emit within the CustomLogHandler - def filter(self, record): + def filter(self, record: logging.LogRecord) -> bool: + ## TODO: Change how we do this to not assign to the log record record.workflow_run_id = wr.get() record.step_run_id = sr.get() return True -class CustomLogHandler(logging.StreamHandler): - def __init__(self, event_client: EventClient, stream=None): +class CustomLogHandler(logging.StreamHandler[Any]): + def __init__(self, event_client: EventClient, stream: StringIO | None = None): super().__init__(stream) self.logger_thread_pool = ThreadPoolExecutor(max_workers=1) self.event_client = event_client - def _log(self, line: str, step_run_id: str | None): + def _log(self, line: str, step_run_id: str | None) -> None: try: if not step_run_id: return @@ -46,20 +57,20 @@ def _log(self, line: str, step_run_id: str | None): except Exception as e: logger.error(f"Error logging: {e}") - def emit(self, record): + def emit(self, record: logging.LogRecord) -> None: super().emit(record) log_entry = self.format(record) - self.logger_thread_pool.submit(self._log, log_entry, record.step_run_id) + + ## TODO: Change how we do this to not assign to the log record + self.logger_thread_pool.submit(self._log, log_entry, record.step_run_id) # type: ignore def capture_logs( - logger: logging.Logger, - event_client: EventClient, - func: Coroutine[Any, Any, Any], -): + logger: logging.Logger, event_client: "EventClient", func: Callable[P, Awaitable[T]] +) -> Callable[P, Awaitable[T]]: @functools.wraps(func) - async def wrapper(*args, **kwargs): + async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: if not logger: raise Exception("No logger configured on client") diff --git a/hatchet_sdk/worker/runner/utils/error_with_traceback.py b/hatchet_sdk/worker/runner/utils/error_with_traceback.py index 9c09602f..6aff1cb6 100644 --- a/hatchet_sdk/worker/runner/utils/error_with_traceback.py +++ b/hatchet_sdk/worker/runner/utils/error_with_traceback.py @@ -1,6 +1,6 @@ import traceback -def errorWithTraceback(message: str, e: Exception): +def errorWithTraceback(message: str, e: Exception) -> str: trace = "".join(traceback.format_exception(type(e), e, e.__traceback__)) return f"{message}\n{trace}" diff --git a/hatchet_sdk/worker/worker.py b/hatchet_sdk/worker/worker.py index b6ec1531..9a76e2eb 100644 --- a/hatchet_sdk/worker/worker.py +++ b/hatchet_sdk/worker/worker.py @@ -4,31 +4,37 @@ import os import signal import sys -from concurrent.futures import Future from dataclasses import dataclass, field from enum import Enum from multiprocessing import Queue from multiprocessing.process import BaseProcess from types import FrameType -from typing import Any, Callable, TypeVar, get_type_hints +from typing import TYPE_CHECKING, Any, TypeVar, Union, get_type_hints from aiohttp import web from aiohttp.web_request import Request from aiohttp.web_response import Response -from prometheus_client import CONTENT_TYPE_LATEST, Gauge, generate_latest +from prometheus_client import Gauge, generate_latest -from hatchet_sdk import Context from hatchet_sdk.client import Client, new_client_raw +from hatchet_sdk.clients.dispatcher.action_listener import Action from hatchet_sdk.contracts.workflows_pb2 import CreateWorkflowVersionOpts from hatchet_sdk.loader import ClientConfig from hatchet_sdk.logger import logger from hatchet_sdk.utils.types import WorkflowValidator from hatchet_sdk.utils.typing import is_basemodel_subclass -from hatchet_sdk.v2.callable import HatchetCallable -from hatchet_sdk.v2.concurrency import ConcurrencyFunction -from hatchet_sdk.worker.action_listener_process import worker_action_listener_process -from hatchet_sdk.worker.runner.run_loop_manager import WorkerActionRunLoopManager -from hatchet_sdk.workflow import WorkflowInterface +from hatchet_sdk.worker.action_listener_process import ( + ActionEvent, + worker_action_listener_process, +) +from hatchet_sdk.worker.runner.run_loop_manager import ( + STOP_LOOP_TYPE, + WorkerActionRunLoopManager, +) +from hatchet_sdk.workflow import Step + +if TYPE_CHECKING: + from hatchet_sdk.workflow import BaseWorkflow T = TypeVar("T") @@ -45,9 +51,6 @@ class WorkerStartOptions: loop: asyncio.AbstractEventLoop | None = field(default=None) -TWorkflow = TypeVar("TWorkflow", bound=object) - - class Worker: def __init__( self, @@ -69,20 +72,20 @@ def __init__( self.client: Client - self.action_registry: dict[str, Callable[[Context], Any]] = {} + self.action_registry: dict[str, Step[Any]] = {} self.validator_registry: dict[str, WorkflowValidator] = {} self.killing: bool = False self._status: WorkerStatus self.action_listener_process: BaseProcess - self.action_listener_health_check: asyncio.Task[Any] + self.action_listener_health_check: asyncio.Task[None] self.action_runner: WorkerActionRunLoopManager self.ctx = multiprocessing.get_context("spawn") - self.action_queue: "Queue[Any]" = self.ctx.Queue() - self.event_queue: "Queue[Any]" = self.ctx.Queue() + self.action_queue: "Queue[Action | STOP_LOOP_TYPE]" = self.ctx.Queue() + self.event_queue: "Queue[ActionEvent]" = self.ctx.Queue() self.loop: asyncio.AbstractEventLoop @@ -95,9 +98,6 @@ def __init__( "hatchet_worker_status", "Current status of the Hatchet worker" ) - def register_function(self, action: str, func: Callable[[Context], Any]) -> None: - self.action_registry[action] = func - def register_workflow_from_opts( self, name: str, opts: CreateWorkflowVersionOpts ) -> None: @@ -108,10 +108,7 @@ def register_workflow_from_opts( logger.error(e) sys.exit(1) - def register_workflow(self, workflow: TWorkflow) -> None: - ## Hack for typing - assert isinstance(workflow, WorkflowInterface) - + def register_workflow(self, workflow: Union["BaseWorkflow", Any]) -> None: namespace = self.client.config.namespace try: @@ -123,25 +120,13 @@ def register_workflow(self, workflow: TWorkflow) -> None: logger.error(e) sys.exit(1) - def create_action_function( - action_func: Callable[..., T] - ) -> Callable[[Context], T]: - def action_function(context: Context) -> T: - return action_func(workflow, context) - - if asyncio.iscoroutinefunction(action_func): - setattr(action_function, "is_coroutine", True) - else: - setattr(action_function, "is_coroutine", False) - - return action_function - - for action_name, action_func in workflow.get_actions(namespace): - self.action_registry[action_name] = create_action_function(action_func) - return_type = get_type_hints(action_func).get("return") + for step in workflow.steps: + action_name = workflow.create_action_name(namespace, step) + self.action_registry[action_name] = step + return_type = get_type_hints(step.fn).get("return") self.validator_registry[action_name] = WorkflowValidator( - workflow_input=workflow.input_validator, + workflow_input=workflow.config.input_validator, step_output=return_type if is_basemodel_subclass(return_type) else None, ) @@ -173,7 +158,7 @@ async def metrics_handler(self, request: Request) -> Response: return web.Response(body=generate_latest(), content_type="text/plain") async def start_health_server(self) -> None: - port = self.config.worker_healthcheck_port or 8001 + port = self.config.healthcheck.port app = web.Application() app.add_routes( @@ -195,12 +180,10 @@ async def start_health_server(self) -> None: logger.info(f"healthcheck server running on port {port}") - def start( - self, options: WorkerStartOptions = WorkerStartOptions() - ) -> Future[asyncio.Task[Any] | None]: + def start(self, options: WorkerStartOptions = WorkerStartOptions()) -> None: self.owned_loop = self.setup_loop(options.loop) - f = asyncio.run_coroutine_threadsafe( + asyncio.run_coroutine_threadsafe( self.async_start(options, _from_start=True), self.loop ) @@ -211,14 +194,12 @@ def start( if self.handle_kill: sys.exit(0) - return f - ## Start methods async def async_start( self, options: WorkerStartOptions = WorkerStartOptions(), _from_start: bool = False, - ) -> Any | None: + ) -> None: main_pid = os.getpid() logger.info("------------------------------------------") logger.info("STARTING HATCHET...") @@ -236,7 +217,7 @@ async def async_start( if not _from_start: self.setup_loop(options.loop) - if self.config.worker_healthcheck_enabled: + if self.config.healthcheck.enabled: await self.start_health_server() self.action_listener_process = self._start_listener() @@ -247,7 +228,7 @@ async def async_start( self._check_listener_health() ) - return await self.action_listener_health_check + await self.action_listener_health_check def _run_action_runner(self) -> WorkerActionRunLoopManager: # Retrieve the shared queue @@ -362,7 +343,8 @@ def exit_forcefully(self) -> None: logger.debug(f"forcefully stopping worker: {self.name}") - self.close() + ## TODO: `self.close` needs to be awaited / used + self.close() # type: ignore[unused-coroutine] if self.action_listener_process: self.action_listener_process.kill() # Forcefully kill the process @@ -371,22 +353,3 @@ def exit_forcefully(self) -> None: sys.exit( 1 ) # Exit immediately TODO - should we exit with 1 here, there may be other workers to cleanup - - -def register_on_worker(callable: HatchetCallable[T], worker: Worker) -> None: - worker.register_function(callable.get_action_name(), callable) - - if callable.function_on_failure is not None: - worker.register_function( - callable.function_on_failure.get_action_name(), callable.function_on_failure - ) - - if callable.function_concurrency is not None: - worker.register_function( - callable.function_concurrency.get_action_name(), - callable.function_concurrency, - ) - - opts = callable.to_workflow_opts() - - worker.register_workflow_from_opts(opts.name, opts) diff --git a/hatchet_sdk/workflow.py b/hatchet_sdk/workflow.py index 9c5cef90..911ae8ea 100644 --- a/hatchet_sdk/workflow.py +++ b/hatchet_sdk/workflow.py @@ -1,64 +1,64 @@ -import functools +import asyncio +from dataclasses import dataclass, field +from enum import Enum from typing import ( + TYPE_CHECKING, Any, + Awaitable, Callable, - Protocol, + Generic, + ParamSpec, Type, + TypeGuard, TypeVar, Union, cast, - get_type_hints, - runtime_checkable, ) -from pydantic import BaseModel +from pydantic import BaseModel, ConfigDict -from hatchet_sdk import ConcurrencyLimitStrategy +from hatchet_sdk.clients.admin import ChildTriggerWorkflowOptions, ChildWorkflowRunDict +from hatchet_sdk.context.context import Context from hatchet_sdk.contracts.workflows_pb2 import ( + ConcurrencyLimitStrategy as ConcurrencyLimitStrategyProto, +) +from hatchet_sdk.contracts.workflows_pb2 import ( + CreateStepRateLimit, CreateWorkflowJobOpts, CreateWorkflowStepOpts, CreateWorkflowVersionOpts, - StickyStrategy, - WorkflowConcurrencyOpts, - WorkflowKind, + DesiredWorkerLabels, ) +from hatchet_sdk.contracts.workflows_pb2 import StickyStrategy as StickyStrategyProto +from hatchet_sdk.contracts.workflows_pb2 import WorkflowConcurrencyOpts, WorkflowKind from hatchet_sdk.logger import logger -from hatchet_sdk.utils.typing import is_basemodel_subclass - +from hatchet_sdk.workflow_run import WorkflowRunRef -class WorkflowStepProtocol(Protocol): - def __call__(self, *args: Any, **kwargs: Any) -> Any: ... +if TYPE_CHECKING: + from hatchet_sdk import Hatchet - __name__: str +R = TypeVar("R") +P = ParamSpec("P") - _step_name: str - _step_timeout: str | None - _step_parents: list[str] - _step_retries: int | None - _step_rate_limits: list[str] | None - _step_desired_worker_labels: dict[str, str] - _step_backoff_factor: float | None - _step_backoff_max_seconds: int | None - _concurrency_fn_name: str - _concurrency_max_runs: int | None - _concurrency_limit_strategy: str | None +class EmptyModel(BaseModel): + model_config = ConfigDict(extra="allow") - _on_failure_step_name: str - _on_failure_step_timeout: str | None - _on_failure_step_retries: int - _on_failure_step_rate_limits: list[str] | None - _on_failure_step_backoff_factor: float | None - _on_failure_step_backoff_max_seconds: int | None +class StickyStrategy(str, Enum): + SOFT = "SOFT" + HARD = "HARD" -StepsType = list[tuple[str, WorkflowStepProtocol]] -T = TypeVar("T") -TW = TypeVar("TW", bound="WorkflowInterface") +class ConcurrencyLimitStrategy(str, Enum): + CANCEL_IN_PROGRESS = "CANCEL_IN_PROGRESS" + DROP_NEWEST = "DROP_NEWEST" + QUEUE_NEWEST = "QUEUE_NEWEST" + GROUP_ROUND_ROBIN = "GROUP_ROUND_ROBIN" + CANCEL_NEWEST = "CANCEL_NEWEST" -class ConcurrencyExpression: +class ConcurrencyExpression(BaseModel): """ Defines concurrency limits for a workflow using a CEL expression. @@ -71,191 +71,366 @@ class ConcurrencyExpression: ConcurrencyExpression("input.user_id", 5, ConcurrencyLimitStrategy.CANCEL_IN_PROGRESS) """ + expression: str + max_runs: int + limit_strategy: ConcurrencyLimitStrategy + + +TWorkflowInput = TypeVar("TWorkflowInput", bound=BaseModel, default=EmptyModel) + + +class WorkflowConfig(BaseModel): + model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True) + + name: str = "" + on_events: list[str] = [] + on_crons: list[str] = [] + version: str = "" + timeout: str = "60m" + schedule_timeout: str = "5m" + sticky: StickyStrategy | None = None + default_priority: int = 1 + concurrency: ConcurrencyExpression | None = None + input_validator: Type[BaseModel] = EmptyModel + + +class StepType(str, Enum): + DEFAULT = "default" + CONCURRENCY = "concurrency" + ON_FAILURE = "on_failure" + + +AsyncFunc = Callable[[Any, Context], Awaitable[R]] +SyncFunc = Callable[[Any, Context], R] +StepFunc = Union[AsyncFunc[R], SyncFunc[R]] + + +def is_async_fn(fn: StepFunc[R]) -> TypeGuard[AsyncFunc[R]]: + return asyncio.iscoroutinefunction(fn) + + +def is_sync_fn(fn: StepFunc[R]) -> TypeGuard[SyncFunc[R]]: + return not asyncio.iscoroutinefunction(fn) + + +class Step(Generic[R]): def __init__( - self, expression: str, max_runs: int, limit_strategy: ConcurrencyLimitStrategy - ): - self.expression = expression - self.max_runs = max_runs - self.limit_strategy = limit_strategy - - -@runtime_checkable -class WorkflowInterface(Protocol): - def get_name(self, namespace: str) -> str: ... - - def get_actions(self, namespace: str) -> list[tuple[str, Callable[..., Any]]]: ... - - def get_create_opts(self, namespace: str) -> Any: ... - - on_events: list[str] | None - on_crons: list[str] | None - name: str - version: str - timeout: str - schedule_timeout: str - sticky: Union[StickyStrategy.Value, None] # type: ignore[name-defined] - default_priority: int | None - concurrency_expression: ConcurrencyExpression | None - input_validator: Type[BaseModel] | None - - -class WorkflowMeta(type): - def __new__( - cls: Type["WorkflowMeta"], - name: str, - bases: tuple[type, ...], - attrs: dict[str, Any], - ) -> "WorkflowMeta": - def _create_steps_actions_list(name: str) -> StepsType: - return [ - (getattr(func, name), attrs.pop(func_name)) - for func_name, func in list(attrs.items()) - if hasattr(func, name) - ] - - concurrencyActions = _create_steps_actions_list("_concurrency_fn_name") - steps = _create_steps_actions_list("_step_name") - - onFailureSteps = _create_steps_actions_list("_on_failure_step_name") - - # Define __init__ and get_step_order methods - original_init = attrs.get("__init__") # Get the original __init__ if it exists - - def __init__(self: TW, *args: Any, **kwargs: Any) -> None: - if original_init: - original_init(self, *args, **kwargs) # Call original __init__ - - def get_service_name(namespace: str) -> str: - return f"{namespace}{name.lower()}" - - @functools.cache - def get_actions(self: TW, namespace: str) -> StepsType: - serviceName = get_service_name(namespace) - - func_actions = [ - (serviceName + ":" + func_name, func) for func_name, func in steps - ] - concurrency_actions = [ - (serviceName + ":" + func_name, func) - for func_name, func in concurrencyActions - ] - onFailure_actions = [ - (serviceName + ":" + func_name, func) - for func_name, func in onFailureSteps - ] - - return func_actions + concurrency_actions + onFailure_actions - - # Add these methods and steps to class attributes - attrs["__init__"] = __init__ - attrs["get_actions"] = get_actions - - for step_name, step_func in steps: - attrs[step_name] = step_func - - def get_name(self: TW, namespace: str) -> str: - return namespace + cast(str, attrs["name"]) - - attrs["get_name"] = get_name - - cron_triggers = attrs["on_crons"] - version = attrs["version"] - schedule_timeout = attrs["schedule_timeout"] - sticky = attrs["sticky"] - default_priority = attrs["default_priority"] - - @functools.cache - def get_create_opts(self: TW, namespace: str) -> CreateWorkflowVersionOpts: - serviceName = get_service_name(namespace) - name = self.get_name(namespace) - event_triggers = [namespace + event for event in attrs["on_events"]] - createStepOpts: list[CreateWorkflowStepOpts] = [ + self, + fn: Callable[[Any, Context], R] | Callable[[Any, Context], Awaitable[R]], + type: StepType, + name: str = "", + timeout: str = "60m", + parents: list[str] = [], + retries: int = 0, + rate_limits: list[CreateStepRateLimit] = [], + desired_worker_labels: dict[str, DesiredWorkerLabels] = {}, + backoff_factor: float | None = None, + backoff_max_seconds: int | None = None, + concurrency__max_runs: int | None = None, + concurrency__limit_strategy: ConcurrencyLimitStrategy | None = None, + ) -> None: + self.fn = fn + self.is_async_function = is_async_fn(fn) + self.workflow: Union["BaseWorkflow", None] = None + + self.type = type + self.timeout = timeout + self.name = name + self.parents = parents + self.retries = retries + self.rate_limits = rate_limits + self.desired_worker_labels = desired_worker_labels + self.backoff_factor = backoff_factor + self.backoff_max_seconds = backoff_max_seconds + self.concurrency__max_runs = concurrency__max_runs + self.concurrency__limit_strategy = concurrency__limit_strategy + + def call(self, ctx: Context) -> R: + if not self.is_registered: + raise ValueError( + "Only steps that have been registered can be called. To register this step, instantiate its corresponding workflow." + ) + + if self.is_async_function: + raise TypeError(f"{self.name} is not a sync function. Use `acall` instead.") + + sync_fn = self.fn + if is_sync_fn(sync_fn): + return sync_fn(self.workflow, ctx) + + raise TypeError(f"{self.name} is not a sync function. Use `acall` instead.") + + async def acall(self, ctx: Context) -> R: + if not self.is_registered: + raise ValueError( + "Only steps that have been registered can be called. To register this step, instantiate its corresponding workflow." + ) + + if not self.is_async_function: + raise TypeError( + f"{self.name} is not an async function. Use `call` instead." + ) + + async_fn = self.fn + + if is_async_fn(async_fn): + return await async_fn(self.workflow, ctx) + + raise TypeError(f"{self.name} is not an async function. Use `call` instead.") + + @property + def is_registered(self) -> bool: + return self.workflow is not None + + +@dataclass +class SpawnWorkflowInput(Generic[TWorkflowInput]): + workflow_name: str + input: TWorkflowInput + key: str | None = None + options: ChildTriggerWorkflowOptions = field( + default_factory=ChildTriggerWorkflowOptions + ) + + +class WorkflowDeclaration(Generic[TWorkflowInput]): + + def __init__(self, config: WorkflowConfig, hatchet: Union["Hatchet", None]): + self.config = config + self.hatchet = hatchet + + def run(self, input: TWorkflowInput | None = None) -> Any: + if not self.hatchet: + raise ValueError("Hatchet client is not initialized.") + + return self.hatchet.admin.run_workflow( + workflow_name=self.config.name, input=input.model_dump() if input else {} + ) + + def get_workflow_input(self, ctx: Context) -> TWorkflowInput: + return cast(TWorkflowInput, ctx.workflow_input) + + def construct_spawn_workflow_input( + self, + input: TWorkflowInput, + key: str | None = None, + options: ChildTriggerWorkflowOptions = ChildTriggerWorkflowOptions(), + ) -> SpawnWorkflowInput[TWorkflowInput]: + return SpawnWorkflowInput[TWorkflowInput]( + workflow_name=self.config.name, input=input, key=key, options=options + ) + + async def spawn_many( + self, ctx: Context, spawn_inputs: list[SpawnWorkflowInput[TWorkflowInput]] + ) -> list[WorkflowRunRef]: + inputs = [ + ChildWorkflowRunDict( + workflow_name=spawn_input.workflow_name, + input=spawn_input.input.model_dump(), + key=spawn_input.key, + options=spawn_input.options, + ) + for spawn_input in spawn_inputs + ] + return await ctx.spawn_workflows(inputs) + + async def spawn_one( + self, + ctx: Context, + input: TWorkflowInput, + key: str | None = None, + options: ChildTriggerWorkflowOptions = ChildTriggerWorkflowOptions(), + ) -> WorkflowRunRef: + return await ctx.spawn_workflow( + workflow_name=self.config.name, + input=input.model_dump(), + key=key, + options=options, + ) + + +class BaseWorkflow: + """ + A Hatchet workflow implementation base. This class should be inherited by all workflow implementations. + + Configuration is passed to the workflow implementation via the `config` attribute. + """ + + config: WorkflowConfig = WorkflowConfig() + + def __init__(self) -> None: + self.config.name = self.config.name or str(self.__class__.__name__) + + for step in self.steps: + step.workflow = self + + def get_service_name(self, namespace: str) -> str: + return f"{namespace}{self.config.name.lower()}" + + def _get_steps_by_type(self, step_type: StepType) -> list[Step[Any]]: + return [ + attr + for _, attr in self.__class__.__dict__.items() + if isinstance(attr, Step) and attr.type == step_type + ] + + @property + def on_failure_steps(self) -> list[Step[Any]]: + return self._get_steps_by_type(StepType.ON_FAILURE) + + @property + def concurrency_actions(self) -> list[Step[Any]]: + return self._get_steps_by_type(StepType.CONCURRENCY) + + @property + def default_steps(self) -> list[Step[Any]]: + return self._get_steps_by_type(StepType.DEFAULT) + + @property + def steps(self) -> list[Step[Any]]: + return self.default_steps + self.concurrency_actions + self.on_failure_steps + + def create_action_name(self, namespace: str, step: Step[Any]) -> str: + return self.get_service_name(namespace) + ":" + step.name + + def get_name(self, namespace: str) -> str: + return namespace + self.config.name + + def validate_concurrency_actions( + self, service_name: str + ) -> WorkflowConcurrencyOpts | None: + if len(self.concurrency_actions) > 0 and self.config.concurrency: + raise ValueError( + "Error: Both concurrencyActions and concurrency_expression are defined. Please use only one concurrency configuration method." + ) + + if len(self.concurrency_actions) > 0: + action = self.concurrency_actions[0] + + return WorkflowConcurrencyOpts( + action=service_name + ":" + action.name, + max_runs=action.concurrency__max_runs, + limit_strategy=cast( + str | None, + self.validate_concurrency(action.concurrency__limit_strategy), + ), + ) + + if self.config.concurrency: + return WorkflowConcurrencyOpts( + expression=self.config.concurrency.expression, + max_runs=self.config.concurrency.max_runs, + limit_strategy=self.config.concurrency.limit_strategy, + ) + + return None + + def validate_on_failure_steps( + self, name: str, service_name: str + ) -> CreateWorkflowJobOpts | None: + if not self.on_failure_steps: + return None + + on_failure_step = next(iter(self.on_failure_steps)) + + return CreateWorkflowJobOpts( + name=name + "-on-failure", + steps=[ CreateWorkflowStepOpts( - readable_id=step_name, - action=serviceName + ":" + step_name, - timeout=func._step_timeout or "60s", + readable_id=on_failure_step.name, + action=service_name + ":" + on_failure_step.name, + timeout=on_failure_step.timeout or "60s", inputs="{}", - parents=[x for x in func._step_parents], - retries=func._step_retries, - rate_limits=func._step_rate_limits, # type: ignore[arg-type] - worker_labels=func._step_desired_worker_labels, # type: ignore[arg-type] - backoff_factor=func._step_backoff_factor, - backoff_max_seconds=func._step_backoff_max_seconds, + parents=[], + retries=on_failure_step.retries, + rate_limits=on_failure_step.rate_limits, + backoff_factor=on_failure_step.backoff_factor, + backoff_max_seconds=on_failure_step.backoff_max_seconds, ) - for step_name, func in steps - ] + ], + ) + + def validate_priority(self, default_priority: int | None) -> int | None: + validated_priority = ( + max(1, min(3, default_priority)) if default_priority else None + ) + if validated_priority != default_priority: + logger.warning( + "Warning: Default Priority Must be between 1 and 3 -- inclusively. Adjusted to be within the range." + ) - concurrency: WorkflowConcurrencyOpts | None = None + return validated_priority - if len(concurrencyActions) > 0: - action = concurrencyActions[0] + def validate_concurrency( + self, concurrency: ConcurrencyLimitStrategy | None + ) -> int | None: + if not concurrency: + return None - concurrency = WorkflowConcurrencyOpts( - action=serviceName + ":" + action[0], - max_runs=action[1]._concurrency_max_runs, - limit_strategy=action[1]._concurrency_limit_strategy, - ) + names = [item.name for item in ConcurrencyLimitStrategyProto.DESCRIPTOR.values] - if self.concurrency_expression: - concurrency = WorkflowConcurrencyOpts( - expression=self.concurrency_expression.expression, - max_runs=self.concurrency_expression.max_runs, - limit_strategy=self.concurrency_expression.limit_strategy, - ) + for name in names: + if name == concurrency.name: + return StickyStrategyProto.Value(concurrency.name) - if len(concurrencyActions) > 0 and self.concurrency_expression: - raise ValueError( - "Error: Both concurrencyActions and concurrency_expression are defined. Please use only one concurrency configuration method." - ) + raise ValueError( + f"Concurrency limit strategy must be one of {names}. Got: {concurrency}" + ) - on_failure_job: CreateWorkflowJobOpts | None = None - - if len(onFailureSteps) > 0: - func_name, func = onFailureSteps[0] - on_failure_job = CreateWorkflowJobOpts( - name=name + "-on-failure", - steps=[ - CreateWorkflowStepOpts( - readable_id=func_name, - action=serviceName + ":" + func_name, - timeout=func._on_failure_step_timeout or "60s", - inputs="{}", - parents=[], - retries=func._on_failure_step_retries, - rate_limits=func._on_failure_step_rate_limits, # type: ignore[arg-type] - backoff_factor=func._on_failure_step_backoff_factor, - backoff_max_seconds=func._on_failure_step_backoff_max_seconds, - ) - ], - ) + def validate_sticky(self, sticky: StickyStrategy | None) -> int | None: + if not sticky: + return None - validated_priority = ( - max(1, min(3, default_priority)) if default_priority else None - ) - if validated_priority != default_priority: - logger.warning( - "Warning: Default Priority Must be between 1 and 3 -- inclusively. Adjusted to be within the range." - ) + names = [item.name for item in StickyStrategyProto.DESCRIPTOR.values] - return CreateWorkflowVersionOpts( - name=name, - kind=WorkflowKind.DAG, - version=version, - event_triggers=event_triggers, - cron_triggers=cron_triggers, - schedule_timeout=schedule_timeout, - sticky=sticky, - jobs=[ - CreateWorkflowJobOpts( - name=name, - steps=createStepOpts, - ) - ], - on_failure_job=on_failure_job, - concurrency=concurrency, - default_priority=validated_priority, - ) + for name in names: + if name == sticky.name: + return StickyStrategyProto.Value(sticky.name) - attrs["get_create_opts"] = get_create_opts + raise ValueError(f"Sticky strategy must be one of {names}. Got: {sticky}") - return super(WorkflowMeta, cls).__new__(cls, name, bases, attrs) + def get_create_opts(self, namespace: str) -> CreateWorkflowVersionOpts: + service_name = self.get_service_name(namespace) + + name = self.get_name(namespace) + event_triggers = [namespace + event for event in self.config.on_events] + + create_step_opts = [ + CreateWorkflowStepOpts( + readable_id=step.name, + action=service_name + ":" + step.name, + timeout=step.timeout or "60s", + inputs="{}", + parents=[x for x in step.parents], + retries=step.retries, + rate_limits=step.rate_limits, + worker_labels=step.desired_worker_labels, + backoff_factor=step.backoff_factor, + backoff_max_seconds=step.backoff_max_seconds, + ) + for step in self.steps + if step.type == StepType.DEFAULT + ] + + concurrency = self.validate_concurrency_actions(service_name) + on_failure_job = self.validate_on_failure_steps(name, service_name) + validated_priority = self.validate_priority(self.config.default_priority) + + return CreateWorkflowVersionOpts( + name=name, + kind=WorkflowKind.DAG, + version=self.config.version, + event_triggers=event_triggers, + cron_triggers=self.config.on_crons, + schedule_timeout=self.config.schedule_timeout, + sticky=cast(str | None, self.validate_sticky(self.config.sticky)), + jobs=[ + CreateWorkflowJobOpts( + name=name, + steps=create_step_opts, + ) + ], + on_failure_job=on_failure_job, + concurrency=concurrency, + default_priority=validated_priority, + ) diff --git a/hatchet_sdk/workflow_run.py b/hatchet_sdk/workflow_run.py index 51a23821..4c21364b 100644 --- a/hatchet_sdk/workflow_run.py +++ b/hatchet_sdk/workflow_run.py @@ -1,5 +1,5 @@ import asyncio -from typing import Any, Coroutine, Generic, Optional, TypedDict, TypeVar +from typing import Any, Coroutine, Generic, TypeVar from hatchet_sdk.clients.run_event_listener import ( RunEventListener, @@ -22,19 +22,19 @@ def __init__( self.workflow_listener = workflow_listener self.workflow_run_event_listener = workflow_run_event_listener - def __str__(self): + def __str__(self) -> str: return self.workflow_run_id def stream(self) -> RunEventListener: return self.workflow_run_event_listener.stream(self.workflow_run_id) - def result(self) -> Coroutine: + def result(self) -> Coroutine[None, None, dict[str, Any]]: return self.workflow_listener.result(self.workflow_run_id) - def sync_result(self) -> dict: + def sync_result(self) -> dict[str, Any]: loop = get_active_event_loop() if loop is None: - with EventLoopThread() as loop: + with EventLoopThread() as loop: # type: ignore[call-arg] coro = self.workflow_listener.result(self.workflow_run_id) future = asyncio.run_coroutine_threadsafe(coro, loop) return future.result() @@ -48,7 +48,7 @@ def sync_result(self) -> dict: class RunRef(WorkflowRunRef, Generic[T]): - async def result(self) -> T: + async def result(self) -> Any | dict[str, Any]: res = await self.workflow_listener.result(self.workflow_run_id) if len(res) == 1: diff --git a/lint.sh b/lint.sh index 8b0263f4..ad5206c2 100755 --- a/lint.sh +++ b/lint.sh @@ -1,3 +1,4 @@ poetry run black . --color poetry run isort . poetry run mypy --config-file=pyproject.toml +poetry run ruff . --fix diff --git a/poetry.lock b/poetry.lock index 603caef7..cfe6cbe5 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.4 and should not be changed by hand. +# This file is automatically @generated by Poetry 2.0.0 and should not be changed by hand. [[package]] name = "aiohappyeyeballs" @@ -6,6 +6,7 @@ version = "2.4.4" description = "Happy Eyeballs for asyncio" optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "aiohappyeyeballs-2.4.4-py3-none-any.whl", hash = "sha256:a980909d50efcd44795c4afeca523296716d50cd756ddca6af8c65b996e27de8"}, {file = "aiohappyeyeballs-2.4.4.tar.gz", hash = "sha256:5fdd7d87889c63183afc18ce9271f9b0a7d32c2303e394468dd45d514a757745"}, @@ -17,6 +18,7 @@ version = "3.11.11" description = "Async http client/server framework (asyncio)" optional = false python-versions = ">=3.9" +groups = ["main"] files = [ {file = "aiohttp-3.11.11-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:a60804bff28662cbcf340a4d61598891f12eea3a66af48ecfdc975ceec21e3c8"}, {file = "aiohttp-3.11.11-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:4b4fa1cb5f270fb3eab079536b764ad740bb749ce69a94d4ec30ceee1b5940d5"}, @@ -115,6 +117,7 @@ version = "2.9.1" description = "Simple retry client for aiohttp" optional = false python-versions = ">=3.7" +groups = ["main"] files = [ {file = "aiohttp_retry-2.9.1-py3-none-any.whl", hash = "sha256:66d2759d1921838256a05a3f80ad7e724936f083e35be5abb5e16eed6be6dc54"}, {file = "aiohttp_retry-2.9.1.tar.gz", hash = "sha256:8eb75e904ed4ee5c2ec242fefe85bf04240f685391c4879d8f541d6028ff01f1"}, @@ -129,6 +132,7 @@ version = "1.3.2" description = "aiosignal: a list of registered asynchronous callbacks" optional = false python-versions = ">=3.9" +groups = ["main"] files = [ {file = "aiosignal-1.3.2-py2.py3-none-any.whl", hash = "sha256:45cde58e409a301715980c2b01d0c28bdde3770d8290b5eb2173759d9acb31a5"}, {file = "aiosignal-1.3.2.tar.gz", hash = "sha256:a8c255c66fafb1e499c9351d0bf32ff2d8a0321595ebac3b93713656d2436f54"}, @@ -143,6 +147,7 @@ version = "0.5.2" description = "Generator-based operators for asynchronous iteration" optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "aiostream-0.5.2-py3-none-any.whl", hash = "sha256:054660370be9d37f6fe3ece3851009240416bd082e469fd90cc8673d3818cf71"}, {file = "aiostream-0.5.2.tar.gz", hash = "sha256:b71b519a2d66c38f0872403ab86417955b77352f08d9ad02ad46fc3926b389f4"}, @@ -157,6 +162,7 @@ version = "0.7.0" description = "Reusable constraint types to use with typing.Annotated" optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "annotated_types-0.7.0-py3-none-any.whl", hash = "sha256:1f02e8b43a8fbbc3f3e0d4f0f4bfc8131bcb4eebe8849b8e5c773f3a1c582a53"}, {file = "annotated_types-0.7.0.tar.gz", hash = "sha256:aff07c09a53a08bc8cfccb9c85b05f1aa9a2a6f23728d790723543408344ce89"}, @@ -168,6 +174,8 @@ version = "5.0.1" description = "Timeout context manager for asyncio programs" optional = false python-versions = ">=3.8" +groups = ["main"] +markers = "python_version < \"3.11\"" files = [ {file = "async_timeout-5.0.1-py3-none-any.whl", hash = "sha256:39e3809566ff85354557ec2398b55e096c8364bacac9405a7a1fa429e77fe76c"}, {file = "async_timeout-5.0.1.tar.gz", hash = "sha256:d9321a7a3d5a6a5e187e824d2fa0793ce379a202935782d555d6e9d2735677d3"}, @@ -179,6 +187,7 @@ version = "24.3.0" description = "Classes Without Boilerplate" optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "attrs-24.3.0-py3-none-any.whl", hash = "sha256:ac96cd038792094f438ad1f6ff80837353805ac950cd2aa0e0625ef19850c308"}, {file = "attrs-24.3.0.tar.gz", hash = "sha256:8f5c07333d543103541ba7be0e2ce16eeee8130cb0b3f9238ab904ce1e85baff"}, @@ -198,6 +207,7 @@ version = "2.16.0" description = "Internationalization utilities" optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "babel-2.16.0-py3-none-any.whl", hash = "sha256:368b5b98b37c06b7daf6696391c3240c938b37767d4584413e8438c5c435fa8b"}, {file = "babel-2.16.0.tar.gz", hash = "sha256:d1f3554ca26605fe173f3de0c65f750f5a42f924499bf134de6423582298e316"}, @@ -212,6 +222,7 @@ version = "24.10.0" description = "The uncompromising code formatter." optional = false python-versions = ">=3.9" +groups = ["lint"] files = [ {file = "black-24.10.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:e6668650ea4b685440857138e5fe40cde4d652633b1bdffc62933d0db4ed9812"}, {file = "black-24.10.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:1c536fcf674217e87b8cc3657b81809d3c085d7bf3ef262ead700da345bfa6ea"}, @@ -258,6 +269,7 @@ version = "0.1.5" description = "Pure Python CEL Implementation" optional = false python-versions = ">=3.7, <4" +groups = ["main"] files = [ {file = "cel-python-0.1.5.tar.gz", hash = "sha256:d3911bb046bc3ed12792bd88ab453f72d98c66923b72a2fa016bcdffd96e2f98"}, {file = "cel_python-0.1.5-py3-none-any.whl", hash = "sha256:ac81fab8ba08b633700a45d84905be2863529c6a32935c9da7ef53fc06844f1a"}, @@ -278,6 +290,7 @@ version = "2024.12.14" description = "Python package for providing Mozilla's CA Bundle." optional = false python-versions = ">=3.6" +groups = ["main"] files = [ {file = "certifi-2024.12.14-py3-none-any.whl", hash = "sha256:1275f7a45be9464efc1173084eaa30f866fe2e47d389406136d332ed4967ec56"}, {file = "certifi-2024.12.14.tar.gz", hash = "sha256:b650d30f370c2b724812bee08008be0c4163b163ddaec3f2546c1caf65f191db"}, @@ -289,6 +302,7 @@ version = "3.4.1" description = "The Real First Universal Charset Detector. Open, modern and actively maintained alternative to Chardet." optional = false python-versions = ">=3.7" +groups = ["main"] files = [ {file = "charset_normalizer-3.4.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:91b36a978b5ae0ee86c394f5a54d6ef44db1de0815eb43de826d41d21e4af3de"}, {file = "charset_normalizer-3.4.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7461baadb4dc00fd9e0acbe254e3d7d2112e7f92ced2adc96e54ef6501c5f176"}, @@ -390,6 +404,7 @@ version = "8.1.8" description = "Composable command line interface toolkit" optional = false python-versions = ">=3.7" +groups = ["lint"] files = [ {file = "click-8.1.8-py3-none-any.whl", hash = "sha256:63c132bbbed01578a06712a2d1f497bb62d9c1c0d329b7903a866228027263b2"}, {file = "click-8.1.8.tar.gz", hash = "sha256:ed53c9d8990d83c2a27deae68e4ee337473f6330c040a31d4225c9574d16096a"}, @@ -404,10 +419,12 @@ version = "0.4.6" description = "Cross-platform colored terminal text." optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7" +groups = ["main", "dev", "lint", "test"] files = [ {file = "colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6"}, {file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"}, ] +markers = {main = "sys_platform == \"win32\"", dev = "sys_platform == \"win32\"", lint = "platform_system == \"Windows\"", test = "sys_platform == \"win32\""} [[package]] name = "deprecated" @@ -415,6 +432,7 @@ version = "1.2.15" description = "Python @deprecated decorator to deprecate old python classes, functions or methods." optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,>=2.7" +groups = ["main"] files = [ {file = "Deprecated-1.2.15-py2.py3-none-any.whl", hash = "sha256:353bc4a8ac4bfc96800ddab349d89c25dec1079f65fd53acdcc1e0b975b21320"}, {file = "deprecated-1.2.15.tar.gz", hash = "sha256:683e561a90de76239796e6b6feac66b99030d2dd3fcf61ef996330f14bbb9b0d"}, @@ -432,6 +450,8 @@ version = "1.2.2" description = "Backport of PEP 654 (exception groups)" optional = false python-versions = ">=3.7" +groups = ["dev", "test"] +markers = "python_version < \"3.11\"" files = [ {file = "exceptiongroup-1.2.2-py3-none-any.whl", hash = "sha256:3111b9d131c238bec2f8f516e123e14ba243563fb135d3fe885990585aa7795b"}, {file = "exceptiongroup-1.2.2.tar.gz", hash = "sha256:47c2edf7c6738fafb49fd34290706d1a1a2f4d1c6df275526b62cbb4aa5393cc"}, @@ -446,6 +466,7 @@ version = "1.5.0" description = "A list-like structure which implements collections.abc.MutableSequence" optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "frozenlist-1.5.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:5b6a66c18b5b9dd261ca98dffcb826a525334b2f29e7caa54e182255c5f6a65a"}, {file = "frozenlist-1.5.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d1b3eb7b05ea246510b43a7e53ed1653e55c2121019a97e60cad7efb881a97bb"}, @@ -547,6 +568,7 @@ version = "1.66.0" description = "Common protobufs used in Google APIs" optional = false python-versions = ">=3.7" +groups = ["main"] files = [ {file = "googleapis_common_protos-1.66.0-py2.py3-none-any.whl", hash = "sha256:d7abcd75fabb2e0ec9f74466401f6c119a0b498e27370e9be4c94cb7e382b8ed"}, {file = "googleapis_common_protos-1.66.0.tar.gz", hash = "sha256:c3e7b33d15fdca5374cc0a7346dd92ffa847425cc4ea941d970f13680052ec8c"}, @@ -558,12 +580,28 @@ protobuf = ">=3.20.2,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4 [package.extras] grpc = ["grpcio (>=1.44.0,<2.0.0.dev0)"] +[[package]] +name = "grpc-stubs" +version = "1.53.0.5" +description = "Mypy stubs for gRPC" +optional = false +python-versions = ">=3.6" +groups = ["dev"] +files = [ + {file = "grpc-stubs-1.53.0.5.tar.gz", hash = "sha256:3e1b642775cbc3e0c6332cfcedfccb022176db87e518757bef3a1241397be406"}, + {file = "grpc_stubs-1.53.0.5-py3-none-any.whl", hash = "sha256:04183fb65a1b166a1febb9627e3d9647d3926ccc2dfe049fe7b6af243428dbe1"}, +] + +[package.dependencies] +grpcio = "*" + [[package]] name = "grpcio" version = "1.69.0" description = "HTTP/2-based RPC framework" optional = false python-versions = ">=3.8" +groups = ["main", "dev"] files = [ {file = "grpcio-1.69.0-cp310-cp310-linux_armv7l.whl", hash = "sha256:2060ca95a8db295ae828d0fc1c7f38fb26ccd5edf9aa51a0f44251f5da332e97"}, {file = "grpcio-1.69.0-cp310-cp310-macosx_12_0_universal2.whl", hash = "sha256:2e52e107261fd8fa8fa457fe44bfadb904ae869d87c1280bf60f93ecd3e79278"}, @@ -631,6 +669,7 @@ version = "1.69.0" description = "Protobuf code generator for gRPC" optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "grpcio_tools-1.69.0-cp310-cp310-linux_armv7l.whl", hash = "sha256:8c210630faa581c3bd08953dac4ad21a7f49862f3b92d69686e9b436d2f1265d"}, {file = "grpcio_tools-1.69.0-cp310-cp310-macosx_12_0_universal2.whl", hash = "sha256:09b66ea279fcdaebae4ec34b1baf7577af3b14322738aa980c1c33cfea71f7d7"}, @@ -700,6 +739,7 @@ version = "3.10" description = "Internationalized Domain Names in Applications (IDNA)" optional = false python-versions = ">=3.6" +groups = ["main"] files = [ {file = "idna-3.10-py3-none-any.whl", hash = "sha256:946d195a0d259cbba61165e88e65941f16e9b36ea6ddb97f00452bae8b1287d3"}, {file = "idna-3.10.tar.gz", hash = "sha256:12f65c9b470abda6dc35cf8e63cc574b1c52b11df2c86030af0ac09b01b13ea9"}, @@ -714,6 +754,7 @@ version = "8.5.0" description = "Read metadata from Python packages" optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "importlib_metadata-8.5.0-py3-none-any.whl", hash = "sha256:45e54197d28b7a7f1559e60b95e7c567032b602131fbd588f1497f47880aa68b"}, {file = "importlib_metadata-8.5.0.tar.gz", hash = "sha256:71522656f0abace1d072b9e5481a48f07c138e00f079c38c8f883823f9c26bd7"}, @@ -737,6 +778,7 @@ version = "2.0.0" description = "brain-dead simple config-ini parsing" optional = false python-versions = ">=3.7" +groups = ["dev", "test"] files = [ {file = "iniconfig-2.0.0-py3-none-any.whl", hash = "sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374"}, {file = "iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3"}, @@ -748,6 +790,7 @@ version = "5.13.2" description = "A Python utility / library to sort Python imports." optional = false python-versions = ">=3.8.0" +groups = ["lint"] files = [ {file = "isort-5.13.2-py3-none-any.whl", hash = "sha256:8ca5e72a8d85860d5a3fa69b8745237f2939afe12dbf656afbcb47fe72d947a6"}, {file = "isort-5.13.2.tar.gz", hash = "sha256:48fdfcb9face5d58a4f6dde2e72a1fb8dcaf8ab26f95ab49fab84c2ddefb0109"}, @@ -762,6 +805,7 @@ version = "1.0.1" description = "JSON Matching Expressions" optional = false python-versions = ">=3.7" +groups = ["main"] files = [ {file = "jmespath-1.0.1-py3-none-any.whl", hash = "sha256:02e2e4cc71b5bcab88332eebf907519190dd9e6e82107fa7f83b1003a6252980"}, {file = "jmespath-1.0.1.tar.gz", hash = "sha256:90261b206d6defd58fdd5e85f478bf633a2901798906be2ad389150c5c60edbe"}, @@ -773,6 +817,7 @@ version = "0.12.0" description = "a modern parsing library" optional = false python-versions = "*" +groups = ["main"] files = [ {file = "lark-parser-0.12.0.tar.gz", hash = "sha256:15967db1f1214013dca65b1180745047b9be457d73da224fcda3d9dd4e96a138"}, {file = "lark_parser-0.12.0-py2.py3-none-any.whl", hash = "sha256:0eaf30cb5ba787fe404d73a7d6e61df97b21d5a63ac26c5008c78a494373c675"}, @@ -789,6 +834,7 @@ version = "0.7.3" description = "Python logging made (stupidly) simple" optional = false python-versions = "<4.0,>=3.5" +groups = ["main"] files = [ {file = "loguru-0.7.3-py3-none-any.whl", hash = "sha256:31a33c10c8e1e10422bfd431aeb5d351c7cf7fa671e3c4df004162264b28220c"}, {file = "loguru-0.7.3.tar.gz", hash = "sha256:19480589e77d47b8d85b2c827ad95d49bf31b0dcde16593892eb51dd18706eb6"}, @@ -807,6 +853,7 @@ version = "6.1.0" description = "multidict implementation" optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "multidict-6.1.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:3380252550e372e8511d49481bd836264c009adb826b23fefcc5dd3c69692f60"}, {file = "multidict-6.1.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:99f826cbf970077383d7de805c0681799491cb939c25450b9b5b3ced03ca99f1"}, @@ -911,6 +958,7 @@ version = "1.14.1" description = "Optional static typing for Python" optional = false python-versions = ">=3.8" +groups = ["lint"] files = [ {file = "mypy-1.14.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:52686e37cf13d559f668aa398dd7ddf1f92c5d613e4f8cb262be2fb4fedb0fcb"}, {file = "mypy-1.14.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:1fb545ca340537d4b45d3eecdb3def05e913299ca72c290326be19b3804b39c0"}, @@ -970,6 +1018,7 @@ version = "1.0.0" description = "Type system extensions for programs checked with the mypy type checker." optional = false python-versions = ">=3.5" +groups = ["lint"] files = [ {file = "mypy_extensions-1.0.0-py3-none-any.whl", hash = "sha256:4392f6c0eb8a5668a69e23d168ffa70f0be9ccfd32b5cc2d26a34ae5b844552d"}, {file = "mypy_extensions-1.0.0.tar.gz", hash = "sha256:75dbf8955dc00442a438fc4d0666508a9a97b6bd41aa2f0ffe9d2f2725af0782"}, @@ -981,6 +1030,7 @@ version = "1.6.0" description = "Patch asyncio to allow nested event loops" optional = false python-versions = ">=3.5" +groups = ["main"] files = [ {file = "nest_asyncio-1.6.0-py3-none-any.whl", hash = "sha256:87af6efd6b5e897c81050477ef65c62e2b2f35d51703cae01aff2905b1852e1c"}, {file = "nest_asyncio-1.6.0.tar.gz", hash = "sha256:6f172d5449aca15afd6c646851f4e31e02c598d553a667e38cafa997cfec55fe"}, @@ -992,6 +1042,7 @@ version = "1.29.0" description = "OpenTelemetry Python API" optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "opentelemetry_api-1.29.0-py3-none-any.whl", hash = "sha256:5fcd94c4141cc49c736271f3e1efb777bebe9cc535759c54c936cca4f1b312b8"}, {file = "opentelemetry_api-1.29.0.tar.gz", hash = "sha256:d04a6cf78aad09614f52964ecb38021e248f5714dc32c2e0d8fd99517b4d69cf"}, @@ -1007,6 +1058,7 @@ version = "0.50b0" description = "OpenTelemetry Python Distro" optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "opentelemetry_distro-0.50b0-py3-none-any.whl", hash = "sha256:5fa2e2a99a047ea477fab53e73fb8088b907bda141e8440745b92eb2a84d74aa"}, {file = "opentelemetry_distro-0.50b0.tar.gz", hash = "sha256:3e059e00f53553ebd646d1162d1d3edf5d7c6d3ceafd54a49e74c90dc1c39a7d"}, @@ -1026,6 +1078,7 @@ version = "1.29.0" description = "OpenTelemetry Collector Exporters" optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "opentelemetry_exporter_otlp-1.29.0-py3-none-any.whl", hash = "sha256:b8da6e20f5b0ffe604154b1e16a407eade17ce310c42fb85bb4e1246fc3688ad"}, {file = "opentelemetry_exporter_otlp-1.29.0.tar.gz", hash = "sha256:ee7dfcccbb5e87ad9b389908452e10b7beeab55f70a83f41ce5b8c4efbde6544"}, @@ -1041,6 +1094,7 @@ version = "1.29.0" description = "OpenTelemetry Protobuf encoding" optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "opentelemetry_exporter_otlp_proto_common-1.29.0-py3-none-any.whl", hash = "sha256:a9d7376c06b4da9cf350677bcddb9618ed4b8255c3f6476975f5e38274ecd3aa"}, {file = "opentelemetry_exporter_otlp_proto_common-1.29.0.tar.gz", hash = "sha256:e7c39b5dbd1b78fe199e40ddfe477e6983cb61aa74ba836df09c3869a3e3e163"}, @@ -1055,6 +1109,7 @@ version = "1.29.0" description = "OpenTelemetry Collector Protobuf over gRPC Exporter" optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "opentelemetry_exporter_otlp_proto_grpc-1.29.0-py3-none-any.whl", hash = "sha256:5a2a3a741a2543ed162676cf3eefc2b4150e6f4f0a193187afb0d0e65039c69c"}, {file = "opentelemetry_exporter_otlp_proto_grpc-1.29.0.tar.gz", hash = "sha256:3d324d07d64574d72ed178698de3d717f62a059a93b6b7685ee3e303384e73ea"}, @@ -1075,6 +1130,7 @@ version = "1.29.0" description = "OpenTelemetry Collector Protobuf over HTTP Exporter" optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "opentelemetry_exporter_otlp_proto_http-1.29.0-py3-none-any.whl", hash = "sha256:b228bdc0f0cfab82eeea834a7f0ffdd2a258b26aa33d89fb426c29e8e934d9d0"}, {file = "opentelemetry_exporter_otlp_proto_http-1.29.0.tar.gz", hash = "sha256:b10d174e3189716f49d386d66361fbcf6f2b9ad81e05404acdee3f65c8214204"}, @@ -1095,6 +1151,7 @@ version = "0.50b0" description = "Instrumentation Tools & Auto Instrumentation for OpenTelemetry Python" optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "opentelemetry_instrumentation-0.50b0-py3-none-any.whl", hash = "sha256:b8f9fc8812de36e1c6dffa5bfc6224df258841fb387b6dfe5df15099daa10630"}, {file = "opentelemetry_instrumentation-0.50b0.tar.gz", hash = "sha256:7d98af72de8dec5323e5202e46122e5f908592b22c6d24733aad619f07d82979"}, @@ -1112,6 +1169,7 @@ version = "1.29.0" description = "OpenTelemetry Python Proto" optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "opentelemetry_proto-1.29.0-py3-none-any.whl", hash = "sha256:495069c6f5495cbf732501cdcd3b7f60fda2b9d3d4255706ca99b7ca8dec53ff"}, {file = "opentelemetry_proto-1.29.0.tar.gz", hash = "sha256:3c136aa293782e9b44978c738fff72877a4b78b5d21a64e879898db7b2d93e5d"}, @@ -1126,6 +1184,7 @@ version = "1.29.0" description = "OpenTelemetry Python SDK" optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "opentelemetry_sdk-1.29.0-py3-none-any.whl", hash = "sha256:173be3b5d3f8f7d671f20ea37056710217959e774e2749d984355d1f9391a30a"}, {file = "opentelemetry_sdk-1.29.0.tar.gz", hash = "sha256:b0787ce6aade6ab84315302e72bd7a7f2f014b0fb1b7c3295b88afe014ed0643"}, @@ -1142,6 +1201,7 @@ version = "0.50b0" description = "OpenTelemetry Semantic Conventions" optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "opentelemetry_semantic_conventions-0.50b0-py3-none-any.whl", hash = "sha256:e87efba8fdb67fb38113efea6a349531e75ed7ffc01562f65b802fcecb5e115e"}, {file = "opentelemetry_semantic_conventions-0.50b0.tar.gz", hash = "sha256:02dc6dbcb62f082de9b877ff19a3f1ffaa3c306300fa53bfac761c4567c83d38"}, @@ -1157,6 +1217,7 @@ version = "24.2" description = "Core utilities for Python packages" optional = false python-versions = ">=3.8" +groups = ["main", "dev", "lint", "test"] files = [ {file = "packaging-24.2-py3-none-any.whl", hash = "sha256:09abb1bccd265c01f4a3aa3f7a7db064b36514d2cba19a2f694fe6150451a759"}, {file = "packaging-24.2.tar.gz", hash = "sha256:c228a6dc5e932d346bc5739379109d49e8853dd8223571c7c5b55260edc0b97f"}, @@ -1168,6 +1229,7 @@ version = "0.12.1" description = "Utility library for gitignore style pattern matching of file paths." optional = false python-versions = ">=3.8" +groups = ["lint"] files = [ {file = "pathspec-0.12.1-py3-none-any.whl", hash = "sha256:a0d503e138a4c123b27490a4f7beda6a01c6f288df0e4a8b79c7eb0dc7b4cc08"}, {file = "pathspec-0.12.1.tar.gz", hash = "sha256:a482d51503a1ab33b1c67a6c3813a26953dbdc71c31dacaef9a838c4e29f5712"}, @@ -1179,6 +1241,7 @@ version = "4.3.6" description = "A small Python package for determining appropriate platform-specific dirs, e.g. a `user data dir`." optional = false python-versions = ">=3.8" +groups = ["lint"] files = [ {file = "platformdirs-4.3.6-py3-none-any.whl", hash = "sha256:73e575e1408ab8103900836b97580d5307456908a03e92031bab39e4554cc3fb"}, {file = "platformdirs-4.3.6.tar.gz", hash = "sha256:357fb2acbc885b0419afd3ce3ed34564c13c9b95c89360cd9563f73aa5e2b907"}, @@ -1195,6 +1258,7 @@ version = "1.5.0" description = "plugin and hook calling mechanisms for python" optional = false python-versions = ">=3.8" +groups = ["dev", "test"] files = [ {file = "pluggy-1.5.0-py3-none-any.whl", hash = "sha256:44e1ad92c8ca002de6377e165f3e0f1be63266ab4d554740532335b9d75ea669"}, {file = "pluggy-1.5.0.tar.gz", hash = "sha256:2cffa88e94fdc978c4c574f15f9e59b7f4201d439195c3715ca9e2486f1d0cf1"}, @@ -1210,6 +1274,7 @@ version = "0.21.1" description = "Python client for the Prometheus monitoring system." optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "prometheus_client-0.21.1-py3-none-any.whl", hash = "sha256:594b45c410d6f4f8888940fe80b5cc2521b305a1fafe1c58609ef715a001f301"}, {file = "prometheus_client-0.21.1.tar.gz", hash = "sha256:252505a722ac04b0456be05c05f75f45d760c2911ffc45f2a06bcaed9f3ae3fb"}, @@ -1224,6 +1289,7 @@ version = "0.2.1" description = "Accelerated property cache" optional = false python-versions = ">=3.9" +groups = ["main"] files = [ {file = "propcache-0.2.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:6b3f39a85d671436ee3d12c017f8fdea38509e4f25b28eb25877293c98c243f6"}, {file = "propcache-0.2.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:39d51fbe4285d5db5d92a929e3e21536ea3dd43732c5b177c7ef03f918dff9f2"}, @@ -1315,6 +1381,7 @@ version = "5.29.2" description = "" optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "protobuf-5.29.2-cp310-abi3-win32.whl", hash = "sha256:c12ba8249f5624300cf51c3d0bfe5be71a60c63e4dcf51ffe9a68771d958c851"}, {file = "protobuf-5.29.2-cp310-abi3-win_amd64.whl", hash = "sha256:842de6d9241134a973aab719ab42b008a18a90f9f07f06ba480df268f86432f9"}, @@ -1335,6 +1402,7 @@ version = "6.1.1" description = "Cross-platform lib for process and system monitoring in Python." optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,>=2.7" +groups = ["dev"] files = [ {file = "psutil-6.1.1-cp27-cp27m-macosx_10_9_x86_64.whl", hash = "sha256:9ccc4316f24409159897799b83004cb1e24f9819b0dcf9c0b68bdcb6cefee6a8"}, {file = "psutil-6.1.1-cp27-cp27m-manylinux2010_i686.whl", hash = "sha256:ca9609c77ea3b8481ab005da74ed894035936223422dc591d6772b147421f777"}, @@ -1365,6 +1433,7 @@ version = "2.10.4" description = "Data validation using Python type hints" optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "pydantic-2.10.4-py3-none-any.whl", hash = "sha256:597e135ea68be3a37552fb524bc7d0d66dcf93d395acd93a00682f1efcb8ee3d"}, {file = "pydantic-2.10.4.tar.gz", hash = "sha256:82f12e9723da6de4fe2ba888b5971157b3be7ad914267dea8f05f82b28254f06"}, @@ -1385,6 +1454,7 @@ version = "2.27.2" description = "Core functionality for Pydantic validation and serialization" optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "pydantic_core-2.27.2-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:2d367ca20b2f14095a8f4fa1210f5a7b78b8a20009ecced6b12818f455b1e9fa"}, {file = "pydantic_core-2.27.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:491a2b73db93fab69731eaee494f320faa4e093dbed776be1a829c2eb222c34c"}, @@ -1491,12 +1561,34 @@ files = [ [package.dependencies] typing-extensions = ">=4.6.0,<4.7.0 || >4.7.0" +[[package]] +name = "pydantic-settings" +version = "2.7.1" +description = "Settings management using Pydantic" +optional = false +python-versions = ">=3.8" +groups = ["main"] +files = [ + {file = "pydantic_settings-2.7.1-py3-none-any.whl", hash = "sha256:590be9e6e24d06db33a4262829edef682500ef008565a969c73d39d5f8bfb3fd"}, + {file = "pydantic_settings-2.7.1.tar.gz", hash = "sha256:10c9caad35e64bfb3c2fbf70a078c0e25cc92499782e5200747f942a065dec93"}, +] + +[package.dependencies] +pydantic = ">=2.7.0" +python-dotenv = ">=0.21.0" + +[package.extras] +azure-key-vault = ["azure-identity (>=1.16.0)", "azure-keyvault-secrets (>=4.8.0)"] +toml = ["tomli (>=2.0.1)"] +yaml = ["pyyaml (>=6.0.1)"] + [[package]] name = "pytest" version = "8.3.4" description = "pytest: simple powerful testing with Python" optional = false python-versions = ">=3.8" +groups = ["dev", "test"] files = [ {file = "pytest-8.3.4-py3-none-any.whl", hash = "sha256:50e16d954148559c9a74109af1eaf0c945ba2d8f30f0a3d3335edde19788b6f6"}, {file = "pytest-8.3.4.tar.gz", hash = "sha256:965370d062bce11e73868e0335abac31b4d3de0e82f4007408d242b4f8610761"}, @@ -1519,6 +1611,7 @@ version = "0.23.8" description = "Pytest support for asyncio" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "pytest_asyncio-0.23.8-py3-none-any.whl", hash = "sha256:50265d892689a5faefb84df80819d1ecef566eb3549cf915dfb33569359d1ce2"}, {file = "pytest_asyncio-0.23.8.tar.gz", hash = "sha256:759b10b33a6dc61cce40a8bd5205e302978bbbcc00e279a8b61d9a6a3c82e4d3"}, @@ -1531,12 +1624,32 @@ pytest = ">=7.0.0,<9" docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1.0)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] +[[package]] +name = "pytest-env" +version = "1.1.5" +description = "pytest plugin that allows you to add environment variables." +optional = false +python-versions = ">=3.8" +groups = ["test"] +files = [ + {file = "pytest_env-1.1.5-py3-none-any.whl", hash = "sha256:ce90cf8772878515c24b31cd97c7fa1f4481cd68d588419fd45f10ecaee6bc30"}, + {file = "pytest_env-1.1.5.tar.gz", hash = "sha256:91209840aa0e43385073ac464a554ad2947cc2fd663a9debf88d03b01e0cc1cf"}, +] + +[package.dependencies] +pytest = ">=8.3.3" +tomli = {version = ">=2.0.1", markers = "python_version < \"3.11\""} + +[package.extras] +testing = ["covdefaults (>=2.3)", "coverage (>=7.6.1)", "pytest-mock (>=3.14)"] + [[package]] name = "pytest-timeout" version = "2.3.1" description = "pytest plugin to abort hanging tests" optional = false python-versions = ">=3.7" +groups = ["test"] files = [ {file = "pytest-timeout-2.3.1.tar.gz", hash = "sha256:12397729125c6ecbdaca01035b9e5239d4db97352320af155b3f5de1ba5165d9"}, {file = "pytest_timeout-2.3.1-py3-none-any.whl", hash = "sha256:68188cb703edfc6a18fad98dc25a3c61e9f24d644b0b70f33af545219fc7813e"}, @@ -1551,6 +1664,7 @@ version = "2.9.0.post0" description = "Extensions to the standard Python datetime module" optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7" +groups = ["main"] files = [ {file = "python-dateutil-2.9.0.post0.tar.gz", hash = "sha256:37dd54208da7e1cd875388217d5e00ebd4179249f90fb72437e91a35459a0ad3"}, {file = "python_dateutil-2.9.0.post0-py2.py3-none-any.whl", hash = "sha256:a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427"}, @@ -1565,6 +1679,7 @@ version = "1.0.1" description = "Read key-value pairs from a .env file and set them as environment variables" optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "python-dotenv-1.0.1.tar.gz", hash = "sha256:e324ee90a023d808f1959c46bcbc04446a10ced277783dc6ee09987c37ec10ca"}, {file = "python_dotenv-1.0.1-py3-none-any.whl", hash = "sha256:f7b63ef50f1b690dddf550d03497b66d609393b40b564ed0d674909a68ebf16a"}, @@ -1579,6 +1694,7 @@ version = "6.0.2" description = "YAML parser and emitter for Python" optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "PyYAML-6.0.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:0a9a2848a5b7feac301353437eb7d5957887edbf81d56e903999a75a3d743086"}, {file = "PyYAML-6.0.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:29717114e51c84ddfba879543fb232a6ed60086602313ca38cce623c1d62cfbf"}, @@ -1641,6 +1757,7 @@ version = "2.32.3" description = "Python HTTP for Humans." optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "requests-2.32.3-py3-none-any.whl", hash = "sha256:70761cfe03c773ceb22aa2f671b4757976145175cdfca038c02654d061d6dcc6"}, {file = "requests-2.32.3.tar.gz", hash = "sha256:55365417734eb18255590a9ff9eb97e9e1da868d4ccd6402399eaf68af20a760"}, @@ -1662,6 +1779,7 @@ version = "75.7.0" description = "Easily download, build, install, upgrade, and uninstall Python packages" optional = false python-versions = ">=3.9" +groups = ["main"] files = [ {file = "setuptools-75.7.0-py3-none-any.whl", hash = "sha256:84fb203f278ebcf5cd08f97d3fb96d3fbed4b629d500b29ad60d11e00769b183"}, {file = "setuptools-75.7.0.tar.gz", hash = "sha256:886ff7b16cd342f1d1defc16fc98c9ce3fde69e087a4e1983d7ab634e5f41f4f"}, @@ -1682,6 +1800,7 @@ version = "1.17.0" description = "Python 2 and 3 compatibility utilities" optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7" +groups = ["main"] files = [ {file = "six-1.17.0-py2.py3-none-any.whl", hash = "sha256:4721f391ed90541fddacab5acf947aa0d3dc7d27b2e1e8eda2be8970586c3274"}, {file = "six-1.17.0.tar.gz", hash = "sha256:ff70335d468e7eb6ec65b95b99d3a2836546063f63acc5171de367e834932a81"}, @@ -1693,6 +1812,7 @@ version = "9.0.0" description = "Retry code until it succeeds" optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "tenacity-9.0.0-py3-none-any.whl", hash = "sha256:93de0c98785b27fcf659856aa9f54bfbd399e29969b0621bc7f762bd441b4539"}, {file = "tenacity-9.0.0.tar.gz", hash = "sha256:807f37ca97d62aa361264d497b0e31e92b8027044942bfa756160d908320d73b"}, @@ -1708,6 +1828,8 @@ version = "2.2.1" description = "A lil' TOML parser" optional = false python-versions = ">=3.8" +groups = ["dev", "lint", "test"] +markers = "python_version < \"3.11\"" files = [ {file = "tomli-2.2.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:678e4fa69e4575eb77d103de3df8a895e1591b48e740211bd1067378c69e8249"}, {file = "tomli-2.2.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:023aa114dd824ade0100497eb2318602af309e5a55595f76b626d6d9f3b7b0a6"}, @@ -1749,17 +1871,31 @@ version = "5.29.1.20241207" description = "Typing stubs for protobuf" optional = false python-versions = ">=3.8" +groups = ["lint"] files = [ {file = "types_protobuf-5.29.1.20241207-py3-none-any.whl", hash = "sha256:92893c42083e9b718c678badc0af7a9a1307b92afe1599e5cba5f3d35b668b2f"}, {file = "types_protobuf-5.29.1.20241207.tar.gz", hash = "sha256:2ebcadb8ab3ef2e3e2f067e0882906d64ba0dc65fc5b0fd7a8b692315b4a0be9"}, ] +[[package]] +name = "types-psutil" +version = "6.1.0.20241221" +description = "Typing stubs for psutil" +optional = false +python-versions = ">=3.8" +groups = ["lint"] +files = [ + {file = "types_psutil-6.1.0.20241221-py3-none-any.whl", hash = "sha256:8498dbe13285a9ba7d4b2fa934c569cc380efc74e3dacdb34ae16d2cdf389ec3"}, + {file = "types_psutil-6.1.0.20241221.tar.gz", hash = "sha256:600f5a36bd5e0eb8887f0e3f3ff2cf154d90690ad8123c8a707bba4ab94d3185"}, +] + [[package]] name = "typing-extensions" version = "4.12.2" description = "Backported and Experimental Type Hints for Python 3.8+" optional = false python-versions = ">=3.8" +groups = ["main", "lint"] files = [ {file = "typing_extensions-4.12.2-py3-none-any.whl", hash = "sha256:04e5ca0351e0f3f85c6853954072df659d0d13fac324d0072316b67d7794700d"}, {file = "typing_extensions-4.12.2.tar.gz", hash = "sha256:1a7ead55c7e559dd4dee8856e3a88b41225abfe1ce8df57b7c13915fe121ffb8"}, @@ -1771,6 +1907,7 @@ version = "2.3.0" description = "HTTP library with thread-safe connection pooling, file post, and more." optional = false python-versions = ">=3.9" +groups = ["main"] files = [ {file = "urllib3-2.3.0-py3-none-any.whl", hash = "sha256:1cee9ad369867bfdbbb48b7dd50374c0967a0bb7710050facf0dd6911440e3df"}, {file = "urllib3-2.3.0.tar.gz", hash = "sha256:f8c5449b3cf0861679ce7e0503c7b44b5ec981bec0d1d3795a07f1ba96f0204d"}, @@ -1788,6 +1925,8 @@ version = "1.2.0" description = "A small Python utility to set file creation time on Windows" optional = false python-versions = ">=3.5" +groups = ["main"] +markers = "sys_platform == \"win32\"" files = [ {file = "win32_setctime-1.2.0-py3-none-any.whl", hash = "sha256:95d644c4e708aba81dc3704a116d8cbc974d70b3bdb8be1d150e36be6e9d1390"}, {file = "win32_setctime-1.2.0.tar.gz", hash = "sha256:ae1fdf948f5640aae05c511ade119313fb6a30d7eabe25fef9764dca5873c4c0"}, @@ -1802,6 +1941,7 @@ version = "1.17.0" description = "Module for decorators, wrappers and monkey patching." optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "wrapt-1.17.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:2a0c23b8319848426f305f9cb0c98a6e32ee68a36264f45948ccf8e7d2b941f8"}, {file = "wrapt-1.17.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b1ca5f060e205f72bec57faae5bd817a1560fcfc4af03f414b08fa29106b7e2d"}, @@ -1876,6 +2016,7 @@ version = "1.18.3" description = "Yet another URL library" optional = false python-versions = ">=3.9" +groups = ["main"] files = [ {file = "yarl-1.18.3-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:7df647e8edd71f000a5208fe6ff8c382a1de8edfbccdbbfe649d263de07d8c34"}, {file = "yarl-1.18.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:c69697d3adff5aa4f874b19c0e4ed65180ceed6318ec856ebc423aa5850d84f7"}, @@ -1972,6 +2113,7 @@ version = "3.21.0" description = "Backport of pathlib-compatible object wrapper for zip files" optional = false python-versions = ">=3.9" +groups = ["main"] files = [ {file = "zipp-3.21.0-py3-none-any.whl", hash = "sha256:ac1bbe05fd2991f160ebce24ffbac5f6d11d83dc90891255885223d42b3cd931"}, {file = "zipp-3.21.0.tar.gz", hash = "sha256:2c9958f6430a2040341a52eb608ed6dd93ef4392e02ffe219417c1b28b5dd1f4"}, @@ -1986,6 +2128,6 @@ test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools", type = ["pytest-mypy"] [metadata] -lock-version = "2.0" +lock-version = "2.1" python-versions = "^3.10" -content-hash = "414d63b255f80d13260cb3a9ecce29f782af46280bba79395554595a47c42f05" +content-hash = "9ca2d219cdeffdf5e53d2b7e2b0ff3263eace96e98f992c9a10ab2a7baf737c1" diff --git a/pyproject.toml b/pyproject.toml index 27dd1d8c..4943f5fb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,12 +9,12 @@ include = ["hatchet_sdk/py.typed"] [tool.poetry.dependencies] python = "^3.10" grpcio = [ - { version = ">=1.64.1, !=1.68.*", markers = "python_version < '3.13'" }, - { version = ">=1.69.0", markers = "python_version >= '3.13'" }, + { version = ">=1.64.1, !=1.68.*", markers = "python_version < '3.13'" }, + { version = ">=1.69.0", markers = "python_version >= '3.13'" }, ] grpcio-tools = [ - { version = ">=1.64.1, !=1.68.*", markers = "python_version < '3.13'" }, - { version = ">=1.69.0", markers = "python_version >= '3.13'" }, + { version = ">=1.64.1, !=1.68.*", markers = "python_version < '3.13'" }, + { version = ">=1.69.0", markers = "python_version >= '3.13'" }, ] python-dotenv = "^1.0.0" protobuf = "^5.29.1" @@ -36,20 +36,24 @@ opentelemetry-distro = ">=0.49b0" opentelemetry-exporter-otlp = "^1.28.0" opentelemetry-exporter-otlp-proto-http = "^1.28.0" prometheus-client = "^0.21.1" +pydantic-settings = "^2.7.1" [tool.poetry.group.dev.dependencies] pytest = "^8.2.2" pytest-asyncio = "^0.23.8" psutil = "^6.0.0" +grpc-stubs = "^1.53.0.5" [tool.poetry.group.lint.dependencies] mypy = "^1.14.0" types-protobuf = "^5.28.3.20241030" black = "^24.10.0" isort = "^5.13.2" +types-psutil = "^6.1.0.20241221" [tool.poetry.group.test.dependencies] pytest-timeout = "^2.3.1" +pytest-env = "^1.1.5" [build-system] requires = ["poetry-core"] @@ -57,19 +61,20 @@ build-backend = "poetry.core.masonry.api" [tool.pytest.ini_options] log_cli = true +env = ["HATCHET_CLIENT_TLS_STRATEGY=none"] [tool.isort] profile = "black" known_third_party = [ - "grpcio", - "grpcio_tools", - "loguru", - "protobuf", - "pydantic", - "python_dotenv", - "python_dateutil", - "pyyaml", - "urllib3", + "grpcio", + "grpcio_tools", + "loguru", + "protobuf", + "pydantic", + "python_dotenv", + "python_dateutil", + "pyyaml", + "urllib3", ] extend_skip = ["hatchet_sdk/contracts/"] @@ -77,26 +82,28 @@ extend_skip = ["hatchet_sdk/contracts/"] extend_exclude = "hatchet_sdk/contracts/" [tool.mypy] +files = ["."] +exclude = [ + "hatchet_sdk/clients/rest/api/*", + "hatchet_sdk/clients/rest/models/*", + "hatchet_sdk/contracts", + "hatchet_sdk/clients/rest/api_client.py", + "hatchet_sdk/clients/rest/configuration.py", + "hatchet_sdk/clients/rest/exceptions.py", + "hatchet_sdk/clients/rest/rest.py", +] strict = true -files = [ - "hatchet_sdk/hatchet.py", - "hatchet_sdk/worker/worker.py", - "hatchet_sdk/context/context.py", - "hatchet_sdk/worker/runner/runner.py", - "hatchet_sdk/workflow.py", - "hatchet_sdk/utils/serialization.py", - "hatchet_sdk/utils/tracing.py", - "hatchet_sdk/utils/types.py", - "hatchet_sdk/utils/backoff.py", - "examples/**/*.py", - "hatchet_sdk/clients/rest/models/workflow_list.py", - "hatchet_sdk/clients/rest/models/workflow_run.py", - "hatchet_sdk/context/worker_context.py", - "hatchet_sdk/clients/dispatcher/dispatcher.py", + +[tool.ruff] +exclude = [ + "hatchet_sdk/clients/rest/api/*", + "hatchet_sdk/clients/rest/models/*", + "hatchet_sdk/contracts", + "hatchet_sdk/clients/rest/api_client.py", + "hatchet_sdk/clients/rest/configuration.py", + "hatchet_sdk/clients/rest/exceptions.py", + "hatchet_sdk/clients/rest/rest.py", ] -follow_imports = "silent" -disable_error_code = ["unused-coroutine"] -explicit_package_bases = true [tool.poetry.scripts] api = "examples.api.api:main" @@ -121,4 +128,3 @@ existing_loop = "examples.worker_existing_loop.worker:main" bulk_fanout = "examples.bulk_fanout.worker:main" retries_with_backoff = "examples.retries_with_backoff.worker:main" pydantic = "examples.pydantic.worker:main" -v2_simple = "examples.v2.simple.worker:main" diff --git a/tests/test_client.py b/tests/test_client.py new file mode 100644 index 00000000..43d7d27a --- /dev/null +++ b/tests/test_client.py @@ -0,0 +1,15 @@ +from hatchet_sdk.loader import ClientConfig + + +def test_client_initialization_from_defaults() -> None: + assert isinstance(ClientConfig(), ClientConfig) + + +def test_client_host_port_overrides() -> None: + host_port = "localhost:8080" + with_host_port = ClientConfig(host_port=host_port) + assert with_host_port.host_port == host_port + assert with_host_port.server_url == host_port + + assert ClientConfig().host_port != host_port + assert ClientConfig().server_url != host_port