From ef37b31bd0c188151c6e0d28388096ed483ce8a7 Mon Sep 17 00:00:00 2001 From: mrkaye97 Date: Mon, 13 Jan 2025 13:10:43 -0500 Subject: [PATCH 01/53] feat: client config overhaul --- hatchet_sdk/loader.py | 329 ++++++++++++------------------------------ 1 file changed, 95 insertions(+), 234 deletions(-) diff --git a/hatchet_sdk/loader.py b/hatchet_sdk/loader.py index d754c2ae..7ec30f06 100644 --- a/hatchet_sdk/loader.py +++ b/hatchet_sdk/loader.py @@ -1,246 +1,107 @@ -import json import os from logging import Logger, getLogger -from typing import Dict, Optional - -import yaml - -from .token import get_addresses_from_jwt, get_tenant_id_from_jwt - - -class ClientTLSConfig: - def __init__( - self, - tls_strategy: str, - cert_file: str, - key_file: str, - ca_file: str, - server_name: str, - ): - self.tls_strategy = tls_strategy - self.cert_file = cert_file - self.key_file = key_file - self.ca_file = ca_file - self.server_name = server_name - - -class ClientConfig: - logInterceptor: Logger - - def __init__( - self, - tenant_id: str = None, - tls_config: ClientTLSConfig = None, - token: str = None, - host_port: str = "localhost:7070", - server_url: str = "https://app.dev.hatchet-tools.com", - namespace: str = None, - listener_v2_timeout: int = None, - logger: Logger = None, - grpc_max_recv_message_length: int = 4 * 1024 * 1024, # 4MB - grpc_max_send_message_length: int = 4 * 1024 * 1024, # 4MB - otel_exporter_oltp_endpoint: str | None = None, - otel_service_name: str | None = None, - otel_exporter_oltp_headers: dict[str, str] | None = None, - otel_exporter_oltp_protocol: str | None = None, - worker_healthcheck_port: int | None = None, - worker_healthcheck_enabled: bool | None = None, - ): - self.tenant_id = tenant_id - self.tls_config = tls_config - self.host_port = host_port - self.token = token - self.server_url = server_url - self.namespace = "" - self.logInterceptor = logger - self.grpc_max_recv_message_length = grpc_max_recv_message_length - self.grpc_max_send_message_length = grpc_max_send_message_length - self.otel_exporter_oltp_endpoint = otel_exporter_oltp_endpoint - self.otel_service_name = otel_service_name - self.otel_exporter_oltp_headers = otel_exporter_oltp_headers - self.otel_exporter_oltp_protocol = otel_exporter_oltp_protocol - self.worker_healthcheck_port = worker_healthcheck_port - self.worker_healthcheck_enabled = worker_healthcheck_enabled - - if not self.logInterceptor: - self.logInterceptor = getLogger() - - # case on whether the namespace already has a trailing underscore - if namespace and not namespace.endswith("_"): - self.namespace = f"{namespace}_" - elif namespace: - self.namespace = namespace - - self.namespace = self.namespace.lower() - - self.listener_v2_timeout = listener_v2_timeout - - -class ConfigLoader: - def __init__(self, directory: str): - self.directory = directory - - def load_client_config(self, defaults: ClientConfig) -> ClientConfig: - config_file_path = os.path.join(self.directory, "client.yaml") - config_data: object = {"tls": {}} - - # determine if client.yaml exists - if os.path.exists(config_file_path): - with open(config_file_path, "r") as file: - config_data = yaml.safe_load(file) - - def get_config_value(key, env_var): - if key in config_data: - return config_data[key] - - if self._get_env_var(env_var) is not None: - return self._get_env_var(env_var) - - return getattr(defaults, key, None) - - namespace = get_config_value("namespace", "HATCHET_CLIENT_NAMESPACE") - - tenant_id = get_config_value("tenantId", "HATCHET_CLIENT_TENANT_ID") - token = get_config_value("token", "HATCHET_CLIENT_TOKEN") - listener_v2_timeout = get_config_value( - "listener_v2_timeout", "HATCHET_CLIENT_LISTENER_V2_TIMEOUT" - ) - listener_v2_timeout = int(listener_v2_timeout) if listener_v2_timeout else None - +from typing import cast + +from pydantic import BaseModel, ValidationError, ValidationInfo, field_validator + +from .token import get_tenant_id_from_jwt + + +class ClientTLSConfig(BaseModel): + tls_strategy: str + cert_file: str | None + key_file: str | None + ca_file: str | None + server_name: str + + +def _load_tls_config(host_port: str) -> ClientTLSConfig: + return ClientTLSConfig( + tls_strategy=os.getenv("HATCHET_CLIENT_TLS_STRATEGY", "tls"), + cert_file=os.getenv("HATCHET_CLIENT_TLS_CERT_FILE"), + key_file=os.getenv("HATCHET_CLIENT_TLS_KEY_FILE"), + ca_file=os.getenv("HATCHET_CLIENT_TLS_ROOT_CA_FILE"), + server_name=os.getenv( + "HATCHET_CLIENT_TLS_SERVER_NAME", host_port.split(":")[0] + ), + ) + + +def parse_listener_timeout(timeout: str | None) -> int | None: + if timeout is None: + return None + + return int(timeout) + + +class ClientConfig(BaseModel): + token: str = os.getenv("HATCHET_CLIENT_TOKEN", "") + logger: Logger = getLogger() + tenant_id: str = os.getenv("HATCHET_CLIENT_TENANT_ID", "") + host_port: str = os.getenv("HATCHET_CLIENT_HOST_PORT", "localhost:7070") + tls_config: ClientTLSConfig = _load_tls_config(host_port) + server_url: str = "https://app.dev.hatchet-tools.com" + namespace: str = os.getenv("HATCHET_CLIENT_NAMESPACE", "") + listener_v2_timeout: int | None = parse_listener_timeout( + os.getenv("HATCHET_CLIENT_LISTENER_V2_TIMEOUT") + ) + grpc_max_recv_message_length: int = int( + os.getenv("HATCHET_CLIENT_GRPC_MAX_RECV_MESSAGE_LENGTH", 4 * 1024 * 1024) + ) # 4MB + grpc_max_send_message_length: int = int( + os.getenv("HATCHET_CLIENT_GRPC_MAX_SEND_MESSAGE_LENGTH", 4 * 1024 * 1024) + ) # 4MB + otel_exporter_oltp_endpoint: str | None = os.getenv( + "HATCHET_CLIENT_OTEL_EXPORTER_OTLP_ENDPOINT" + ) + otel_service_name: str | None = os.getenv("HATCHET_CLIENT_OTEL_SERVICE_NAME") + otel_exporter_oltp_headers: str | None = os.getenv( + "HATCHET_CLIENT_OTEL_EXPORTER_OTLP_HEADERS" + ) + otel_exporter_oltp_protocol: str | None = os.getenv( + "HATCHET_CLIENT_OTEL_EXPORTER_OTLP_PROTOCOL" + ) + worker_healthcheck_port: int = int( + os.getenv("HATCHET_CLIENT_WORKER_HEALTHCHECK_PORT", 8001) + ) + worker_healthcheck_enabled: bool = ( + os.getenv("HATCHET_CLIENT_WORKER_HEALTHCHECK_ENABLED", "False") == "True" + ) + + @field_validator("token", mode="after") + @classmethod + def validate_token(cls, token: str) -> str: if not token: - raise ValueError( - "Token must be set via HATCHET_CLIENT_TOKEN environment variable" - ) - - host_port = get_config_value("hostPort", "HATCHET_CLIENT_HOST_PORT") - server_url: str | None = None + raise ValidationError("Token must be set") - grpc_max_recv_message_length = get_config_value( - "grpc_max_recv_message_length", - "HATCHET_CLIENT_GRPC_MAX_RECV_MESSAGE_LENGTH", - ) - grpc_max_send_message_length = get_config_value( - "grpc_max_send_message_length", - "HATCHET_CLIENT_GRPC_MAX_SEND_MESSAGE_LENGTH", - ) + return token - if grpc_max_recv_message_length: - grpc_max_recv_message_length = int(grpc_max_recv_message_length) + @field_validator("namespace", mode="after") + @classmethod + def validate_namespace(cls, namespace: str) -> str: + if not namespace.endswith("_"): + namespace = f"{namespace}_" - if grpc_max_send_message_length: - grpc_max_send_message_length = int(grpc_max_send_message_length) + return namespace.lower() - if not host_port: - # extract host and port from token - server_url, grpc_broadcast_address = get_addresses_from_jwt(token) - host_port = grpc_broadcast_address + @field_validator("tenant_id", mode="after") + @classmethod + def validate_tenant_id(cls, tenant_id: str, info: ValidationInfo) -> str: + token = cast(str | None, info.data.get("token")) if not tenant_id: - tenant_id = get_tenant_id_from_jwt(token) - - tls_config = self._load_tls_config(config_data["tls"], host_port) - - otel_exporter_oltp_endpoint = get_config_value( - "otel_exporter_oltp_endpoint", "HATCHET_CLIENT_OTEL_EXPORTER_OTLP_ENDPOINT" - ) + if not token: + raise ValidationError( + "Token must be set before attempting to infer tenant ID" + ) - otel_service_name = get_config_value( - "otel_service_name", "HATCHET_CLIENT_OTEL_SERVICE_NAME" - ) + return get_tenant_id_from_jwt(token) - _oltp_headers = get_config_value( - "otel_exporter_oltp_headers", "HATCHET_CLIENT_OTEL_EXPORTER_OTLP_HEADERS" - ) + return tenant_id - if _oltp_headers: - try: - otel_header_key, api_key = _oltp_headers.split("=", maxsplit=1) - otel_exporter_oltp_headers = {otel_header_key: api_key} - except ValueError: - raise ValueError( - "HATCHET_CLIENT_OTEL_EXPORTER_OTLP_HEADERS must be in the format `key=value`" - ) - else: - otel_exporter_oltp_headers = None - - otel_exporter_oltp_protocol = get_config_value( - "otel_exporter_oltp_protocol", "HATCHET_CLIENT_OTEL_EXPORTER_OTLP_PROTOCOL" - ) - - worker_healthcheck_port = int( - get_config_value( - "worker_healthcheck_port", "HATCHET_CLIENT_WORKER_HEALTHCHECK_PORT" - ) - or 8001 - ) - - worker_healthcheck_enabled = ( - str( - get_config_value( - "worker_healthcheck_port", - "HATCHET_CLIENT_WORKER_HEALTHCHECK_ENABLED", - ) - ) - == "True" - ) - - return ClientConfig( - tenant_id=tenant_id, - tls_config=tls_config, - token=token, - host_port=host_port, - server_url=server_url, - namespace=namespace, - listener_v2_timeout=listener_v2_timeout, - logger=defaults.logInterceptor, - grpc_max_recv_message_length=grpc_max_recv_message_length, - grpc_max_send_message_length=grpc_max_send_message_length, - otel_exporter_oltp_endpoint=otel_exporter_oltp_endpoint, - otel_service_name=otel_service_name, - otel_exporter_oltp_headers=otel_exporter_oltp_headers, - otel_exporter_oltp_protocol=otel_exporter_oltp_protocol, - worker_healthcheck_port=worker_healthcheck_port, - worker_healthcheck_enabled=worker_healthcheck_enabled, - ) - - def _load_tls_config(self, tls_data: Dict, host_port) -> ClientTLSConfig: - tls_strategy = ( - tls_data["tlsStrategy"] - if "tlsStrategy" in tls_data - else self._get_env_var("HATCHET_CLIENT_TLS_STRATEGY") - ) - - if not tls_strategy: - tls_strategy = "tls" - - cert_file = ( - tls_data["tlsCertFile"] - if "tlsCertFile" in tls_data - else self._get_env_var("HATCHET_CLIENT_TLS_CERT_FILE") - ) - key_file = ( - tls_data["tlsKeyFile"] - if "tlsKeyFile" in tls_data - else self._get_env_var("HATCHET_CLIENT_TLS_KEY_FILE") - ) - ca_file = ( - tls_data["tlsRootCAFile"] - if "tlsRootCAFile" in tls_data - else self._get_env_var("HATCHET_CLIENT_TLS_ROOT_CA_FILE") - ) - - server_name = ( - tls_data["tlsServerName"] - if "tlsServerName" in tls_data - else self._get_env_var("HATCHET_CLIENT_TLS_SERVER_NAME") - ) - - # if server_name is not set, use the host from the host_port - if not server_name: - server_name = host_port.split(":")[0] - - return ClientTLSConfig(tls_strategy, cert_file, key_file, ca_file, server_name) - - @staticmethod - def _get_env_var(env_var: str, default: Optional[str] = None) -> str: - return os.environ.get(env_var, default) + ## TODO: Fix host port overrides here + ## Old code: + ## if not host_port: + ## ## extract host and port from token + ## server_url, grpc_broadcast_address = get_addresses_from_jwt(token) + ## host_port = grpc_broadcast_address \ No newline at end of file From 179b17862eeee1702800c0143bca41003034a894 Mon Sep 17 00:00:00 2001 From: mrkaye97 Date: Mon, 13 Jan 2025 13:11:09 -0500 Subject: [PATCH 02/53] fix: tracing headers --- hatchet_sdk/utils/tracing.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/hatchet_sdk/utils/tracing.py b/hatchet_sdk/utils/tracing.py index afc398f7..72509f6f 100644 --- a/hatchet_sdk/utils/tracing.py +++ b/hatchet_sdk/utils/tracing.py @@ -16,6 +16,18 @@ OTEL_CARRIER_KEY = "__otel_carrier" +def parse_headers(headers: str | None) -> dict[str, str]: + if headers is None: + return {} + + try: + otel_header_key, api_key = headers.split("=", maxsplit=1) + + return {otel_header_key: api_key} + except ValueError: + raise ValueError("OTLP headers must be in the format `key=value`") + + @cache def create_tracer(config: ClientConfig) -> Tracer: ## TODO: Figure out how to specify protocol here @@ -27,7 +39,7 @@ def create_tracer(config: ClientConfig) -> Tracer: processor = BatchSpanProcessor( OTLPSpanExporter( endpoint=config.otel_exporter_oltp_endpoint, - headers=config.otel_exporter_oltp_headers, + headers=parse_headers(config.otel_exporter_oltp_headers), ), ) @@ -67,4 +79,4 @@ def parse_carrier_from_metadata(metadata: dict[str, Any] | None) -> Context | No TraceContextTextMapPropagator().extract(_ctx) if (_ctx := metadata.get(OTEL_CARRIER_KEY)) else None - ) + ) \ No newline at end of file From 98237216427e43488b8d0584b4756d797bc44643 Mon Sep 17 00:00:00 2001 From: mrkaye97 Date: Mon, 13 Jan 2025 13:12:33 -0500 Subject: [PATCH 03/53] fix: refs to loader --- hatchet_sdk/client.py | 5 ++--- hatchet_sdk/hatchet.py | 8 ++------ pyproject.toml | 1 + 3 files changed, 5 insertions(+), 9 deletions(-) diff --git a/hatchet_sdk/client.py b/hatchet_sdk/client.py index 45dfd394..4e340388 100644 --- a/hatchet_sdk/client.py +++ b/hatchet_sdk/client.py @@ -37,11 +37,10 @@ def from_environment( loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) - config: ClientConfig = ConfigLoader(".").load_client_config(defaults) for opt_function in opts_functions: - opt_function(config) + opt_function(defaults) - return cls.from_config(config, debug) + return cls.from_config(defaults, debug) @classmethod def from_config( diff --git a/hatchet_sdk/hatchet.py b/hatchet_sdk/hatchet.py index bf0e9089..e73dab64 100644 --- a/hatchet_sdk/hatchet.py +++ b/hatchet_sdk/hatchet.py @@ -16,7 +16,7 @@ from hatchet_sdk.features.cron import CronClient from hatchet_sdk.features.scheduled import ScheduledClient from hatchet_sdk.labels import DesiredWorkerLabel -from hatchet_sdk.loader import ClientConfig, ConfigLoader +from hatchet_sdk.loader import ClientConfig from hatchet_sdk.rate_limit import RateLimit from hatchet_sdk.v2.callable import HatchetCallable @@ -187,12 +187,8 @@ class HatchetRest: rest (RestApi): Interface for REST API operations. """ - rest: RestApi - def __init__(self, config: ClientConfig = ClientConfig()): - _config: ClientConfig = ConfigLoader(".").load_client_config(config) - self.rest = RestApi(_config.server_url, _config.token, _config.tenant_id) - + self.rest = RestApi(config.server_url, config.token, config.tenant_id) class Hatchet: """ diff --git a/pyproject.toml b/pyproject.toml index 69380b67..7207c430 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -93,6 +93,7 @@ files = [ "hatchet_sdk/clients/rest/models/workflow_run.py", "hatchet_sdk/context/worker_context.py", "hatchet_sdk/clients/dispatcher/dispatcher.py", + "hatchet_sdk/loader.py", ] follow_imports = "silent" disable_error_code = ["unused-coroutine"] From 898353612a29aaaa4478d42e0be6522aa8a79c18 Mon Sep 17 00:00:00 2001 From: mrkaye97 Date: Mon, 13 Jan 2025 13:13:43 -0500 Subject: [PATCH 04/53] fix: lint --- .gitignore | 1 + hatchet_sdk/hatchet.py | 1 + hatchet_sdk/loader.py | 2 +- hatchet_sdk/utils/tracing.py | 2 +- 4 files changed, 4 insertions(+), 2 deletions(-) diff --git a/.gitignore b/.gitignore index 6b15a2af..a8fca96c 100644 --- a/.gitignore +++ b/.gitignore @@ -162,3 +162,4 @@ cython_debug/ #.idea/ openapitools.json +.python-version diff --git a/hatchet_sdk/hatchet.py b/hatchet_sdk/hatchet.py index e73dab64..f71e208d 100644 --- a/hatchet_sdk/hatchet.py +++ b/hatchet_sdk/hatchet.py @@ -190,6 +190,7 @@ class HatchetRest: def __init__(self, config: ClientConfig = ClientConfig()): self.rest = RestApi(config.server_url, config.token, config.tenant_id) + class Hatchet: """ Main client for interacting with the Hatchet SDK. diff --git a/hatchet_sdk/loader.py b/hatchet_sdk/loader.py index 7ec30f06..265aa4f0 100644 --- a/hatchet_sdk/loader.py +++ b/hatchet_sdk/loader.py @@ -104,4 +104,4 @@ def validate_tenant_id(cls, tenant_id: str, info: ValidationInfo) -> str: ## if not host_port: ## ## extract host and port from token ## server_url, grpc_broadcast_address = get_addresses_from_jwt(token) - ## host_port = grpc_broadcast_address \ No newline at end of file + ## host_port = grpc_broadcast_address diff --git a/hatchet_sdk/utils/tracing.py b/hatchet_sdk/utils/tracing.py index 72509f6f..19341780 100644 --- a/hatchet_sdk/utils/tracing.py +++ b/hatchet_sdk/utils/tracing.py @@ -79,4 +79,4 @@ def parse_carrier_from_metadata(metadata: dict[str, Any] | None) -> Context | No TraceContextTextMapPropagator().extract(_ctx) if (_ctx := metadata.get(OTEL_CARRIER_KEY)) else None - ) \ No newline at end of file + ) From baa85671f71b75263e74731c709aa0ebcc75220c Mon Sep 17 00:00:00 2001 From: mrkaye97 Date: Mon, 13 Jan 2025 14:14:03 -0500 Subject: [PATCH 05/53] debug: host port and server validation --- hatchet_sdk/client.py | 4 +-- hatchet_sdk/loader.py | 67 +++++++++++++++++++++++++++++++++++-------- hatchet_sdk/token.py | 7 +++-- 3 files changed, 61 insertions(+), 17 deletions(-) diff --git a/hatchet_sdk/client.py b/hatchet_sdk/client.py index 4e340388..a956a1d5 100644 --- a/hatchet_sdk/client.py +++ b/hatchet_sdk/client.py @@ -12,7 +12,7 @@ from .clients.dispatcher.dispatcher import DispatcherClient, new_dispatcher from .clients.events import EventClient, new_event from .clients.rest_client import RestApi -from .loader import ClientConfig, ConfigLoader +from .loader import ClientConfig class Client: @@ -102,7 +102,7 @@ def __init__( self.config = config self.listener = RunEventListenerClient(config) self.workflow_listener = workflow_listener - self.logInterceptor = config.logInterceptor + self.logInterceptor = config.logger self.debug = debug diff --git a/hatchet_sdk/loader.py b/hatchet_sdk/loader.py index 265aa4f0..dc2453ab 100644 --- a/hatchet_sdk/loader.py +++ b/hatchet_sdk/loader.py @@ -1,10 +1,11 @@ +import json import os from logging import Logger, getLogger -from typing import cast +from typing import Any, cast -from pydantic import BaseModel, ValidationError, ValidationInfo, field_validator +from pydantic import BaseModel, ConfigDict, ValidationInfo, field_validator -from .token import get_tenant_id_from_jwt +from .token import get_addresses_from_jwt, get_tenant_id_from_jwt class ClientTLSConfig(BaseModel): @@ -15,15 +16,21 @@ class ClientTLSConfig(BaseModel): server_name: str -def _load_tls_config(host_port: str) -> ClientTLSConfig: +def _load_tls_config(host_port: str | None = None) -> ClientTLSConfig: + server_name = os.getenv("HATCHET_CLIENT_TLS_SERVER_NAME") + + if not server_name and host_port: + server_name = host_port.split(":")[0] + + if not server_name: + server_name = "localhost" + return ClientTLSConfig( tls_strategy=os.getenv("HATCHET_CLIENT_TLS_STRATEGY", "tls"), cert_file=os.getenv("HATCHET_CLIENT_TLS_CERT_FILE"), key_file=os.getenv("HATCHET_CLIENT_TLS_KEY_FILE"), ca_file=os.getenv("HATCHET_CLIENT_TLS_ROOT_CA_FILE"), - server_name=os.getenv( - "HATCHET_CLIENT_TLS_SERVER_NAME", host_port.split(":")[0] - ), + server_name=server_name, ) @@ -35,11 +42,13 @@ def parse_listener_timeout(timeout: str | None) -> int | None: class ClientConfig(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True, validate_default=True) + token: str = os.getenv("HATCHET_CLIENT_TOKEN", "") logger: Logger = getLogger() tenant_id: str = os.getenv("HATCHET_CLIENT_TENANT_ID", "") host_port: str = os.getenv("HATCHET_CLIENT_HOST_PORT", "localhost:7070") - tls_config: ClientTLSConfig = _load_tls_config(host_port) + tls_config: ClientTLSConfig = _load_tls_config() server_url: str = "https://app.dev.hatchet-tools.com" namespace: str = os.getenv("HATCHET_CLIENT_NAMESPACE", "") listener_v2_timeout: int | None = parse_listener_timeout( @@ -72,7 +81,7 @@ class ClientConfig(BaseModel): @classmethod def validate_token(cls, token: str) -> str: if not token: - raise ValidationError("Token must be set") + return "" return token @@ -91,14 +100,48 @@ def validate_tenant_id(cls, tenant_id: str, info: ValidationInfo) -> str: if not tenant_id: if not token: - raise ValidationError( - "Token must be set before attempting to infer tenant ID" - ) + return "" return get_tenant_id_from_jwt(token) return tenant_id + @field_validator("host_port", mode="after") + @classmethod + def validate_host_port(cls, host_port: str, info: ValidationInfo) -> str: + token = cast(str | None, info.data.get("token")) + + if not token: + return host_port + + _, grpc_broadcast_address = get_addresses_from_jwt(token) + + return grpc_broadcast_address + + @field_validator("server_url", mode="after") + @classmethod + def validate_server_url(cls, server_url: str, info: ValidationInfo) -> str: + token = cast(str | None, info.data.get("token")) + + if not token: + return server_url + + _server_url, _ = get_addresses_from_jwt(token) + + return _server_url + + @field_validator("tls_config", mode="after") + @classmethod + def validate_tls_config( + cls, tls_config: ClientTLSConfig, info: ValidationInfo + ) -> ClientTLSConfig: + host_port = cast(str, info.data.get("host_port")) + + return _load_tls_config(host_port) + + def __hash__(self) -> int: + return hash(json.dumps(self.model_dump(), default=str)) + ## TODO: Fix host port overrides here ## Old code: ## if not host_port: diff --git a/hatchet_sdk/token.py b/hatchet_sdk/token.py index 313a6671..0b539d76 100644 --- a/hatchet_sdk/token.py +++ b/hatchet_sdk/token.py @@ -1,5 +1,6 @@ import base64 import json +from typing import Any def get_tenant_id_from_jwt(token: str) -> str: @@ -8,13 +9,13 @@ def get_tenant_id_from_jwt(token: str) -> str: return claims.get("sub") -def get_addresses_from_jwt(token: str) -> (str, str): +def get_addresses_from_jwt(token: str) -> tuple[str, str]: claims = extract_claims_from_jwt(token) - return claims.get("server_url"), claims.get("grpc_broadcast_address") + return claims["server_url"], claims["grpc_broadcast_address"] -def extract_claims_from_jwt(token: str): +def extract_claims_from_jwt(token: str) -> dict[str, Any]: parts = token.split(".") if len(parts) != 3: raise ValueError("Invalid token format") From 618671a4e685bfa293bb03c5a51e0e3f156eddfd Mon Sep 17 00:00:00 2001 From: mrkaye97 Date: Mon, 13 Jan 2025 14:37:18 -0500 Subject: [PATCH 06/53] fix: raise errors if no token set --- hatchet_sdk/loader.py | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/hatchet_sdk/loader.py b/hatchet_sdk/loader.py index dc2453ab..374be36f 100644 --- a/hatchet_sdk/loader.py +++ b/hatchet_sdk/loader.py @@ -81,7 +81,7 @@ class ClientConfig(BaseModel): @classmethod def validate_token(cls, token: str) -> str: if not token: - return "" + raise ValueError("Token must be set") return token @@ -100,7 +100,7 @@ def validate_tenant_id(cls, tenant_id: str, info: ValidationInfo) -> str: if not tenant_id: if not token: - return "" + raise ValueError("Either the token or tenant_id must be set") return get_tenant_id_from_jwt(token) @@ -109,10 +109,7 @@ def validate_tenant_id(cls, tenant_id: str, info: ValidationInfo) -> str: @field_validator("host_port", mode="after") @classmethod def validate_host_port(cls, host_port: str, info: ValidationInfo) -> str: - token = cast(str | None, info.data.get("token")) - - if not token: - return host_port + token = cast(str, info.data.get("token")) _, grpc_broadcast_address = get_addresses_from_jwt(token) @@ -121,10 +118,7 @@ def validate_host_port(cls, host_port: str, info: ValidationInfo) -> str: @field_validator("server_url", mode="after") @classmethod def validate_server_url(cls, server_url: str, info: ValidationInfo) -> str: - token = cast(str | None, info.data.get("token")) - - if not token: - return server_url + token = cast(str, info.data.get("token")) _server_url, _ = get_addresses_from_jwt(token) From ed35943aa19cc0baabec7636a25efda7b3a4c31a Mon Sep 17 00:00:00 2001 From: mrkaye97 Date: Mon, 13 Jan 2025 16:38:32 -0500 Subject: [PATCH 07/53] fix: namespace prefixing --- hatchet_sdk/loader.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/hatchet_sdk/loader.py b/hatchet_sdk/loader.py index 374be36f..461b9214 100644 --- a/hatchet_sdk/loader.py +++ b/hatchet_sdk/loader.py @@ -88,6 +88,9 @@ def validate_token(cls, token: str) -> str: @field_validator("namespace", mode="after") @classmethod def validate_namespace(cls, namespace: str) -> str: + if not namespace: + return "" + if not namespace.endswith("_"): namespace = f"{namespace}_" From a66bcac40fdd8d39a469d3be4bec18899256ab2d Mon Sep 17 00:00:00 2001 From: mrkaye97 Date: Mon, 13 Jan 2025 16:50:50 -0500 Subject: [PATCH 08/53] cleanup: pydantic for parsing claims --- hatchet_sdk/token.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/hatchet_sdk/token.py b/hatchet_sdk/token.py index 0b539d76..58d34c65 100644 --- a/hatchet_sdk/token.py +++ b/hatchet_sdk/token.py @@ -1,21 +1,25 @@ import base64 -import json -from typing import Any +from pydantic import BaseModel -def get_tenant_id_from_jwt(token: str) -> str: - claims = extract_claims_from_jwt(token) - return claims.get("sub") +class Claims(BaseModel): + sub: str + server_url: str + grpc_broadcast_address: str + + +def get_tenant_id_from_jwt(token: str) -> str: + return extract_claims_from_jwt(token).sub def get_addresses_from_jwt(token: str) -> tuple[str, str]: claims = extract_claims_from_jwt(token) - return claims["server_url"], claims["grpc_broadcast_address"] + return claims.server_url, claims.grpc_broadcast_address -def extract_claims_from_jwt(token: str) -> dict[str, Any]: +def extract_claims_from_jwt(token: str) -> Claims: parts = token.split(".") if len(parts) != 3: raise ValueError("Invalid token format") @@ -23,6 +27,5 @@ def extract_claims_from_jwt(token: str) -> dict[str, Any]: claims_part = parts[1] claims_part += "=" * ((4 - len(claims_part) % 4) % 4) # Padding for base64 decoding claims_data = base64.urlsafe_b64decode(claims_part) - claims = json.loads(claims_data) - return claims + return Claims.model_validate_json(claims_data) From eada02205481464ba0ba5f319a0999d9f707fa0a Mon Sep 17 00:00:00 2001 From: mrkaye97 Date: Mon, 13 Jan 2025 16:51:19 -0500 Subject: [PATCH 09/53] feat: add token to mypy --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 7207c430..e061ac4d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -94,6 +94,7 @@ files = [ "hatchet_sdk/context/worker_context.py", "hatchet_sdk/clients/dispatcher/dispatcher.py", "hatchet_sdk/loader.py", + "hatchet_sdk/token.py" ] follow_imports = "silent" disable_error_code = ["unused-coroutine"] From 085da4ae54ee7b858ad3c2cd7d23d490b7b41ba0 Mon Sep 17 00:00:00 2001 From: mrkaye97 Date: Mon, 13 Jan 2025 17:02:07 -0500 Subject: [PATCH 10/53] fix: pythonic dict construction for otel headers --- hatchet_sdk/utils/tracing.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/hatchet_sdk/utils/tracing.py b/hatchet_sdk/utils/tracing.py index 19341780..634c3995 100644 --- a/hatchet_sdk/utils/tracing.py +++ b/hatchet_sdk/utils/tracing.py @@ -21,9 +21,7 @@ def parse_headers(headers: str | None) -> dict[str, str]: return {} try: - otel_header_key, api_key = headers.split("=", maxsplit=1) - - return {otel_header_key: api_key} + return dict([headers.split("=", maxsplit=1)]) except ValueError: raise ValueError("OTLP headers must be in the format `key=value`") From b30b40f1784c5d7137c2d84c84d120f5f23aea26 Mon Sep 17 00:00:00 2001 From: mrkaye97 Date: Tue, 14 Jan 2025 10:37:40 -0500 Subject: [PATCH 11/53] cruft: any --- hatchet_sdk/loader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hatchet_sdk/loader.py b/hatchet_sdk/loader.py index 461b9214..658d71d1 100644 --- a/hatchet_sdk/loader.py +++ b/hatchet_sdk/loader.py @@ -1,7 +1,7 @@ import json import os from logging import Logger, getLogger -from typing import Any, cast +from typing import cast from pydantic import BaseModel, ConfigDict, ValidationInfo, field_validator From e7740ef1d3732ba9d46947ae9d80a7e5469dc4f8 Mon Sep 17 00:00:00 2001 From: mrkaye97 Date: Tue, 14 Jan 2025 10:41:29 -0500 Subject: [PATCH 12/53] feat: allow host_port overrides --- hatchet_sdk/loader.py | 31 +++++++++++++++++++++++-------- 1 file changed, 23 insertions(+), 8 deletions(-) diff --git a/hatchet_sdk/loader.py b/hatchet_sdk/loader.py index 658d71d1..7fc4eb63 100644 --- a/hatchet_sdk/loader.py +++ b/hatchet_sdk/loader.py @@ -41,14 +41,22 @@ def parse_listener_timeout(timeout: str | None) -> int | None: return int(timeout) +DEFAULT_HOST_PORT = "localhost:7070" + + class ClientConfig(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True, validate_default=True) token: str = os.getenv("HATCHET_CLIENT_TOKEN", "") logger: Logger = getLogger() tenant_id: str = os.getenv("HATCHET_CLIENT_TENANT_ID", "") - host_port: str = os.getenv("HATCHET_CLIENT_HOST_PORT", "localhost:7070") + + ## IMPORTANT: Order matters here. The validators run in the order that the + ## fields are defined in the model. So, we need to make sure that the + ## host_port is set before we try to load the tls_config and server_url + host_port: str = os.getenv("HATCHET_CLIENT_HOST_PORT", DEFAULT_HOST_PORT) tls_config: ClientTLSConfig = _load_tls_config() + server_url: str = "https://app.dev.hatchet-tools.com" namespace: str = os.getenv("HATCHET_CLIENT_NAMESPACE", "") listener_v2_timeout: int | None = parse_listener_timeout( @@ -112,6 +120,9 @@ def validate_tenant_id(cls, tenant_id: str, info: ValidationInfo) -> str: @field_validator("host_port", mode="after") @classmethod def validate_host_port(cls, host_port: str, info: ValidationInfo) -> str: + if host_port and host_port != DEFAULT_HOST_PORT: + return host_port + token = cast(str, info.data.get("token")) _, grpc_broadcast_address = get_addresses_from_jwt(token) @@ -121,6 +132,14 @@ def validate_host_port(cls, host_port: str, info: ValidationInfo) -> str: @field_validator("server_url", mode="after") @classmethod def validate_server_url(cls, server_url: str, info: ValidationInfo) -> str: + ## IMPORTANT: Order matters here. The validators run in the order that the + ## fields are defined in the model. So, we need to make sure that the + ## host_port is set before we try to load the server_url + host_port = cast(str, info.data.get("host_port")) + + if host_port and host_port != DEFAULT_HOST_PORT: + return host_port + token = cast(str, info.data.get("token")) _server_url, _ = get_addresses_from_jwt(token) @@ -132,16 +151,12 @@ def validate_server_url(cls, server_url: str, info: ValidationInfo) -> str: def validate_tls_config( cls, tls_config: ClientTLSConfig, info: ValidationInfo ) -> ClientTLSConfig: + ## IMPORTANT: Order matters here. This validator runs in the order + ## that the fields are defined in the model. So, we need to make sure + ## that the host_port is set before we try to load the tls_config host_port = cast(str, info.data.get("host_port")) return _load_tls_config(host_port) def __hash__(self) -> int: return hash(json.dumps(self.model_dump(), default=str)) - - ## TODO: Fix host port overrides here - ## Old code: - ## if not host_port: - ## ## extract host and port from token - ## server_url, grpc_broadcast_address = get_addresses_from_jwt(token) - ## host_port = grpc_broadcast_address From a6d6a45619977ecedcc6aa27c5d09284f3b4bdb0 Mon Sep 17 00:00:00 2001 From: mrkaye97 Date: Tue, 14 Jan 2025 11:30:29 -0500 Subject: [PATCH 13/53] feat: tests --- conftest.py | 1 + hatchet_sdk/loader.py | 8 +++- poetry.lock | 100 ++++++++++++++++++++++++++++++++++++++++-- pyproject.toml | 4 ++ tests/test_client.py | 17 +++++++ 5 files changed, 126 insertions(+), 4 deletions(-) create mode 100644 tests/test_client.py diff --git a/conftest.py b/conftest.py index 2aff5cd3..acd22dd8 100644 --- a/conftest.py +++ b/conftest.py @@ -1,4 +1,5 @@ import logging +import os import subprocess import time from io import BytesIO diff --git a/hatchet_sdk/loader.py b/hatchet_sdk/loader.py index 7fc4eb63..deda2348 100644 --- a/hatchet_sdk/loader.py +++ b/hatchet_sdk/loader.py @@ -5,7 +5,7 @@ from pydantic import BaseModel, ConfigDict, ValidationInfo, field_validator -from .token import get_addresses_from_jwt, get_tenant_id_from_jwt +from hatchet_sdk.token import get_addresses_from_jwt, get_tenant_id_from_jwt class ClientTLSConfig(BaseModel): @@ -125,6 +125,9 @@ def validate_host_port(cls, host_port: str, info: ValidationInfo) -> str: token = cast(str, info.data.get("token")) + if not token: + raise ValueError("Token must be set") + _, grpc_broadcast_address = get_addresses_from_jwt(token) return grpc_broadcast_address @@ -142,6 +145,9 @@ def validate_server_url(cls, server_url: str, info: ValidationInfo) -> str: token = cast(str, info.data.get("token")) + if not token: + raise ValueError("Token must be set") + _server_url, _ = get_addresses_from_jwt(token) return _server_url diff --git a/poetry.lock b/poetry.lock index 603caef7..7e63063b 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.4 and should not be changed by hand. +# This file is automatically @generated by Poetry 2.0.0 and should not be changed by hand. [[package]] name = "aiohappyeyeballs" @@ -6,6 +6,7 @@ version = "2.4.4" description = "Happy Eyeballs for asyncio" optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "aiohappyeyeballs-2.4.4-py3-none-any.whl", hash = "sha256:a980909d50efcd44795c4afeca523296716d50cd756ddca6af8c65b996e27de8"}, {file = "aiohappyeyeballs-2.4.4.tar.gz", hash = "sha256:5fdd7d87889c63183afc18ce9271f9b0a7d32c2303e394468dd45d514a757745"}, @@ -17,6 +18,7 @@ version = "3.11.11" description = "Async http client/server framework (asyncio)" optional = false python-versions = ">=3.9" +groups = ["main"] files = [ {file = "aiohttp-3.11.11-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:a60804bff28662cbcf340a4d61598891f12eea3a66af48ecfdc975ceec21e3c8"}, {file = "aiohttp-3.11.11-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:4b4fa1cb5f270fb3eab079536b764ad740bb749ce69a94d4ec30ceee1b5940d5"}, @@ -115,6 +117,7 @@ version = "2.9.1" description = "Simple retry client for aiohttp" optional = false python-versions = ">=3.7" +groups = ["main"] files = [ {file = "aiohttp_retry-2.9.1-py3-none-any.whl", hash = "sha256:66d2759d1921838256a05a3f80ad7e724936f083e35be5abb5e16eed6be6dc54"}, {file = "aiohttp_retry-2.9.1.tar.gz", hash = "sha256:8eb75e904ed4ee5c2ec242fefe85bf04240f685391c4879d8f541d6028ff01f1"}, @@ -129,6 +132,7 @@ version = "1.3.2" description = "aiosignal: a list of registered asynchronous callbacks" optional = false python-versions = ">=3.9" +groups = ["main"] files = [ {file = "aiosignal-1.3.2-py2.py3-none-any.whl", hash = "sha256:45cde58e409a301715980c2b01d0c28bdde3770d8290b5eb2173759d9acb31a5"}, {file = "aiosignal-1.3.2.tar.gz", hash = "sha256:a8c255c66fafb1e499c9351d0bf32ff2d8a0321595ebac3b93713656d2436f54"}, @@ -143,6 +147,7 @@ version = "0.5.2" description = "Generator-based operators for asynchronous iteration" optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "aiostream-0.5.2-py3-none-any.whl", hash = "sha256:054660370be9d37f6fe3ece3851009240416bd082e469fd90cc8673d3818cf71"}, {file = "aiostream-0.5.2.tar.gz", hash = "sha256:b71b519a2d66c38f0872403ab86417955b77352f08d9ad02ad46fc3926b389f4"}, @@ -157,6 +162,7 @@ version = "0.7.0" description = "Reusable constraint types to use with typing.Annotated" optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "annotated_types-0.7.0-py3-none-any.whl", hash = "sha256:1f02e8b43a8fbbc3f3e0d4f0f4bfc8131bcb4eebe8849b8e5c773f3a1c582a53"}, {file = "annotated_types-0.7.0.tar.gz", hash = "sha256:aff07c09a53a08bc8cfccb9c85b05f1aa9a2a6f23728d790723543408344ce89"}, @@ -168,6 +174,8 @@ version = "5.0.1" description = "Timeout context manager for asyncio programs" optional = false python-versions = ">=3.8" +groups = ["main"] +markers = "python_version < \"3.11\"" files = [ {file = "async_timeout-5.0.1-py3-none-any.whl", hash = "sha256:39e3809566ff85354557ec2398b55e096c8364bacac9405a7a1fa429e77fe76c"}, {file = "async_timeout-5.0.1.tar.gz", hash = "sha256:d9321a7a3d5a6a5e187e824d2fa0793ce379a202935782d555d6e9d2735677d3"}, @@ -179,6 +187,7 @@ version = "24.3.0" description = "Classes Without Boilerplate" optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "attrs-24.3.0-py3-none-any.whl", hash = "sha256:ac96cd038792094f438ad1f6ff80837353805ac950cd2aa0e0625ef19850c308"}, {file = "attrs-24.3.0.tar.gz", hash = "sha256:8f5c07333d543103541ba7be0e2ce16eeee8130cb0b3f9238ab904ce1e85baff"}, @@ -198,6 +207,7 @@ version = "2.16.0" description = "Internationalization utilities" optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "babel-2.16.0-py3-none-any.whl", hash = "sha256:368b5b98b37c06b7daf6696391c3240c938b37767d4584413e8438c5c435fa8b"}, {file = "babel-2.16.0.tar.gz", hash = "sha256:d1f3554ca26605fe173f3de0c65f750f5a42f924499bf134de6423582298e316"}, @@ -212,6 +222,7 @@ version = "24.10.0" description = "The uncompromising code formatter." optional = false python-versions = ">=3.9" +groups = ["lint"] files = [ {file = "black-24.10.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:e6668650ea4b685440857138e5fe40cde4d652633b1bdffc62933d0db4ed9812"}, {file = "black-24.10.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:1c536fcf674217e87b8cc3657b81809d3c085d7bf3ef262ead700da345bfa6ea"}, @@ -258,6 +269,7 @@ version = "0.1.5" description = "Pure Python CEL Implementation" optional = false python-versions = ">=3.7, <4" +groups = ["main"] files = [ {file = "cel-python-0.1.5.tar.gz", hash = "sha256:d3911bb046bc3ed12792bd88ab453f72d98c66923b72a2fa016bcdffd96e2f98"}, {file = "cel_python-0.1.5-py3-none-any.whl", hash = "sha256:ac81fab8ba08b633700a45d84905be2863529c6a32935c9da7ef53fc06844f1a"}, @@ -278,6 +290,7 @@ version = "2024.12.14" description = "Python package for providing Mozilla's CA Bundle." optional = false python-versions = ">=3.6" +groups = ["main"] files = [ {file = "certifi-2024.12.14-py3-none-any.whl", hash = "sha256:1275f7a45be9464efc1173084eaa30f866fe2e47d389406136d332ed4967ec56"}, {file = "certifi-2024.12.14.tar.gz", hash = "sha256:b650d30f370c2b724812bee08008be0c4163b163ddaec3f2546c1caf65f191db"}, @@ -289,6 +302,7 @@ version = "3.4.1" description = "The Real First Universal Charset Detector. Open, modern and actively maintained alternative to Chardet." optional = false python-versions = ">=3.7" +groups = ["main"] files = [ {file = "charset_normalizer-3.4.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:91b36a978b5ae0ee86c394f5a54d6ef44db1de0815eb43de826d41d21e4af3de"}, {file = "charset_normalizer-3.4.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7461baadb4dc00fd9e0acbe254e3d7d2112e7f92ced2adc96e54ef6501c5f176"}, @@ -390,6 +404,7 @@ version = "8.1.8" description = "Composable command line interface toolkit" optional = false python-versions = ">=3.7" +groups = ["lint"] files = [ {file = "click-8.1.8-py3-none-any.whl", hash = "sha256:63c132bbbed01578a06712a2d1f497bb62d9c1c0d329b7903a866228027263b2"}, {file = "click-8.1.8.tar.gz", hash = "sha256:ed53c9d8990d83c2a27deae68e4ee337473f6330c040a31d4225c9574d16096a"}, @@ -404,10 +419,12 @@ version = "0.4.6" description = "Cross-platform colored terminal text." optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7" +groups = ["main", "dev", "lint", "test"] files = [ {file = "colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6"}, {file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"}, ] +markers = {main = "sys_platform == \"win32\"", dev = "sys_platform == \"win32\"", lint = "platform_system == \"Windows\"", test = "sys_platform == \"win32\""} [[package]] name = "deprecated" @@ -415,6 +432,7 @@ version = "1.2.15" description = "Python @deprecated decorator to deprecate old python classes, functions or methods." optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,>=2.7" +groups = ["main"] files = [ {file = "Deprecated-1.2.15-py2.py3-none-any.whl", hash = "sha256:353bc4a8ac4bfc96800ddab349d89c25dec1079f65fd53acdcc1e0b975b21320"}, {file = "deprecated-1.2.15.tar.gz", hash = "sha256:683e561a90de76239796e6b6feac66b99030d2dd3fcf61ef996330f14bbb9b0d"}, @@ -432,6 +450,8 @@ version = "1.2.2" description = "Backport of PEP 654 (exception groups)" optional = false python-versions = ">=3.7" +groups = ["dev", "test"] +markers = "python_version < \"3.11\"" files = [ {file = "exceptiongroup-1.2.2-py3-none-any.whl", hash = "sha256:3111b9d131c238bec2f8f516e123e14ba243563fb135d3fe885990585aa7795b"}, {file = "exceptiongroup-1.2.2.tar.gz", hash = "sha256:47c2edf7c6738fafb49fd34290706d1a1a2f4d1c6df275526b62cbb4aa5393cc"}, @@ -446,6 +466,7 @@ version = "1.5.0" description = "A list-like structure which implements collections.abc.MutableSequence" optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "frozenlist-1.5.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:5b6a66c18b5b9dd261ca98dffcb826a525334b2f29e7caa54e182255c5f6a65a"}, {file = "frozenlist-1.5.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d1b3eb7b05ea246510b43a7e53ed1653e55c2121019a97e60cad7efb881a97bb"}, @@ -547,6 +568,7 @@ version = "1.66.0" description = "Common protobufs used in Google APIs" optional = false python-versions = ">=3.7" +groups = ["main"] files = [ {file = "googleapis_common_protos-1.66.0-py2.py3-none-any.whl", hash = "sha256:d7abcd75fabb2e0ec9f74466401f6c119a0b498e27370e9be4c94cb7e382b8ed"}, {file = "googleapis_common_protos-1.66.0.tar.gz", hash = "sha256:c3e7b33d15fdca5374cc0a7346dd92ffa847425cc4ea941d970f13680052ec8c"}, @@ -564,6 +586,7 @@ version = "1.69.0" description = "HTTP/2-based RPC framework" optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "grpcio-1.69.0-cp310-cp310-linux_armv7l.whl", hash = "sha256:2060ca95a8db295ae828d0fc1c7f38fb26ccd5edf9aa51a0f44251f5da332e97"}, {file = "grpcio-1.69.0-cp310-cp310-macosx_12_0_universal2.whl", hash = "sha256:2e52e107261fd8fa8fa457fe44bfadb904ae869d87c1280bf60f93ecd3e79278"}, @@ -631,6 +654,7 @@ version = "1.69.0" description = "Protobuf code generator for gRPC" optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "grpcio_tools-1.69.0-cp310-cp310-linux_armv7l.whl", hash = "sha256:8c210630faa581c3bd08953dac4ad21a7f49862f3b92d69686e9b436d2f1265d"}, {file = "grpcio_tools-1.69.0-cp310-cp310-macosx_12_0_universal2.whl", hash = "sha256:09b66ea279fcdaebae4ec34b1baf7577af3b14322738aa980c1c33cfea71f7d7"}, @@ -700,6 +724,7 @@ version = "3.10" description = "Internationalized Domain Names in Applications (IDNA)" optional = false python-versions = ">=3.6" +groups = ["main"] files = [ {file = "idna-3.10-py3-none-any.whl", hash = "sha256:946d195a0d259cbba61165e88e65941f16e9b36ea6ddb97f00452bae8b1287d3"}, {file = "idna-3.10.tar.gz", hash = "sha256:12f65c9b470abda6dc35cf8e63cc574b1c52b11df2c86030af0ac09b01b13ea9"}, @@ -714,6 +739,7 @@ version = "8.5.0" description = "Read metadata from Python packages" optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "importlib_metadata-8.5.0-py3-none-any.whl", hash = "sha256:45e54197d28b7a7f1559e60b95e7c567032b602131fbd588f1497f47880aa68b"}, {file = "importlib_metadata-8.5.0.tar.gz", hash = "sha256:71522656f0abace1d072b9e5481a48f07c138e00f079c38c8f883823f9c26bd7"}, @@ -737,6 +763,7 @@ version = "2.0.0" description = "brain-dead simple config-ini parsing" optional = false python-versions = ">=3.7" +groups = ["dev", "test"] files = [ {file = "iniconfig-2.0.0-py3-none-any.whl", hash = "sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374"}, {file = "iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3"}, @@ -748,6 +775,7 @@ version = "5.13.2" description = "A Python utility / library to sort Python imports." optional = false python-versions = ">=3.8.0" +groups = ["lint"] files = [ {file = "isort-5.13.2-py3-none-any.whl", hash = "sha256:8ca5e72a8d85860d5a3fa69b8745237f2939afe12dbf656afbcb47fe72d947a6"}, {file = "isort-5.13.2.tar.gz", hash = "sha256:48fdfcb9face5d58a4f6dde2e72a1fb8dcaf8ab26f95ab49fab84c2ddefb0109"}, @@ -762,6 +790,7 @@ version = "1.0.1" description = "JSON Matching Expressions" optional = false python-versions = ">=3.7" +groups = ["main"] files = [ {file = "jmespath-1.0.1-py3-none-any.whl", hash = "sha256:02e2e4cc71b5bcab88332eebf907519190dd9e6e82107fa7f83b1003a6252980"}, {file = "jmespath-1.0.1.tar.gz", hash = "sha256:90261b206d6defd58fdd5e85f478bf633a2901798906be2ad389150c5c60edbe"}, @@ -773,6 +802,7 @@ version = "0.12.0" description = "a modern parsing library" optional = false python-versions = "*" +groups = ["main"] files = [ {file = "lark-parser-0.12.0.tar.gz", hash = "sha256:15967db1f1214013dca65b1180745047b9be457d73da224fcda3d9dd4e96a138"}, {file = "lark_parser-0.12.0-py2.py3-none-any.whl", hash = "sha256:0eaf30cb5ba787fe404d73a7d6e61df97b21d5a63ac26c5008c78a494373c675"}, @@ -789,6 +819,7 @@ version = "0.7.3" description = "Python logging made (stupidly) simple" optional = false python-versions = "<4.0,>=3.5" +groups = ["main"] files = [ {file = "loguru-0.7.3-py3-none-any.whl", hash = "sha256:31a33c10c8e1e10422bfd431aeb5d351c7cf7fa671e3c4df004162264b28220c"}, {file = "loguru-0.7.3.tar.gz", hash = "sha256:19480589e77d47b8d85b2c827ad95d49bf31b0dcde16593892eb51dd18706eb6"}, @@ -807,6 +838,7 @@ version = "6.1.0" description = "multidict implementation" optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "multidict-6.1.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:3380252550e372e8511d49481bd836264c009adb826b23fefcc5dd3c69692f60"}, {file = "multidict-6.1.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:99f826cbf970077383d7de805c0681799491cb939c25450b9b5b3ced03ca99f1"}, @@ -911,6 +943,7 @@ version = "1.14.1" description = "Optional static typing for Python" optional = false python-versions = ">=3.8" +groups = ["lint"] files = [ {file = "mypy-1.14.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:52686e37cf13d559f668aa398dd7ddf1f92c5d613e4f8cb262be2fb4fedb0fcb"}, {file = "mypy-1.14.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:1fb545ca340537d4b45d3eecdb3def05e913299ca72c290326be19b3804b39c0"}, @@ -970,6 +1003,7 @@ version = "1.0.0" description = "Type system extensions for programs checked with the mypy type checker." optional = false python-versions = ">=3.5" +groups = ["lint"] files = [ {file = "mypy_extensions-1.0.0-py3-none-any.whl", hash = "sha256:4392f6c0eb8a5668a69e23d168ffa70f0be9ccfd32b5cc2d26a34ae5b844552d"}, {file = "mypy_extensions-1.0.0.tar.gz", hash = "sha256:75dbf8955dc00442a438fc4d0666508a9a97b6bd41aa2f0ffe9d2f2725af0782"}, @@ -981,6 +1015,7 @@ version = "1.6.0" description = "Patch asyncio to allow nested event loops" optional = false python-versions = ">=3.5" +groups = ["main"] files = [ {file = "nest_asyncio-1.6.0-py3-none-any.whl", hash = "sha256:87af6efd6b5e897c81050477ef65c62e2b2f35d51703cae01aff2905b1852e1c"}, {file = "nest_asyncio-1.6.0.tar.gz", hash = "sha256:6f172d5449aca15afd6c646851f4e31e02c598d553a667e38cafa997cfec55fe"}, @@ -992,6 +1027,7 @@ version = "1.29.0" description = "OpenTelemetry Python API" optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "opentelemetry_api-1.29.0-py3-none-any.whl", hash = "sha256:5fcd94c4141cc49c736271f3e1efb777bebe9cc535759c54c936cca4f1b312b8"}, {file = "opentelemetry_api-1.29.0.tar.gz", hash = "sha256:d04a6cf78aad09614f52964ecb38021e248f5714dc32c2e0d8fd99517b4d69cf"}, @@ -1007,6 +1043,7 @@ version = "0.50b0" description = "OpenTelemetry Python Distro" optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "opentelemetry_distro-0.50b0-py3-none-any.whl", hash = "sha256:5fa2e2a99a047ea477fab53e73fb8088b907bda141e8440745b92eb2a84d74aa"}, {file = "opentelemetry_distro-0.50b0.tar.gz", hash = "sha256:3e059e00f53553ebd646d1162d1d3edf5d7c6d3ceafd54a49e74c90dc1c39a7d"}, @@ -1026,6 +1063,7 @@ version = "1.29.0" description = "OpenTelemetry Collector Exporters" optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "opentelemetry_exporter_otlp-1.29.0-py3-none-any.whl", hash = "sha256:b8da6e20f5b0ffe604154b1e16a407eade17ce310c42fb85bb4e1246fc3688ad"}, {file = "opentelemetry_exporter_otlp-1.29.0.tar.gz", hash = "sha256:ee7dfcccbb5e87ad9b389908452e10b7beeab55f70a83f41ce5b8c4efbde6544"}, @@ -1041,6 +1079,7 @@ version = "1.29.0" description = "OpenTelemetry Protobuf encoding" optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "opentelemetry_exporter_otlp_proto_common-1.29.0-py3-none-any.whl", hash = "sha256:a9d7376c06b4da9cf350677bcddb9618ed4b8255c3f6476975f5e38274ecd3aa"}, {file = "opentelemetry_exporter_otlp_proto_common-1.29.0.tar.gz", hash = "sha256:e7c39b5dbd1b78fe199e40ddfe477e6983cb61aa74ba836df09c3869a3e3e163"}, @@ -1055,6 +1094,7 @@ version = "1.29.0" description = "OpenTelemetry Collector Protobuf over gRPC Exporter" optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "opentelemetry_exporter_otlp_proto_grpc-1.29.0-py3-none-any.whl", hash = "sha256:5a2a3a741a2543ed162676cf3eefc2b4150e6f4f0a193187afb0d0e65039c69c"}, {file = "opentelemetry_exporter_otlp_proto_grpc-1.29.0.tar.gz", hash = "sha256:3d324d07d64574d72ed178698de3d717f62a059a93b6b7685ee3e303384e73ea"}, @@ -1075,6 +1115,7 @@ version = "1.29.0" description = "OpenTelemetry Collector Protobuf over HTTP Exporter" optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "opentelemetry_exporter_otlp_proto_http-1.29.0-py3-none-any.whl", hash = "sha256:b228bdc0f0cfab82eeea834a7f0ffdd2a258b26aa33d89fb426c29e8e934d9d0"}, {file = "opentelemetry_exporter_otlp_proto_http-1.29.0.tar.gz", hash = "sha256:b10d174e3189716f49d386d66361fbcf6f2b9ad81e05404acdee3f65c8214204"}, @@ -1095,6 +1136,7 @@ version = "0.50b0" description = "Instrumentation Tools & Auto Instrumentation for OpenTelemetry Python" optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "opentelemetry_instrumentation-0.50b0-py3-none-any.whl", hash = "sha256:b8f9fc8812de36e1c6dffa5bfc6224df258841fb387b6dfe5df15099daa10630"}, {file = "opentelemetry_instrumentation-0.50b0.tar.gz", hash = "sha256:7d98af72de8dec5323e5202e46122e5f908592b22c6d24733aad619f07d82979"}, @@ -1112,6 +1154,7 @@ version = "1.29.0" description = "OpenTelemetry Python Proto" optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "opentelemetry_proto-1.29.0-py3-none-any.whl", hash = "sha256:495069c6f5495cbf732501cdcd3b7f60fda2b9d3d4255706ca99b7ca8dec53ff"}, {file = "opentelemetry_proto-1.29.0.tar.gz", hash = "sha256:3c136aa293782e9b44978c738fff72877a4b78b5d21a64e879898db7b2d93e5d"}, @@ -1126,6 +1169,7 @@ version = "1.29.0" description = "OpenTelemetry Python SDK" optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "opentelemetry_sdk-1.29.0-py3-none-any.whl", hash = "sha256:173be3b5d3f8f7d671f20ea37056710217959e774e2749d984355d1f9391a30a"}, {file = "opentelemetry_sdk-1.29.0.tar.gz", hash = "sha256:b0787ce6aade6ab84315302e72bd7a7f2f014b0fb1b7c3295b88afe014ed0643"}, @@ -1142,6 +1186,7 @@ version = "0.50b0" description = "OpenTelemetry Semantic Conventions" optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "opentelemetry_semantic_conventions-0.50b0-py3-none-any.whl", hash = "sha256:e87efba8fdb67fb38113efea6a349531e75ed7ffc01562f65b802fcecb5e115e"}, {file = "opentelemetry_semantic_conventions-0.50b0.tar.gz", hash = "sha256:02dc6dbcb62f082de9b877ff19a3f1ffaa3c306300fa53bfac761c4567c83d38"}, @@ -1157,6 +1202,7 @@ version = "24.2" description = "Core utilities for Python packages" optional = false python-versions = ">=3.8" +groups = ["main", "dev", "lint", "test"] files = [ {file = "packaging-24.2-py3-none-any.whl", hash = "sha256:09abb1bccd265c01f4a3aa3f7a7db064b36514d2cba19a2f694fe6150451a759"}, {file = "packaging-24.2.tar.gz", hash = "sha256:c228a6dc5e932d346bc5739379109d49e8853dd8223571c7c5b55260edc0b97f"}, @@ -1168,6 +1214,7 @@ version = "0.12.1" description = "Utility library for gitignore style pattern matching of file paths." optional = false python-versions = ">=3.8" +groups = ["lint"] files = [ {file = "pathspec-0.12.1-py3-none-any.whl", hash = "sha256:a0d503e138a4c123b27490a4f7beda6a01c6f288df0e4a8b79c7eb0dc7b4cc08"}, {file = "pathspec-0.12.1.tar.gz", hash = "sha256:a482d51503a1ab33b1c67a6c3813a26953dbdc71c31dacaef9a838c4e29f5712"}, @@ -1179,6 +1226,7 @@ version = "4.3.6" description = "A small Python package for determining appropriate platform-specific dirs, e.g. a `user data dir`." optional = false python-versions = ">=3.8" +groups = ["lint"] files = [ {file = "platformdirs-4.3.6-py3-none-any.whl", hash = "sha256:73e575e1408ab8103900836b97580d5307456908a03e92031bab39e4554cc3fb"}, {file = "platformdirs-4.3.6.tar.gz", hash = "sha256:357fb2acbc885b0419afd3ce3ed34564c13c9b95c89360cd9563f73aa5e2b907"}, @@ -1195,6 +1243,7 @@ version = "1.5.0" description = "plugin and hook calling mechanisms for python" optional = false python-versions = ">=3.8" +groups = ["dev", "test"] files = [ {file = "pluggy-1.5.0-py3-none-any.whl", hash = "sha256:44e1ad92c8ca002de6377e165f3e0f1be63266ab4d554740532335b9d75ea669"}, {file = "pluggy-1.5.0.tar.gz", hash = "sha256:2cffa88e94fdc978c4c574f15f9e59b7f4201d439195c3715ca9e2486f1d0cf1"}, @@ -1210,6 +1259,7 @@ version = "0.21.1" description = "Python client for the Prometheus monitoring system." optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "prometheus_client-0.21.1-py3-none-any.whl", hash = "sha256:594b45c410d6f4f8888940fe80b5cc2521b305a1fafe1c58609ef715a001f301"}, {file = "prometheus_client-0.21.1.tar.gz", hash = "sha256:252505a722ac04b0456be05c05f75f45d760c2911ffc45f2a06bcaed9f3ae3fb"}, @@ -1224,6 +1274,7 @@ version = "0.2.1" description = "Accelerated property cache" optional = false python-versions = ">=3.9" +groups = ["main"] files = [ {file = "propcache-0.2.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:6b3f39a85d671436ee3d12c017f8fdea38509e4f25b28eb25877293c98c243f6"}, {file = "propcache-0.2.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:39d51fbe4285d5db5d92a929e3e21536ea3dd43732c5b177c7ef03f918dff9f2"}, @@ -1315,6 +1366,7 @@ version = "5.29.2" description = "" optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "protobuf-5.29.2-cp310-abi3-win32.whl", hash = "sha256:c12ba8249f5624300cf51c3d0bfe5be71a60c63e4dcf51ffe9a68771d958c851"}, {file = "protobuf-5.29.2-cp310-abi3-win_amd64.whl", hash = "sha256:842de6d9241134a973aab719ab42b008a18a90f9f07f06ba480df268f86432f9"}, @@ -1335,6 +1387,7 @@ version = "6.1.1" description = "Cross-platform lib for process and system monitoring in Python." optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,>=2.7" +groups = ["dev"] files = [ {file = "psutil-6.1.1-cp27-cp27m-macosx_10_9_x86_64.whl", hash = "sha256:9ccc4316f24409159897799b83004cb1e24f9819b0dcf9c0b68bdcb6cefee6a8"}, {file = "psutil-6.1.1-cp27-cp27m-manylinux2010_i686.whl", hash = "sha256:ca9609c77ea3b8481ab005da74ed894035936223422dc591d6772b147421f777"}, @@ -1365,6 +1418,7 @@ version = "2.10.4" description = "Data validation using Python type hints" optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "pydantic-2.10.4-py3-none-any.whl", hash = "sha256:597e135ea68be3a37552fb524bc7d0d66dcf93d395acd93a00682f1efcb8ee3d"}, {file = "pydantic-2.10.4.tar.gz", hash = "sha256:82f12e9723da6de4fe2ba888b5971157b3be7ad914267dea8f05f82b28254f06"}, @@ -1385,6 +1439,7 @@ version = "2.27.2" description = "Core functionality for Pydantic validation and serialization" optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "pydantic_core-2.27.2-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:2d367ca20b2f14095a8f4fa1210f5a7b78b8a20009ecced6b12818f455b1e9fa"}, {file = "pydantic_core-2.27.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:491a2b73db93fab69731eaee494f320faa4e093dbed776be1a829c2eb222c34c"}, @@ -1497,6 +1552,7 @@ version = "8.3.4" description = "pytest: simple powerful testing with Python" optional = false python-versions = ">=3.8" +groups = ["dev", "test"] files = [ {file = "pytest-8.3.4-py3-none-any.whl", hash = "sha256:50e16d954148559c9a74109af1eaf0c945ba2d8f30f0a3d3335edde19788b6f6"}, {file = "pytest-8.3.4.tar.gz", hash = "sha256:965370d062bce11e73868e0335abac31b4d3de0e82f4007408d242b4f8610761"}, @@ -1519,6 +1575,7 @@ version = "0.23.8" description = "Pytest support for asyncio" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "pytest_asyncio-0.23.8-py3-none-any.whl", hash = "sha256:50265d892689a5faefb84df80819d1ecef566eb3549cf915dfb33569359d1ce2"}, {file = "pytest_asyncio-0.23.8.tar.gz", hash = "sha256:759b10b33a6dc61cce40a8bd5205e302978bbbcc00e279a8b61d9a6a3c82e4d3"}, @@ -1531,12 +1588,32 @@ pytest = ">=7.0.0,<9" docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1.0)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] +[[package]] +name = "pytest-env" +version = "1.1.5" +description = "pytest plugin that allows you to add environment variables." +optional = false +python-versions = ">=3.8" +groups = ["test"] +files = [ + {file = "pytest_env-1.1.5-py3-none-any.whl", hash = "sha256:ce90cf8772878515c24b31cd97c7fa1f4481cd68d588419fd45f10ecaee6bc30"}, + {file = "pytest_env-1.1.5.tar.gz", hash = "sha256:91209840aa0e43385073ac464a554ad2947cc2fd663a9debf88d03b01e0cc1cf"}, +] + +[package.dependencies] +pytest = ">=8.3.3" +tomli = {version = ">=2.0.1", markers = "python_version < \"3.11\""} + +[package.extras] +testing = ["covdefaults (>=2.3)", "coverage (>=7.6.1)", "pytest-mock (>=3.14)"] + [[package]] name = "pytest-timeout" version = "2.3.1" description = "pytest plugin to abort hanging tests" optional = false python-versions = ">=3.7" +groups = ["test"] files = [ {file = "pytest-timeout-2.3.1.tar.gz", hash = "sha256:12397729125c6ecbdaca01035b9e5239d4db97352320af155b3f5de1ba5165d9"}, {file = "pytest_timeout-2.3.1-py3-none-any.whl", hash = "sha256:68188cb703edfc6a18fad98dc25a3c61e9f24d644b0b70f33af545219fc7813e"}, @@ -1551,6 +1628,7 @@ version = "2.9.0.post0" description = "Extensions to the standard Python datetime module" optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7" +groups = ["main"] files = [ {file = "python-dateutil-2.9.0.post0.tar.gz", hash = "sha256:37dd54208da7e1cd875388217d5e00ebd4179249f90fb72437e91a35459a0ad3"}, {file = "python_dateutil-2.9.0.post0-py2.py3-none-any.whl", hash = "sha256:a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427"}, @@ -1565,6 +1643,7 @@ version = "1.0.1" description = "Read key-value pairs from a .env file and set them as environment variables" optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "python-dotenv-1.0.1.tar.gz", hash = "sha256:e324ee90a023d808f1959c46bcbc04446a10ced277783dc6ee09987c37ec10ca"}, {file = "python_dotenv-1.0.1-py3-none-any.whl", hash = "sha256:f7b63ef50f1b690dddf550d03497b66d609393b40b564ed0d674909a68ebf16a"}, @@ -1579,6 +1658,7 @@ version = "6.0.2" description = "YAML parser and emitter for Python" optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "PyYAML-6.0.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:0a9a2848a5b7feac301353437eb7d5957887edbf81d56e903999a75a3d743086"}, {file = "PyYAML-6.0.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:29717114e51c84ddfba879543fb232a6ed60086602313ca38cce623c1d62cfbf"}, @@ -1641,6 +1721,7 @@ version = "2.32.3" description = "Python HTTP for Humans." optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "requests-2.32.3-py3-none-any.whl", hash = "sha256:70761cfe03c773ceb22aa2f671b4757976145175cdfca038c02654d061d6dcc6"}, {file = "requests-2.32.3.tar.gz", hash = "sha256:55365417734eb18255590a9ff9eb97e9e1da868d4ccd6402399eaf68af20a760"}, @@ -1662,6 +1743,7 @@ version = "75.7.0" description = "Easily download, build, install, upgrade, and uninstall Python packages" optional = false python-versions = ">=3.9" +groups = ["main"] files = [ {file = "setuptools-75.7.0-py3-none-any.whl", hash = "sha256:84fb203f278ebcf5cd08f97d3fb96d3fbed4b629d500b29ad60d11e00769b183"}, {file = "setuptools-75.7.0.tar.gz", hash = "sha256:886ff7b16cd342f1d1defc16fc98c9ce3fde69e087a4e1983d7ab634e5f41f4f"}, @@ -1682,6 +1764,7 @@ version = "1.17.0" description = "Python 2 and 3 compatibility utilities" optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7" +groups = ["main"] files = [ {file = "six-1.17.0-py2.py3-none-any.whl", hash = "sha256:4721f391ed90541fddacab5acf947aa0d3dc7d27b2e1e8eda2be8970586c3274"}, {file = "six-1.17.0.tar.gz", hash = "sha256:ff70335d468e7eb6ec65b95b99d3a2836546063f63acc5171de367e834932a81"}, @@ -1693,6 +1776,7 @@ version = "9.0.0" description = "Retry code until it succeeds" optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "tenacity-9.0.0-py3-none-any.whl", hash = "sha256:93de0c98785b27fcf659856aa9f54bfbd399e29969b0621bc7f762bd441b4539"}, {file = "tenacity-9.0.0.tar.gz", hash = "sha256:807f37ca97d62aa361264d497b0e31e92b8027044942bfa756160d908320d73b"}, @@ -1708,6 +1792,8 @@ version = "2.2.1" description = "A lil' TOML parser" optional = false python-versions = ">=3.8" +groups = ["dev", "lint", "test"] +markers = "python_version < \"3.11\"" files = [ {file = "tomli-2.2.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:678e4fa69e4575eb77d103de3df8a895e1591b48e740211bd1067378c69e8249"}, {file = "tomli-2.2.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:023aa114dd824ade0100497eb2318602af309e5a55595f76b626d6d9f3b7b0a6"}, @@ -1749,6 +1835,7 @@ version = "5.29.1.20241207" description = "Typing stubs for protobuf" optional = false python-versions = ">=3.8" +groups = ["lint"] files = [ {file = "types_protobuf-5.29.1.20241207-py3-none-any.whl", hash = "sha256:92893c42083e9b718c678badc0af7a9a1307b92afe1599e5cba5f3d35b668b2f"}, {file = "types_protobuf-5.29.1.20241207.tar.gz", hash = "sha256:2ebcadb8ab3ef2e3e2f067e0882906d64ba0dc65fc5b0fd7a8b692315b4a0be9"}, @@ -1760,6 +1847,7 @@ version = "4.12.2" description = "Backported and Experimental Type Hints for Python 3.8+" optional = false python-versions = ">=3.8" +groups = ["main", "lint"] files = [ {file = "typing_extensions-4.12.2-py3-none-any.whl", hash = "sha256:04e5ca0351e0f3f85c6853954072df659d0d13fac324d0072316b67d7794700d"}, {file = "typing_extensions-4.12.2.tar.gz", hash = "sha256:1a7ead55c7e559dd4dee8856e3a88b41225abfe1ce8df57b7c13915fe121ffb8"}, @@ -1771,6 +1859,7 @@ version = "2.3.0" description = "HTTP library with thread-safe connection pooling, file post, and more." optional = false python-versions = ">=3.9" +groups = ["main"] files = [ {file = "urllib3-2.3.0-py3-none-any.whl", hash = "sha256:1cee9ad369867bfdbbb48b7dd50374c0967a0bb7710050facf0dd6911440e3df"}, {file = "urllib3-2.3.0.tar.gz", hash = "sha256:f8c5449b3cf0861679ce7e0503c7b44b5ec981bec0d1d3795a07f1ba96f0204d"}, @@ -1788,6 +1877,8 @@ version = "1.2.0" description = "A small Python utility to set file creation time on Windows" optional = false python-versions = ">=3.5" +groups = ["main"] +markers = "sys_platform == \"win32\"" files = [ {file = "win32_setctime-1.2.0-py3-none-any.whl", hash = "sha256:95d644c4e708aba81dc3704a116d8cbc974d70b3bdb8be1d150e36be6e9d1390"}, {file = "win32_setctime-1.2.0.tar.gz", hash = "sha256:ae1fdf948f5640aae05c511ade119313fb6a30d7eabe25fef9764dca5873c4c0"}, @@ -1802,6 +1893,7 @@ version = "1.17.0" description = "Module for decorators, wrappers and monkey patching." optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "wrapt-1.17.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:2a0c23b8319848426f305f9cb0c98a6e32ee68a36264f45948ccf8e7d2b941f8"}, {file = "wrapt-1.17.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b1ca5f060e205f72bec57faae5bd817a1560fcfc4af03f414b08fa29106b7e2d"}, @@ -1876,6 +1968,7 @@ version = "1.18.3" description = "Yet another URL library" optional = false python-versions = ">=3.9" +groups = ["main"] files = [ {file = "yarl-1.18.3-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:7df647e8edd71f000a5208fe6ff8c382a1de8edfbccdbbfe649d263de07d8c34"}, {file = "yarl-1.18.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:c69697d3adff5aa4f874b19c0e4ed65180ceed6318ec856ebc423aa5850d84f7"}, @@ -1972,6 +2065,7 @@ version = "3.21.0" description = "Backport of pathlib-compatible object wrapper for zip files" optional = false python-versions = ">=3.9" +groups = ["main"] files = [ {file = "zipp-3.21.0-py3-none-any.whl", hash = "sha256:ac1bbe05fd2991f160ebce24ffbac5f6d11d83dc90891255885223d42b3cd931"}, {file = "zipp-3.21.0.tar.gz", hash = "sha256:2c9958f6430a2040341a52eb608ed6dd93ef4392e02ffe219417c1b28b5dd1f4"}, @@ -1986,6 +2080,6 @@ test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools", type = ["pytest-mypy"] [metadata] -lock-version = "2.0" +lock-version = "2.1" python-versions = "^3.10" -content-hash = "414d63b255f80d13260cb3a9ecce29f782af46280bba79395554595a47c42f05" +content-hash = "a51a43e75624789a2044790c6543dacf55f3b8ce2140c9dda1d1300d643df877" diff --git a/pyproject.toml b/pyproject.toml index e061ac4d..60b29c09 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,6 +50,7 @@ isort = "^5.13.2" [tool.poetry.group.test.dependencies] pytest-timeout = "^2.3.1" +pytest-env = "^1.1.5" [build-system] requires = ["poetry-core"] @@ -57,6 +58,9 @@ build-backend = "poetry.core.masonry.api" [tool.pytest.ini_options] log_cli = true +env = [ + "HATCHET_CLIENT_TOKEN=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhdWQiOiJodHRwOi8vbG9jYWxob3N0OjEyMzQiLCJleHAiOjk5OTk5OTk5OTksImdycGNfYnJvYWRjYXN0X2FkZHJlc3MiOiJodHRwOi8vbG9jYWxob3N0OjQ0MyIsImlhdCI6MTIzNDU2Nzg5MSwiaXNzIjoiaHR0cDovL2xvY2FsaG9zdDoxMjM0Iiwic2VydmVyX3VybCI6Imh0dHA6Ly9sb2NhbGhvc3Q6MTIzNCIsInN1YiI6IjAwMDAwMDAwLTVmN2QtNGM1NS1iZmEzLWFkZDk0MTc4YjhmNyIsInRva2VuX2lkIjoiMDAwMDAwMDAtZmU5ZS00ZGEyLThmOTgtNTQ5YTgxOWRmZTE5In0.bIly53KfKcXP_7wjySWvbmxG9cVqit-fzVQAF5K7rPc", +] [tool.isort] profile = "black" diff --git a/tests/test_client.py b/tests/test_client.py new file mode 100644 index 00000000..f72571f5 --- /dev/null +++ b/tests/test_client.py @@ -0,0 +1,17 @@ +import os + +from hatchet_sdk.loader import DEFAULT_HOST_PORT, ClientConfig + + +def test_client_initialization_from_defaults() -> None: + assert isinstance(ClientConfig(), ClientConfig) + + +def test_client_host_port_overrides() -> None: + host_port = "localhost:8080" + with_host_port = ClientConfig(host_port=host_port) + assert with_host_port.host_port == host_port + assert with_host_port.server_url == host_port + + assert ClientConfig().host_port != host_port + assert ClientConfig().server_url != host_port From e0e52b965aec3a08a1a55fbd44f73f99edcf2620 Mon Sep 17 00:00:00 2001 From: mrkaye97 Date: Tue, 14 Jan 2025 14:31:43 -0500 Subject: [PATCH 14/53] feat: rm typed dicts in a bunch of places in favor of pydantic --- examples/affinity-workers/event.py | 3 +- examples/affinity-workers/worker.py | 13 ++-- examples/bulk_fanout/stream.py | 2 +- examples/bulk_fanout/worker.py | 23 ++++--- examples/dedupe/worker.py | 6 +- .../durable_sticky_with_affinity/worker.py | 30 +++++---- examples/events/test_event.py | 38 ++++++------ examples/fanout/stream.py | 2 +- examples/fanout/sync_stream.py | 2 +- examples/fanout/worker.py | 6 +- examples/simple/event.py | 32 +++++----- examples/sticky_workers/event.py | 3 +- examples/sticky_workers/worker.py | 4 +- examples/sync_to_async/worker.py | 5 +- hatchet_sdk/clients/admin.py | 40 ++++++------ .../clients/dispatcher/action_listener.py | 3 +- hatchet_sdk/clients/events.py | 12 ++-- hatchet_sdk/clients/rest_client.py | 9 +-- hatchet_sdk/context/context.py | 61 ++++++------------- hatchet_sdk/features/cron.py | 7 ++- hatchet_sdk/features/scheduled.py | 7 ++- hatchet_sdk/hatchet.py | 11 ++-- hatchet_sdk/labels.py | 12 ++-- hatchet_sdk/utils/types.py | 3 + hatchet_sdk/workflow_run.py | 2 +- 25 files changed, 170 insertions(+), 166 deletions(-) diff --git a/examples/affinity-workers/event.py b/examples/affinity-workers/event.py index 3d4cae41..6b01a724 100644 --- a/examples/affinity-workers/event.py +++ b/examples/affinity-workers/event.py @@ -1,5 +1,6 @@ from dotenv import load_dotenv +from hatchet_sdk.clients.events import PushEventOptions from hatchet_sdk.hatchet import Hatchet load_dotenv() @@ -9,5 +10,5 @@ hatchet.event.push( "affinity:run", {"test": "test"}, - options={"additional_metadata": {"hello": "moon"}}, + options=PushEventOptions(additional_metadata={"hello": "moon"}), ) diff --git a/examples/affinity-workers/worker.py b/examples/affinity-workers/worker.py index 5099a804..ea721c9d 100644 --- a/examples/affinity-workers/worker.py +++ b/examples/affinity-workers/worker.py @@ -1,6 +1,7 @@ from dotenv import load_dotenv from hatchet_sdk import Context, Hatchet, WorkerLabelComparator +from hatchet_sdk.labels import DesiredWorkerLabel load_dotenv() @@ -11,12 +12,12 @@ class AffinityWorkflow: @hatchet.step( desired_worker_labels={ - "model": {"value": "fancy-ai-model-v2", "weight": 10}, - "memory": { - "value": 256, - "required": True, - "comparator": WorkerLabelComparator.LESS_THAN, - }, + "model": DesiredWorkerLabel(value="fancy-ai-model-v2", weight=10), + "memory": DesiredWorkerLabel( + value=256, + required=True, + comparator=WorkerLabelComparator.LESS_THAN, + ), }, ) async def step(self, context: Context) -> dict[str, str | None]: diff --git a/examples/bulk_fanout/stream.py b/examples/bulk_fanout/stream.py index 08d0cb4a..2eb03648 100644 --- a/examples/bulk_fanout/stream.py +++ b/examples/bulk_fanout/stream.py @@ -31,7 +31,7 @@ async def main() -> None: workflowRun = hatchet.admin.run_workflow( "Parent", {"n": 2}, - options={"additional_metadata": {streamKey: streamVal}}, + options=TriggerWorkflowOptions(additional_metadata={streamKey: streamVal}), ) # Stream all events for the additional meta key value diff --git a/examples/bulk_fanout/worker.py b/examples/bulk_fanout/worker.py index e0ea3c50..baf9953b 100644 --- a/examples/bulk_fanout/worker.py +++ b/examples/bulk_fanout/worker.py @@ -4,7 +4,7 @@ from dotenv import load_dotenv from hatchet_sdk import Context, Hatchet -from hatchet_sdk.clients.admin import ChildWorkflowRunDict +from hatchet_sdk.clients.admin import ChildTriggerWorkflowOptions, ChildWorkflowRunDict load_dotenv() @@ -22,18 +22,17 @@ async def spawn(self, context: Context) -> dict[str, list[Any]]: n = context.workflow_input().get("n", 100) - child_workflow_runs: list[ChildWorkflowRunDict] = [] - - for i in range(n): - - child_workflow_runs.append( - { - "workflow_name": "BulkChild", - "input": {"a": str(i)}, - "key": f"child{i}", - "options": {"additional_metadata": {"hello": "earth"}}, - } + child_workflow_runs = [ + ChildWorkflowRunDict( + workflow_name="BulkChild", + input={"a": str(i)}, + key=f"child{i}", + options=ChildTriggerWorkflowOptions( + additional_metadata={"hello": "earth"} + ), ) + for i in range(n) + ] if len(child_workflow_runs) == 0: return {} diff --git a/examples/dedupe/worker.py b/examples/dedupe/worker.py index 2f22f52d..6e5c1f02 100644 --- a/examples/dedupe/worker.py +++ b/examples/dedupe/worker.py @@ -4,7 +4,7 @@ from dotenv import load_dotenv -from hatchet_sdk import Context, Hatchet +from hatchet_sdk import ChildTriggerWorkflowOptions, Context, Hatchet from hatchet_sdk.clients.admin import DedupeViolationErr from hatchet_sdk.loader import ClientConfig @@ -29,7 +29,9 @@ async def spawn(self, context: Context) -> dict[str, list[Any]]: "DedupeChild", {"a": str(i)}, key=f"child{i}", - options={"additional_metadata": {"dedupe": "test"}}, + options=ChildTriggerWorkflowOptions( + additional_metadata={"dedupe": "test"} + ), ) ).result() ) diff --git a/examples/durable_sticky_with_affinity/worker.py b/examples/durable_sticky_with_affinity/worker.py index 93505d98..0e6036c2 100644 --- a/examples/durable_sticky_with_affinity/worker.py +++ b/examples/durable_sticky_with_affinity/worker.py @@ -3,7 +3,13 @@ from dotenv import load_dotenv -from hatchet_sdk import Context, StickyStrategy, WorkerLabelComparator +from hatchet_sdk import ( + ChildTriggerWorkflowOptions, + Context, + StickyStrategy, + WorkerLabelComparator, +) +from hatchet_sdk.labels import DesiredWorkerLabel from hatchet_sdk.v2.callable import DurableContext from hatchet_sdk.v2.hatchet import Hatchet @@ -15,17 +21,17 @@ @hatchet.durable( sticky=StickyStrategy.HARD, desired_worker_labels={ - "running_workflow": { - "value": "True", - "required": True, - "comparator": WorkerLabelComparator.NOT_EQUAL, - }, + "running_workflow": DesiredWorkerLabel( + value="True", + required=True, + comparator=WorkerLabelComparator.NOT_EQUAL, + ), }, ) async def my_durable_func(context: DurableContext) -> dict[str, Any]: try: ref = await context.aio.spawn_workflow( - "StickyChildWorkflow", {}, options={"sticky": True} + "StickyChildWorkflow", {}, options=ChildTriggerWorkflowOptions(sticky=True) ) result = await ref.result() except Exception as e: @@ -39,11 +45,11 @@ async def my_durable_func(context: DurableContext) -> dict[str, Any]: class StickyChildWorkflow: @hatchet.step( desired_worker_labels={ - "running_workflow": { - "value": "True", - "required": True, - "comparator": WorkerLabelComparator.NOT_EQUAL, - }, + "running_workflow": DesiredWorkerLabel( + value="True", + required=True, + comparator=WorkerLabelComparator.NOT_EQUAL, + ), }, ) async def child(self, context: Context) -> dict[str, str | None]: diff --git a/examples/events/test_event.py b/examples/events/test_event.py index a4fca8ae..5fd920e6 100644 --- a/examples/events/test_event.py +++ b/examples/events/test_event.py @@ -24,34 +24,34 @@ async def test_async_event_push(aiohatchet: Hatchet) -> None: @pytest.mark.asyncio(scope="session") async def test_async_event_bulk_push(aiohatchet: Hatchet) -> None: - events: List[BulkPushEventWithMetadata] = [ - { - "key": "event1", - "payload": {"message": "This is event 1"}, - "additional_metadata": {"source": "test", "user_id": "user123"}, - }, - { - "key": "event2", - "payload": {"message": "This is event 2"}, - "additional_metadata": {"source": "test", "user_id": "user456"}, - }, - { - "key": "event3", - "payload": {"message": "This is event 3"}, - "additional_metadata": {"source": "test", "user_id": "user789"}, - }, + events = [ + BulkPushEventWithMetadata( + key="event1", + payload={"message": "This is event 1"}, + additional_metadata={"source": "test", "user_id": "user123"}, + ), + BulkPushEventWithMetadata( + key="event2", + payload={"message": "This is event 2"}, + additional_metadata={"source": "test", "user_id": "user456"}, + ), + BulkPushEventWithMetadata( + key="event3", + payload={"message": "This is event 3"}, + additional_metadata={"source": "test", "user_id": "user789"}, + ), ] - opts: BulkPushEventOptions = {"namespace": "bulk-test"} + opts = BulkPushEventOptions(namespace="bulk-test") e = await aiohatchet.event.async_bulk_push(events, opts) assert len(e) == 3 # Sort both lists of events by their key to ensure comparison order - sorted_events = sorted(events, key=lambda x: x["key"]) + sorted_events = sorted(events, key=lambda x: x.key) sorted_returned_events = sorted(e, key=lambda x: x.key) namespace = "bulk-test" # Check that the returned events match the original events for original_event, returned_event in zip(sorted_events, sorted_returned_events): - assert returned_event.key == namespace + original_event["key"] + assert returned_event.key == namespace + original_event.key diff --git a/examples/fanout/stream.py b/examples/fanout/stream.py index 08d0cb4a..2eb03648 100644 --- a/examples/fanout/stream.py +++ b/examples/fanout/stream.py @@ -31,7 +31,7 @@ async def main() -> None: workflowRun = hatchet.admin.run_workflow( "Parent", {"n": 2}, - options={"additional_metadata": {streamKey: streamVal}}, + options=TriggerWorkflowOptions(additional_metadata={streamKey: streamVal}), ) # Stream all events for the additional meta key value diff --git a/examples/fanout/sync_stream.py b/examples/fanout/sync_stream.py index 0b1a7140..e05b510c 100644 --- a/examples/fanout/sync_stream.py +++ b/examples/fanout/sync_stream.py @@ -31,7 +31,7 @@ def main() -> None: workflowRun = hatchet.admin.run_workflow( "Parent", {"n": 2}, - options={"additional_metadata": {streamKey: streamVal}}, + options=TriggerWorkflowOptions(additional_metadata={streamKey: streamVal}), ) # Stream all events for the additional meta key value diff --git a/examples/fanout/worker.py b/examples/fanout/worker.py index c32344c9..0a678937 100644 --- a/examples/fanout/worker.py +++ b/examples/fanout/worker.py @@ -3,7 +3,7 @@ from dotenv import load_dotenv -from hatchet_sdk import Context, Hatchet +from hatchet_sdk import ChildTriggerWorkflowOptions, Context, Hatchet load_dotenv() @@ -28,7 +28,9 @@ async def spawn(self, context: Context) -> dict[str, Any]: "Child", {"a": str(i)}, key=f"child{i}", - options={"additional_metadata": {"hello": "earth"}}, + options=ChildTriggerWorkflowOptions( + additional_metadata={"hello": "earth"} + ), ) ).result() ) diff --git a/examples/simple/event.py b/examples/simple/event.py index c2d0178a..4f31de73 100644 --- a/examples/simple/event.py +++ b/examples/simple/event.py @@ -14,22 +14,22 @@ "user:create", {"test": "test"}, options={"additional_metadata": {"hello": "moon"}} ) -events: List[BulkPushEventWithMetadata] = [ - { - "key": "event1", - "payload": {"message": "This is event 1"}, - "additional_metadata": {"source": "test", "user_id": "user123"}, - }, - { - "key": "event2", - "payload": {"message": "This is event 2"}, - "additional_metadata": {"source": "test", "user_id": "user456"}, - }, - { - "key": "event3", - "payload": {"message": "This is event 3"}, - "additional_metadata": {"source": "test", "user_id": "user789"}, - }, +events = [ + BulkPushEventWithMetadata( + key="event1", + payload={"message": "This is event 1"}, + additional_metadata={"source": "test", "user_id": "user123"}, + ), + BulkPushEventWithMetadata( + key="event2", + payload={"message": "This is event 2"}, + additional_metadata={"source": "test", "user_id": "user456"}, + ), + BulkPushEventWithMetadata( + key="event3", + payload={"message": "This is event 3"}, + additional_metadata={"source": "test", "user_id": "user789"}, + ), ] diff --git a/examples/sticky_workers/event.py b/examples/sticky_workers/event.py index 55ed9b8f..67855b49 100644 --- a/examples/sticky_workers/event.py +++ b/examples/sticky_workers/event.py @@ -1,5 +1,6 @@ from dotenv import load_dotenv +from hatchet_sdk.clients.events import PushEventOptions from hatchet_sdk.hatchet import Hatchet load_dotenv() @@ -10,5 +11,5 @@ hatchet.event.push( "sticky:parent", {"test": "test"}, - options={"additional_metadata": {"hello": "moon"}}, + options=PushEventOptions(additional_metadata={"hello": "moon"}), ) diff --git a/examples/sticky_workers/worker.py b/examples/sticky_workers/worker.py index abedb820..bc681a2a 100644 --- a/examples/sticky_workers/worker.py +++ b/examples/sticky_workers/worker.py @@ -1,6 +1,6 @@ from dotenv import load_dotenv -from hatchet_sdk import Context, Hatchet, StickyStrategy +from hatchet_sdk import ChildTriggerWorkflowOptions, Context, Hatchet, StickyStrategy from hatchet_sdk.context.context import ContextAioImpl load_dotenv() @@ -21,7 +21,7 @@ def step1b(self, context: Context) -> dict[str, str | None]: @hatchet.step(parents=["step1a", "step1b"]) async def step2(self, context: ContextAioImpl) -> dict[str, str | None]: ref = await context.spawn_workflow( - "StickyChildWorkflow", {}, options={"sticky": True} + "StickyChildWorkflow", {}, options=ChildTriggerWorkflowOptions(sticky=True) ) await ref.result() diff --git a/examples/sync_to_async/worker.py b/examples/sync_to_async/worker.py index 5ac3a912..6ee1636a 100644 --- a/examples/sync_to_async/worker.py +++ b/examples/sync_to_async/worker.py @@ -6,6 +6,7 @@ from dotenv import load_dotenv from hatchet_sdk import Context, sync_to_async +from hatchet_sdk.clients.admin import ChildTriggerWorkflowOptions from hatchet_sdk.v2.hatchet import Hatchet os.environ["PYTHONASYNCIODEBUG"] = "1" @@ -31,7 +32,9 @@ async def fanout_sync_async(context: Context) -> dict[str, Any]: "Child", {"a": str(i)}, key=f"child{i}", - options={"additional_metadata": {"hello": "earth"}}, + options=ChildTriggerWorkflowOptions( + additional_metadata={"hello": "earth"} + ), ) ).result() ) diff --git a/hatchet_sdk/clients/admin.py b/hatchet_sdk/clients/admin.py index 02fdeb56..bcb65a9f 100644 --- a/hatchet_sdk/clients/admin.py +++ b/hatchet_sdk/clients/admin.py @@ -1,9 +1,10 @@ import json from datetime import datetime -from typing import Any, Callable, Dict, List, Optional, TypedDict, TypeVar, Union +from typing import Any, Callable, TypeVar, Union import grpc from google.protobuf import timestamp_pb2 +from pydantic import BaseModel from hatchet_sdk.clients.rest.models.workflow_run import WorkflowRun from hatchet_sdk.clients.rest.tenacity_utils import tenacity_retry @@ -30,6 +31,7 @@ inject_carrier_into_metadata, parse_carrier_from_metadata, ) +from hatchet_sdk.utils.types import AdditionalMetadata from hatchet_sdk.workflow_run import RunRef, WorkflowRunRef from ..loader import ClientConfig @@ -41,36 +43,36 @@ def new_admin(config: ClientConfig): return AdminClient(config) -class ScheduleTriggerWorkflowOptions(TypedDict, total=False): - parent_id: Optional[str] - parent_step_run_id: Optional[str] - child_index: Optional[int] - child_key: Optional[str] - namespace: Optional[str] +class ScheduleTriggerWorkflowOptions(BaseModel): + parent_id: str | None = None + parent_step_run_id: str | None = None + child_index: int | None = None + child_key: str | None = None + namespace: str | None = None -class ChildTriggerWorkflowOptions(TypedDict, total=False): - additional_metadata: Dict[str, str] | None = None +class ChildTriggerWorkflowOptions(BaseModel): + additional_metadata: AdditionalMetadata = {} sticky: bool | None = None -class ChildWorkflowRunDict(TypedDict, total=False): +class ChildWorkflowRunDict(BaseModel): workflow_name: str input: Any options: ChildTriggerWorkflowOptions key: str | None = None -class TriggerWorkflowOptions(ScheduleTriggerWorkflowOptions, total=False): - additional_metadata: Dict[str, str] | None = None +class TriggerWorkflowOptions(ScheduleTriggerWorkflowOptions): + additional_metadata: AdditionalMetadata = {} desired_worker_id: str | None = None namespace: str | None = None -class WorkflowRunDict(TypedDict, total=False): +class WorkflowRunDict(BaseModel): workflow_name: str input: Any - options: TriggerWorkflowOptions | None + options: TriggerWorkflowOptions class DedupeViolationErr(Exception): @@ -133,7 +135,7 @@ def _prepare_put_workflow_request( def _prepare_schedule_workflow_request( self, name: str, - schedules: List[Union[datetime, timestamp_pb2.Timestamp]], + schedules: list[Union[datetime, timestamp_pb2.Timestamp]], input={}, options: ScheduleTriggerWorkflowOptions = None, ): @@ -263,7 +265,7 @@ async def run_workflows( self, workflows: list[WorkflowRunDict], options: TriggerWorkflowOptions | None = None, - ) -> List[WorkflowRunRef]: + ) -> list[WorkflowRunRef]: if len(workflows) == 0: raise ValueError("No workflows to run") try: @@ -357,7 +359,7 @@ async def put_rate_limit( async def schedule_workflow( self, name: str, - schedules: List[Union[datetime, timestamp_pb2.Timestamp]], + schedules: list[Union[datetime, timestamp_pb2.Timestamp]], input={}, options: ScheduleTriggerWorkflowOptions = None, ) -> WorkflowVersion: @@ -443,7 +445,7 @@ def put_rate_limit( def schedule_workflow( self, name: str, - schedules: List[Union[datetime, timestamp_pb2.Timestamp]], + schedules: list[Union[datetime, timestamp_pb2.Timestamp]], input={}, options: ScheduleTriggerWorkflowOptions = None, ) -> WorkflowVersion: @@ -549,7 +551,7 @@ def run_workflow( @tenacity_retry def run_workflows( - self, workflows: List[WorkflowRunDict], options: TriggerWorkflowOptions = None + self, workflows: list[WorkflowRunDict], options: TriggerWorkflowOptions = None ) -> list[WorkflowRunRef]: workflow_run_requests: TriggerWorkflowRequest = [] try: diff --git a/hatchet_sdk/clients/dispatcher/action_listener.py b/hatchet_sdk/clients/dispatcher/action_listener.py index fc2887bd..8f388634 100644 --- a/hatchet_sdk/clients/dispatcher/action_listener.py +++ b/hatchet_sdk/clients/dispatcher/action_listener.py @@ -24,6 +24,7 @@ from hatchet_sdk.logger import logger from hatchet_sdk.utils.backoff import exp_backoff_sleep from hatchet_sdk.utils.serialization import flatten +from hatchet_sdk.utils.types import AdditionalMetadata from ...loader import ClientConfig from ...metadata import get_metadata @@ -71,7 +72,7 @@ class Action: action_payload: str action_type: ActionType retry_count: int - additional_metadata: dict[str, str] | None = None + additional_metadata: AdditionalMetadata = {} child_workflow_index: int | None = None child_workflow_key: str | None = None diff --git a/hatchet_sdk/clients/events.py b/hatchet_sdk/clients/events.py index e188d386..ca9a654a 100644 --- a/hatchet_sdk/clients/events.py +++ b/hatchet_sdk/clients/events.py @@ -6,6 +6,7 @@ import grpc from google.protobuf import timestamp_pb2 +from pydantic import BaseModel from hatchet_sdk.clients.rest.tenacity_utils import tenacity_retry from hatchet_sdk.contracts.events_pb2 import ( @@ -23,6 +24,7 @@ inject_carrier_into_metadata, parse_carrier_from_metadata, ) +from hatchet_sdk.utils.types import AdditionalMetadata from ..loader import ClientConfig from ..metadata import get_metadata @@ -43,19 +45,19 @@ def proto_timestamp_now(): return timestamp_pb2.Timestamp(seconds=seconds, nanos=nanos) -class PushEventOptions(TypedDict, total=False): - additional_metadata: Dict[str, str] | None = None +class PushEventOptions(BaseModel): + additional_metadata: AdditionalMetadata = {} namespace: str | None = None -class BulkPushEventOptions(TypedDict, total=False): +class BulkPushEventOptions(BaseModel): namespace: str | None = None -class BulkPushEventWithMetadata(TypedDict, total=False): +class BulkPushEventWithMetadata(BaseModel): key: str payload: Any - additional_metadata: Optional[Dict[str, Any]] # Optional metadata + additional_metadata: AdditionalMetadata = {} class EventClient: diff --git a/hatchet_sdk/clients/rest_client.py b/hatchet_sdk/clients/rest_client.py index dbfa5c6c..faa5672b 100644 --- a/hatchet_sdk/clients/rest_client.py +++ b/hatchet_sdk/clients/rest_client.py @@ -66,6 +66,7 @@ WorkflowRunsCancelRequest, ) from hatchet_sdk.clients.rest.models.workflow_version import WorkflowVersion +from hatchet_sdk.utils.types import AdditionalMetadata class AsyncRestApi: @@ -230,7 +231,7 @@ async def cron_create( cron_name: str, expression: str, input: dict[str, Any], - additional_metadata: dict[str, str], + additional_metadata: AdditionalMetadata, ): return await self.workflow_run_api.cron_workflow_trigger_create( tenant=self.tenant_id, @@ -279,7 +280,7 @@ async def schedule_create( name: str, trigger_at: datetime.datetime, input: dict[str, Any], - additional_metadata: dict[str, str], + additional_metadata: AdditionalMetadata, ): return await self.workflow_run_api.scheduled_workflow_run_create( tenant=self.tenant_id, @@ -486,7 +487,7 @@ def cron_create( cron_name: str, expression: str, input: dict[str, Any], - additional_metadata: dict[str, str], + additional_metadata: AdditionalMetadata, ) -> CronWorkflows: return self._run_coroutine( self.aio.cron_create( @@ -525,7 +526,7 @@ def schedule_create( workflow_name: str, trigger_at: datetime.datetime, input: dict[str, Any], - additional_metadata: dict[str, str], + additional_metadata: AdditionalMetadata, ): return self._run_coroutine( self.aio.schedule_create( diff --git a/hatchet_sdk/context/context.py b/hatchet_sdk/context/context.py index 02838c50..f4bcdb7a 100644 --- a/hatchet_sdk/context/context.py +++ b/hatchet_sdk/context/context.py @@ -54,29 +54,20 @@ class BaseContext: def _prepare_workflow_options( self, key: str | None = None, - options: ChildTriggerWorkflowOptions | None = None, + options: ChildTriggerWorkflowOptions = ChildTriggerWorkflowOptions(), worker_id: str | None = None, ) -> TriggerWorkflowOptions: workflow_run_id = self.action.workflow_run_id step_run_id = self.action.step_run_id - desired_worker_id = None - if options is not None and "sticky" in options and options["sticky"] == True: - desired_worker_id = worker_id - - meta = None - if options is not None and "additional_metadata" in options: - meta = options["additional_metadata"] - - ## TODO: Pydantic here to simplify this - trigger_options: TriggerWorkflowOptions = { - "parent_id": workflow_run_id, - "parent_step_run_id": step_run_id, - "child_key": key, - "child_index": self.spawn_index, - "additional_metadata": meta, - "desired_worker_id": desired_worker_id, - } + trigger_options = TriggerWorkflowOptions( + parent_id=workflow_run_id, + parent_step_run_id=step_run_id, + child_key=key, + child_index=self.spawn_index, + additional_metadata=options.additional_metadata, + desired_worker_id=worker_id if options.sticky else None, + ) self.spawn_index += 1 return trigger_options @@ -112,18 +103,9 @@ async def spawn_workflow( workflow_name: str, input: dict[str, Any] = {}, key: str | None = None, - options: ChildTriggerWorkflowOptions | None = None, + options: ChildTriggerWorkflowOptions = ChildTriggerWorkflowOptions(), ) -> WorkflowRunRef: worker_id = self.worker.id() - # if ( - # options is not None - # and "sticky" in options - # and options["sticky"] == True - # and not self.worker.has_workflow(workflow_name) - # ): - # raise Exception( - # f"cannot run with sticky: workflow {workflow_name} is not registered on the worker" - # ) trigger_options = self._prepare_workflow_options(key, options, worker_id) @@ -141,21 +123,16 @@ async def spawn_workflows( worker_id = self.worker.id() - bulk_trigger_workflow_runs: list[WorkflowRunDict] = [] - for child_workflow_run in child_workflow_runs: - workflow_name = child_workflow_run["workflow_name"] - input = child_workflow_run["input"] - - key = child_workflow_run.get("key") - options = child_workflow_run.get("options", {}) - - trigger_options = self._prepare_workflow_options(key, options, worker_id) - - bulk_trigger_workflow_runs.append( - WorkflowRunDict( - workflow_name=workflow_name, input=input, options=trigger_options - ) + bulk_trigger_workflow_runs = [ + WorkflowRunDict( + workflow_name=child_workflow_run.workflow_name, + input=input, + options=self._prepare_workflow_options( + child_workflow_run.key, child_workflow_run.options, worker_id + ), ) + for child_workflow_run in child_workflow_runs + ] return await self.admin_client.aio.run_workflows(bulk_trigger_workflow_runs) diff --git a/hatchet_sdk/features/cron.py b/hatchet_sdk/features/cron.py index c54e5b3b..46a0742e 100644 --- a/hatchet_sdk/features/cron.py +++ b/hatchet_sdk/features/cron.py @@ -11,6 +11,7 @@ from hatchet_sdk.clients.rest.models.workflow_run_order_by_direction import ( WorkflowRunOrderByDirection, ) +from hatchet_sdk.utils.types import AdditionalMetadata class CreateCronTriggerInput(BaseModel): @@ -25,7 +26,7 @@ class CreateCronTriggerInput(BaseModel): expression: str = None input: dict = {} - additional_metadata: dict[str, str] = {} + additional_metadata: AdditionalMetadata = {} @field_validator("expression") def validate_cron_expression(cls, v): @@ -87,7 +88,7 @@ def create( cron_name: str, expression: str, input: dict, - additional_metadata: dict[str, str], + additional_metadata: AdditionalMetadata, ) -> CronWorkflows: """ Creates a new workflow cron trigger. @@ -199,7 +200,7 @@ async def create( cron_name: str, expression: str, input: dict, - additional_metadata: dict[str, str], + additional_metadata: AdditionalMetadata, ) -> CronWorkflows: """ Asynchronously creates a new workflow cron trigger. diff --git a/hatchet_sdk/features/scheduled.py b/hatchet_sdk/features/scheduled.py index 45af2609..27495359 100644 --- a/hatchet_sdk/features/scheduled.py +++ b/hatchet_sdk/features/scheduled.py @@ -15,6 +15,7 @@ from hatchet_sdk.clients.rest.models.workflow_run_order_by_direction import ( WorkflowRunOrderByDirection, ) +from hatchet_sdk.utils.types import AdditionalMetadata class CreateScheduledTriggerInput(BaseModel): @@ -28,7 +29,7 @@ class CreateScheduledTriggerInput(BaseModel): """ input: Dict[str, Any] = {} - additional_metadata: Dict[str, str] = {} + additional_metadata: AdditionalMetadata = {} trigger_at: Optional[datetime.datetime] = None @@ -58,7 +59,7 @@ def create( workflow_name: str, trigger_at: datetime.datetime, input: Dict[str, Any], - additional_metadata: Dict[str, str], + additional_metadata: AdditionalMetadata, ) -> ScheduledWorkflows: """ Creates a new scheduled workflow run asynchronously. @@ -168,7 +169,7 @@ async def create( workflow_name: str, trigger_at: datetime.datetime, input: Dict[str, Any], - additional_metadata: Dict[str, str], + additional_metadata: AdditionalMetadatan, ) -> ScheduledWorkflows: """ Creates a new scheduled workflow run asynchronously. diff --git a/hatchet_sdk/hatchet.py b/hatchet_sdk/hatchet.py index f71e208d..801007e5 100644 --- a/hatchet_sdk/hatchet.py +++ b/hatchet_sdk/hatchet.py @@ -12,6 +12,7 @@ CreateStepRateLimit, DesiredWorkerLabels, StickyStrategy, + WorkerLabelComparator, ) from hatchet_sdk.features.cron import CronClient from hatchet_sdk.features.scheduled import ScheduledClient @@ -107,13 +108,13 @@ def inner(func: Callable[P, R]) -> Callable[P, R]: setattr(func, "_step_backoff_max_seconds", backoff_max_seconds) def create_label(d: DesiredWorkerLabel) -> DesiredWorkerLabels: - value = d["value"] if "value" in d else None + value = d.value return DesiredWorkerLabels( - strValue=str(value) if not isinstance(value, int) else None, + strValue=value if not isinstance(value, int) else None, intValue=value if isinstance(value, int) else None, - required=d["required"] if "required" in d else None, # type: ignore[arg-type] - weight=d["weight"] if "weight" in d else None, - comparator=d["comparator"] if "comparator" in d else None, # type: ignore[arg-type] + required=d.required, + weight=d.weight, + comparator=d.comparator, ) setattr( diff --git a/hatchet_sdk/labels.py b/hatchet_sdk/labels.py index 646c666d..ca808024 100644 --- a/hatchet_sdk/labels.py +++ b/hatchet_sdk/labels.py @@ -1,10 +1,10 @@ -from typing import TypedDict +from pydantic import BaseModel +from hatchet_sdk.contracts.workflows_pb2 import WorkerLabelComparator -class DesiredWorkerLabel(TypedDict, total=False): + +class DesiredWorkerLabel(BaseModel): value: str | int - required: bool | None = None + required: bool = False weight: int | None = None - comparator: int | None = ( - None # _ClassVar[WorkerLabelComparator] TODO figure out type - ) + comparator: WorkerLabelComparator | None = None diff --git a/hatchet_sdk/utils/types.py b/hatchet_sdk/utils/types.py index 30e469f7..1c48d0f4 100644 --- a/hatchet_sdk/utils/types.py +++ b/hatchet_sdk/utils/types.py @@ -6,3 +6,6 @@ class WorkflowValidator(BaseModel): workflow_input: Type[BaseModel] | None = None step_output: Type[BaseModel] | None = None + + +AdditionalMetadata = dict[str, str] diff --git a/hatchet_sdk/workflow_run.py b/hatchet_sdk/workflow_run.py index 51a23821..f29b47aa 100644 --- a/hatchet_sdk/workflow_run.py +++ b/hatchet_sdk/workflow_run.py @@ -1,5 +1,5 @@ import asyncio -from typing import Any, Coroutine, Generic, Optional, TypedDict, TypeVar +from typing import Any, Coroutine, Generic, TypeVar from hatchet_sdk.clients.run_event_listener import ( RunEventListener, From 4cdc75050136191d5f1325149523f35eafec0628 Mon Sep 17 00:00:00 2001 From: mrkaye97 Date: Tue, 14 Jan 2025 15:14:12 -0500 Subject: [PATCH 15/53] feat: couple more bits of cleanup using `Field` --- hatchet_sdk/clients/admin.py | 6 +++--- hatchet_sdk/clients/dispatcher/action_listener.py | 2 +- hatchet_sdk/clients/events.py | 6 +++--- hatchet_sdk/features/cron.py | 8 ++++---- hatchet_sdk/features/scheduled.py | 12 ++++++------ hatchet_sdk/hatchet.py | 2 +- hatchet_sdk/labels.py | 8 ++++---- hatchet_sdk/utils/types.py | 2 +- 8 files changed, 23 insertions(+), 23 deletions(-) diff --git a/hatchet_sdk/clients/admin.py b/hatchet_sdk/clients/admin.py index bcb65a9f..377d6acc 100644 --- a/hatchet_sdk/clients/admin.py +++ b/hatchet_sdk/clients/admin.py @@ -4,7 +4,7 @@ import grpc from google.protobuf import timestamp_pb2 -from pydantic import BaseModel +from pydantic import BaseModel, Field from hatchet_sdk.clients.rest.models.workflow_run import WorkflowRun from hatchet_sdk.clients.rest.tenacity_utils import tenacity_retry @@ -52,7 +52,7 @@ class ScheduleTriggerWorkflowOptions(BaseModel): class ChildTriggerWorkflowOptions(BaseModel): - additional_metadata: AdditionalMetadata = {} + additional_metadata: AdditionalMetadata = Field(default_factory=dict) sticky: bool | None = None @@ -64,7 +64,7 @@ class ChildWorkflowRunDict(BaseModel): class TriggerWorkflowOptions(ScheduleTriggerWorkflowOptions): - additional_metadata: AdditionalMetadata = {} + additional_metadata: AdditionalMetadata = Field(default_factory=dict) desired_worker_id: str | None = None namespace: str | None = None diff --git a/hatchet_sdk/clients/dispatcher/action_listener.py b/hatchet_sdk/clients/dispatcher/action_listener.py index 8f388634..35560869 100644 --- a/hatchet_sdk/clients/dispatcher/action_listener.py +++ b/hatchet_sdk/clients/dispatcher/action_listener.py @@ -72,7 +72,7 @@ class Action: action_payload: str action_type: ActionType retry_count: int - additional_metadata: AdditionalMetadata = {} + additional_metadata: AdditionalMetadata = field(default_factory=dict) child_workflow_index: int | None = None child_workflow_key: str | None = None diff --git a/hatchet_sdk/clients/events.py b/hatchet_sdk/clients/events.py index ca9a654a..7f6fdd05 100644 --- a/hatchet_sdk/clients/events.py +++ b/hatchet_sdk/clients/events.py @@ -6,7 +6,7 @@ import grpc from google.protobuf import timestamp_pb2 -from pydantic import BaseModel +from pydantic import BaseModel, Field from hatchet_sdk.clients.rest.tenacity_utils import tenacity_retry from hatchet_sdk.contracts.events_pb2 import ( @@ -46,7 +46,7 @@ def proto_timestamp_now(): class PushEventOptions(BaseModel): - additional_metadata: AdditionalMetadata = {} + additional_metadata: AdditionalMetadata = Field(default_factory=dict) namespace: str | None = None @@ -57,7 +57,7 @@ class BulkPushEventOptions(BaseModel): class BulkPushEventWithMetadata(BaseModel): key: str payload: Any - additional_metadata: AdditionalMetadata = {} + additional_metadata: AdditionalMetadata = Field(default_factory=dict) class EventClient: diff --git a/hatchet_sdk/features/cron.py b/hatchet_sdk/features/cron.py index 46a0742e..288e4441 100644 --- a/hatchet_sdk/features/cron.py +++ b/hatchet_sdk/features/cron.py @@ -1,6 +1,6 @@ -from typing import Union +from typing import Any, Union -from pydantic import BaseModel, field_validator +from pydantic import BaseModel, Field, field_validator from hatchet_sdk.client import Client from hatchet_sdk.clients.rest.models.cron_workflows import CronWorkflows @@ -25,8 +25,8 @@ class CreateCronTriggerInput(BaseModel): """ expression: str = None - input: dict = {} - additional_metadata: AdditionalMetadata = {} + input: dict[str, Any] = Field(default_factory=dict) + additional_metadata: AdditionalMetadata = Field(default_factory=dict) @field_validator("expression") def validate_cron_expression(cls, v): diff --git a/hatchet_sdk/features/scheduled.py b/hatchet_sdk/features/scheduled.py index 27495359..e215763b 100644 --- a/hatchet_sdk/features/scheduled.py +++ b/hatchet_sdk/features/scheduled.py @@ -1,7 +1,7 @@ import datetime from typing import Any, Coroutine, Dict, List, Optional, Union -from pydantic import BaseModel +from pydantic import BaseModel, Field from hatchet_sdk.client import Client from hatchet_sdk.clients.rest.models.cron_workflows import CronWorkflows @@ -28,9 +28,9 @@ class CreateScheduledTriggerInput(BaseModel): trigger_at (Optional[datetime.datetime]): The datetime when the run should be triggered. """ - input: Dict[str, Any] = {} - additional_metadata: AdditionalMetadata = {} - trigger_at: Optional[datetime.datetime] = None + input: dict[str, Any] = Field(default_factory=dict) + additional_metadata: AdditionalMetadata = Field(default_factory=dict) + trigger_at: datetime.datetime | None = None class ScheduledClient: @@ -168,8 +168,8 @@ async def create( self, workflow_name: str, trigger_at: datetime.datetime, - input: Dict[str, Any], - additional_metadata: AdditionalMetadatan, + input: dict[str, Any], + additional_metadata: AdditionalMetadata, ) -> ScheduledWorkflows: """ Creates a new scheduled workflow run asynchronously. diff --git a/hatchet_sdk/hatchet.py b/hatchet_sdk/hatchet.py index 801007e5..396209c5 100644 --- a/hatchet_sdk/hatchet.py +++ b/hatchet_sdk/hatchet.py @@ -114,7 +114,7 @@ def create_label(d: DesiredWorkerLabel) -> DesiredWorkerLabels: intValue=value if isinstance(value, int) else None, required=d.required, weight=d.weight, - comparator=d.comparator, + comparator=d.comparator, # type: ignore[arg-type] ) setattr( diff --git a/hatchet_sdk/labels.py b/hatchet_sdk/labels.py index ca808024..55836e31 100644 --- a/hatchet_sdk/labels.py +++ b/hatchet_sdk/labels.py @@ -1,10 +1,10 @@ -from pydantic import BaseModel - -from hatchet_sdk.contracts.workflows_pb2 import WorkerLabelComparator +from pydantic import BaseModel, ConfigDict class DesiredWorkerLabel(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + value: str | int required: bool = False weight: int | None = None - comparator: WorkerLabelComparator | None = None + comparator: int | None = None diff --git a/hatchet_sdk/utils/types.py b/hatchet_sdk/utils/types.py index 1c48d0f4..f02ab1f0 100644 --- a/hatchet_sdk/utils/types.py +++ b/hatchet_sdk/utils/types.py @@ -1,4 +1,4 @@ -from typing import Type +from typing import Any, Type from pydantic import BaseModel From 57b8751858cddf1b74a4d484767566cfd0c622c7 Mon Sep 17 00:00:00 2001 From: mrkaye97 Date: Tue, 14 Jan 2025 15:52:47 -0500 Subject: [PATCH 16/53] fix: clean up a bunch of type issues in the admin client --- hatchet_sdk/clients/admin.py | 254 +++++++++++++---------------- hatchet_sdk/clients/rest_client.py | 14 +- hatchet_sdk/context/context.py | 6 +- hatchet_sdk/features/cron.py | 8 +- hatchet_sdk/features/scheduled.py | 8 +- hatchet_sdk/utils/types.py | 3 +- hatchet_sdk/v2/callable.py | 3 +- poetry.lock | 19 ++- pyproject.toml | 4 +- 9 files changed, 152 insertions(+), 167 deletions(-) diff --git a/hatchet_sdk/clients/admin.py b/hatchet_sdk/clients/admin.py index 377d6acc..4978d012 100644 --- a/hatchet_sdk/clients/admin.py +++ b/hatchet_sdk/clients/admin.py @@ -1,6 +1,6 @@ import json from datetime import datetime -from typing import Any, Callable, TypeVar, Union +from typing import Any, Callable, TypeVar, Union, cast import grpc from google.protobuf import timestamp_pb2 @@ -31,7 +31,7 @@ inject_carrier_into_metadata, parse_carrier_from_metadata, ) -from hatchet_sdk.utils.types import AdditionalMetadata +from hatchet_sdk.utils.types import AdditionalMetadata, Input from hatchet_sdk.workflow_run import RunRef, WorkflowRunRef from ..loader import ClientConfig @@ -39,7 +39,7 @@ from ..workflow import WorkflowMeta -def new_admin(config: ClientConfig): +def new_admin(config: ClientConfig) -> "AdminClient": return AdminClient(config) @@ -58,7 +58,7 @@ class ChildTriggerWorkflowOptions(BaseModel): class ChildWorkflowRunDict(BaseModel): workflow_name: str - input: Any + input: Input options: ChildTriggerWorkflowOptions key: str | None = None @@ -71,7 +71,7 @@ class TriggerWorkflowOptions(ScheduleTriggerWorkflowOptions): class WorkflowRunDict(BaseModel): workflow_name: str - input: Any + input: Input options: TriggerWorkflowOptions @@ -85,24 +85,21 @@ class AdminClientBase: pooled_workflow_listener: PooledWorkflowRunListener | None = None def _prepare_workflow_request( - self, workflow_name: str, input: any, options: TriggerWorkflowOptions = None - ): + self, workflow_name: str, input: dict[str, Any], options: TriggerWorkflowOptions + ) -> TriggerWorkflowRequest: try: payload_data = json.dumps(input) + _options = options.model_dump() try: - meta = ( - None - if options is None or "additional_metadata" not in options - else options["additional_metadata"] - ) - if meta is not None: - options["additional_metadata"] = json.dumps(meta).encode("utf-8") + _options["additional_metadata"] = json.dumps( + options.additional_metadata + ).encode("utf-8") except json.JSONDecodeError as e: raise ValueError(f"Error encoding payload: {e}") return TriggerWorkflowRequest( - name=workflow_name, input=payload_data, **(options or {}) + name=workflow_name, input=payload_data, **_options ) except json.JSONDecodeError as e: raise ValueError(f"Error encoding payload: {e}") @@ -112,14 +109,14 @@ def _prepare_put_workflow_request( name: str, workflow: CreateWorkflowVersionOpts | WorkflowMeta, overrides: CreateWorkflowVersionOpts | None = None, - ): + ) -> PutWorkflowRequest: try: opts: CreateWorkflowVersionOpts if isinstance(workflow, CreateWorkflowVersionOpts): opts = workflow else: - opts = workflow.get_create_opts(self.client.config.namespace) + opts = workflow.get_create_opts(self.client.config.namespace) # type: ignore[attr-defined] if overrides is not None: opts.MergeFrom(overrides) @@ -136,9 +133,9 @@ def _prepare_schedule_workflow_request( self, name: str, schedules: list[Union[datetime, timestamp_pb2.Timestamp]], - input={}, - options: ScheduleTriggerWorkflowOptions = None, - ): + input: Input = {}, + options: ScheduleTriggerWorkflowOptions = ScheduleTriggerWorkflowOptions(), + ) -> ScheduleWorkflowRequest: timestamp_schedules = [] for schedule in schedules: if isinstance(schedule, datetime): @@ -158,7 +155,7 @@ def _prepare_schedule_workflow_request( name=name, schedules=timestamp_schedules, input=json.dumps(input), - **(options or {}), + **options.model_dump(), ) @@ -169,7 +166,7 @@ class AdminClientAioImpl(AdminClientBase): def __init__(self, config: ClientConfig): aio_conn = new_conn(config, True) self.config = config - self.aio_client = WorkflowServiceStub(aio_conn) + self.aio_client = WorkflowServiceStub(aio_conn) # type: ignore[no-untyped-call] self.token = config.token self.listener_client = new_listener(config) self.namespace = config.namespace @@ -178,13 +175,17 @@ def __init__(self, config: ClientConfig): async def run( self, function: Union[str, Callable[[Any], T]], - input: any, - options: TriggerWorkflowOptions = None, + input: Input, + options: TriggerWorkflowOptions = TriggerWorkflowOptions(), ) -> "RunRef[T]": - workflow_name = function - - if not isinstance(function, str): - workflow_name = function.function_name + workflow_name = cast( + str, + ( + function + if isinstance(function, str) + else getattr(function, "function_name") + ), + ) wrr = await self.run_workflow(workflow_name, input, options) @@ -194,11 +195,12 @@ async def run( @tenacity_retry async def run_workflow( - self, workflow_name: str, input: any, options: TriggerWorkflowOptions = None + self, + workflow_name: str, + input: Input, + options: TriggerWorkflowOptions = TriggerWorkflowOptions(), ) -> WorkflowRunRef: - ctx = parse_carrier_from_metadata( - (options or {}).get("additional_metadata", {}) - ) + ctx = parse_carrier_from_metadata(options.additional_metadata) with self.otel_tracer.start_as_current_span( f"hatchet.async_run_workflow.{workflow_name}", context=ctx @@ -211,27 +213,18 @@ async def run_workflow( self.config ) - namespace = self.namespace - - if ( - options is not None - and "namespace" in options - and options["namespace"] is not None - ): - namespace = options.pop("namespace") + namespace = options.namespace or self.namespace if namespace != "" and not workflow_name.startswith(self.namespace): workflow_name = f"{namespace}{workflow_name}" - if options is not None and "additional_metadata" in options: - options["additional_metadata"] = inject_carrier_into_metadata( - options["additional_metadata"], carrier - ) - span.set_attributes( - flatten( - options["additional_metadata"], parent_key="", separator="." - ) - ) + options.additional_metadata = inject_carrier_into_metadata( + options.additional_metadata, carrier + ) + + span.set_attributes( + flatten(options.additional_metadata, parent_key="", separator=".") + ) request = self._prepare_workflow_request(workflow_name, input, options) @@ -264,7 +257,7 @@ async def run_workflow( async def run_workflows( self, workflows: list[WorkflowRunDict], - options: TriggerWorkflowOptions | None = None, + options: TriggerWorkflowOptions = TriggerWorkflowOptions(), ) -> list[WorkflowRunRef]: if len(workflows) == 0: raise ValueError("No workflows to run") @@ -272,22 +265,14 @@ async def run_workflows( if not self.pooled_workflow_listener: self.pooled_workflow_listener = PooledWorkflowRunListener(self.config) - namespace = self.namespace - - if ( - options is not None - and "namespace" in options - and options["namespace"] is not None - ): - namespace = options["namespace"] - del options["namespace"] + namespace = options.namespace or self.namespace - workflow_run_requests: TriggerWorkflowRequest = [] + workflow_run_requests: list[TriggerWorkflowRequest] = [] for workflow in workflows: - workflow_name = workflow["workflow_name"] - input_data = workflow["input"] - options = workflow["options"] + workflow_name = workflow.workflow_name + input_data = workflow.input + options = workflow.options if namespace != "" and not workflow_name.startswith(self.namespace): workflow_name = f"{namespace}{workflow_name}" @@ -298,11 +283,11 @@ async def run_workflows( ) workflow_run_requests.append(request) - request = BulkTriggerWorkflowRequest(workflows=workflow_run_requests) + bulk_request = BulkTriggerWorkflowRequest(workflows=workflow_run_requests) resp: BulkTriggerWorkflowResponse = ( await self.aio_client.BulkTriggerWorkflow( - request, + bulk_request, metadata=get_metadata(self.token), ) ) @@ -329,9 +314,12 @@ async def put_workflow( try: opts = self._prepare_put_workflow_request(name, workflow, overrides) - return await self.aio_client.PutWorkflow( - opts, - metadata=get_metadata(self.token), + return cast( + WorkflowVersion, + await self.aio_client.PutWorkflow( + opts, + metadata=get_metadata(self.token), + ), ) except grpc.RpcError as e: raise ValueError(f"Could not put workflow: {e}") @@ -342,7 +330,7 @@ async def put_rate_limit( key: str, limit: int, duration: RateLimitDuration = RateLimitDuration.SECOND, - ): + ) -> None: try: await self.aio_client.PutRateLimit( PutRateLimitRequest( @@ -360,19 +348,11 @@ async def schedule_workflow( self, name: str, schedules: list[Union[datetime, timestamp_pb2.Timestamp]], - input={}, - options: ScheduleTriggerWorkflowOptions = None, + input: Input = {}, + options: ScheduleTriggerWorkflowOptions = ScheduleTriggerWorkflowOptions(), ) -> WorkflowVersion: try: - namespace = self.namespace - - if ( - options is not None - and "namespace" in options - and options["namespace"] is not None - ): - namespace = options["namespace"] - del options["namespace"] + namespace = options.namespace or self.namespace if namespace != "" and not name.startswith(self.namespace): name = f"{namespace}{name}" @@ -381,9 +361,12 @@ async def schedule_workflow( name, schedules, input, options ) - return await self.aio_client.ScheduleWorkflow( - request, - metadata=get_metadata(self.token), + return cast( + WorkflowVersion, + await self.aio_client.ScheduleWorkflow( + request, + metadata=get_metadata(self.token), + ), ) except grpc.RpcError as e: if e.code() == grpc.StatusCode.ALREADY_EXISTS: @@ -396,7 +379,7 @@ class AdminClient(AdminClientBase): def __init__(self, config: ClientConfig): conn = new_conn(config) self.config = config - self.client = WorkflowServiceStub(conn) + self.client = WorkflowServiceStub(conn) # type: ignore[no-untyped-call] self.aio = AdminClientAioImpl(config) self.token = config.token self.listener_client = new_listener(config) @@ -427,8 +410,8 @@ def put_rate_limit( self, key: str, limit: int, - duration: Union[RateLimitDuration.Value, str] = RateLimitDuration.SECOND, - ): + duration: Union[RateLimitDuration, str] = RateLimitDuration.SECOND, + ) -> None: try: self.client.PutRateLimit( PutRateLimitRequest( @@ -446,19 +429,11 @@ def schedule_workflow( self, name: str, schedules: list[Union[datetime, timestamp_pb2.Timestamp]], - input={}, - options: ScheduleTriggerWorkflowOptions = None, + input: Input = {}, + options: ScheduleTriggerWorkflowOptions = ScheduleTriggerWorkflowOptions(), ) -> WorkflowVersion: try: - namespace = self.namespace - - if ( - options is not None - and "namespace" in options - and options["namespace"] is not None - ): - namespace = options["namespace"] - del options["namespace"] + namespace = options.namespace or self.namespace if namespace != "" and not name.startswith(self.namespace): name = f"{namespace}{name}" @@ -467,9 +442,12 @@ def schedule_workflow( name, schedules, input, options ) - return self.client.ScheduleWorkflow( - request, - metadata=get_metadata(self.token), + return cast( + WorkflowVersion, + self.client.ScheduleWorkflow( + request, + metadata=get_metadata(self.token), + ), ) except grpc.RpcError as e: if e.code() == grpc.StatusCode.ALREADY_EXISTS: @@ -481,11 +459,12 @@ def schedule_workflow( ## TODO: `any` type hint should come from `typing` @tenacity_retry def run_workflow( - self, workflow_name: str, input: any, options: TriggerWorkflowOptions = None + self, + workflow_name: str, + input: Input, + options: TriggerWorkflowOptions = TriggerWorkflowOptions(), ) -> WorkflowRunRef: - ctx = parse_carrier_from_metadata( - (options or {}).get("additional_metadata", {}) - ) + ctx = parse_carrier_from_metadata(options.additional_metadata) with self.otel_tracer.start_as_current_span( f"hatchet.run_workflow.{workflow_name}", context=ctx @@ -498,26 +477,15 @@ def run_workflow( self.config ) - namespace = self.namespace - - ## TODO: Factor this out - it's repeated a lot of places - if ( - options is not None - and "namespace" in options - and options["namespace"] is not None - ): - namespace = options.pop("namespace") + namespace = options.namespace or self.namespace - if options is not None and "additional_metadata" in options: - options["additional_metadata"] = inject_carrier_into_metadata( - options["additional_metadata"], carrier - ) + options.additional_metadata = inject_carrier_into_metadata( + options.additional_metadata, carrier + ) - span.set_attributes( - flatten( - options["additional_metadata"], parent_key="", separator="." - ) - ) + span.set_attributes( + flatten(options.additional_metadata, parent_key="", separator=".") + ) if namespace != "" and not workflow_name.startswith(self.namespace): workflow_name = f"{namespace}{workflow_name}" @@ -551,27 +519,21 @@ def run_workflow( @tenacity_retry def run_workflows( - self, workflows: list[WorkflowRunDict], options: TriggerWorkflowOptions = None + self, + workflows: list[WorkflowRunDict], + options: TriggerWorkflowOptions = TriggerWorkflowOptions(), ) -> list[WorkflowRunRef]: - workflow_run_requests: TriggerWorkflowRequest = [] + workflow_run_requests: list[TriggerWorkflowRequest] = [] try: if not self.pooled_workflow_listener: self.pooled_workflow_listener = PooledWorkflowRunListener(self.config) for workflow in workflows: - workflow_name = workflow["workflow_name"] - input_data = workflow["input"] - options = workflow["options"] - - namespace = self.namespace + workflow_name = workflow.workflow_name + input_data = workflow.input + options = workflow.options - if ( - options is not None - and "namespace" in options - and options["namespace"] is not None - ): - namespace = options["namespace"] - del options["namespace"] + namespace = options.namespace or self.namespace if namespace != "" and not workflow_name.startswith(self.namespace): workflow_name = f"{namespace}{workflow_name}" @@ -583,10 +545,10 @@ def run_workflows( workflow_run_requests.append(request) - request = BulkTriggerWorkflowRequest(workflows=workflow_run_requests) + bulk_request = BulkTriggerWorkflowRequest(workflows=workflow_run_requests) resp: BulkTriggerWorkflowResponse = self.client.BulkTriggerWorkflow( - request, + bulk_request, metadata=get_metadata(self.token), ) @@ -605,13 +567,17 @@ def run_workflows( def run( self, function: Union[str, Callable[[Any], T]], - input: any, - options: TriggerWorkflowOptions = None, + input: Input, + options: TriggerWorkflowOptions = TriggerWorkflowOptions(), ) -> "RunRef[T]": - workflow_name = function - - if not isinstance(function, str): - workflow_name = function.function_name + workflow_name = cast( + str, + ( + function + if isinstance(function, str) + else getattr(function, "function_name") + ), + ) wrr = self.run_workflow(workflow_name, input, options) diff --git a/hatchet_sdk/clients/rest_client.py b/hatchet_sdk/clients/rest_client.py index faa5672b..c685a8e0 100644 --- a/hatchet_sdk/clients/rest_client.py +++ b/hatchet_sdk/clients/rest_client.py @@ -66,7 +66,7 @@ WorkflowRunsCancelRequest, ) from hatchet_sdk.clients.rest.models.workflow_version import WorkflowVersion -from hatchet_sdk.utils.types import AdditionalMetadata +from hatchet_sdk.utils.types import AdditionalMetadata, Input class AsyncRestApi: @@ -212,7 +212,7 @@ async def workflow_run_bulk_cancel( async def workflow_run_create( self, workflow_id: str, - input: dict[str, Any], + input: Input, version: str | None = None, additional_metadata: list[str] | None = None, ) -> WorkflowRun: @@ -230,7 +230,7 @@ async def cron_create( workflow_name: str, cron_name: str, expression: str, - input: dict[str, Any], + input: Input, additional_metadata: AdditionalMetadata, ): return await self.workflow_run_api.cron_workflow_trigger_create( @@ -279,7 +279,7 @@ async def schedule_create( self, name: str, trigger_at: datetime.datetime, - input: dict[str, Any], + input: Input, additional_metadata: AdditionalMetadata, ): return await self.workflow_run_api.scheduled_workflow_run_create( @@ -471,7 +471,7 @@ def workflow_run_bulk_cancel( def workflow_run_create( self, workflow_id: str, - input: dict[str, Any], + input: Input, version: str | None = None, additional_metadata: list[str] | None = None, ) -> WorkflowRun: @@ -486,7 +486,7 @@ def cron_create( workflow_name: str, cron_name: str, expression: str, - input: dict[str, Any], + input: Input, additional_metadata: AdditionalMetadata, ) -> CronWorkflows: return self._run_coroutine( @@ -525,7 +525,7 @@ def schedule_create( self, workflow_name: str, trigger_at: datetime.datetime, - input: dict[str, Any], + input: Input, additional_metadata: AdditionalMetadata, ): return self._run_coroutine( diff --git a/hatchet_sdk/context/context.py b/hatchet_sdk/context/context.py index f4bcdb7a..d6586a0c 100644 --- a/hatchet_sdk/context/context.py +++ b/hatchet_sdk/context/context.py @@ -18,7 +18,7 @@ BulkTriggerWorkflowRequest, TriggerWorkflowRequest, ) -from hatchet_sdk.utils.types import WorkflowValidator +from hatchet_sdk.utils.types import Input, WorkflowValidator from hatchet_sdk.utils.typing import is_basemodel_subclass from hatchet_sdk.workflow_run import WorkflowRunRef @@ -101,7 +101,7 @@ def __init__( async def spawn_workflow( self, workflow_name: str, - input: dict[str, Any] = {}, + input: Input = {}, key: str | None = None, options: ChildTriggerWorkflowOptions = ChildTriggerWorkflowOptions(), ) -> WorkflowRunRef: @@ -126,7 +126,7 @@ async def spawn_workflows( bulk_trigger_workflow_runs = [ WorkflowRunDict( workflow_name=child_workflow_run.workflow_name, - input=input, + input=child_workflow_run.input, options=self._prepare_workflow_options( child_workflow_run.key, child_workflow_run.options, worker_id ), diff --git a/hatchet_sdk/features/cron.py b/hatchet_sdk/features/cron.py index 288e4441..ccad8ab3 100644 --- a/hatchet_sdk/features/cron.py +++ b/hatchet_sdk/features/cron.py @@ -11,7 +11,7 @@ from hatchet_sdk.clients.rest.models.workflow_run_order_by_direction import ( WorkflowRunOrderByDirection, ) -from hatchet_sdk.utils.types import AdditionalMetadata +from hatchet_sdk.utils.types import AdditionalMetadata, Input class CreateCronTriggerInput(BaseModel): @@ -25,7 +25,7 @@ class CreateCronTriggerInput(BaseModel): """ expression: str = None - input: dict[str, Any] = Field(default_factory=dict) + input: Input = Field(default_factory=dict) additional_metadata: AdditionalMetadata = Field(default_factory=dict) @field_validator("expression") @@ -87,7 +87,7 @@ def create( workflow_name: str, cron_name: str, expression: str, - input: dict, + input: Input, additional_metadata: AdditionalMetadata, ) -> CronWorkflows: """ @@ -199,7 +199,7 @@ async def create( workflow_name: str, cron_name: str, expression: str, - input: dict, + input: Input, additional_metadata: AdditionalMetadata, ) -> CronWorkflows: """ diff --git a/hatchet_sdk/features/scheduled.py b/hatchet_sdk/features/scheduled.py index e215763b..58380948 100644 --- a/hatchet_sdk/features/scheduled.py +++ b/hatchet_sdk/features/scheduled.py @@ -15,7 +15,7 @@ from hatchet_sdk.clients.rest.models.workflow_run_order_by_direction import ( WorkflowRunOrderByDirection, ) -from hatchet_sdk.utils.types import AdditionalMetadata +from hatchet_sdk.utils.types import AdditionalMetadata, Input class CreateScheduledTriggerInput(BaseModel): @@ -28,7 +28,7 @@ class CreateScheduledTriggerInput(BaseModel): trigger_at (Optional[datetime.datetime]): The datetime when the run should be triggered. """ - input: dict[str, Any] = Field(default_factory=dict) + input: Input = Field(default_factory=dict) additional_metadata: AdditionalMetadata = Field(default_factory=dict) trigger_at: datetime.datetime | None = None @@ -58,7 +58,7 @@ def create( self, workflow_name: str, trigger_at: datetime.datetime, - input: Dict[str, Any], + input: Input, additional_metadata: AdditionalMetadata, ) -> ScheduledWorkflows: """ @@ -168,7 +168,7 @@ async def create( self, workflow_name: str, trigger_at: datetime.datetime, - input: dict[str, Any], + input: Input, additional_metadata: AdditionalMetadata, ) -> ScheduledWorkflows: """ diff --git a/hatchet_sdk/utils/types.py b/hatchet_sdk/utils/types.py index f02ab1f0..c3af00ea 100644 --- a/hatchet_sdk/utils/types.py +++ b/hatchet_sdk/utils/types.py @@ -8,4 +8,5 @@ class WorkflowValidator(BaseModel): step_output: Type[BaseModel] | None = None -AdditionalMetadata = dict[str, str] +AdditionalMetadata = dict[str, Any] +Input = dict[str, Any] diff --git a/hatchet_sdk/v2/callable.py b/hatchet_sdk/v2/callable.py index 097a7d87..fa4aae99 100644 --- a/hatchet_sdk/v2/callable.py +++ b/hatchet_sdk/v2/callable.py @@ -26,6 +26,7 @@ from hatchet_sdk.labels import DesiredWorkerLabel from hatchet_sdk.logger import logger from hatchet_sdk.rate_limit import RateLimit +from hatchet_sdk.utils.types import Input from hatchet_sdk.v2.concurrency import ConcurrencyFunction from hatchet_sdk.workflow_run import RunRef @@ -176,7 +177,7 @@ class DurableContext(Context): def run( self, function: str | Callable[[Context], Any], - input: dict[Any, Any] = {}, + input: Input = {}, key: str | None = None, options: ChildTriggerWorkflowOptions | None = None, ) -> "RunRef[T]": diff --git a/poetry.lock b/poetry.lock index 7e63063b..498f39b4 100644 --- a/poetry.lock +++ b/poetry.lock @@ -580,13 +580,28 @@ protobuf = ">=3.20.2,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4 [package.extras] grpc = ["grpcio (>=1.44.0,<2.0.0.dev0)"] +[[package]] +name = "grpc-stubs" +version = "1.53.0.5" +description = "Mypy stubs for gRPC" +optional = false +python-versions = ">=3.6" +groups = ["dev"] +files = [ + {file = "grpc-stubs-1.53.0.5.tar.gz", hash = "sha256:3e1b642775cbc3e0c6332cfcedfccb022176db87e518757bef3a1241397be406"}, + {file = "grpc_stubs-1.53.0.5-py3-none-any.whl", hash = "sha256:04183fb65a1b166a1febb9627e3d9647d3926ccc2dfe049fe7b6af243428dbe1"}, +] + +[package.dependencies] +grpcio = "*" + [[package]] name = "grpcio" version = "1.69.0" description = "HTTP/2-based RPC framework" optional = false python-versions = ">=3.8" -groups = ["main"] +groups = ["main", "dev"] files = [ {file = "grpcio-1.69.0-cp310-cp310-linux_armv7l.whl", hash = "sha256:2060ca95a8db295ae828d0fc1c7f38fb26ccd5edf9aa51a0f44251f5da332e97"}, {file = "grpcio-1.69.0-cp310-cp310-macosx_12_0_universal2.whl", hash = "sha256:2e52e107261fd8fa8fa457fe44bfadb904ae869d87c1280bf60f93ecd3e79278"}, @@ -2082,4 +2097,4 @@ type = ["pytest-mypy"] [metadata] lock-version = "2.1" python-versions = "^3.10" -content-hash = "a51a43e75624789a2044790c6543dacf55f3b8ce2140c9dda1d1300d643df877" +content-hash = "59a1e9a4aafe7da78bfd9b85af64531167192e56dd0a46dc4c2e40e147cad40d" diff --git a/pyproject.toml b/pyproject.toml index 60b29c09..23ad379f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,6 +41,7 @@ prometheus-client = "^0.21.1" pytest = "^8.2.2" pytest-asyncio = "^0.23.8" psutil = "^6.0.0" +grpc-stubs = "^1.53.0.5" [tool.poetry.group.lint.dependencies] mypy = "^1.14.0" @@ -98,7 +99,8 @@ files = [ "hatchet_sdk/context/worker_context.py", "hatchet_sdk/clients/dispatcher/dispatcher.py", "hatchet_sdk/loader.py", - "hatchet_sdk/token.py" + "hatchet_sdk/token.py", + "hatchet_sdk/clients/admin.py", ] follow_imports = "silent" disable_error_code = ["unused-coroutine"] From 6c46ea3b151b49789065315f352fb3cb024b6c34 Mon Sep 17 00:00:00 2001 From: mrkaye97 Date: Tue, 14 Jan 2025 16:31:31 -0500 Subject: [PATCH 17/53] feat: fix a bunch more type issues w/ pydantic --- examples/simple/event.py | 5 +- hatchet_sdk/clients/events.py | 98 +++++++++++++++++------------------ pyproject.toml | 1 + 3 files changed, 50 insertions(+), 54 deletions(-) diff --git a/examples/simple/event.py b/examples/simple/event.py index 4f31de73..bc8a9068 100644 --- a/examples/simple/event.py +++ b/examples/simple/event.py @@ -3,7 +3,7 @@ from dotenv import load_dotenv from hatchet_sdk import new_client -from hatchet_sdk.clients.events import BulkPushEventWithMetadata +from hatchet_sdk.clients.events import BulkPushEventOptions, BulkPushEventWithMetadata load_dotenv() @@ -34,8 +34,7 @@ result = client.event.bulk_push( - events, - options={"namespace": "bulk-test"}, + events, options=BulkPushEventOptions(namespace="bulk-test") ) print(result) diff --git a/hatchet_sdk/clients/events.py b/hatchet_sdk/clients/events.py index 7f6fdd05..95a5a949 100644 --- a/hatchet_sdk/clients/events.py +++ b/hatchet_sdk/clients/events.py @@ -1,7 +1,7 @@ import asyncio import datetime import json -from typing import Any, Dict, List, Optional, TypedDict +from typing import Any, Dict, List, Optional, TypedDict, cast from uuid import uuid4 import grpc @@ -19,6 +19,7 @@ from hatchet_sdk.contracts.events_pb2_grpc import EventsServiceStub from hatchet_sdk.utils.serialization import flatten from hatchet_sdk.utils.tracing import ( + OTEL_CARRIER_KEY, create_carrier, create_tracer, inject_carrier_into_metadata, @@ -30,14 +31,14 @@ from ..metadata import get_metadata -def new_event(conn, config: ClientConfig): +def new_event(conn: grpc.Channel, config: ClientConfig) -> "EventClient": return EventClient( - client=EventsServiceStub(conn), + client=EventsServiceStub(conn), # type: ignore[no-untyped-call] config=config, ) -def proto_timestamp_now(): +def proto_timestamp_now() -> timestamp_pb2.Timestamp: t = datetime.datetime.now().timestamp() seconds = int(t) nanos = int(t % 1 * 1e9) @@ -52,6 +53,7 @@ class PushEventOptions(BaseModel): class BulkPushEventOptions(BaseModel): namespace: str | None = None + otel_carrier: dict[str, str] = Field(default_factory=dict) class BulkPushEventWithMetadata(BaseModel): @@ -68,7 +70,10 @@ def __init__(self, client: EventsServiceStub, config: ClientConfig): self.otel_tracer = create_tracer(config=config) async def async_push( - self, event_key, payload, options: Optional[PushEventOptions] = None + self, + event_key: str, + payload: dict[str, Any], + options: PushEventOptions = PushEventOptions(), ) -> Event: return await asyncio.to_thread( self.push, event_key=event_key, payload=payload, options=options @@ -76,51 +81,47 @@ async def async_push( async def async_bulk_push( self, - events: List[BulkPushEventWithMetadata], - options: Optional[BulkPushEventOptions] = None, + events: list[BulkPushEventWithMetadata], + options: BulkPushEventOptions = BulkPushEventOptions(), ) -> List[Event]: return await asyncio.to_thread(self.bulk_push, events=events, options=options) @tenacity_retry - def push(self, event_key, payload, options: PushEventOptions = None) -> Event: - ctx = parse_carrier_from_metadata( - (options or {}).get("additional_metadata", {}) - ) + def push( + self, + event_key: str, + payload: dict[str, Any], + options: PushEventOptions = PushEventOptions(), + ) -> Event: + ctx = parse_carrier_from_metadata(options.additional_metadata) with self.otel_tracer.start_as_current_span( "hatchet.push", context=ctx ) as span: carrier = create_carrier() - namespace = self.namespace - - if ( - options is not None - and "namespace" in options - and options["namespace"] is not None - ): - namespace = options.pop("namespace") + namespace = options.namespace or self.namespace namespaced_event_key = namespace + event_key try: meta = inject_carrier_into_metadata( - dict() if options is None else options["additional_metadata"], + options.additional_metadata, carrier, ) - meta_bytes = None if meta is None else json.dumps(meta).encode("utf-8") + meta_bytes = None if meta is None else json.dumps(meta) except Exception as e: raise ValueError(f"Error encoding meta: {e}") span.set_attributes(flatten(meta, parent_key="", separator=".")) try: - payload_bytes = json.dumps(payload).encode("utf-8") - except json.UnicodeEncodeError as e: + payload_str = json.dumps(payload) + except (TypeError, ValueError) as e: raise ValueError(f"Error encoding payload: {e}") request = PushEventRequest( key=namespaced_event_key, - payload=payload_bytes, + payload=payload_str, eventTimestamp=proto_timestamp_now(), additionalMetadata=meta_bytes, ) @@ -128,7 +129,9 @@ def push(self, event_key, payload, options: PushEventOptions = None) -> Event: span.add_event("Pushing event", attributes={"key": namespaced_event_key}) try: - return self.client.Push(request, metadata=get_metadata(self.token)) + return cast( + Event, self.client.Push(request, metadata=get_metadata(self.token)) + ) except grpc.RpcError as e: raise ValueError(f"gRPC error: {e}") @@ -136,20 +139,12 @@ def push(self, event_key, payload, options: PushEventOptions = None) -> Event: def bulk_push( self, events: List[BulkPushEventWithMetadata], - options: BulkPushEventOptions = None, + options: BulkPushEventOptions, ) -> List[Event]: - namespace = self.namespace + namespace = options.namespace or self.namespace bulk_push_correlation_id = uuid4() - ctx = parse_carrier_from_metadata( - (options or {}).get("additional_metadata", {}) - ) - if ( - options is not None - and "namespace" in options - and options["namespace"] is not None - ): - namespace = options.pop("namespace") + ctx = parse_carrier_from_metadata({OTEL_CARRIER_KEY: options.otel_carrier}) bulk_events = [] for event in events: @@ -161,29 +156,27 @@ def bulk_push( "bulk_push_correlation_id", str(bulk_push_correlation_id) ) - event_key = namespace + event["key"] - payload = event["payload"] + event_key = namespace + event.key + payload = event.payload + + meta = inject_carrier_into_metadata(event.additional_metadata, carrier) + span.set_attributes(flatten(meta, parent_key="", separator=".")) try: - meta = inject_carrier_into_metadata( - event.get("additional_metadata", {}), carrier - ) - meta_bytes = json.dumps(meta).encode("utf-8") if meta else None + meta_str = json.dumps(meta) except Exception as e: raise ValueError(f"Error encoding meta: {e}") - span.set_attributes(flatten(meta, parent_key="", separator=".")) - try: - payload_bytes = json.dumps(payload).encode("utf-8") - except json.UnicodeEncodeError as e: + payload = json.dumps(payload) + except (TypeError, ValueError) as e: raise ValueError(f"Error encoding payload: {e}") request = PushEventRequest( key=event_key, - payload=payload_bytes, + payload=payload, eventTimestamp=proto_timestamp_now(), - additionalMetadata=meta_bytes, + additionalMetadata=meta_str, ) bulk_events.append(request) @@ -194,11 +187,14 @@ def bulk_push( response = self.client.BulkPush( bulk_request, metadata=get_metadata(self.token) ) - return response.events + return cast( + list[Event], + response.events, + ) except grpc.RpcError as e: raise ValueError(f"gRPC error: {e}") - def log(self, message: str, step_run_id: str): + def log(self, message: str, step_run_id: str) -> None: try: request = PutLogRequest( stepRunId=step_run_id, @@ -210,7 +206,7 @@ def log(self, message: str, step_run_id: str): except Exception as e: raise ValueError(f"Error logging: {e}") - def stream(self, data: str | bytes, step_run_id: str): + def stream(self, data: str | bytes, step_run_id: str) -> None: try: if isinstance(data, str): data_bytes = data.encode("utf-8") diff --git a/pyproject.toml b/pyproject.toml index 23ad379f..44a32f7d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -101,6 +101,7 @@ files = [ "hatchet_sdk/loader.py", "hatchet_sdk/token.py", "hatchet_sdk/clients/admin.py", + "hatchet_sdk/clients/events.py", ] follow_imports = "silent" disable_error_code = ["unused-coroutine"] From ec33c9b7e7ee3c54dc7f433725f8fd6f9b24c597 Mon Sep 17 00:00:00 2001 From: mrkaye97 Date: Tue, 14 Jan 2025 16:34:13 -0500 Subject: [PATCH 18/53] fix: simple event trigger --- examples/simple/event.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/examples/simple/event.py b/examples/simple/event.py index bc8a9068..17fa8e70 100644 --- a/examples/simple/event.py +++ b/examples/simple/event.py @@ -1,9 +1,7 @@ -from typing import List - from dotenv import load_dotenv from hatchet_sdk import new_client -from hatchet_sdk.clients.events import BulkPushEventOptions, BulkPushEventWithMetadata +from hatchet_sdk.clients.events import BulkPushEventOptions, BulkPushEventWithMetadata, PushEventOptions load_dotenv() @@ -11,7 +9,7 @@ # client.event.push("user:create", {"test": "test"}) client.event.push( - "user:create", {"test": "test"}, options={"additional_metadata": {"hello": "moon"}} + "user:create", {"test": "test"}, options=PushEventOptions(additional_metadata={"hello": "moon"}) ) events = [ From d9ba2c83927d83ba5f640d32a177143a213ceaf3 Mon Sep 17 00:00:00 2001 From: mrkaye97 Date: Tue, 14 Jan 2025 16:45:40 -0500 Subject: [PATCH 19/53] fix: lint --- examples/simple/event.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/examples/simple/event.py b/examples/simple/event.py index 17fa8e70..68b70a85 100644 --- a/examples/simple/event.py +++ b/examples/simple/event.py @@ -1,7 +1,11 @@ from dotenv import load_dotenv from hatchet_sdk import new_client -from hatchet_sdk.clients.events import BulkPushEventOptions, BulkPushEventWithMetadata, PushEventOptions +from hatchet_sdk.clients.events import ( + BulkPushEventOptions, + BulkPushEventWithMetadata, + PushEventOptions, +) load_dotenv() @@ -9,7 +13,9 @@ # client.event.push("user:create", {"test": "test"}) client.event.push( - "user:create", {"test": "test"}, options=PushEventOptions(additional_metadata={"hello": "moon"}) + "user:create", + {"test": "test"}, + options=PushEventOptions(additional_metadata={"hello": "moon"}), ) events = [ From 9000d51d6e6b988c84547fa6ee6b24455b25c414 Mon Sep 17 00:00:00 2001 From: mrkaye97 Date: Wed, 22 Jan 2025 15:14:46 -0500 Subject: [PATCH 20/53] fix: lint --- hatchet_sdk/clients/events.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/hatchet_sdk/clients/events.py b/hatchet_sdk/clients/events.py index 438a0fe0..cb44f9e4 100644 --- a/hatchet_sdk/clients/events.py +++ b/hatchet_sdk/clients/events.py @@ -129,8 +129,8 @@ def push( span.add_event("Pushing event", attributes={"key": namespaced_event_key}) return cast( - Event, self.client.Push(request, metadata=get_metadata(self.token)) - ) + Event, self.client.Push(request, metadata=get_metadata(self.token)) + ) @tenacity_retry def bulk_push( @@ -183,9 +183,9 @@ def bulk_push( response = self.client.BulkPush(bulk_request, metadata=get_metadata(self.token)) return cast( - list[Event], - response.events, - ) + list[Event], + response.events, + ) def log(self, message: str, step_run_id: str) -> None: try: From 292fd4d95232fb35235561746545765a7d83a5b9 Mon Sep 17 00:00:00 2001 From: mrkaye97 Date: Wed, 22 Jan 2025 15:27:03 -0500 Subject: [PATCH 21/53] more helpful type hint --- hatchet_sdk/clients/admin.py | 24 +++++++++---------- .../clients/dispatcher/action_listener.py | 4 ++-- hatchet_sdk/clients/events.py | 6 ++--- hatchet_sdk/clients/rest_client.py | 22 ++++++++--------- hatchet_sdk/context/context.py | 4 ++-- hatchet_sdk/features/cron.py | 20 ++++++++-------- hatchet_sdk/features/scheduled.py | 18 +++++++------- hatchet_sdk/utils/types.py | 3 +-- hatchet_sdk/v2/callable.py | 4 ++-- 9 files changed, 52 insertions(+), 53 deletions(-) diff --git a/hatchet_sdk/clients/admin.py b/hatchet_sdk/clients/admin.py index 1c43ce3f..5c80fc7a 100644 --- a/hatchet_sdk/clients/admin.py +++ b/hatchet_sdk/clients/admin.py @@ -31,7 +31,7 @@ inject_carrier_into_metadata, parse_carrier_from_metadata, ) -from hatchet_sdk.utils.types import AdditionalMetadata, Input +from hatchet_sdk.utils.types import JSONSerializableDict from hatchet_sdk.workflow_run import RunRef, WorkflowRunRef from ..loader import ClientConfig @@ -52,26 +52,26 @@ class ScheduleTriggerWorkflowOptions(BaseModel): class ChildTriggerWorkflowOptions(BaseModel): - additional_metadata: AdditionalMetadata = Field(default_factory=dict) + additional_metadata: JSONSerializableDict = Field(default_factory=dict) sticky: bool | None = None class ChildWorkflowRunDict(BaseModel): workflow_name: str - input: Input + input: JSONSerializableDict options: ChildTriggerWorkflowOptions key: str | None = None class TriggerWorkflowOptions(ScheduleTriggerWorkflowOptions): - additional_metadata: AdditionalMetadata = Field(default_factory=dict) + additional_metadata: JSONSerializableDict = Field(default_factory=dict) desired_worker_id: str | None = None namespace: str | None = None class WorkflowRunDict(BaseModel): workflow_name: str - input: Input + input: JSONSerializableDict options: TriggerWorkflowOptions @@ -133,7 +133,7 @@ def _prepare_schedule_workflow_request( self, name: str, schedules: list[Union[datetime, timestamp_pb2.Timestamp]], - input: Input = {}, + input: JSONSerializableDict = {}, options: ScheduleTriggerWorkflowOptions = ScheduleTriggerWorkflowOptions(), ) -> ScheduleWorkflowRequest: timestamp_schedules = [] @@ -175,7 +175,7 @@ def __init__(self, config: ClientConfig): async def run( self, function: Union[str, Callable[[Any], T]], - input: Input, + input: JSONSerializableDict, options: TriggerWorkflowOptions = TriggerWorkflowOptions(), ) -> "RunRef[T]": workflow_name = cast( @@ -197,7 +197,7 @@ async def run( async def run_workflow( self, workflow_name: str, - input: Input, + input: JSONSerializableDict, options: TriggerWorkflowOptions = TriggerWorkflowOptions(), ) -> WorkflowRunRef: ctx = parse_carrier_from_metadata(options.additional_metadata) @@ -335,7 +335,7 @@ async def schedule_workflow( self, name: str, schedules: list[Union[datetime, timestamp_pb2.Timestamp]], - input: Input = {}, + input: JSONSerializableDict = {}, options: ScheduleTriggerWorkflowOptions = ScheduleTriggerWorkflowOptions(), ) -> WorkflowVersion: try: @@ -410,7 +410,7 @@ def schedule_workflow( self, name: str, schedules: list[Union[datetime, timestamp_pb2.Timestamp]], - input: Input = {}, + input: JSONSerializableDict = {}, options: ScheduleTriggerWorkflowOptions = ScheduleTriggerWorkflowOptions(), ) -> WorkflowVersion: try: @@ -442,7 +442,7 @@ def schedule_workflow( def run_workflow( self, workflow_name: str, - input: Input, + input: JSONSerializableDict, options: TriggerWorkflowOptions = TriggerWorkflowOptions(), ) -> WorkflowRunRef: ctx = parse_carrier_from_metadata(options.additional_metadata) @@ -542,7 +542,7 @@ def run_workflows( def run( self, function: Union[str, Callable[[Any], T]], - input: Input, + input: JSONSerializableDict, options: TriggerWorkflowOptions = TriggerWorkflowOptions(), ) -> "RunRef[T]": workflow_name = cast( diff --git a/hatchet_sdk/clients/dispatcher/action_listener.py b/hatchet_sdk/clients/dispatcher/action_listener.py index 35560869..d38087a4 100644 --- a/hatchet_sdk/clients/dispatcher/action_listener.py +++ b/hatchet_sdk/clients/dispatcher/action_listener.py @@ -24,7 +24,7 @@ from hatchet_sdk.logger import logger from hatchet_sdk.utils.backoff import exp_backoff_sleep from hatchet_sdk.utils.serialization import flatten -from hatchet_sdk.utils.types import AdditionalMetadata +from hatchet_sdk.utils.types import JSONSerializableDict from ...loader import ClientConfig from ...metadata import get_metadata @@ -72,7 +72,7 @@ class Action: action_payload: str action_type: ActionType retry_count: int - additional_metadata: AdditionalMetadata = field(default_factory=dict) + additional_metadata: JSONSerializableDict = field(default_factory=dict) child_workflow_index: int | None = None child_workflow_key: str | None = None diff --git a/hatchet_sdk/clients/events.py b/hatchet_sdk/clients/events.py index cb44f9e4..c8d5556e 100644 --- a/hatchet_sdk/clients/events.py +++ b/hatchet_sdk/clients/events.py @@ -25,7 +25,7 @@ inject_carrier_into_metadata, parse_carrier_from_metadata, ) -from hatchet_sdk.utils.types import AdditionalMetadata +from hatchet_sdk.utils.types import JSONSerializableDict from ..loader import ClientConfig from ..metadata import get_metadata @@ -47,7 +47,7 @@ def proto_timestamp_now() -> timestamp_pb2.Timestamp: class PushEventOptions(BaseModel): - additional_metadata: AdditionalMetadata = Field(default_factory=dict) + additional_metadata: JSONSerializableDict = Field(default_factory=dict) namespace: str | None = None @@ -59,7 +59,7 @@ class BulkPushEventOptions(BaseModel): class BulkPushEventWithMetadata(BaseModel): key: str payload: Any - additional_metadata: AdditionalMetadata = Field(default_factory=dict) + additional_metadata: JSONSerializableDict = Field(default_factory=dict) class EventClient: diff --git a/hatchet_sdk/clients/rest_client.py b/hatchet_sdk/clients/rest_client.py index 9ab09192..9676ef09 100644 --- a/hatchet_sdk/clients/rest_client.py +++ b/hatchet_sdk/clients/rest_client.py @@ -66,7 +66,7 @@ WorkflowRunsCancelRequest, ) from hatchet_sdk.clients.rest.models.workflow_version import WorkflowVersion -from hatchet_sdk.utils.types import AdditionalMetadata, Input +from hatchet_sdk.utils.types import JSONSerializableDict class AsyncRestApi: @@ -212,7 +212,7 @@ async def workflow_run_bulk_cancel( async def workflow_run_create( self, workflow_id: str, - input: Input, + input: JSONSerializableDict, version: str | None = None, additional_metadata: list[str] | None = None, ) -> WorkflowRun: @@ -230,8 +230,8 @@ async def cron_create( workflow_name: str, cron_name: str, expression: str, - input: Input, - additional_metadata: AdditionalMetadata, + input: JSONSerializableDict, + additional_metadata: JSONSerializableDict, ): return await self.workflow_run_api.cron_workflow_trigger_create( tenant=self.tenant_id, @@ -279,8 +279,8 @@ async def schedule_create( self, name: str, trigger_at: datetime.datetime, - input: Input, - additional_metadata: AdditionalMetadata, + input: JSONSerializableDict, + additional_metadata: JSONSerializableDict, ): return await self.workflow_run_api.scheduled_workflow_run_create( tenant=self.tenant_id, @@ -471,7 +471,7 @@ def workflow_run_bulk_cancel( def workflow_run_create( self, workflow_id: str, - input: Input, + input: JSONSerializableDict, version: str | None = None, additional_metadata: list[str] | None = None, ) -> WorkflowRun: @@ -486,8 +486,8 @@ def cron_create( workflow_name: str, cron_name: str, expression: str, - input: Input, - additional_metadata: AdditionalMetadata, + input: JSONSerializableDict, + additional_metadata: JSONSerializableDict, ) -> CronWorkflows: return self._run_coroutine( self.aio.cron_create( @@ -525,8 +525,8 @@ def schedule_create( self, workflow_name: str, trigger_at: datetime.datetime, - input: Input, - additional_metadata: AdditionalMetadata, + input: JSONSerializableDict, + additional_metadata: JSONSerializableDict, ): return self._run_coroutine( self.aio.schedule_create( diff --git a/hatchet_sdk/context/context.py b/hatchet_sdk/context/context.py index 3a01b7b3..fa52219f 100644 --- a/hatchet_sdk/context/context.py +++ b/hatchet_sdk/context/context.py @@ -18,7 +18,7 @@ BulkTriggerWorkflowRequest, TriggerWorkflowRequest, ) -from hatchet_sdk.utils.types import Input, WorkflowValidator +from hatchet_sdk.utils.types import JSONSerializableDict, WorkflowValidator from hatchet_sdk.utils.typing import is_basemodel_subclass from hatchet_sdk.workflow_run import WorkflowRunRef @@ -101,7 +101,7 @@ def __init__( async def spawn_workflow( self, workflow_name: str, - input: Input = {}, + input: JSONSerializableDict = {}, key: str | None = None, options: ChildTriggerWorkflowOptions = ChildTriggerWorkflowOptions(), ) -> WorkflowRunRef: diff --git a/hatchet_sdk/features/cron.py b/hatchet_sdk/features/cron.py index ccad8ab3..e03da751 100644 --- a/hatchet_sdk/features/cron.py +++ b/hatchet_sdk/features/cron.py @@ -11,10 +11,10 @@ from hatchet_sdk.clients.rest.models.workflow_run_order_by_direction import ( WorkflowRunOrderByDirection, ) -from hatchet_sdk.utils.types import AdditionalMetadata, Input +from hatchet_sdk.utils.types import JSONSerializableDict -class CreateCronTriggerInput(BaseModel): +class CreateCronTriggerJSONSerializableDict(BaseModel): """ Schema for creating a workflow run triggered by a cron. @@ -25,8 +25,8 @@ class CreateCronTriggerInput(BaseModel): """ expression: str = None - input: Input = Field(default_factory=dict) - additional_metadata: AdditionalMetadata = Field(default_factory=dict) + input: JSONSerializableDict = Field(default_factory=dict) + additional_metadata: JSONSerializableDict = Field(default_factory=dict) @field_validator("expression") def validate_cron_expression(cls, v): @@ -87,8 +87,8 @@ def create( workflow_name: str, cron_name: str, expression: str, - input: Input, - additional_metadata: AdditionalMetadata, + input: JSONSerializableDict, + additional_metadata: JSONSerializableDict, ) -> CronWorkflows: """ Creates a new workflow cron trigger. @@ -103,7 +103,7 @@ def create( Returns: CronWorkflows: The created cron workflow instance. """ - validated_input = CreateCronTriggerInput( + validated_input = CreateCronTriggerJSONSerializableDict( expression=expression, input=input, additional_metadata=additional_metadata ) @@ -199,8 +199,8 @@ async def create( workflow_name: str, cron_name: str, expression: str, - input: Input, - additional_metadata: AdditionalMetadata, + input: JSONSerializableDict, + additional_metadata: JSONSerializableDict, ) -> CronWorkflows: """ Asynchronously creates a new workflow cron trigger. @@ -215,7 +215,7 @@ async def create( Returns: CronWorkflows: The created cron workflow instance. """ - validated_input = CreateCronTriggerInput( + validated_input = CreateCronTriggerJSONSerializableDict( expression=expression, input=input, additional_metadata=additional_metadata ) diff --git a/hatchet_sdk/features/scheduled.py b/hatchet_sdk/features/scheduled.py index 58380948..6c232f53 100644 --- a/hatchet_sdk/features/scheduled.py +++ b/hatchet_sdk/features/scheduled.py @@ -15,10 +15,10 @@ from hatchet_sdk.clients.rest.models.workflow_run_order_by_direction import ( WorkflowRunOrderByDirection, ) -from hatchet_sdk.utils.types import AdditionalMetadata, Input +from hatchet_sdk.utils.types import JSONSerializableDict -class CreateScheduledTriggerInput(BaseModel): +class CreateScheduledTriggerJSONSerializableDict(BaseModel): """ Schema for creating a scheduled workflow run. @@ -28,8 +28,8 @@ class CreateScheduledTriggerInput(BaseModel): trigger_at (Optional[datetime.datetime]): The datetime when the run should be triggered. """ - input: Input = Field(default_factory=dict) - additional_metadata: AdditionalMetadata = Field(default_factory=dict) + input: JSONSerializableDict = Field(default_factory=dict) + additional_metadata: JSONSerializableDict = Field(default_factory=dict) trigger_at: datetime.datetime | None = None @@ -58,8 +58,8 @@ def create( self, workflow_name: str, trigger_at: datetime.datetime, - input: Input, - additional_metadata: AdditionalMetadata, + input: JSONSerializableDict, + additional_metadata: JSONSerializableDict, ) -> ScheduledWorkflows: """ Creates a new scheduled workflow run asynchronously. @@ -74,7 +74,7 @@ def create( ScheduledWorkflows: The created scheduled workflow instance. """ - validated_input = CreateScheduledTriggerInput( + validated_input = CreateScheduledTriggerJSONSerializableDict( trigger_at=trigger_at, input=input, additional_metadata=additional_metadata ) @@ -168,8 +168,8 @@ async def create( self, workflow_name: str, trigger_at: datetime.datetime, - input: Input, - additional_metadata: AdditionalMetadata, + input: JSONSerializableDict, + additional_metadata: JSONSerializableDict, ) -> ScheduledWorkflows: """ Creates a new scheduled workflow run asynchronously. diff --git a/hatchet_sdk/utils/types.py b/hatchet_sdk/utils/types.py index c3af00ea..16ab43f6 100644 --- a/hatchet_sdk/utils/types.py +++ b/hatchet_sdk/utils/types.py @@ -8,5 +8,4 @@ class WorkflowValidator(BaseModel): step_output: Type[BaseModel] | None = None -AdditionalMetadata = dict[str, Any] -Input = dict[str, Any] +JSONSerializableDict = dict[str, Any] diff --git a/hatchet_sdk/v2/callable.py b/hatchet_sdk/v2/callable.py index fa4aae99..53c88dfd 100644 --- a/hatchet_sdk/v2/callable.py +++ b/hatchet_sdk/v2/callable.py @@ -26,7 +26,7 @@ from hatchet_sdk.labels import DesiredWorkerLabel from hatchet_sdk.logger import logger from hatchet_sdk.rate_limit import RateLimit -from hatchet_sdk.utils.types import Input +from hatchet_sdk.utils.types import JSONSerializableDict from hatchet_sdk.v2.concurrency import ConcurrencyFunction from hatchet_sdk.workflow_run import RunRef @@ -177,7 +177,7 @@ class DurableContext(Context): def run( self, function: str | Callable[[Context], Any], - input: Input = {}, + input: JSONSerializableDict = {}, key: str | None = None, options: ChildTriggerWorkflowOptions | None = None, ) -> "RunRef[T]": From 6648800538ced77c2608f8ee3b4344389ed70c91 Mon Sep 17 00:00:00 2001 From: mrkaye97 Date: Wed, 22 Jan 2025 15:33:08 -0500 Subject: [PATCH 22/53] feat: type checking everywhere + remove v2 --- examples/__init__.py | 0 examples/_deprecated/README.md | 1 - .../_deprecated/concurrency_limit_rr/event.py | 15 -- .../test_dep_concurrency_limit_rr.py | 58 ----- .../concurrency_limit_rr/worker.py | 38 --- examples/_deprecated/test_event_client.py | 26 -- hatchet_sdk/v2/callable.py | 203 ---------------- hatchet_sdk/v2/concurrency.py | 47 ---- hatchet_sdk/v2/hatchet.py | 224 ------------------ pyproject.toml | 25 +- 10 files changed, 6 insertions(+), 631 deletions(-) create mode 100644 examples/__init__.py delete mode 100644 examples/_deprecated/README.md delete mode 100644 examples/_deprecated/concurrency_limit_rr/event.py delete mode 100644 examples/_deprecated/concurrency_limit_rr/test_dep_concurrency_limit_rr.py delete mode 100644 examples/_deprecated/concurrency_limit_rr/worker.py delete mode 100644 examples/_deprecated/test_event_client.py delete mode 100644 hatchet_sdk/v2/callable.py delete mode 100644 hatchet_sdk/v2/concurrency.py delete mode 100644 hatchet_sdk/v2/hatchet.py diff --git a/examples/__init__.py b/examples/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/_deprecated/README.md b/examples/_deprecated/README.md deleted file mode 100644 index ee47c61b..00000000 --- a/examples/_deprecated/README.md +++ /dev/null @@ -1 +0,0 @@ -The examples and tests in this directory are deprecated, but we're maintaining them to ensure backwards compatibility. diff --git a/examples/_deprecated/concurrency_limit_rr/event.py b/examples/_deprecated/concurrency_limit_rr/event.py deleted file mode 100644 index 16b2bcd0..00000000 --- a/examples/_deprecated/concurrency_limit_rr/event.py +++ /dev/null @@ -1,15 +0,0 @@ -from dotenv import load_dotenv - -from hatchet_sdk import new_client - -load_dotenv() - -client = new_client() - -for i in range(200): - group = "0" - - if i % 2 == 0: - group = "1" - - client.event.push("concurrency-test", {"group": group}) diff --git a/examples/_deprecated/concurrency_limit_rr/test_dep_concurrency_limit_rr.py b/examples/_deprecated/concurrency_limit_rr/test_dep_concurrency_limit_rr.py deleted file mode 100644 index 978186d1..00000000 --- a/examples/_deprecated/concurrency_limit_rr/test_dep_concurrency_limit_rr.py +++ /dev/null @@ -1,58 +0,0 @@ -import asyncio -import time - -import pytest - -from hatchet_sdk import Hatchet, Worker -from hatchet_sdk.workflow_run import WorkflowRunRef - - -# requires scope module or higher for shared event loop -@pytest.mark.parametrize("worker", ["concurrency_limit_rr"], indirect=True) -@pytest.mark.skip(reason="The timing for this test is not reliable") -@pytest.mark.asyncio(scope="session") -async def test_run(aiohatchet: Hatchet, worker: Worker) -> None: - num_groups = 2 - runs: list[WorkflowRunRef] = [] - - # Start all runs - for i in range(1, num_groups + 1): - run = aiohatchet.admin.run_workflow("ConcurrencyDemoWorkflowRR", {"group": i}) - runs.append(run) - run = aiohatchet.admin.run_workflow("ConcurrencyDemoWorkflowRR", {"group": i}) - runs.append(run) - - # Wait for all results - successful_runs = [] - cancelled_runs = [] - - start_time = time.time() - - # Process each run individually - for i, run in enumerate(runs, start=1): - try: - result = await run.result() - successful_runs.append((i, result)) - except Exception as e: - if "CANCELLED_BY_CONCURRENCY_LIMIT" in str(e): - cancelled_runs.append((i, str(e))) - else: - raise # Re-raise if it's an unexpected error - - end_time = time.time() - total_time = end_time - start_time - - # Check that we have the correct number of successful and cancelled runs - assert ( - len(successful_runs) == 4 - ), f"Expected 4 successful runs, got {len(successful_runs)}" - assert ( - len(cancelled_runs) == 0 - ), f"Expected 0 cancelled run, got {len(cancelled_runs)}" - - # Check that the total time is close to 2 seconds - assert ( - 3.8 <= total_time <= 7 - ), f"Expected runtime to be about 4 seconds, but it took {total_time:.2f} seconds" - - print(f"Total execution time: {total_time:.2f} seconds") diff --git a/examples/_deprecated/concurrency_limit_rr/worker.py b/examples/_deprecated/concurrency_limit_rr/worker.py deleted file mode 100644 index 9678e798..00000000 --- a/examples/_deprecated/concurrency_limit_rr/worker.py +++ /dev/null @@ -1,38 +0,0 @@ -import time - -from dotenv import load_dotenv - -from hatchet_sdk import ConcurrencyLimitStrategy, Context, Hatchet - -load_dotenv() - -hatchet = Hatchet(debug=True) - - -@hatchet.workflow(on_events=["concurrency-test"], schedule_timeout="10m") -class ConcurrencyDemoWorkflowRR: - - # NOTE: We're replacing the concurrency key function with a CEL expression - # to simplify architecture. - # See ../../concurrency_limit_rr/worker.py for the new implementation. - @hatchet.concurrency( - max_runs=1, limit_strategy=ConcurrencyLimitStrategy.GROUP_ROUND_ROBIN - ) - def concurrency(self, context: Context) -> str: - input = context.workflow_input() - print(input) - return f'group-{input["group"]}' - - @hatchet.step() - def step1(self, context: Context) -> None: - print("starting step1") - time.sleep(2) - print("finished step1") - pass - - -workflow = ConcurrencyDemoWorkflowRR() -worker = hatchet.worker("concurrency-demo-worker-rr", max_runs=10) -worker.register_workflow(workflow) - -worker.start() diff --git a/examples/_deprecated/test_event_client.py b/examples/_deprecated/test_event_client.py deleted file mode 100644 index 41c6ad65..00000000 --- a/examples/_deprecated/test_event_client.py +++ /dev/null @@ -1,26 +0,0 @@ -import pytest -from dotenv import load_dotenv - -from hatchet_sdk import new_client -from hatchet_sdk.hatchet import Hatchet - -load_dotenv() - - -@pytest.mark.asyncio(scope="session") -async def test_direct_client_event() -> None: - client = new_client() - e = client.event.push("user:create", {"test": "test"}) - - assert e.eventId is not None - - -@pytest.mark.filterwarnings( - "ignore:Direct access to client is deprecated:DeprecationWarning" -) -@pytest.mark.asyncio(scope="session") -async def test_hatchet_client_event() -> None: - hatchet = Hatchet() - e = hatchet.client.event.push("user:create", {"test": "test"}) - - assert e.eventId is not None diff --git a/hatchet_sdk/v2/callable.py b/hatchet_sdk/v2/callable.py deleted file mode 100644 index 53c88dfd..00000000 --- a/hatchet_sdk/v2/callable.py +++ /dev/null @@ -1,203 +0,0 @@ -import asyncio -from typing import ( - Any, - Callable, - Dict, - Generic, - List, - Optional, - TypedDict, - TypeVar, - Union, -) - -from hatchet_sdk.clients.admin import ChildTriggerWorkflowOptions -from hatchet_sdk.context.context import Context -from hatchet_sdk.contracts.workflows_pb2 import ( # type: ignore[attr-defined] - CreateStepRateLimit, - CreateWorkflowJobOpts, - CreateWorkflowStepOpts, - CreateWorkflowVersionOpts, - DesiredWorkerLabels, - StickyStrategy, - WorkflowConcurrencyOpts, - WorkflowKind, -) -from hatchet_sdk.labels import DesiredWorkerLabel -from hatchet_sdk.logger import logger -from hatchet_sdk.rate_limit import RateLimit -from hatchet_sdk.utils.types import JSONSerializableDict -from hatchet_sdk.v2.concurrency import ConcurrencyFunction -from hatchet_sdk.workflow_run import RunRef - -T = TypeVar("T") - - -class HatchetCallable(Generic[T]): - def __init__( - self, - func: Callable[[Context], T], - durable: bool = False, - name: str = "", - auto_register: bool = True, - on_events: list[str] | None = None, - on_crons: list[str] | None = None, - version: str = "", - timeout: str = "60m", - schedule_timeout: str = "5m", - sticky: StickyStrategy = None, - retries: int = 0, - rate_limits: List[RateLimit] | None = None, - concurrency: ConcurrencyFunction | None = None, - on_failure: Union["HatchetCallable[T]", None] = None, - desired_worker_labels: dict[str, DesiredWorkerLabel] = {}, - default_priority: int | None = None, - ): - self.func = func - - on_events = on_events or [] - on_crons = on_crons or [] - - limits = None - if rate_limits: - limits = [rate_limit._req for rate_limit in rate_limits or []] - - self.function_desired_worker_labels = {} - - for key, d in desired_worker_labels.items(): - value = d["value"] if "value" in d else None - self.function_desired_worker_labels[key] = DesiredWorkerLabels( - strValue=str(value) if not isinstance(value, int) else None, - intValue=value if isinstance(value, int) else None, - required=d["required"] if "required" in d else None, - weight=d["weight"] if "weight" in d else None, - comparator=d["comparator"] if "comparator" in d else None, - ) - self.sticky = sticky - self.default_priority = default_priority - self.durable = durable - self.function_name = name.lower() or str(func.__name__).lower() - self.function_version = version - self.function_on_events = on_events - self.function_on_crons = on_crons - self.function_timeout = timeout - self.function_schedule_timeout = schedule_timeout - self.function_retries = retries - self.function_rate_limits = limits - self.function_concurrency = concurrency - self.function_on_failure = on_failure - self.function_namespace = "default" - self.function_auto_register = auto_register - - self.is_coroutine = False - - if asyncio.iscoroutinefunction(func): - self.is_coroutine = True - - def __call__(self, context: Context) -> T: - return self.func(context) - - def with_namespace(self, namespace: str) -> None: - if namespace is not None and namespace != "": - self.function_namespace = namespace - self.function_name = namespace + self.function_name - - def to_workflow_opts(self) -> CreateWorkflowVersionOpts: - kind: WorkflowKind = WorkflowKind.FUNCTION - - if self.durable: - kind = WorkflowKind.DURABLE - - on_failure_job: CreateWorkflowJobOpts | None = None - - if self.function_on_failure is not None: - on_failure_job = CreateWorkflowJobOpts( - name=self.function_name + "-on-failure", - steps=[ - self.function_on_failure.to_step(), - ], - ) - - concurrency: WorkflowConcurrencyOpts | None = None - - if self.function_concurrency is not None: - self.function_concurrency.set_namespace(self.function_namespace) - concurrency = WorkflowConcurrencyOpts( - action=self.function_concurrency.get_action_name(), - max_runs=self.function_concurrency.max_runs, - limit_strategy=self.function_concurrency.limit_strategy, - ) - - validated_priority = ( - max(1, min(3, self.default_priority)) if self.default_priority else None - ) - if validated_priority != self.default_priority: - logger.warning( - "Warning: Default Priority Must be between 1 and 3 -- inclusively. Adjusted to be within the range." - ) - - return CreateWorkflowVersionOpts( - name=self.function_name, - kind=kind, - version=self.function_version, - event_triggers=self.function_on_events, - cron_triggers=self.function_on_crons, - schedule_timeout=self.function_schedule_timeout, - sticky=self.sticky, - on_failure_job=on_failure_job, - concurrency=concurrency, - jobs=[ - CreateWorkflowJobOpts( - name=self.function_name, - steps=[ - self.to_step(), - ], - ) - ], - default_priority=validated_priority, - ) - - def to_step(self) -> CreateWorkflowStepOpts: - return CreateWorkflowStepOpts( - readable_id=self.function_name, - action=self.get_action_name(), - timeout=self.function_timeout, - inputs="{}", - parents=[], - retries=self.function_retries, - rate_limits=self.function_rate_limits, - worker_labels=self.function_desired_worker_labels, - ) - - def get_action_name(self) -> str: - return self.function_namespace + ":" + self.function_name - - -class DurableContext(Context): - def run( - self, - function: str | Callable[[Context], Any], - input: JSONSerializableDict = {}, - key: str | None = None, - options: ChildTriggerWorkflowOptions | None = None, - ) -> "RunRef[T]": - worker_id = self.worker.id() - - workflow_name = function - - if not isinstance(function, str): - workflow_name = function.function_name - - # if ( - # options is not None - # and "sticky" in options - # and options["sticky"] == True - # and not self.worker.has_workflow(workflow_name) - # ): - # raise Exception( - # f"cannot run with sticky: workflow {workflow_name} is not registered on the worker" - # ) - - trigger_options = self._prepare_workflow_options(key, options, worker_id) - - return self.admin_client.run(function, input, trigger_options) diff --git a/hatchet_sdk/v2/concurrency.py b/hatchet_sdk/v2/concurrency.py deleted file mode 100644 index 73d9e3b4..00000000 --- a/hatchet_sdk/v2/concurrency.py +++ /dev/null @@ -1,47 +0,0 @@ -from typing import Any, Callable - -from hatchet_sdk.context.context import Context -from hatchet_sdk.contracts.workflows_pb2 import ( # type: ignore[attr-defined] - ConcurrencyLimitStrategy, -) - - -class ConcurrencyFunction: - def __init__( - self, - func: Callable[[Context], str], - name: str = "concurrency", - max_runs: int = 1, - limit_strategy: ConcurrencyLimitStrategy = ConcurrencyLimitStrategy.GROUP_ROUND_ROBIN, - ): - self.func = func - self.name = name - self.max_runs = max_runs - self.limit_strategy = limit_strategy - self.namespace = "default" - - def set_namespace(self, namespace: str) -> None: - self.namespace = namespace - - def get_action_name(self) -> str: - return self.namespace + ":" + self.name - - def __call__(self, *args: Any, **kwargs: Any) -> str: - return self.func(*args, **kwargs) - - def __str__(self) -> str: - return f"{self.name}({self.max_runs})" - - def __repr__(self) -> str: - return f"{self.name}({self.max_runs})" - - -def concurrency( - name: str = "", - max_runs: int = 1, - limit_strategy: ConcurrencyLimitStrategy = ConcurrencyLimitStrategy.GROUP_ROUND_ROBIN, -) -> Callable[[Callable[[Context], str]], ConcurrencyFunction]: - def inner(func: Callable[[Context], str]) -> ConcurrencyFunction: - return ConcurrencyFunction(func, name, max_runs, limit_strategy) - - return inner diff --git a/hatchet_sdk/v2/hatchet.py b/hatchet_sdk/v2/hatchet.py deleted file mode 100644 index 4dd3faf0..00000000 --- a/hatchet_sdk/v2/hatchet.py +++ /dev/null @@ -1,224 +0,0 @@ -from typing import Any, Callable, TypeVar, Union - -from hatchet_sdk import Worker -from hatchet_sdk.context.context import Context -from hatchet_sdk.contracts.workflows_pb2 import ( # type: ignore[attr-defined] - ConcurrencyLimitStrategy, - StickyStrategy, -) -from hatchet_sdk.hatchet import Hatchet as HatchetV1 -from hatchet_sdk.hatchet import workflow -from hatchet_sdk.labels import DesiredWorkerLabel -from hatchet_sdk.rate_limit import RateLimit -from hatchet_sdk.v2.callable import DurableContext, HatchetCallable -from hatchet_sdk.v2.concurrency import ConcurrencyFunction -from hatchet_sdk.worker.worker import register_on_worker - -T = TypeVar("T") - - -def function( - name: str = "", - auto_register: bool = True, - on_events: list[str] | None = None, - on_crons: list[str] | None = None, - version: str = "", - timeout: str = "60m", - schedule_timeout: str = "5m", - sticky: StickyStrategy = None, - retries: int = 0, - rate_limits: list[RateLimit] | None = None, - desired_worker_labels: dict[str, DesiredWorkerLabel] = {}, - concurrency: ConcurrencyFunction | None = None, - on_failure: Union["HatchetCallable[T]", None] = None, - default_priority: int | None = None, -) -> Callable[[Callable[[Context], str]], HatchetCallable[T]]: - def inner(func: Callable[[Context], T]) -> HatchetCallable[T]: - return HatchetCallable( - func=func, - name=name, - auto_register=auto_register, - on_events=on_events, - on_crons=on_crons, - version=version, - timeout=timeout, - schedule_timeout=schedule_timeout, - sticky=sticky, - retries=retries, - rate_limits=rate_limits, - desired_worker_labels=desired_worker_labels, - concurrency=concurrency, - on_failure=on_failure, - default_priority=default_priority, - ) - - return inner - - -def durable( - name: str = "", - auto_register: bool = True, - on_events: list[str] | None = None, - on_crons: list[str] | None = None, - version: str = "", - timeout: str = "60m", - schedule_timeout: str = "5m", - sticky: StickyStrategy = None, - retries: int = 0, - rate_limits: list[RateLimit] | None = None, - desired_worker_labels: dict[str, DesiredWorkerLabel] = {}, - concurrency: ConcurrencyFunction | None = None, - on_failure: HatchetCallable[T] | None = None, - default_priority: int | None = None, -) -> Callable[[HatchetCallable[T]], HatchetCallable[T]]: - def inner(func: HatchetCallable[T]) -> HatchetCallable[T]: - func.durable = True - - f = function( - name=name, - auto_register=auto_register, - on_events=on_events, - on_crons=on_crons, - version=version, - timeout=timeout, - schedule_timeout=schedule_timeout, - sticky=sticky, - retries=retries, - rate_limits=rate_limits, - desired_worker_labels=desired_worker_labels, - concurrency=concurrency, - on_failure=on_failure, - default_priority=default_priority, - ) - - resp = f(func) - - resp.durable = True - - return resp - - return inner - - -def concurrency( - name: str = "concurrency", - max_runs: int = 1, - limit_strategy: ConcurrencyLimitStrategy = ConcurrencyLimitStrategy.GROUP_ROUND_ROBIN, -) -> Callable[[Callable[[Context], str]], ConcurrencyFunction]: - def inner(func: Callable[[Context], str]) -> ConcurrencyFunction: - return ConcurrencyFunction(func, name, max_runs, limit_strategy) - - return inner - - -class Hatchet(HatchetV1): - dag = staticmethod(workflow) - concurrency = staticmethod(concurrency) - - functions: list[HatchetCallable[T]] = [] - - def function( - self, - name: str = "", - auto_register: bool = True, - on_events: list[str] | None = None, - on_crons: list[str] | None = None, - version: str = "", - timeout: str = "60m", - schedule_timeout: str = "5m", - retries: int = 0, - rate_limits: list[RateLimit] | None = None, - desired_worker_labels: dict[str, DesiredWorkerLabel] = {}, - concurrency: ConcurrencyFunction | None = None, - on_failure: Union["HatchetCallable[T]", None] = None, - default_priority: int | None = None, - ) -> Callable[[Callable[[Context], Any]], Callable[[Context], Any]]: - resp = function( - name=name, - auto_register=auto_register, - on_events=on_events, - on_crons=on_crons, - version=version, - timeout=timeout, - schedule_timeout=schedule_timeout, - retries=retries, - rate_limits=rate_limits, - desired_worker_labels=desired_worker_labels, - concurrency=concurrency, - on_failure=on_failure, - default_priority=default_priority, - ) - - def wrapper(func: Callable[[Context], str]) -> HatchetCallable[T]: - wrapped_resp = resp(func) - - if wrapped_resp.function_auto_register: - self.functions.append(wrapped_resp) - - wrapped_resp.with_namespace(self._client.config.namespace) - - return wrapped_resp - - return wrapper - - def durable( - self, - name: str = "", - auto_register: bool = True, - on_events: list[str] | None = None, - on_crons: list[str] | None = None, - version: str = "", - timeout: str = "60m", - schedule_timeout: str = "5m", - sticky: StickyStrategy = None, - retries: int = 0, - rate_limits: list[RateLimit] | None = None, - desired_worker_labels: dict[str, DesiredWorkerLabel] = {}, - concurrency: ConcurrencyFunction | None = None, - on_failure: Union["HatchetCallable[T]", None] = None, - default_priority: int | None = None, - ) -> Callable[[Callable[[DurableContext], Any]], Callable[[DurableContext], Any]]: - resp = durable( - name=name, - auto_register=auto_register, - on_events=on_events, - on_crons=on_crons, - version=version, - timeout=timeout, - schedule_timeout=schedule_timeout, - sticky=sticky, - retries=retries, - rate_limits=rate_limits, - desired_worker_labels=desired_worker_labels, - concurrency=concurrency, - on_failure=on_failure, - default_priority=default_priority, - ) - - def wrapper(func: HatchetCallable[T]) -> HatchetCallable[T]: - wrapped_resp = resp(func) - - if wrapped_resp.function_auto_register: - self.functions.append(wrapped_resp) - - wrapped_resp.with_namespace(self._client.config.namespace) - - return wrapped_resp - - return wrapper - - def worker( - self, name: str, max_runs: int | None = None, labels: dict[str, str | int] = {} - ): - worker = Worker( - name=name, - max_runs=max_runs, - labels=labels, - config=self._client.config, - debug=self._client.debug, - ) - - for func in self.functions: - register_on_worker(func, worker) - - return worker diff --git a/pyproject.toml b/pyproject.toml index 523685ed..3c6026c0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -84,27 +84,14 @@ extend_exclude = "hatchet_sdk/contracts/" [tool.mypy] strict = true files = [ - "hatchet_sdk/hatchet.py", - "hatchet_sdk/worker/worker.py", - "hatchet_sdk/context/context.py", - "hatchet_sdk/worker/runner/runner.py", - "hatchet_sdk/workflow.py", - "hatchet_sdk/utils/serialization.py", - "hatchet_sdk/utils/tracing.py", - "hatchet_sdk/utils/types.py", - "hatchet_sdk/utils/backoff.py", - "examples/**/*.py", - "hatchet_sdk/clients/rest/models/workflow_list.py", - "hatchet_sdk/clients/rest/models/workflow_run.py", - "hatchet_sdk/context/worker_context.py", - "hatchet_sdk/clients/dispatcher/dispatcher.py", - "hatchet_sdk/loader.py", - "hatchet_sdk/token.py", - "hatchet_sdk/clients/admin.py", - "hatchet_sdk/clients/events.py", + "." +] +exclude = [ + "hatchet_sdk/clients/rest", + "hatchet_sdk/clients/dispatcher", + "hatchet_sdk/contracts", ] follow_imports = "silent" -disable_error_code = ["unused-coroutine"] explicit_package_bases = true [tool.poetry.scripts] From db9af5a34f3f224f9aae77a83667878b953ce6ef Mon Sep 17 00:00:00 2001 From: mrkaye97 Date: Wed, 22 Jan 2025 15:37:53 -0500 Subject: [PATCH 23/53] feat: rm more v2 stuff --- examples/bulk_fanout/stream.py | 3 +- examples/default_priority/worker.py | 3 +- .../durable_sticky_with_affinity/worker.py | 78 ------------------- examples/fanout/stream.py | 3 +- examples/fanout/sync_stream.py | 3 +- examples/sync_to_async/worker.py | 3 +- examples/v2/simple/test_v2_worker.py | 26 ------- examples/v2/simple/worker.py | 44 ----------- hatchet_sdk/clients/admin.py | 8 +- hatchet_sdk/hatchet.py | 1 - hatchet_sdk/metadata.py | 2 +- hatchet_sdk/worker/runner/runner.py | 19 +---- .../runner/utils/error_with_traceback.py | 2 +- hatchet_sdk/worker/worker.py | 21 ----- poetry.lock | 14 +++- pyproject.toml | 3 +- 16 files changed, 27 insertions(+), 206 deletions(-) delete mode 100644 examples/durable_sticky_with_affinity/worker.py delete mode 100644 examples/v2/simple/test_v2_worker.py delete mode 100644 examples/v2/simple/worker.py diff --git a/examples/bulk_fanout/stream.py b/examples/bulk_fanout/stream.py index 2eb03648..c0d03388 100644 --- a/examples/bulk_fanout/stream.py +++ b/examples/bulk_fanout/stream.py @@ -6,10 +6,9 @@ from dotenv import load_dotenv -from hatchet_sdk import new_client +from hatchet_sdk import Hatchet, new_client from hatchet_sdk.clients.admin import TriggerWorkflowOptions from hatchet_sdk.clients.run_event_listener import StepRunEventType -from hatchet_sdk.v2.hatchet import Hatchet async def main() -> None: diff --git a/examples/default_priority/worker.py b/examples/default_priority/worker.py index 070d20f9..13d0cda0 100644 --- a/examples/default_priority/worker.py +++ b/examples/default_priority/worker.py @@ -3,8 +3,7 @@ from dotenv import load_dotenv -from hatchet_sdk import Context -from hatchet_sdk.v2.hatchet import Hatchet +from hatchet_sdk import Context, Hatchet load_dotenv() diff --git a/examples/durable_sticky_with_affinity/worker.py b/examples/durable_sticky_with_affinity/worker.py deleted file mode 100644 index 0e6036c2..00000000 --- a/examples/durable_sticky_with_affinity/worker.py +++ /dev/null @@ -1,78 +0,0 @@ -import asyncio -from typing import Any - -from dotenv import load_dotenv - -from hatchet_sdk import ( - ChildTriggerWorkflowOptions, - Context, - StickyStrategy, - WorkerLabelComparator, -) -from hatchet_sdk.labels import DesiredWorkerLabel -from hatchet_sdk.v2.callable import DurableContext -from hatchet_sdk.v2.hatchet import Hatchet - -load_dotenv() - -hatchet = Hatchet(debug=True) - - -@hatchet.durable( - sticky=StickyStrategy.HARD, - desired_worker_labels={ - "running_workflow": DesiredWorkerLabel( - value="True", - required=True, - comparator=WorkerLabelComparator.NOT_EQUAL, - ), - }, -) -async def my_durable_func(context: DurableContext) -> dict[str, Any]: - try: - ref = await context.aio.spawn_workflow( - "StickyChildWorkflow", {}, options=ChildTriggerWorkflowOptions(sticky=True) - ) - result = await ref.result() - except Exception as e: - result = str(e) - - await context.worker.async_upsert_labels({"running_workflow": "False"}) - return {"worker_result": result} - - -@hatchet.workflow(on_events=["sticky:child"], sticky=StickyStrategy.HARD) -class StickyChildWorkflow: - @hatchet.step( - desired_worker_labels={ - "running_workflow": DesiredWorkerLabel( - value="True", - required=True, - comparator=WorkerLabelComparator.NOT_EQUAL, - ), - }, - ) - async def child(self, context: Context) -> dict[str, str | None]: - await context.worker.async_upsert_labels({"running_workflow": "True"}) - - print(f"Heavy work started on {context.worker.id()}") - await asyncio.sleep(15) - print(f"Finished Heavy work on {context.worker.id()}") - - return {"worker": context.worker.id()} - - -def main() -> None: - worker = hatchet.worker( - "sticky-worker", - max_runs=10, - labels={"running_workflow": "False"}, - ) - - worker.register_workflow(StickyChildWorkflow()) - - worker.start() - - -if __name__ == "__main__": - main() diff --git a/examples/fanout/stream.py b/examples/fanout/stream.py index 2eb03648..c0d03388 100644 --- a/examples/fanout/stream.py +++ b/examples/fanout/stream.py @@ -6,10 +6,9 @@ from dotenv import load_dotenv -from hatchet_sdk import new_client +from hatchet_sdk import Hatchet, new_client from hatchet_sdk.clients.admin import TriggerWorkflowOptions from hatchet_sdk.clients.run_event_listener import StepRunEventType -from hatchet_sdk.v2.hatchet import Hatchet async def main() -> None: diff --git a/examples/fanout/sync_stream.py b/examples/fanout/sync_stream.py index e05b510c..d035ddc3 100644 --- a/examples/fanout/sync_stream.py +++ b/examples/fanout/sync_stream.py @@ -6,10 +6,9 @@ from dotenv import load_dotenv -from hatchet_sdk import new_client +from hatchet_sdk import Hatchet from hatchet_sdk.clients.admin import TriggerWorkflowOptions from hatchet_sdk.clients.run_event_listener import StepRunEventType -from hatchet_sdk.v2.hatchet import Hatchet def main() -> None: diff --git a/examples/sync_to_async/worker.py b/examples/sync_to_async/worker.py index 6ee1636a..600b454d 100644 --- a/examples/sync_to_async/worker.py +++ b/examples/sync_to_async/worker.py @@ -5,9 +5,8 @@ from dotenv import load_dotenv -from hatchet_sdk import Context, sync_to_async +from hatchet_sdk import Context, Hatchet, sync_to_async from hatchet_sdk.clients.admin import ChildTriggerWorkflowOptions -from hatchet_sdk.v2.hatchet import Hatchet os.environ["PYTHONASYNCIODEBUG"] = "1" load_dotenv() diff --git a/examples/v2/simple/test_v2_worker.py b/examples/v2/simple/test_v2_worker.py deleted file mode 100644 index c06dae1f..00000000 --- a/examples/v2/simple/test_v2_worker.py +++ /dev/null @@ -1,26 +0,0 @@ -import pytest - -from examples.v2.simple.worker import MyResultType, my_durable_func, my_func -from hatchet_sdk import Hatchet, Worker -from hatchet_sdk.workflow_run import RunRef - - -# requires scope module or higher for shared event loop -@pytest.mark.asyncio(scope="session") -@pytest.mark.parametrize("worker", ["v2_simple"], indirect=True) -async def test_durable(hatchet: Hatchet, worker: Worker) -> None: - durable_run: RunRef[dict[str, str]] = hatchet.admin.run( - my_durable_func, {"test": "test"} - ) - result = await durable_run.result() - - assert result == {"my_durable_func": "testing123"} - - -@pytest.mark.asyncio(scope="session") -@pytest.mark.parametrize("worker", ["v2_simple"], indirect=True) -async def test_func(hatchet: Hatchet, worker: Worker) -> None: - durable_run: RunRef[MyResultType] = hatchet.admin.run(my_func, {"test": "test"}) - result = await durable_run.result() - - assert result == {"my_func": "testing123"} diff --git a/examples/v2/simple/worker.py b/examples/v2/simple/worker.py deleted file mode 100644 index 215bdd0e..00000000 --- a/examples/v2/simple/worker.py +++ /dev/null @@ -1,44 +0,0 @@ -import json -import time -from typing import Any, TypedDict, cast - -from dotenv import load_dotenv - -from hatchet_sdk import Context -from hatchet_sdk.v2.callable import DurableContext -from hatchet_sdk.v2.hatchet import Hatchet -from hatchet_sdk.workflow_run import RunRef - -load_dotenv() - -hatchet = Hatchet(debug=True) - - -class MyResultType(TypedDict): - my_func: str - - -@hatchet.function() -def my_func(context: Context) -> MyResultType: - return MyResultType(my_func="testing123") - - -@hatchet.durable() -async def my_durable_func(context: DurableContext) -> dict[str, MyResultType | None]: - result = cast(dict[str, Any], await context.run(my_func, {"test": "test"}).result()) - - context.log(result) - - return {"my_durable_func": result.get("my_func")} - - -def main() -> None: - worker = hatchet.worker("test-worker", max_runs=5) - - hatchet.admin.run(my_durable_func, {"test": "test"}) - - worker.start() - - -if __name__ == "__main__": - main() diff --git a/hatchet_sdk/clients/admin.py b/hatchet_sdk/clients/admin.py index 5c80fc7a..9d71c2ff 100644 --- a/hatchet_sdk/clients/admin.py +++ b/hatchet_sdk/clients/admin.py @@ -267,12 +267,12 @@ async def run_workflows( namespace = options.namespace or self.namespace - workflow_run_requests: TriggerWorkflowRequest = [] + workflow_run_requests: list[TriggerWorkflowRequest] = [] for workflow in workflows: - workflow_name = workflow["workflow_name"] - input_data = workflow["input"] - options = workflow["options"] + workflow_name = workflow.workflow_name + input_data = workflow.input + options = workflow.options if namespace != "" and not workflow_name.startswith(self.namespace): workflow_name = f"{namespace}{workflow_name}" diff --git a/hatchet_sdk/hatchet.py b/hatchet_sdk/hatchet.py index 396209c5..60e733ae 100644 --- a/hatchet_sdk/hatchet.py +++ b/hatchet_sdk/hatchet.py @@ -19,7 +19,6 @@ from hatchet_sdk.labels import DesiredWorkerLabel from hatchet_sdk.loader import ClientConfig from hatchet_sdk.rate_limit import RateLimit -from hatchet_sdk.v2.callable import HatchetCallable from .client import Client, new_client, new_client_raw from .clients.admin import AdminClient diff --git a/hatchet_sdk/metadata.py b/hatchet_sdk/metadata.py index 38a31b8b..d4004c64 100644 --- a/hatchet_sdk/metadata.py +++ b/hatchet_sdk/metadata.py @@ -1,2 +1,2 @@ -def get_metadata(token: str): +def get_metadata(token: str) -> list[tuple[str, str]]: return [("authorization", "bearer " + token)] diff --git a/hatchet_sdk/worker/runner/runner.py b/hatchet_sdk/worker/runner/runner.py index f72fb04b..aa15b8d2 100644 --- a/hatchet_sdk/worker/runner/runner.py +++ b/hatchet_sdk/worker/runner/runner.py @@ -34,7 +34,6 @@ from hatchet_sdk.logger import logger from hatchet_sdk.utils.tracing import create_tracer, parse_carrier_from_metadata from hatchet_sdk.utils.types import WorkflowValidator -from hatchet_sdk.v2.callable import DurableContext from hatchet_sdk.worker.action_listener_process import ActionEvent from hatchet_sdk.worker.runner.utils.capture_logs import copy_context_vars, sr, wr @@ -276,23 +275,7 @@ def cleanup_run_id(self, run_id: str | None) -> None: if run_id in self.contexts: del self.contexts[run_id] - def create_context( - self, action: Action, action_func: Callable[..., Any] | None - ) -> Context | DurableContext: - if hasattr(action_func, "durable") and getattr(action_func, "durable"): - return DurableContext( - action, - self.dispatcher_client, - self.admin_client, - self.client.event, - self.client.rest, - self.client.workflow_listener, - self.workflow_run_event_listener, - self.worker_context, - self.client.config.namespace, - validator_registry=self.validator_registry, - ) - + def create_context(self, action: Action) -> Context: return Context( action, self.dispatcher_client, diff --git a/hatchet_sdk/worker/runner/utils/error_with_traceback.py b/hatchet_sdk/worker/runner/utils/error_with_traceback.py index 9c09602f..6aff1cb6 100644 --- a/hatchet_sdk/worker/runner/utils/error_with_traceback.py +++ b/hatchet_sdk/worker/runner/utils/error_with_traceback.py @@ -1,6 +1,6 @@ import traceback -def errorWithTraceback(message: str, e: Exception): +def errorWithTraceback(message: str, e: Exception) -> str: trace = "".join(traceback.format_exception(type(e), e, e.__traceback__)) return f"{message}\n{trace}" diff --git a/hatchet_sdk/worker/worker.py b/hatchet_sdk/worker/worker.py index b6ec1531..e7622f53 100644 --- a/hatchet_sdk/worker/worker.py +++ b/hatchet_sdk/worker/worker.py @@ -24,8 +24,6 @@ from hatchet_sdk.logger import logger from hatchet_sdk.utils.types import WorkflowValidator from hatchet_sdk.utils.typing import is_basemodel_subclass -from hatchet_sdk.v2.callable import HatchetCallable -from hatchet_sdk.v2.concurrency import ConcurrencyFunction from hatchet_sdk.worker.action_listener_process import worker_action_listener_process from hatchet_sdk.worker.runner.run_loop_manager import WorkerActionRunLoopManager from hatchet_sdk.workflow import WorkflowInterface @@ -371,22 +369,3 @@ def exit_forcefully(self) -> None: sys.exit( 1 ) # Exit immediately TODO - should we exit with 1 here, there may be other workers to cleanup - - -def register_on_worker(callable: HatchetCallable[T], worker: Worker) -> None: - worker.register_function(callable.get_action_name(), callable) - - if callable.function_on_failure is not None: - worker.register_function( - callable.function_on_failure.get_action_name(), callable.function_on_failure - ) - - if callable.function_concurrency is not None: - worker.register_function( - callable.function_concurrency.get_action_name(), - callable.function_concurrency, - ) - - opts = callable.to_workflow_opts() - - worker.register_workflow_from_opts(opts.name, opts) diff --git a/poetry.lock b/poetry.lock index 498f39b4..2462047e 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1856,6 +1856,18 @@ files = [ {file = "types_protobuf-5.29.1.20241207.tar.gz", hash = "sha256:2ebcadb8ab3ef2e3e2f067e0882906d64ba0dc65fc5b0fd7a8b692315b4a0be9"}, ] +[[package]] +name = "types-psutil" +version = "6.1.0.20241221" +description = "Typing stubs for psutil" +optional = false +python-versions = ">=3.8" +groups = ["lint"] +files = [ + {file = "types_psutil-6.1.0.20241221-py3-none-any.whl", hash = "sha256:8498dbe13285a9ba7d4b2fa934c569cc380efc74e3dacdb34ae16d2cdf389ec3"}, + {file = "types_psutil-6.1.0.20241221.tar.gz", hash = "sha256:600f5a36bd5e0eb8887f0e3f3ff2cf154d90690ad8123c8a707bba4ab94d3185"}, +] + [[package]] name = "typing-extensions" version = "4.12.2" @@ -2097,4 +2109,4 @@ type = ["pytest-mypy"] [metadata] lock-version = "2.1" python-versions = "^3.10" -content-hash = "59a1e9a4aafe7da78bfd9b85af64531167192e56dd0a46dc4c2e40e147cad40d" +content-hash = "e59b746d16c418856dbf00015dfb396a703e13961b53e0196bb511c530920e47" diff --git a/pyproject.toml b/pyproject.toml index 3c6026c0..30426c65 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,6 +48,7 @@ mypy = "^1.14.0" types-protobuf = "^5.28.3.20241030" black = "^24.10.0" isort = "^5.13.2" +types-psutil = "^6.1.0.20241221" [tool.poetry.group.test.dependencies] pytest-timeout = "^2.3.1" @@ -117,4 +118,4 @@ existing_loop = "examples.worker_existing_loop.worker:main" bulk_fanout = "examples.bulk_fanout.worker:main" retries_with_backoff = "examples.retries_with_backoff.worker:main" pydantic = "examples.pydantic.worker:main" -v2_simple = "examples.v2.simple.worker:main" + From c79c79be588fc75f6bae6c88adc8e223d87d24e0 Mon Sep 17 00:00:00 2001 From: mrkaye97 Date: Wed, 22 Jan 2025 16:02:52 -0500 Subject: [PATCH 24/53] fix: lots of type hints --- conftest.py | 6 +- examples/default_priority/worker.py | 39 ------- examples/sync_to_async/worker.py | 100 ------------------ hatchet_sdk/__init__.py | 5 +- hatchet_sdk/client.py | 14 +-- hatchet_sdk/clients/admin.py | 6 +- .../clients/dispatcher/action_listener.py | 2 +- hatchet_sdk/clients/dispatcher/dispatcher.py | 2 +- hatchet_sdk/clients/event_ts.py | 6 +- hatchet_sdk/connection.py | 10 +- hatchet_sdk/features/cron.py | 99 ++++++++++------- hatchet_sdk/features/scheduled.py | 100 +++++++++++------- hatchet_sdk/rate_limit.py | 7 +- hatchet_sdk/worker/action_listener_process.py | 22 ++-- hatchet_sdk/worker/runner/run_loop_manager.py | 13 +-- hatchet_sdk/worker/runner/runner.py | 2 +- hatchet_sdk/worker/worker.py | 1 + hatchet_sdk/workflow_run.py | 6 +- 18 files changed, 173 insertions(+), 267 deletions(-) delete mode 100644 examples/default_priority/worker.py delete mode 100644 examples/sync_to_async/worker.py diff --git a/conftest.py b/conftest.py index acd22dd8..e10df408 100644 --- a/conftest.py +++ b/conftest.py @@ -4,7 +4,7 @@ import time from io import BytesIO from threading import Thread -from typing import AsyncGenerator, Callable, cast +from typing import AsyncGenerator, Callable, Generator, cast import psutil import pytest @@ -24,7 +24,9 @@ def hatchet() -> Hatchet: @pytest.fixture() -def worker(request: pytest.FixtureRequest): +def worker( + request: pytest.FixtureRequest, +) -> Generator[subprocess.Popen[bytes], None, None]: example = cast(str, request.param) command = ["poetry", "run", example] diff --git a/examples/default_priority/worker.py b/examples/default_priority/worker.py deleted file mode 100644 index 13d0cda0..00000000 --- a/examples/default_priority/worker.py +++ /dev/null @@ -1,39 +0,0 @@ -import asyncio -from typing import TypedDict - -from dotenv import load_dotenv - -from hatchet_sdk import Context, Hatchet - -load_dotenv() - -hatchet = Hatchet(debug=True) - - -class MyResultType(TypedDict): - return_string: str - - -@hatchet.function(default_priority=2) -async def high_prio_func(context: Context) -> MyResultType: - await asyncio.sleep(5) - return MyResultType(return_string="High Priority Return") - - -@hatchet.function(default_priority=1) -async def low_prio_func(context: Context) -> MyResultType: - await asyncio.sleep(5) - return MyResultType(return_string="Low Priority Return") - - -def main() -> None: - worker = hatchet.worker("example-priority-worker", max_runs=1) - hatchet.admin.run(high_prio_func, {"test": "test"}) - hatchet.admin.run(high_prio_func, {"test": "test"}) - hatchet.admin.run(low_prio_func, {"test": "test"}) - hatchet.admin.run(low_prio_func, {"test": "test"}) - worker.start() - - -if __name__ == "__main__": - main() diff --git a/examples/sync_to_async/worker.py b/examples/sync_to_async/worker.py deleted file mode 100644 index 600b454d..00000000 --- a/examples/sync_to_async/worker.py +++ /dev/null @@ -1,100 +0,0 @@ -import asyncio -import os -import time -from typing import Any - -from dotenv import load_dotenv - -from hatchet_sdk import Context, Hatchet, sync_to_async -from hatchet_sdk.clients.admin import ChildTriggerWorkflowOptions - -os.environ["PYTHONASYNCIODEBUG"] = "1" -load_dotenv() - -hatchet = Hatchet(debug=True) - - -@hatchet.function() -async def fanout_sync_async(context: Context) -> dict[str, Any]: - print("spawning child") - - context.put_stream("spawning...") - results = [] - - n = context.workflow_input().get("n", 10) - - start_time = time.time() - for i in range(n): - results.append( - ( - await context.aio.spawn_workflow( - "Child", - {"a": str(i)}, - key=f"child{i}", - options=ChildTriggerWorkflowOptions( - additional_metadata={"hello": "earth"} - ), - ) - ).result() - ) - - result = await asyncio.gather(*results) - - execution_time = time.time() - start_time - print(f"Completed in {execution_time:.2f} seconds") - - return {"results": result} - - -@hatchet.workflow(on_events=["child:create"]) -class Child: - ###### Example Functions ###### - def sync_blocking_function(self) -> dict[str, str]: - time.sleep(5) - return {"type": "sync_blocking"} - - @sync_to_async # this makes the function async safe! - def decorated_sync_blocking_function(self) -> dict[str, str]: - time.sleep(5) - return {"type": "decorated_sync_blocking"} - - @sync_to_async # this makes the async function loop safe! - async def async_blocking_function(self) -> dict[str, str]: - time.sleep(5) - return {"type": "async_blocking"} - - ###### Hatchet Steps ###### - @hatchet.step() - async def handle_blocking_sync_in_async(self, context: Context) -> dict[str, str]: - wrapped_blocking_function = sync_to_async(self.sync_blocking_function) - - # This will now be async safe! - data = await wrapped_blocking_function() - return {"blocking_status": "success", "data": data} - - @hatchet.step() - async def handle_decorated_blocking_sync_in_async( - self, context: Context - ) -> dict[str, str]: - data = await self.decorated_sync_blocking_function() - return {"blocking_status": "success", "data": data} - - @hatchet.step() - async def handle_blocking_async_in_async(self, context: Context) -> dict[str, str]: - data = await self.async_blocking_function() - return {"blocking_status": "success", "data": data} - - @hatchet.step() - async def non_blocking_async(self, context: Context) -> dict[str, str]: - await asyncio.sleep(5) - return {"nonblocking_status": "success"} - - -def main() -> None: - worker = hatchet.worker("fanout-worker", max_runs=50) - worker.register_workflow(Child()) - worker.start() - - -if __name__ == "__main__": - main() diff --git a/hatchet_sdk/__init__.py b/hatchet_sdk/__init__.py index 3162c25c..fc81cced 100644 --- a/hatchet_sdk/__init__.py +++ b/hatchet_sdk/__init__.py @@ -137,8 +137,9 @@ from .clients.run_event_listener import StepRunEventType, WorkflowRunEventType from .context.context import Context from .context.worker_context import WorkerContext -from .hatchet import ClientConfig, Hatchet, concurrency, on_failure_step, step, workflow -from .worker import Worker, WorkerStartOptions, WorkerStatus +from .hatchet import Hatchet, concurrency, on_failure_step, step, workflow +from .loader import ClientConfig +from .worker.worker import Worker, WorkerStartOptions, WorkerStatus from .workflow import ConcurrencyExpression __all__ = [ diff --git a/hatchet_sdk/client.py b/hatchet_sdk/client.py index a956a1d5..4baf2702 100644 --- a/hatchet_sdk/client.py +++ b/hatchet_sdk/client.py @@ -30,7 +30,7 @@ def from_environment( defaults: ClientConfig = ClientConfig(), debug: bool = False, *opts_functions: Callable[[ClientConfig], None], - ): + ) -> "Client": try: loop = asyncio.get_running_loop() except RuntimeError: @@ -47,7 +47,7 @@ def from_config( cls, config: ClientConfig = ClientConfig(), debug: bool = False, - ): + ) -> "Client": try: loop = asyncio.get_running_loop() except RuntimeError: @@ -60,7 +60,7 @@ def from_config( if config.host_port is None: raise ValueError("Host and port are required") - conn: grpc.Channel = new_conn(config) + conn: grpc.Channel = new_conn(config, False) # Instantiate clients event_client = new_event(conn, config) @@ -106,13 +106,5 @@ def __init__( self.debug = debug -def with_host_port(host: str, port: int): - def with_host_port_impl(config: ClientConfig): - config.host = host - config.port = port - - return with_host_port_impl - - new_client = Client.from_environment new_client_raw = Client.from_config diff --git a/hatchet_sdk/clients/admin.py b/hatchet_sdk/clients/admin.py index 9d71c2ff..518bb38f 100644 --- a/hatchet_sdk/clients/admin.py +++ b/hatchet_sdk/clients/admin.py @@ -281,10 +281,8 @@ async def run_workflows( request = self._prepare_workflow_request(workflow_name, input_data, options) workflow_run_requests.append(request) - request = BulkTriggerWorkflowRequest(workflows=workflow_run_requests) - resp: BulkTriggerWorkflowResponse = await self.aio_client.BulkTriggerWorkflow( - request, + BulkTriggerWorkflowRequest(workflows=workflow_run_requests), metadata=get_metadata(self.token), ) @@ -364,7 +362,7 @@ async def schedule_workflow( class AdminClient(AdminClientBase): def __init__(self, config: ClientConfig): - conn = new_conn(config) + conn = new_conn(config, False) self.config = config self.client = WorkflowServiceStub(conn) # type: ignore[no-untyped-call] self.aio = AdminClientAioImpl(config) diff --git a/hatchet_sdk/clients/dispatcher/action_listener.py b/hatchet_sdk/clients/dispatcher/action_listener.py index d38087a4..aec9869e 100644 --- a/hatchet_sdk/clients/dispatcher/action_listener.py +++ b/hatchet_sdk/clients/dispatcher/action_listener.py @@ -139,7 +139,7 @@ class ActionListener: missed_heartbeats: int = field(default=0, init=False) def __post_init__(self): - self.client = DispatcherStub(new_conn(self.config)) + self.client = DispatcherStub(new_conn(self.config, False)) self.aio_client = DispatcherStub(new_conn(self.config, True)) self.token = self.config.token diff --git a/hatchet_sdk/clients/dispatcher/dispatcher.py b/hatchet_sdk/clients/dispatcher/dispatcher.py index e52aca2a..e956557e 100644 --- a/hatchet_sdk/clients/dispatcher/dispatcher.py +++ b/hatchet_sdk/clients/dispatcher/dispatcher.py @@ -41,7 +41,7 @@ class DispatcherClient: config: ClientConfig def __init__(self, config: ClientConfig): - conn = new_conn(config) + conn = new_conn(config, False) self.client = DispatcherStub(conn) # type: ignore[no-untyped-call] aio_conn = new_conn(config, True) diff --git a/hatchet_sdk/clients/event_ts.py b/hatchet_sdk/clients/event_ts.py index 1d3c3978..c32ea8c0 100644 --- a/hatchet_sdk/clients/event_ts.py +++ b/hatchet_sdk/clients/event_ts.py @@ -7,16 +7,16 @@ class Event_ts(asyncio.Event): Event_ts is a subclass of asyncio.Event that allows for thread-safe setting and clearing of the event. """ - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) if self._loop is None: self._loop = asyncio.get_event_loop() - def set(self): + def set(self) -> None: if not self._loop.is_closed(): self._loop.call_soon_threadsafe(super().set) - def clear(self): + def clear(self) -> None: self._loop.call_soon_threadsafe(super().clear) diff --git a/hatchet_sdk/connection.py b/hatchet_sdk/connection.py index 185395e4..787ad7c1 100644 --- a/hatchet_sdk/connection.py +++ b/hatchet_sdk/connection.py @@ -1,5 +1,5 @@ import os -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Literal, overload import grpc @@ -7,7 +7,13 @@ from hatchet_sdk.loader import ClientConfig -def new_conn(config: "ClientConfig", aio=False): +def new_conn(config: "ClientConfig", aio: Literal[False]) -> grpc.Channel: ... + + +def new_conn(config: "ClientConfig", aio: Literal[True]) -> grpc.aio.Channel: ... + + +def new_conn(config: "ClientConfig", aio: bool) -> grpc.Channel | grpc.aio.Channel: credentials: grpc.ChannelCredentials | None = None diff --git a/hatchet_sdk/features/cron.py b/hatchet_sdk/features/cron.py index e03da751..a03a1c19 100644 --- a/hatchet_sdk/features/cron.py +++ b/hatchet_sdk/features/cron.py @@ -1,4 +1,4 @@ -from typing import Any, Union +from typing import Any, Union, cast from pydantic import BaseModel, Field, field_validator @@ -24,12 +24,13 @@ class CreateCronTriggerJSONSerializableDict(BaseModel): additional_metadata (dict[str, str]): Additional metadata associated with the cron trigger (e.g. {"key1": "value1", "key2": "value2"}). """ - expression: str = None + expression: str input: JSONSerializableDict = Field(default_factory=dict) additional_metadata: JSONSerializableDict = Field(default_factory=dict) @field_validator("expression") - def validate_cron_expression(cls, v): + @classmethod + def validate_cron_expression(cls, v: str) -> str: """ Validates the cron expression to ensure it adheres to the expected format. @@ -122,10 +123,11 @@ def delete(self, cron_trigger: Union[str, CronWorkflows]) -> None: Args: cron_trigger (Union[str, CronWorkflows]): The cron trigger ID or CronWorkflows instance to delete. """ - id_ = cron_trigger - if isinstance(cron_trigger, CronWorkflows): - id_ = cron_trigger.metadata.id - self._client.rest.cron_delete(id_) + self._client.rest.cron_delete( + cron_trigger.metadata.id + if isinstance(cron_trigger, CronWorkflows) + else cron_trigger + ) def list( self, @@ -150,13 +152,16 @@ def list( Returns: CronWorkflowsList: A list of cron workflows. """ - return self._client.rest.cron_list( - offset=offset, - limit=limit, - workflow_id=workflow_id, - additional_metadata=additional_metadata, - order_by_field=order_by_field, - order_by_direction=order_by_direction, + return cast( + CronWorkflowsList, + self._client.rest.cron_list( + offset=offset, + limit=limit, + workflow_id=workflow_id, + additional_metadata=additional_metadata, + order_by_field=order_by_field, + order_by_direction=order_by_direction, + ), ) def get(self, cron_trigger: Union[str, CronWorkflows]) -> CronWorkflows: @@ -169,10 +174,14 @@ def get(self, cron_trigger: Union[str, CronWorkflows]) -> CronWorkflows: Returns: CronWorkflows: The requested cron workflow instance. """ - id_ = cron_trigger - if isinstance(cron_trigger, CronWorkflows): - id_ = cron_trigger.metadata.id - return self._client.rest.cron_get(id_) + return cast( + CronWorkflows, + self._client.rest.cron_get( + cron_trigger.metadata.id + if isinstance(cron_trigger, CronWorkflows) + else cron_trigger + ), + ) class CronClientAsync: @@ -219,12 +228,15 @@ async def create( expression=expression, input=input, additional_metadata=additional_metadata ) - return await self._client.rest.aio.cron_create( - workflow_name=workflow_name, - cron_name=cron_name, - expression=validated_input.expression, - input=validated_input.input, - additional_metadata=validated_input.additional_metadata, + return cast( + CronWorkflows, + await self._client.rest.aio.cron_create( + workflow_name=workflow_name, + cron_name=cron_name, + expression=validated_input.expression, + input=validated_input.input, + additional_metadata=validated_input.additional_metadata, + ), ) async def delete(self, cron_trigger: Union[str, CronWorkflows]) -> None: @@ -234,10 +246,11 @@ async def delete(self, cron_trigger: Union[str, CronWorkflows]) -> None: Args: cron_trigger (Union[str, CronWorkflows]): The cron trigger ID or CronWorkflows instance to delete. """ - id_ = cron_trigger - if isinstance(cron_trigger, CronWorkflows): - id_ = cron_trigger.metadata.id - await self._client.rest.aio.cron_delete(id_) + await self._client.rest.aio.cron_delete( + cron_trigger.metadata.id + if isinstance(cron_trigger, CronWorkflows) + else cron_trigger + ) async def list( self, @@ -262,13 +275,16 @@ async def list( Returns: CronWorkflowsList: A list of cron workflows. """ - return await self._client.rest.aio.cron_list( - offset=offset, - limit=limit, - workflow_id=workflow_id, - additional_metadata=additional_metadata, - order_by_field=order_by_field, - order_by_direction=order_by_direction, + return cast( + CronWorkflowsList, + await self._client.rest.aio.cron_list( + offset=offset, + limit=limit, + workflow_id=workflow_id, + additional_metadata=additional_metadata, + order_by_field=order_by_field, + order_by_direction=order_by_direction, + ), ) async def get(self, cron_trigger: Union[str, CronWorkflows]) -> CronWorkflows: @@ -281,7 +297,12 @@ async def get(self, cron_trigger: Union[str, CronWorkflows]) -> CronWorkflows: Returns: CronWorkflows: The requested cron workflow instance. """ - id_ = cron_trigger - if isinstance(cron_trigger, CronWorkflows): - id_ = cron_trigger.metadata.id - return await self._client.rest.aio.cron_get(id_) + + return cast( + CronWorkflows, + await self._client.rest.aio.cron_get( + cron_trigger.metadata.id + if isinstance(cron_trigger, CronWorkflows) + else cron_trigger + ), + ) diff --git a/hatchet_sdk/features/scheduled.py b/hatchet_sdk/features/scheduled.py index 6c232f53..312ad527 100644 --- a/hatchet_sdk/features/scheduled.py +++ b/hatchet_sdk/features/scheduled.py @@ -1,5 +1,5 @@ import datetime -from typing import Any, Coroutine, Dict, List, Optional, Union +from typing import Any, Coroutine, Dict, List, Optional, Union, cast from pydantic import BaseModel, Field @@ -30,7 +30,7 @@ class CreateScheduledTriggerJSONSerializableDict(BaseModel): input: JSONSerializableDict = Field(default_factory=dict) additional_metadata: JSONSerializableDict = Field(default_factory=dict) - trigger_at: datetime.datetime | None = None + trigger_at: datetime.datetime class ScheduledClient: @@ -78,11 +78,14 @@ def create( trigger_at=trigger_at, input=input, additional_metadata=additional_metadata ) - return self._client.rest.schedule_create( - workflow_name, - validated_input.trigger_at, - validated_input.input, - validated_input.additional_metadata, + return cast( + ScheduledWorkflows, + self._client.rest.schedule_create( + workflow_name, + validated_input.trigger_at, + validated_input.input, + validated_input.additional_metadata, + ), ) def delete(self, scheduled: Union[str, ScheduledWorkflows]) -> None: @@ -92,10 +95,11 @@ def delete(self, scheduled: Union[str, ScheduledWorkflows]) -> None: Args: scheduled (Union[str, ScheduledWorkflows]): The scheduled workflow trigger ID or ScheduledWorkflows instance to delete. """ - id_ = scheduled - if isinstance(scheduled, ScheduledWorkflows): - id_ = scheduled.metadata.id - self._client.rest.schedule_delete(id_) + self._client.rest.schedule_delete( + scheduled.metadata.id + if isinstance(scheduled, ScheduledWorkflows) + else scheduled + ) def list( self, @@ -120,13 +124,16 @@ def list( Returns: List[ScheduledWorkflows]: A list of scheduled workflows matching the criteria. """ - return self._client.rest.schedule_list( - offset=offset, - limit=limit, - workflow_id=workflow_id, - additional_metadata=additional_metadata, - order_by_field=order_by_field, - order_by_direction=order_by_direction, + return cast( + ScheduledWorkflowsList, + self._client.rest.schedule_list( + offset=offset, + limit=limit, + workflow_id=workflow_id, + additional_metadata=additional_metadata, + order_by_field=order_by_field, + order_by_direction=order_by_direction, + ), ) def get(self, scheduled: Union[str, ScheduledWorkflows]) -> ScheduledWorkflows: @@ -139,10 +146,14 @@ def get(self, scheduled: Union[str, ScheduledWorkflows]) -> ScheduledWorkflows: Returns: ScheduledWorkflows: The requested scheduled workflow instance. """ - id_ = scheduled - if isinstance(scheduled, ScheduledWorkflows): - id_ = scheduled.metadata.id - return self._client.rest.schedule_get(id_) + return cast( + ScheduledWorkflows, + self._client.rest.schedule_get( + scheduled.metadata.id + if isinstance(scheduled, ScheduledWorkflows) + else scheduled + ), + ) class ScheduledClientAsync: @@ -183,8 +194,11 @@ async def create( Returns: ScheduledWorkflows: The created scheduled workflow instance. """ - return await self._client.rest.aio.schedule_create( - workflow_name, trigger_at, input, additional_metadata + return cast( + ScheduledWorkflows, + await self._client.rest.aio.schedule_create( + workflow_name, trigger_at, input, additional_metadata + ), ) async def delete(self, scheduled: Union[str, ScheduledWorkflows]) -> None: @@ -194,10 +208,11 @@ async def delete(self, scheduled: Union[str, ScheduledWorkflows]) -> None: Args: scheduled (Union[str, ScheduledWorkflows]): The scheduled workflow trigger ID or ScheduledWorkflows instance to delete. """ - id_ = scheduled - if isinstance(scheduled, ScheduledWorkflows): - id_ = scheduled.metadata.id - await self._client.rest.aio.schedule_delete(id_) + await self._client.rest.aio.schedule_delete( + scheduled.metadata.id + if isinstance(scheduled, ScheduledWorkflows) + else scheduled + ) async def list( self, @@ -222,13 +237,16 @@ async def list( Returns: ScheduledWorkflowsList: A list of scheduled workflows matching the criteria. """ - return await self._client.rest.aio.schedule_list( - offset=offset, - limit=limit, - workflow_id=workflow_id, - additional_metadata=additional_metadata, - order_by_field=order_by_field, - order_by_direction=order_by_direction, + return cast( + ScheduledWorkflowsList, + await self._client.rest.aio.schedule_list( + offset=offset, + limit=limit, + workflow_id=workflow_id, + additional_metadata=additional_metadata, + order_by_field=order_by_field, + order_by_direction=order_by_direction, + ), ) async def get( @@ -243,7 +261,11 @@ async def get( Returns: ScheduledWorkflows: The requested scheduled workflow instance. """ - id_ = scheduled - if isinstance(scheduled, ScheduledWorkflows): - id_ = scheduled.metadata.id - return await self._client.rest.aio.schedule_get(id_) + return cast( + ScheduledWorkflows, + await self._client.rest.aio.schedule_get( + scheduled.metadata.id + if isinstance(scheduled, ScheduledWorkflows) + else scheduled + ), + ) diff --git a/hatchet_sdk/rate_limit.py b/hatchet_sdk/rate_limit.py index 0d7b9143..f651cdd8 100644 --- a/hatchet_sdk/rate_limit.py +++ b/hatchet_sdk/rate_limit.py @@ -1,7 +1,8 @@ from dataclasses import dataclass +from enum import Enum from typing import Union -from celpy import CELEvalError, Environment +from celpy import CELEvalError, Environment # type: ignore from hatchet_sdk.contracts.workflows_pb2 import CreateStepRateLimit @@ -15,7 +16,7 @@ def validate_cel_expression(expr: str) -> bool: return False -class RateLimitDuration: +class RateLimitDuration(str, Enum): SECOND = "SECOND" MINUTE = "MINUTE" HOUR = "HOUR" @@ -73,7 +74,7 @@ class RateLimit: _req: CreateStepRateLimit = None - def __post_init__(self): + def __post_init__(self) -> None: # juggle the key and key_expr fields key = self.static_key key_expression = self.dynamic_key diff --git a/hatchet_sdk/worker/action_listener_process.py b/hatchet_sdk/worker/action_listener_process.py index 08508607..2186624c 100644 --- a/hatchet_sdk/worker/action_listener_process.py +++ b/hatchet_sdk/worker/action_listener_process.py @@ -51,8 +51,8 @@ class WorkerActionListenerProcess: actions: List[str] max_runs: int config: ClientConfig - action_queue: Queue - event_queue: Queue + action_queue: Queue[Action] + event_queue: Queue[ActionEvent] handle_kill: bool = True debug: bool = False labels: dict = field(default_factory=dict) @@ -112,7 +112,7 @@ async def _get_event(self): loop = asyncio.get_running_loop() return await loop.run_in_executor(None, self.event_queue.get) - async def start_event_send_loop(self): + async def start_event_send_loop(self) -> None: while True: event: ActionEvent = await self._get_event() if event == STOP_LOOP: @@ -122,7 +122,7 @@ async def start_event_send_loop(self): logger.debug(f"tx: event: {event.action.action_id}/{event.type}") asyncio.create_task(self.send_event(event)) - async def start_blocked_main_loop(self): + async def start_blocked_main_loop(self) -> None: threshold = 1 while not self.killing: count = 0 @@ -135,7 +135,7 @@ async def start_blocked_main_loop(self): logger.warning(f"{BLOCKED_THREAD_WARNING}: Waiting Steps {count}") await asyncio.sleep(1) - async def send_event(self, event: ActionEvent, retry_attempt: int = 1): + async def send_event(self, event: ActionEvent, retry_attempt: int = 1) -> None: try: match event.action.action_type: # FIXME: all events sent from an execution of a function are of type ActionType.START_STEP_RUN since @@ -185,10 +185,10 @@ async def send_event(self, event: ActionEvent, retry_attempt: int = 1): await exp_backoff_sleep(retry_attempt, 1) await self.send_event(event, retry_attempt + 1) - def now(self): + def now(self) -> float: return time.time() - async def start_action_loop(self): + async def start_action_loop(self) -> None: try: async for action in self.listener: if action is None: @@ -241,7 +241,7 @@ async def start_action_loop(self): if not self.killing: await self.exit_gracefully(skip_unregister=True) - async def cleanup(self): + async def cleanup(self) -> None: self.killing = True if self.listener is not None: @@ -249,7 +249,7 @@ async def cleanup(self): self.event_queue.put(STOP_LOOP) - async def exit_gracefully(self, skip_unregister=False): + async def exit_gracefully(self) -> None: if self.killing: return @@ -262,12 +262,12 @@ async def exit_gracefully(self, skip_unregister=False): logger.info("action listener closed") - def exit_forcefully(self): + def exit_forcefully(self) -> None: asyncio.run(self.cleanup()) logger.debug("forcefully closing listener...") -def worker_action_listener_process(*args, **kwargs): +def worker_action_listener_process(*args: Any, **kwargs: Any) -> None: async def run(): process = WorkerActionListenerProcess(*args, **kwargs) await process.start() diff --git a/hatchet_sdk/worker/runner/run_loop_manager.py b/hatchet_sdk/worker/runner/run_loop_manager.py index 27ed788c..2039fb7c 100644 --- a/hatchet_sdk/worker/runner/run_loop_manager.py +++ b/hatchet_sdk/worker/runner/run_loop_manager.py @@ -10,6 +10,7 @@ from hatchet_sdk.loader import ClientConfig from hatchet_sdk.logger import logger from hatchet_sdk.utils.types import WorkflowValidator +from hatchet_sdk.worker.action_listener_process import ActionEvent from hatchet_sdk.worker.runner.runner import Runner from hatchet_sdk.worker.runner.utils.capture_logs import capture_logs @@ -25,8 +26,8 @@ class WorkerActionRunLoopManager: validator_registry: dict[str, WorkflowValidator] max_runs: int | None config: ClientConfig - action_queue: Queue - event_queue: Queue + action_queue: Queue[str] + event_queue: Queue[ActionEvent] loop: asyncio.AbstractEventLoop handle_kill: bool = True debug: bool = False @@ -37,16 +38,16 @@ class WorkerActionRunLoopManager: killing: bool = field(init=False, default=False) runner: Runner = field(init=False, default=None) - def __post_init__(self): + def __post_init__(self) -> None: if self.debug: logger.setLevel(logging.DEBUG) self.client = new_client_raw(self.config, self.debug) self.start() - def start(self, retry_count=1): + def start(self, retry_count=1) -> None: k = self.loop.create_task(self.async_start(retry_count)) - async def async_start(self, retry_count=1): + async def async_start(self, retry_count: int = 1) -> None: await capture_logs( self.client.logInterceptor, self.client.event, @@ -91,7 +92,7 @@ async def _start_action_loop(self) -> None: self.runner.run(action) logger.debug("action runner loop stopped") - async def _get_action(self): + async def _get_action(self) -> str: return await self.loop.run_in_executor(None, self.action_queue.get) async def exit_gracefully(self) -> None: diff --git a/hatchet_sdk/worker/runner/runner.py b/hatchet_sdk/worker/runner/runner.py index aa15b8d2..06269f54 100644 --- a/hatchet_sdk/worker/runner/runner.py +++ b/hatchet_sdk/worker/runner/runner.py @@ -301,7 +301,7 @@ async def handle_start_step_run(self, action: Action) -> None: # Find the corresponding action function from the registry action_func = self.action_registry.get(action_name) - context = self.create_context(action, action_func) + context = self.create_context(action) self.contexts[action.step_run_id] = context diff --git a/hatchet_sdk/worker/worker.py b/hatchet_sdk/worker/worker.py index e7622f53..74ef0931 100644 --- a/hatchet_sdk/worker/worker.py +++ b/hatchet_sdk/worker/worker.py @@ -360,6 +360,7 @@ def exit_forcefully(self) -> None: logger.debug(f"forcefully stopping worker: {self.name}") + ## TODO: `self.close` needs to be awaited / used self.close() if self.action_listener_process: diff --git a/hatchet_sdk/workflow_run.py b/hatchet_sdk/workflow_run.py index f29b47aa..c1fac873 100644 --- a/hatchet_sdk/workflow_run.py +++ b/hatchet_sdk/workflow_run.py @@ -1,5 +1,5 @@ import asyncio -from typing import Any, Coroutine, Generic, TypeVar +from typing import Any, Coroutine, Generic, TypeVar, cast from hatchet_sdk.clients.run_event_listener import ( RunEventListener, @@ -37,11 +37,11 @@ def sync_result(self) -> dict: with EventLoopThread() as loop: coro = self.workflow_listener.result(self.workflow_run_id) future = asyncio.run_coroutine_threadsafe(coro, loop) - return future.result() + return cast(dict[str, Any], future.result()) else: coro = self.workflow_listener.result(self.workflow_run_id) future = asyncio.run_coroutine_threadsafe(coro, loop) - return future.result() + return cast(dict[str, Any], future.result()) T = TypeVar("T") From b1dee95619699f769906eb026876ca35337e3b2c Mon Sep 17 00:00:00 2001 From: mrkaye97 Date: Wed, 22 Jan 2025 16:15:30 -0500 Subject: [PATCH 25/53] fix: whole bunch more --- examples/blocked_async/event.py | 5 +-- examples/bulk_fanout/bulk_trigger.py | 31 +++++++++---------- examples/bulk_fanout/trigger.py | 9 +++--- examples/fanout/trigger.py | 6 +--- examples/logger/event.py | 4 ++- examples/manual_trigger/stream.py | 2 +- hatchet_sdk/clients/workflow_listener.py | 17 +++++----- hatchet_sdk/connection.py | 6 ++++ hatchet_sdk/worker/action_listener_process.py | 8 ++--- hatchet_sdk/worker/runner/run_loop_manager.py | 10 +++--- .../worker/runner/utils/capture_logs.py | 4 +-- hatchet_sdk/workflow_run.py | 10 +++--- 12 files changed, 59 insertions(+), 53 deletions(-) diff --git a/examples/blocked_async/event.py b/examples/blocked_async/event.py index 116b227d..3dcc4fcc 100644 --- a/examples/blocked_async/event.py +++ b/examples/blocked_async/event.py @@ -6,7 +6,8 @@ client = new_client() -# client.event.push("user:create", {"test": "test"}) client.event.push( - "user:create", {"test": "test"}, options={"additional_metadata": {"hello": "moon"}} + "user:create", + {"test": "test"}, + options=PushEventOptions(additional_metadata={"hello": "moon"}), ) diff --git a/examples/bulk_fanout/bulk_trigger.py b/examples/bulk_fanout/bulk_trigger.py index d0606673..51905063 100644 --- a/examples/bulk_fanout/bulk_trigger.py +++ b/examples/bulk_fanout/bulk_trigger.py @@ -7,7 +7,7 @@ from dotenv import load_dotenv from hatchet_sdk import new_client -from hatchet_sdk.clients.admin import TriggerWorkflowOptions +from hatchet_sdk.clients.admin import TriggerWorkflowOptions, WorkflowRunDict from hatchet_sdk.clients.rest.models.workflow_run import WorkflowRun from hatchet_sdk.clients.run_event_listener import StepRunEventType @@ -16,25 +16,22 @@ async def main() -> None: load_dotenv() hatchet = new_client() - workflowRuns: list[dict[str, Any]] = [] - - # we are going to run the BulkParent workflow 20 which will trigger the Child workflows n times for each n in range(20) - for i in range(20): - workflowRuns.append( - { - "workflow_name": "BulkParent", - "input": {"n": i}, - "options": { - "additional_metadata": { - "bulk-trigger": i, - "hello-{i}": "earth-{i}", - }, - }, - } + workflow_runs = [ + WorkflowRunDict( + workflow_name="BulkParent", + input={"n": i}, + options=TriggerWorkflowOptions( + additional_metadata={ + "bulk-trigger": i, + "hello-{i}": "earth-{i}", + } + ), ) + for i in range(20) + ] workflowRunRefs = hatchet.admin.run_workflows( - workflowRuns, + workflow_runs, ) results = await asyncio.gather( diff --git a/examples/bulk_fanout/trigger.py b/examples/bulk_fanout/trigger.py index 1a1b3f17..fe44a627 100644 --- a/examples/bulk_fanout/trigger.py +++ b/examples/bulk_fanout/trigger.py @@ -7,6 +7,7 @@ from hatchet_sdk import new_client from hatchet_sdk.clients.admin import TriggerWorkflowOptions +from hatchet_sdk.clients.events import PushEventOptions from hatchet_sdk.clients.rest.models.workflow_run import WorkflowRun from hatchet_sdk.clients.run_event_listener import StepRunEventType @@ -15,10 +16,10 @@ async def main() -> None: load_dotenv() hatchet = new_client() - workflowRuns: WorkflowRun = [] # type: ignore[assignment] - - event = hatchet.event.push( - "parent:create", {"n": 999}, {"additional_metadata": {"no-dedupe": "world"}} + hatchet.event.push( + "parent:create", + {"n": 999}, + PushEventOptions(additional_metadata={"no-dedupe": "world"}), ) diff --git a/examples/fanout/trigger.py b/examples/fanout/trigger.py index c34d01b3..e156322c 100644 --- a/examples/fanout/trigger.py +++ b/examples/fanout/trigger.py @@ -1,13 +1,9 @@ import asyncio -import base64 -import json -import os from dotenv import load_dotenv from hatchet_sdk import new_client from hatchet_sdk.clients.admin import TriggerWorkflowOptions -from hatchet_sdk.clients.run_event_listener import StepRunEventType async def main() -> None: @@ -17,7 +13,7 @@ async def main() -> None: hatchet.admin.run_workflow( "Parent", {"test": "test"}, - options={"additional_metadata": {"hello": "moon"}}, + options=TriggerWorkflowOptions(additional_metadata={"hello": "moon"}), ) diff --git a/examples/logger/event.py b/examples/logger/event.py index 5f7818f6..3dcc4fcc 100644 --- a/examples/logger/event.py +++ b/examples/logger/event.py @@ -7,5 +7,7 @@ client = new_client() client.event.push( - "user:create", {"test": "test"}, options={"additional_metadata": {"hello": "moon"}} + "user:create", + {"test": "test"}, + options=PushEventOptions(additional_metadata={"hello": "moon"}), ) diff --git a/examples/manual_trigger/stream.py b/examples/manual_trigger/stream.py index bc4adfab..c05a73b6 100644 --- a/examples/manual_trigger/stream.py +++ b/examples/manual_trigger/stream.py @@ -17,7 +17,7 @@ async def main() -> None: workflowRun = hatchet.admin.run_workflow( "ManualTriggerWorkflow", {"test": "test"}, - options={"additional_metadata": {"hello": "moon"}}, + options=TriggerWorkflowOptions(additional_metadata={"hello": "moon"}), ) listener = workflowRun.stream() diff --git a/hatchet_sdk/clients/workflow_listener.py b/hatchet_sdk/clients/workflow_listener.py index b1131587..3e4e150b 100644 --- a/hatchet_sdk/clients/workflow_listener.py +++ b/hatchet_sdk/clients/workflow_listener.py @@ -1,7 +1,7 @@ import asyncio import json from collections.abc import AsyncIterator -from typing import AsyncGenerator +from typing import Any, AsyncGenerator, cast import grpc from grpc._cython import cygrpc @@ -224,7 +224,7 @@ async def subscribe(self, workflow_run_id: str): finally: self.cleanup_subscription(subscription_id) - async def result(self, workflow_run_id: str): + async def result(self, workflow_run_id: str) -> dict[str, Any]: from hatchet_sdk.clients.admin import DedupeViolationErr event = await self.subscribe(workflow_run_id) @@ -248,7 +248,7 @@ async def result(self, workflow_run_id: str): return results - async def _retry_subscribe(self): + async def _retry_subscribe(self) -> WorkflowRunEvent | None: retries = 0 while retries < DEFAULT_WORKFLOW_LISTENER_RETRY_COUNT: @@ -260,12 +260,13 @@ async def _retry_subscribe(self): if self.curr_requester != 0: self.requests.put_nowait(self.curr_requester) - listener = self.client.SubscribeToWorkflowRuns( - self._request(), - metadata=get_metadata(self.token), + return cast( + WorkflowRunEvent, + self.client.SubscribeToWorkflowRuns( + self._request(), + metadata=get_metadata(self.token), + ), ) - - return listener except grpc.RpcError as e: if e.code() == grpc.StatusCode.UNAVAILABLE: retries = retries + 1 diff --git a/hatchet_sdk/connection.py b/hatchet_sdk/connection.py index 787ad7c1..53ff8c43 100644 --- a/hatchet_sdk/connection.py +++ b/hatchet_sdk/connection.py @@ -7,9 +7,11 @@ from hatchet_sdk.loader import ClientConfig +@overload def new_conn(config: "ClientConfig", aio: Literal[False]) -> grpc.Channel: ... +@overload def new_conn(config: "ClientConfig", aio: Literal[True]) -> grpc.aio.Channel: ... @@ -26,6 +28,10 @@ def new_conn(config: "ClientConfig", aio: bool) -> grpc.Channel | grpc.aio.Chann credentials = grpc.ssl_channel_credentials(root_certificates=root) elif config.tls_config.tls_strategy == "mtls": + assert config.tls_config.ca_file + assert config.tls_config.key_file + assert config.tls_config.cert_file + root = open(config.tls_config.ca_file, "rb").read() private_key = open(config.tls_config.key_file, "rb").read() certificate_chain = open(config.tls_config.cert_file, "rb").read() diff --git a/hatchet_sdk/worker/action_listener_process.py b/hatchet_sdk/worker/action_listener_process.py index 2186624c..89d4ed02 100644 --- a/hatchet_sdk/worker/action_listener_process.py +++ b/hatchet_sdk/worker/action_listener_process.py @@ -77,7 +77,7 @@ def __post_init__(self): signal.SIGQUIT, lambda: asyncio.create_task(self.exit_gracefully()) ) - async def start(self, retry_attempt=0): + async def start(self, retry_attempt: int = 0) -> None: if retry_attempt > 5: logger.error("could not start action listener") return @@ -108,13 +108,13 @@ async def start(self, retry_attempt=0): self.blocked_main_loop = asyncio.create_task(self.start_blocked_main_loop()) # TODO move event methods to separate class - async def _get_event(self): + async def _get_event(self) -> ActionEvent: loop = asyncio.get_running_loop() return await loop.run_in_executor(None, self.event_queue.get) async def start_event_send_loop(self) -> None: while True: - event: ActionEvent = await self._get_event() + event = await self._get_event() if event == STOP_LOOP: logger.debug("stopping event send loop...") break @@ -126,7 +126,7 @@ async def start_blocked_main_loop(self) -> None: threshold = 1 while not self.killing: count = 0 - for step_run_id, start_time in self.running_step_runs.items(): + for _, start_time in self.running_step_runs.items(): diff = self.now() - start_time if diff > threshold: count += 1 diff --git a/hatchet_sdk/worker/runner/run_loop_manager.py b/hatchet_sdk/worker/runner/run_loop_manager.py index 2039fb7c..34a705f8 100644 --- a/hatchet_sdk/worker/runner/run_loop_manager.py +++ b/hatchet_sdk/worker/runner/run_loop_manager.py @@ -26,7 +26,7 @@ class WorkerActionRunLoopManager: validator_registry: dict[str, WorkflowValidator] max_runs: int | None config: ClientConfig - action_queue: Queue[str] + action_queue: Queue[Action] event_queue: Queue[ActionEvent] loop: asyncio.AbstractEventLoop handle_kill: bool = True @@ -44,7 +44,7 @@ def __post_init__(self) -> None: self.client = new_client_raw(self.config, self.debug) self.start() - def start(self, retry_count=1) -> None: + def start(self, retry_count: int = 1) -> None: k = self.loop.create_task(self.async_start(retry_count)) async def async_start(self, retry_count: int = 1) -> None: @@ -64,6 +64,7 @@ async def _async_start(self, retry_count: int = 1) -> None: def cleanup(self) -> None: self.killing = True + ## TODO: The action queue is a queue of `Action`, so I don't think this will work self.action_queue.put(STOP_LOOP) async def wait_for_tasks(self) -> None: @@ -84,7 +85,8 @@ async def _start_action_loop(self) -> None: logger.debug(f"'{self.name}' waiting for {list(self.action_registry.keys())}") while not self.killing: - action: Action = await self._get_action() + action = await self._get_action() + ## TODO: This is a queue of `Action`, so I don't think this will work if action == STOP_LOOP: logger.debug("stopping action runner loop...") break @@ -92,7 +94,7 @@ async def _start_action_loop(self) -> None: self.runner.run(action) logger.debug("action runner loop stopped") - async def _get_action(self) -> str: + async def _get_action(self) -> Action: return await self.loop.run_in_executor(None, self.action_queue.get) async def exit_gracefully(self) -> None: diff --git a/hatchet_sdk/worker/runner/utils/capture_logs.py b/hatchet_sdk/worker/runner/utils/capture_logs.py index 08c57de8..245de4c1 100644 --- a/hatchet_sdk/worker/runner/utils/capture_logs.py +++ b/hatchet_sdk/worker/runner/utils/capture_logs.py @@ -32,7 +32,7 @@ def filter(self, record): class CustomLogHandler(logging.StreamHandler): - def __init__(self, event_client: EventClient, stream=None): + def __init__(self, event_client: EventClient, stream: StringIO | None = None): super().__init__(stream) self.logger_thread_pool = ThreadPoolExecutor(max_workers=1) self.event_client = event_client @@ -46,7 +46,7 @@ def _log(self, line: str, step_run_id: str | None): except Exception as e: logger.error(f"Error logging: {e}") - def emit(self, record): + def emit(self, record: logging.LogRecord) -> None: super().emit(record) log_entry = self.format(record) diff --git a/hatchet_sdk/workflow_run.py b/hatchet_sdk/workflow_run.py index c1fac873..3233eba7 100644 --- a/hatchet_sdk/workflow_run.py +++ b/hatchet_sdk/workflow_run.py @@ -22,26 +22,26 @@ def __init__( self.workflow_listener = workflow_listener self.workflow_run_event_listener = workflow_run_event_listener - def __str__(self): + def __str__(self) -> str: return self.workflow_run_id def stream(self) -> RunEventListener: return self.workflow_run_event_listener.stream(self.workflow_run_id) - def result(self) -> Coroutine: + def result(self) -> Coroutine[None, None, dict[str, Any]]: return self.workflow_listener.result(self.workflow_run_id) - def sync_result(self) -> dict: + def sync_result(self) -> dict[str, Any]: loop = get_active_event_loop() if loop is None: with EventLoopThread() as loop: coro = self.workflow_listener.result(self.workflow_run_id) future = asyncio.run_coroutine_threadsafe(coro, loop) - return cast(dict[str, Any], future.result()) + return future.result() else: coro = self.workflow_listener.result(self.workflow_run_id) future = asyncio.run_coroutine_threadsafe(coro, loop) - return cast(dict[str, Any], future.result()) + return future.result() T = TypeVar("T") From 4bd64be4f742d9e64df4ce1681d59ab2623e2dee Mon Sep 17 00:00:00 2001 From: mrkaye97 Date: Wed, 22 Jan 2025 16:30:22 -0500 Subject: [PATCH 26/53] fix: more, down to 60ish --- hatchet_sdk/clients/event_ts.py | 4 +- hatchet_sdk/clients/run_event_listener.py | 53 +++++++++++-------- hatchet_sdk/clients/workflow_listener.py | 15 +++--- hatchet_sdk/connection.py | 11 ++-- hatchet_sdk/rate_limit.py | 2 +- .../worker/runner/utils/capture_logs.py | 25 +++++---- 6 files changed, 64 insertions(+), 46 deletions(-) diff --git a/hatchet_sdk/clients/event_ts.py b/hatchet_sdk/clients/event_ts.py index c32ea8c0..cb40cc98 100644 --- a/hatchet_sdk/clients/event_ts.py +++ b/hatchet_sdk/clients/event_ts.py @@ -9,7 +9,7 @@ class Event_ts(asyncio.Event): def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) - if self._loop is None: + if self._loop is None: # type: ignore[has-type] self._loop = asyncio.get_event_loop() def set(self) -> None: @@ -20,7 +20,7 @@ def clear(self) -> None: self._loop.call_soon_threadsafe(super().clear) -async def read_with_interrupt(listener: Any, interrupt: Event_ts): +async def read_with_interrupt(listener: Any, interrupt: Event_ts) -> Any: try: result = await listener.read() return result diff --git a/hatchet_sdk/clients/run_event_listener.py b/hatchet_sdk/clients/run_event_listener.py index b5db6a74..c1f6c650 100644 --- a/hatchet_sdk/clients/run_event_listener.py +++ b/hatchet_sdk/clients/run_event_listener.py @@ -1,6 +1,7 @@ import asyncio import json -from typing import AsyncGenerator +from enum import Enum +from typing import Any, AsyncGenerator, Callable, Generator, cast import grpc @@ -21,7 +22,7 @@ DEFAULT_ACTION_LISTENER_RETRY_COUNT = 5 -class StepRunEventType: +class StepRunEventType(str, Enum): STEP_RUN_EVENT_TYPE_STARTED = "STEP_RUN_EVENT_TYPE_STARTED" STEP_RUN_EVENT_TYPE_COMPLETED = "STEP_RUN_EVENT_TYPE_COMPLETED" STEP_RUN_EVENT_TYPE_FAILED = "STEP_RUN_EVENT_TYPE_FAILED" @@ -30,7 +31,7 @@ class StepRunEventType: STEP_RUN_EVENT_TYPE_STREAM = "STEP_RUN_EVENT_TYPE_STREAM" -class WorkflowRunEventType: +class WorkflowRunEventType(str, Enum): WORKFLOW_RUN_EVENT_TYPE_STARTED = "WORKFLOW_RUN_EVENT_TYPE_STARTED" WORKFLOW_RUN_EVENT_TYPE_COMPLETED = "WORKFLOW_RUN_EVENT_TYPE_COMPLETED" WORKFLOW_RUN_EVENT_TYPE_FAILED = "WORKFLOW_RUN_EVENT_TYPE_FAILED" @@ -62,14 +63,14 @@ def __init__(self, type: StepRunEventType, payload: str): self.payload = payload -def new_listener(config: ClientConfig): +def new_listener(config: ClientConfig) -> "RunEventListenerClient": return RunEventListenerClient(config=config) class RunEventListener: - workflow_run_id: str = None - additional_meta_kv: tuple[str, str] = None + workflow_run_id: str | None = None + additional_meta_kv: tuple[str, str] | None = None def __init__(self, client: DispatcherStub, token: str): self.client = client @@ -77,7 +78,9 @@ def __init__(self, client: DispatcherStub, token: str): self.token = token @classmethod - def for_run_id(cls, workflow_run_id: str, client: DispatcherStub, token: str): + def for_run_id( + cls, workflow_run_id: str, client: DispatcherStub, token: str + ) -> "RunEventListener": listener = RunEventListener(client, token) listener.workflow_run_id = workflow_run_id return listener @@ -85,21 +88,21 @@ def for_run_id(cls, workflow_run_id: str, client: DispatcherStub, token: str): @classmethod def for_additional_meta( cls, key: str, value: str, client: DispatcherStub, token: str - ): + ) -> "RunEventListener": listener = RunEventListener(client, token) listener.additional_meta_kv = (key, value) return listener - def abort(self): + def abort(self) -> None: self.stop_signal = True - def __aiter__(self): + def __aiter__(self) -> AsyncGenerator[StepRunEvent, None]: return self._generator() - async def __anext__(self): + async def __anext__(self) -> StepRunEvent: return await self._generator().__anext__() - def __iter__(self): + def __iter__(self) -> Generator[StepRunEvent, None, None]: try: loop = asyncio.get_event_loop() except RuntimeError as e: @@ -145,6 +148,7 @@ async def _generator(self) -> AsyncGenerator[StepRunEvent, None]: try: if workflow_event.eventPayload: + ## TODO: Should this be `dumps` instead? payload = json.loads(workflow_event.eventPayload) except Exception as e: payload = workflow_event.eventPayload @@ -194,7 +198,7 @@ async def _generator(self) -> AsyncGenerator[StepRunEvent, None]: break # Raise StopAsyncIteration to properly end the generator - async def retry_subscribe(self): + async def retry_subscribe(self) -> AsyncGenerator[WorkflowEvent, None]: retries = 0 while retries < DEFAULT_ACTION_LISTENER_RETRY_COUNT: @@ -203,11 +207,14 @@ async def retry_subscribe(self): await asyncio.sleep(DEFAULT_ACTION_LISTENER_RETRY_INTERVAL) if self.workflow_run_id is not None: - return self.client.SubscribeToWorkflowEvents( - SubscribeToWorkflowEventsRequest( - workflowRunId=self.workflow_run_id, + return cast( + WorkflowEvent, + self.client.SubscribeToWorkflowEvents( + SubscribeToWorkflowEventsRequest( + workflowRunId=self.workflow_run_id, + ), + metadata=get_metadata(self.token), ), - metadata=get_metadata(self.token), ) elif self.additional_meta_kv is not None: return self.client.SubscribeToWorkflowEvents( @@ -226,6 +233,8 @@ async def retry_subscribe(self): else: raise ValueError(f"gRPC error: {e}") + raise Exception("Failed to subscribe to workflow events") + class RunEventListenerClient: def __init__(self, config: ClientConfig): @@ -233,10 +242,10 @@ def __init__(self, config: ClientConfig): self.config = config self.client: DispatcherStub = None - def stream_by_run_id(self, workflow_run_id: str): + def stream_by_run_id(self, workflow_run_id: str) -> RunEventListener: return self.stream(workflow_run_id) - def stream(self, workflow_run_id: str): + def stream(self, workflow_run_id: str) -> RunEventListener: if not isinstance(workflow_run_id, str) and hasattr(workflow_run_id, "__str__"): workflow_run_id = str(workflow_run_id) @@ -246,14 +255,16 @@ def stream(self, workflow_run_id: str): return RunEventListener.for_run_id(workflow_run_id, self.client, self.token) - def stream_by_additional_metadata(self, key: str, value: str): + def stream_by_additional_metadata(self, key: str, value: str) -> RunEventListener: if not self.client: aio_conn = new_conn(self.config, True) self.client = DispatcherStub(aio_conn) return RunEventListener.for_additional_meta(key, value, self.client, self.token) - async def on(self, workflow_run_id: str, handler: callable = None): + async def on( + self, workflow_run_id: str, handler: Callable[[StepRunEvent], Any] | None = None + ) -> None: async for event in self.stream(workflow_run_id): # call the handler if provided if handler: diff --git a/hatchet_sdk/clients/workflow_listener.py b/hatchet_sdk/clients/workflow_listener.py index 3e4e150b..937846be 100644 --- a/hatchet_sdk/clients/workflow_listener.py +++ b/hatchet_sdk/clients/workflow_listener.py @@ -31,7 +31,7 @@ def __init__(self, id: int, workflow_run_id: str): self.workflow_run_id = workflow_run_id self.queue: asyncio.Queue[WorkflowRunEvent | None] = asyncio.Queue() - async def __aiter__(self): + async def __aiter__(self) -> "_Subscription": return self async def __anext__(self) -> WorkflowRunEvent: @@ -45,10 +45,10 @@ async def get(self) -> WorkflowRunEvent: return event - async def put(self, item: WorkflowRunEvent): + async def put(self, item: WorkflowRunEvent) -> None: await self.queue.put(item) - async def close(self): + async def close(self) -> None: await self.queue.put(None) @@ -187,8 +187,7 @@ def cleanup_subscription(self, subscription_id: int): del self.subscriptionsToWorkflows[subscription_id] del self.events[subscription_id] - async def subscribe(self, workflow_run_id: str): - init_producer: asyncio.Task = None + async def subscribe(self, workflow_run_id: str) -> WorkflowRunEvent: try: # create a new subscription id, place a mutex on the counter await self.subscription_counter_lock.acquire() @@ -216,9 +215,7 @@ async def subscribe(self, workflow_run_id: str): if not self.listener_task or self.listener_task.done(): self.listener_task = asyncio.create_task(self._init_producer()) - event = await self.events[subscription_id].get() - - return event + return await self.events[subscription_id].get() except asyncio.CancelledError: raise finally: @@ -272,3 +269,5 @@ async def _retry_subscribe(self) -> WorkflowRunEvent | None: retries = retries + 1 else: raise ValueError(f"gRPC error: {e}") + + raise ValueError("Failed to connect to workflow run listener") diff --git a/hatchet_sdk/connection.py b/hatchet_sdk/connection.py index 53ff8c43..2373d8dd 100644 --- a/hatchet_sdk/connection.py +++ b/hatchet_sdk/connection.py @@ -1,5 +1,5 @@ import os -from typing import TYPE_CHECKING, Any, Literal, overload +from typing import TYPE_CHECKING, Any, Literal, cast, overload import grpc @@ -16,7 +16,6 @@ def new_conn(config: "ClientConfig", aio: Literal[True]) -> grpc.aio.Channel: .. def new_conn(config: "ClientConfig", aio: bool) -> grpc.Channel | grpc.aio.Channel: - credentials: grpc.ChannelCredentials | None = None # load channel credentials @@ -44,7 +43,7 @@ def new_conn(config: "ClientConfig", aio: bool) -> grpc.Channel | grpc.aio.Chann start = grpc if not aio else grpc.aio - channel_options = [ + channel_options: list[tuple[str, str | int]] = [ ("grpc.max_send_message_length", config.grpc_max_send_message_length), ("grpc.max_receive_message_length", config.grpc_max_recv_message_length), ("grpc.keepalive_time_ms", 10 * 1000), @@ -73,4 +72,8 @@ def new_conn(config: "ClientConfig", aio: bool) -> grpc.Channel | grpc.aio.Chann credentials=credentials, options=channel_options, ) - return conn + + return cast( + grpc.Channel | grpc.aio.Channel, + conn, + ) diff --git a/hatchet_sdk/rate_limit.py b/hatchet_sdk/rate_limit.py index f651cdd8..f9f574a4 100644 --- a/hatchet_sdk/rate_limit.py +++ b/hatchet_sdk/rate_limit.py @@ -72,7 +72,7 @@ class RateLimit: limit: Union[int, str, None] = None duration: RateLimitDuration = RateLimitDuration.MINUTE - _req: CreateStepRateLimit = None + _req: CreateStepRateLimit | None = None def __post_init__(self) -> None: # juggle the key and key_expr fields diff --git a/hatchet_sdk/worker/runner/utils/capture_logs.py b/hatchet_sdk/worker/runner/utils/capture_logs.py index 245de4c1..30cc7e70 100644 --- a/hatchet_sdk/worker/runner/utils/capture_logs.py +++ b/hatchet_sdk/worker/runner/utils/capture_logs.py @@ -3,7 +3,7 @@ import logging from concurrent.futures import ThreadPoolExecutor from io import StringIO -from typing import Any, Coroutine +from typing import Any, Awaitable, Callable, Coroutine, ParamSpec, TypeVar from hatchet_sdk import logger from hatchet_sdk.clients.events import EventClient @@ -25,9 +25,10 @@ def copy_context_vars(ctx_vars, func, *args, **kwargs): class InjectingFilter(logging.Filter): # For some reason, only the InjectingFilter has access to the contextvars method sr.get(), # otherwise we would use emit within the CustomLogHandler - def filter(self, record): - record.workflow_run_id = wr.get() - record.step_run_id = sr.get() + def filter(self, record) -> bool: + ## TODO: Change how we do this to not assign to the log record + record.workflow_run_id = wr.get() # type: ignore + record.step_run_id = sr.get() # type: ignore return True @@ -50,16 +51,20 @@ def emit(self, record: logging.LogRecord) -> None: super().emit(record) log_entry = self.format(record) - self.logger_thread_pool.submit(self._log, log_entry, record.step_run_id) + + ## TODO: Change how we do this to not assign to the log record + self.logger_thread_pool.submit(self._log, log_entry, record.step_run_id) # type: ignore + + +T = TypeVar("T") +P = ParamSpec("P") def capture_logs( - logger: logging.Logger, - event_client: EventClient, - func: Coroutine[Any, Any, Any], -): + logger: logging.Logger, event_client: "EventClient", func: Callable[P, Awaitable[T]] +) -> Callable[P, Awaitable[T]]: @functools.wraps(func) - async def wrapper(*args, **kwargs): + async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: if not logger: raise Exception("No logger configured on client") From b07ead8e82cc86513321d9ca5ec80125ac69acd7 Mon Sep 17 00:00:00 2001 From: mrkaye97 Date: Wed, 22 Jan 2025 16:58:02 -0500 Subject: [PATCH 27/53] fix: down to 30ish --- hatchet_sdk/client.py | 2 +- hatchet_sdk/clients/run_event_listener.py | 4 ++-- hatchet_sdk/clients/workflow_listener.py | 10 ++++---- hatchet_sdk/context/context.py | 4 ++-- hatchet_sdk/worker/action_listener_process.py | 22 ++++++++--------- hatchet_sdk/worker/runner/run_loop_manager.py | 4 ++-- .../worker/runner/utils/capture_logs.py | 24 ++++++++++++------- 7 files changed, 38 insertions(+), 32 deletions(-) diff --git a/hatchet_sdk/client.py b/hatchet_sdk/client.py index 4baf2702..a9bc8e60 100644 --- a/hatchet_sdk/client.py +++ b/hatchet_sdk/client.py @@ -20,7 +20,7 @@ class Client: dispatcher: DispatcherClient event: EventClient rest: RestApi - workflow_listener: PooledWorkflowRunListener + workflow_listener: PooledWorkflowRunListener | None logInterceptor: Logger debug: bool = False diff --git a/hatchet_sdk/clients/run_event_listener.py b/hatchet_sdk/clients/run_event_listener.py index c1f6c650..5faf6e43 100644 --- a/hatchet_sdk/clients/run_event_listener.py +++ b/hatchet_sdk/clients/run_event_listener.py @@ -251,14 +251,14 @@ def stream(self, workflow_run_id: str) -> RunEventListener: if not self.client: aio_conn = new_conn(self.config, True) - self.client = DispatcherStub(aio_conn) + self.client = DispatcherStub(aio_conn) # type: ignore[no-untyped-call] return RunEventListener.for_run_id(workflow_run_id, self.client, self.token) def stream_by_additional_metadata(self, key: str, value: str) -> RunEventListener: if not self.client: aio_conn = new_conn(self.config, True) - self.client = DispatcherStub(aio_conn) + self.client = DispatcherStub(aio_conn) # type: ignore[no-untyped-call] return RunEventListener.for_additional_meta(key, value, self.client, self.token) diff --git a/hatchet_sdk/clients/workflow_listener.py b/hatchet_sdk/clients/workflow_listener.py index 937846be..86302c40 100644 --- a/hatchet_sdk/clients/workflow_listener.py +++ b/hatchet_sdk/clients/workflow_listener.py @@ -65,14 +65,14 @@ class PooledWorkflowRunListener: requests: asyncio.Queue[SubscribeToWorkflowRunsRequest] = asyncio.Queue() listener: AsyncGenerator[WorkflowRunEvent, None] = None - listener_task: asyncio.Task = None + listener_task: asyncio.Task[None] = None curr_requester: int = 0 # events have keys of the format workflow_run_id + subscription_id events: dict[int, _Subscription] = {} - interrupter: asyncio.Task = None + interrupter: asyncio.Task[None] | None = None def __init__(self, config: ClientConfig): conn = new_conn(config, True) @@ -80,7 +80,7 @@ def __init__(self, config: ClientConfig): self.token = config.token self.config = config - async def _interrupter(self): + async def _interrupter(self) -> None: """ _interrupter runs in a separate thread and interrupts the listener according to a configurable duration. """ @@ -89,7 +89,7 @@ async def _interrupter(self): if self.interrupt is not None: self.interrupt.set() - async def _init_producer(self): + async def _init_producer(self) -> None: try: if not self.listener: while True: @@ -178,7 +178,7 @@ async def _request(self) -> AsyncIterator[SubscribeToWorkflowRunsRequest]: yield request self.requests.task_done() - def cleanup_subscription(self, subscription_id: int): + def cleanup_subscription(self, subscription_id: int) -> None: workflow_run_id = self.subscriptionsToWorkflows[subscription_id] if workflow_run_id in self.workflowsToSubscriptions: diff --git a/hatchet_sdk/context/context.py b/hatchet_sdk/context/context.py index fa52219f..6b86ca36 100644 --- a/hatchet_sdk/context/context.py +++ b/hatchet_sdk/context/context.py @@ -81,7 +81,7 @@ def __init__( admin_client: AdminClient, event_client: EventClient, rest_client: RestApi, - workflow_listener: PooledWorkflowRunListener, + workflow_listener: PooledWorkflowRunListener | None, workflow_run_event_listener: RunEventListenerClient, worker: WorkerContext, namespace: str = "", @@ -149,7 +149,7 @@ def __init__( admin_client: AdminClient, event_client: EventClient, rest_client: RestApi, - workflow_listener: PooledWorkflowRunListener, + workflow_listener: PooledWorkflowRunListener | None, workflow_run_event_listener: RunEventListenerClient, worker: WorkerContext, namespace: str = "", diff --git a/hatchet_sdk/worker/action_listener_process.py b/hatchet_sdk/worker/action_listener_process.py index 89d4ed02..f69fa71e 100644 --- a/hatchet_sdk/worker/action_listener_process.py +++ b/hatchet_sdk/worker/action_listener_process.py @@ -8,12 +8,12 @@ import grpc -from hatchet_sdk.clients.dispatcher.action_listener import Action -from hatchet_sdk.clients.dispatcher.dispatcher import ( +from hatchet_sdk.clients.dispatcher.action_listener import ( + Action, ActionListener, GetActionListenerRequest, - new_dispatcher, ) +from hatchet_sdk.clients.dispatcher.dispatcher import new_dispatcher from hatchet_sdk.contracts.dispatcher_pb2 import ( GROUP_KEY_EVENT_TYPE_STARTED, STEP_EVENT_TYPE_STARTED, @@ -41,7 +41,7 @@ class ActionEvent: ) -def noop_handler(): +def noop_handler() -> None: pass @@ -55,18 +55,18 @@ class WorkerActionListenerProcess: event_queue: Queue[ActionEvent] handle_kill: bool = True debug: bool = False - labels: dict = field(default_factory=dict) + labels: dict[str, str | int] = field(default_factory=dict) - listener: ActionListener = field(init=False, default=None) + listener: ActionListener = field(init=False) killing: bool = field(init=False, default=False) - action_loop_task: asyncio.Task = field(init=False, default=None) - event_send_loop_task: asyncio.Task = field(init=False, default=None) + action_loop_task: asyncio.Task[None] | None = field(init=False, default=None) + event_send_loop_task: asyncio.Task[None] | None = field(init=False, default=None) running_step_runs: Mapping[str, float] = field(init=False, default_factory=dict) - def __post_init__(self): + def __post_init__(self) -> None: if self.debug: logger.setLevel(logging.DEBUG) @@ -239,7 +239,7 @@ async def start_action_loop(self) -> None: finally: logger.info("action loop closed") if not self.killing: - await self.exit_gracefully(skip_unregister=True) + await self.exit_gracefully() async def cleanup(self) -> None: self.killing = True @@ -268,7 +268,7 @@ def exit_forcefully(self) -> None: def worker_action_listener_process(*args: Any, **kwargs: Any) -> None: - async def run(): + async def run() -> None: process = WorkerActionListenerProcess(*args, **kwargs) await process.start() # Keep the process running diff --git a/hatchet_sdk/worker/runner/run_loop_manager.py b/hatchet_sdk/worker/runner/run_loop_manager.py index 34a705f8..95baed1f 100644 --- a/hatchet_sdk/worker/runner/run_loop_manager.py +++ b/hatchet_sdk/worker/runner/run_loop_manager.py @@ -33,10 +33,10 @@ class WorkerActionRunLoopManager: debug: bool = False labels: dict[str, str | int] = field(default_factory=dict) - client: Client = field(init=False, default=None) + client: Client = field(init=False) killing: bool = field(init=False, default=False) - runner: Runner = field(init=False, default=None) + runner: Runner | None = field(init=False, default=None) def __post_init__(self) -> None: if self.debug: diff --git a/hatchet_sdk/worker/runner/utils/capture_logs.py b/hatchet_sdk/worker/runner/utils/capture_logs.py index 30cc7e70..39429827 100644 --- a/hatchet_sdk/worker/runner/utils/capture_logs.py +++ b/hatchet_sdk/worker/runner/utils/capture_logs.py @@ -2,8 +2,9 @@ import functools import logging from concurrent.futures import ThreadPoolExecutor +from contextvars import ContextVar from io import StringIO -from typing import Any, Awaitable, Callable, Coroutine, ParamSpec, TypeVar +from typing import Any, Awaitable, Callable, Coroutine, ItemsView, ParamSpec, TypeVar from hatchet_sdk import logger from hatchet_sdk.clients.events import EventClient @@ -16,7 +17,16 @@ ) -def copy_context_vars(ctx_vars, func, *args, **kwargs): +T = TypeVar("T") +P = ParamSpec("P") + + +def copy_context_vars( + ctx_vars: ItemsView[ContextVar[Any], Any], + func: Callable[P, T], + *args: P.args, + **kwargs: P.kwargs, +) -> T: for var, value in ctx_vars: var.set(value) return func(*args, **kwargs) @@ -27,8 +37,8 @@ class InjectingFilter(logging.Filter): # otherwise we would use emit within the CustomLogHandler def filter(self, record) -> bool: ## TODO: Change how we do this to not assign to the log record - record.workflow_run_id = wr.get() # type: ignore - record.step_run_id = sr.get() # type: ignore + record.workflow_run_id = wr.get() + record.step_run_id = sr.get() return True @@ -38,7 +48,7 @@ def __init__(self, event_client: EventClient, stream: StringIO | None = None): self.logger_thread_pool = ThreadPoolExecutor(max_workers=1) self.event_client = event_client - def _log(self, line: str, step_run_id: str | None): + def _log(self, line: str, step_run_id: str | None) -> None: try: if not step_run_id: return @@ -56,10 +66,6 @@ def emit(self, record: logging.LogRecord) -> None: self.logger_thread_pool.submit(self._log, log_entry, record.step_run_id) # type: ignore -T = TypeVar("T") -P = ParamSpec("P") - - def capture_logs( logger: logging.Logger, event_client: "EventClient", func: Callable[P, Awaitable[T]] ) -> Callable[P, Awaitable[T]]: From 79a80f077712c6cbedff67739a26314f0b67cfd1 Mon Sep 17 00:00:00 2001 From: mrkaye97 Date: Wed, 22 Jan 2025 17:12:10 -0500 Subject: [PATCH 28/53] fix: t-10 --- .../clients/dispatcher/action_listener.py | 2 +- hatchet_sdk/clients/run_event_listener.py | 17 ++++++++++------- hatchet_sdk/clients/workflow_listener.py | 8 ++++---- hatchet_sdk/features/scheduled.py | 5 ++++- hatchet_sdk/worker/action_listener_process.py | 15 +++++++++------ hatchet_sdk/worker/runner/run_loop_manager.py | 9 +++++---- hatchet_sdk/worker/runner/runner.py | 8 ++------ hatchet_sdk/worker/runner/utils/capture_logs.py | 4 ++-- hatchet_sdk/worker/worker.py | 2 +- hatchet_sdk/workflow_run.py | 2 +- 10 files changed, 39 insertions(+), 33 deletions(-) diff --git a/hatchet_sdk/clients/dispatcher/action_listener.py b/hatchet_sdk/clients/dispatcher/action_listener.py index aec9869e..4dc5dfae 100644 --- a/hatchet_sdk/clients/dispatcher/action_listener.py +++ b/hatchet_sdk/clients/dispatcher/action_listener.py @@ -393,7 +393,7 @@ async def get_listen_client(self): return listener - def cleanup(self): + def cleanup(self) -> None: self.run_heartbeat = False self.heartbeat_task.cancel() diff --git a/hatchet_sdk/clients/run_event_listener.py b/hatchet_sdk/clients/run_event_listener.py index 5faf6e43..5a3e867f 100644 --- a/hatchet_sdk/clients/run_event_listener.py +++ b/hatchet_sdk/clients/run_event_listener.py @@ -208,7 +208,7 @@ async def retry_subscribe(self) -> AsyncGenerator[WorkflowEvent, None]: if self.workflow_run_id is not None: return cast( - WorkflowEvent, + AsyncGenerator[WorkflowEvent, None], self.client.SubscribeToWorkflowEvents( SubscribeToWorkflowEventsRequest( workflowRunId=self.workflow_run_id, @@ -217,12 +217,15 @@ async def retry_subscribe(self) -> AsyncGenerator[WorkflowEvent, None]: ), ) elif self.additional_meta_kv is not None: - return self.client.SubscribeToWorkflowEvents( - SubscribeToWorkflowEventsRequest( - additionalMetaKey=self.additional_meta_kv[0], - additionalMetaValue=self.additional_meta_kv[1], + return cast( + AsyncGenerator[WorkflowEvent, None], + self.client.SubscribeToWorkflowEvents( + SubscribeToWorkflowEventsRequest( + additionalMetaKey=self.additional_meta_kv[0], + additionalMetaValue=self.additional_meta_kv[1], + ), + metadata=get_metadata(self.token), ), - metadata=get_metadata(self.token), ) else: raise Exception("no listener method provided") @@ -240,7 +243,7 @@ class RunEventListenerClient: def __init__(self, config: ClientConfig): self.token = config.token self.config = config - self.client: DispatcherStub = None + self.client: DispatcherStub | None = None def stream_by_run_id(self, workflow_run_id: str) -> RunEventListener: return self.stream(workflow_run_id) diff --git a/hatchet_sdk/clients/workflow_listener.py b/hatchet_sdk/clients/workflow_listener.py index 86302c40..ecec272e 100644 --- a/hatchet_sdk/clients/workflow_listener.py +++ b/hatchet_sdk/clients/workflow_listener.py @@ -4,7 +4,7 @@ from typing import Any, AsyncGenerator, cast import grpc -from grpc._cython import cygrpc +from grpc._cython import cygrpc # type: ignore[attr-defined] from hatchet_sdk.clients.event_ts import Event_ts, read_with_interrupt from hatchet_sdk.connection import new_conn @@ -64,8 +64,8 @@ class PooledWorkflowRunListener: requests: asyncio.Queue[SubscribeToWorkflowRunsRequest] = asyncio.Queue() - listener: AsyncGenerator[WorkflowRunEvent, None] = None - listener_task: asyncio.Task[None] = None + listener: AsyncGenerator[WorkflowRunEvent, None] | None = None + listener_task: asyncio.Task[None] | None = None curr_requester: int = 0 @@ -76,7 +76,7 @@ class PooledWorkflowRunListener: def __init__(self, config: ClientConfig): conn = new_conn(config, True) - self.client = DispatcherStub(conn) + self.client = DispatcherStub(conn) # type: ignore[no-untyped-call] self.token = config.token self.config = config diff --git a/hatchet_sdk/features/scheduled.py b/hatchet_sdk/features/scheduled.py index 312ad527..60a43614 100644 --- a/hatchet_sdk/features/scheduled.py +++ b/hatchet_sdk/features/scheduled.py @@ -12,6 +12,9 @@ from hatchet_sdk.clients.rest.models.scheduled_workflows_list import ( ScheduledWorkflowsList, ) +from hatchet_sdk.clients.rest.models.scheduled_workflows_order_by_field import ( + ScheduledWorkflowsOrderByField, +) from hatchet_sdk.clients.rest.models.workflow_run_order_by_direction import ( WorkflowRunOrderByDirection, ) @@ -220,7 +223,7 @@ async def list( limit: Optional[int] = None, workflow_id: Optional[str] = None, additional_metadata: Optional[List[str]] = None, - order_by_field: Optional[CronWorkflowsOrderByField] = None, + order_by_field: Optional[ScheduledWorkflowsOrderByField] = None, order_by_direction: Optional[WorkflowRunOrderByDirection] = None, ) -> ScheduledWorkflowsList: """ diff --git a/hatchet_sdk/worker/action_listener_process.py b/hatchet_sdk/worker/action_listener_process.py index f69fa71e..e1ef4b38 100644 --- a/hatchet_sdk/worker/action_listener_process.py +++ b/hatchet_sdk/worker/action_listener_process.py @@ -4,7 +4,7 @@ import time from dataclasses import dataclass, field from multiprocessing import Queue -from typing import Any, List, Mapping, Optional +from typing import Any, List, Literal, Mapping, Optional import grpc @@ -30,10 +30,11 @@ class ActionEvent: action: Action type: Any # TODO type - payload: Optional[str] = None + payload: str -STOP_LOOP = "STOP_LOOP" # Sentinel object to stop the loop +STOP_LOOP_TYPE = Literal["STOP_LOOP"] +STOP_LOOP: STOP_LOOP_TYPE = "STOP_LOOP" # Sentinel object to stop the loop # TODO link to a block post BLOCKED_THREAD_WARNING = ( @@ -52,7 +53,7 @@ class WorkerActionListenerProcess: max_runs: int config: ClientConfig action_queue: Queue[Action] - event_queue: Queue[ActionEvent] + event_queue: Queue[ActionEvent | STOP_LOOP_TYPE] handle_kill: bool = True debug: bool = False labels: dict[str, str | int] = field(default_factory=dict) @@ -64,7 +65,7 @@ class WorkerActionListenerProcess: action_loop_task: asyncio.Task[None] | None = field(init=False, default=None) event_send_loop_task: asyncio.Task[None] | None = field(init=False, default=None) - running_step_runs: Mapping[str, float] = field(init=False, default_factory=dict) + running_step_runs: dict[str, float] = field(init=False, default_factory=dict) def __post_init__(self) -> None: if self.debug: @@ -108,7 +109,7 @@ async def start(self, retry_attempt: int = 0) -> None: self.blocked_main_loop = asyncio.create_task(self.start_blocked_main_loop()) # TODO move event methods to separate class - async def _get_event(self) -> ActionEvent: + async def _get_event(self) -> ActionEvent | STOP_LOOP_TYPE: loop = asyncio.get_running_loop() return await loop.run_in_executor(None, self.event_queue.get) @@ -201,6 +202,7 @@ async def start_action_loop(self) -> None: ActionEvent( action=action, type=STEP_EVENT_TYPE_STARTED, # TODO ack type + payload="", ) ) logger.info( @@ -220,6 +222,7 @@ async def start_action_loop(self) -> None: ActionEvent( action=action, type=GROUP_KEY_EVENT_TYPE_STARTED, # TODO ack type + payload="", ) ) logger.info( diff --git a/hatchet_sdk/worker/runner/run_loop_manager.py b/hatchet_sdk/worker/runner/run_loop_manager.py index 95baed1f..32b2b670 100644 --- a/hatchet_sdk/worker/runner/run_loop_manager.py +++ b/hatchet_sdk/worker/runner/run_loop_manager.py @@ -2,7 +2,7 @@ import logging from dataclasses import dataclass, field from multiprocessing import Queue -from typing import Callable, TypeVar +from typing import Callable, Literal, TypeVar from hatchet_sdk import Context from hatchet_sdk.client import Client, new_client_raw @@ -14,7 +14,8 @@ from hatchet_sdk.worker.runner.runner import Runner from hatchet_sdk.worker.runner.utils.capture_logs import capture_logs -STOP_LOOP = "STOP_LOOP" +STOP_LOOP_TYPE = Literal["STOP_LOOP"] +STOP_LOOP: STOP_LOOP_TYPE = "STOP_LOOP" T = TypeVar("T") @@ -26,7 +27,7 @@ class WorkerActionRunLoopManager: validator_registry: dict[str, WorkflowValidator] max_runs: int | None config: ClientConfig - action_queue: Queue[Action] + action_queue: Queue[Action | STOP_LOOP_TYPE] event_queue: Queue[ActionEvent] loop: asyncio.AbstractEventLoop handle_kill: bool = True @@ -94,7 +95,7 @@ async def _start_action_loop(self) -> None: self.runner.run(action) logger.debug("action runner loop stopped") - async def _get_action(self) -> Action: + async def _get_action(self) -> Action | STOP_LOOP_TYPE: return await self.loop.run_in_executor(None, self.action_queue.get) async def exit_gracefully(self) -> None: diff --git a/hatchet_sdk/worker/runner/runner.py b/hatchet_sdk/worker/runner/runner.py index 06269f54..5ca95e9e 100644 --- a/hatchet_sdk/worker/runner/runner.py +++ b/hatchet_sdk/worker/runner/runner.py @@ -307,10 +307,7 @@ async def handle_start_step_run(self, action: Action) -> None: if action_func: self.event_queue.put( - ActionEvent( - action=action, - type=STEP_EVENT_TYPE_STARTED, - ) + ActionEvent(action=action, type=STEP_EVENT_TYPE_STARTED, payload="") ) loop = asyncio.get_event_loop() @@ -360,8 +357,7 @@ async def handle_start_group_key_run(self, action: Action) -> None: # send an event that the group key run has started self.event_queue.put( ActionEvent( - action=action, - type=GROUP_KEY_EVENT_TYPE_STARTED, + action=action, type=GROUP_KEY_EVENT_TYPE_STARTED, payload="" ) ) diff --git a/hatchet_sdk/worker/runner/utils/capture_logs.py b/hatchet_sdk/worker/runner/utils/capture_logs.py index 39429827..f2aa5a80 100644 --- a/hatchet_sdk/worker/runner/utils/capture_logs.py +++ b/hatchet_sdk/worker/runner/utils/capture_logs.py @@ -6,8 +6,8 @@ from io import StringIO from typing import Any, Awaitable, Callable, Coroutine, ItemsView, ParamSpec, TypeVar -from hatchet_sdk import logger from hatchet_sdk.clients.events import EventClient +from hatchet_sdk.logger import logger wr: contextvars.ContextVar[str | None] = contextvars.ContextVar( "workflow_run_id", default=None @@ -35,7 +35,7 @@ def copy_context_vars( class InjectingFilter(logging.Filter): # For some reason, only the InjectingFilter has access to the contextvars method sr.get(), # otherwise we would use emit within the CustomLogHandler - def filter(self, record) -> bool: + def filter(self, record: logging.LogRecord) -> bool: ## TODO: Change how we do this to not assign to the log record record.workflow_run_id = wr.get() record.step_run_id = sr.get() diff --git a/hatchet_sdk/worker/worker.py b/hatchet_sdk/worker/worker.py index 74ef0931..50bf18f8 100644 --- a/hatchet_sdk/worker/worker.py +++ b/hatchet_sdk/worker/worker.py @@ -361,7 +361,7 @@ def exit_forcefully(self) -> None: logger.debug(f"forcefully stopping worker: {self.name}") ## TODO: `self.close` needs to be awaited / used - self.close() + self.close() # type: ignore[unused-coroutine] if self.action_listener_process: self.action_listener_process.kill() # Forcefully kill the process diff --git a/hatchet_sdk/workflow_run.py b/hatchet_sdk/workflow_run.py index 3233eba7..fae1fc6a 100644 --- a/hatchet_sdk/workflow_run.py +++ b/hatchet_sdk/workflow_run.py @@ -48,7 +48,7 @@ def sync_result(self) -> dict[str, Any]: class RunRef(WorkflowRunRef, Generic[T]): - async def result(self) -> T: + async def result(self) -> Any | dict[str, Any]: res = await self.workflow_listener.result(self.workflow_run_id) if len(res) == 1: From 64a25493e9fc367dddad7c2f3e9513f4eed280a5 Mon Sep 17 00:00:00 2001 From: mrkaye97 Date: Wed, 22 Jan 2025 17:13:59 -0500 Subject: [PATCH 29/53] fix: 9 --- hatchet_sdk/clients/run_event_listener.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/hatchet_sdk/clients/run_event_listener.py b/hatchet_sdk/clients/run_event_listener.py index 5a3e867f..25e56a0b 100644 --- a/hatchet_sdk/clients/run_event_listener.py +++ b/hatchet_sdk/clients/run_event_listener.py @@ -154,10 +154,12 @@ async def _generator(self) -> AsyncGenerator[StepRunEvent, None]: payload = workflow_event.eventPayload pass + assert isinstance(payload, str) + yield StepRunEvent(type=eventType, payload=payload) elif workflow_event.resourceType == RESOURCE_TYPE_WORKFLOW_RUN: if workflow_event.eventType in workflow_run_event_type_mapping: - eventType = workflow_run_event_type_mapping[ + workflowRunEventType = workflow_run_event_type_mapping[ workflow_event.eventType ] else: @@ -173,7 +175,9 @@ async def _generator(self) -> AsyncGenerator[StepRunEvent, None]: except Exception as e: pass - yield StepRunEvent(type=eventType, payload=payload) + assert isinstance(payload, str) + + yield StepRunEvent(type=workflowRunEventType, payload=payload) if workflow_event.hangup: listener = None From 9df38c48c3b94c233e265f86e220aaf8cef000d3 Mon Sep 17 00:00:00 2001 From: mrkaye97 Date: Wed, 22 Jan 2025 17:15:47 -0500 Subject: [PATCH 30/53] fix: 8 --- hatchet_sdk/clients/workflow_listener.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/hatchet_sdk/clients/workflow_listener.py b/hatchet_sdk/clients/workflow_listener.py index ecec272e..bef767da 100644 --- a/hatchet_sdk/clients/workflow_listener.py +++ b/hatchet_sdk/clients/workflow_listener.py @@ -118,7 +118,8 @@ async def _init_producer(self) -> None: ) t.cancel() - self.listener.cancel() + if self.listener: + self.listener.cancel() await asyncio.sleep( DEFAULT_WORKFLOW_LISTENER_RETRY_INTERVAL ) From eb2dd356b909a99c51f042cc834b3de13ab7432b Mon Sep 17 00:00:00 2001 From: mrkaye97 Date: Wed, 22 Jan 2025 17:16:24 -0500 Subject: [PATCH 31/53] fix: 7 --- hatchet_sdk/client.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/hatchet_sdk/client.py b/hatchet_sdk/client.py index a9bc8e60..f1715972 100644 --- a/hatchet_sdk/client.py +++ b/hatchet_sdk/client.py @@ -16,14 +16,6 @@ class Client: - admin: AdminClient - dispatcher: DispatcherClient - event: EventClient - rest: RestApi - workflow_listener: PooledWorkflowRunListener | None - logInterceptor: Logger - debug: bool = False - @classmethod def from_environment( cls, @@ -84,7 +76,7 @@ def __init__( event_client: EventClient, admin_client: AdminClient, dispatcher_client: DispatcherClient, - workflow_listener: PooledWorkflowRunListener, + workflow_listener: PooledWorkflowRunListener | None, rest_client: RestApi, config: ClientConfig, debug: bool = False, From 9ba0c58f4326af89fee034ce76de3306a23477cc Mon Sep 17 00:00:00 2001 From: mrkaye97 Date: Wed, 22 Jan 2025 17:17:07 -0500 Subject: [PATCH 32/53] fix: 6 --- hatchet_sdk/clients/workflow_listener.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hatchet_sdk/clients/workflow_listener.py b/hatchet_sdk/clients/workflow_listener.py index bef767da..c6ab3a5d 100644 --- a/hatchet_sdk/clients/workflow_listener.py +++ b/hatchet_sdk/clients/workflow_listener.py @@ -62,7 +62,7 @@ class PooledWorkflowRunListener: subscription_counter: int = 0 subscription_counter_lock: asyncio.Lock = asyncio.Lock() - requests: asyncio.Queue[SubscribeToWorkflowRunsRequest] = asyncio.Queue() + requests: asyncio.Queue[SubscribeToWorkflowRunsRequest | int] = asyncio.Queue() listener: AsyncGenerator[WorkflowRunEvent, None] | None = None listener_task: asyncio.Task[None] | None = None From ed190ffa613ae52268dd2e79f1211c3d165bf617 Mon Sep 17 00:00:00 2001 From: mrkaye97 Date: Wed, 22 Jan 2025 17:17:58 -0500 Subject: [PATCH 33/53] fix: 5 --- hatchet_sdk/clients/run_event_listener.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/hatchet_sdk/clients/run_event_listener.py b/hatchet_sdk/clients/run_event_listener.py index 25e56a0b..570d248b 100644 --- a/hatchet_sdk/clients/run_event_listener.py +++ b/hatchet_sdk/clients/run_event_listener.py @@ -158,8 +158,8 @@ async def _generator(self) -> AsyncGenerator[StepRunEvent, None]: yield StepRunEvent(type=eventType, payload=payload) elif workflow_event.resourceType == RESOURCE_TYPE_WORKFLOW_RUN: - if workflow_event.eventType in workflow_run_event_type_mapping: - workflowRunEventType = workflow_run_event_type_mapping[ + if workflow_event.eventType in step_run_event_type_mapping: + workflowRunEventType = step_run_event_type_mapping[ workflow_event.eventType ] else: From e81e3dd1cd7b22c6bf2667ce3a38effefaff79e8 Mon Sep 17 00:00:00 2001 From: mrkaye97 Date: Wed, 22 Jan 2025 17:18:24 -0500 Subject: [PATCH 34/53] fix: 4 --- hatchet_sdk/clients/workflow_listener.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hatchet_sdk/clients/workflow_listener.py b/hatchet_sdk/clients/workflow_listener.py index c6ab3a5d..0a3bcd6a 100644 --- a/hatchet_sdk/clients/workflow_listener.py +++ b/hatchet_sdk/clients/workflow_listener.py @@ -34,7 +34,7 @@ def __init__(self, id: int, workflow_run_id: str): async def __aiter__(self) -> "_Subscription": return self - async def __anext__(self) -> WorkflowRunEvent: + async def __anext__(self) -> WorkflowRunEvent | None: return await self.queue.get() async def get(self) -> WorkflowRunEvent: From 918e6fb02cccad4a943fce7138be1de2a57b373a Mon Sep 17 00:00:00 2001 From: mrkaye97 Date: Wed, 22 Jan 2025 17:19:56 -0500 Subject: [PATCH 35/53] fix: 3 --- hatchet_sdk/clients/workflow_listener.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/hatchet_sdk/clients/workflow_listener.py b/hatchet_sdk/clients/workflow_listener.py index 0a3bcd6a..28283b8e 100644 --- a/hatchet_sdk/clients/workflow_listener.py +++ b/hatchet_sdk/clients/workflow_listener.py @@ -246,7 +246,7 @@ async def result(self, workflow_run_id: str) -> dict[str, Any]: return results - async def _retry_subscribe(self) -> WorkflowRunEvent | None: + async def _retry_subscribe(self) -> AsyncGenerator[WorkflowRunEvent, None]: retries = 0 while retries < DEFAULT_WORKFLOW_LISTENER_RETRY_COUNT: @@ -259,7 +259,7 @@ async def _retry_subscribe(self) -> WorkflowRunEvent | None: self.requests.put_nowait(self.curr_requester) return cast( - WorkflowRunEvent, + AsyncGenerator[WorkflowRunEvent, None], self.client.SubscribeToWorkflowRuns( self._request(), metadata=get_metadata(self.token), From 8116949f0f46ad42f4acbe6f640d626f21913ae0 Mon Sep 17 00:00:00 2001 From: mrkaye97 Date: Wed, 22 Jan 2025 17:21:57 -0500 Subject: [PATCH 36/53] fix: 2 --- hatchet_sdk/workflow_run.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hatchet_sdk/workflow_run.py b/hatchet_sdk/workflow_run.py index fae1fc6a..3bf570db 100644 --- a/hatchet_sdk/workflow_run.py +++ b/hatchet_sdk/workflow_run.py @@ -34,7 +34,7 @@ def result(self) -> Coroutine[None, None, dict[str, Any]]: def sync_result(self) -> dict[str, Any]: loop = get_active_event_loop() if loop is None: - with EventLoopThread() as loop: + with EventLoopThread() as loop: # type: ignore[call-arg] coro = self.workflow_listener.result(self.workflow_run_id) future = asyncio.run_coroutine_threadsafe(coro, loop) return future.result() From 9535dd1f78b0e67fbf5aa42199967cf8fd2571cd Mon Sep 17 00:00:00 2001 From: mrkaye97 Date: Wed, 22 Jan 2025 17:22:03 -0500 Subject: [PATCH 37/53] fix: 1 --- hatchet_sdk/clients/workflow_listener.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hatchet_sdk/clients/workflow_listener.py b/hatchet_sdk/clients/workflow_listener.py index 28283b8e..7d5b715c 100644 --- a/hatchet_sdk/clients/workflow_listener.py +++ b/hatchet_sdk/clients/workflow_listener.py @@ -119,7 +119,7 @@ async def _init_producer(self) -> None: t.cancel() if self.listener: - self.listener.cancel() + self.listener.cancel() # type: ignore[attr-defined] await asyncio.sleep( DEFAULT_WORKFLOW_LISTENER_RETRY_INTERVAL ) From e5bf592cfb8fcc79024bceebf9287df291dba76c Mon Sep 17 00:00:00 2001 From: mrkaye97 Date: Wed, 22 Jan 2025 17:23:14 -0500 Subject: [PATCH 38/53] fix: 0 --- hatchet_sdk/clients/workflow_listener.py | 2 +- hatchet_sdk/worker/runner/utils/capture_logs.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/hatchet_sdk/clients/workflow_listener.py b/hatchet_sdk/clients/workflow_listener.py index 7d5b715c..ada20040 100644 --- a/hatchet_sdk/clients/workflow_listener.py +++ b/hatchet_sdk/clients/workflow_listener.py @@ -119,7 +119,7 @@ async def _init_producer(self) -> None: t.cancel() if self.listener: - self.listener.cancel() # type: ignore[attr-defined] + self.listener.cancel() # type: ignore[attr-defined] await asyncio.sleep( DEFAULT_WORKFLOW_LISTENER_RETRY_INTERVAL ) diff --git a/hatchet_sdk/worker/runner/utils/capture_logs.py b/hatchet_sdk/worker/runner/utils/capture_logs.py index f2aa5a80..6fec015c 100644 --- a/hatchet_sdk/worker/runner/utils/capture_logs.py +++ b/hatchet_sdk/worker/runner/utils/capture_logs.py @@ -42,7 +42,7 @@ def filter(self, record: logging.LogRecord) -> bool: return True -class CustomLogHandler(logging.StreamHandler): +class CustomLogHandler(logging.StreamHandler[Any]): def __init__(self, event_client: EventClient, stream: StringIO | None = None): super().__init__(stream) self.logger_thread_pool = ThreadPoolExecutor(max_workers=1) From e04a2376ad1574448df957793439bb9e55d574fa Mon Sep 17 00:00:00 2001 From: mrkaye97 Date: Thu, 23 Jan 2025 09:25:28 -0500 Subject: [PATCH 39/53] fix: queue types --- hatchet_sdk/worker/action_listener_process.py | 4 ++-- hatchet_sdk/worker/runner/run_loop_manager.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/hatchet_sdk/worker/action_listener_process.py b/hatchet_sdk/worker/action_listener_process.py index e1ef4b38..6017ab68 100644 --- a/hatchet_sdk/worker/action_listener_process.py +++ b/hatchet_sdk/worker/action_listener_process.py @@ -52,8 +52,8 @@ class WorkerActionListenerProcess: actions: List[str] max_runs: int config: ClientConfig - action_queue: Queue[Action] - event_queue: Queue[ActionEvent | STOP_LOOP_TYPE] + action_queue: "Queue[Action]" + event_queue: "Queue[ActionEvent | STOP_LOOP_TYPE]" handle_kill: bool = True debug: bool = False labels: dict[str, str | int] = field(default_factory=dict) diff --git a/hatchet_sdk/worker/runner/run_loop_manager.py b/hatchet_sdk/worker/runner/run_loop_manager.py index 32b2b670..972c9cd5 100644 --- a/hatchet_sdk/worker/runner/run_loop_manager.py +++ b/hatchet_sdk/worker/runner/run_loop_manager.py @@ -27,8 +27,8 @@ class WorkerActionRunLoopManager: validator_registry: dict[str, WorkflowValidator] max_runs: int | None config: ClientConfig - action_queue: Queue[Action | STOP_LOOP_TYPE] - event_queue: Queue[ActionEvent] + action_queue: "Queue[Action | STOP_LOOP_TYPE]" + event_queue: "Queue[ActionEvent]" loop: asyncio.AbstractEventLoop handle_kill: bool = True debug: bool = False From 514511930d1dba8ab050896dc53bf82fa79dc20a Mon Sep 17 00:00:00 2001 From: mrkaye97 Date: Tue, 28 Jan 2025 17:04:09 -0500 Subject: [PATCH 40/53] feat: finally fix mypy config --- pyproject.toml | 75 ++++++++++++++++++++++++++++++++++++-------------- 1 file changed, 55 insertions(+), 20 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 30426c65..1e061b35 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,11 +10,11 @@ include = ["hatchet_sdk/py.typed"] python = "^3.10" grpcio = [ { version = ">=1.64.1, !=1.68.*", markers = "python_version < '3.13'" }, - { version = ">=1.69.0", markers = "python_version >= '3.13'" } + { version = ">=1.69.0", markers = "python_version >= '3.13'" }, ] grpcio-tools = [ { version = ">=1.64.1, !=1.68.*", markers = "python_version < '3.13'" }, - { version = ">=1.69.0", markers = "python_version >= '3.13'" } + { version = ">=1.69.0", markers = "python_version >= '3.13'" }, ] python-dotenv = "^1.0.0" protobuf = "^5.29.1" @@ -67,15 +67,15 @@ env = [ [tool.isort] profile = "black" known_third_party = [ - "grpcio", - "grpcio_tools", - "loguru", - "protobuf", - "pydantic", - "python_dotenv", - "python_dateutil", - "pyyaml", - "urllib3", + "grpcio", + "grpcio_tools", + "loguru", + "protobuf", + "pydantic", + "python_dotenv", + "python_dateutil", + "pyyaml", + "urllib3", ] extend_skip = ["hatchet_sdk/contracts/"] @@ -83,17 +83,53 @@ extend_skip = ["hatchet_sdk/contracts/"] extend_exclude = "hatchet_sdk/contracts/" [tool.mypy] -strict = true -files = [ - "." -] +files = ["."] +follow_imports = "silent" exclude = [ - "hatchet_sdk/clients/rest", - "hatchet_sdk/clients/dispatcher", - "hatchet_sdk/contracts", + "hatchet_sdk/clients/rest/api/*", + "hatchet_sdk/clients/rest/models/*", + "hatchet_sdk/contracts", ] -follow_imports = "silent" + explicit_package_bases = true +warn_unused_configs = true +warn_redundant_casts = true +warn_unused_ignores = true + +strict_equality = true + +check_untyped_defs = true + +disallow_subclassing_any = true +disallow_untyped_decorators = true +disallow_any_generics = true + +disallow_untyped_calls = true +disallow_incomplete_defs = true +disallow_untyped_defs = true + +no_implicit_reexport = true + +warn_return_any = true + +[[tool.mypy.overrides]] +module = ["hatchet_sdk/contracts/*", "hatchet_sdk/clients/rest/*"] + +warn_unused_ignores = false + +strict_equality = false + +disallow_subclassing_any = false +disallow_untyped_decorators = false +disallow_any_generics = false + +disallow_untyped_calls = false +disallow_incomplete_defs = false +disallow_untyped_defs = false + +no_implicit_reexport = false + +warn_return_any = false [tool.poetry.scripts] api = "examples.api.api:main" @@ -118,4 +154,3 @@ existing_loop = "examples.worker_existing_loop.worker:main" bulk_fanout = "examples.bulk_fanout.worker:main" retries_with_backoff = "examples.retries_with_backoff.worker:main" pydantic = "examples.pydantic.worker:main" - From 62948af2e16a0f30af26e8c93045cd43a3070ff0 Mon Sep 17 00:00:00 2001 From: mrkaye97 Date: Tue, 28 Jan 2025 17:55:44 -0500 Subject: [PATCH 41/53] fix: rest of the mypy errors --- .../clients/dispatcher/action_listener.py | 89 +++++++------- hatchet_sdk/clients/event_ts.py | 15 ++- hatchet_sdk/clients/rest/api_client.py | 2 +- hatchet_sdk/clients/rest/tenacity_utils.py | 2 +- hatchet_sdk/clients/rest_client.py | 109 ++++++++++-------- hatchet_sdk/clients/workflow_listener.py | 19 ++- hatchet_sdk/features/cron.py | 71 +++++------- hatchet_sdk/features/scheduled.py | 76 +++++------- 8 files changed, 192 insertions(+), 191 deletions(-) diff --git a/hatchet_sdk/clients/dispatcher/action_listener.py b/hatchet_sdk/clients/dispatcher/action_listener.py index 4dc5dfae..ff422805 100644 --- a/hatchet_sdk/clients/dispatcher/action_listener.py +++ b/hatchet_sdk/clients/dispatcher/action_listener.py @@ -2,10 +2,11 @@ import json import time from dataclasses import dataclass, field -from typing import Any, AsyncGenerator, List, Optional +from typing import Any, AsyncGenerator, AsyncIterable, AsyncIterator, Optional, cast import grpc -from grpc._cython import cygrpc +import grpc.aio +from grpc._cython import cygrpc # type: ignore[attr-defined] from hatchet_sdk.clients.event_ts import Event_ts, read_with_interrupt from hatchet_sdk.clients.run_event_listener import ( @@ -40,14 +41,14 @@ @dataclass class GetActionListenerRequest: worker_name: str - services: List[str] - actions: List[str] + services: list[str] + actions: list[str] max_runs: Optional[int] = None _labels: dict[str, str | int] = field(default_factory=dict) labels: dict[str, WorkerLabels] = field(init=False) - def __post_init__(self): + def __post_init__(self) -> None: self.labels = {} for key, value in self._labels.items(): @@ -78,7 +79,7 @@ class Action: child_workflow_key: str | None = None parent_workflow_run_id: str | None = None - def __post_init__(self): + def __post_init__(self) -> None: if isinstance(self.additional_metadata, str) and self.additional_metadata != "": try: self.additional_metadata = json.loads(self.additional_metadata) @@ -114,11 +115,6 @@ def otel_attributes(self) -> dict[str, Any]: ) -START_STEP_RUN = 0 -CANCEL_STEP_RUN = 1 -START_GET_GROUP_KEY = 2 - - @dataclass class ActionListener: config: ClientConfig @@ -131,22 +127,22 @@ class ActionListener: last_connection_attempt: float = field(default=0, init=False) last_heartbeat_succeeded: bool = field(default=True, init=False) time_last_hb_succeeded: float = field(default=9999999999999, init=False) - heartbeat_task: Optional[asyncio.Task] = field(default=None, init=False) + heartbeat_task: Optional[asyncio.Task[None]] = field(default=None, init=False) run_heartbeat: bool = field(default=True, init=False) listen_strategy: str = field(default="v2", init=False) stop_signal: bool = field(default=False, init=False) missed_heartbeats: int = field(default=0, init=False) - def __post_init__(self): - self.client = DispatcherStub(new_conn(self.config, False)) - self.aio_client = DispatcherStub(new_conn(self.config, True)) + def __post_init__(self) -> None: + self.client = DispatcherStub(new_conn(self.config, False)) # type: ignore[no-untyped-call] + self.aio_client = DispatcherStub(new_conn(self.config, True)) # type: ignore[no-untyped-call] self.token = self.config.token - def is_healthy(self): + def is_healthy(self) -> bool: return self.last_heartbeat_succeeded - async def heartbeat(self): + async def heartbeat(self) -> None: # send a heartbeat every 4 seconds heartbeat_delay = 4 @@ -206,7 +202,7 @@ async def heartbeat(self): break await asyncio.sleep(heartbeat_delay) - async def start_heartbeater(self): + async def start_heartbeater(self) -> None: if self.heartbeat_task is not None: return @@ -220,10 +216,10 @@ async def start_heartbeater(self): raise e self.heartbeat_task = loop.create_task(self.heartbeat()) - def __aiter__(self): + def __aiter__(self) -> AsyncGenerator[Action | None, None]: return self._generator() - async def _generator(self) -> AsyncGenerator[Action, None]: + async def _generator(self) -> AsyncGenerator[Action | None, None]: listener = None while not self.stop_signal: @@ -239,6 +235,10 @@ async def _generator(self) -> AsyncGenerator[Action, None]: try: while not self.stop_signal: self.interrupt = Event_ts() + + if listener is None: + continue + t = asyncio.create_task( read_with_interrupt(listener, self.interrupt) ) @@ -251,7 +251,10 @@ async def _generator(self) -> AsyncGenerator[Action, None]: ) t.cancel() - listener.cancel() + + if listener: + listener.cancel() + break assigned_action = t.result() @@ -261,10 +264,9 @@ async def _generator(self) -> AsyncGenerator[Action, None]: break self.retries = 0 - assigned_action: AssignedAction # Process the received action - action_type = self.map_action_type(assigned_action.actionType) + action_type = assigned_action.actionType if ( assigned_action.actionPayload is None @@ -287,7 +289,8 @@ async def _generator(self) -> AsyncGenerator[Action, None]: step_id=assigned_action.stepId, step_run_id=assigned_action.stepRunId, action_id=assigned_action.actionId, - action_payload=action_payload, + ## TODO: Figure out this type - maybe needs to be dumped to JSON? + action_payload=action_payload, # type: ignore[arg-type] action_type=action_type, retry_count=assigned_action.retryCount, additional_metadata=assigned_action.additional_metadata, @@ -324,25 +327,15 @@ async def _generator(self) -> AsyncGenerator[Action, None]: self.retries = self.retries + 1 - def parse_action_payload(self, payload: str): + def parse_action_payload(self, payload: str) -> JSONSerializableDict: try: - payload_data = json.loads(payload) + return cast(JSONSerializableDict, json.loads(payload)) except json.JSONDecodeError as e: raise ValueError(f"Error decoding payload: {e}") - return payload_data - - def map_action_type(self, action_type): - if action_type == ActionType.START_STEP_RUN: - return START_STEP_RUN - elif action_type == ActionType.CANCEL_STEP_RUN: - return CANCEL_STEP_RUN - elif action_type == ActionType.START_GET_GROUP_KEY: - return START_GET_GROUP_KEY - else: - # logger.error(f"Unknown action type: {action_type}") - return None - async def get_listen_client(self): + async def get_listen_client( + self, + ) -> grpc.aio.UnaryStreamCall[WorkerListenRequest, AssignedAction]: current_time = int(time.time()) if ( @@ -370,7 +363,8 @@ async def get_listen_client(self): f"action listener connection interrupted, retrying... ({self.retries}/{DEFAULT_ACTION_LISTENER_RETRY_COUNT})" ) - self.aio_client = DispatcherStub(new_conn(self.config, True)) + ## TODO: Figure out how to get type support for these + self.aio_client = DispatcherStub(new_conn(self.config, True)) # type: ignore[no-untyped-call] if self.listen_strategy == "v2": # we should await for the listener to be established before @@ -391,11 +385,14 @@ async def get_listen_client(self): self.last_connection_attempt = current_time - return listener + return cast( + grpc.aio.UnaryStreamCall[WorkerListenRequest, AssignedAction], listener + ) def cleanup(self) -> None: self.run_heartbeat = False - self.heartbeat_task.cancel() + if self.heartbeat_task is not None: + self.heartbeat_task.cancel() try: self.unregister() @@ -405,9 +402,11 @@ def cleanup(self) -> None: if self.interrupt: self.interrupt.set() - def unregister(self): + def unregister(self) -> WorkerUnsubscribeRequest: self.run_heartbeat = False - self.heartbeat_task.cancel() + + if self.heartbeat_task is not None: + self.heartbeat_task.cancel() try: req = self.aio_client.Unsubscribe( @@ -417,6 +416,6 @@ def unregister(self): ) if self.interrupt is not None: self.interrupt.set() - return req + return cast(WorkerUnsubscribeRequest, req) except grpc.RpcError as e: raise Exception(f"Failed to unsubscribe: {e}") diff --git a/hatchet_sdk/clients/event_ts.py b/hatchet_sdk/clients/event_ts.py index cb40cc98..694e7af6 100644 --- a/hatchet_sdk/clients/event_ts.py +++ b/hatchet_sdk/clients/event_ts.py @@ -1,5 +1,7 @@ import asyncio -from typing import Any +from typing import Any, TypeVar, cast + +import grpc.aio class Event_ts(asyncio.Event): @@ -20,9 +22,14 @@ def clear(self) -> None: self._loop.call_soon_threadsafe(super().clear) -async def read_with_interrupt(listener: Any, interrupt: Event_ts) -> Any: +TRequest = TypeVar("TRequest") +TResponse = TypeVar("TResponse") + + +async def read_with_interrupt( + listener: grpc.aio.UnaryStreamCall[TRequest, TResponse], interrupt: Event_ts +) -> Any: try: - result = await listener.read() - return result + return cast(Any, await listener.read()) finally: interrupt.set() diff --git a/hatchet_sdk/clients/rest/api_client.py b/hatchet_sdk/clients/rest/api_client.py index 76446dda..62bdc472 100644 --- a/hatchet_sdk/clients/rest/api_client.py +++ b/hatchet_sdk/clients/rest/api_client.py @@ -97,7 +97,7 @@ async def __aenter__(self): async def __aexit__(self, exc_type, exc_value, traceback): await self.close() - async def close(self): + async def close(self) -> None: await self.rest_client.close() @property diff --git a/hatchet_sdk/clients/rest/tenacity_utils.py b/hatchet_sdk/clients/rest/tenacity_utils.py index 377266a1..c90f7352 100644 --- a/hatchet_sdk/clients/rest/tenacity_utils.py +++ b/hatchet_sdk/clients/rest/tenacity_utils.py @@ -27,7 +27,7 @@ def tenacity_alert_retry(retry_state: tenacity.RetryCallState) -> None: ) -def tenacity_should_retry(ex: Exception) -> bool: +def tenacity_should_retry(ex: BaseException) -> bool: if isinstance(ex, (grpc.aio.AioRpcError, grpc.RpcError)): if ex.code() in [ grpc.StatusCode.UNIMPLEMENTED, diff --git a/hatchet_sdk/clients/rest_client.py b/hatchet_sdk/clients/rest_client.py index 9676ef09..ba62490a 100644 --- a/hatchet_sdk/clients/rest_client.py +++ b/hatchet_sdk/clients/rest_client.py @@ -2,10 +2,11 @@ import atexit import datetime import threading -from typing import Any, Coroutine, List +from typing import Any, Coroutine, List, TypeVar, cast from pydantic import StrictInt +from hatchet_sdk.clients.rest import CronWorkflowsList, ScheduledWorkflowsList from hatchet_sdk.clients.rest.api.event_api import EventApi from hatchet_sdk.clients.rest.api.log_api import LogApi from hatchet_sdk.clients.rest.api.step_run_api import StepRunApi @@ -27,6 +28,9 @@ EventOrderByDirection, ) from hatchet_sdk.clients.rest.models.event_order_by_field import EventOrderByField +from hatchet_sdk.clients.rest.models.event_update_cancel200_response import ( + EventUpdateCancel200Response, +) from hatchet_sdk.clients.rest.models.log_line_level import LogLineLevel from hatchet_sdk.clients.rest.models.log_line_list import LogLineList from hatchet_sdk.clients.rest.models.log_line_order_by_direction import ( @@ -68,6 +72,17 @@ from hatchet_sdk.clients.rest.models.workflow_version import WorkflowVersion from hatchet_sdk.utils.types import JSONSerializableDict +## Type variables to use with coroutines. +## See https://stackoverflow.com/questions/73240620/the-right-way-to-type-hint-a-coroutine-function +## Return type +R = TypeVar("R") + +## Yield type +Y = TypeVar("Y") + +## Send type +S = TypeVar("S") + class AsyncRestApi: def __init__(self, host: str, api_key: str, tenant_id: str): @@ -78,50 +93,50 @@ def __init__(self, host: str, api_key: str, tenant_id: str): access_token=api_key, ) - self._api_client = None - self._workflow_api = None - self._workflow_run_api = None - self._step_run_api = None - self._event_api = None - self._log_api = None + self._api_client: ApiClient | None = None + self._workflow_api: WorkflowApi | None = None + self._workflow_run_api: WorkflowRunApi | None = None + self._step_run_api: StepRunApi | None = None + self._event_api: EventApi | None = None + self._log_api: LogApi | None = None @property - def api_client(self): + def api_client(self) -> ApiClient: if self._api_client is None: self._api_client = ApiClient(configuration=self.config) return self._api_client @property - def workflow_api(self): + def workflow_api(self) -> WorkflowApi: if self._workflow_api is None: self._workflow_api = WorkflowApi(self.api_client) return self._workflow_api @property - def workflow_run_api(self): + def workflow_run_api(self) -> WorkflowRunApi: if self._workflow_run_api is None: self._workflow_run_api = WorkflowRunApi(self.api_client) return self._workflow_run_api @property - def step_run_api(self): + def step_run_api(self) -> StepRunApi: if self._step_run_api is None: self._step_run_api = StepRunApi(self.api_client) return self._step_run_api @property - def event_api(self): + def event_api(self) -> EventApi: if self._event_api is None: self._event_api = EventApi(self.api_client) return self._event_api @property - def log_api(self): + def log_api(self) -> LogApi: if self._log_api is None: self._log_api = LogApi(self.api_client) return self._log_api - async def close(self): + async def close(self) -> None: # Ensure the aiohttp client session is closed if self._api_client is not None: await self._api_client.close() @@ -185,13 +200,13 @@ async def workflow_run_replay( return await self.workflow_run_api.workflow_run_update_replay( tenant=self.tenant_id, replay_workflow_runs_request=ReplayWorkflowRunsRequest( - workflow_run_ids=workflow_run_ids, + workflowRunIds=workflow_run_ids, ), ) async def workflow_run_cancel( self, workflow_run_id: str - ) -> WorkflowRunCancel200Response: + ) -> EventUpdateCancel200Response: return await self.workflow_run_api.workflow_run_cancel( tenant=self.tenant_id, workflow_runs_cancel_request=WorkflowRunsCancelRequest( @@ -201,7 +216,7 @@ async def workflow_run_cancel( async def workflow_run_bulk_cancel( self, workflow_run_ids: list[str] - ) -> WorkflowRunCancel200Response: + ) -> EventUpdateCancel200Response: return await self.workflow_run_api.workflow_run_cancel( tenant=self.tenant_id, workflow_runs_cancel_request=WorkflowRunsCancelRequest( @@ -219,9 +234,10 @@ async def workflow_run_create( return await self.workflow_run_api.workflow_run_create( workflow=workflow_id, version=version, + ## TODO: Fix this type error - maybe a list of strings is okay since it's still JSON? trigger_workflow_run_request=TriggerWorkflowRunRequest( input=input, - additional_metadata=additional_metadata, + additionalMetadata=additional_metadata, # type: ignore[arg-type] ), ) @@ -232,7 +248,7 @@ async def cron_create( expression: str, input: JSONSerializableDict, additional_metadata: JSONSerializableDict, - ): + ) -> CronWorkflows: return await self.workflow_run_api.cron_workflow_trigger_create( tenant=self.tenant_id, workflow=workflow_name, @@ -240,12 +256,12 @@ async def cron_create( cronName=cron_name, cronExpression=expression, input=input, - additional_metadata=additional_metadata, + additionalMetadata=additional_metadata, ), ) - async def cron_delete(self, cron_trigger_id: str): - return await self.workflow_api.workflow_cron_delete( + async def cron_delete(self, cron_trigger_id: str) -> None: + await self.workflow_api.workflow_cron_delete( tenant=self.tenant_id, cron_workflow=cron_trigger_id, ) @@ -258,7 +274,7 @@ async def cron_list( additional_metadata: list[str] | None = None, order_by_field: CronWorkflowsOrderByField | None = None, order_by_direction: WorkflowRunOrderByDirection | None = None, - ): + ) -> CronWorkflowsList: return await self.workflow_api.cron_workflow_list( tenant=self.tenant_id, offset=offset, @@ -269,7 +285,7 @@ async def cron_list( order_by_direction=order_by_direction, ) - async def cron_get(self, cron_trigger_id: str): + async def cron_get(self, cron_trigger_id: str) -> CronWorkflows: return await self.workflow_api.workflow_cron_get( tenant=self.tenant_id, cron_workflow=cron_trigger_id, @@ -281,19 +297,19 @@ async def schedule_create( trigger_at: datetime.datetime, input: JSONSerializableDict, additional_metadata: JSONSerializableDict, - ): + ) -> ScheduledWorkflows: return await self.workflow_run_api.scheduled_workflow_run_create( tenant=self.tenant_id, workflow=name, schedule_workflow_run_request=ScheduleWorkflowRunRequest( triggerAt=trigger_at, input=input, - additional_metadata=additional_metadata, + additionalMetadata=additional_metadata, ), ) - async def schedule_delete(self, scheduled_trigger_id: str): - return await self.workflow_api.workflow_scheduled_delete( + async def schedule_delete(self, scheduled_trigger_id: str) -> None: + await self.workflow_api.workflow_scheduled_delete( tenant=self.tenant_id, scheduled_workflow_run=scheduled_trigger_id, ) @@ -308,7 +324,7 @@ async def schedule_list( parent_step_run_id: str | None = None, order_by_field: ScheduledWorkflowsOrderByField | None = None, order_by_direction: WorkflowRunOrderByDirection | None = None, - ): + ) -> ScheduledWorkflowsList: return await self.workflow_api.workflow_scheduled_list( tenant=self.tenant_id, offset=offset, @@ -321,7 +337,7 @@ async def schedule_list( order_by_direction=order_by_direction, ) - async def schedule_get(self, scheduled_trigger_id: str): + async def schedule_get(self, scheduled_trigger_id: str) -> ScheduledWorkflows: return await self.workflow_api.workflow_scheduled_get( tenant=self.tenant_id, scheduled_workflow_run=scheduled_trigger_id, @@ -374,9 +390,10 @@ async def events_list( async def events_replay(self, event_ids: list[str] | EventList) -> EventList: if isinstance(event_ids, EventList): - event_ids = [r.metadata.id for r in event_ids.rows] + rows = event_ids.rows or [] + event_ids = [r.metadata.id for r in rows] - return self.event_api.event_update_replay( + return await self.event_api.event_update_replay( tenant=self.tenant_id, replay_event_request=ReplayEventRequest(eventIds=event_ids), ) @@ -394,7 +411,7 @@ def __init__(self, host: str, api_key: str, tenant_id: str): # Register the cleanup method to be called on exit atexit.register(self._cleanup) - def _cleanup(self): + def _cleanup(self) -> None: """ Stop the running thread and clean up the event loop. """ @@ -402,14 +419,14 @@ def _cleanup(self): self._loop.call_soon_threadsafe(self._loop.stop) self._thread.join() - def _run_event_loop(self): + def _run_event_loop(self) -> None: """ Run the asyncio event loop in a separate thread. """ asyncio.set_event_loop(self._loop) self._loop.run_forever() - def _run_coroutine(self, coro) -> Any: + def _run_coroutine(self, coro: Coroutine[Y, S, R]) -> R: """ Execute a coroutine in the event loop and return the result. """ @@ -460,12 +477,12 @@ def workflow_run_list( def workflow_run_get(self, workflow_run_id: str) -> WorkflowRun: return self._run_coroutine(self.aio.workflow_run_get(workflow_run_id)) - def workflow_run_cancel(self, workflow_run_id: str) -> WorkflowRunCancel200Response: + def workflow_run_cancel(self, workflow_run_id: str) -> EventUpdateCancel200Response: return self._run_coroutine(self.aio.workflow_run_cancel(workflow_run_id)) def workflow_run_bulk_cancel( self, workflow_run_ids: list[str] - ) -> WorkflowRunCancel200Response: + ) -> EventUpdateCancel200Response: return self._run_coroutine(self.aio.workflow_run_bulk_cancel(workflow_run_ids)) def workflow_run_create( @@ -495,8 +512,8 @@ def cron_create( ) ) - def cron_delete(self, cron_trigger_id: str): - return self._run_coroutine(self.aio.cron_delete(cron_trigger_id)) + def cron_delete(self, cron_trigger_id: str) -> None: + self._run_coroutine(self.aio.cron_delete(cron_trigger_id)) def cron_list( self, @@ -506,7 +523,7 @@ def cron_list( additional_metadata: list[str] | None = None, order_by_field: CronWorkflowsOrderByField | None = None, order_by_direction: WorkflowRunOrderByDirection | None = None, - ): + ) -> CronWorkflowsList: return self._run_coroutine( self.aio.cron_list( offset, @@ -518,7 +535,7 @@ def cron_list( ) ) - def cron_get(self, cron_trigger_id: str): + def cron_get(self, cron_trigger_id: str) -> CronWorkflows: return self._run_coroutine(self.aio.cron_get(cron_trigger_id)) def schedule_create( @@ -527,15 +544,15 @@ def schedule_create( trigger_at: datetime.datetime, input: JSONSerializableDict, additional_metadata: JSONSerializableDict, - ): + ) -> ScheduledWorkflows: return self._run_coroutine( self.aio.schedule_create( workflow_name, trigger_at, input, additional_metadata ) ) - def schedule_delete(self, scheduled_trigger_id: str): - return self._run_coroutine(self.aio.schedule_delete(scheduled_trigger_id)) + def schedule_delete(self, scheduled_trigger_id: str) -> None: + self._run_coroutine(self.aio.schedule_delete(scheduled_trigger_id)) def schedule_list( self, @@ -545,7 +562,7 @@ def schedule_list( additional_metadata: list[str] | None = None, order_by_field: CronWorkflowsOrderByField | None = None, order_by_direction: WorkflowRunOrderByDirection | None = None, - ): + ) -> ScheduledWorkflowsList: return self._run_coroutine( self.aio.schedule_list( offset, @@ -557,7 +574,7 @@ def schedule_list( ) ) - def schedule_get(self, scheduled_trigger_id: str): + def schedule_get(self, scheduled_trigger_id: str) -> ScheduledWorkflows: return self._run_coroutine(self.aio.schedule_get(scheduled_trigger_id)) def list_logs( diff --git a/hatchet_sdk/clients/workflow_listener.py b/hatchet_sdk/clients/workflow_listener.py index ada20040..d38d6be8 100644 --- a/hatchet_sdk/clients/workflow_listener.py +++ b/hatchet_sdk/clients/workflow_listener.py @@ -4,6 +4,7 @@ from typing import Any, AsyncGenerator, cast import grpc +import grpc.aio from grpc._cython import cygrpc # type: ignore[attr-defined] from hatchet_sdk.clients.event_ts import Event_ts, read_with_interrupt @@ -64,7 +65,10 @@ class PooledWorkflowRunListener: requests: asyncio.Queue[SubscribeToWorkflowRunsRequest | int] = asyncio.Queue() - listener: AsyncGenerator[WorkflowRunEvent, None] | None = None + listener: ( + grpc.aio.UnaryStreamCall[SubscribeToWorkflowRunsRequest, WorkflowRunEvent] + | None + ) = None listener_task: asyncio.Task[None] | None = None curr_requester: int = 0 @@ -106,6 +110,9 @@ async def _init_producer(self) -> None: while True: self.interrupt = Event_ts() + if self.listener is None: + continue + t = asyncio.create_task( read_with_interrupt(self.listener, self.interrupt) ) @@ -119,7 +126,7 @@ async def _init_producer(self) -> None: t.cancel() if self.listener: - self.listener.cancel() # type: ignore[attr-defined] + self.listener.cancel() await asyncio.sleep( DEFAULT_WORKFLOW_LISTENER_RETRY_INTERVAL ) @@ -246,7 +253,9 @@ async def result(self, workflow_run_id: str) -> dict[str, Any]: return results - async def _retry_subscribe(self) -> AsyncGenerator[WorkflowRunEvent, None]: + async def _retry_subscribe( + self, + ) -> grpc.aio.UnaryStreamCall[SubscribeToWorkflowRunsRequest, WorkflowRunEvent]: retries = 0 while retries < DEFAULT_WORKFLOW_LISTENER_RETRY_COUNT: @@ -259,7 +268,9 @@ async def _retry_subscribe(self) -> AsyncGenerator[WorkflowRunEvent, None]: self.requests.put_nowait(self.curr_requester) return cast( - AsyncGenerator[WorkflowRunEvent, None], + grpc.aio.UnaryStreamCall[ + SubscribeToWorkflowRunsRequest, WorkflowRunEvent + ], self.client.SubscribeToWorkflowRuns( self._request(), metadata=get_metadata(self.token), diff --git a/hatchet_sdk/features/cron.py b/hatchet_sdk/features/cron.py index a03a1c19..413251d2 100644 --- a/hatchet_sdk/features/cron.py +++ b/hatchet_sdk/features/cron.py @@ -152,16 +152,13 @@ def list( Returns: CronWorkflowsList: A list of cron workflows. """ - return cast( - CronWorkflowsList, - self._client.rest.cron_list( - offset=offset, - limit=limit, - workflow_id=workflow_id, - additional_metadata=additional_metadata, - order_by_field=order_by_field, - order_by_direction=order_by_direction, - ), + return self._client.rest.cron_list( + offset=offset, + limit=limit, + workflow_id=workflow_id, + additional_metadata=additional_metadata, + order_by_field=order_by_field, + order_by_direction=order_by_direction, ) def get(self, cron_trigger: Union[str, CronWorkflows]) -> CronWorkflows: @@ -174,13 +171,10 @@ def get(self, cron_trigger: Union[str, CronWorkflows]) -> CronWorkflows: Returns: CronWorkflows: The requested cron workflow instance. """ - return cast( - CronWorkflows, - self._client.rest.cron_get( - cron_trigger.metadata.id - if isinstance(cron_trigger, CronWorkflows) - else cron_trigger - ), + return self._client.rest.cron_get( + cron_trigger.metadata.id + if isinstance(cron_trigger, CronWorkflows) + else cron_trigger ) @@ -228,15 +222,12 @@ async def create( expression=expression, input=input, additional_metadata=additional_metadata ) - return cast( - CronWorkflows, - await self._client.rest.aio.cron_create( - workflow_name=workflow_name, - cron_name=cron_name, - expression=validated_input.expression, - input=validated_input.input, - additional_metadata=validated_input.additional_metadata, - ), + return await self._client.rest.aio.cron_create( + workflow_name=workflow_name, + cron_name=cron_name, + expression=validated_input.expression, + input=validated_input.input, + additional_metadata=validated_input.additional_metadata, ) async def delete(self, cron_trigger: Union[str, CronWorkflows]) -> None: @@ -275,16 +266,13 @@ async def list( Returns: CronWorkflowsList: A list of cron workflows. """ - return cast( - CronWorkflowsList, - await self._client.rest.aio.cron_list( - offset=offset, - limit=limit, - workflow_id=workflow_id, - additional_metadata=additional_metadata, - order_by_field=order_by_field, - order_by_direction=order_by_direction, - ), + return await self._client.rest.aio.cron_list( + offset=offset, + limit=limit, + workflow_id=workflow_id, + additional_metadata=additional_metadata, + order_by_field=order_by_field, + order_by_direction=order_by_direction, ) async def get(self, cron_trigger: Union[str, CronWorkflows]) -> CronWorkflows: @@ -298,11 +286,8 @@ async def get(self, cron_trigger: Union[str, CronWorkflows]) -> CronWorkflows: CronWorkflows: The requested cron workflow instance. """ - return cast( - CronWorkflows, - await self._client.rest.aio.cron_get( - cron_trigger.metadata.id - if isinstance(cron_trigger, CronWorkflows) - else cron_trigger - ), + return await self._client.rest.aio.cron_get( + cron_trigger.metadata.id + if isinstance(cron_trigger, CronWorkflows) + else cron_trigger ) diff --git a/hatchet_sdk/features/scheduled.py b/hatchet_sdk/features/scheduled.py index 60a43614..cf8462df 100644 --- a/hatchet_sdk/features/scheduled.py +++ b/hatchet_sdk/features/scheduled.py @@ -81,14 +81,11 @@ def create( trigger_at=trigger_at, input=input, additional_metadata=additional_metadata ) - return cast( - ScheduledWorkflows, - self._client.rest.schedule_create( - workflow_name, - validated_input.trigger_at, - validated_input.input, - validated_input.additional_metadata, - ), + return self._client.rest.schedule_create( + workflow_name, + validated_input.trigger_at, + validated_input.input, + validated_input.additional_metadata, ) def delete(self, scheduled: Union[str, ScheduledWorkflows]) -> None: @@ -127,16 +124,13 @@ def list( Returns: List[ScheduledWorkflows]: A list of scheduled workflows matching the criteria. """ - return cast( - ScheduledWorkflowsList, - self._client.rest.schedule_list( - offset=offset, - limit=limit, - workflow_id=workflow_id, - additional_metadata=additional_metadata, - order_by_field=order_by_field, - order_by_direction=order_by_direction, - ), + return self._client.rest.schedule_list( + offset=offset, + limit=limit, + workflow_id=workflow_id, + additional_metadata=additional_metadata, + order_by_field=order_by_field, + order_by_direction=order_by_direction, ) def get(self, scheduled: Union[str, ScheduledWorkflows]) -> ScheduledWorkflows: @@ -149,13 +143,10 @@ def get(self, scheduled: Union[str, ScheduledWorkflows]) -> ScheduledWorkflows: Returns: ScheduledWorkflows: The requested scheduled workflow instance. """ - return cast( - ScheduledWorkflows, - self._client.rest.schedule_get( - scheduled.metadata.id - if isinstance(scheduled, ScheduledWorkflows) - else scheduled - ), + return self._client.rest.schedule_get( + scheduled.metadata.id + if isinstance(scheduled, ScheduledWorkflows) + else scheduled ) @@ -197,11 +188,8 @@ async def create( Returns: ScheduledWorkflows: The created scheduled workflow instance. """ - return cast( - ScheduledWorkflows, - await self._client.rest.aio.schedule_create( - workflow_name, trigger_at, input, additional_metadata - ), + return await self._client.rest.aio.schedule_create( + workflow_name, trigger_at, input, additional_metadata ) async def delete(self, scheduled: Union[str, ScheduledWorkflows]) -> None: @@ -240,16 +228,13 @@ async def list( Returns: ScheduledWorkflowsList: A list of scheduled workflows matching the criteria. """ - return cast( - ScheduledWorkflowsList, - await self._client.rest.aio.schedule_list( - offset=offset, - limit=limit, - workflow_id=workflow_id, - additional_metadata=additional_metadata, - order_by_field=order_by_field, - order_by_direction=order_by_direction, - ), + return await self._client.rest.aio.schedule_list( + offset=offset, + limit=limit, + workflow_id=workflow_id, + additional_metadata=additional_metadata, + order_by_field=order_by_field, + order_by_direction=order_by_direction, ) async def get( @@ -264,11 +249,8 @@ async def get( Returns: ScheduledWorkflows: The requested scheduled workflow instance. """ - return cast( - ScheduledWorkflows, - await self._client.rest.aio.schedule_get( - scheduled.metadata.id - if isinstance(scheduled, ScheduledWorkflows) - else scheduled - ), + return await self._client.rest.aio.schedule_get( + scheduled.metadata.id + if isinstance(scheduled, ScheduledWorkflows) + else scheduled ) From 5dcc29866a8a82c9be00d71730e41e8335ce391d Mon Sep 17 00:00:00 2001 From: mrkaye97 Date: Tue, 28 Jan 2025 18:03:47 -0500 Subject: [PATCH 42/53] fix: finally, it all works --- hatchet_sdk/clients/rest_client.py | 11 ++++++++--- pyproject.toml | 16 ++++++++++------ 2 files changed, 18 insertions(+), 9 deletions(-) diff --git a/hatchet_sdk/clients/rest_client.py b/hatchet_sdk/clients/rest_client.py index ba62490a..763648e1 100644 --- a/hatchet_sdk/clients/rest_client.py +++ b/hatchet_sdk/clients/rest_client.py @@ -2,11 +2,10 @@ import atexit import datetime import threading -from typing import Any, Coroutine, List, TypeVar, cast +from typing import Coroutine, TypeVar from pydantic import StrictInt -from hatchet_sdk.clients.rest import CronWorkflowsList, ScheduledWorkflowsList from hatchet_sdk.clients.rest.api.event_api import EventApi from hatchet_sdk.clients.rest.api.log_api import LogApi from hatchet_sdk.clients.rest.api.step_run_api import StepRunApi @@ -15,11 +14,11 @@ from hatchet_sdk.clients.rest.api.workflow_runs_api import WorkflowRunsApi from hatchet_sdk.clients.rest.api_client import ApiClient from hatchet_sdk.clients.rest.configuration import Configuration -from hatchet_sdk.clients.rest.models import TriggerWorkflowRunRequest from hatchet_sdk.clients.rest.models.create_cron_workflow_trigger_request import ( CreateCronWorkflowTriggerRequest, ) from hatchet_sdk.clients.rest.models.cron_workflows import CronWorkflows +from hatchet_sdk.clients.rest.models.cron_workflows_list import CronWorkflowsList from hatchet_sdk.clients.rest.models.cron_workflows_order_by_field import ( CronWorkflowsOrderByField, ) @@ -48,9 +47,15 @@ ScheduleWorkflowRunRequest, ) from hatchet_sdk.clients.rest.models.scheduled_workflows import ScheduledWorkflows +from hatchet_sdk.clients.rest.models.scheduled_workflows_list import ( + ScheduledWorkflowsList, +) from hatchet_sdk.clients.rest.models.scheduled_workflows_order_by_field import ( ScheduledWorkflowsOrderByField, ) +from hatchet_sdk.clients.rest.models.trigger_workflow_run_request import ( + TriggerWorkflowRunRequest, +) from hatchet_sdk.clients.rest.models.workflow import Workflow from hatchet_sdk.clients.rest.models.workflow_kind import WorkflowKind from hatchet_sdk.clients.rest.models.workflow_list import WorkflowList diff --git a/pyproject.toml b/pyproject.toml index 1e061b35..06dc9ba4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -89,6 +89,10 @@ exclude = [ "hatchet_sdk/clients/rest/api/*", "hatchet_sdk/clients/rest/models/*", "hatchet_sdk/contracts", + "hatchet_sdk/clients/rest/api_client.py", + "hatchet_sdk/clients/rest/configuration.py", + "hatchet_sdk/clients/rest/exceptions.py", + "hatchet_sdk/clients/rest/rest.py", ] explicit_package_bases = true @@ -115,19 +119,19 @@ warn_return_any = true [[tool.mypy.overrides]] module = ["hatchet_sdk/contracts/*", "hatchet_sdk/clients/rest/*"] -warn_unused_ignores = false +warn_unused_ignores = true -strict_equality = false +strict_equality = true -disallow_subclassing_any = false -disallow_untyped_decorators = false -disallow_any_generics = false +disallow_subclassing_any = true +disallow_untyped_decorators = true +disallow_any_generics = true disallow_untyped_calls = false disallow_incomplete_defs = false disallow_untyped_defs = false -no_implicit_reexport = false +no_implicit_reexport = true warn_return_any = false From ab40bdde8c5bb8dea9a95ba4cb782cc06f12c025 Mon Sep 17 00:00:00 2001 From: mrkaye97 Date: Tue, 28 Jan 2025 18:05:48 -0500 Subject: [PATCH 43/53] fix: clean up mypy config --- pyproject.toml | 42 +----------------------------------------- 1 file changed, 1 insertion(+), 41 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 06dc9ba4..3a21726f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -84,7 +84,6 @@ extend_exclude = "hatchet_sdk/contracts/" [tool.mypy] files = ["."] -follow_imports = "silent" exclude = [ "hatchet_sdk/clients/rest/api/*", "hatchet_sdk/clients/rest/models/*", @@ -94,46 +93,7 @@ exclude = [ "hatchet_sdk/clients/rest/exceptions.py", "hatchet_sdk/clients/rest/rest.py", ] - -explicit_package_bases = true -warn_unused_configs = true -warn_redundant_casts = true -warn_unused_ignores = true - -strict_equality = true - -check_untyped_defs = true - -disallow_subclassing_any = true -disallow_untyped_decorators = true -disallow_any_generics = true - -disallow_untyped_calls = true -disallow_incomplete_defs = true -disallow_untyped_defs = true - -no_implicit_reexport = true - -warn_return_any = true - -[[tool.mypy.overrides]] -module = ["hatchet_sdk/contracts/*", "hatchet_sdk/clients/rest/*"] - -warn_unused_ignores = true - -strict_equality = true - -disallow_subclassing_any = true -disallow_untyped_decorators = true -disallow_any_generics = true - -disallow_untyped_calls = false -disallow_incomplete_defs = false -disallow_untyped_defs = false - -no_implicit_reexport = true - -warn_return_any = false +strict = true [tool.poetry.scripts] api = "examples.api.api:main" From 49693cd1bf4385da25dbe04fabc1feda74cb594e Mon Sep 17 00:00:00 2001 From: mrkaye97 Date: Tue, 28 Jan 2025 18:07:46 -0500 Subject: [PATCH 44/53] fix: rm a couple type: ignore comments --- hatchet_sdk/hatchet.py | 2 +- hatchet_sdk/worker/runner/runner.py | 2 +- hatchet_sdk/workflow.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/hatchet_sdk/hatchet.py b/hatchet_sdk/hatchet.py index 60e733ae..cdf17225 100644 --- a/hatchet_sdk/hatchet.py +++ b/hatchet_sdk/hatchet.py @@ -48,7 +48,7 @@ def workflow( version: str = "", timeout: str = "60m", schedule_timeout: str = "5m", - sticky: Union[StickyStrategy.Value, None] = None, # type: ignore[name-defined] + sticky: Union[StickyStrategy, None] = None, default_priority: int | None = None, concurrency: ConcurrencyExpression | None = None, input_validator: Type[T] | None = None, diff --git a/hatchet_sdk/worker/runner/runner.py b/hatchet_sdk/worker/runner/runner.py index 5ca95e9e..a5df9fc5 100644 --- a/hatchet_sdk/worker/runner/runner.py +++ b/hatchet_sdk/worker/runner/runner.py @@ -19,7 +19,7 @@ from hatchet_sdk.clients.dispatcher.dispatcher import new_dispatcher from hatchet_sdk.clients.run_event_listener import new_listener from hatchet_sdk.clients.workflow_listener import PooledWorkflowRunListener -from hatchet_sdk.context import Context # type: ignore[attr-defined] +from hatchet_sdk.context.context import Context from hatchet_sdk.context.worker_context import WorkerContext from hatchet_sdk.contracts.dispatcher_pb2 import ( GROUP_KEY_EVENT_TYPE_COMPLETED, diff --git a/hatchet_sdk/workflow.py b/hatchet_sdk/workflow.py index 9c5cef90..4a1e045c 100644 --- a/hatchet_sdk/workflow.py +++ b/hatchet_sdk/workflow.py @@ -93,7 +93,7 @@ def get_create_opts(self, namespace: str) -> Any: ... version: str timeout: str schedule_timeout: str - sticky: Union[StickyStrategy.Value, None] # type: ignore[name-defined] + sticky: Union[StickyStrategy, None] default_priority: int | None concurrency_expression: ConcurrencyExpression | None input_validator: Type[BaseModel] | None From 097301782595f4b12f247d951bc80d452ba1c0fc Mon Sep 17 00:00:00 2001 From: mrkaye97 Date: Wed, 29 Jan 2025 18:51:52 -0500 Subject: [PATCH 45/53] fix: remove some Any types --- hatchet_sdk/clients/dispatcher/dispatcher.py | 28 +++++++++++-------- hatchet_sdk/clients/event_ts.py | 10 +++++-- hatchet_sdk/worker/worker.py | 29 +++++++++++--------- 3 files changed, 41 insertions(+), 26 deletions(-) diff --git a/hatchet_sdk/clients/dispatcher/dispatcher.py b/hatchet_sdk/clients/dispatcher/dispatcher.py index e956557e..d9cae493 100644 --- a/hatchet_sdk/clients/dispatcher/dispatcher.py +++ b/hatchet_sdk/clients/dispatcher/dispatcher.py @@ -1,5 +1,6 @@ from typing import Any, cast +import grpc.aio from google.protobuf.timestamp_pb2 import Timestamp from hatchet_sdk.clients.dispatcher.action_listener import ( @@ -69,7 +70,7 @@ async def get_action_listener( async def send_step_action_event( self, action: Action, event_type: StepActionEventType, payload: str - ) -> Any: + ) -> grpc.aio.UnaryUnaryCall[StepActionEvent, ActionEventResponse] | None: try: return await self._try_send_step_action_event(action, event_type, payload) except Exception as e: @@ -84,12 +85,12 @@ async def send_step_action_event( "Failed to send finished event: " + str(e), ) - return + return None @tenacity_retry async def _try_send_step_action_event( self, action: Action, event_type: StepActionEventType, payload: str - ) -> Any: + ) -> grpc.aio.UnaryUnaryCall[StepActionEvent, ActionEventResponse]: eventTimestamp = Timestamp() eventTimestamp.GetCurrentTime() @@ -105,15 +106,17 @@ async def _try_send_step_action_event( eventPayload=payload, ) - ## TODO: What does this return? - return await self.aio_client.SendStepActionEvent( - event, - metadata=get_metadata(self.token), + return cast( + grpc.aio.UnaryUnaryCall[StepActionEvent, ActionEventResponse], + await self.aio_client.SendStepActionEvent( + event, + metadata=get_metadata(self.token), + ), ) async def send_group_key_action_event( self, action: Action, event_type: GroupKeyActionEventType, payload: str - ) -> Any: + ) -> grpc.aio.UnaryUnaryCall[GroupKeyActionEvent, ActionEventResponse]: eventTimestamp = Timestamp() eventTimestamp.GetCurrentTime() @@ -128,9 +131,12 @@ async def send_group_key_action_event( ) ## TODO: What does this return? - return await self.aio_client.SendGroupKeyActionEvent( - event, - metadata=get_metadata(self.token), + return cast( + grpc.aio.UnaryUnaryCall[GroupKeyActionEvent, ActionEventResponse], + await self.aio_client.SendGroupKeyActionEvent( + event, + metadata=get_metadata(self.token), + ), ) def put_overrides_data(self, data: OverridesData) -> ActionEventResponse: diff --git a/hatchet_sdk/clients/event_ts.py b/hatchet_sdk/clients/event_ts.py index 694e7af6..7a85d467 100644 --- a/hatchet_sdk/clients/event_ts.py +++ b/hatchet_sdk/clients/event_ts.py @@ -2,6 +2,7 @@ from typing import Any, TypeVar, cast import grpc.aio +from grpc._cython import cygrpc # type: ignore[attr-defined] class Event_ts(asyncio.Event): @@ -28,8 +29,13 @@ def clear(self) -> None: async def read_with_interrupt( listener: grpc.aio.UnaryStreamCall[TRequest, TResponse], interrupt: Event_ts -) -> Any: +) -> TResponse: try: - return cast(Any, await listener.read()) + result = await listener.read() + + if result is cygrpc.EOF: + raise ValueError("Unexpected EOF") + + return cast(TResponse, result) finally: interrupt.set() diff --git a/hatchet_sdk/worker/worker.py b/hatchet_sdk/worker/worker.py index 50bf18f8..2c393245 100644 --- a/hatchet_sdk/worker/worker.py +++ b/hatchet_sdk/worker/worker.py @@ -19,13 +19,20 @@ from hatchet_sdk import Context from hatchet_sdk.client import Client, new_client_raw +from hatchet_sdk.clients.dispatcher.action_listener import Action from hatchet_sdk.contracts.workflows_pb2 import CreateWorkflowVersionOpts from hatchet_sdk.loader import ClientConfig from hatchet_sdk.logger import logger from hatchet_sdk.utils.types import WorkflowValidator from hatchet_sdk.utils.typing import is_basemodel_subclass -from hatchet_sdk.worker.action_listener_process import worker_action_listener_process -from hatchet_sdk.worker.runner.run_loop_manager import WorkerActionRunLoopManager +from hatchet_sdk.worker.action_listener_process import ( + ActionEvent, + worker_action_listener_process, +) +from hatchet_sdk.worker.runner.run_loop_manager import ( + STOP_LOOP_TYPE, + WorkerActionRunLoopManager, +) from hatchet_sdk.workflow import WorkflowInterface T = TypeVar("T") @@ -74,13 +81,13 @@ def __init__( self._status: WorkerStatus self.action_listener_process: BaseProcess - self.action_listener_health_check: asyncio.Task[Any] + self.action_listener_health_check: asyncio.Task[None] self.action_runner: WorkerActionRunLoopManager self.ctx = multiprocessing.get_context("spawn") - self.action_queue: "Queue[Any]" = self.ctx.Queue() - self.event_queue: "Queue[Any]" = self.ctx.Queue() + self.action_queue: "Queue[Action | STOP_LOOP_TYPE]" = self.ctx.Queue() + self.event_queue: "Queue[ActionEvent]" = self.ctx.Queue() self.loop: asyncio.AbstractEventLoop @@ -193,12 +200,10 @@ async def start_health_server(self) -> None: logger.info(f"healthcheck server running on port {port}") - def start( - self, options: WorkerStartOptions = WorkerStartOptions() - ) -> Future[asyncio.Task[Any] | None]: + def start(self, options: WorkerStartOptions = WorkerStartOptions()) -> None: self.owned_loop = self.setup_loop(options.loop) - f = asyncio.run_coroutine_threadsafe( + asyncio.run_coroutine_threadsafe( self.async_start(options, _from_start=True), self.loop ) @@ -209,14 +214,12 @@ def start( if self.handle_kill: sys.exit(0) - return f - ## Start methods async def async_start( self, options: WorkerStartOptions = WorkerStartOptions(), _from_start: bool = False, - ) -> Any | None: + ) -> None: main_pid = os.getpid() logger.info("------------------------------------------") logger.info("STARTING HATCHET...") @@ -245,7 +248,7 @@ async def async_start( self._check_listener_health() ) - return await self.action_listener_health_check + await self.action_listener_health_check def _run_action_runner(self) -> WorkerActionRunLoopManager: # Retrieve the shared queue From 9ced997c862e90256b0498a48fe17dc811bdc1ad Mon Sep 17 00:00:00 2001 From: mrkaye97 Date: Wed, 29 Jan 2025 19:05:01 -0500 Subject: [PATCH 46/53] fix: remove a few more Any types --- .../clients/dispatcher/action_listener.py | 27 ++++++++++--------- hatchet_sdk/context/context.py | 8 +++--- 2 files changed, 19 insertions(+), 16 deletions(-) diff --git a/hatchet_sdk/clients/dispatcher/action_listener.py b/hatchet_sdk/clients/dispatcher/action_listener.py index ff422805..667f9cf2 100644 --- a/hatchet_sdk/clients/dispatcher/action_listener.py +++ b/hatchet_sdk/clients/dispatcher/action_listener.py @@ -70,9 +70,9 @@ class Action: step_id: str step_run_id: str action_id: str - action_payload: str action_type: ActionType retry_count: int + action_payload: JSONSerializableDict = field(default_factory=dict) additional_metadata: JSONSerializableDict = field(default_factory=dict) child_workflow_index: int | None = None @@ -268,15 +268,19 @@ async def _generator(self) -> AsyncGenerator[Action | None, None]: # Process the received action action_type = assigned_action.actionType - if ( - assigned_action.actionPayload is None - or assigned_action.actionPayload == "" - ): - action_payload = None - else: - action_payload = self.parse_action_payload( - assigned_action.actionPayload + action_payload = ( + {} + if not assigned_action.actionPayload + else self.parse_action_payload(assigned_action.actionPayload) + ) + + try: + additional_metadata = cast( + dict[str, Any], + json.loads(assigned_action.additional_metadata), ) + except json.JSONDecodeError: + additional_metadata = {} action = Action( tenant_id=assigned_action.tenantId, @@ -289,11 +293,10 @@ async def _generator(self) -> AsyncGenerator[Action | None, None]: step_id=assigned_action.stepId, step_run_id=assigned_action.stepRunId, action_id=assigned_action.actionId, - ## TODO: Figure out this type - maybe needs to be dumped to JSON? - action_payload=action_payload, # type: ignore[arg-type] + action_payload=action_payload, action_type=action_type, retry_count=assigned_action.retryCount, - additional_metadata=assigned_action.additional_metadata, + additional_metadata=additional_metadata, child_workflow_index=assigned_action.child_workflow_index, child_workflow_key=assigned_action.child_workflow_key, parent_workflow_run_id=assigned_action.parent_workflow_run_id, diff --git a/hatchet_sdk/context/context.py b/hatchet_sdk/context/context.py index 6b86ca36..aa72f208 100644 --- a/hatchet_sdk/context/context.py +++ b/hatchet_sdk/context/context.py @@ -170,6 +170,8 @@ def __init__( namespace, ) + self.data: dict[str, Any] + # Check the type of action.action_payload before attempting to load it as JSON if isinstance(action.action_payload, (str, bytes, bytearray)): try: @@ -180,16 +182,14 @@ def __init__( self.data: dict[str, Any] = {} # type: ignore[no-redef] else: # Directly assign the payload to self.data if it's already a dict - self.data = ( - action.action_payload if isinstance(action.action_payload, dict) else {} - ) + self.data = action.action_payload self.action = action # FIXME: stepRunId is a legacy field, we should remove it self.stepRunId = action.step_run_id - self.step_run_id = action.step_run_id + self.step_run_id: str = action.step_run_id self.exit_flag = False self.dispatcher_client = dispatcher_client self.admin_client = admin_client From e8722ca9a47b05b00a0de962a1e44f1b46614c35 Mon Sep 17 00:00:00 2001 From: mrkaye97 Date: Wed, 29 Jan 2025 19:22:27 -0500 Subject: [PATCH 47/53] fix: couple more types --- hatchet_sdk/clients/dispatcher/action_listener.py | 1 - hatchet_sdk/clients/rest_client.py | 7 +++---- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/hatchet_sdk/clients/dispatcher/action_listener.py b/hatchet_sdk/clients/dispatcher/action_listener.py index 667f9cf2..763c0e6a 100644 --- a/hatchet_sdk/clients/dispatcher/action_listener.py +++ b/hatchet_sdk/clients/dispatcher/action_listener.py @@ -366,7 +366,6 @@ async def get_listen_client( f"action listener connection interrupted, retrying... ({self.retries}/{DEFAULT_ACTION_LISTENER_RETRY_COUNT})" ) - ## TODO: Figure out how to get type support for these self.aio_client = DispatcherStub(new_conn(self.config, True)) # type: ignore[no-untyped-call] if self.listen_strategy == "v2": diff --git a/hatchet_sdk/clients/rest_client.py b/hatchet_sdk/clients/rest_client.py index 763648e1..83266a55 100644 --- a/hatchet_sdk/clients/rest_client.py +++ b/hatchet_sdk/clients/rest_client.py @@ -234,15 +234,14 @@ async def workflow_run_create( workflow_id: str, input: JSONSerializableDict, version: str | None = None, - additional_metadata: list[str] | None = None, + additional_metadata: JSONSerializableDict = {}, ) -> WorkflowRun: return await self.workflow_run_api.workflow_run_create( workflow=workflow_id, version=version, - ## TODO: Fix this type error - maybe a list of strings is okay since it's still JSON? trigger_workflow_run_request=TriggerWorkflowRunRequest( input=input, - additionalMetadata=additional_metadata, # type: ignore[arg-type] + additionalMetadata=additional_metadata, ), ) @@ -495,7 +494,7 @@ def workflow_run_create( workflow_id: str, input: JSONSerializableDict, version: str | None = None, - additional_metadata: list[str] | None = None, + additional_metadata: JSONSerializableDict = {}, ) -> WorkflowRun: return self._run_coroutine( self.aio.workflow_run_create( From bde190afa621564f6eb49296f0433df08d693890 Mon Sep 17 00:00:00 2001 From: mrkaye97 Date: Wed, 29 Jan 2025 20:32:34 -0500 Subject: [PATCH 48/53] chore: copy --- hatchet_sdk/v2.py | 307 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 307 insertions(+) create mode 100644 hatchet_sdk/v2.py diff --git a/hatchet_sdk/v2.py b/hatchet_sdk/v2.py new file mode 100644 index 00000000..cdf17225 --- /dev/null +++ b/hatchet_sdk/v2.py @@ -0,0 +1,307 @@ +import asyncio +import logging +from typing import Any, Callable, Optional, ParamSpec, Type, TypeVar, Union + +from pydantic import BaseModel +from typing_extensions import deprecated + +from hatchet_sdk.clients.rest_client import RestApi +from hatchet_sdk.context.context import Context +from hatchet_sdk.contracts.workflows_pb2 import ( + ConcurrencyLimitStrategy, + CreateStepRateLimit, + DesiredWorkerLabels, + StickyStrategy, + WorkerLabelComparator, +) +from hatchet_sdk.features.cron import CronClient +from hatchet_sdk.features.scheduled import ScheduledClient +from hatchet_sdk.labels import DesiredWorkerLabel +from hatchet_sdk.loader import ClientConfig +from hatchet_sdk.rate_limit import RateLimit + +from .client import Client, new_client, new_client_raw +from .clients.admin import AdminClient +from .clients.dispatcher.dispatcher import DispatcherClient +from .clients.events import EventClient +from .clients.run_event_listener import RunEventListenerClient +from .logger import logger +from .worker.worker import Worker +from .workflow import ( + ConcurrencyExpression, + WorkflowInterface, + WorkflowMeta, + WorkflowStepProtocol, +) + +T = TypeVar("T", bound=BaseModel) +R = TypeVar("R") +P = ParamSpec("P") + +TWorkflow = TypeVar("TWorkflow", bound=object) + + +def workflow( + name: str = "", + on_events: list[str] | None = None, + on_crons: list[str] | None = None, + version: str = "", + timeout: str = "60m", + schedule_timeout: str = "5m", + sticky: Union[StickyStrategy, None] = None, + default_priority: int | None = None, + concurrency: ConcurrencyExpression | None = None, + input_validator: Type[T] | None = None, +) -> Callable[[Type[TWorkflow]], WorkflowMeta]: + on_events = on_events or [] + on_crons = on_crons or [] + + def inner(cls: Type[TWorkflow]) -> WorkflowMeta: + nonlocal name + name = name or str(cls.__name__) + + setattr(cls, "on_events", on_events) + setattr(cls, "on_crons", on_crons) + setattr(cls, "name", name) + setattr(cls, "version", version) + setattr(cls, "timeout", timeout) + setattr(cls, "schedule_timeout", schedule_timeout) + setattr(cls, "sticky", sticky) + setattr(cls, "default_priority", default_priority) + setattr(cls, "concurrency_expression", concurrency) + + # Define a new class with the same name and bases as the original, but + # with WorkflowMeta as its metaclass + + ## TODO: Figure out how to type this metaclass correctly + setattr(cls, "input_validator", input_validator) + + return WorkflowMeta(name, cls.__bases__, dict(cls.__dict__)) + + return inner + + +def step( + name: str = "", + timeout: str = "", + parents: list[str] | None = None, + retries: int = 0, + rate_limits: list[RateLimit] | None = None, + desired_worker_labels: dict[str, DesiredWorkerLabel] = {}, + backoff_factor: float | None = None, + backoff_max_seconds: int | None = None, +) -> Callable[[Callable[P, R]], Callable[P, R]]: + parents = parents or [] + + def inner(func: Callable[P, R]) -> Callable[P, R]: + limits = None + if rate_limits: + limits = [rate_limit._req for rate_limit in rate_limits or []] + + setattr(func, "_step_name", name.lower() or str(func.__name__).lower()) + setattr(func, "_step_parents", parents) + setattr(func, "_step_timeout", timeout) + setattr(func, "_step_retries", retries) + setattr(func, "_step_rate_limits", limits) + setattr(func, "_step_backoff_factor", backoff_factor) + setattr(func, "_step_backoff_max_seconds", backoff_max_seconds) + + def create_label(d: DesiredWorkerLabel) -> DesiredWorkerLabels: + value = d.value + return DesiredWorkerLabels( + strValue=value if not isinstance(value, int) else None, + intValue=value if isinstance(value, int) else None, + required=d.required, + weight=d.weight, + comparator=d.comparator, # type: ignore[arg-type] + ) + + setattr( + func, + "_step_desired_worker_labels", + {key: create_label(d) for key, d in desired_worker_labels.items()}, + ) + + return func + + return inner + + +def on_failure_step( + name: str = "", + timeout: str = "", + retries: int = 0, + rate_limits: list[RateLimit] | None = None, + backoff_factor: float | None = None, + backoff_max_seconds: int | None = None, +) -> Callable[..., Any]: + def inner(func: Callable[[Context], Any]) -> Callable[[Context], Any]: + limits = None + if rate_limits: + limits = [ + CreateStepRateLimit(key=rate_limit.static_key, units=rate_limit.units) # type: ignore[arg-type] + for rate_limit in rate_limits or [] + ] + + setattr( + func, "_on_failure_step_name", name.lower() or str(func.__name__).lower() + ) + setattr(func, "_on_failure_step_timeout", timeout) + setattr(func, "_on_failure_step_retries", retries) + setattr(func, "_on_failure_step_rate_limits", limits) + setattr(func, "_on_failure_step_backoff_factor", backoff_factor) + setattr(func, "_on_failure_step_backoff_max_seconds", backoff_max_seconds) + + return func + + return inner + + +def concurrency( + name: str = "", + max_runs: int = 1, + limit_strategy: ConcurrencyLimitStrategy = ConcurrencyLimitStrategy.CANCEL_IN_PROGRESS, +) -> Callable[..., Any]: + def inner(func: Callable[[Context], Any]) -> Callable[[Context], Any]: + setattr( + func, + "_concurrency_fn_name", + name.lower() or str(func.__name__).lower(), + ) + setattr(func, "_concurrency_max_runs", max_runs) + setattr(func, "_concurrency_limit_strategy", limit_strategy) + + return func + + return inner + + +class HatchetRest: + """ + Main client for interacting with the Hatchet API. + + This class provides access to various client interfaces and utility methods + for working with Hatchet via the REST API, + + Attributes: + rest (RestApi): Interface for REST API operations. + """ + + def __init__(self, config: ClientConfig = ClientConfig()): + self.rest = RestApi(config.server_url, config.token, config.tenant_id) + + +class Hatchet: + """ + Main client for interacting with the Hatchet SDK. + + This class provides access to various client interfaces and utility methods + for working with Hatchet workers, workflows, and steps. + + Attributes: + cron (CronClient): Interface for cron trigger operations. + + admin (AdminClient): Interface for administrative operations. + dispatcher (DispatcherClient): Interface for dispatching operations. + event (EventClient): Interface for event-related operations. + rest (RestApi): Interface for REST API operations. + """ + + _client: Client + cron: CronClient + scheduled: ScheduledClient + + @classmethod + def from_environment( + cls, defaults: ClientConfig = ClientConfig(), **kwargs: Any + ) -> "Hatchet": + return cls(client=new_client(defaults), **kwargs) + + @classmethod + def from_config(cls, config: ClientConfig, **kwargs: Any) -> "Hatchet": + return cls(client=new_client_raw(config), **kwargs) + + def __init__( + self, + debug: bool = False, + client: Optional[Client] = None, + config: ClientConfig = ClientConfig(), + ): + """ + Initialize a new Hatchet instance. + + Args: + debug (bool, optional): Enable debug logging. Defaults to False. + client (Optional[Client], optional): A pre-configured Client instance. Defaults to None. + config (ClientConfig, optional): Configuration for creating a new Client. Defaults to ClientConfig(). + """ + if client is not None: + self._client = client + else: + self._client = new_client(config, debug) + + if debug: + logger.setLevel(logging.DEBUG) + + self.cron = CronClient(self._client) + self.scheduled = ScheduledClient(self._client) + + @property + @deprecated( + "Direct access to client is deprecated and will be removed in a future version. Use specific client properties (Hatchet.admin, Hatchet.dispatcher, Hatchet.event, Hatchet.rest) instead. [0.32.0]", + ) + def client(self) -> Client: + return self._client + + @property + def admin(self) -> AdminClient: + return self._client.admin + + @property + def dispatcher(self) -> DispatcherClient: + return self._client.dispatcher + + @property + def event(self) -> EventClient: + return self._client.event + + @property + def rest(self) -> RestApi: + return self._client.rest + + @property + def listener(self) -> RunEventListenerClient: + return self._client.listener + + @property + def config(self) -> ClientConfig: + return self._client.config + + @property + def tenant_id(self) -> str: + return self._client.config.tenant_id + + concurrency = staticmethod(concurrency) + + workflow = staticmethod(workflow) + + step = staticmethod(step) + + on_failure_step = staticmethod(on_failure_step) + + def worker( + self, name: str, max_runs: int | None = None, labels: dict[str, str | int] = {} + ) -> Worker: + try: + loop = asyncio.get_running_loop() + except RuntimeError: + loop = None + + return Worker( + name=name, + max_runs=max_runs, + labels=labels, + config=self._client.config, + debug=self._client.debug, + owned_loop=loop is None, + ) From 45e21d3fab0ef7781693b75ec403c415a8f2dfa3 Mon Sep 17 00:00:00 2001 From: mrkaye97 Date: Wed, 29 Jan 2025 21:00:20 -0500 Subject: [PATCH 49/53] feat: initial workflow impl --- hatchet_sdk/v2.py | 235 ++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 197 insertions(+), 38 deletions(-) diff --git a/hatchet_sdk/v2.py b/hatchet_sdk/v2.py index cdf17225..a14a66db 100644 --- a/hatchet_sdk/v2.py +++ b/hatchet_sdk/v2.py @@ -1,9 +1,18 @@ import asyncio +from hatchet_sdk.contracts.workflows_pb2 import ( + CreateWorkflowJobOpts, + CreateWorkflowStepOpts, + CreateWorkflowVersionOpts, + StickyStrategy, + WorkflowConcurrencyOpts, + WorkflowKind, +) import logging from typing import Any, Callable, Optional, ParamSpec, Type, TypeVar, Union -from pydantic import BaseModel +from pydantic import BaseModel, ConfigDict from typing_extensions import deprecated +from enum import Enum from hatchet_sdk.clients.rest_client import RestApi from hatchet_sdk.context.context import Context @@ -40,45 +49,195 @@ TWorkflow = TypeVar("TWorkflow", bound=object) +class EmptyModel(BaseModel): + model_config = ConfigDict(extra="allow") + + +class WorkflowConfig(BaseModel): + name: str = "" + on_events: list[str] = [] + on_crons: list[str] = [] + version: str = "" + timeout: str = "60m" + schedule_timeout: str = "5m" + sticky: Union[StickyStrategy, None] = None + default_priority: int = 0 + concurrency: ConcurrencyExpression | None = None + input_validator: Type[BaseModel] = EmptyModel + +class StepType(str, Enum): + DEFAULT = "default" + CONCURRENCY = "concurrency" + ON_FAILURE = "on_failure" + + +class Step: + def __init__(self) -> None: + self.type = StepType.DEFAULT + self.timeout = "60s" + self.name = "name" + self.parents: list[Step] = [] + self.retries: int = 0 + self.rate_limits: list[RateLimit] = [] + self.desired_worker_labels: dict[str, DesiredWorkerLabel] = {} + self.backoff_factor: float | None = None + self.backoff_max_seconds: int | None = None + self.concurrency__max_runs = 1 + self.concurrency__limit_strategy = ConcurrencyLimitStrategy.CANCEL_IN_PROGRESS + + +class Workflow: + config: WorkflowConfig = WorkflowConfig() + + def get_service_name(self, namespace: str) -> str: + return f"{namespace}{self.config.name.lower()}" -def workflow( - name: str = "", - on_events: list[str] | None = None, - on_crons: list[str] | None = None, - version: str = "", - timeout: str = "60m", - schedule_timeout: str = "5m", - sticky: Union[StickyStrategy, None] = None, - default_priority: int | None = None, - concurrency: ConcurrencyExpression | None = None, - input_validator: Type[T] | None = None, -) -> Callable[[Type[TWorkflow]], WorkflowMeta]: - on_events = on_events or [] - on_crons = on_crons or [] - - def inner(cls: Type[TWorkflow]) -> WorkflowMeta: - nonlocal name - name = name or str(cls.__name__) - - setattr(cls, "on_events", on_events) - setattr(cls, "on_crons", on_crons) - setattr(cls, "name", name) - setattr(cls, "version", version) - setattr(cls, "timeout", timeout) - setattr(cls, "schedule_timeout", schedule_timeout) - setattr(cls, "sticky", sticky) - setattr(cls, "default_priority", default_priority) - setattr(cls, "concurrency_expression", concurrency) - - # Define a new class with the same name and bases as the original, but - # with WorkflowMeta as its metaclass - - ## TODO: Figure out how to type this metaclass correctly - setattr(cls, "input_validator", input_validator) - - return WorkflowMeta(name, cls.__bases__, dict(cls.__dict__)) + @property + def on_failure_steps(self) -> list[Step]: + return [ + inst + for attr in dir(self) + if isinstance(inst := getattr(self, attr), Step) and inst.type == StepType.ON_FAILURE + ] - return inner + @property + def concurrency_actions(self) -> list[Step]: + return [ + inst + for attr in dir(self) + if isinstance(inst := getattr(self, attr), Step) and inst.type == StepType.CONCURRENCY + ] + + @property + def default_steps(self) -> list[Step]: + return [ + inst + for attr in dir(self) + if isinstance(inst := getattr(self, attr), Step) and inst.type == StepType.DEFAULT + ] + + + @property + def steps(self) -> list[Step]: + return self.default_steps + self.concurrency_actions + self.on_failure_steps + + + @property + def actions(self, namespace: str) -> list[Step]: + service_name = self.get_service_name(namespace) + + return [ + service_name + ":" + step + for step in self.steps + ] + + + def __init__(self) -> None: + self.config.name = self.config.name or str(self.__class__.__name__) + + def get_name(self, namespace: str) -> str: + return namespace + self.config.name + + def validate_concurrency_actions(self, service_name: str) -> WorkflowConcurrencyOpts | None: + if len(self.concurrency_actions) > 0 and self.config.concurrency: + raise ValueError( + "Error: Both concurrencyActions and concurrency_expression are defined. Please use only one concurrency configuration method." + ) + + if len(self.concurrency_actions) > 0: + action = self.concurrency_actions[0] + + return WorkflowConcurrencyOpts( + action=service_name + ":" + action.name, + max_runs=action.concurrency__max_runs, + limit_strategy=action.concurrency__limit_strategy, + ) + + if self.config.concurrency: + return WorkflowConcurrencyOpts( + expression=self.config.concurrency.expression, + max_runs=self.config.concurrency.max_runs, + limit_strategy=self.config.concurrency.limit_strategy, + ) + + def validate_on_failure_steps(self, name: str, service_name: str) -> CreateWorkflowJobOpts | None: + if not self.on_failure_steps: + return None + + on_failure_step = next(iter(self.on_failure_steps)) + + return CreateWorkflowJobOpts( + name=name + "-on-failure", + steps=[ + CreateWorkflowStepOpts( + readable_id=on_failure_step.name, + action=service_name + ":" + on_failure_step.name, + timeout=on_failure_step.timeout or "60s", + inputs="{}", + parents=[], + retries=on_failure_step.retries, + rate_limits=on_failure_step.rate_limits, # type: ignore[arg-type] + backoff_factor=on_failure_step.backoff_factor, + backoff_max_seconds=on_failure_step.backoff_max_seconds, + ) + ], + ) + + def validate_priority(self, default_priority: int | None) -> int | None: + validated_priority = ( + max(1, min(3, default_priority)) if default_priority else None + ) + if validated_priority != default_priority: + logger.warning( + "Warning: Default Priority Must be between 1 and 3 -- inclusively. Adjusted to be within the range." + ) + + return validated_priority + + def get_create_opts(self, namespace: str) -> CreateWorkflowVersionOpts: + service_name = self.get_service_name(namespace) + + name = self.get_name(namespace) + event_triggers = [namespace + event for event in self.config.on_events] + + create_step_opts = [ + CreateWorkflowStepOpts( + readable_id=step.name, + action=service_name + ":" + step.name, + timeout=step.timeout or "60s", + inputs="{}", + parents=[x for x in step.parents], + retries=step.retries, + rate_limits=step.rate_limits, # type: ignore[arg-type] + worker_labels=step.desired_worker_labels, # type: ignore[arg-type] + backoff_factor=step.backoff_factor, + backoff_max_seconds=step.backoff_max_seconds, + ) + for step in self.steps + ] + + concurrency = self.validate_concurrency_actions(service_name) + on_failure_job = self.validate_on_failure_steps(name, service_name) + validated_priority = self.validate_priority(self.config.default_priority) + + return CreateWorkflowVersionOpts( + name=name, + kind=WorkflowKind.DAG, + version=self.config.version, + event_triggers=event_triggers, + cron_triggers=self.config.on_crons, + schedule_timeout=self.config.schedule_timeout, + sticky=self.config.sticky, + jobs=[ + CreateWorkflowJobOpts( + name=name, + steps=create_step_opts, + ) + ], + on_failure_job=on_failure_job, + concurrency=concurrency, + default_priority=validated_priority, + ) def step( From 50e7bf1d7959b0f3267a66d764360e917447dd56 Mon Sep 17 00:00:00 2001 From: mrkaye97 Date: Wed, 29 Jan 2025 21:12:35 -0500 Subject: [PATCH 50/53] feat: step class --- hatchet_sdk/v2.py | 215 +++++++++++++++++++++------------------------- 1 file changed, 100 insertions(+), 115 deletions(-) diff --git a/hatchet_sdk/v2.py b/hatchet_sdk/v2.py index a14a66db..19efcbe1 100644 --- a/hatchet_sdk/v2.py +++ b/hatchet_sdk/v2.py @@ -1,27 +1,24 @@ import asyncio -from hatchet_sdk.contracts.workflows_pb2 import ( - CreateWorkflowJobOpts, - CreateWorkflowStepOpts, - CreateWorkflowVersionOpts, - StickyStrategy, - WorkflowConcurrencyOpts, - WorkflowKind, -) import logging +from enum import Enum from typing import Any, Callable, Optional, ParamSpec, Type, TypeVar, Union from pydantic import BaseModel, ConfigDict from typing_extensions import deprecated -from enum import Enum from hatchet_sdk.clients.rest_client import RestApi from hatchet_sdk.context.context import Context from hatchet_sdk.contracts.workflows_pb2 import ( ConcurrencyLimitStrategy, CreateStepRateLimit, + CreateWorkflowJobOpts, + CreateWorkflowStepOpts, + CreateWorkflowVersionOpts, DesiredWorkerLabels, StickyStrategy, WorkerLabelComparator, + WorkflowConcurrencyOpts, + WorkflowKind, ) from hatchet_sdk.features.cron import CronClient from hatchet_sdk.features.scheduled import ScheduledClient @@ -49,6 +46,7 @@ TWorkflow = TypeVar("TWorkflow", bound=object) + class EmptyModel(BaseModel): model_config = ConfigDict(extra="allow") @@ -65,6 +63,7 @@ class WorkflowConfig(BaseModel): concurrency: ConcurrencyExpression | None = None input_validator: Type[BaseModel] = EmptyModel + class StepType(str, Enum): DEFAULT = "default" CONCURRENCY = "concurrency" @@ -72,16 +71,29 @@ class StepType(str, Enum): class Step: - def __init__(self) -> None: - self.type = StepType.DEFAULT - self.timeout = "60s" - self.name = "name" - self.parents: list[Step] = [] - self.retries: int = 0 - self.rate_limits: list[RateLimit] = [] - self.desired_worker_labels: dict[str, DesiredWorkerLabel] = {} - self.backoff_factor: float | None = None - self.backoff_max_seconds: int | None = None + def __init__( + self, + fn: Callable[[Context], R], + type: StepType, + name: str = "", + timeout: str = "60m", + parents: list[str] = [], + retries: int = 0, + rate_limits: list[CreateStepRateLimit] = [], + desired_worker_labels: dict[str, DesiredWorkerLabel] = {}, + backoff_factor: float | None = None, + backoff_max_seconds: int | None = None, + ) -> None: + self.fn = fn + self.type = type + self.timeout = timeout + self.name = name + self.parents = parents + self.retries = retries + self.rate_limits = rate_limits + self.desired_worker_labels = desired_worker_labels + self.backoff_factor = backoff_factor + self.backoff_max_seconds = backoff_max_seconds self.concurrency__max_runs = 1 self.concurrency__limit_strategy = ConcurrencyLimitStrategy.CANCEL_IN_PROGRESS @@ -97,7 +109,8 @@ def on_failure_steps(self) -> list[Step]: return [ inst for attr in dir(self) - if isinstance(inst := getattr(self, attr), Step) and inst.type == StepType.ON_FAILURE + if isinstance(inst := getattr(self, attr), Step) + and inst.type == StepType.ON_FAILURE ] @property @@ -105,7 +118,8 @@ def concurrency_actions(self) -> list[Step]: return [ inst for attr in dir(self) - if isinstance(inst := getattr(self, attr), Step) and inst.type == StepType.CONCURRENCY + if isinstance(inst := getattr(self, attr), Step) + and inst.type == StepType.CONCURRENCY ] @property @@ -113,24 +127,18 @@ def default_steps(self) -> list[Step]: return [ inst for attr in dir(self) - if isinstance(inst := getattr(self, attr), Step) and inst.type == StepType.DEFAULT + if isinstance(inst := getattr(self, attr), Step) + and inst.type == StepType.DEFAULT ] - @property def steps(self) -> list[Step]: - return self.default_steps + self.concurrency_actions + self.on_failure_steps + return self.default_steps + self.concurrency_actions + self.on_failure_steps - - @property - def actions(self, namespace: str) -> list[Step]: + def get_actions(self, namespace: str) -> list[str]: service_name = self.get_service_name(namespace) - return [ - service_name + ":" + step - for step in self.steps - ] - + return [service_name + ":" + step.name for step in self.steps] def __init__(self) -> None: self.config.name = self.config.name or str(self.__class__.__name__) @@ -138,7 +146,9 @@ def __init__(self) -> None: def get_name(self, namespace: str) -> str: return namespace + self.config.name - def validate_concurrency_actions(self, service_name: str) -> WorkflowConcurrencyOpts | None: + def validate_concurrency_actions( + self, service_name: str + ) -> WorkflowConcurrencyOpts | None: if len(self.concurrency_actions) > 0 and self.config.concurrency: raise ValueError( "Error: Both concurrencyActions and concurrency_expression are defined. Please use only one concurrency configuration method." @@ -160,7 +170,11 @@ def validate_concurrency_actions(self, service_name: str) -> WorkflowConcurrency limit_strategy=self.config.concurrency.limit_strategy, ) - def validate_on_failure_steps(self, name: str, service_name: str) -> CreateWorkflowJobOpts | None: + return None + + def validate_on_failure_steps( + self, name: str, service_name: str + ) -> CreateWorkflowJobOpts | None: if not self.on_failure_steps: return None @@ -176,7 +190,7 @@ def validate_on_failure_steps(self, name: str, service_name: str) -> CreateWorkf inputs="{}", parents=[], retries=on_failure_step.retries, - rate_limits=on_failure_step.rate_limits, # type: ignore[arg-type] + rate_limits=on_failure_step.rate_limits, backoff_factor=on_failure_step.backoff_factor, backoff_max_seconds=on_failure_step.backoff_max_seconds, ) @@ -208,7 +222,7 @@ def get_create_opts(self, namespace: str) -> CreateWorkflowVersionOpts: inputs="{}", parents=[x for x in step.parents], retries=step.retries, - rate_limits=step.rate_limits, # type: ignore[arg-type] + rate_limits=step.rate_limits, worker_labels=step.desired_worker_labels, # type: ignore[arg-type] backoff_factor=step.backoff_factor, backoff_max_seconds=step.backoff_max_seconds, @@ -240,97 +254,73 @@ def get_create_opts(self, namespace: str) -> CreateWorkflowVersionOpts: ) +def transform_desired_worker_label(d: DesiredWorkerLabel) -> DesiredWorkerLabels: + value = d.value + return DesiredWorkerLabels( + strValue=value if not isinstance(value, int) else None, + intValue=value if isinstance(value, int) else None, + required=d.required, + weight=d.weight, + comparator=d.comparator, # type: ignore[arg-type] + ) + + def step( name: str = "", - timeout: str = "", - parents: list[str] | None = None, + timeout: str = "60m", + parents: list[str] = [], retries: int = 0, - rate_limits: list[RateLimit] | None = None, + rate_limits: list[RateLimit] = [], desired_worker_labels: dict[str, DesiredWorkerLabel] = {}, backoff_factor: float | None = None, backoff_max_seconds: int | None = None, -) -> Callable[[Callable[P, R]], Callable[P, R]]: - parents = parents or [] - - def inner(func: Callable[P, R]) -> Callable[P, R]: - limits = None - if rate_limits: - limits = [rate_limit._req for rate_limit in rate_limits or []] - - setattr(func, "_step_name", name.lower() or str(func.__name__).lower()) - setattr(func, "_step_parents", parents) - setattr(func, "_step_timeout", timeout) - setattr(func, "_step_retries", retries) - setattr(func, "_step_rate_limits", limits) - setattr(func, "_step_backoff_factor", backoff_factor) - setattr(func, "_step_backoff_max_seconds", backoff_max_seconds) - - def create_label(d: DesiredWorkerLabel) -> DesiredWorkerLabels: - value = d.value - return DesiredWorkerLabels( - strValue=value if not isinstance(value, int) else None, - intValue=value if isinstance(value, int) else None, - required=d.required, - weight=d.weight, - comparator=d.comparator, # type: ignore[arg-type] - ) - - setattr( - func, - "_step_desired_worker_labels", - {key: create_label(d) for key, d in desired_worker_labels.items()}, +) -> Callable[[Callable[[Context], R]], Step]: + def inner(func: Callable[[Context], R]) -> Step: + return Step( + fn=func, + type=StepType.DEFAULT, + name=name.lower() or str(func.__name__).lower(), + timeout=timeout, + parents=parents, + retries=retries, + rate_limits=[r for rate_limit in rate_limits if (r := rate_limit._req)], + desired_worker_labels={ + key: transform_desired_worker_label(d) + for key, d in desired_worker_labels.items() + }, + backoff_factor=backoff_factor, + backoff_max_seconds=backoff_max_seconds, ) - return func - return inner def on_failure_step( name: str = "", - timeout: str = "", + timeout: str = "60m", + parents: list[str] = [], retries: int = 0, - rate_limits: list[RateLimit] | None = None, + rate_limits: list[RateLimit] = [], + desired_worker_labels: dict[str, DesiredWorkerLabel] = {}, backoff_factor: float | None = None, backoff_max_seconds: int | None = None, -) -> Callable[..., Any]: - def inner(func: Callable[[Context], Any]) -> Callable[[Context], Any]: - limits = None - if rate_limits: - limits = [ - CreateStepRateLimit(key=rate_limit.static_key, units=rate_limit.units) # type: ignore[arg-type] - for rate_limit in rate_limits or [] - ] - - setattr( - func, "_on_failure_step_name", name.lower() or str(func.__name__).lower() +) -> Callable[[Callable[[Context], R]], Step]: + def inner(func: Callable[[Context], R]) -> Step: + return Step( + fn=func, + type=StepType.ON_FAILURE, + name=name.lower() or str(func.__name__).lower(), + timeout=timeout, + parents=parents, + retries=retries, + rate_limits=[r for rate_limit in rate_limits if (r := rate_limit._req)], + desired_worker_labels={ + key: transform_desired_worker_label(d) + for key, d in desired_worker_labels.items() + }, + backoff_factor=backoff_factor, + backoff_max_seconds=backoff_max_seconds, ) - setattr(func, "_on_failure_step_timeout", timeout) - setattr(func, "_on_failure_step_retries", retries) - setattr(func, "_on_failure_step_rate_limits", limits) - setattr(func, "_on_failure_step_backoff_factor", backoff_factor) - setattr(func, "_on_failure_step_backoff_max_seconds", backoff_max_seconds) - - return func - - return inner - - -def concurrency( - name: str = "", - max_runs: int = 1, - limit_strategy: ConcurrencyLimitStrategy = ConcurrencyLimitStrategy.CANCEL_IN_PROGRESS, -) -> Callable[..., Any]: - def inner(func: Callable[[Context], Any]) -> Callable[[Context], Any]: - setattr( - func, - "_concurrency_fn_name", - name.lower() or str(func.__name__).lower(), - ) - setattr(func, "_concurrency_max_runs", max_runs) - setattr(func, "_concurrency_limit_strategy", limit_strategy) - - return func return inner @@ -440,12 +430,7 @@ def config(self) -> ClientConfig: def tenant_id(self) -> str: return self._client.config.tenant_id - concurrency = staticmethod(concurrency) - - workflow = staticmethod(workflow) - step = staticmethod(step) - on_failure_step = staticmethod(on_failure_step) def worker( From 908d446d4e04159e57d663b2843beb26042e5d68 Mon Sep 17 00:00:00 2001 From: mrkaye97 Date: Wed, 29 Jan 2025 21:16:00 -0500 Subject: [PATCH 51/53] fix: type errors --- hatchet_sdk/v2.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/hatchet_sdk/v2.py b/hatchet_sdk/v2.py index 19efcbe1..63f2c225 100644 --- a/hatchet_sdk/v2.py +++ b/hatchet_sdk/v2.py @@ -80,11 +80,13 @@ def __init__( parents: list[str] = [], retries: int = 0, rate_limits: list[CreateStepRateLimit] = [], - desired_worker_labels: dict[str, DesiredWorkerLabel] = {}, + desired_worker_labels: dict[str, DesiredWorkerLabels] = {}, backoff_factor: float | None = None, backoff_max_seconds: int | None = None, ) -> None: self.fn = fn + self.is_async_function = bool(asyncio.iscoroutinefunction(fn)) + self.type = type self.timeout = timeout self.name = name @@ -223,7 +225,7 @@ def get_create_opts(self, namespace: str) -> CreateWorkflowVersionOpts: parents=[x for x in step.parents], retries=step.retries, rate_limits=step.rate_limits, - worker_labels=step.desired_worker_labels, # type: ignore[arg-type] + worker_labels=step.desired_worker_labels, backoff_factor=step.backoff_factor, backoff_max_seconds=step.backoff_max_seconds, ) From dc3019fdbbe307f201fb683cb7fea9cf6bad074b Mon Sep 17 00:00:00 2001 From: mrkaye97 Date: Wed, 29 Jan 2025 22:17:06 -0500 Subject: [PATCH 52/53] feat: move files around + finish typing --- hatchet_sdk/v2/__init__.py | 3 + hatchet_sdk/v2/hatchet.py | 176 ++++++++++++++++ hatchet_sdk/{v2.py => v2/workflows.py} | 272 ++++++------------------- hatchet_sdk/worker/worker.py | 7 +- 4 files changed, 247 insertions(+), 211 deletions(-) create mode 100644 hatchet_sdk/v2/__init__.py create mode 100644 hatchet_sdk/v2/hatchet.py rename hatchet_sdk/{v2.py => v2/workflows.py} (54%) diff --git a/hatchet_sdk/v2/__init__.py b/hatchet_sdk/v2/__init__.py new file mode 100644 index 00000000..e4d009d2 --- /dev/null +++ b/hatchet_sdk/v2/__init__.py @@ -0,0 +1,3 @@ +from .hatchet import Hatchet as Hatchet +from .workflows import Workflow as Workflow +from .workflows import WorkflowConfig as WorkflowConfig diff --git a/hatchet_sdk/v2/hatchet.py b/hatchet_sdk/v2/hatchet.py new file mode 100644 index 00000000..8037cc2a --- /dev/null +++ b/hatchet_sdk/v2/hatchet.py @@ -0,0 +1,176 @@ +import asyncio +import inspect +import logging +from enum import Enum +from functools import partial +from typing import Any, Callable, Optional, ParamSpec, Type, TypeVar, Union + +from pydantic import BaseModel, ConfigDict +from typing_extensions import deprecated + +from hatchet_sdk.clients.rest_client import RestApi +from hatchet_sdk.context.context import Context +from hatchet_sdk.contracts.workflows_pb2 import ( + ConcurrencyLimitStrategy, + CreateStepRateLimit, + CreateWorkflowJobOpts, + CreateWorkflowStepOpts, + CreateWorkflowVersionOpts, + DesiredWorkerLabels, + StickyStrategy, + WorkerLabelComparator, + WorkflowConcurrencyOpts, + WorkflowKind, +) +from hatchet_sdk.features.cron import CronClient +from hatchet_sdk.features.scheduled import ScheduledClient +from hatchet_sdk.labels import DesiredWorkerLabel +from hatchet_sdk.loader import ClientConfig +from hatchet_sdk.rate_limit import RateLimit +from hatchet_sdk.v2.workflows import ( + Step, + StepType, + Workflow, + WorkflowConfig, + step_factory, +) + +from ..client import Client, new_client, new_client_raw +from ..clients.admin import AdminClient +from ..clients.dispatcher.dispatcher import DispatcherClient +from ..clients.events import EventClient +from ..clients.run_event_listener import RunEventListenerClient +from ..logger import logger +from ..worker.worker import Worker +from ..workflow import ( + ConcurrencyExpression, + WorkflowInterface, + WorkflowMeta, + WorkflowStepProtocol, +) + + +class HatchetRest: + """ + Main client for interacting with the Hatchet API. + + This class provides access to various client interfaces and utility methods + for working with Hatchet via the REST API, + + Attributes: + rest (RestApi): Interface for REST API operations. + """ + + def __init__(self, config: ClientConfig = ClientConfig()): + self.rest = RestApi(config.server_url, config.token, config.tenant_id) + + +class Hatchet: + """ + Main client for interacting with the Hatchet SDK. + + This class provides access to various client interfaces and utility methods + for working with Hatchet workers, workflows, and steps. + + Attributes: + cron (CronClient): Interface for cron trigger operations. + + admin (AdminClient): Interface for administrative operations. + dispatcher (DispatcherClient): Interface for dispatching operations. + event (EventClient): Interface for event-related operations. + rest (RestApi): Interface for REST API operations. + """ + + _client: Client + cron: CronClient + scheduled: ScheduledClient + + @classmethod + def from_environment( + cls, defaults: ClientConfig = ClientConfig(), **kwargs: Any + ) -> "Hatchet": + return cls(client=new_client(defaults), **kwargs) + + @classmethod + def from_config(cls, config: ClientConfig, **kwargs: Any) -> "Hatchet": + return cls(client=new_client_raw(config), **kwargs) + + def __init__( + self, + debug: bool = False, + client: Optional[Client] = None, + config: ClientConfig = ClientConfig(), + ): + """ + Initialize a new Hatchet instance. + + Args: + debug (bool, optional): Enable debug logging. Defaults to False. + client (Optional[Client], optional): A pre-configured Client instance. Defaults to None. + config (ClientConfig, optional): Configuration for creating a new Client. Defaults to ClientConfig(). + """ + if client is not None: + self._client = client + else: + self._client = new_client(config, debug) + + if debug: + logger.setLevel(logging.DEBUG) + + self.cron = CronClient(self._client) + self.scheduled = ScheduledClient(self._client) + + @property + @deprecated( + "Direct access to client is deprecated and will be removed in a future version. Use specific client properties (Hatchet.admin, Hatchet.dispatcher, Hatchet.event, Hatchet.rest) instead. [0.32.0]", + ) + def client(self) -> Client: + return self._client + + @property + def admin(self) -> AdminClient: + return self._client.admin + + @property + def dispatcher(self) -> DispatcherClient: + return self._client.dispatcher + + @property + def event(self) -> EventClient: + return self._client.event + + @property + def rest(self) -> RestApi: + return self._client.rest + + @property + def listener(self) -> RunEventListenerClient: + return self._client.listener + + @property + def config(self) -> ClientConfig: + return self._client.config + + @property + def tenant_id(self) -> str: + return self._client.config.tenant_id + + step = staticmethod(step_factory(type=StepType.DEFAULT)) + on_failure_step = staticmethod(step_factory(type=StepType.ON_FAILURE)) + + def worker( + self, name: str, max_runs: int | None = None, labels: dict[str, str | int] = {} + ) -> Worker: + try: + loop = asyncio.get_running_loop() + except RuntimeError: + loop = None + + return Worker( + name=name, + max_runs=max_runs, + labels=labels, + config=self._client.config, + debug=self._client.debug, + owned_loop=loop is None, + ) diff --git a/hatchet_sdk/v2.py b/hatchet_sdk/v2/workflows.py similarity index 54% rename from hatchet_sdk/v2.py rename to hatchet_sdk/v2/workflows.py index 63f2c225..a506f06d 100644 --- a/hatchet_sdk/v2.py +++ b/hatchet_sdk/v2/workflows.py @@ -1,10 +1,8 @@ import asyncio -import logging from enum import Enum -from typing import Any, Callable, Optional, ParamSpec, Type, TypeVar, Union +from typing import Any, Callable, Concatenate, ParamSpec, Type, TypeVar, Union from pydantic import BaseModel, ConfigDict -from typing_extensions import deprecated from hatchet_sdk.clients.rest_client import RestApi from hatchet_sdk.context.context import Context @@ -16,35 +14,38 @@ CreateWorkflowVersionOpts, DesiredWorkerLabels, StickyStrategy, - WorkerLabelComparator, WorkflowConcurrencyOpts, WorkflowKind, ) -from hatchet_sdk.features.cron import CronClient -from hatchet_sdk.features.scheduled import ScheduledClient from hatchet_sdk.labels import DesiredWorkerLabel from hatchet_sdk.loader import ClientConfig from hatchet_sdk.rate_limit import RateLimit -from .client import Client, new_client, new_client_raw -from .clients.admin import AdminClient -from .clients.dispatcher.dispatcher import DispatcherClient -from .clients.events import EventClient -from .clients.run_event_listener import RunEventListenerClient -from .logger import logger -from .worker.worker import Worker -from .workflow import ( - ConcurrencyExpression, - WorkflowInterface, - WorkflowMeta, - WorkflowStepProtocol, -) +from ..logger import logger -T = TypeVar("T", bound=BaseModel) R = TypeVar("R") P = ParamSpec("P") -TWorkflow = TypeVar("TWorkflow", bound=object) + +class ConcurrencyExpression: + """ + Defines concurrency limits for a workflow using a CEL expression. + + Args: + expression (str): CEL expression to determine concurrency grouping. (i.e. "input.user_id") + max_runs (int): Maximum number of concurrent workflow runs. + limit_strategy (ConcurrencyLimitStrategy): Strategy for handling limit violations. + + Example: + ConcurrencyExpression("input.user_id", 5, ConcurrencyLimitStrategy.CANCEL_IN_PROGRESS) + """ + + def __init__( + self, expression: str, max_runs: int, limit_strategy: ConcurrencyLimitStrategy + ): + self.expression = expression + self.max_runs = max_runs + self.limit_strategy = limit_strategy class EmptyModel(BaseModel): @@ -52,6 +53,7 @@ class EmptyModel(BaseModel): class WorkflowConfig(BaseModel): + model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True) name: str = "" on_events: list[str] = [] on_crons: list[str] = [] @@ -73,7 +75,7 @@ class StepType(str, Enum): class Step: def __init__( self, - fn: Callable[[Context], R], + fn: Callable[[Any, Context], R], type: StepType, name: str = "", timeout: str = "60m", @@ -106,32 +108,24 @@ class Workflow: def get_service_name(self, namespace: str) -> str: return f"{namespace}{self.config.name.lower()}" - @property - def on_failure_steps(self) -> list[Step]: + def _get_steps_by_type(self, step_type: StepType) -> list[Step]: return [ - inst - for attr in dir(self) - if isinstance(inst := getattr(self, attr), Step) - and inst.type == StepType.ON_FAILURE + attr + for _, attr in self.__class__.__dict__.items() + if isinstance(attr, Step) and attr.type == step_type ] + @property + def on_failure_steps(self) -> list[Step]: + return self._get_steps_by_type(StepType.ON_FAILURE) + @property def concurrency_actions(self) -> list[Step]: - return [ - inst - for attr in dir(self) - if isinstance(inst := getattr(self, attr), Step) - and inst.type == StepType.CONCURRENCY - ] + return self._get_steps_by_type(StepType.CONCURRENCY) @property def default_steps(self) -> list[Step]: - return [ - inst - for attr in dir(self) - if isinstance(inst := getattr(self, attr), Step) - and inst.type == StepType.DEFAULT - ] + return self._get_steps_by_type(StepType.DEFAULT) @property def steps(self) -> list[Step]: @@ -267,64 +261,39 @@ def transform_desired_worker_label(d: DesiredWorkerLabel) -> DesiredWorkerLabels ) -def step( - name: str = "", - timeout: str = "60m", - parents: list[str] = [], - retries: int = 0, - rate_limits: list[RateLimit] = [], - desired_worker_labels: dict[str, DesiredWorkerLabel] = {}, - backoff_factor: float | None = None, - backoff_max_seconds: int | None = None, -) -> Callable[[Callable[[Context], R]], Step]: - def inner(func: Callable[[Context], R]) -> Step: - return Step( - fn=func, - type=StepType.DEFAULT, - name=name.lower() or str(func.__name__).lower(), - timeout=timeout, - parents=parents, - retries=retries, - rate_limits=[r for rate_limit in rate_limits if (r := rate_limit._req)], - desired_worker_labels={ - key: transform_desired_worker_label(d) - for key, d in desired_worker_labels.items() - }, - backoff_factor=backoff_factor, - backoff_max_seconds=backoff_max_seconds, - ) +def step_factory( + type: StepType, +) -> Callable[..., Callable[[Callable[[Any, Context], R]], Step]]: + def _step( + name: str = "", + timeout: str = "60m", + parents: list[str] = [], + retries: int = 0, + rate_limits: list[RateLimit] = [], + desired_worker_labels: dict[str, DesiredWorkerLabel] = {}, + backoff_factor: float | None = None, + backoff_max_seconds: int | None = None, + ) -> Callable[[Callable[[Any, Context], R]], Step]: + def inner(func: Callable[[Any, Context], R]) -> Step: + return Step( + fn=func, + type=type, + name=name.lower() or str(func.__name__).lower(), + timeout=timeout, + parents=parents, + retries=retries, + rate_limits=[r for rate_limit in rate_limits if (r := rate_limit._req)], + desired_worker_labels={ + key: transform_desired_worker_label(d) + for key, d in desired_worker_labels.items() + }, + backoff_factor=backoff_factor, + backoff_max_seconds=backoff_max_seconds, + ) - return inner - - -def on_failure_step( - name: str = "", - timeout: str = "60m", - parents: list[str] = [], - retries: int = 0, - rate_limits: list[RateLimit] = [], - desired_worker_labels: dict[str, DesiredWorkerLabel] = {}, - backoff_factor: float | None = None, - backoff_max_seconds: int | None = None, -) -> Callable[[Callable[[Context], R]], Step]: - def inner(func: Callable[[Context], R]) -> Step: - return Step( - fn=func, - type=StepType.ON_FAILURE, - name=name.lower() or str(func.__name__).lower(), - timeout=timeout, - parents=parents, - retries=retries, - rate_limits=[r for rate_limit in rate_limits if (r := rate_limit._req)], - desired_worker_labels={ - key: transform_desired_worker_label(d) - for key, d in desired_worker_labels.items() - }, - backoff_factor=backoff_factor, - backoff_max_seconds=backoff_max_seconds, - ) + return inner - return inner + return _step class HatchetRest: @@ -340,114 +309,3 @@ class HatchetRest: def __init__(self, config: ClientConfig = ClientConfig()): self.rest = RestApi(config.server_url, config.token, config.tenant_id) - - -class Hatchet: - """ - Main client for interacting with the Hatchet SDK. - - This class provides access to various client interfaces and utility methods - for working with Hatchet workers, workflows, and steps. - - Attributes: - cron (CronClient): Interface for cron trigger operations. - - admin (AdminClient): Interface for administrative operations. - dispatcher (DispatcherClient): Interface for dispatching operations. - event (EventClient): Interface for event-related operations. - rest (RestApi): Interface for REST API operations. - """ - - _client: Client - cron: CronClient - scheduled: ScheduledClient - - @classmethod - def from_environment( - cls, defaults: ClientConfig = ClientConfig(), **kwargs: Any - ) -> "Hatchet": - return cls(client=new_client(defaults), **kwargs) - - @classmethod - def from_config(cls, config: ClientConfig, **kwargs: Any) -> "Hatchet": - return cls(client=new_client_raw(config), **kwargs) - - def __init__( - self, - debug: bool = False, - client: Optional[Client] = None, - config: ClientConfig = ClientConfig(), - ): - """ - Initialize a new Hatchet instance. - - Args: - debug (bool, optional): Enable debug logging. Defaults to False. - client (Optional[Client], optional): A pre-configured Client instance. Defaults to None. - config (ClientConfig, optional): Configuration for creating a new Client. Defaults to ClientConfig(). - """ - if client is not None: - self._client = client - else: - self._client = new_client(config, debug) - - if debug: - logger.setLevel(logging.DEBUG) - - self.cron = CronClient(self._client) - self.scheduled = ScheduledClient(self._client) - - @property - @deprecated( - "Direct access to client is deprecated and will be removed in a future version. Use specific client properties (Hatchet.admin, Hatchet.dispatcher, Hatchet.event, Hatchet.rest) instead. [0.32.0]", - ) - def client(self) -> Client: - return self._client - - @property - def admin(self) -> AdminClient: - return self._client.admin - - @property - def dispatcher(self) -> DispatcherClient: - return self._client.dispatcher - - @property - def event(self) -> EventClient: - return self._client.event - - @property - def rest(self) -> RestApi: - return self._client.rest - - @property - def listener(self) -> RunEventListenerClient: - return self._client.listener - - @property - def config(self) -> ClientConfig: - return self._client.config - - @property - def tenant_id(self) -> str: - return self._client.config.tenant_id - - step = staticmethod(step) - on_failure_step = staticmethod(on_failure_step) - - def worker( - self, name: str, max_runs: int | None = None, labels: dict[str, str | int] = {} - ) -> Worker: - try: - loop = asyncio.get_running_loop() - except RuntimeError: - loop = None - - return Worker( - name=name, - max_runs=max_runs, - labels=labels, - config=self._client.config, - debug=self._client.debug, - owned_loop=loop is None, - ) diff --git a/hatchet_sdk/worker/worker.py b/hatchet_sdk/worker/worker.py index 2c393245..2c2f7788 100644 --- a/hatchet_sdk/worker/worker.py +++ b/hatchet_sdk/worker/worker.py @@ -113,12 +113,11 @@ def register_workflow_from_opts( logger.error(e) sys.exit(1) - def register_workflow(self, workflow: TWorkflow) -> None: - ## Hack for typing - assert isinstance(workflow, WorkflowInterface) - + def register_workflow(self, workflow) -> None: namespace = self.client.config.namespace + print(f"registering workflow: {workflow}", workflow.steps) + try: self.client.admin.put_workflow( workflow.get_name(namespace), workflow.get_create_opts(namespace) From 7807ae124c1369f147eb0b39fb7e4169267b0fd9 Mon Sep 17 00:00:00 2001 From: mrkaye97 Date: Wed, 29 Jan 2025 22:44:16 -0500 Subject: [PATCH 53/53] feat: first working version! --- hatchet_sdk/v2/hatchet.py | 37 +++--------------------------------- hatchet_sdk/v2/workflows.py | 29 +++++++--------------------- hatchet_sdk/worker/worker.py | 30 +++++++++++++---------------- 3 files changed, 23 insertions(+), 73 deletions(-) diff --git a/hatchet_sdk/v2/hatchet.py b/hatchet_sdk/v2/hatchet.py index 8037cc2a..3e02cdae 100644 --- a/hatchet_sdk/v2/hatchet.py +++ b/hatchet_sdk/v2/hatchet.py @@ -1,39 +1,15 @@ import asyncio -import inspect import logging -from enum import Enum -from functools import partial -from typing import Any, Callable, Optional, ParamSpec, Type, TypeVar, Union +from typing import TYPE_CHECKING, Any, Optional -from pydantic import BaseModel, ConfigDict from typing_extensions import deprecated from hatchet_sdk.clients.rest_client import RestApi -from hatchet_sdk.context.context import Context -from hatchet_sdk.contracts.workflows_pb2 import ( - ConcurrencyLimitStrategy, - CreateStepRateLimit, - CreateWorkflowJobOpts, - CreateWorkflowStepOpts, - CreateWorkflowVersionOpts, - DesiredWorkerLabels, - StickyStrategy, - WorkerLabelComparator, - WorkflowConcurrencyOpts, - WorkflowKind, -) from hatchet_sdk.features.cron import CronClient from hatchet_sdk.features.scheduled import ScheduledClient -from hatchet_sdk.labels import DesiredWorkerLabel from hatchet_sdk.loader import ClientConfig -from hatchet_sdk.rate_limit import RateLimit -from hatchet_sdk.v2.workflows import ( - Step, - StepType, - Workflow, - WorkflowConfig, - step_factory, -) +from hatchet_sdk.v2.workflows import StepType, step_factory +from hatchet_sdk.worker import Worker from ..client import Client, new_client, new_client_raw from ..clients.admin import AdminClient @@ -41,13 +17,6 @@ from ..clients.events import EventClient from ..clients.run_event_listener import RunEventListenerClient from ..logger import logger -from ..worker.worker import Worker -from ..workflow import ( - ConcurrencyExpression, - WorkflowInterface, - WorkflowMeta, - WorkflowStepProtocol, -) class HatchetRest: diff --git a/hatchet_sdk/v2/workflows.py b/hatchet_sdk/v2/workflows.py index a506f06d..2c068e27 100644 --- a/hatchet_sdk/v2/workflows.py +++ b/hatchet_sdk/v2/workflows.py @@ -1,6 +1,6 @@ import asyncio from enum import Enum -from typing import Any, Callable, Concatenate, ParamSpec, Type, TypeVar, Union +from typing import Any, Callable, ParamSpec, Type, TypeVar, Union from pydantic import BaseModel, ConfigDict @@ -18,7 +18,6 @@ WorkflowKind, ) from hatchet_sdk.labels import DesiredWorkerLabel -from hatchet_sdk.loader import ClientConfig from hatchet_sdk.rate_limit import RateLimit from ..logger import logger @@ -61,7 +60,7 @@ class WorkflowConfig(BaseModel): timeout: str = "60m" schedule_timeout: str = "5m" sticky: Union[StickyStrategy, None] = None - default_priority: int = 0 + default_priority: int = 1 concurrency: ConcurrencyExpression | None = None input_validator: Type[BaseModel] = EmptyModel @@ -101,6 +100,9 @@ def __init__( self.concurrency__max_runs = 1 self.concurrency__limit_strategy = ConcurrencyLimitStrategy.CANCEL_IN_PROGRESS + def __call__(self, *args: Any, **kwargs: Any) -> R: + return self.fn(*args, **kwargs) + class Workflow: config: WorkflowConfig = WorkflowConfig() @@ -131,10 +133,8 @@ def default_steps(self) -> list[Step]: def steps(self) -> list[Step]: return self.default_steps + self.concurrency_actions + self.on_failure_steps - def get_actions(self, namespace: str) -> list[str]: - service_name = self.get_service_name(namespace) - - return [service_name + ":" + step.name for step in self.steps] + def create_action_name(self, namespace: str, step: Step) -> str: + return self.get_service_name(namespace) + ":" + step.name def __init__(self) -> None: self.config.name = self.config.name or str(self.__class__.__name__) @@ -294,18 +294,3 @@ def inner(func: Callable[[Any, Context], R]) -> Step: return inner return _step - - -class HatchetRest: - """ - Main client for interacting with the Hatchet API. - - This class provides access to various client interfaces and utility methods - for working with Hatchet via the REST API, - - Attributes: - rest (RestApi): Interface for REST API operations. - """ - - def __init__(self, config: ClientConfig = ClientConfig()): - self.rest = RestApi(config.server_url, config.token, config.tenant_id) diff --git a/hatchet_sdk/worker/worker.py b/hatchet_sdk/worker/worker.py index 2c2f7788..9ed0da7c 100644 --- a/hatchet_sdk/worker/worker.py +++ b/hatchet_sdk/worker/worker.py @@ -10,7 +10,7 @@ from multiprocessing import Queue from multiprocessing.process import BaseProcess from types import FrameType -from typing import Any, Callable, TypeVar, get_type_hints +from typing import TYPE_CHECKING, Any, Callable, TypeVar, Union, get_type_hints from aiohttp import web from aiohttp.web_request import Request @@ -33,7 +33,10 @@ STOP_LOOP_TYPE, WorkerActionRunLoopManager, ) -from hatchet_sdk.workflow import WorkflowInterface + +if TYPE_CHECKING: + from hatchet_sdk.v2 import Workflow + from hatchet_sdk.v2.workflows import Step T = TypeVar("T") @@ -50,9 +53,6 @@ class WorkerStartOptions: loop: asyncio.AbstractEventLoop | None = field(default=None) -TWorkflow = TypeVar("TWorkflow", bound=object) - - class Worker: def __init__( self, @@ -113,11 +113,9 @@ def register_workflow_from_opts( logger.error(e) sys.exit(1) - def register_workflow(self, workflow) -> None: + def register_workflow(self, workflow: Union["Workflow", Any]) -> None: namespace = self.client.config.namespace - print(f"registering workflow: {workflow}", workflow.steps) - try: self.client.admin.put_workflow( workflow.get_name(namespace), workflow.get_create_opts(namespace) @@ -128,24 +126,22 @@ def register_workflow(self, workflow) -> None: sys.exit(1) def create_action_function( - action_func: Callable[..., T] + action_func: "Step" ) -> Callable[[Context], T]: def action_function(context: Context) -> T: return action_func(workflow, context) - if asyncio.iscoroutinefunction(action_func): - setattr(action_function, "is_coroutine", True) - else: - setattr(action_function, "is_coroutine", False) + setattr(action_function, "is_coroutine", action_func.is_async_function) return action_function - for action_name, action_func in workflow.get_actions(namespace): - self.action_registry[action_name] = create_action_function(action_func) - return_type = get_type_hints(action_func).get("return") + for step in workflow.steps: + action_name = workflow.create_action_name(namespace, step) + self.action_registry[action_name] = create_action_function(step) + return_type = get_type_hints(step.fn).get("return") self.validator_registry[action_name] = WorkflowValidator( - workflow_input=workflow.input_validator, + workflow_input=workflow.config.input_validator, step_output=return_type if is_basemodel_subclass(return_type) else None, )