From edfc0e385a140f8d5dd3f33e1381601f721a4f14 Mon Sep 17 00:00:00 2001 From: Mark Wiebe <399551+mwiebe@users.noreply.github.com> Date: Sat, 8 Feb 2025 19:15:17 -0800 Subject: [PATCH] feat: Implement the task chunking RFC 0001 Signed-off-by: Mark Wiebe <399551+mwiebe@users.noreply.github.com> --- requirements-testing.txt | 2 +- src/openjd/model/_internal/__init__.py | 6 +- .../_internal/_param_space_dim_validation.py | 46 +- src/openjd/model/v2023_09/__init__.py | 4 + src/openjd/model/v2023_09/_model.py | 248 ++++++- .../test_chunk_int_task_parameter_type.py | 658 ++++++++++++++++++ test/openjd/model/v2023_09/test_create.py | 122 +++- test/openjd/model/v2023_09/test_strings.py | 60 +- 8 files changed, 1084 insertions(+), 62 deletions(-) create mode 100644 test/openjd/model/v2023_09/test_chunk_int_task_parameter_type.py diff --git a/requirements-testing.txt b/requirements-testing.txt index 1d12c427..765c593e 100644 --- a/requirements-testing.txt +++ b/requirements-testing.txt @@ -4,6 +4,6 @@ pytest-cov == 6.0.* pytest-timeout == 2.3.* pytest-xdist == 3.6.* types-PyYAML ~= 6.0 -black == 24.* +black == 25.* ruff == 0.9.* mypy == 1.15.* diff --git a/src/openjd/model/_internal/__init__.py b/src/openjd/model/_internal/__init__.py index f3019676..b0dee5e2 100644 --- a/src/openjd/model/_internal/__init__.py +++ b/src/openjd/model/_internal/__init__.py @@ -8,12 +8,16 @@ from ._combination_expr import Parser as CombinationExpressionParser from ._combination_expr import ProductNode as CombinationExpressionProductNode from ._create_job import instantiate_model -from ._param_space_dim_validation import validate_step_parameter_space_dimensions +from ._param_space_dim_validation import ( + validate_step_parameter_space_chunk_constraint, + validate_step_parameter_space_dimensions, +) from ._variable_reference_validation import prevalidate_model_template_variable_references __all__ = ( "instantiate_model", "prevalidate_model_template_variable_references", + "validate_step_parameter_space_chunk_constraint", "validate_step_parameter_space_dimensions", "validate_unique_elements", "CombinationExpressionAssociationNode", diff --git a/src/openjd/model/_internal/_param_space_dim_validation.py b/src/openjd/model/_internal/_param_space_dim_validation.py index d8c5b6c6..e9ed665a 100644 --- a/src/openjd/model/_internal/_param_space_dim_validation.py +++ b/src/openjd/model/_internal/_param_space_dim_validation.py @@ -23,10 +23,42 @@ def validate_step_parameter_space_dimensions( ExpressionError if the combination expression violates constraints. """ parse_tree = Parser().parse(combination) - _validate_expr_tree(parse_tree, parameter_range_lengths) + _validate_expr_tree_dimensions(parse_tree, parameter_range_lengths) -def _validate_expr_tree(root: Node, parameter_range_lengths: dict[str, int]) -> int: +def validate_step_parameter_space_chunk_constraint(chunk_parameter: str, parse_tree: Node) -> bool: + """This validates that the task parameter of type CHUNK[INT] never appears + within scope of an associative expression. A single chunk consists of + individual values from the non-chunk parameters, and a set of values from the + chunk parameter. With this restriction, the session script code can interpret + these parameters easily, while without it the specification would need to define + how the associations are represented and provided to the script. + + Raises: + ExpressionError if the combination expression violates the chunk constraint. + """ + # Returns True if the subtree includes the chunk parameter, otherwise False + if isinstance(parse_tree, IdentifierNode): + return parse_tree.parameter == chunk_parameter + elif isinstance(parse_tree, AssociationNode): + for child in parse_tree.children: + if validate_step_parameter_space_chunk_constraint(chunk_parameter, child): + raise ExpressionError( + ( + f"CHUNK[INT] parameter {chunk_parameter} must not be part of an associative expression. " + ) + ) + return False + else: + # For type hinting + assert isinstance(parse_tree, ProductNode) + return any( + validate_step_parameter_space_chunk_constraint(chunk_parameter, child) + for child in parse_tree.children + ) + + +def _validate_expr_tree_dimensions(root: Node, parameter_range_lengths: dict[str, int]) -> int: # Returns the length of the subtree while recursively validating it. if isinstance(root, IdentifierNode): name = root.parameter @@ -35,7 +67,8 @@ def _validate_expr_tree(root: Node, parameter_range_lengths: dict[str, int]) -> # Association requires that all arguments are the exact same length. # Ensure that is the case arg_lengths = tuple( - _validate_expr_tree(child, parameter_range_lengths) for child in root.children + _validate_expr_tree_dimensions(child, parameter_range_lengths) + for child in root.children ) if len(set(arg_lengths)) > 1: raise ExpressionError( @@ -49,5 +82,10 @@ def _validate_expr_tree(root: Node, parameter_range_lengths: dict[str, int]) -> # For type hinting assert isinstance(root, ProductNode) return reduce( - mul, (_validate_expr_tree(child, parameter_range_lengths) for child in root.children), 1 + mul, + ( + _validate_expr_tree_dimensions(child, parameter_range_lengths) + for child in root.children + ), + 1, ) diff --git a/src/openjd/model/v2023_09/__init__.py b/src/openjd/model/v2023_09/__init__.py index 1cae7919..704e5252 100644 --- a/src/openjd/model/v2023_09/__init__.py +++ b/src/openjd/model/v2023_09/__init__.py @@ -21,6 +21,7 @@ CancelationMethodNotifyThenTerminate, CancelationMethodTerminate, CancelationMode, + ChunkIntTaskParameterDefinition, CombinationExpr, CommandString, DataString, @@ -73,6 +74,7 @@ StepTemplateList, StringRangeList, StringTaskParameterDefinition, + TaskChunksDefinition, TaskParameterList, TaskParameterStringValue, TaskParameterStringValueAsJob, @@ -101,6 +103,7 @@ "CancelationMethodNotifyThenTerminate", "CancelationMethodTerminate", "CancelationMode", + "ChunkIntTaskParameterDefinition", "CombinationExpr", "CommandString", "DataString", @@ -153,6 +156,7 @@ "StepTemplateList", "StringRangeList", "StringTaskParameterDefinition", + "TaskChunksDefinition", "TaskParameterList", "TaskParameterStringValue", "TaskParameterStringValueAsJob", diff --git a/src/openjd/model/v2023_09/_model.py b/src/openjd/model/v2023_09/_model.py index 43a79006..b8f22046 100644 --- a/src/openjd/model/v2023_09/_model.py +++ b/src/openjd/model/v2023_09/_model.py @@ -8,6 +8,7 @@ from graphlib import CycleError, TopologicalSorter from typing import Any, ClassVar, Literal, Optional, Type, Union, cast, Iterable from typing_extensions import Annotated, Self +import annotated_types from pydantic import ( field_validator, @@ -16,12 +17,13 @@ Field, PositiveInt, PositiveFloat, + Strict, StrictBool, StrictInt, ValidationError, ValidationInfo, ) -from pydantic_core import InitErrorDetails +from pydantic_core import InitErrorDetails, PydanticKnownError from pydantic.fields import ModelPrivateAttr from .._format_strings import FormatString @@ -33,6 +35,7 @@ from .._internal import ( CombinationExpressionParser, validate_step_parameter_space_dimensions, + validate_step_parameter_space_chunk_constraint, validate_unique_elements, ) from .._internal._variable_reference_validation import ( @@ -92,7 +95,7 @@ class ExtensionName(str, Enum): """ # # https://github.com/OpenJobDescription/openjd-specifications/blob/mainline/rfcs/0001-task-chunking.md - # TASK_CHUNKING = "TASK_CHUNKING" + TASK_CHUNKING = "TASK_CHUNKING" ExtensionNameList = Annotated[list[str], Field(min_length=1)] @@ -495,6 +498,7 @@ class TaskParameterType(str, Enum): FLOAT = "FLOAT" STRING = "STRING" PATH = "PATH" + CHUNK_INT = "CHUNK[INT]" class RangeString(FormatString): @@ -522,6 +526,8 @@ class RangeListTaskParameterDefinition(OpenJDModel_v2023_09): type: TaskParameterType # NOTE: Pydantic V1 was allowing non-string values in this range, V2 is enforcing that type. range: TaskRangeList + # has a value when type is CHUNK[INT], which is only possible from the TASK_CHUNKING extension + chunks: Optional[TaskChunksDefinition] = None @field_validator("range", mode="before") @classmethod @@ -539,6 +545,8 @@ class RangeExpressionTaskParameterDefinition(OpenJDModel_v2023_09): # element type of items in the range type: TaskParameterType range: TaskRangeExpression + # has a value when type is CHUNK[INT], which is only possible from the TASK_CHUNKING extension + chunks: Optional[TaskChunksDefinition] = None @field_validator("range") @classmethod @@ -553,6 +561,52 @@ def _validate_range_expression(cls, value: Any) -> Any: return value +class TaskChunksRangeConstraint(str, Enum): + CONTIGUOUS = "CONTIGUOUS" + NONCONTIGUOUS = "NONCONTIGUOUS" + + +class TaskChunksDefinition(OpenJDModel_v2023_09): + defaultTaskCount: Union[Annotated[int, annotated_types.Ge(1), Strict()], FormatString] + targetRuntimeSeconds: Optional[ + Union[Annotated[int, annotated_types.Ge(0), Strict()], FormatString] + ] = None + rangeConstraint: TaskChunksRangeConstraint + + _job_creation_metadata = JobCreationMetadata( + resolve_fields={"defaultTaskCount", "targetRuntimeSeconds"}, + ) + + @field_validator("defaultTaskCount") + @classmethod + def _validate_default_task_count(cls, value: Any) -> Any: + if isinstance(value, FormatString): + # If the string value has no expressions, can validate the value now. + # Otherwise will validate when + if len(value.expressions) == 0: + try: + int_value = int(value) + except ValueError: + raise ValueError("String literal must contain an integer.") + if int_value < 1: + raise PydanticKnownError("greater_than_equal", {"ge": 1}) + return value + + @field_validator("targetRuntimeSeconds") + @classmethod + def _validate_target_runtime_seconds(cls, value: Any) -> Any: + if isinstance(value, FormatString): + # If the string value has no expressions, can validate it now + if len(value.expressions) == 0: + try: + int_value = int(value) + except ValueError: + raise ValueError("String literal must contain an integer.") + if int_value < 0: + raise PydanticKnownError("greater_than_equal", {"ge": 0}) + return value + + class IntTaskParameterDefinition(OpenJDModel_v2023_09): """Definition of an integer-typed Task Parameter and its value range. @@ -577,8 +631,8 @@ class IntTaskParameterDefinition(OpenJDModel_v2023_09): ) _template_variable_sources = {"__export__": {"__self__"}} - def _get_range_task_param_type(model: Any) -> Type[OpenJDModel]: - if isinstance(model.range, RangeString): + def _get_range_task_param_type(self: Any) -> Type[OpenJDModel]: + if isinstance(self.range, RangeString): return RangeExpressionTaskParameterDefinition return RangeListTaskParameterDefinition @@ -591,9 +645,9 @@ def _get_range_task_param_type(model: Any) -> Type[OpenJDModel]: @field_validator("range", mode="before") @classmethod def _validate_range_element_type(cls, value: Any) -> Any: - # pydantic will automatically type coerse values into integers. We explicitly + # pydantic will automatically type coerce values into integers. We explicitly # want to reject non-integer values, so this *pre* validator validates the - # value *before* pydantic tries to type coerse it. + # value *before* pydantic tries to type coerce it. # We do allow coersion from a string since we want to allow "1", and # "1.2" or "a" will fail the type coersion if isinstance(value, list): @@ -613,7 +667,7 @@ def _validate_range_element_type(cls, value: Any) -> Any: if errors: raise ValidationError.from_exception_data(cls.__name__, line_errors=errors) elif isinstance(value, RangeString): - # TODO: nothing to do - it's guaranteed to be a format string at this point + # Nothing to do - it's guaranteed to be a format string at this point pass return value @@ -797,11 +851,134 @@ class PathTaskParameterDefinition(OpenJDModel_v2023_09): ) +class ChunkIntTaskParameterDefinition(OpenJDModel_v2023_09): + """Definition of an integer-typed Task Parameter, that is processed as + chunks of tasks insteas of as individual tasks when running. + + Attributes: + name (Identifier): A name by which the parameter is referenced. + type (TaskParameterType.CHUNK_INT): discriminator to identify the type of the parameter. + range (IntRangeList | RangeString): The list of values that the parameter takes on. + chunks (TaskChunkProperties): Properties that specify how to form chunks of tasks. + """ + + name: Identifier + type: Literal[TaskParameterType.CHUNK_INT] + # Note: Ordering here is important. Pydantic will try to match in + # the order given. + range: Union[IntRangeList, RangeString] + chunks: TaskChunksDefinition + + _template_variable_definitions = DefinesTemplateVariables( + defines={ + TemplateVariableDef(prefix="|Task.Param.", resolves=ResolutionScope.TASK), + TemplateVariableDef(prefix="|Task.RawParam.", resolves=ResolutionScope.TASK), + }, + field="name", + ) + _template_variable_sources = {"__export__": {"__self__"}} + + def _get_range_task_param_type(self: Any) -> Type[OpenJDModel]: + if isinstance(self.range, RangeString): + return RangeExpressionTaskParameterDefinition + return RangeListTaskParameterDefinition + + _job_creation_metadata = JobCreationMetadata( + create_as=JobCreateAsMetadata(callable=_get_range_task_param_type), + resolve_fields={"range"}, + exclude_fields={"name"}, + ) + + @model_validator(mode="before") + @classmethod + def _validate_task_chunking_extension( + cls, values: dict[str, Any], info: ValidationInfo + ) -> dict[str, Any]: + if info.context: + context = cast(ModelParsingContext, info.context) + if ExtensionName.TASK_CHUNKING not in context.extensions: + raise ValueError( + "The CHUNK[INT] task parameter requires the TASK_CHUNKING extension." + ) + return values + + @field_validator("range", mode="before") + @classmethod + def _validate_range_element_type(cls, value: Any) -> Any: + # pydantic will automatically type coerce values into integers. We explicitly + # want to reject non-integer values, so this *pre* validator validates the + # value *before* pydantic tries to type coerce it. + # We do allow coersion from a string since we want to allow "1", and + # "1.2" or "a" will fail the type coersion + if isinstance(value, list): + errors = list[InitErrorDetails]() + for i, item in enumerate(value): + if isinstance(item, bool) or not isinstance(item, (int, str)): + errors.append( + InitErrorDetails( + type="value_error", + loc=(i,), + ctx={ + "error": ValueError("Value must be an integer or integer string.") + }, + input=item, + ) + ) + if errors: + raise ValidationError.from_exception_data(cls.__name__, line_errors=errors) + elif isinstance(value, RangeString): + # Nothing to do - it's guaranteed to be a format string at this point + pass + + return value + + @field_validator("range") + @classmethod + def _validate_range_elements(cls, value: Any) -> Any: + if isinstance(value, list): + errors = list[InitErrorDetails]() + for i, item in enumerate(value): + if isinstance(item, TaskParameterStringValue): + # A TaskParameterStringValue is a FormatString. + # FormatString.expressions is the list of all expressions in the format string + # ( e.g. "{{ Param.Foo }}"). + # Reject the string if it contains any expressions. + if len(item.expressions) == 0: + try: + int(item) + except ValueError: + errors.append( + InitErrorDetails( + type="value_error", + loc=(i,), + ctx={ + "error": ValueError( + "String literal must contain an integer." + ) + }, + input=item, + ) + ) + if errors: + raise ValidationError.from_exception_data(cls.__name__, line_errors=errors) + else: + # If there are no format expressions, we can validate the range expression. + # otherwise we defer to the RangeExressionTaskParameter model when + # they've all been evaluated + if len(value.expressions) == 0: + try: + IntRangeExpr.from_str(value) + except Exception as e: + raise ValueError(str(e)) + return value + + TaskParameterDefinition = Union[ IntTaskParameterDefinition, FloatTaskParameterDefinition, StringTaskParameterDefinition, PathTaskParameterDefinition, + ChunkIntTaskParameterDefinition, ] TaskParameterList = Annotated[ @@ -890,6 +1067,7 @@ def _validate_combination(self) -> Self: # Ensure that the 'combination' string: # a) is a properly formed combination expression; and # b) references all available task parameters exactly once each + # c) does not include a CHUNK[INT] parameter in an associative expression try: parse_tree = CombinationExpressionParser().parse(combination) @@ -949,6 +1127,25 @@ def _validate_combination(self) -> Self: ) ) + # If a parameter has type CHUNK[INT], get its name + chunk_parameter = None + for param in self.taskParameterDefinitions: + if param.type == TaskParameterType.CHUNK_INT: + chunk_parameter = param.name + + try: + if chunk_parameter is not None: + validate_step_parameter_space_chunk_constraint(chunk_parameter, parse_tree) + except ExpressionError as e: + errors.append( + InitErrorDetails( + type="value_error", + loc=("combination",), + ctx={"error": ValueError(str(e))}, + input=combination, + ) + ) + if errors: raise ValidationError.from_exception_data(self.__class__.__name__, errors) @@ -2454,25 +2651,28 @@ def _unique_extension_names( def _permitted_extension_names( cls, value: Optional[ExtensionNameList], info: ValidationInfo ) -> Optional[ExtensionNameList]: - context = cast(ModelParsingContext, info.context) - if value is not None: - # Before processing the extensions field, context.extensions is the list of supported extensions - # that were requested in the call of the parse_job_template function. - # Take the intersection of the input supported extensions with what is implemented - # in this list, as the implementation needs to support an extension for it to be supported. - supported_extensions = context.extensions.intersection(cls.supported_extension_names()) - - unsupported_extensions = set(value).difference(supported_extensions) - if unsupported_extensions: - raise ValueError( - f"Unsupported extension names: {', '.join(sorted(unsupported_extensions))}" + if info.context: + context = cast(ModelParsingContext, info.context) + if value is not None: + # Before processing the extensions field, context.extensions is the list of supported extensions + # that were requested in the call of the parse_job_template function. + # Take the intersection of the input supported extensions with what is implemented + # in this list, as the implementation needs to support an extension for it to be supported. + supported_extensions = context.extensions.intersection( + cls.supported_extension_names() ) - # After processing the extensions field, context.extensions is the list of - # extension names used by the template. - context.extensions = set(value) - else: - context.extensions = set() + unsupported_extensions = set(value).difference(supported_extensions) + if unsupported_extensions: + raise ValueError( + f"Unsupported extension names: {', '.join(sorted(unsupported_extensions))}" + ) + + # After processing the extensions field, context.extensions is the list of + # extension names used by the template. + context.extensions = set(value) + else: + context.extensions = set() return value @field_validator("steps") diff --git a/test/openjd/model/v2023_09/test_chunk_int_task_parameter_type.py b/test/openjd/model/v2023_09/test_chunk_int_task_parameter_type.py new file mode 100644 index 00000000..23a7669c --- /dev/null +++ b/test/openjd/model/v2023_09/test_chunk_int_task_parameter_type.py @@ -0,0 +1,658 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + +from typing import Any + +import pytest +from pydantic import ValidationError + +from openjd.model._parse import _parse_model +from openjd.model.v2023_09 import ( + ChunkIntTaskParameterDefinition, + ModelParsingContext, + StepParameterSpaceDefinition, +) + + +@pytest.mark.parametrize( + "data", + ( + pytest.param( + { + "name": "foo", + "type": "CHUNK[INT]", + "range": [1], + "chunks": {"defaultTaskCount": 1, "rangeConstraint": "CONTIGUOUS"}, + }, + id="min len int list", + ), + pytest.param( + { + "name": "foo", + "type": "CHUNK[INT]", + "range": [1] * 1024, + "chunks": {"defaultTaskCount": 1, "rangeConstraint": "CONTIGUOUS"}, + }, + id="max len int list", + ), + pytest.param( + { + "name": "foo", + "type": "CHUNK[INT]", + "range": ["1"], + "chunks": {"defaultTaskCount": 1, "rangeConstraint": "CONTIGUOUS"}, + }, + id="int as string", + ), + pytest.param( + { + "name": "foo", + "type": "CHUNK[INT]", + "range": ["1", 2], + "chunks": {"defaultTaskCount": 1, "rangeConstraint": "CONTIGUOUS"}, + }, + id="mixed int types", + ), + pytest.param( + { + "name": "foo", + "type": "CHUNK[INT]", + "range": ["{{Param.Value}}"], + "chunks": {"defaultTaskCount": 1, "rangeConstraint": "CONTIGUOUS"}, + }, + id="format string", + ), + pytest.param( + { + "name": "foo", + "type": "CHUNK[INT]", + "range": [1, "2", "{{Param.Value}}"], + "chunks": {"defaultTaskCount": 1, "rangeConstraint": "CONTIGUOUS"}, + }, + id="mix of item types", + ), + pytest.param( + { + "name": "foo", + "type": "CHUNK[INT]", + "range": "1-100", + "chunks": {"defaultTaskCount": 10, "rangeConstraint": "NONCONTIGUOUS"}, + }, + id="non-contiguous chunks", + ), + pytest.param( + { + "name": "foo", + "type": "CHUNK[INT]", + "range": "1-100", + "chunks": { + "defaultTaskCount": 10, + "targetRuntimeSeconds": 0, + "rangeConstraint": "NONCONTIGUOUS", + }, + }, + id="target runtime seconds of 0", + ), + pytest.param( + { + "name": "foo", + "type": "CHUNK[INT]", + "range": "1-100", + "chunks": { + "defaultTaskCount": 10, + "targetRuntimeSeconds": 1000, + "rangeConstraint": "NONCONTIGUOUS", + }, + }, + id="target runtime seconds of 1000", + ), + pytest.param( + { + "name": "foo", + "type": "CHUNK[INT]", + "range": "1-100", + "chunks": { + "defaultTaskCount": "10", + "targetRuntimeSeconds": 100, + "rangeConstraint": "NONCONTIGUOUS", + }, + }, + id="defaultTaskCount is str", + ), + pytest.param( + { + "name": "foo", + "type": "CHUNK[INT]", + "range": "1-100", + "chunks": { + "defaultTaskCount": "{{Param.ChunkSize}}", + "targetRuntimeSeconds": 100, + "rangeConstraint": "NONCONTIGUOUS", + }, + }, + id="defaultTaskCount is str with expression", + ), + pytest.param( + { + "name": "foo", + "type": "CHUNK[INT]", + "range": "1-100", + "chunks": { + "defaultTaskCount": 10, + "targetRuntimeSeconds": "100", + "rangeConstraint": "NONCONTIGUOUS", + }, + }, + id="targetRuntimeSeconds is str", + ), + pytest.param( + { + "name": "foo", + "type": "CHUNK[INT]", + "range": "1-100", + "chunks": { + "defaultTaskCount": 10, + "targetRuntimeSeconds": "{{Param.TargetChunkRuntime}}", + "rangeConstraint": "NONCONTIGUOUS", + }, + }, + id="targetRuntimeSeconds is str expression", + ), + ), +) +def test_chunk_int_task_parameter_parse_success(data: dict[str, Any]) -> None: + # It parses successfully when the TASK_CHUNKING extension is requested + _parse_model( + model=ChunkIntTaskParameterDefinition, + obj=data, + context=ModelParsingContext(supported_extensions=["TASK_CHUNKING"]), + ) + + # It fails to parse without the TASK_CHUNKING extension + with pytest.raises(ValidationError) as excinfo: + _parse_model(model=ChunkIntTaskParameterDefinition, obj=data) + assert "The CHUNK[INT] task parameter requires the TASK_CHUNKING extension." in str( + excinfo.value + ) + assert excinfo.value.error_count() == 1 + + +@pytest.mark.parametrize( + "data,error_message,error_count", + ( + pytest.param({}, "Field required", 4, id="empty object"), + pytest.param( + { + "name": "foo", + "type": "FLOAT", + "range": [1], + "chunks": { + "defaultTaskCount": 10, + "targetRuntimeSeconds": 1000, + "rangeConstraint": "NONCONTIGUOUS", + }, + }, + "type\n Input should be", + 1, + id="wrong type", + ), + pytest.param( + { + "name": "foo", + "type": "CHUNK[INT]", + "range": [1], + }, + "chunks\n Field required", + 1, + id="missing chunks", + ), + pytest.param( + { + "type": "CHUNK[INT]", + "range": [1], + "chunks": { + "defaultTaskCount": 10, + "targetRuntimeSeconds": 1000, + "rangeConstraint": "NONCONTIGUOUS", + }, + }, + "name\n Field required", + 1, + id="missing name", + ), + pytest.param( + { + "name": "foo", + "range": [1], + "chunks": { + "defaultTaskCount": 10, + "targetRuntimeSeconds": 1000, + "rangeConstraint": "NONCONTIGUOUS", + }, + }, + "type\n Field required", + 1, + id="missing type", + ), + pytest.param( + { + "name": "foo", + "type": "CHUNK[INT]", + "chunks": { + "defaultTaskCount": 10, + "targetRuntimeSeconds": 1000, + "rangeConstraint": "NONCONTIGUOUS", + }, + }, + "range\n Field required", + 1, + id="missing range", + ), + pytest.param( + { + "name": "foo", + "type": "CHUNK[INT]", + "range": [], + "chunks": { + "defaultTaskCount": 10, + "targetRuntimeSeconds": 1000, + "rangeConstraint": "NONCONTIGUOUS", + }, + }, + "List should have at least 1 item after validation, not 0", + 2, + id="range too short", + ), + pytest.param( + { + "name": "foo", + "type": "CHUNK[INT]", + "range": [1], + "unknown": "key", + "chunks": { + "defaultTaskCount": 10, + "targetRuntimeSeconds": 1000, + "rangeConstraint": "NONCONTIGUOUS", + }, + }, + "Extra inputs are not permitted", + 1, + id="unknown key", + ), + pytest.param( + { + "name": "foo", + "type": "CHUNK[INT]", + "range": [1] * 1025, + "chunks": { + "defaultTaskCount": 10, + "targetRuntimeSeconds": 1000, + "rangeConstraint": "NONCONTIGUOUS", + }, + }, + "List should have at most 1024 items after validation, not 1025", + 2, + id="range too long", + ), + pytest.param( + { + "name": "foo", + "type": "CHUNK[INT]", + "range": [1.1], + "chunks": { + "defaultTaskCount": 10, + "targetRuntimeSeconds": 1000, + "rangeConstraint": "NONCONTIGUOUS", + }, + }, + "Value must be an integer or integer string.", + 1, + id="disallow floats", + ), + pytest.param( + { + "name": "foo", + "type": "CHUNK[INT]", + "range": [True], + "chunks": { + "defaultTaskCount": 10, + "targetRuntimeSeconds": 1000, + "rangeConstraint": "NONCONTIGUOUS", + }, + }, + "Value must be an integer or integer string.", + 1, + id="disallow bool", + ), + pytest.param( + { + "name": "foo", + "type": "CHUNK[INT]", + "range": ["1.1"], + "chunks": { + "defaultTaskCount": 10, + "targetRuntimeSeconds": 1000, + "rangeConstraint": "NONCONTIGUOUS", + }, + }, + "String literal must contain an integer.", + 1, + id="disallow float strings", + ), + pytest.param( + { + "name": "foo", + "type": "CHUNK[INT]", + "range": ["{{ Job.Parameter.Foo"], + "chunks": { + "defaultTaskCount": 10, + "targetRuntimeSeconds": 1000, + "rangeConstraint": "NONCONTIGUOUS", + }, + }, + "Failed to parse interpolation expression at [0, 20]. Reason: Braces mismatch.", + 3, + id="malformed format string", + ), + pytest.param( + { + "name": "foo", + "type": "CHUNK[INT]", + "range": ["notint"], + "chunks": { + "defaultTaskCount": 10, + "targetRuntimeSeconds": 1000, + "rangeConstraint": "NONCONTIGUOUS", + }, + }, + "String literal must contain an integer.", + 1, + id="literal string not an int", + ), + pytest.param( + { + "name": "foo", + "type": "CHUNK[INT]", + "range": "1-100", + "chunks": { + "targetRuntimeSeconds": 1000, + "rangeConstraint": "NONCONTIGUOUS", + }, + }, + "chunks.defaultTaskCount\n Field required", + 1, + id="missing defaultTaskCount", + ), + pytest.param( + { + "name": "foo", + "type": "CHUNK[INT]", + "range": "1-100", + "chunks": { + "defaultTaskCount": 10, + "targetRuntimeSeconds": 1000, + }, + }, + "chunks.rangeConstraint\n Field required", + 1, + id="missing rangeConstraint", + ), + pytest.param( + { + "name": "foo", + "type": "CHUNK[INT]", + "range": "1-100", + "chunks": { + "defaultTaskCount": 0, + "targetRuntimeSeconds": 1000, + "rangeConstraint": "NONCONTIGUOUS", + }, + }, + "Input should be greater than or equal to 1", + 2, + id="defaultTaskCount 0 (too small)", + ), + pytest.param( + { + "name": "foo", + "type": "CHUNK[INT]", + "range": "1-100", + "chunks": { + "defaultTaskCount": 1, + "targetRuntimeSeconds": -1, + "rangeConstraint": "NONCONTIGUOUS", + }, + }, + "Input should be greater than or equal to 0", + 2, + id="targetRuntimeSeconds -1 (too small)", + ), + pytest.param( + { + "name": "foo", + "type": "CHUNK[INT]", + "range": "1-100", + "chunks": { + "defaultTaskCount": 1, + "targetRuntimeSeconds": 0, + "rangeConstraint": "UNCONTIGUOUS", + }, + }, + "chunks.rangeConstraint\n Input should be 'CONTIGUOUS' or 'NONCONTIGUOUS'", + 1, + id="rangeConstraint incorrect value UNCONTIGUOUS", + ), + pytest.param( + { + "name": "foo", + "type": "CHUNK[INT]", + "range": "1-100", + "chunks": { + "defaultTaskCount": "0", + "targetRuntimeSeconds": 0, + "rangeConstraint": "CONTIGUOUS", + }, + }, + "Input should be greater than or equal to 1", + 1, + id="defaultTaskCount is str with non-positive integer value", + ), + pytest.param( + { + "name": "foo", + "type": "CHUNK[INT]", + "range": "1-100", + "chunks": { + "defaultTaskCount": "1.5", + "targetRuntimeSeconds": 0, + "rangeConstraint": "CONTIGUOUS", + }, + }, + "String literal must contain an integer.", + 1, + id="defaultTaskCount is str with non-integer value", + ), + pytest.param( + { + "name": "foo", + "type": "CHUNK[INT]", + "range": "1-100", + "chunks": { + "defaultTaskCount": "{{Param.ChunkSize}", + "targetRuntimeSeconds": 0, + "rangeConstraint": "CONTIGUOUS", + }, + }, + "Failed to parse interpolation expression at [0, 18]. Reason: Braces mismatch.", + 2, + id="defaultTaskCount is str with incorrect expression", + ), + pytest.param( + { + "name": "foo", + "type": "CHUNK[INT]", + "range": "1-100", + "chunks": { + "defaultTaskCount": 2, + "targetRuntimeSeconds": "0.1", + "rangeConstraint": "CONTIGUOUS", + }, + }, + "String literal must contain an integer.", + 1, + id="targetRuntimeSeconds is str with non-integer value", + ), + pytest.param( + { + "name": "foo", + "type": "CHUNK[INT]", + "range": "1-100", + "chunks": { + "defaultTaskCount": 2, + "targetRuntimeSeconds": "-1", + "rangeConstraint": "CONTIGUOUS", + }, + }, + "Input should be greater than or equal to 0", + 1, + id="targetRuntimeSeconds is str with negative integer value", + ), + pytest.param( + { + "name": "foo", + "type": "CHUNK[INT]", + "range": "1-100", + "chunks": { + "defaultTaskCount": 2, + "targetRuntimeSeconds": "{{Param.TargetChunkRuntime}", + "rangeConstraint": "CONTIGUOUS", + }, + }, + "Failed to parse interpolation expression at [0, 27]. Reason: Braces mismatch.", + 2, + id="targetRuntimeSeconds is str with incorrect expression", + ), + ), +) +def test_chunk_int_task_parameter_parse_fails( + data: dict[str, Any], error_message: str, error_count: int +) -> None: + # It fails to parse with a test-specific message with the TASK_CHUNKING extension + with pytest.raises(ValidationError) as excinfo: + _parse_model( + model=ChunkIntTaskParameterDefinition, + obj=data, + context=ModelParsingContext(supported_extensions=["TASK_CHUNKING"]), + ) + print(excinfo.value) + assert error_message in str(excinfo.value) + assert excinfo.value.error_count() == error_count + + # It fails to parse without the TASK_CHUNKING extension + with pytest.raises(ValidationError) as excinfo: + _parse_model(model=ChunkIntTaskParameterDefinition, obj=data) + assert "The CHUNK[INT] task parameter requires the TASK_CHUNKING extension." in str( + excinfo.value + ) + assert excinfo.value.error_count() == 1 + + +@pytest.mark.parametrize( + "data", + ( + pytest.param( + { + "taskParameterDefinitions": [ + {"name": "foo", "type": "INT", "range": "1-5"}, + {"name": "bar", "type": "INT", "range": "6-10"}, + { + "name": "baz", + "type": "CHUNK[INT]", + "range": "1-10", + "chunks": {"defaultTaskCount": 1, "rangeConstraint": "CONTIGUOUS"}, + }, + ], + "combination": "(foo, bar) * baz", + }, + id="combination expr with CHUNK[INT]", + ), + ), +) +def test_param_space_with_chunk_int_parse_success(data: dict[str, Any]) -> None: + # It parses successfully when the TASK_CHUNKING extension is requested + _parse_model( + model=StepParameterSpaceDefinition, + obj=data, + context=ModelParsingContext(supported_extensions=["TASK_CHUNKING"]), + ) + + # It fails to parse without the TASK_CHUNKING extension + with pytest.raises(ValidationError) as excinfo: + _parse_model(model=StepParameterSpaceDefinition, obj=data) + assert "The CHUNK[INT] task parameter requires the TASK_CHUNKING extension." in str( + excinfo.value + ) + assert excinfo.value.error_count() == 1 + + +@pytest.mark.parametrize( + "data,error_message,error_count", + ( + pytest.param( + { + "taskParameterDefinitions": [ + {"name": "foo", "type": "INT", "range": "1-5"}, + {"name": "bar", "type": "INT", "range": "11-20"}, + { + "name": "baz", + "type": "CHUNK[INT]", + "range": "1-10", + "chunks": {"defaultTaskCount": 1, "rangeConstraint": "CONTIGUOUS"}, + }, + ], + "combination": "foo * (bar, baz)", + }, + "CHUNK[INT] parameter baz must not be part of an associative expression.", + 1, + id="CHUNK[INT] directly in associative expression", + ), + pytest.param( + { + "taskParameterDefinitions": [ + {"name": "foo", "type": "INT", "range": "11-20"}, + {"name": "bar", "type": "INT", "range": "12"}, + { + "name": "baz", + "type": "CHUNK[INT]", + "range": "1-10", + "chunks": {"defaultTaskCount": 1, "rangeConstraint": "CONTIGUOUS"}, + }, + ], + "combination": "(foo, bar * baz)", + }, + "CHUNK[INT] parameter baz must not be part of an associative expression.", + 1, + id="CHUNK[INT] nested in product before associative expression", + ), + ), +) +def test_param_space_with_chunk_int_parse_fails( + data: dict[str, Any], error_message: str, error_count: int +) -> None: + # It fails to parse with a test-specific message with the TASK_CHUNKING extension + with pytest.raises(ValidationError) as excinfo: + _parse_model( + model=StepParameterSpaceDefinition, + obj=data, + context=ModelParsingContext(supported_extensions=["TASK_CHUNKING"]), + ) + print(excinfo.value) + assert error_message in str(excinfo.value) + assert excinfo.value.error_count() == error_count + + # It fails to parse without the TASK_CHUNKING extension + with pytest.raises(ValidationError) as excinfo: + _parse_model(model=StepParameterSpaceDefinition, obj=data) + assert "The CHUNK[INT] task parameter requires the TASK_CHUNKING extension." in str( + excinfo.value + ) + assert excinfo.value.error_count() == 1 diff --git a/test/openjd/model/v2023_09/test_create.py b/test/openjd/model/v2023_09/test_create.py index 96b7c7a7..34bad514 100644 --- a/test/openjd/model/v2023_09/test_create.py +++ b/test/openjd/model/v2023_09/test_create.py @@ -12,6 +12,7 @@ AttributeRequirementTemplate, CancelationMethodNotifyThenTerminate, CancelationMethodTerminate, + ChunkIntTaskParameterDefinition, EmbeddedFileText, Environment, EnvironmentActions, @@ -35,13 +36,14 @@ StepScript, StepTemplate, StringTaskParameterDefinition, + TaskChunksDefinition, ) class TestCreateJob: - def test(self) -> None: + def test_v2023_09(self) -> None: # One big test that does everything relevant for the create-job annotations. - # Should be the only test that we need. + # Should be the only test that we need for the baseline specification. # # Key things: # 1) Every format string has a job parameter reference - only some should be @@ -394,3 +396,119 @@ def test(self) -> None: assert result.model_dump() == expected.model_dump() # This is the important assertion. assert result == expected + + def test_v2023_09_extension_task_chunking(self) -> None: + # An end-to-end test for the TASK_CHUNKING extension. + # + # Key things: + # 1) Every format string has a job parameter reference - only some should be + # evaluated when creating jobs + # Specifically, only job name & task parameter range values should be evaluated. + # 2) Every entity and every field that exists is defined at least once. + # 3) Testing of _internal.create_job covers corner cases & exceptions; we don't worry + # about those here. + + # GIVEN + extra_kwargs = {"$schema": "blah "} # special snowflake due to field naming + template = JobTemplate( + **extra_kwargs, + specificationVersion="jobtemplate-2023-09", + extensions=["TASK_CHUNKING"], + name="Job {{ Param.IntParam }}", + parameterDefinitions=[ + JobIntParameterDefinition( + name="RangeExpressionParam", + type="INT", + description="desc", + minValue=0, + maxValue=100, + allowedValues=[3, 75], + default=75, + ), + JobIntParameterDefinition( + name="IntParam", + type="INT", + description="desc", + minValue=0, + maxValue=100, + allowedValues=[5, 10, 20], + default=20, + ), + ], + steps=[ + StepTemplate( + name="StepName", + parameterSpace=StepParameterSpaceDefinition( + taskParameterDefinitions=[ + ChunkIntTaskParameterDefinition( + name="ParamE", + type="CHUNK[INT]", + range="2 - {{ Param.RangeExpressionParam }}", + chunks=TaskChunksDefinition( + defaultTaskCount="{{Param.RangeExpressionParam}}", + targetRuntimeSeconds="{{Param.IntParam}}", + rangeConstraint="CONTIGUOUS", + ), + ), + ], + combination="ParamE", + ), + script=StepScript( + actions=StepActions( + onRun=Action( + command="{{ Param.IntParam }}", + ) + ), + ), + ) + ], + ) + job_parameter_values = { + "IntParam": ParameterValue(type=ParameterValueType.INT, value="10"), + "RangeExpressionParam": ParameterValue(type=ParameterValueType.STRING, value="3"), + } + + # WHEN + result = create_job(job_template=template, job_parameter_values=job_parameter_values) + + # THEN + expected = Job( + extensions=["TASK_CHUNKING"], + name="Job 10", + parameters={ + "RangeExpressionParam": JobParameter(type="INT", description="desc", value="3"), + "IntParam": JobParameter(type="INT", description="desc", value="10"), + }, + steps=[ + Step( + name="StepName", + parameterSpace=StepParameterSpace( + taskParameterDefinitions={ + "ParamE": RangeExpressionTaskParameterDefinition( + type="CHUNK[INT]", + range="2 - 3", + chunks=TaskChunksDefinition( + defaultTaskCount="3", + targetRuntimeSeconds="10", + rangeConstraint="CONTIGUOUS", + ), + ), + }, + combination="ParamE", + ), + script=StepScript( + actions=StepActions( + onRun=Action( + command="{{ Param.IntParam }}", + ) + ), + ), + ) + ], + ) + + # Note: The dict compare generates an easier to read diff if there's a test failure. + # It is not essential to the test. + assert result.model_dump() == expected.model_dump() + # This is the important assertion. + assert result == expected diff --git a/test/openjd/model/v2023_09/test_strings.py b/test/openjd/model/v2023_09/test_strings.py index 260277ce..984fa863 100644 --- a/test/openjd/model/v2023_09/test_strings.py +++ b/test/openjd/model/v2023_09/test_strings.py @@ -153,9 +153,9 @@ def test_parse_success(self, value: str) -> None: pytest.param({"name": "\n"}, id="no newline"), # Just testing the boundary points of the allowable characters pytest.param({"name": "\u0000"}, id="NULL disallowed"), - pytest.param({"name": "\u001F"}, id="1f disallowed"), - pytest.param({"name": "\u007F"}, id="DEL disallowed"), - pytest.param({"name": "\u009F"}, id="9f disallowed"), + pytest.param({"name": "\u001f"}, id="1f disallowed"), + pytest.param({"name": "\u007f"}, id="DEL disallowed"), + pytest.param({"name": "\u009f"}, id="9f disallowed"), ), ) def test_parse_fails(self, data: dict[str, Any]) -> None: @@ -197,9 +197,9 @@ def test_parse_success(self, value: str) -> None: pytest.param({"name": "\n"}, id="no newline"), # Just testing the boundary points of the allowable characters pytest.param({"name": "\u0000"}, id="NULL disallowed"), - pytest.param({"name": "\u001F"}, id="1f disallowed"), - pytest.param({"name": "\u007F"}, id="DEL disallowed"), - pytest.param({"name": "\u009F"}, id="9f disallowed"), + pytest.param({"name": "\u001f"}, id="1f disallowed"), + pytest.param({"name": "\u007f"}, id="DEL disallowed"), + pytest.param({"name": "\u009f"}, id="9f disallowed"), ), ) def test_parse_fails(self, data: dict[str, Any]) -> None: @@ -241,9 +241,9 @@ def test_parse_success(self, value: str) -> None: pytest.param({"name": "\n"}, id="no newline"), # Just testing the boundary points of the allowable characters pytest.param({"name": "\u0000"}, id="NULL disallowed"), - pytest.param({"name": "\u001F"}, id="1f disallowed"), - pytest.param({"name": "\u007F"}, id="DEL disallowed"), - pytest.param({"name": "\u009F"}, id="9f disallowed"), + pytest.param({"name": "\u001f"}, id="1f disallowed"), + pytest.param({"name": "\u007f"}, id="DEL disallowed"), + pytest.param({"name": "\u009f"}, id="9f disallowed"), ), ) def test_parse_fails(self, data: dict[str, Any]) -> None: @@ -428,8 +428,8 @@ class TestDescription: pytest.param("A" * 2048, id="max length"), # Control character exlusion cases pytest.param("\u0020", id="start of first printable range"), - pytest.param("\u007E", id="end of first printable range"), - pytest.param("\u00A0", id="start of second printable range"), + pytest.param("\u007e", id="end of first printable range"), + pytest.param("\u00a0", id="start of second printable range"), ), ) def test_parse_success(self, value: str) -> None: @@ -449,9 +449,9 @@ def test_parse_success(self, value: str) -> None: pytest.param({"desc": ""}, id="too short"), pytest.param({"desc": "a" * 2049}, id="too long"), pytest.param({"desc": "\u0000"}, id="start of first control character range"), - pytest.param({"desc": "\u001F"}, id="end of first control character range"), - pytest.param({"desc": "\u007F"}, id="start of second control character range"), - pytest.param({"desc": "\u009F"}, id="end of second control character range"), + pytest.param({"desc": "\u001f"}, id="end of first control character range"), + pytest.param({"desc": "\u007f"}, id="start of second control character range"), + pytest.param({"desc": "\u009f"}, id="end of second control character range"), pytest.param({"desc": "a\n\u0000"}, id="disallowed after newline"), ), ) @@ -510,8 +510,8 @@ class TestArgString: pytest.param("A" * (32 * 1024), id="long length"), # Control character exlusion cases pytest.param("\u0020", id="start of first printable range"), - pytest.param("\u007E", id="end of first printable range"), - pytest.param("\u00A0", id="start of second printable range"), + pytest.param("\u007e", id="end of first printable range"), + pytest.param("\u00a0", id="start of second printable range"), ), ) def test_parse_success(self, value: str) -> None: @@ -532,9 +532,9 @@ def test_parse_success(self, value: str) -> None: pytest.param({"cmd": "\r"}, id="carriage return"), pytest.param({"cmd": "\t"}, id="horizontal tab"), pytest.param({"arg": "\u0000"}, id="start of first control character range"), - pytest.param({"arg": "\u001F"}, id="end of first control character range"), - pytest.param({"arg": "\u007F"}, id="start of second control character range"), - pytest.param({"arg": "\u009F"}, id="end of second control character range"), + pytest.param({"arg": "\u001f"}, id="end of first control character range"), + pytest.param({"arg": "\u007f"}, id="start of second control character range"), + pytest.param({"arg": "\u009f"}, id="end of second control character range"), ), ) def test_parse_fails(self, data: dict[str, Any]) -> None: @@ -556,8 +556,8 @@ class TestCommandString: pytest.param("A" * (32 * 1024), id="long length"), # Control character exlusion cases pytest.param("\u0020", id="start of first printable range"), - pytest.param("\u007E", id="end of first printable range"), - pytest.param("\u00A0", id="start of second printable range"), + pytest.param("\u007e", id="end of first printable range"), + pytest.param("\u00a0", id="start of second printable range"), ), ) def test_parse_success(self, value: str) -> None: @@ -579,9 +579,9 @@ def test_parse_success(self, value: str) -> None: pytest.param({"cmd": "\r"}, id="carriage return"), pytest.param({"cmd": "\t"}, id="horizontal tab"), pytest.param({"cmd": "\u0000"}, id="start of first control character range"), - pytest.param({"cmd": "\u001F"}, id="end of first control character range"), - pytest.param({"cmd": "\u007F"}, id="start of second control character range"), - pytest.param({"cmd": "\u009F"}, id="end of second control character range"), + pytest.param({"cmd": "\u001f"}, id="end of first control character range"), + pytest.param({"cmd": "\u007f"}, id="start of second control character range"), + pytest.param({"cmd": "\u009f"}, id="end of second control character range"), ), ) def test_parse_fails(self, data: dict[str, Any]) -> None: @@ -770,9 +770,9 @@ def test_parse_success(self, value: str) -> None: pytest.param({"str": "\n"}, id="no newline"), # Just testing the boundary points of the allowable characters pytest.param({"str": "\u0000"}, id="NULL disallowed"), - pytest.param({"str": "\u001F"}, id="1f disallowed"), - pytest.param({"str": "\u007F"}, id="DEL disallowed"), - pytest.param({"str": "\u009F"}, id="9f disallowed"), + pytest.param({"str": "\u001f"}, id="1f disallowed"), + pytest.param({"str": "\u007f"}, id="DEL disallowed"), + pytest.param({"str": "\u009f"}, id="9f disallowed"), ), ) def test_parse_fails(self, data: dict[str, Any]) -> None: @@ -817,9 +817,9 @@ def test_parse_success(self, value: str) -> None: pytest.param({"name": "*.\n"}, id="no newline"), # Just testing the boundary points of the allowable characters pytest.param({"name": "*.\u0000"}, id="NULL disallowed"), - pytest.param({"name": "*.\u001F"}, id="1f disallowed"), - pytest.param({"name": "*.\u007F"}, id="DEL disallowed"), - pytest.param({"name": "*.\u009F"}, id="9f disallowed"), + pytest.param({"name": "*.\u001f"}, id="1f disallowed"), + pytest.param({"name": "*.\u007f"}, id="DEL disallowed"), + pytest.param({"name": "*.\u009f"}, id="9f disallowed"), # The list of characters explicitly disallowed # b. Path separators "\" and "/". pytest.param({"name": "*.\\"}, id="no '\\'"),