diff --git a/pyproject.toml b/pyproject.toml index a3a8229..75a2f4d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,7 +27,7 @@ dependencies = [ "requests==2.32.4", "dnspython==2.7.0", "pydantic==2.11.7", - "aiohttp==3.12.13", + "aiohttp==3.12.14", "black==25.1.0", "cryptography==45.0.4", ] diff --git a/src/open_mpic_core/__about__.py b/src/open_mpic_core/__about__.py index 1b1e67f..0f607a5 100644 --- a/src/open_mpic_core/__about__.py +++ b/src/open_mpic_core/__about__.py @@ -1 +1 @@ -__version__ = "5.10.0" +__version__ = "6.0.0" diff --git a/src/open_mpic_core/mpic_caa_checker/mpic_caa_checker.py b/src/open_mpic_core/mpic_caa_checker/mpic_caa_checker.py index f74e99a..56cbc42 100644 --- a/src/open_mpic_core/mpic_caa_checker/mpic_caa_checker.py +++ b/src/open_mpic_core/mpic_caa_checker/mpic_caa_checker.py @@ -59,7 +59,9 @@ async def find_caa_records_and_domain(self, caa_request) -> tuple[RRset, Name]: except (dns.resolver.NoAnswer, dns.resolver.NXDOMAIN): domain = domain.parent() except Exception as e: - self.logger.error(f"Exception during CAA lookup for {caa_request.domain_or_ip_target}: {e}. Trace identifier: {caa_request.trace_identifier}") + self.logger.error( + f"Exception during CAA lookup for {caa_request.domain_or_ip_target}: {e}. Trace ID: {caa_request.trace_identifier}" + ) raise MpicCaaLookupException(f"{e}") from e return rrset, domain @@ -105,7 +107,7 @@ async def check_caa(self, caa_request: CaaCheckRequest) -> CaaCheckResponse: caa_found = rrset is not None except (MpicCaaLookupException, ValueError) as e: caa_lookup_error = True - error_message = f"Error during CAA lookup for {caa_request.domain_or_ip_target}: {e}. Trace identifier: {caa_request.trace_identifier}" + error_message = f"Error during CAA lookup for {caa_request.domain_or_ip_target}: {e}. Trace ID: {caa_request.trace_identifier}" caa_check_response.errors = [MpicValidationError.create(ErrorMessages.CAA_LOOKUP_ERROR, error_message)] caa_check_response.details.found_at = None caa_check_response.details.records_seen = None diff --git a/src/open_mpic_core/mpic_coordinator/mpic_coordinator.py b/src/open_mpic_core/mpic_coordinator/mpic_coordinator.py index 1a85ed4..292e0cf 100644 --- a/src/open_mpic_core/mpic_coordinator/mpic_coordinator.py +++ b/src/open_mpic_core/mpic_coordinator/mpic_coordinator.py @@ -214,14 +214,16 @@ async def call_remote_perspective( """ try: # noinspection PyUnresolvedReferences - async with self.logger.trace_timing(f"MPIC round-trip with perspective {call_config.perspective.code}"): + async with self.logger.trace_timing( + f"MPIC round-trip with perspective {call_config.perspective.code}; trace ID: {call_config.check_request.trace_identifier}" + ): response = await call_remote_perspective_function( call_config.perspective, call_config.check_type, call_config.check_request ) except Exception as exc: error_message = str(exc) if str(exc) else exc.__class__.__name__ raise RemoteCheckException( - f"Check failed for perspective {call_config.perspective.code}, target {call_config.check_request.domain_or_ip_target}: {error_message}", + f"Check failed for perspective {call_config.perspective.code}, target {call_config.check_request.domain_or_ip_target}: {error_message}; trace ID: {call_config.check_request.trace_identifier}", call_config=call_config, ) from exc return PerspectiveResponse(perspective_code=call_config.perspective.code, check_response=response) @@ -270,7 +272,9 @@ async def call_checkers_and_collect_responses( ] # noinspection PyUnresolvedReferences - async with self.logger.trace_timing(f"MPIC round-trip with {len(perspectives_to_use)} perspectives"): + async with self.logger.trace_timing( + f"MPIC round-trip with {len(perspectives_to_use)} perspectives; trace ID: {mpic_request.trace_identifier}" + ): responses = await asyncio.gather(*tasks, return_exceptions=True) for response in responses: @@ -279,7 +283,7 @@ async def call_checkers_and_collect_responses( # (trying to handle other Exceptions should be unreachable code) if isinstance(response, RemoteCheckException): response_as_string = str(response) - log_msg = f"{response_as_string} - Trace identifier: {mpic_request.trace_identifier}" + log_msg = f"{response_as_string} - trace ID: {mpic_request.trace_identifier}" logger.warning(log_msg) error_response = MpicCoordinator.build_error_perspective_response_from_exception(response) perspective_responses.append(error_response) diff --git a/src/open_mpic_core/mpic_dcv_checker/mpic_dcv_checker.py b/src/open_mpic_core/mpic_dcv_checker/mpic_dcv_checker.py index a938e04..62166fd 100644 --- a/src/open_mpic_core/mpic_dcv_checker/mpic_dcv_checker.py +++ b/src/open_mpic_core/mpic_dcv_checker/mpic_dcv_checker.py @@ -31,92 +31,57 @@ class MpicDcvChecker: WELL_KNOWN_ACME_PATH = ".well-known/acme-challenge" CONTACT_EMAIL_TAG = "contactemail" CONTACT_PHONE_TAG = "contactphone" -# acme_tls_alpn related constants are in ./dcv_tls_alpn_validator.py + # acme_tls_alpn related constants are in ./dcv_tls_alpn_validator.py def __init__( self, http_client_timeout: float = 30, - reuse_http_client: bool = False, verify_ssl: bool = False, log_level: int = None, dns_timeout: float = None, dns_resolution_lifetime: float = None, ): self.verify_ssl = verify_ssl - self._reuse_http_client = reuse_http_client self._async_http_client = None self._http_client_loop = None # track which loop the http client was created on - self._http_client_timeout = http_client_timeout self.logger = logger.getChild(self.__class__.__name__) if log_level is not None: self.logger.setLevel(log_level) - self.resolver = dns.asyncresolver.get_default_resolver() self.resolver.timeout = dns_timeout if dns_timeout is not None else self.resolver.timeout self.resolver.lifetime = ( dns_resolution_lifetime if dns_resolution_lifetime is not None else self.resolver.lifetime ) self.acme_tls_alpn_validator = DcvTlsAlpnValidator(log_level=log_level) + self._http_client_timeout = http_client_timeout @asynccontextmanager async def get_async_http_client(self): - current_loop = asyncio.get_running_loop() - - if self._reuse_http_client: # implementations such as FastAPI may want this for efficiency - reason_for_new_client = None - # noinspection PyProtectedMember - if self._async_http_client is None or self._async_http_client.closed: - reason_for_new_client = "Creating new async HTTP client because there isn't an active one" - elif self._http_client_loop is not current_loop: - reason_for_new_client = "Creating new async HTTP client due to a mismatch in running event loops" - - if reason_for_new_client is not None: - self.logger.debug(reason_for_new_client) - if self._async_http_client and not self._async_http_client.closed: - await self._async_http_client.close() - - connector = aiohttp.TCPConnector(ssl=self.verify_ssl, limit=0) # no limit on simultaneous connections - dummy_cookie_jar = aiohttp.DummyCookieJar() # disable cookie processing - self._async_http_client = aiohttp.ClientSession( - connector=connector, - timeout=aiohttp.ClientTimeout(total=self._http_client_timeout), - trust_env=True, - cookie_jar=dummy_cookie_jar, - ) - self._http_client_loop = current_loop - yield self._async_http_client - else: # implementations such as AWS Lambda will need a new client for each invocation - connector = aiohttp.TCPConnector(ssl=self.verify_ssl, limit=0) - dummy_cookie_jar = aiohttp.DummyCookieJar() # disable cookie processing - client = aiohttp.ClientSession( - connector=connector, - timeout=aiohttp.ClientTimeout(total=self._http_client_timeout), - trust_env=True, - cookie_jar=dummy_cookie_jar, - ) - try: - yield client - finally: - if not client.closed: - await client.close() - - async def shutdown(self): - """Close the async HTTP client. - - Will need to call this as part of shutdown in wrapping code. - For example, FastAPI's lifespan (https://fastapi.tiangolo.com/advanced/events/) - :return: - """ - if self._async_http_client and not self._async_http_client.closed: - await self._async_http_client.close() - self._async_http_client = None + connector = aiohttp.TCPConnector(ssl=self.verify_ssl, limit=0, force_close=True) + dummy_cookie_jar = aiohttp.DummyCookieJar() # disable cookie processing + client = aiohttp.ClientSession( + connector=connector, + timeout=aiohttp.ClientTimeout(total=self._http_client_timeout), + trust_env=True, + cookie_jar=dummy_cookie_jar, + ) + try: + yield client + finally: + if not client.closed: + await client.close() async def check_dcv(self, dcv_request: DcvCheckRequest) -> DcvCheckResponse: validation_method = dcv_request.dcv_check_parameters.validation_method # noinspection PyUnresolvedReferences - self.logger.trace(f"Checking DCV for {dcv_request.domain_or_ip_target} with method {validation_method}.") + self.logger.trace( + "Checking DCV for %s with method %s. Trace ID: %s", + dcv_request.domain_or_ip_target, + validation_method, + dcv_request.trace_identifier, + ) # encode domain if needed dcv_request.domain_or_ip_target = DomainEncoder.prepare_target_for_lookup(dcv_request.domain_or_ip_target) @@ -131,7 +96,13 @@ async def check_dcv(self, dcv_request: DcvCheckRequest) -> DcvCheckResponse: result = await self.perform_general_dns_validation(dcv_request) # noinspection PyUnresolvedReferences - self.logger.trace(f"Completed DCV for {dcv_request.domain_or_ip_target} with method {validation_method}") + + self.logger.trace( + "Completed DCV for %s with method %s. Trace ID: %s", + dcv_request.domain_or_ip_target, + validation_method, + dcv_request.trace_identifier, + ) return result async def perform_general_dns_validation(self, request: DcvCheckRequest) -> DcvCheckResponse: @@ -158,13 +129,15 @@ async def perform_general_dns_validation(self, request: DcvCheckRequest) -> DcvC try: # noinspection PyUnresolvedReferences - async with self.logger.trace_timing(f"DNS lookup for target {name_to_resolve}"): + async with self.logger.trace_timing( + f"DNS lookup for target {name_to_resolve}. Trace ID: {request.trace_identifier}" + ): lookup = await self.perform_dns_resolution(name_to_resolve, validation_method, dns_record_type) MpicDcvChecker.evaluate_dns_lookup_response( dcv_check_response, lookup, validation_method, dns_record_type, expected_dns_record_content, exact_match ) except dns.exception.DNSException as e: - log_msg = f"DNS lookup error for {name_to_resolve}: {str(e)}. Trace identifier: {request.trace_identifier}" + log_msg = f"DNS lookup error for {name_to_resolve}: {str(e)}. Trace ID: {request.trace_identifier}" if isinstance(e, dns.resolver.NoAnswer) or isinstance(e, dns.resolver.NXDOMAIN): dcv_check_response.check_completed = True # errors on the target domain, not the lookup # noinspection PyUnresolvedReferences @@ -220,21 +193,23 @@ async def perform_http_based_validation(self, request: DcvCheckRequest) -> DcvCh try: async with self.get_async_http_client() as async_http_client: # noinspection PyUnresolvedReferences - async with self.logger.trace_timing(f"HTTP lookup for target {token_url}"): + async with self.logger.trace_timing( + f"HTTP lookup for target {token_url}, trace ID: {request.trace_identifier}" + ): async with async_http_client.get(url=token_url, headers=http_headers, max_redirects=20) as response: dcv_check_response = await MpicDcvChecker.evaluate_http_lookup_response( request, dcv_check_response, response, token_url, expected_response_content ) except asyncio.TimeoutError as e: dcv_check_response.timestamp_ns = time.time_ns() - log_message = f"Timeout connecting to {token_url}: {str(e)}. Trace identifier: {request.trace_identifier}" + log_message = f"Timeout connecting to {token_url}: {str(e)}. Trace ID: {request.trace_identifier}" self.logger.warning(log_message) message = f"Connection timed out while attempting to connect to {token_url}" dcv_check_response.errors = [ MpicValidationError.create(ErrorMessages.DCV_LOOKUP_ERROR, e.__class__.__name__, message) ] except (ClientError, HTTPException, OSError) as e: - log_message = f"Error connecting to {token_url}: {str(e)}. Trace identifier: {request.trace_identifier}" + log_message = f"Error connecting to {token_url}: {str(e)}. Trace ID: {request.trace_identifier}" self.logger.error(log_message) dcv_check_response.timestamp_ns = time.time_ns() dcv_check_response.errors = [ @@ -304,6 +279,8 @@ async def evaluate_http_lookup_response( dcv_check_response.details.response_history = response_history dcv_check_response.details.response_page = base64.b64encode(content).decode() + http_response.close() # ensure connection is closed + return dcv_check_response @staticmethod @@ -363,7 +340,7 @@ def evaluate_dns_lookup_response( # This code will flatten a list of record sets should a CNAME come with multiple records in the record set. # Per the RFCs, there can only ever be one CNAME in a CNAME record set. for cname_record in cname_record_set: - cname_chain_str.append( b".".join(cname_record.target.labels).decode("utf-8")) + cname_chain_str.append(b".".join(cname_record.target.labels).decode("utf-8")) dcv_check_response.details.cname_chain = cname_chain_str dcv_check_response.details.found_at = dns_response.qname.to_text(omit_final_dot=True) diff --git a/tests/unit/open_mpic_core/test_mpic_dcv_checker.py b/tests/unit/open_mpic_core/test_mpic_dcv_checker.py index 31b1c75..d11277b 100644 --- a/tests/unit/open_mpic_core/test_mpic_dcv_checker.py +++ b/tests/unit/open_mpic_core/test_mpic_dcv_checker.py @@ -80,17 +80,6 @@ def mpic_dcv_checker__should_be_able_to_log_at_trace_level(self): log_contents = self.log_output.getvalue() assert all(text in log_contents for text in [test_message, "TRACE", dcv_checker.logger.name]) - @pytest.mark.parametrize("reuse_http_client", [True, False]) - async def mpic_dcv_checker__should_optionally_reuse_http_client(self, reuse_http_client): - dcv_checker = MpicDcvChecker(reuse_http_client=reuse_http_client, log_level=TRACE_LEVEL) - async with dcv_checker.get_async_http_client() as client_1: - async with dcv_checker.get_async_http_client() as client_2: - try: - assert (client_1 is client_2) == reuse_http_client - finally: - if reuse_http_client: - await dcv_checker.shutdown() - # integration test of a sort -- only mocking dns methods rather than remaining class methods @pytest.mark.parametrize( "dcv_method, record_type", @@ -938,7 +927,6 @@ async def side_effect(qname, rdtype): return test_dns_query_answer raise self.raise_(dns.resolver.NoAnswer) - return self.patch_resolver_resolve_with_side_effect(mocker, self.dcv_checker.resolver, side_effect) def _mock_dns_resolve_call_with_specific_response_code(