Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/open_mpic_core/__about__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "6.3.0"
__version__ = "6.3.1"
21 changes: 21 additions & 0 deletions src/open_mpic_core/common_domain/check_parameters.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from abc import ABC
from typing import Literal, Union, Any, Set, Annotated
from urllib.parse import urlparse

from pydantic import BaseModel, field_validator, Field

Expand Down Expand Up @@ -61,6 +62,26 @@ class DcvDnsPersistentValidationParameters(DcvValidationParameters):
issuer_domain_names: list[str] # Disclosed issuer domain names from CA's CP/CPS
expected_account_uri: str # The specific account URI to validate

# expected_account_uri should be a URI with a scheme and host (e.g. "https://example.com/acct/123")
@field_validator("expected_account_uri")
@classmethod
def validate_account_uri(cls, v: str) -> str:
parsed_account_uri = urlparse(v)
if not parsed_account_uri.scheme or not parsed_account_uri.netloc:
raise ValueError(f"expected_account_uri must be a valid URI with scheme and host, got {v}")
return v

@field_validator("issuer_domain_names")
@classmethod
def validate_issuer_domain_names(cls, v: list[str]) -> list[str]:
if not v:
raise ValueError("issuer_domain_names must be a non-empty list of domain names")
for domain in v:
# check that v is non-empty
if not domain:
raise ValueError("issuer_domain_names must not contain empty strings")
return v


class DcvContactEmailTxtValidationParameters(DcvGeneralDnsValidationParameters):
validation_method: Literal[DcvValidationMethod.CONTACT_EMAIL_TXT] = DcvValidationMethod.CONTACT_EMAIL_TXT
Expand Down
2 changes: 1 addition & 1 deletion src/open_mpic_core/common_domain/messages/ErrorMessages.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
class ErrorMessages(Enum):
CAA_LOOKUP_ERROR = ('mpic_error:caa_checker:lookup', 'There was an error looking up the CAA record: {0}')
DCV_LOOKUP_ERROR = ('mpic_error:dcv_checker:lookup', 'There was an error looking up the DCV record. Error type: {0}, Error message: {1}')
DCV_PARAMETER_ERROR = ('mpic_error:dcv_checker:parameter:key_authorization_hash', 'The provided key_authorization_hash contained invalid characters: {0}')
DCV_PARAMETER_ERROR = ('mpic_error:dcv_checker:parameter', 'The provided parameter was invalid: {0}, Value provided: {1}')
COORDINATOR_COMMUNICATION_ERROR = ('mpic_error:coordinator:communication', 'Communication with the remote perspective failed.')
COORDINATOR_REMOTE_CHECK_ERROR = ('mpic_error:coordinator:remote_check', 'The remote check failed to complete: {0}')
TOO_MANY_FAILED_PERSPECTIVES_ERROR = ('mpic_error:coordinator:too_many_failed_perspectives', 'Too many perspectives failed to complete the check.')
Expand Down
33 changes: 23 additions & 10 deletions src/open_mpic_core/mpic_dcv_checker/dcv_tls_alpn_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,17 +79,24 @@ async def perform_tls_alpn_validation(self, request: DcvCheckRequest) -> DcvChec
context.minimum_version = TLSVersion.TLSv1_2
with getpeercert_with_binary_info(): # monkeypatch overrides default behavior and gets binary cert info
reader, writer = await asyncio.open_connection(
hostname, 443, ssl=context, server_hostname=sni_target # use the real host name # pass in the context.
hostname,
443,
ssl=context,
server_hostname=sni_target, # use the real host name # pass in the context.
)
binary_cert = writer.get_extra_info("peercert")
try:
binary_cert = writer.get_extra_info("peercert")
finally:
writer.close()
await writer.wait_closed()

dcv_check_response.check_completed = True # check will be considered "complete" whether it passes or fails

x509_cert = x509.load_der_x509_certificate(binary_cert)

subject_alt_name_extension = None
acme_tls_alpn_extension = None

for extension in x509_cert.extensions:
if extension.oid.dotted_string == self.ACME_TLS_ALPN_OID_DOTTED_STRING:
acme_tls_alpn_extension = extension
Expand Down Expand Up @@ -120,24 +127,28 @@ async def perform_tls_alpn_validation(self, request: DcvCheckRequest) -> DcvChec
+ key_authorization_hash_binary
)
self.logger.debug(f"tls-alpn-01: binary_challenge_seen: {binary_challenge_seen}")
self.logger.debug(f"tls-alpn-01: key_authorization_hash_binary: {key_authorization_hash_binary}")

self.logger.debug(
f"tls-alpn-01: key_authorization_hash_binary: {key_authorization_hash_binary}"
)

except ValueError:
dcv_check_response.errors = [
MpicValidationError.create(ErrorMessages.DCV_PARAMETER_ERROR, key_authorization_hash)
MpicValidationError.create(
ErrorMessages.DCV_PARAMETER_ERROR, "key_authorization_hash", key_authorization_hash
)
]
else:
# Only assign the check_passed attribute if we properly parsed the challenge.
dcv_check_response.check_passed = binary_challenge_seen == key_authorization_hash_binary

# Obtain the certs common name for logging.
common_name_attributes = x509_cert.subject.get_attributes_for_oid(NameOID.COMMON_NAME)
common_name = None
if len(common_name_attributes) > 0:
common_name = str(common_name_attributes[0].value)
self.logger.debug(f"common name: {common_name}")
dcv_check_response.details.common_name = common_name # Cert common name for logging info.
dcv_check_response.details.common_name = common_name # Cert common name for logging info.

self.logger.debug(f"tls-alpn-01: passed? {dcv_check_response.check_passed}")
dcv_check_response.timestamp_ns = time.time_ns()
except asyncio.TimeoutError as e:
Expand Down Expand Up @@ -169,7 +180,9 @@ def _validate_san_entry(
if type(san_target) == str:
if not isinstance(single_san_name, x509.general_name.DNSName):
errors = [MpicValidationError.create(ErrorMessages.TLS_ALPN_ERROR_CERTIFICATE_SAN_NOT_DNSNAME)]
elif single_san_name.value.lower() != san_target.lower(): # Comparison is case insensitive per RFC4343 rules.
elif (
single_san_name.value.lower() != san_target.lower()
): # Comparison is case insensitive per RFC4343 rules.
errors = [MpicValidationError.create(ErrorMessages.TLS_ALPN_ERROR_CERTIFICATE_SAN_NOT_HOSTNAME)]
else:
if not isinstance(single_san_name, x509.general_name.IPAddress):
Expand Down
16 changes: 15 additions & 1 deletion src/open_mpic_core/mpic_dcv_checker/mpic_dcv_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,21 +443,35 @@ def evaluate_persistent_dns_response(
# Look for required accounturi parameter
valid_account_uri = False
within_allowed_time = True # Assume valid unless proven otherwise
well_formed_record = True
found_accounturi_param = False
found_persistuntil_param = False

if not (len(param_list) == 1 and param_list[0].strip() == ""): # if actual parameters follow the semicolon
for parameter in param_list:
name_and_value = parameter.split("=", 1)
if len(name_and_value) != 2:
well_formed_record = False
break # malformed parameter; skip to next record
param_name = name_and_value[0].strip().lower()
param_value = name_and_value[1].strip()

if param_name == "accounturi":
if found_accounturi_param: # check if duplicate param; if so, mark as malformed and skip
well_formed_record = False
break
else:
found_accounturi_param = True
if param_value.lower() == expected_account_uri:
valid_account_uri = True
else:
break # accounturi does not match; skip to next record
elif param_name == "persistuntil":
if found_persistuntil_param: # check if duplicate param; if so, mark as malformed and skip
well_formed_record = False
break
else:
found_persistuntil_param = True
try:
persist_until_in_seconds = int(param_value) # seconds since epoch
current_seconds = int(time.time())
Expand All @@ -470,7 +484,7 @@ def evaluate_persistent_dns_response(
# Additional parameters are ignored per CA/Browser Forum spec

# Record is valid if issuer matches, account URI matches, and not expired
if valid_account_uri and within_allowed_time:
if valid_account_uri and within_allowed_time and well_formed_record:
found_valid_record = True

return found_valid_record
Expand Down
4 changes: 4 additions & 0 deletions tests/unit/open_mpic_core/test_check_request_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ def check_request_parameters__should_automatically_deserialize_into_correct_obje
"should fail validation when required issuer_domain_names is missing for DNS Persistent"),
('{"validation_method": "dns-persistent", "issuer_domain_names": ["authority.example"]}',
"should fail validation when required expected_account_uri is missing for DNS Persistent"),
('{"validation_method": "dns-persistent", "issuer_domain_names": ["authority.example"], "expected_account_uri": "not-a-valid-uri"}',
"should fail validation when expected_account_uri is not a valid URI for DNS Persistent"),
('{"validation_method": "dns-persistent", "issuer_domain_names": [""], "expected_account_uri": "https://authority.example/acct/123"}',
"should fail validation when issuer_domain_names contains an empty string for DNS Persistent"),
])
# fmt: on
def check_request_parameters__should_fail_validation_when_serialized_object_is_malformed(
Expand Down
36 changes: 17 additions & 19 deletions tests/unit/open_mpic_core/test_dcv_tls_alpn_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pytest
import ssl
import socket
from unittest.mock import MagicMock
from unittest.mock import MagicMock, AsyncMock
from io import StringIO
from cryptography import x509
from cryptography.x509 import SubjectAlternativeName, Extension, NameAttribute
Expand Down Expand Up @@ -144,26 +144,31 @@ async def perform_tls_alpn_validation__should_fail_given_noncritical_alpn_extens
assert response.check_completed is True
assert response.check_passed is False
assert len(response.errors) > 0
assert response.errors[0].error_message == ErrorMessages.TLS_ALPN_ERROR_CERTIFICATE_ALPN_EXTENSION_NONCRITICAL.message
assert (
response.errors[0].error_message
== ErrorMessages.TLS_ALPN_ERROR_CERTIFICATE_ALPN_EXTENSION_NONCRITICAL.message
)

async def perform_tls_alpn_validation__should_fail_given_invalid_key_authorization_hash(self, mocker):
dcv_request = ValidCheckCreator.create_valid_acme_tls_alpn_01_check_request()
hostname = dcv_request.domain_or_ip_target
key_authorization_hash = dcv_request.dcv_check_parameters.key_authorization_hash
mock_cert = self._create_mock_certificate(hostname, key_authorization_hash)
# Modify the mock certificate to have a different key authorization hash
dcv_request.dcv_check_parameters.key_authorization_hash = 'invalid_hash_value'
dcv_request.dcv_check_parameters.key_authorization_hash = "invalid_hash_value"
self._mock_socket_and_ssl_context(mocker, mock_cert)
response = await self.validator.perform_tls_alpn_validation(dcv_request)
assert response.check_completed is True
assert response.check_passed is False
expected_error = ErrorMessages.DCV_PARAMETER_ERROR.message.format("invalid_hash_value")
expected_error = ErrorMessages.DCV_PARAMETER_ERROR.message.format(
"key_authorization_hash", "invalid_hash_value"
)
assert response.errors[0].error_message == expected_error

async def perform_tls_alpn_validation__should_handle_connection_errors(self, mocker):
dcv_request = ValidCheckCreator.create_valid_acme_tls_alpn_01_check_request()
# Mock asyncio to raise an exception (haven't checked if this specific exception could be raised, but whatever)
mocker.patch('asyncio.open_connection', side_effect=socket.timeout("Connection timed out"))
mocker.patch("asyncio.open_connection", side_effect=socket.timeout("Connection timed out"))
response = await self.validator.perform_tls_alpn_validation(dcv_request)
assert response.check_completed is False
assert response.check_passed is False
Expand All @@ -172,7 +177,7 @@ async def perform_tls_alpn_validation__should_handle_connection_errors(self, moc
async def perform_tls_alpn_validation__should_handle_ssl_errors(self, mocker):
dcv_request = ValidCheckCreator.create_valid_acme_tls_alpn_01_check_request()
# Mock SSL context to raise an exception
mocker.patch('ssl.create_default_context', side_effect=ssl.SSLError("SSL error"))
mocker.patch("ssl.create_default_context", side_effect=ssl.SSLError("SSL error"))
response = await self.validator.perform_tls_alpn_validation(dcv_request)
assert response.check_completed is False
assert response.check_passed is False
Expand Down Expand Up @@ -230,15 +235,15 @@ def _create_mock_certificate_with_nonmatching_san_entry(self, hostname, key_auth
invalid_san = x509.general_name.DNSName("invalid.example.com")
san_extension.value._general_names = [invalid_san]
return mock_cert

def _create_mock_certificate_with_ip_san(self, ipstring, key_authorization_hash):
mock_cert = self._create_mock_certificate('foo.baa', key_authorization_hash)
mock_cert = self._create_mock_certificate("foo.baa", key_authorization_hash)
# Create a SAN entry that does not match the hostname
san_extension = mock_cert.extensions[0]
ip_san = x509.general_name.IPAddress(ipaddress.ip_address(ipstring))
san_extension.value._general_names = [ip_san]
return mock_cert

def _create_mock_certificate_with_noncritical_alpn_extension(self, hostname, key_authorization_hash):
mock_cert = self._create_mock_certificate(hostname, key_authorization_hash)
alpn_extension = mock_cert.extensions[1]
Expand All @@ -249,19 +254,12 @@ def _mock_socket_and_ssl_context(self, mocker, mock_cert):
# Mock socket.create_connection
mock_writer = MagicMock()
mock_reader = MagicMock()
mocker.patch('asyncio.open_connection', return_value=(mock_reader, mock_writer))
mocker.patch("asyncio.open_connection", return_value=(mock_reader, mock_writer))

mock_writer.get_extra_info.return_value = mock_cert
# Mock SSL context and wrapped socket
#mock_ssl_socket = MagicMock()
#mock_ssl_socket.getpeercert.return_value = b'mock_binary_cert'

## Mock SSL context's wrap_socket method
#mock_context = MagicMock()
#mock_context.wrap_socket.return_value.__enter__.return_value = mock_ssl_socket
#mocker.patch('ssl.create_default_context', return_value=mock_context)
mock_writer.wait_closed = AsyncMock() # Make wait_closed an async mock

# Mock x509.load_der_x509_certificate to return our mock certificate
mocker.patch('cryptography.x509.load_der_x509_certificate', return_value=mock_cert)
mocker.patch("cryptography.x509.load_der_x509_certificate", return_value=mock_cert)

return mock_writer
Loading