diff --git a/crategen/models/__init__.py b/crategen/models/__init__.py index 2da85d7..0dcf03e 100644 --- a/crategen/models/__init__.py +++ b/crategen/models/__init__.py @@ -17,8 +17,16 @@ TESState, TESTaskLog, ) +from .wes_models import ( + Log, + Run, + RunRequest, + State, + TaskLog, +) __all__ = [ + # TES Models "TESData", "TESInput", "TESOutput", @@ -29,4 +37,11 @@ "TESOutputFileLog", "TESFileType", "TESState", + + # WES Models + "State", + "Log", + "TaskLog", + "RunRequest", + "Run", ] diff --git a/crategen/models/wes_models.py b/crategen/models/wes_models.py new file mode 100644 index 0000000..01c5f63 --- /dev/null +++ b/crategen/models/wes_models.py @@ -0,0 +1,149 @@ +"""Each model in this module conforms to the corresponding WES model names as specified by the GA4GH schema (https://ga4gh.github.io/workflow-execution-service-schemas/docs/).""" + +from enum import Enum +from typing import List, Optional, Union + +from pydantic import BaseModel, Field, root_validator, validator +from rfc3339_validator import validate_rfc3339 # type: ignore + + +class State(str, Enum): + """Enumeration of workflow states. + + Attributes: + UNKNOWN: The state of the workflow is unknown. This provides a safe default for messages where this field is missing. + QUEUED: The workflow is queued. + INITIALIZING: The workflow is initializing. + RUNNING: The workflow is running. + PAUSED: The workflow is paused. + COMPLETE: The workflow has completed successfully. + EXECUTOR_ERROR: The workflow encountered an executor error. + SYSTEM_ERROR: The workflow encountered a system error. + CANCELED: The workflow was canceled by the user. + CANCELING: The workflow was canceled by the user, and is in the process of stopping. + PREEMPTED: The workflow is stopped (preempted) by the system. + """ + UNKNOWN = "UNKNOWN" + QUEUED = "QUEUED" + INITIALIZING = "INITIALIZING" + RUNNING = "RUNNING" + PAUSED = "PAUSED" + COMPLETE = "COMPLETE" + EXECUTOR_ERROR = "EXECUTOR_ERROR" + SYSTEM_ERROR = "SYSTEM_ERROR" + CANCELED = "CANCELED" + CANCELING = "CANCELING" + PREEMPTED = "PREEMPTED" + + +class Log(BaseModel): + """Log information for a workflow run or task. + + Attributes: + name (`Optional[str]`): Task or workflow name + cmd (`Optional[List[str]]`): Command line executed + start_time (`Optional[str]`): When the task started executing (RFC 3339) + end_time (`Optional[str]`): When the task ended (RFC 3339) + stdout (`Optional[str]`): URL to retrieve standard output logs + stderr (`Optional[str]`): URL to retrieve standard error logs + exit_code (`Optional[int]`): Exit code of the program + system_logs (`Optional[List[str]]`): Any logs the system decides are relevant + """ + + name: Optional[str] = None + cmd: Optional[List[str]] = None + start_time: Optional[str] = None + end_time: Optional[str] = None + stdout: Optional[str] = None + stderr: Optional[str] = None + exit_code: Optional[int] = None + system_logs: Optional[List[str]] = None + + @validator("start_time", "end_time", allow_reuse=True) + def validate_datetime(cls, value, field): + """Check correct datetime format is RFC 3339""" + if value and not validate_rfc3339(value): + raise ValueError( + f"The '{field.name}' property must be in RFC 3339 format" + ) + return value + + +class TaskLog(Log): + """Task execution log information. + + Attributes: + id (`str`): Unique identifier which may be used to reference the task + tes_uri (`Optional[str]`): Optional URL pointing to an extended task definition defined by a TES API + name (`str`): REQUIRED The name of the task + """ + + id: str + tes_uri: Optional[str] = None + name: str = Field(...) + + +class RunRequest(BaseModel): + """A workflow run request. + + Attributes: + workflow_params (`dict[str, str]`): REQUIRED The workflow run parameterizations (JSON encoded) + workflow_type (`str`): REQUIRED The workflow descriptor type (e.g., "CWL" or "WDL") + workflow_type_version (`str`): REQUIRED The workflow descriptor type version + tags (`Optional[dict[str, str]]`): Arbitrary key/value tags for the workflow + workflow_engine_parameters (`Optional[dict[str, str]]`): Workflow engine specific parameters + workflow_engine (`Optional[str]`): The workflow engine that should run this workflow + workflow_engine_version (`Optional[str]`): The version of the workflow engine + workflow_url (`str`): The workflow CWL or WDL document + """ + + workflow_params: dict[str, str] + workflow_type: str + workflow_type_version: str + tags: Optional[dict[str, str]] = {} + workflow_engine_parameters: Optional[dict[str, str]] = None + workflow_engine: Optional[str] = None + workflow_engine_version: Optional[str] = None + workflow_url: str + + @root_validator() + def validate_workflow_engine(cls, values): + """Validate workflow engine dependencies.""" + engine_version = values.get("workflow_engine_version") + engine = values.get("workflow_engine") + if engine_version is not None and engine is None: + raise ValueError( + "The 'workflow_engine' attribute is required when the 'workflow_engine_version' attribute is set" + ) + return values + + +class Run(BaseModel): + """A workflow run. + + Attributes: + run_id (`str`): Workflow run ID + request (`Optional[RunRequest]`): The original workflow run request + state (`Optional[State]`): Current state of the workflow run + run_log (`Optional[Log]`): Log information about the workflow run + task_logs_url (`Optional[str]`): URL for obtaining task logs + task_logs (`Optional[List[Union[Log, TaskLog]]]`): DEPRECATED Task logs, use task_logs_url instead + outputs (`dict[str, str]`): Output files produced by the workflow run + """ + + run_id: str + request: Optional[RunRequest] = None + state: Optional[State] = None + run_log: Optional[Log] = None + task_logs_url: Optional[str] = None + task_logs: Optional[List[Union[Log, TaskLog]]] = None + outputs: dict[str, str] = {} + + @root_validator + def check_deprecated_fields(cls, values): + """Check for usage of deprecated fields.""" + if values.get("task_logs") is not None: + print( + "DeprecationWarning: The 'task_logs' field is deprecated and will be removed in future versions. Use 'task_logs_url' instead." + ) + return values diff --git a/tests/unit/test_wes_models.py b/tests/unit/test_wes_models.py new file mode 100644 index 0000000..206cf54 --- /dev/null +++ b/tests/unit/test_wes_models.py @@ -0,0 +1,220 @@ +"""Tests for WES models""" +import io +import sys + +import pytest + +from crategen.models.wes_models import ( + Log, + Run, + RunRequest, + State, + TaskLog, +) + +# Test data constants +valid_datetime_strings = [ + "2020-10-02T16:00:00.000Z", + "2024-10-15T18:14:34+00:00", + "2024-10-15T18:14:34.948996+00:00", + "2024-10-15T19:01:06.872464+00:00", +] + +invalid_datetime_strings = [ + "2020-10-02 16:00:00", # Missing 'T' separator + "2020-10-02T16:00:00", # Missing timezone + "20201002T160000Z", # Missing separators + "2020-10-02T16:00:00.000+0200", # Invalid timezone format + "2020-10-02T16:00:00.000 GMT", # Invalid timezone format + "02-10-2020T16:00:00.000Z", # Incorrect date order +] + +test_url = "https://raw.githubusercontent.com/elixir-cloud-aai/CrateGen/refs/heads/main/README.md" + +# Test data samples +test_workflow_params = { + "reads": f"{test_url}/reads.fastq", + "reference": f"{test_url}/reference.fa", + "output_dir": f"{test_url}/results/" +} + +test_workflow_engine_params = { + "memory": "16GB", + "cpu": "4", + "disk_size": "100GB" +} + +test_tags = { + "project": "genomics-pipeline", + "sample": "TCGA-AB-2823", + "analysis": "variant-calling" +} + + +class TestState: + """Test suite for State enum""" + + def test_state_enum_values(self): + """Test that State enum has correct values from GA4GH spec""" + assert State.UNKNOWN == "UNKNOWN" + assert State.QUEUED == "QUEUED" + assert State.INITIALIZING == "INITIALIZING" + assert State.RUNNING == "RUNNING" + assert State.PAUSED == "PAUSED" + assert State.COMPLETE == "COMPLETE" + assert State.EXECUTOR_ERROR == "EXECUTOR_ERROR" + assert State.SYSTEM_ERROR == "SYSTEM_ERROR" + assert State.CANCELED == "CANCELED" + assert State.CANCELING == "CANCELING" + assert State.PREEMPTED == "PREEMPTED" + + +class TestLog: + """Test suite for Log model""" + + def test_log_datetime_validation(self): + """Test datetime validation in Log""" + for valid_datetime in valid_datetime_strings: + log = Log( + name="workflow_123", + start_time=valid_datetime, + end_time=valid_datetime + ) + assert log.start_time == valid_datetime + assert log.end_time == valid_datetime + + for invalid_datetime in invalid_datetime_strings: + with pytest.raises(ValueError) as exc_info: + Log( + name="workflow_123", + start_time=invalid_datetime + ) + assert "format" in str(exc_info.value) + + +class TestTaskLog: + """Test suite for TaskLog model""" + + def test_task_log_required_fields(self): + """Test that required fields must be provided""" + with pytest.raises(ValueError): + TaskLog(id="task-123") # Missing required name field + + # Test with required fields + task_log = TaskLog( + id="task-123", + name="alignment" + ) + assert task_log.id == "task-123" + assert task_log.name == "alignment" + + def test_task_log_all_fields(self): + """Test TaskLog with all fields""" + task_log = TaskLog( + id="task-bwa-mem-123", + name="bwa_mem_alignment", + cmd=["bwa", "mem", "-t", "4", "reference.fa", "reads.fastq"], + stdout="https://storage.googleapis.com/workflow-logs/task123/stdout.log", + stderr="https://storage.googleapis.com/workflow-logs/task123/stderr.log", + exit_code=0, + tes_uri=test_url + ) + assert task_log.id.startswith("task-") + assert task_log.name == "bwa_mem_alignment" + assert task_log.stdout.startswith("https://") + assert task_log.stderr.startswith("https://") + assert task_log.tes_uri == test_url + + +class TestRunRequest: + """Test suite for RunRequest model""" + + def test_run_request_required_fields(self): + """Test that required fields must be provided""" + with pytest.raises(ValueError): + RunRequest(workflow_type="CWL") # Missing other required fields + + request = RunRequest( + workflow_params={"input": "test.txt"}, + workflow_type="CWL", + workflow_type_version="v1.0", + workflow_url=test_url + ) + assert request.workflow_type == "CWL" + assert request.workflow_url == test_url + + def test_workflow_engine_validation(self): + """Test workflow engine validation rules""" + # Version without engine should fail + with pytest.raises(ValueError) as exc_info: + RunRequest( + workflow_params={}, + workflow_type="CWL", + workflow_type_version="v1.0", + workflow_url=test_url, + workflow_engine_version="3.1.0" + ) + assert "workflow_engine" in str(exc_info.value) + + # Both engine and version should work + request = RunRequest( + workflow_params={}, + workflow_type="CWL", + workflow_type_version="v1.0", + workflow_url=test_url, + workflow_engine="cwltool", + workflow_engine_version="3.1.0" + ) + assert request.workflow_engine == "cwltool" + assert request.workflow_engine_version == "3.1.0" + + +class TestRun: + """Test suite for Run model""" + + def test_run_required_fields(self): + """Test that required fields must be provided""" + run = Run(run_id="run-123") + assert run.run_id == "run-123" + assert run.outputs == {} + + def test_task_logs_deprecation(self): + """Test deprecation warning for task_logs field""" + task_log = TaskLog( + id="task-123", + name="alignment" + ) + + # Capture stdout to test deprecation warning + captured_output = io.StringIO() + sys.stdout = captured_output + + Run( + run_id="run-123", + task_logs=[task_log] + ) + + sys.stdout = sys.__stdout__ + output = captured_output.getvalue() + + assert "DeprecationWarning" in output + assert "task_logs" in output + assert "task_logs_url" in output + + def test_run_output_urls(self): + """Test Run accepts outputs with different URL schemes""" + output_urls = { + "http_url": "http://example.com/output.txt", + "https_url": "https://storage.googleapis.com/output.txt", + "s3_url": "s3://my-bucket/output.txt", + "gs_url": "gs://my-bucket/output.txt", + "file_url": "file:///local/path/output.txt", + "absolute_path": "/absolute/path/output.txt", + "relative_path": "./relative/path/output.txt" + } + + run = Run(run_id="run-123", outputs=output_urls) + + # Verify all output URLs are preserved + for key, value in output_urls.items(): + assert run.outputs[key] == value