Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ dependencies = [
"requests==2.32.4",
"dnspython==2.7.0",
"pydantic==2.11.7",
"aiohttp==3.12.13",
"aiohttp==3.12.14",
"black==25.1.0",
"cryptography==45.0.4",
]
Expand Down
2 changes: 1 addition & 1 deletion src/open_mpic_core/__about__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "5.10.0"
__version__ = "6.0.0"
6 changes: 4 additions & 2 deletions src/open_mpic_core/mpic_caa_checker/mpic_caa_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,9 @@ async def find_caa_records_and_domain(self, caa_request) -> tuple[RRset, Name]:
except (dns.resolver.NoAnswer, dns.resolver.NXDOMAIN):
domain = domain.parent()
except Exception as e:
self.logger.error(f"Exception during CAA lookup for {caa_request.domain_or_ip_target}: {e}. Trace identifier: {caa_request.trace_identifier}")
self.logger.error(
f"Exception during CAA lookup for {caa_request.domain_or_ip_target}: {e}. Trace ID: {caa_request.trace_identifier}"
)
raise MpicCaaLookupException(f"{e}") from e

return rrset, domain
Expand Down Expand Up @@ -105,7 +107,7 @@ async def check_caa(self, caa_request: CaaCheckRequest) -> CaaCheckResponse:
caa_found = rrset is not None
except (MpicCaaLookupException, ValueError) as e:
caa_lookup_error = True
error_message = f"Error during CAA lookup for {caa_request.domain_or_ip_target}: {e}. Trace identifier: {caa_request.trace_identifier}"
error_message = f"Error during CAA lookup for {caa_request.domain_or_ip_target}: {e}. Trace ID: {caa_request.trace_identifier}"
caa_check_response.errors = [MpicValidationError.create(ErrorMessages.CAA_LOOKUP_ERROR, error_message)]
caa_check_response.details.found_at = None
caa_check_response.details.records_seen = None
Expand Down
12 changes: 8 additions & 4 deletions src/open_mpic_core/mpic_coordinator/mpic_coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,14 +214,16 @@ async def call_remote_perspective(
"""
try:
# noinspection PyUnresolvedReferences
async with self.logger.trace_timing(f"MPIC round-trip with perspective {call_config.perspective.code}"):
async with self.logger.trace_timing(
f"MPIC round-trip with perspective {call_config.perspective.code}; trace ID: {call_config.check_request.trace_identifier}"
):
response = await call_remote_perspective_function(
call_config.perspective, call_config.check_type, call_config.check_request
)
except Exception as exc:
error_message = str(exc) if str(exc) else exc.__class__.__name__
raise RemoteCheckException(
f"Check failed for perspective {call_config.perspective.code}, target {call_config.check_request.domain_or_ip_target}: {error_message}",
f"Check failed for perspective {call_config.perspective.code}, target {call_config.check_request.domain_or_ip_target}: {error_message}; trace ID: {call_config.check_request.trace_identifier}",
call_config=call_config,
) from exc
return PerspectiveResponse(perspective_code=call_config.perspective.code, check_response=response)
Expand Down Expand Up @@ -270,7 +272,9 @@ async def call_checkers_and_collect_responses(
]

# noinspection PyUnresolvedReferences
async with self.logger.trace_timing(f"MPIC round-trip with {len(perspectives_to_use)} perspectives"):
async with self.logger.trace_timing(
f"MPIC round-trip with {len(perspectives_to_use)} perspectives; trace ID: {mpic_request.trace_identifier}"
):
responses = await asyncio.gather(*tasks, return_exceptions=True)

for response in responses:
Expand All @@ -279,7 +283,7 @@ async def call_checkers_and_collect_responses(
# (trying to handle other Exceptions should be unreachable code)
if isinstance(response, RemoteCheckException):
response_as_string = str(response)
log_msg = f"{response_as_string} - Trace identifier: {mpic_request.trace_identifier}"
log_msg = f"{response_as_string} - trace ID: {mpic_request.trace_identifier}"
logger.warning(log_msg)
error_response = MpicCoordinator.build_error_perspective_response_from_exception(response)
perspective_responses.append(error_response)
Expand Down
103 changes: 40 additions & 63 deletions src/open_mpic_core/mpic_dcv_checker/mpic_dcv_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,92 +31,57 @@ class MpicDcvChecker:
WELL_KNOWN_ACME_PATH = ".well-known/acme-challenge"
CONTACT_EMAIL_TAG = "contactemail"
CONTACT_PHONE_TAG = "contactphone"
# acme_tls_alpn related constants are in ./dcv_tls_alpn_validator.py
# acme_tls_alpn related constants are in ./dcv_tls_alpn_validator.py

def __init__(
self,
http_client_timeout: float = 30,
reuse_http_client: bool = False,
verify_ssl: bool = False,
log_level: int = None,
dns_timeout: float = None,
dns_resolution_lifetime: float = None,
):
self.verify_ssl = verify_ssl
self._reuse_http_client = reuse_http_client
self._async_http_client = None
self._http_client_loop = None # track which loop the http client was created on
self._http_client_timeout = http_client_timeout

self.logger = logger.getChild(self.__class__.__name__)
if log_level is not None:
self.logger.setLevel(log_level)


self.resolver = dns.asyncresolver.get_default_resolver()
self.resolver.timeout = dns_timeout if dns_timeout is not None else self.resolver.timeout
self.resolver.lifetime = (
dns_resolution_lifetime if dns_resolution_lifetime is not None else self.resolver.lifetime
)
self.acme_tls_alpn_validator = DcvTlsAlpnValidator(log_level=log_level)
self._http_client_timeout = http_client_timeout

@asynccontextmanager
async def get_async_http_client(self):
current_loop = asyncio.get_running_loop()

if self._reuse_http_client: # implementations such as FastAPI may want this for efficiency
reason_for_new_client = None
# noinspection PyProtectedMember
if self._async_http_client is None or self._async_http_client.closed:
reason_for_new_client = "Creating new async HTTP client because there isn't an active one"
elif self._http_client_loop is not current_loop:
reason_for_new_client = "Creating new async HTTP client due to a mismatch in running event loops"

if reason_for_new_client is not None:
self.logger.debug(reason_for_new_client)
if self._async_http_client and not self._async_http_client.closed:
await self._async_http_client.close()

connector = aiohttp.TCPConnector(ssl=self.verify_ssl, limit=0) # no limit on simultaneous connections
dummy_cookie_jar = aiohttp.DummyCookieJar() # disable cookie processing
self._async_http_client = aiohttp.ClientSession(
connector=connector,
timeout=aiohttp.ClientTimeout(total=self._http_client_timeout),
trust_env=True,
cookie_jar=dummy_cookie_jar,
)
self._http_client_loop = current_loop
yield self._async_http_client
else: # implementations such as AWS Lambda will need a new client for each invocation
connector = aiohttp.TCPConnector(ssl=self.verify_ssl, limit=0)
dummy_cookie_jar = aiohttp.DummyCookieJar() # disable cookie processing
client = aiohttp.ClientSession(
connector=connector,
timeout=aiohttp.ClientTimeout(total=self._http_client_timeout),
trust_env=True,
cookie_jar=dummy_cookie_jar,
)
try:
yield client
finally:
if not client.closed:
await client.close()

async def shutdown(self):
"""Close the async HTTP client.

Will need to call this as part of shutdown in wrapping code.
For example, FastAPI's lifespan (https://fastapi.tiangolo.com/advanced/events/)
:return:
"""
if self._async_http_client and not self._async_http_client.closed:
await self._async_http_client.close()
self._async_http_client = None
connector = aiohttp.TCPConnector(ssl=self.verify_ssl, limit=0, force_close=True)
dummy_cookie_jar = aiohttp.DummyCookieJar() # disable cookie processing
client = aiohttp.ClientSession(
connector=connector,
timeout=aiohttp.ClientTimeout(total=self._http_client_timeout),
trust_env=True,
cookie_jar=dummy_cookie_jar,
)
try:
yield client
finally:
if not client.closed:
await client.close()

async def check_dcv(self, dcv_request: DcvCheckRequest) -> DcvCheckResponse:
validation_method = dcv_request.dcv_check_parameters.validation_method
# noinspection PyUnresolvedReferences
self.logger.trace(f"Checking DCV for {dcv_request.domain_or_ip_target} with method {validation_method}.")
self.logger.trace(
"Checking DCV for %s with method %s. Trace ID: %s",
dcv_request.domain_or_ip_target,
validation_method,
dcv_request.trace_identifier,
)

# encode domain if needed
dcv_request.domain_or_ip_target = DomainEncoder.prepare_target_for_lookup(dcv_request.domain_or_ip_target)
Expand All @@ -131,7 +96,13 @@ async def check_dcv(self, dcv_request: DcvCheckRequest) -> DcvCheckResponse:
result = await self.perform_general_dns_validation(dcv_request)

# noinspection PyUnresolvedReferences
self.logger.trace(f"Completed DCV for {dcv_request.domain_or_ip_target} with method {validation_method}")

self.logger.trace(
"Completed DCV for %s with method %s. Trace ID: %s",
dcv_request.domain_or_ip_target,
validation_method,
dcv_request.trace_identifier,
)
return result

async def perform_general_dns_validation(self, request: DcvCheckRequest) -> DcvCheckResponse:
Expand All @@ -158,13 +129,15 @@ async def perform_general_dns_validation(self, request: DcvCheckRequest) -> DcvC

try:
# noinspection PyUnresolvedReferences
async with self.logger.trace_timing(f"DNS lookup for target {name_to_resolve}"):
async with self.logger.trace_timing(
f"DNS lookup for target {name_to_resolve}. Trace ID: {request.trace_identifier}"
):
lookup = await self.perform_dns_resolution(name_to_resolve, validation_method, dns_record_type)
MpicDcvChecker.evaluate_dns_lookup_response(
dcv_check_response, lookup, validation_method, dns_record_type, expected_dns_record_content, exact_match
)
except dns.exception.DNSException as e:
log_msg = f"DNS lookup error for {name_to_resolve}: {str(e)}. Trace identifier: {request.trace_identifier}"
log_msg = f"DNS lookup error for {name_to_resolve}: {str(e)}. Trace ID: {request.trace_identifier}"
if isinstance(e, dns.resolver.NoAnswer) or isinstance(e, dns.resolver.NXDOMAIN):
dcv_check_response.check_completed = True # errors on the target domain, not the lookup
# noinspection PyUnresolvedReferences
Expand Down Expand Up @@ -220,21 +193,23 @@ async def perform_http_based_validation(self, request: DcvCheckRequest) -> DcvCh
try:
async with self.get_async_http_client() as async_http_client:
# noinspection PyUnresolvedReferences
async with self.logger.trace_timing(f"HTTP lookup for target {token_url}"):
async with self.logger.trace_timing(
f"HTTP lookup for target {token_url}, trace ID: {request.trace_identifier}"
):
async with async_http_client.get(url=token_url, headers=http_headers, max_redirects=20) as response:
dcv_check_response = await MpicDcvChecker.evaluate_http_lookup_response(
request, dcv_check_response, response, token_url, expected_response_content
)
except asyncio.TimeoutError as e:
dcv_check_response.timestamp_ns = time.time_ns()
log_message = f"Timeout connecting to {token_url}: {str(e)}. Trace identifier: {request.trace_identifier}"
log_message = f"Timeout connecting to {token_url}: {str(e)}. Trace ID: {request.trace_identifier}"
self.logger.warning(log_message)
message = f"Connection timed out while attempting to connect to {token_url}"
dcv_check_response.errors = [
MpicValidationError.create(ErrorMessages.DCV_LOOKUP_ERROR, e.__class__.__name__, message)
]
except (ClientError, HTTPException, OSError) as e:
log_message = f"Error connecting to {token_url}: {str(e)}. Trace identifier: {request.trace_identifier}"
log_message = f"Error connecting to {token_url}: {str(e)}. Trace ID: {request.trace_identifier}"
self.logger.error(log_message)
dcv_check_response.timestamp_ns = time.time_ns()
dcv_check_response.errors = [
Expand Down Expand Up @@ -304,6 +279,8 @@ async def evaluate_http_lookup_response(
dcv_check_response.details.response_history = response_history
dcv_check_response.details.response_page = base64.b64encode(content).decode()

http_response.close() # ensure connection is closed

return dcv_check_response

@staticmethod
Expand Down Expand Up @@ -363,7 +340,7 @@ def evaluate_dns_lookup_response(
# This code will flatten a list of record sets should a CNAME come with multiple records in the record set.
# Per the RFCs, there can only ever be one CNAME in a CNAME record set.
for cname_record in cname_record_set:
cname_chain_str.append( b".".join(cname_record.target.labels).decode("utf-8"))
cname_chain_str.append(b".".join(cname_record.target.labels).decode("utf-8"))
dcv_check_response.details.cname_chain = cname_chain_str
dcv_check_response.details.found_at = dns_response.qname.to_text(omit_final_dot=True)

Expand Down
12 changes: 0 additions & 12 deletions tests/unit/open_mpic_core/test_mpic_dcv_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,17 +80,6 @@ def mpic_dcv_checker__should_be_able_to_log_at_trace_level(self):
log_contents = self.log_output.getvalue()
assert all(text in log_contents for text in [test_message, "TRACE", dcv_checker.logger.name])

@pytest.mark.parametrize("reuse_http_client", [True, False])
async def mpic_dcv_checker__should_optionally_reuse_http_client(self, reuse_http_client):
dcv_checker = MpicDcvChecker(reuse_http_client=reuse_http_client, log_level=TRACE_LEVEL)
async with dcv_checker.get_async_http_client() as client_1:
async with dcv_checker.get_async_http_client() as client_2:
try:
assert (client_1 is client_2) == reuse_http_client
finally:
if reuse_http_client:
await dcv_checker.shutdown()

# integration test of a sort -- only mocking dns methods rather than remaining class methods
@pytest.mark.parametrize(
"dcv_method, record_type",
Expand Down Expand Up @@ -938,7 +927,6 @@ async def side_effect(qname, rdtype):
return test_dns_query_answer
raise self.raise_(dns.resolver.NoAnswer)


return self.patch_resolver_resolve_with_side_effect(mocker, self.dcv_checker.resolver, side_effect)

def _mock_dns_resolve_call_with_specific_response_code(
Expand Down