diff --git a/src/open_mpic_core/__about__.py b/src/open_mpic_core/__about__.py index 833eacf..221cd93 100644 --- a/src/open_mpic_core/__about__.py +++ b/src/open_mpic_core/__about__.py @@ -1 +1 @@ -__version__ = "6.3.0" +__version__ = "6.3.1" diff --git a/src/open_mpic_core/common_domain/check_parameters.py b/src/open_mpic_core/common_domain/check_parameters.py index c23e419..8d24c44 100644 --- a/src/open_mpic_core/common_domain/check_parameters.py +++ b/src/open_mpic_core/common_domain/check_parameters.py @@ -1,5 +1,6 @@ from abc import ABC from typing import Literal, Union, Any, Set, Annotated +from urllib.parse import urlparse from pydantic import BaseModel, field_validator, Field @@ -61,6 +62,26 @@ class DcvDnsPersistentValidationParameters(DcvValidationParameters): issuer_domain_names: list[str] # Disclosed issuer domain names from CA's CP/CPS expected_account_uri: str # The specific account URI to validate + # expected_account_uri should be a URI with a scheme and host (e.g. "https://example.com/acct/123") + @field_validator("expected_account_uri") + @classmethod + def validate_account_uri(cls, v: str) -> str: + parsed_account_uri = urlparse(v) + if not parsed_account_uri.scheme or not parsed_account_uri.netloc: + raise ValueError(f"expected_account_uri must be a valid URI with scheme and host, got {v}") + return v + + @field_validator("issuer_domain_names") + @classmethod + def validate_issuer_domain_names(cls, v: list[str]) -> list[str]: + if not v: + raise ValueError("issuer_domain_names must be a non-empty list of domain names") + for domain in v: + # check that v is non-empty + if not domain: + raise ValueError("issuer_domain_names must not contain empty strings") + return v + class DcvContactEmailTxtValidationParameters(DcvGeneralDnsValidationParameters): validation_method: Literal[DcvValidationMethod.CONTACT_EMAIL_TXT] = DcvValidationMethod.CONTACT_EMAIL_TXT diff --git a/src/open_mpic_core/common_domain/messages/ErrorMessages.py b/src/open_mpic_core/common_domain/messages/ErrorMessages.py index cdac874..f161d11 100644 --- a/src/open_mpic_core/common_domain/messages/ErrorMessages.py +++ b/src/open_mpic_core/common_domain/messages/ErrorMessages.py @@ -4,7 +4,7 @@ class ErrorMessages(Enum): CAA_LOOKUP_ERROR = ('mpic_error:caa_checker:lookup', 'There was an error looking up the CAA record: {0}') DCV_LOOKUP_ERROR = ('mpic_error:dcv_checker:lookup', 'There was an error looking up the DCV record. Error type: {0}, Error message: {1}') - DCV_PARAMETER_ERROR = ('mpic_error:dcv_checker:parameter:key_authorization_hash', 'The provided key_authorization_hash contained invalid characters: {0}') + DCV_PARAMETER_ERROR = ('mpic_error:dcv_checker:parameter', 'The provided parameter was invalid: {0}, Value provided: {1}') COORDINATOR_COMMUNICATION_ERROR = ('mpic_error:coordinator:communication', 'Communication with the remote perspective failed.') COORDINATOR_REMOTE_CHECK_ERROR = ('mpic_error:coordinator:remote_check', 'The remote check failed to complete: {0}') TOO_MANY_FAILED_PERSPECTIVES_ERROR = ('mpic_error:coordinator:too_many_failed_perspectives', 'Too many perspectives failed to complete the check.') diff --git a/src/open_mpic_core/mpic_dcv_checker/dcv_tls_alpn_validator.py b/src/open_mpic_core/mpic_dcv_checker/dcv_tls_alpn_validator.py index 90c90af..233144a 100644 --- a/src/open_mpic_core/mpic_dcv_checker/dcv_tls_alpn_validator.py +++ b/src/open_mpic_core/mpic_dcv_checker/dcv_tls_alpn_validator.py @@ -79,9 +79,16 @@ async def perform_tls_alpn_validation(self, request: DcvCheckRequest) -> DcvChec context.minimum_version = TLSVersion.TLSv1_2 with getpeercert_with_binary_info(): # monkeypatch overrides default behavior and gets binary cert info reader, writer = await asyncio.open_connection( - hostname, 443, ssl=context, server_hostname=sni_target # use the real host name # pass in the context. + hostname, + 443, + ssl=context, + server_hostname=sni_target, # use the real host name # pass in the context. ) - binary_cert = writer.get_extra_info("peercert") + try: + binary_cert = writer.get_extra_info("peercert") + finally: + writer.close() + await writer.wait_closed() dcv_check_response.check_completed = True # check will be considered "complete" whether it passes or fails @@ -89,7 +96,7 @@ async def perform_tls_alpn_validation(self, request: DcvCheckRequest) -> DcvChec subject_alt_name_extension = None acme_tls_alpn_extension = None - + for extension in x509_cert.extensions: if extension.oid.dotted_string == self.ACME_TLS_ALPN_OID_DOTTED_STRING: acme_tls_alpn_extension = extension @@ -120,24 +127,28 @@ async def perform_tls_alpn_validation(self, request: DcvCheckRequest) -> DcvChec + key_authorization_hash_binary ) self.logger.debug(f"tls-alpn-01: binary_challenge_seen: {binary_challenge_seen}") - self.logger.debug(f"tls-alpn-01: key_authorization_hash_binary: {key_authorization_hash_binary}") - + self.logger.debug( + f"tls-alpn-01: key_authorization_hash_binary: {key_authorization_hash_binary}" + ) + except ValueError: dcv_check_response.errors = [ - MpicValidationError.create(ErrorMessages.DCV_PARAMETER_ERROR, key_authorization_hash) + MpicValidationError.create( + ErrorMessages.DCV_PARAMETER_ERROR, "key_authorization_hash", key_authorization_hash + ) ] else: # Only assign the check_passed attribute if we properly parsed the challenge. dcv_check_response.check_passed = binary_challenge_seen == key_authorization_hash_binary - + # Obtain the certs common name for logging. common_name_attributes = x509_cert.subject.get_attributes_for_oid(NameOID.COMMON_NAME) common_name = None if len(common_name_attributes) > 0: common_name = str(common_name_attributes[0].value) self.logger.debug(f"common name: {common_name}") - dcv_check_response.details.common_name = common_name # Cert common name for logging info. - + dcv_check_response.details.common_name = common_name # Cert common name for logging info. + self.logger.debug(f"tls-alpn-01: passed? {dcv_check_response.check_passed}") dcv_check_response.timestamp_ns = time.time_ns() except asyncio.TimeoutError as e: @@ -169,7 +180,9 @@ def _validate_san_entry( if type(san_target) == str: if not isinstance(single_san_name, x509.general_name.DNSName): errors = [MpicValidationError.create(ErrorMessages.TLS_ALPN_ERROR_CERTIFICATE_SAN_NOT_DNSNAME)] - elif single_san_name.value.lower() != san_target.lower(): # Comparison is case insensitive per RFC4343 rules. + elif ( + single_san_name.value.lower() != san_target.lower() + ): # Comparison is case insensitive per RFC4343 rules. errors = [MpicValidationError.create(ErrorMessages.TLS_ALPN_ERROR_CERTIFICATE_SAN_NOT_HOSTNAME)] else: if not isinstance(single_san_name, x509.general_name.IPAddress): diff --git a/src/open_mpic_core/mpic_dcv_checker/mpic_dcv_checker.py b/src/open_mpic_core/mpic_dcv_checker/mpic_dcv_checker.py index d18ffee..24426c4 100644 --- a/src/open_mpic_core/mpic_dcv_checker/mpic_dcv_checker.py +++ b/src/open_mpic_core/mpic_dcv_checker/mpic_dcv_checker.py @@ -443,21 +443,35 @@ def evaluate_persistent_dns_response( # Look for required accounturi parameter valid_account_uri = False within_allowed_time = True # Assume valid unless proven otherwise + well_formed_record = True + found_accounturi_param = False + found_persistuntil_param = False if not (len(param_list) == 1 and param_list[0].strip() == ""): # if actual parameters follow the semicolon for parameter in param_list: name_and_value = parameter.split("=", 1) if len(name_and_value) != 2: + well_formed_record = False break # malformed parameter; skip to next record param_name = name_and_value[0].strip().lower() param_value = name_and_value[1].strip() if param_name == "accounturi": + if found_accounturi_param: # check if duplicate param; if so, mark as malformed and skip + well_formed_record = False + break + else: + found_accounturi_param = True if param_value.lower() == expected_account_uri: valid_account_uri = True else: break # accounturi does not match; skip to next record elif param_name == "persistuntil": + if found_persistuntil_param: # check if duplicate param; if so, mark as malformed and skip + well_formed_record = False + break + else: + found_persistuntil_param = True try: persist_until_in_seconds = int(param_value) # seconds since epoch current_seconds = int(time.time()) @@ -470,7 +484,7 @@ def evaluate_persistent_dns_response( # Additional parameters are ignored per CA/Browser Forum spec # Record is valid if issuer matches, account URI matches, and not expired - if valid_account_uri and within_allowed_time: + if valid_account_uri and within_allowed_time and well_formed_record: found_valid_record = True return found_valid_record diff --git a/tests/unit/open_mpic_core/test_check_request_parameters.py b/tests/unit/open_mpic_core/test_check_request_parameters.py index 5f61b30..9b73c9d 100644 --- a/tests/unit/open_mpic_core/test_check_request_parameters.py +++ b/tests/unit/open_mpic_core/test_check_request_parameters.py @@ -64,6 +64,10 @@ def check_request_parameters__should_automatically_deserialize_into_correct_obje "should fail validation when required issuer_domain_names is missing for DNS Persistent"), ('{"validation_method": "dns-persistent", "issuer_domain_names": ["authority.example"]}', "should fail validation when required expected_account_uri is missing for DNS Persistent"), + ('{"validation_method": "dns-persistent", "issuer_domain_names": ["authority.example"], "expected_account_uri": "not-a-valid-uri"}', + "should fail validation when expected_account_uri is not a valid URI for DNS Persistent"), + ('{"validation_method": "dns-persistent", "issuer_domain_names": [""], "expected_account_uri": "https://authority.example/acct/123"}', + "should fail validation when issuer_domain_names contains an empty string for DNS Persistent"), ]) # fmt: on def check_request_parameters__should_fail_validation_when_serialized_object_is_malformed( diff --git a/tests/unit/open_mpic_core/test_dcv_tls_alpn_validator.py b/tests/unit/open_mpic_core/test_dcv_tls_alpn_validator.py index 80b71c6..700dbb4 100644 --- a/tests/unit/open_mpic_core/test_dcv_tls_alpn_validator.py +++ b/tests/unit/open_mpic_core/test_dcv_tls_alpn_validator.py @@ -3,7 +3,7 @@ import pytest import ssl import socket -from unittest.mock import MagicMock +from unittest.mock import MagicMock, AsyncMock from io import StringIO from cryptography import x509 from cryptography.x509 import SubjectAlternativeName, Extension, NameAttribute @@ -144,7 +144,10 @@ async def perform_tls_alpn_validation__should_fail_given_noncritical_alpn_extens assert response.check_completed is True assert response.check_passed is False assert len(response.errors) > 0 - assert response.errors[0].error_message == ErrorMessages.TLS_ALPN_ERROR_CERTIFICATE_ALPN_EXTENSION_NONCRITICAL.message + assert ( + response.errors[0].error_message + == ErrorMessages.TLS_ALPN_ERROR_CERTIFICATE_ALPN_EXTENSION_NONCRITICAL.message + ) async def perform_tls_alpn_validation__should_fail_given_invalid_key_authorization_hash(self, mocker): dcv_request = ValidCheckCreator.create_valid_acme_tls_alpn_01_check_request() @@ -152,18 +155,20 @@ async def perform_tls_alpn_validation__should_fail_given_invalid_key_authorizati key_authorization_hash = dcv_request.dcv_check_parameters.key_authorization_hash mock_cert = self._create_mock_certificate(hostname, key_authorization_hash) # Modify the mock certificate to have a different key authorization hash - dcv_request.dcv_check_parameters.key_authorization_hash = 'invalid_hash_value' + dcv_request.dcv_check_parameters.key_authorization_hash = "invalid_hash_value" self._mock_socket_and_ssl_context(mocker, mock_cert) response = await self.validator.perform_tls_alpn_validation(dcv_request) assert response.check_completed is True assert response.check_passed is False - expected_error = ErrorMessages.DCV_PARAMETER_ERROR.message.format("invalid_hash_value") + expected_error = ErrorMessages.DCV_PARAMETER_ERROR.message.format( + "key_authorization_hash", "invalid_hash_value" + ) assert response.errors[0].error_message == expected_error async def perform_tls_alpn_validation__should_handle_connection_errors(self, mocker): dcv_request = ValidCheckCreator.create_valid_acme_tls_alpn_01_check_request() # Mock asyncio to raise an exception (haven't checked if this specific exception could be raised, but whatever) - mocker.patch('asyncio.open_connection', side_effect=socket.timeout("Connection timed out")) + mocker.patch("asyncio.open_connection", side_effect=socket.timeout("Connection timed out")) response = await self.validator.perform_tls_alpn_validation(dcv_request) assert response.check_completed is False assert response.check_passed is False @@ -172,7 +177,7 @@ async def perform_tls_alpn_validation__should_handle_connection_errors(self, moc async def perform_tls_alpn_validation__should_handle_ssl_errors(self, mocker): dcv_request = ValidCheckCreator.create_valid_acme_tls_alpn_01_check_request() # Mock SSL context to raise an exception - mocker.patch('ssl.create_default_context', side_effect=ssl.SSLError("SSL error")) + mocker.patch("ssl.create_default_context", side_effect=ssl.SSLError("SSL error")) response = await self.validator.perform_tls_alpn_validation(dcv_request) assert response.check_completed is False assert response.check_passed is False @@ -230,15 +235,15 @@ def _create_mock_certificate_with_nonmatching_san_entry(self, hostname, key_auth invalid_san = x509.general_name.DNSName("invalid.example.com") san_extension.value._general_names = [invalid_san] return mock_cert - + def _create_mock_certificate_with_ip_san(self, ipstring, key_authorization_hash): - mock_cert = self._create_mock_certificate('foo.baa', key_authorization_hash) + mock_cert = self._create_mock_certificate("foo.baa", key_authorization_hash) # Create a SAN entry that does not match the hostname san_extension = mock_cert.extensions[0] ip_san = x509.general_name.IPAddress(ipaddress.ip_address(ipstring)) san_extension.value._general_names = [ip_san] return mock_cert - + def _create_mock_certificate_with_noncritical_alpn_extension(self, hostname, key_authorization_hash): mock_cert = self._create_mock_certificate(hostname, key_authorization_hash) alpn_extension = mock_cert.extensions[1] @@ -249,19 +254,12 @@ def _mock_socket_and_ssl_context(self, mocker, mock_cert): # Mock socket.create_connection mock_writer = MagicMock() mock_reader = MagicMock() - mocker.patch('asyncio.open_connection', return_value=(mock_reader, mock_writer)) + mocker.patch("asyncio.open_connection", return_value=(mock_reader, mock_writer)) mock_writer.get_extra_info.return_value = mock_cert - # Mock SSL context and wrapped socket - #mock_ssl_socket = MagicMock() - #mock_ssl_socket.getpeercert.return_value = b'mock_binary_cert' - - ## Mock SSL context's wrap_socket method - #mock_context = MagicMock() - #mock_context.wrap_socket.return_value.__enter__.return_value = mock_ssl_socket - #mocker.patch('ssl.create_default_context', return_value=mock_context) + mock_writer.wait_closed = AsyncMock() # Make wait_closed an async mock # Mock x509.load_der_x509_certificate to return our mock certificate - mocker.patch('cryptography.x509.load_der_x509_certificate', return_value=mock_cert) + mocker.patch("cryptography.x509.load_der_x509_certificate", return_value=mock_cert) return mock_writer diff --git a/tests/unit/open_mpic_core/test_mpic_dcv_checker.py b/tests/unit/open_mpic_core/test_mpic_dcv_checker.py index 244a3aa..b0c7978 100644 --- a/tests/unit/open_mpic_core/test_mpic_dcv_checker.py +++ b/tests/unit/open_mpic_core/test_mpic_dcv_checker.py @@ -726,9 +726,7 @@ async def dns_based_dcv_checks__should_not_pass_given_non_matching_dns_record(se assert dcv_response.check_passed is False @pytest.mark.parametrize("set_persist_until_parameter", [True, False]) - def evaluate_persistent_dns_response__should_return_true_given_valid_record( - self, set_persist_until_parameter - ): + def evaluate_persistent_dns_response__should_return_true_given_valid_record(self, set_persist_until_parameter): issuer_domain_name = "ca.example.com" expected_account_uri = "https://ca.example.com/acct/123" @@ -749,7 +747,7 @@ def evaluate_persistent_dns_response__should_return_true_given_valid_record( def evaluate_persistent_dns_response__should_be_case_insensitive(self): issuer_domain_name = "cA.EXaMPle.com" expected_account_uri = "https://cA.EXaMPle.com/acct/123" - records = [f"{issuer_domain_name}; accounturi={expected_account_uri}"] + records = [f"{issuer_domain_name}; acCoUntUrI={expected_account_uri}"] expected_dns_record_content = ExpectedDnsRecordContent( possible_values=[issuer_domain_name.lower()], @@ -873,12 +871,17 @@ def evaluate_persistent_dns_response__should_accept_match_for_any_issuer_in_the_ def evaluate_persistent_dns_response__should_return_false_given_malformed_record(self): issuer_domain_names = ["ca.example"] expected_account_uri = "https://ca.example/acct/123" + time_now = int(time.time()) malformed_records = [ ";;;", # Only separators "ca.example", # Missing parameters "ca.example;", # Parameter separator but no parameters "ca.example; =value", # Missing parameter name "ca.example; accounturi", # Missing value + "ca.example; accounturi=https://ca.example/acct/123; badparam", # malformed param after expected param + "ca.example; badparam; accountURI=https://ca.example/acct/123", # malformed param before expected param + "ca.example; accounturi=https://ca.example/acct/234; accounturi=https://example/acct/123", # duplicate param + f"ca.example; accounturi=https://ca.example/acct/123; persistuntil={time_now+10}; persistuntil={time_now+20}", # duplicate persistUntil param ] expected_dns_record_content = ExpectedDnsRecordContent( @@ -1129,9 +1132,7 @@ async def side_effect(qname, rdtype): return self.patch_resolver_resolve_with_side_effect(mocker, self.dcv_checker.resolver, side_effect) - def _mock_dns_resolve_call_with_specific_response_code( - self, dcv_request: DcvCheckRequest, response_code, mocker - ): + def _mock_dns_resolve_call_with_specific_response_code(self, dcv_request: DcvCheckRequest, response_code, mocker): test_dns_query_answer = self._create_basic_dns_response_for_mock(dcv_request, mocker) test_dns_query_answer.response.rcode = lambda: response_code @@ -1194,7 +1195,7 @@ def _create_basic_dns_response_for_mock(self, dcv_request: DcvCheckRequest, mock case DcvValidationMethod.DNS_PERSISTENT: issuer_domain = check_parameters.issuer_domain_names[0] account_uri = check_parameters.expected_account_uri - persist_until = int(time.time()) + 365*24*60*60 # 1 year from now + persist_until = int(time.time()) + 365 * 24 * 60 * 60 # 1 year from now persistent_value = f"{issuer_domain}; accounturi={account_uri}; persistUntil={persist_until}" record_data = {"value": persistent_value} case DcvValidationMethod.CONTACT_EMAIL_CAA: @@ -1234,9 +1235,7 @@ def _mock_successful_tls_alpn_validation_entirely(self, dcv_request, mocker): details=DcvCheckResponseDetailsBuilder.build_response_details(DcvValidationMethod.ACME_TLS_ALPN_01), ) response.details.common_name = dcv_request.domain_or_ip_target - mocker.patch.object( - DcvTlsAlpnValidator, "perform_tls_alpn_validation", return_value=response - ) + mocker.patch.object(DcvTlsAlpnValidator, "perform_tls_alpn_validation", return_value=response) # fmt: off @pytest.mark.parametrize("input_target, expected_output", [ @@ -1262,7 +1261,7 @@ def shuffle_case(string_to_shuffle: str) -> str: if result.islower() or result.isupper(): for i, char in enumerate(result): if char.isalpha(): - result = result[:i] + char.swapcase() + result[i + 1:] + result = result[:i] + char.swapcase() + result[i + 1 :] break return result