From ce882f66a97c989a9c79467cdcd0bc53287ef424 Mon Sep 17 00:00:00 2001 From: Dmitry Sharkov Date: Thu, 30 Oct 2025 13:22:50 -0400 Subject: [PATCH 1/4] polished dcv checker unit tests to get rid of warnings --- .../open_mpic_core/test_mpic_dcv_checker.py | 118 ++++++++---------- 1 file changed, 49 insertions(+), 69 deletions(-) diff --git a/tests/unit/open_mpic_core/test_mpic_dcv_checker.py b/tests/unit/open_mpic_core/test_mpic_dcv_checker.py index d11277b..a9bca4e 100644 --- a/tests/unit/open_mpic_core/test_mpic_dcv_checker.py +++ b/tests/unit/open_mpic_core/test_mpic_dcv_checker.py @@ -164,33 +164,18 @@ async def check_dcv__should_be_case_insensitive_for_challenge_values_for_all_val dcv_response = await self.dcv_checker.check_dcv(dcv_request) assert dcv_response.check_passed is is_case_insensitive - @pytest.mark.parametrize( - "record_type, target_record_data, mock_record_data", - [(DnsRecordType.A, "1.2.00.3", "1.2.0.3"), (DnsRecordType.AAAA, "1:00000::", "1::")], - ) - async def check_dcv__should_disallow_issuance_given_malformed_records_in_ip_address_lookup( - self, record_type, target_record_data, mock_record_data, mocker - ): - dcv_request = ValidCheckCreator.create_valid_dcv_check_request(DcvValidationMethod.IP_ADDRESS, record_type) - dcv_request.dcv_check_parameters.challenge_value = target_record_data - mock_record_data_with_value = {"value": mock_record_data} - dns_response = MockDnsObjectCreator.create_dns_query_answer( - dcv_request.domain_or_ip_target, "", record_type, mock_record_data_with_value, mocker - ) - self._patch_resolver_with_answer_or_exception(mocker, dns_response) - dcv_response = await self.dcv_checker.check_dcv(dcv_request) - assert dcv_response.check_passed is False - # fmt: off - @pytest.mark.parametrize("record_type, target_record_data, mock_record_data", [ - (DnsRecordType.A, "1.2.0.3", "1.2.0.3"), - (DnsRecordType.AAAA, "1:0:00:000:0000::", "1::"), # Expanding zero block - (DnsRecordType.AAAA, "2001:db8:3333:4444:5555:6666:1.2.3.4", "2001:db8:3333:4444:5555:6666:102:304"), # IPv4 in IPv6 - (DnsRecordType.AAAA, "::11.22.33.44", "::b16:212c") # IPv4 in IPv6, leading zeros + @pytest.mark.parametrize("record_type, target_record_data, mock_record_data, should_allow_issuance", [ + (DnsRecordType.A, "1.2.0.3", "1.2.0.3", True), + (DnsRecordType.AAAA, "1:0:00:000:0000::", "1::", True), # Expanding zero block + (DnsRecordType.AAAA, "2001:db8:3333:4444:5555:6666:1.2.3.4", "2001:db8:3333:4444:5555:6666:102:304", True), # IPv4 in IPv6 + (DnsRecordType.AAAA, "::11.22.33.44", "::b16:212c", True), # IPv4 in IPv6, leading zeros + (DnsRecordType.A, "1.2.00.3", "1.2.0.3", False), # malformed IPv4 + (DnsRecordType.AAAA, "1:00000::", "1::", False), # malformed IPv6 ]) # fmt: on - async def check_dcv__should_allow_issuance_for_different_record_formats_in_ip_address_lookup( - self, record_type, target_record_data, mock_record_data, mocker + async def check_dcv__should_allow_issuance_only_for_well_formed_ipv4_and_ipv6_in_ip_address_lookup( + self, record_type, target_record_data, mock_record_data, should_allow_issuance, mocker ): dcv_request = ValidCheckCreator.create_valid_dcv_check_request(DcvValidationMethod.IP_ADDRESS, record_type) dcv_request.dcv_check_parameters.challenge_value = target_record_data @@ -200,7 +185,7 @@ async def check_dcv__should_allow_issuance_for_different_record_formats_in_ip_ad ) self._patch_resolver_with_answer_or_exception(mocker, dns_response) dcv_response = await self.dcv_checker.check_dcv(dcv_request) - assert dcv_response.check_passed is True + assert dcv_response.check_passed is should_allow_issuance # fmt: off @pytest.mark.parametrize("dcv_method, domain, encoded_domain", [ @@ -290,14 +275,13 @@ async def http_based_dcv_checks__should_return_timestamp_and_response_url_and_st dcv_request = ValidCheckCreator.create_valid_dcv_check_request(dcv_method) self._mock_request_specific_http_response(dcv_request, mocker) dcv_response = await self.dcv_checker.check_dcv(dcv_request) - match dcv_method: - case DcvValidationMethod.WEBSITE_CHANGE: - url_scheme = dcv_request.dcv_check_parameters.url_scheme - http_token_path = dcv_request.dcv_check_parameters.http_token_path - expected_url = f"{url_scheme}://{dcv_request.domain_or_ip_target}/{MpicDcvChecker.WELL_KNOWN_PKI_PATH}/{http_token_path}" - case _: - token = dcv_request.dcv_check_parameters.token - expected_url = f"http://{dcv_request.domain_or_ip_target}/{MpicDcvChecker.WELL_KNOWN_ACME_PATH}/{token}" # noqa E501 (http) + if dcv_method == DcvValidationMethod.WEBSITE_CHANGE: + url_scheme = dcv_request.dcv_check_parameters.url_scheme + http_token_path = dcv_request.dcv_check_parameters.http_token_path + expected_url = f"{url_scheme}://{dcv_request.domain_or_ip_target}/{MpicDcvChecker.WELL_KNOWN_PKI_PATH}/{http_token_path}" + else: + token = dcv_request.dcv_check_parameters.token + expected_url = f"http://{dcv_request.domain_or_ip_target}/{MpicDcvChecker.WELL_KNOWN_ACME_PATH}/{token}" # noqa E501 (http) assert dcv_response.timestamp_ns is not None assert dcv_response.details.response_url == expected_url assert dcv_response.details.response_status_code == 200 @@ -364,13 +348,12 @@ async def http_based_dcv_checks__should_auto_insert_well_known_path_segment( self, dcv_method, expected_segment, mocker ): dcv_request = ValidCheckCreator.create_valid_dcv_check_request(dcv_method) - match dcv_method: - case DcvValidationMethod.WEBSITE_CHANGE: - dcv_request.dcv_check_parameters.http_token_path = "test-path" - url_scheme = dcv_request.dcv_check_parameters.url_scheme - case _: - dcv_request.dcv_check_parameters.token = "test-path" - url_scheme = "http" + if dcv_method == DcvValidationMethod.WEBSITE_CHANGE: + dcv_request.dcv_check_parameters.http_token_path = "test-path" + url_scheme = dcv_request.dcv_check_parameters.url_scheme + else: + dcv_request.dcv_check_parameters.token = "test-path" + url_scheme = "http" self._mock_request_specific_http_response(dcv_request, mocker) dcv_response = await self.dcv_checker.check_dcv(dcv_request) expected_url = f"{url_scheme}://{dcv_request.domain_or_ip_target}/{expected_segment}/test-path" @@ -381,11 +364,10 @@ async def http_based_dcv_checks__should_follow_redirects_and_track_redirect_hist self, dcv_method, mocker ): dcv_request = ValidCheckCreator.create_valid_dcv_check_request(dcv_method) - match dcv_request.dcv_check_parameters.validation_method: - case DcvValidationMethod.WEBSITE_CHANGE: - expected_challenge = dcv_request.dcv_check_parameters.challenge_value - case _: - expected_challenge = dcv_request.dcv_check_parameters.key_authorization + if dcv_request.dcv_check_parameters.validation_method == DcvValidationMethod.WEBSITE_CHANGE: + expected_challenge = dcv_request.dcv_check_parameters.challenge_value + else: + expected_challenge = dcv_request.dcv_check_parameters.key_authorization history = self._create_http_redirect_history() mock_response = TestMpicDcvChecker._create_mock_http_response(200, expected_challenge, {"history": history}) @@ -472,11 +454,10 @@ async def http_based_dcv_checks__should_not_pass_on_invalid_redirect_code_or_por self, dcv_method, code_or_port, mocker ): dcv_request = ValidCheckCreator.create_valid_dcv_check_request(dcv_method) - match dcv_request.dcv_check_parameters.validation_method: - case DcvValidationMethod.WEBSITE_CHANGE: - expected_challenge = dcv_request.dcv_check_parameters.challenge_value - case _: - expected_challenge = dcv_request.dcv_check_parameters.key_authorization + if dcv_request.dcv_check_parameters.validation_method == DcvValidationMethod.WEBSITE_CHANGE: + expected_challenge = dcv_request.dcv_check_parameters.challenge_value + else: + expected_challenge = dcv_request.dcv_check_parameters.key_authorization if code_or_port == "unacceptable_code": history = self._create_http_redirect_history_with_disallowed_code() @@ -842,16 +823,15 @@ def _create_mock_http_response_with_content_and_encoding(content: bytes, encodin return response def _mock_request_specific_http_response(self, dcv_request: DcvCheckRequest, mocker): - match dcv_request.dcv_check_parameters.validation_method: - case DcvValidationMethod.WEBSITE_CHANGE: - url_scheme = dcv_request.dcv_check_parameters.url_scheme - http_token_path = dcv_request.dcv_check_parameters.http_token_path - expected_url = f"{url_scheme}://{dcv_request.domain_or_ip_target}/{MpicDcvChecker.WELL_KNOWN_PKI_PATH}/{http_token_path}" - expected_challenge = dcv_request.dcv_check_parameters.challenge_value - case _: - token = dcv_request.dcv_check_parameters.token - expected_url = f"http://{dcv_request.domain_or_ip_target}/{MpicDcvChecker.WELL_KNOWN_ACME_PATH}/{token}" # noqa E501 (http) - expected_challenge = dcv_request.dcv_check_parameters.key_authorization + if dcv_request.dcv_check_parameters.validation_method == DcvValidationMethod.WEBSITE_CHANGE: + url_scheme = dcv_request.dcv_check_parameters.url_scheme + http_token_path = dcv_request.dcv_check_parameters.http_token_path + expected_url = f"{url_scheme}://{dcv_request.domain_or_ip_target}/{MpicDcvChecker.WELL_KNOWN_PKI_PATH}/{http_token_path}" + expected_challenge = dcv_request.dcv_check_parameters.challenge_value + else: + token = dcv_request.dcv_check_parameters.token + expected_url = f"http://{dcv_request.domain_or_ip_target}/{MpicDcvChecker.WELL_KNOWN_ACME_PATH}/{token}" # noqa E501 (http) + expected_challenge = dcv_request.dcv_check_parameters.key_authorization success_response = TestMpicDcvChecker._create_mock_http_response(200, expected_challenge) not_found_response = TestMpicDcvChecker._create_mock_http_response(404, "Not Found", {"reason": "Not Found"}) @@ -945,7 +925,7 @@ def _mock_dns_resolve_call_with_specific_flag(self, dcv_request: DcvCheckRequest def _mock_dns_resolve_call_with_cname_chain(self, dcv_request: DcvCheckRequest, mocker): test_dns_query_answer = self._create_basic_dns_response_for_mock(dcv_request, mocker) test_dns_query_answer.chaining_result = ChainingResult( - canonical_name="sub.example.com", + canonical_name=dns.name.from_text("sub.example.com"), answer=None, minimum_ttl=1, cnames=[ @@ -958,13 +938,12 @@ def _mock_dns_resolve_call_with_cname_chain(self, dcv_request: DcvCheckRequest, def _mock_dns_resolve_call_getting_multiple_txt_records(self, dcv_request: DcvCheckRequest, mocker): check_parameters = dcv_request.dcv_check_parameters - match check_parameters.validation_method: - case DcvValidationMethod.DNS_CHANGE: - record_data = {"value": check_parameters.challenge_value} - record_name_prefix = check_parameters.dns_name_prefix - case _: - record_data = {"value": check_parameters.key_authorization_hash} - record_name_prefix = "_acme-challenge" + if check_parameters.validation_method == DcvValidationMethod.DNS_CHANGE: + record_data = {"value": check_parameters.challenge_value} + record_name_prefix = check_parameters.dns_name_prefix + else: + record_data = {"value": check_parameters.key_authorization_hash} + record_name_prefix = "_acme-challenge" txt_record_1 = MockDnsObjectCreator.create_record_by_type(DnsRecordType.TXT, record_data) txt_record_2 = MockDnsObjectCreator.create_record_by_type(DnsRecordType.TXT, {"value": "whatever2"}) txt_record_3 = MockDnsObjectCreator.create_record_by_type(DnsRecordType.TXT, {"value": "whatever3"}) @@ -979,6 +958,7 @@ def _mock_dns_resolve_call_getting_multiple_txt_records(self, dcv_request: DcvCh def _create_basic_dns_response_for_mock(self, dcv_request: DcvCheckRequest, mocker) -> dns.resolver.Answer: check_parameters = dcv_request.dcv_check_parameters + record_data = None match check_parameters.validation_method: case ( DcvValidationMethod.DNS_CHANGE @@ -1042,7 +1022,7 @@ def shuffle_case(string_to_shuffle: str) -> str: if result.islower() or result.isupper(): for i, char in enumerate(result): if char.isalpha(): - result = result[:i] + char.swapcase() + result[i + 1 :] + result = result[:i] + char.swapcase() + result[i + 1:] break return result From 2e6609609d178605b22bae2d0a0786c93175d7e2 Mon Sep 17 00:00:00 2001 From: Dmitry Sharkov Date: Sun, 23 Nov 2025 20:32:37 -0500 Subject: [PATCH 2/4] modified domain encoder to allow pre-encoded labels, either 2008 or 2003 idna spec. --- .../common_util/domain_encoder.py | 43 ++++++++++++---- .../open_mpic_core/test_domain_encoder.py | 49 +++++++++++++++++-- 2 files changed, 78 insertions(+), 14 deletions(-) diff --git a/src/open_mpic_core/common_util/domain_encoder.py b/src/open_mpic_core/common_util/domain_encoder.py index dff01f6..218e132 100644 --- a/src/open_mpic_core/common_util/domain_encoder.py +++ b/src/open_mpic_core/common_util/domain_encoder.py @@ -1,5 +1,8 @@ import ipaddress + +import dns.name import idna +from dns.name import IDNAException class DomainEncoder: @@ -14,12 +17,34 @@ def prepare_target_for_lookup(domain_or_ip_target) -> str: pass # Convert to IDNA/Punycode - try: - is_wildcard = domain_or_ip_target.startswith("*.") - if is_wildcard: - domain_or_ip_target = domain_or_ip_target[2:] # Remove *. prefix - - encoded_domain = idna.encode(domain_or_ip_target, uts46=True).decode("ascii") - return encoded_domain - except idna.IDNAError as e: - raise ValueError(f"Invalid domain name: {str(e)}") + is_wildcard = domain_or_ip_target.startswith("*.") + if is_wildcard: + domain_or_ip_target = domain_or_ip_target[2:] # Remove *. prefix + + prepared_domain = domain_or_ip_target + is_already_encoded = False + + for label_text in domain_or_ip_target.split("."): + # check if any label is punycode encoded + if label_text.startswith("xn--"): + try: + label_bytes = label_text.encode("ascii") + try: + dns.name.IDNA_2008_Strict.decode(label_bytes) + except IDNAException: + try: + dns.name.IDNA_2003_Strict.decode(label_bytes) + except IDNAException as e2: + raise ValueError(f"Invalid domain name: {str(e2)}") + except UnicodeEncodeError: + raise ValueError(f"Invalid domain name: Label '{label_text}' is not valid ASCII.") + # if we made it here then we had a valid punycode label + is_already_encoded = True + + if not is_already_encoded: + try: + prepared_domain = idna.encode(domain_or_ip_target, uts46=True).decode("ascii") + except idna.IDNAError as e: + raise ValueError(f"Invalid domain name: {str(e)}") + + return prepared_domain diff --git a/tests/unit/open_mpic_core/test_domain_encoder.py b/tests/unit/open_mpic_core/test_domain_encoder.py index 238fc5a..3ee8ad9 100644 --- a/tests/unit/open_mpic_core/test_domain_encoder.py +++ b/tests/unit/open_mpic_core/test_domain_encoder.py @@ -4,15 +4,16 @@ class TestDomainEncoder: + @staticmethod @pytest.mark.parametrize( "input_domain, expected_output", [ ("café.example.com", "xn--caf-dma.example.com"), ("bücher.example.de", "xn--bcher-kva.example.de"), - ("свічка.example.com", "xn--80ady0a5a8f.example.com"), ("127.0.0.1", "127.0.0.1"), ("example.com", "example.com"), + ("subdomain.café.example.com", "subdomain.xn--caf-dma.example.com"), ], ) def prepare_domain_for_lookup__should_convert_nonascii_domain_to_punycode(input_domain, expected_output): @@ -23,8 +24,9 @@ def prepare_domain_for_lookup__should_convert_nonascii_domain_to_punycode(input_ @pytest.mark.parametrize( "input_domain, expected_output", [ - ("*.example.com", "example.com"), - ("*.café.example.com", "xn--caf-dma.example.com"), + ("*.example.com", "example.com"), # ascii + ("*.café.example.com", "xn--caf-dma.example.com"), # not encoded + ("*.xn--yaztura-tfb.com", "xn--yaztura-tfb.com"), # already encoded ], ) def prepare_domain_for_lookup__should_remove_leading_asterisk_from_wildcard_domain(input_domain, expected_output): @@ -32,6 +34,43 @@ def prepare_domain_for_lookup__should_remove_leading_asterisk_from_wildcard_doma assert result == expected_output @staticmethod - def prepare_domain_for_lookup__should_raise_value_error_if_idna_error_encountered(): + @pytest.mark.parametrize( + "input_domain, expected_output", + [ + ("xn--caf-dma.example.com", "xn--caf-dma.example.com"), # café.example.com idna2008 + ("*.xn--yaztura-tfb.com", "xn--yaztura-tfb.com"), + ("xn--nxasmm1c.com", "xn--nxasmm1c.com"), # "βόλος.com" idna2008 + ("xn--ls8h.la", "xn--ls8h.la"), # poop emoji idna2003 + ("xn--4ca.com", "xn--4ca.com"), # "√.com" idna2003 + ], + ) + def prepare_domain_for_lookup__should_allow_already_encoded_domain(input_domain, expected_output): + result = DomainEncoder.prepare_target_for_lookup(input_domain) + assert result == expected_output + + @staticmethod + @pytest.mark.parametrize( + "input_domain, expected_output", + [ + ("sub.xn--caf-dma.example.com", "sub.xn--caf-dma.example.com"), + ("sub.xn--4ca.com", "sub.xn--4ca.com"), # "√.com" idna2003 + ("sub.xn--ls8h.la", "sub.xn--ls8h.la"), # poop emoji idna2003 + ], + ) + def prepare_domain_for_lookup__should_detect_punycode_in_inner_labels(input_domain, expected_output): + result = DomainEncoder.prepare_target_for_lookup(input_domain) + assert result == expected_output + + @staticmethod + @pytest.mark.parametrize( + "input_domain", + [ + "*example.com", + "exa mple.com", + "exam!ple.com", + "xn--café.com", # invalid (punycode prefix on non-ascii string) + ], + ) + def prepare_domain_for_lookup__should_raise_value_error_given_malformed_domain(input_domain): with pytest.raises(ValueError): - DomainEncoder.prepare_target_for_lookup("*example.com") + DomainEncoder.prepare_target_for_lookup(input_domain) From 388a3ac3e31120a0fb50abda8028a6434b93118a Mon Sep 17 00:00:00 2001 From: Dmitry Sharkov Date: Mon, 24 Nov 2025 16:17:18 -0500 Subject: [PATCH 3/4] adding an ability to specify the cohort number to use in a single attempt without knowing what's in it --- src/open_mpic_core/__about__.py | 2 +- src/open_mpic_core/__init__.py | 7 +- .../common_domain/messages/ErrorMessages.py | 1 + .../domain/cohort_creation_exception.py | 4 - .../domain/mpic_orchestration_parameters.py | 1 + .../domain/mpic_request_errors.py | 10 +++ .../mpic_coordinator/mpic_coordinator.py | 17 ++++- .../mpic_request_validator.py | 5 ++ .../open_mpic_core/test_mpic_coordinator.py | 75 ++++++++++++++++++- 9 files changed, 112 insertions(+), 10 deletions(-) delete mode 100644 src/open_mpic_core/mpic_coordinator/domain/cohort_creation_exception.py diff --git a/src/open_mpic_core/__about__.py b/src/open_mpic_core/__about__.py index 7856d12..0a895f3 100644 --- a/src/open_mpic_core/__about__.py +++ b/src/open_mpic_core/__about__.py @@ -1 +1 @@ -__version__ = "6.1.0" +__version__ = "6.2.0" diff --git a/src/open_mpic_core/__init__.py b/src/open_mpic_core/__init__.py index 2fcc7c2..ca3c1a0 100644 --- a/src/open_mpic_core/__init__.py +++ b/src/open_mpic_core/__init__.py @@ -46,11 +46,14 @@ MpicEffectiveOrchestrationParameters, ) -from open_mpic_core.mpic_coordinator.domain.cohort_creation_exception import CohortCreationException from open_mpic_core.mpic_coordinator.domain.perspective_response import PerspectiveResponse from open_mpic_core.mpic_coordinator.domain.mpic_request import MpicRequest, MpicDcvRequest, MpicCaaRequest from open_mpic_core.mpic_coordinator.domain.mpic_response import MpicResponse, MpicCaaResponse, MpicDcvResponse -from open_mpic_core.mpic_coordinator.domain.mpic_request_errors import MpicRequestValidationException +from open_mpic_core.mpic_coordinator.domain.mpic_request_errors import ( + MpicRequestValidationException, + CohortCreationException, + CohortSelectionException +) from open_mpic_core.mpic_coordinator.domain.remote_check_call_configuration import RemoteCheckCallConfiguration from open_mpic_core.mpic_coordinator.domain.remote_check_exception import RemoteCheckException from open_mpic_core.mpic_coordinator.messages.mpic_request_validation_messages import MpicRequestValidationMessages diff --git a/src/open_mpic_core/common_domain/messages/ErrorMessages.py b/src/open_mpic_core/common_domain/messages/ErrorMessages.py index fbbf919..cdac874 100644 --- a/src/open_mpic_core/common_domain/messages/ErrorMessages.py +++ b/src/open_mpic_core/common_domain/messages/ErrorMessages.py @@ -11,6 +11,7 @@ class ErrorMessages(Enum): GENERAL_HTTP_ERROR = ('mpic_error:http', 'An HTTP error occurred: Response status {0}, Response reason: {1}') INVALID_REDIRECT_ERROR = ('mpic_error:redirect:invalid', 'Invalid redirect. Redirect code: {0}, target: {1}') COHORT_CREATION_ERROR = ('mpic_error:coordinator:cohort', 'The coordinator could not construct a cohort of size {0}') + COHORT_SELECTION_ERROR = ('mpic_error:coordinator:cohort_selection', 'The coordinator could not select cohort number {0} from available cohorts.') TLS_ALPN_ERROR_CERTIFICATE_EXTENSION_MISSING = ('mpic_error:dcv_checker:tls_alpn:certificate:extension_missing', 'The TLS ALPN certificate was missing an extension.') TLS_ALPN_ERROR_CERTIFICATE_ALPN_EXTENSION_NONCRITICAL = ('mpic_error:dcv_checker:tls_alpn:certificate:noncritical_alpn_extension', 'The TLS ALPN certificate has non-critical id-pe-acmeIdentifier extension') diff --git a/src/open_mpic_core/mpic_coordinator/domain/cohort_creation_exception.py b/src/open_mpic_core/mpic_coordinator/domain/cohort_creation_exception.py deleted file mode 100644 index d46f4df..0000000 --- a/src/open_mpic_core/mpic_coordinator/domain/cohort_creation_exception.py +++ /dev/null @@ -1,4 +0,0 @@ - -class CohortCreationException(Exception): - def __init__(self, message): - super().__init__(message) diff --git a/src/open_mpic_core/mpic_coordinator/domain/mpic_orchestration_parameters.py b/src/open_mpic_core/mpic_coordinator/domain/mpic_orchestration_parameters.py index 5e9b4c3..1e6a714 100644 --- a/src/open_mpic_core/mpic_coordinator/domain/mpic_orchestration_parameters.py +++ b/src/open_mpic_core/mpic_coordinator/domain/mpic_orchestration_parameters.py @@ -9,6 +9,7 @@ class BaseMpicOrchestrationParameters(BaseModel, ABC): class MpicRequestOrchestrationParameters(BaseMpicOrchestrationParameters): max_attempts: int | None = None + cohort_for_single_attempt: int | None = None # sets max_attempts to 1 if defined; must be > 0 class MpicEffectiveOrchestrationParameters(BaseMpicOrchestrationParameters): diff --git a/src/open_mpic_core/mpic_coordinator/domain/mpic_request_errors.py b/src/open_mpic_core/mpic_coordinator/domain/mpic_request_errors.py index 5600889..9eda1dd 100644 --- a/src/open_mpic_core/mpic_coordinator/domain/mpic_request_errors.py +++ b/src/open_mpic_core/mpic_coordinator/domain/mpic_request_errors.py @@ -1,2 +1,12 @@ class MpicRequestValidationException(Exception): pass + + +class CohortCreationException(Exception): + def __init__(self, message): + super().__init__(message) + + +class CohortSelectionException(Exception): + def __init__(self, message): + super().__init__(message) \ No newline at end of file diff --git a/src/open_mpic_core/mpic_coordinator/mpic_coordinator.py b/src/open_mpic_core/mpic_coordinator/mpic_coordinator.py index 292e0cf..73116f1 100644 --- a/src/open_mpic_core/mpic_coordinator/mpic_coordinator.py +++ b/src/open_mpic_core/mpic_coordinator/mpic_coordinator.py @@ -14,7 +14,7 @@ from open_mpic_core import MpicValidationError, ErrorMessages from open_mpic_core import CheckType from open_mpic_core import CohortCreator -from open_mpic_core import CohortCreationException +from open_mpic_core import CohortCreationException, CohortSelectionException from open_mpic_core import RemoteCheckException from open_mpic_core import RemoteCheckCallConfiguration from open_mpic_core import RemotePerspective @@ -79,12 +79,22 @@ async def coordinate_mpic(self, mpic_request: MpicRequest) -> MpicResponse: if len(perspective_cohorts) == 0: raise CohortCreationException(ErrorMessages.COHORT_CREATION_ERROR.message.format(perspective_count)) + # check if a specific cohort is requested for single attempt + cohort_to_use = None + if orchestration_parameters is not None and orchestration_parameters.cohort_for_single_attempt is not None: + cohort_to_use = orchestration_parameters.cohort_for_single_attempt + if not MpicRequestValidator.is_requested_cohort_for_single_attempt_valid( + cohort_to_use, len(perspective_cohorts) + ): + raise CohortSelectionException(ErrorMessages.COHORT_SELECTION_ERROR.message.format(cohort_to_use)) + quorum_count = self.determine_required_quorum_count(orchestration_parameters, perspective_count) if ( orchestration_parameters is not None and orchestration_parameters.max_attempts is not None and orchestration_parameters.max_attempts > 0 + and orchestration_parameters.cohort_for_single_attempt is None ): max_attempts = orchestration_parameters.max_attempts if self.global_max_attempts is not None and max_attempts > self.global_max_attempts: @@ -96,7 +106,10 @@ async def coordinate_mpic(self, mpic_request: MpicRequest) -> MpicResponse: cohort_cycle = cycle(perspective_cohorts) while attempts <= max_attempts: - perspectives_to_use = next(cohort_cycle) + if cohort_to_use is not None: + perspectives_to_use = perspective_cohorts[cohort_to_use - 1] # cohorts are 1-indexed for the user + else: + perspectives_to_use = next(cohort_cycle) # Collect async calls to invoke for each perspective. async_calls_to_issue = MpicCoordinator.collect_checker_calls_to_issue(mpic_request, perspectives_to_use) diff --git a/src/open_mpic_core/mpic_coordinator/mpic_request_validator.py b/src/open_mpic_core/mpic_coordinator/mpic_request_validator.py index c718797..707e8f7 100644 --- a/src/open_mpic_core/mpic_coordinator/mpic_request_validator.py +++ b/src/open_mpic_core/mpic_coordinator/mpic_request_validator.py @@ -54,6 +54,11 @@ def is_requested_perspective_count_valid(requested_perspective_count, target_per target_perspectives ) + @staticmethod + def is_requested_cohort_for_single_attempt_valid(cohort_for_single_attempt, number_of_cohorts) -> bool: + # check if cohort_for_single_attempt is an integer and is within the number of available cohorts + return isinstance(cohort_for_single_attempt, int) and 1 <= cohort_for_single_attempt <= number_of_cohorts + @staticmethod def validate_quorum_count(requested_perspective_count, quorum_count, request_validation_issues) -> None: # quorum_count can be no less than perspectives-1 if perspectives <= 5 diff --git a/tests/unit/open_mpic_core/test_mpic_coordinator.py b/tests/unit/open_mpic_core/test_mpic_coordinator.py index 4fec4cd..2925593 100644 --- a/tests/unit/open_mpic_core/test_mpic_coordinator.py +++ b/tests/unit/open_mpic_core/test_mpic_coordinator.py @@ -11,6 +11,8 @@ ErrorMessages, CaaCheckResponse, CaaCheckResponseDetails, + CohortCreationException, + CohortSelectionException, MpicRequestOrchestrationParameters, RemotePerspective, MpicRequestValidationException, @@ -21,7 +23,6 @@ ) from open_mpic_core.common_domain.enum.regional_internet_registry import RegionalInternetRegistry -from open_mpic_core.mpic_coordinator.domain.cohort_creation_exception import CohortCreationException from unit.test_util.valid_mpic_request_creator import ValidMpicRequestCreator @@ -458,6 +459,78 @@ async def coordinate_mpic__should_raise_exception_given_logically_invalid_mpic_r with pytest.raises(MpicRequestValidationException): await mpic_coordinator.coordinate_mpic(mpic_request) + @pytest.mark.parametrize("cohort_for_single_attempt", [1, 2]) + async def coordinate_mpic__should_perform_attempt_with_cohort_if_single_attempt_cohort_number_specified( + self, cohort_for_single_attempt + ): + # will create 2 cohorts with 3 perspectives each (2 in RIR 'ARIN', and 1 in 'RIPE NCC'). + perspectives = [ + RemotePerspective(rir=RegionalInternetRegistry.ARIN, code="us-east-1"), + RemotePerspective(rir=RegionalInternetRegistry.ARIN, code="us-west-1"), + RemotePerspective(rir=RegionalInternetRegistry.RIPE_NCC, code="eu-central-1"), + RemotePerspective(rir=RegionalInternetRegistry.ARIN, code="us-east-2"), + RemotePerspective(rir=RegionalInternetRegistry.ARIN, code="us-west-2"), + RemotePerspective(rir=RegionalInternetRegistry.RIPE_NCC, code="eu-central-2"), + ] + mpic_coordinator_config = self.create_mpic_coordinator_configuration() + mpic_coordinator_config.target_perspectives = perspectives + + mocked_call_remote_perspective_function = AsyncMock() + mocked_call_remote_perspective_function.side_effect = TestMpicCoordinator.SideEffectForMockedPayloads( + self.create_passing_caa_check_response + ) + mpic_coordinator = MpicCoordinator(mocked_call_remote_perspective_function, mpic_coordinator_config) + + mpic_request = ValidMpicRequestCreator.create_valid_caa_mpic_request() + mpic_request.orchestration_parameters = MpicRequestOrchestrationParameters( + perspective_count=3, + cohort_for_single_attempt=cohort_for_single_attempt + ) + + mpic_response = await mpic_coordinator.coordinate_mpic(mpic_request) + assert mpic_response.is_valid is True + assert mpic_response.actual_orchestration_parameters.attempt_count == 1 + + # fmt: off + @pytest.mark.parametrize("cohort_size, single_attempt_cohort_number", [ + (2, 0), + (2, -1), + (3, 4), + (6, 2) + ]) + # fmt: on + async def coordinate_mpic__should_raise_exception_given_invalid_single_attempt_cohort_number_specified( + self, cohort_size, single_attempt_cohort_number + ): + # If cohort_size is 2, should create cohorts with 2 perspectives each. (One cohort will be all 'ARIN'.) + # If cohort_size is 3, should create cohorts with 3 perspectives each (2 in RIR 'ARIN', and 1 in 'RIPE NCC'). + # If cohort_size is 6, should create cohort with 6 perspectives (4 in RIR 'ARIN', and 2 in 'RIPE NCC'). + perspectives = [ + RemotePerspective(rir=RegionalInternetRegistry.ARIN, code="us-east-1"), + RemotePerspective(rir=RegionalInternetRegistry.ARIN, code="us-west-1"), + RemotePerspective(rir=RegionalInternetRegistry.RIPE_NCC, code="eu-central-1"), + RemotePerspective(rir=RegionalInternetRegistry.ARIN, code="us-east-2"), + RemotePerspective(rir=RegionalInternetRegistry.ARIN, code="us-west-2"), + RemotePerspective(rir=RegionalInternetRegistry.RIPE_NCC, code="eu-central-2"), + ] + mpic_coordinator_config = self.create_mpic_coordinator_configuration() + mpic_coordinator_config.target_perspectives = perspectives + + mocked_call_remote_perspective_function = AsyncMock() + mocked_call_remote_perspective_function.side_effect = TestMpicCoordinator.SideEffectForMockedPayloads( + self.create_passing_caa_check_response + ) + mpic_coordinator = MpicCoordinator(mocked_call_remote_perspective_function, mpic_coordinator_config) + + mpic_request = ValidMpicRequestCreator.create_valid_caa_mpic_request() + mpic_request.orchestration_parameters = MpicRequestOrchestrationParameters( + perspective_count=cohort_size, + cohort_for_single_attempt=single_attempt_cohort_number + ) + + with pytest.raises(CohortSelectionException): + await mpic_coordinator.coordinate_mpic(mpic_request) + async def coordinate_mpic__should_return_trace_identifier_if_included_in_request(self): mpic_request = ValidMpicRequestCreator.create_valid_caa_mpic_request() mpic_request.trace_identifier = "test_trace_identifier" From b43e944e15c14da49cfc8299bee29b71c373798e Mon Sep 17 00:00:00 2001 From: Dmitry Sharkov Date: Sun, 30 Nov 2025 17:47:37 -0500 Subject: [PATCH 4/4] updated version of API being implemented --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index f916457..5712f14 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,7 +57,7 @@ build.targets.wheel.packages = ["src/open_mpic_core"] "./tests/unit/test_util" = "open_mpic_core_test/test_util" # include tests in the wheel to facilitate integration testing in wrapper projects [tool.api] -spec_version = "3.6.0" +spec_version = "3.7.0" spec_repository = "https://github.com/open-mpic/open-mpic-specification" [tool.hatch.envs.default]