From eaf88ced9f732123005b87e1afe9f0d2d0466ed0 Mon Sep 17 00:00:00 2001 From: Mark Wiebe <399551+mwiebe@users.noreply.github.com> Date: Thu, 30 Jan 2025 17:54:21 -0800 Subject: [PATCH 1/3] feat: Implement the extensions RFC 0002 Signed-off-by: Mark Wiebe <399551+mwiebe@users.noreply.github.com> --- README.md | 34 ++++---- src/openjd/model/_parse.py | 56 +++++++++++-- src/openjd/model/_types.py | 4 + src/openjd/model/v2023_09/__init__.py | 2 + src/openjd/model/v2023_09/_model.py | 115 +++++++++++++++++++++++++- test/openjd/model/test_parse.py | 94 +++++++++++++++++++++ 6 files changed, 278 insertions(+), 27 deletions(-) diff --git a/README.md b/README.md index b3b70a65..ff0fc743 100644 --- a/README.md +++ b/README.md @@ -23,13 +23,13 @@ This library requires: ## Versioning -This package's version follows [Semantic Versioning 2.0](https://semver.org/), but is still considered to be in its +This package's version follows [Semantic Versioning 2.0](https://semver.org/), but is still considered to be in its initial development, thus backwards incompatible versions are denoted by minor version bumps. To help illustrate how versions will increment during this initial development stage, they are described below: -1. The MAJOR version is currently 0, indicating initial development. -2. The MINOR version is currently incremented when backwards incompatible changes are introduced to the public API. -3. The PATCH version is currently incremented when bug fixes or backwards compatible changes are introduced to the public API. +1. The MAJOR version is currently 0, indicating initial development. +2. The MINOR version is currently incremented when backwards incompatible changes are introduced to the public API. +3. The PATCH version is currently incremented when bug fixes or backwards compatible changes are introduced to the public API. ## Contributing @@ -303,7 +303,7 @@ For example, if you would like to verify your download of the wheel for version 3) Save the following contents to a file called `openjobdescription-pgp.asc`: ``` -----BEGIN PGP PUBLIC KEY BLOCK----- - + mQINBGXGjx0BEACdChrQ/nch2aYGJ4fxHNQwlPE42jeHECqTdlc1V/mug+7qN7Pc C4NQk4t68Y72WX/NG49gRfpAxPlSeNt18c3vJ9/sWTukmonWYGK0jQGnDWjuVgFT XtvJAAQBFilQXN8h779Th2lEuD4bQX+mGB7l60Xvh7vIehE3C4Srbp6KJXskPLPo @@ -350,36 +350,36 @@ For example, if you would like to verify your download of the wheel for version gpg (GnuPG) 2.0.22; Copyright (C) 2013 Free Software Foundation, Inc. This is free software: you are free to change and redistribute it. There is NO WARRANTY, to the extent permitted by law. - - + + pub 4096R/BCC40987 created: 2024-02-09 expires: 2026-02-08 usage: SCEA trust: unknown validity: unknown [ unknown] (1). Open Job Description - + gpg> trust pub 4096R/BCC40987 created: 2024-02-09 expires: 2026-02-08 usage: SCEA trust: unknown validity: unknown [ unknown] (1). Open Job Description - + Please decide how far you trust this user to correctly verify other users' keys (by looking at passports, checking fingerprints from different sources, etc.) - + 1 = I don't know or won't say 2 = I do NOT trust 3 = I trust marginally 4 = I trust fully 5 = I trust ultimately m = back to the main menu - + Your decision? 5 Do you really want to set this key to ultimate trust? (y/N) y - + pub 4096R/BCC40987 created: 2024-02-09 expires: 2026-02-08 usage: SCEA trust: ultimate validity: unknown [ unknown] (1). Open Job Description Please note that the shown key validity is not necessarily correct unless you restart the program. - + gpg> quit ``` @@ -391,11 +391,11 @@ For example, if you would like to verify your download of the wheel for version ## Security -We take all security reports seriously. When we receive such reports, we will -investigate and subsequently address any potential vulnerabilities as quickly -as possible. If you discover a potential security issue in this project, please +We take all security reports seriously. When we receive such reports, we will +investigate and subsequently address any potential vulnerabilities as quickly +as possible. If you discover a potential security issue in this project, please notify AWS/Amazon Security via our [vulnerability reporting page](http://aws.amazon.com/security/vulnerability-reporting/) -or directly via email to [AWS Security](aws-security@amazon.com). Please do not +or directly via email to [AWS Security](aws-security@amazon.com). Please do not create a public GitHub issue in this project. ## License diff --git a/src/openjd/model/_parse.py b/src/openjd/model/_parse.py index 66f5ec56..c33dda5c 100644 --- a/src/openjd/model/_parse.py +++ b/src/openjd/model/_parse.py @@ -4,6 +4,7 @@ from dataclasses import is_dataclass from decimal import Decimal from enum import Enum +from collections.abc import Iterable from typing import Any, ClassVar, Optional, Type, TypeVar, Union, cast import yaml @@ -44,9 +45,14 @@ class PydanticDataclass: T = TypeVar("T", bound=OpenJDModel) -def _parse_model(*, model: Type[T], obj: Any) -> T: +def _parse_model(*, model: Type[T], obj: Any, context: Any = None) -> T: + if context is None: + context = model.model_parsing_context_type() if is_dataclass(model): - return cast(T, cast(PydanticDataclass, model).__pydantic_model__.model_validate(obj)) + return cast( + T, + cast(PydanticDataclass, model).__pydantic_model__.model_validate(obj, context=context), + ) else: prevalidator_error: Optional[PydanticValidationError] = None if hasattr(model, "_root_template_prevalidator"): @@ -55,7 +61,7 @@ def _parse_model(*, model: Type[T], obj: Any) -> T: except PydanticValidationError as exc: prevalidator_error = exc try: - result = cast(T, cast(BaseModel, model).model_validate(obj)) + result = cast(T, cast(BaseModel, model).model_validate(obj, context=context)) except PydanticValidationError as exc: if prevalidator_error is not None: errors = list[InitErrorDetails]() @@ -74,9 +80,15 @@ def _parse_model(*, model: Type[T], obj: Any) -> T: return result -def parse_model(*, model: Type[T], obj: Any) -> T: +def parse_model( + *, model: Type[T], obj: Any, supported_extensions: Optional[Iterable[str]] = None +) -> T: try: - return _parse_model(model=model, obj=obj) + return _parse_model( + model=model, + obj=obj, + context=model.model_parsing_context_type(supported_extensions=supported_extensions), + ) except PydanticValidationError as exc: errors: list[ErrorDetails] = exc.errors() raise DecodeValidationError(pydantic_validationerrors_to_str(model, errors)) @@ -141,12 +153,22 @@ def decimal_to_str(data: Union[dict[str, Any], list[Any]]) -> None: return as_dict -def decode_job_template(*, template: dict[str, Any]) -> JobTemplate: +def decode_job_template( + *, template: dict[str, Any], supported_extensions: Optional[Iterable[str]] = None +) -> JobTemplate: """Given a dictionary containing a Job Template, this will decode the template, run validation checks on it, and then return the decoded template. + This function places no restriction on the version of the specification. The caller + can inspect the `specificationVersion` property of the returned object to validate this. + + By default, no extensions are supported. The caller can opt in to specific extensions, + by providing them as a list. + Args: template (dict[str, Any]): A Job Template as a dictionary object. + supported_extensions (list[str]): A list of extension names to support. This list is intersected + with the extensions names supported by the implementation before processing. Returns: JobTemplate: The decoded job template. @@ -189,7 +211,9 @@ def decode_job_template(*, template: dict[str, Any]) -> JobTemplate: ) if schema_version == TemplateSpecificationVersion.JOBTEMPLATE_v2023_09: - return parse_model(model=JobTemplate_2023_09, obj=template) + return parse_model( + model=JobTemplate_2023_09, obj=template, supported_extensions=supported_extensions + ) else: raise NotImplementedError( f"Template decode for schema {schema_version.value} is not yet implemented." @@ -203,12 +227,22 @@ def decode_template(*, template: dict[str, Any]) -> JobTemplate: return decode_job_template(template=template) -def decode_environment_template(*, template: dict[str, Any]) -> EnvironmentTemplate: +def decode_environment_template( + *, template: dict[str, Any], supported_extensions: Optional[Iterable[str]] = None +) -> EnvironmentTemplate: """Given a dictionary containing an Environment Template, this will decode the template, run validation checks on it, and then return the decoded template. + This function places no restriction on the version of the specification. The caller + can inspect the `specificationVersion` property of the returned object to validate this. + + By default, no extensions are supported. The caller can opt in to specific extensions, + by providing them as a list. + Args: template (dict[str, Any]): An Environment Template as a dictionary object. + supported_extensions (list[str]): A list of extension names to support. This list is intersected + with the extensions names supported by the implementation before processing. Returns: EnvironmentTemplate: The decoded environment template. @@ -246,7 +280,11 @@ def decode_environment_template(*, template: dict[str, Any]) -> EnvironmentTempl ) if schema_version == TemplateSpecificationVersion.ENVIRONMENT_v2023_09: - return parse_model(model=EnvironmentTemplate_2023_09, obj=template) + return parse_model( + model=EnvironmentTemplate_2023_09, + obj=template, + supported_extensions=supported_extensions, + ) else: raise NotImplementedError( f"Template decode for schema {schema_version.value} is not yet implemented." diff --git a/src/openjd/model/_types.py b/src/openjd/model/_types.py index af13dc2e..7742953b 100644 --- a/src/openjd/model/_types.py +++ b/src/openjd/model/_types.py @@ -274,6 +274,10 @@ class OpenJDModel(BaseModel): # The specific schema revision that the model implements. revision: ClassVar[SpecificationRevision] + # The model parsing context required by this model. Each revision of + # the specification defines this, and it must be default-constructible. + model_parsing_context_type: ClassVar[Type] + # ---- # Metadata used for defining template variables for use in FormatStrings diff --git a/src/openjd/model/v2023_09/__init__.py b/src/openjd/model/v2023_09/__init__.py index 936d308c..1cae7919 100644 --- a/src/openjd/model/v2023_09/__init__.py +++ b/src/openjd/model/v2023_09/__init__.py @@ -57,6 +57,7 @@ JobStringParameterDefinition, JobTemplate, JobTemplateName, + ModelParsingContext, ParameterStringValue, PathTaskParameterDefinition, RangeExpressionTaskParameterDefinition, @@ -136,6 +137,7 @@ "JobStringParameterDefinition", "JobTemplate", "JobTemplateName", + "ModelParsingContext", "ParameterStringValue", "PathTaskParameterDefinition", "RangeExpressionTaskParameterDefinition", diff --git a/src/openjd/model/v2023_09/_model.py b/src/openjd/model/v2023_09/_model.py index 7dac58d3..a56729bf 100644 --- a/src/openjd/model/v2023_09/_model.py +++ b/src/openjd/model/v2023_09/_model.py @@ -6,7 +6,7 @@ from decimal import Decimal, InvalidOperation from enum import Enum from graphlib import CycleError, TopologicalSorter -from typing import Any, ClassVar, Literal, Optional, Type, Union, cast +from typing import Any, ClassVar, Literal, Optional, Type, Union, cast, Iterable from typing_extensions import Annotated, Self from pydantic import ( @@ -52,8 +52,44 @@ ) +class ModelParsingContext: + """Context required while parsing an OpenJDModel. An instance of this class + must be provided when calling model_validate. + + OpenJDModelSubclass.model_validate(data, context=ModelParsingContext()) + + Individual validators receive this value as ValidationInfo.context. + """ + + extensions: set[str] + """Initially, is the set of supported extension names. When the 'extensions' + field of the template is processed, this becomes the set of extensions that + the the template requested.""" + + def __init__(self, *, supported_extensions: Optional[Iterable[str]] = None) -> None: + self.extensions = set(supported_extensions or []) + + class OpenJDModel_v2023_09(OpenJDModel): # noqa: N801 revision = SpecificationRevision.v2023_09 + model_parsing_context_type = ModelParsingContext + + @staticmethod + def supported_extension_names() -> set[str]: + """Returns the list of all extension names supported by the 2023-09 specification version.""" + return {v.value for v in ExtensionName} + + +class ExtensionName(str, Enum): + """Enumerant of all extensions supported for the 2023-09 specification revision. + This appears in the 'extensions' list property of all model instances. + """ + + # # https://github.com/OpenJobDescription/openjd-specifications/blob/mainline/rfcs/0001-task-chunking.md + # TASK_CHUNKING = "TASK_CHUNKING" + + +ExtensionNameList = Annotated[list[str], Field(min_length=1)] class ValueReferenceConstants(Enum): @@ -2351,6 +2387,7 @@ class Job(OpenJDModel_v2023_09): description: Optional[Description] = None parameters: Optional[JobParameters] = None jobEnvironments: Optional[JobEnvironmentsList] = None + extensions: Optional[list[ExtensionName]] = None class JobTemplate(OpenJDModel_v2023_09): @@ -2359,6 +2396,7 @@ class JobTemplate(OpenJDModel_v2023_09): Attributes: specificationVersion (TemplateSpecificationVersion.v2023_09): The OpenJD schema version whose data model this follows. + extensions (Optional[ExtensionNameList]): If provided, a non-empty list of named extensions to enable. name (JobTemplateName): The name of Jobs constructed by this template. steps (StepTemplateList): The Step Templates that comprise the Job Template. description (Optional[str]): A free form string that can be used to describe the Job. @@ -2371,6 +2409,7 @@ class JobTemplate(OpenJDModel_v2023_09): """ specificationVersion: Literal[TemplateSpecificationVersion.JOBTEMPLATE_v2023_09] # noqa: N815 + extensions: Optional[ExtensionNameList] = None name: JobTemplateName steps: StepTemplateList description: Optional[Description] = None @@ -2393,6 +2432,42 @@ class JobTemplate(OpenJDModel_v2023_09): rename_fields={"parameterDefinitions": "parameters"}, ) + @field_validator("extensions") + @classmethod + def _unique_extension_names( + cls, value: Optional[ExtensionNameList] + ) -> Optional[ExtensionNameList]: + if value is not None: + return validate_unique_elements( + value, item_value=lambda v: v, property="extension name" + ) + return value + + @field_validator("extensions") + @classmethod + def _permitted_extension_names( + cls, value: Optional[ExtensionNameList], info: ValidationInfo + ) -> Optional[ExtensionNameList]: + context = cast(ModelParsingContext, info.context) + if value is not None: + # Before processing the extensions field, context.extensions is the list of supported extensions. + # Take the intersection of the input supported extensions with what is implemented + # in this list, as the implementation needs to support an extension for it to be supported. + supported_extensions = context.extensions.intersection(cls.supported_extension_names()) + + unsupported_extensions = set(value).difference(supported_extensions) + if unsupported_extensions: + raise ValueError( + f"Unsupported extension names: {', '.join(sorted(unsupported_extensions))}" + ) + + # After processing the extensions field, context.extensions is the list of + # extension names used by the template. + context.extensions = set(value) + else: + context.extensions = set() + return value + @field_validator("steps") @classmethod def _unique_step_names(cls, v: StepTemplateList) -> StepTemplateList: @@ -2523,6 +2598,7 @@ class EnvironmentTemplate(OpenJDModel_v2023_09): Attributes: specificationVersion (TemplateSpecificationVersion.ENVIRONMENT_v2023_09): The OpenJD schema version whose data model this follows. + extensions (Optional[ExtensionNameList]): If provided, a non-empty list of named extensions to enable. parameterDefinitions (Optional[JobParameterDefinitionList]): The job parameters that are available for use within this template, and that must have values defined for them when creating jobs while this environment template is included. @@ -2530,6 +2606,7 @@ class EnvironmentTemplate(OpenJDModel_v2023_09): """ specificationVersion: Literal[TemplateSpecificationVersion.ENVIRONMENT_v2023_09] + extensions: Optional[ExtensionNameList] = None parameterDefinitions: Optional[JobParameterDefinitionList] = None environment: Environment @@ -2538,6 +2615,42 @@ class EnvironmentTemplate(OpenJDModel_v2023_09): "environment": {"parameterDefinitions"}, } + @field_validator("extensions") + @classmethod + def _unique_extension_names( + cls, value: Optional[ExtensionNameList] + ) -> Optional[ExtensionNameList]: + if value is not None: + return validate_unique_elements( + value, item_value=lambda v: v, property="extension name" + ) + return value + + @field_validator("extensions") + @classmethod + def _permitted_extension_names( + cls, value: Optional[ExtensionNameList], info: ValidationInfo + ) -> Optional[ExtensionNameList]: + context = cast(ModelParsingContext, info.context) + if value is not None: + # Before processing the extensions field, context.extensions is the list of supported extensions. + # Take the intersection of the input supported extensions with what is implemented + # in this list, as the implementation needs to support an extension for it to be supported. + supported_extensions = context.extensions.intersection(cls.supported_extension_names()) + + unsupported_extensions = set(value).difference(supported_extensions) + if unsupported_extensions: + raise ValueError( + f"Unsupported extension names: {', '.join(sorted(unsupported_extensions))}" + ) + + # After processing the extensions field, context.extensions is the list of + # extension names used by the template. + context.extensions = set(value) + else: + context.extensions = set() + return value + @field_validator("parameterDefinitions") @classmethod def _unique_parameter_names( diff --git a/test/openjd/model/test_parse.py b/test/openjd/model/test_parse.py index bb097752..2a07a6bf 100644 --- a/test/openjd/model/test_parse.py +++ b/test/openjd/model/test_parse.py @@ -1,7 +1,9 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +from enum import Enum import json from typing import Any, Type +from unittest.mock import patch import pytest import yaml @@ -15,6 +17,7 @@ model_to_object, ) from openjd.model._types import OpenJDModel +import openjd.model.v2023_09 from openjd.model.v2023_09 import JobTemplate as JobTemplate_2023_09 from openjd.model.v2023_09 import EnvironmentTemplate as EnvironmentTemplate_2023_09 @@ -185,3 +188,94 @@ def test_success(self, template: dict[str, Any], expected_class: Type[OpenJDMode # THEN assert isinstance(result, expected_class) + + +class MockExtensionName(str, Enum): + """A mock enum with only SUPPORTED_NAME for testing.""" + + SUPPORTED_NAME = "SUPPORTED_NAME" + + +@pytest.mark.parametrize( + "template,template_type,decode_function", + [ + pytest.param( + { + "name": "DemoJob", + "specificationVersion": "jobtemplate-2023-09", + "parameterDefinitions": [{"name": "Foo", "type": "FLOAT", "default": "12"}], + "steps": [ + { + "name": "DemoStep", + "script": {"actions": {"onRun": {"command": "echo"}}}, + } + ], + }, + "JobTemplate", + decode_job_template, + id="job template", + ), + pytest.param( + { + "specificationVersion": "environment-2023-09", + "environment": { + "name": "FooEnv", + "description": "A description", + "script": {"actions": {"onEnter": {"command": "echo"}}}, + }, + }, + "EnvironmentTemplate", + decode_environment_template, + id="environment template", + ), + ], +) +def test_template_extensions_list(template, template_type, decode_function) -> None: + with patch.object(openjd.model.v2023_09._model, "ExtensionName", MockExtensionName): + # Confirm the template doesn't include extensions yet and can be decoded + assert "extensions" not in template + decode_function(template=template) + + # If an unimplemented name is provided to supported_extensions, it is ignored + decode_function(template=template, supported_extensions=["UNSUPPORTED_NAME"]) + + # When the requested extension name is in the supported list + template["extensions"] = ["SUPPORTED_NAME"] + model = decode_function(template=template, supported_extensions=["SUPPORTED_NAME"]) + assert model.extensions == ["SUPPORTED_NAME"] + + # If provided, the extensions list cannot be empty + template["extensions"] = [] + with pytest.raises(DecodeValidationError) as excinfo: + decode_function(template=template) + assert ( + f"1 validation errors for {template_type}\nextensions:\n\tList should have at least 1 item after validation, not 0" + in str(excinfo.value) + ) + + # By default no extensions are supported + template["extensions"] = ["SUPPORTED_NAME"] + with pytest.raises(DecodeValidationError) as excinfo: + decode_function(template=template) + assert ( + f"1 validation errors for {template_type}\nextensions:\n\tUnsupported extension names: SUPPORTED_NAME" + in str(excinfo.value) + ) + + # When the request list includes an unsupported extension name + template["extensions"] = ["SUPPORTED_NAME"] + with pytest.raises(DecodeValidationError) as excinfo: + decode_function(template=template, supported_extensions=["UNSUPPORTED_NAME"]) + assert ( + f"1 validation errors for {template_type}\nextensions:\n\tUnsupported extension names: SUPPORTED_NAME" + in str(excinfo.value) + ) + + # If an unimplemented name is provided to supported_extensions, it still can't be requested by the template + template["extensions"] = ["UNSUPPORTED_NAME"] + with pytest.raises(DecodeValidationError) as excinfo: + decode_function(template=template, supported_extensions=["UNSUPPORTED_NAME"]) + assert ( + f"1 validation errors for {template_type}\nextensions:\n\tUnsupported extension names: UNSUPPORTED_NAME" + in str(excinfo.value) + ) From 797fc8be46ce02be60130d0e8a140d5eb7e39d6b Mon Sep 17 00:00:00 2001 From: Mark Wiebe <399551+mwiebe@users.noreply.github.com> Date: Tue, 11 Feb 2025 09:35:40 -0800 Subject: [PATCH 2/3] fix: Change for CodeQL comment Signed-off-by: Mark Wiebe <399551+mwiebe@users.noreply.github.com> --- test/openjd/model/test_parse.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/openjd/model/test_parse.py b/test/openjd/model/test_parse.py index 2a07a6bf..dd619206 100644 --- a/test/openjd/model/test_parse.py +++ b/test/openjd/model/test_parse.py @@ -17,7 +17,7 @@ model_to_object, ) from openjd.model._types import OpenJDModel -import openjd.model.v2023_09 +import openjd from openjd.model.v2023_09 import JobTemplate as JobTemplate_2023_09 from openjd.model.v2023_09 import EnvironmentTemplate as EnvironmentTemplate_2023_09 From 03ccc01bb0748d75d999f37be32c5792962e1d59 Mon Sep 17 00:00:00 2001 From: Mark Wiebe <399551+mwiebe@users.noreply.github.com> Date: Tue, 11 Feb 2025 10:59:11 -0800 Subject: [PATCH 3/3] fix: Clarify comments and extend tests for PR feedback Signed-off-by: Mark Wiebe <399551+mwiebe@users.noreply.github.com> --- src/openjd/model/_parse.py | 4 ++-- src/openjd/model/v2023_09/_model.py | 15 +++++++++++---- test/openjd/model/test_parse.py | 20 ++++++++++++++++++++ 3 files changed, 33 insertions(+), 6 deletions(-) diff --git a/src/openjd/model/_parse.py b/src/openjd/model/_parse.py index c33dda5c..0b8f4dd6 100644 --- a/src/openjd/model/_parse.py +++ b/src/openjd/model/_parse.py @@ -167,7 +167,7 @@ def decode_job_template( Args: template (dict[str, Any]): A Job Template as a dictionary object. - supported_extensions (list[str]): A list of extension names to support. This list is intersected + supported_extensions (Optional[Iterable[str]]): A list of extension names to support. This list is intersected with the extensions names supported by the implementation before processing. Returns: @@ -241,7 +241,7 @@ def decode_environment_template( Args: template (dict[str, Any]): An Environment Template as a dictionary object. - supported_extensions (list[str]): A list of extension names to support. This list is intersected + supported_extensions (Optional[Iterable[str]]): A list of extension names to support. This list is intersected with the extensions names supported by the implementation before processing. Returns: diff --git a/src/openjd/model/v2023_09/_model.py b/src/openjd/model/v2023_09/_model.py index a56729bf..43a79006 100644 --- a/src/openjd/model/v2023_09/_model.py +++ b/src/openjd/model/v2023_09/_model.py @@ -62,9 +62,15 @@ class ModelParsingContext: """ extensions: set[str] - """Initially, is the set of supported extension names. When the 'extensions' - field of the template is processed, this becomes the set of extensions that - the the template requested.""" + """When parsing a top-level model instance, this is the set of supported extension names. + The 'extensions' field is second in the list of model properties for both the job template + and environment template, and when that field is processed it becomes the set of extensions + that the template requested. + + When fields of a model that depend on an extension are processed, its validators should + check whether the needed extension is included in the context and adjust its parsing + as written in the specification. + """ def __init__(self, *, supported_extensions: Optional[Iterable[str]] = None) -> None: self.extensions = set(supported_extensions or []) @@ -2450,7 +2456,8 @@ def _permitted_extension_names( ) -> Optional[ExtensionNameList]: context = cast(ModelParsingContext, info.context) if value is not None: - # Before processing the extensions field, context.extensions is the list of supported extensions. + # Before processing the extensions field, context.extensions is the list of supported extensions + # that were requested in the call of the parse_job_template function. # Take the intersection of the input supported extensions with what is implemented # in this list, as the implementation needs to support an extension for it to be supported. supported_extensions = context.extensions.intersection(cls.supported_extension_names()) diff --git a/test/openjd/model/test_parse.py b/test/openjd/model/test_parse.py index dd619206..80437a42 100644 --- a/test/openjd/model/test_parse.py +++ b/test/openjd/model/test_parse.py @@ -196,6 +196,13 @@ class MockExtensionName(str, Enum): SUPPORTED_NAME = "SUPPORTED_NAME" +class MockExtensionNameWithTwoNames(str, Enum): + """A mock enum with only SUPPORTED_NAME for testing.""" + + SUPPORTED_NAME = "SUPPORTED_NAME" + ANOTHER_SUPPORTED_NAME = "ANOTHER_SUPPORTED_NAME" + + @pytest.mark.parametrize( "template,template_type,decode_function", [ @@ -262,6 +269,12 @@ def test_template_extensions_list(template, template_type, decode_function) -> N in str(excinfo.value) ) + # Extension names cannot be repeated + template["extensions"] = ["SUPPORTED_NAME", "SUPPORTED_NAME"] + with pytest.raises(DecodeValidationError) as excinfo: + decode_function(template=template) + assert "Duplicate values for extension name are not allowed." in str(excinfo.value) + # When the request list includes an unsupported extension name template["extensions"] = ["SUPPORTED_NAME"] with pytest.raises(DecodeValidationError) as excinfo: @@ -279,3 +292,10 @@ def test_template_extensions_list(template, template_type, decode_function) -> N f"1 validation errors for {template_type}\nextensions:\n\tUnsupported extension names: UNSUPPORTED_NAME" in str(excinfo.value) ) + + # For this test, there are two different extension names supported + with patch.object(openjd.model.v2023_09._model, "ExtensionName", MockExtensionNameWithTwoNames): + # When the requested extension name is in the supported list + template["extensions"] = ["ANOTHER_SUPPORTED_NAME"] + model = decode_function(template=template, supported_extensions=["ANOTHER_SUPPORTED_NAME"]) + assert model.extensions == ["ANOTHER_SUPPORTED_NAME"]