diff --git a/hatchet_sdk/clients/dispatcher/dispatcher.py b/hatchet_sdk/clients/dispatcher/dispatcher.py index e52aca2a..05871970 100644 --- a/hatchet_sdk/clients/dispatcher/dispatcher.py +++ b/hatchet_sdk/clients/dispatcher/dispatcher.py @@ -52,6 +52,13 @@ def __init__(self, config: ClientConfig): async def get_action_listener( self, req: GetActionListenerRequest ) -> ActionListener: + + # Override labels with the preset labels + preset_labels = self.config.worker_preset_labels + + for key, value in preset_labels.items(): + req.labels[key] = WorkerLabels(strValue=str(value)) + # Register the worker response: WorkerRegisterResponse = await self.aio_client.Register( WorkerRegisterRequest( diff --git a/hatchet_sdk/loader.py b/hatchet_sdk/loader.py index d754c2ae..38b0b2bf 100644 --- a/hatchet_sdk/loader.py +++ b/hatchet_sdk/loader.py @@ -45,6 +45,7 @@ def __init__( otel_exporter_oltp_protocol: str | None = None, worker_healthcheck_port: int | None = None, worker_healthcheck_enabled: bool | None = None, + worker_preset_labels: dict[str, str] = {}, ): self.tenant_id = tenant_id self.tls_config = tls_config @@ -61,6 +62,7 @@ def __init__( self.otel_exporter_oltp_protocol = otel_exporter_oltp_protocol self.worker_healthcheck_port = worker_healthcheck_port self.worker_healthcheck_enabled = worker_healthcheck_enabled + self.worker_preset_labels = worker_preset_labels if not self.logInterceptor: self.logInterceptor = getLogger() @@ -184,6 +186,16 @@ def get_config_value(key, env_var): == "True" ) + # Add preset labels to the worker config + worker_preset_labels: dict[str, str] = defaults.worker_preset_labels + + autoscaling_target = get_config_value( + "autoscaling_target", "HATCHET_CLIENT_AUTOSCALING_TARGET" + ) + + if autoscaling_target: + worker_preset_labels["hatchet-autoscaling-target"] = autoscaling_target + return ClientConfig( tenant_id=tenant_id, tls_config=tls_config, @@ -201,6 +213,7 @@ def get_config_value(key, env_var): otel_exporter_oltp_protocol=otel_exporter_oltp_protocol, worker_healthcheck_port=worker_healthcheck_port, worker_healthcheck_enabled=worker_healthcheck_enabled, + worker_preset_labels=worker_preset_labels, ) def _load_tls_config(self, tls_data: Dict, host_port) -> ClientTLSConfig: diff --git a/pyproject.toml b/pyproject.toml index 099b3e2c..5004db3c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "hatchet-sdk" -version = "0.44.2" +version = "0.45.0" description = "" authors = ["Alexander Belanger "] readme = "README.md"