diff --git a/pyproject.toml b/pyproject.toml index c2789d0..e08cc41 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,19 +31,19 @@ dependencies = [ "requests>=2.32.3", "dnspython==2.6.1", "pydantic==2.8.2", + "aiohttp==3.11.11", "aws-lambda-powertools[parser]==3.2.0", - "open-mpic-core==3.1.0", + "open-mpic-core==4.1.2", + "aioboto3~=13.3.0", ] [project.optional-dependencies] -provided = [ - "boto3~=1.34.141", -] test = [ "pytest==8.2.2", "pytest-cov==5.0.0", "pytest-mock==3.14.0", "pytest-html==4.1.1", + "pytest-asyncio==0.25.1", ] [project.urls] @@ -88,7 +88,6 @@ install = "pip install . --platform manylinux2014_aarch64 --only-binary=:all: -- skip-install = false features = [ "test", - "provided" ] installer = "pip" @@ -123,6 +122,8 @@ markers = [ addopts = [ "--import-mode=prepend", # explicit default, as the tests rely on it for proper import resolution ] +asyncio_mode = "auto" # defaults to "strict" +asyncio_default_fixture_loop_scope = "function" [tool.coverage.run] source = [ diff --git a/src/aws_lambda_mpic/mpic_caa_checker_lambda/mpic_caa_checker_lambda_function.py b/src/aws_lambda_mpic/mpic_caa_checker_lambda/mpic_caa_checker_lambda_function.py index 3be5ade..25cb41a 100644 --- a/src/aws_lambda_mpic/mpic_caa_checker_lambda/mpic_caa_checker_lambda_function.py +++ b/src/aws_lambda_mpic/mpic_caa_checker_lambda/mpic_caa_checker_lambda_function.py @@ -1,3 +1,5 @@ +import asyncio + from aws_lambda_powertools.utilities.parser import event_parser from open_mpic_core.common_domain.check_request import CaaCheckRequest @@ -12,7 +14,14 @@ def __init__(self): self.caa_checker = MpicCaaChecker(self.default_caa_domain_list, self.perspective_code) def process_invocation(self, caa_request: CaaCheckRequest): - caa_response = self.caa_checker.check_caa(caa_request) + try: + event_loop = asyncio.get_running_loop() + except RuntimeError: + # No running event loop, create a new one + event_loop = asyncio.new_event_loop() + asyncio.set_event_loop(event_loop) + + caa_response = event_loop.run_until_complete(self.caa_checker.check_caa(caa_request)) result = { 'statusCode': 200, # note: must be snakeCase 'headers': {'Content-Type': 'application/json'}, diff --git a/src/aws_lambda_mpic/mpic_coordinator_lambda/mpic_coordinator_lambda_function.py b/src/aws_lambda_mpic/mpic_coordinator_lambda/mpic_coordinator_lambda_function.py index df87a6a..fba8820 100644 --- a/src/aws_lambda_mpic/mpic_coordinator_lambda/mpic_coordinator_lambda_function.py +++ b/src/aws_lambda_mpic/mpic_coordinator_lambda/mpic_coordinator_lambda_function.py @@ -1,6 +1,13 @@ +import os +import json +import yaml +import asyncio +import aioboto3 + +from asyncio import Queue +from collections import defaultdict from importlib import resources -import yaml from aws_lambda_powertools.utilities.parser import event_parser, envelopes from pydantic import TypeAdapter, ValidationError, BaseModel from open_mpic_core.common_domain.check_request import BaseCheckRequest @@ -12,10 +19,6 @@ from open_mpic_core.common_domain.enum.check_type import CheckType from open_mpic_core.mpic_coordinator.domain.remote_perspective import RemotePerspective -import boto3 -import os -import json - class PerspectiveEndpointInfo(BaseModel): arn: str @@ -30,7 +33,7 @@ class MpicCoordinatorLambdaHandler: def __init__(self): perspectives_json = os.environ['perspectives'] perspectives = {code: PerspectiveEndpoints.model_validate(endpoints) for code, endpoints in json.loads(perspectives_json).items()} - self.all_target_perspective_codes = list(perspectives.keys()) + self._all_target_perspective_codes = list(perspectives.keys()) self.default_perspective_count = int(os.environ['default_perspective_count']) self.global_max_attempts = int(os.environ['absolute_max_attempts']) if 'absolute_max_attempts' in os.environ else None self.hash_secret = os.environ['hash_secret'] @@ -42,7 +45,7 @@ def __init__(self): all_possible_perspectives_by_code = MpicCoordinatorLambdaHandler.load_aws_region_config() self.target_perspectives = MpicCoordinatorLambdaHandler.convert_codes_to_remote_perspectives( - self.all_target_perspective_codes, all_possible_perspectives_by_code) + self._all_target_perspective_codes, all_possible_perspectives_by_code) self.mpic_coordinator_configuration = MpicCoordinatorConfiguration( self.target_perspectives, @@ -60,6 +63,28 @@ def __init__(self): self.mpic_request_adapter = TypeAdapter(MpicRequest) self.check_response_adapter = TypeAdapter(CheckResponse) + self._session = aioboto3.Session() + self._client_pools = defaultdict(lambda: Queue(maxsize=10)) # pool of 10 clients per region + + async def initialize_client_pools(self): + # Call this during cold start + for perspective_code in self._all_target_perspective_codes: + for _ in range(10): # pre-populate pool + client = await self._session.client('lambda', perspective_code).__aenter__() + await self._client_pools[perspective_code].put(client) + + async def get_lambda_client(self, perspective_code: str): + return await self._client_pools[perspective_code].get() + + async def release_lambda_client(self, perspective_code: str, client): + await self._client_pools[perspective_code].put(client) + + # async def cleanup(self): # Call this during shutdown if needed (maybe not needed in Lambda) + # for pool in self._client_pools.values(): + # while not pool.empty(): + # client = await pool.get() + # await client.__aexit__(None, None, None) + @staticmethod def load_aws_region_config() -> dict[str, RemotePerspective]: """ @@ -88,24 +113,25 @@ def convert_codes_to_remote_perspectives(perspective_codes: list[str], return remote_perspectives # This function MUST validate its response and return a proper open_mpic_core object type. - def call_remote_perspective(self, perspective: RemotePerspective, check_type: CheckType, check_request: BaseCheckRequest) -> CheckResponse: - # Uses dcv_arn_list, caa_arn_list - client = boto3.client('lambda', perspective.code) - function_endpoint_info = self.remotes_per_perspective_per_check_type[check_type][perspective.code] - response = client.invoke( # AWS Lambda-specific structure + async def call_remote_perspective(self, perspective: RemotePerspective, check_type: CheckType, check_request: BaseCheckRequest) -> CheckResponse: + client = await self.get_lambda_client(perspective.code) + try: + function_endpoint_info = self.remotes_per_perspective_per_check_type[check_type][perspective.code] + response = await client.invoke( # AWS Lambda-specific structure FunctionName=function_endpoint_info.arn, InvocationType='RequestResponse', Payload=check_request.model_dump_json() # AWS Lambda functions expect a JSON string for payload ) - response_payload = json.loads(response['Payload'].read().decode('utf-8')) - try: + response_payload = json.loads(await response['Payload'].read()) return self.check_response_adapter.validate_json(response_payload['body']) except ValidationError as ve: # We might want to handle this differently later. raise ve + finally: + await self.release_lambda_client(perspective.code, client) - def process_invocation(self, mpic_request: MpicRequest) -> dict: - mpic_response = self.mpic_coordinator.coordinate_mpic(mpic_request) + async def process_invocation(self, mpic_request: MpicRequest) -> dict: + mpic_response = await self.mpic_coordinator.coordinate_mpic(mpic_request) return { 'statusCode': 200, 'headers': {'Content-Type': 'application/json'}, @@ -117,13 +143,27 @@ def process_invocation(self, mpic_request: MpicRequest) -> dict: _handler = None +async def initialize_handler() -> MpicCoordinatorLambdaHandler: + handler = MpicCoordinatorLambdaHandler() + await handler.initialize_client_pools() + return handler + + def get_handler() -> MpicCoordinatorLambdaHandler: """ - Singleton pattern to avoid recreating the handler on every Lambda invocation + Singleton pattern to avoid recreating the handler on every Lambda invocation. + Performs lazy initialization using event loop. """ global _handler if _handler is None: - _handler = MpicCoordinatorLambdaHandler() + try: + event_loop = asyncio.get_running_loop() + except RuntimeError: + # No running event loop, create a new one + event_loop = asyncio.new_event_loop() + asyncio.set_event_loop(event_loop) + + _handler = event_loop.run_until_complete(initialize_handler()) return _handler @@ -157,4 +197,5 @@ def wrapper(*args, **kwargs): @handle_lambda_exceptions @event_parser(model=MpicRequest, envelope=envelopes.ApiGatewayEnvelope) # AWS Lambda Powertools decorator def lambda_handler(event: MpicRequest, context): # AWS Lambda entry point - return get_handler().process_invocation(event) + handler = get_handler() + return asyncio.get_event_loop().run_until_complete(handler.process_invocation(event)) diff --git a/src/aws_lambda_mpic/mpic_dcv_checker_lambda/mpic_dcv_checker_lambda_function.py b/src/aws_lambda_mpic/mpic_dcv_checker_lambda/mpic_dcv_checker_lambda_function.py index faebb3f..7b9c357 100644 --- a/src/aws_lambda_mpic/mpic_dcv_checker_lambda/mpic_dcv_checker_lambda_function.py +++ b/src/aws_lambda_mpic/mpic_dcv_checker_lambda/mpic_dcv_checker_lambda_function.py @@ -1,3 +1,5 @@ +import asyncio + from aws_lambda_powertools.utilities.parser import event_parser from open_mpic_core.common_domain.check_request import DcvCheckRequest @@ -11,7 +13,14 @@ def __init__(self): self.dcv_checker = MpicDcvChecker(self.perspective_code) def process_invocation(self, dcv_request: DcvCheckRequest): - dcv_response = self.dcv_checker.check_dcv(dcv_request) + try: + event_loop = asyncio.get_running_loop() + except RuntimeError: + # No running event loop, create a new one + event_loop = asyncio.new_event_loop() + asyncio.set_event_loop(event_loop) + + dcv_response = event_loop.run_until_complete(self.dcv_checker.check_dcv(dcv_request)) status_code = 200 if dcv_response.errors is not None and len(dcv_response.errors) > 0: if dcv_response.errors[0].error_type == '404': diff --git a/tests/unit/aws_lambda_mpic/test_mpic_coordinator_lambda.py b/tests/unit/aws_lambda_mpic/test_mpic_coordinator_lambda.py index 47d9408..ba726fd 100644 --- a/tests/unit/aws_lambda_mpic/test_mpic_coordinator_lambda.py +++ b/tests/unit/aws_lambda_mpic/test_mpic_coordinator_lambda.py @@ -2,6 +2,7 @@ import json from datetime import datetime from importlib import resources +from unittest.mock import AsyncMock import pytest import yaml @@ -55,17 +56,43 @@ def set_env_variables(): class_scoped_monkeypatch.setenv(k, v) yield class_scoped_monkeypatch # restore the environment afterward - def call_remote_perspective__should_make_aws_lambda_call_with_provided_arguments_and_return_check_response(self, set_env_variables, mocker): - mocker.patch('botocore.client.BaseClient._make_api_call', side_effect=self.create_successful_boto3_api_call_response_for_dcv_check) + async def call_remote_perspective__should_make_aws_lambda_call_with_provided_arguments_and_return_check_response(self, set_env_variables, mocker): + # Mock the aioboto3 client creation and context manager + mock_client = AsyncMock() + mock_client.invoke = AsyncMock(side_effect=self.create_successful_aioboto3_response_for_dcv_check) + + # Mock the __aenter__ method that gets called in initialize_client_pools() + mock_client.__aenter__.return_value = mock_client + + # Mock the session creation and client initialization + mock_session = mocker.patch('aioboto3.Session') + mock_session.return_value.client.return_value = mock_client + + # mocker.patch('botocore.client.BaseClient._make_api_call', side_effect=self.create_successful_boto3_api_call_response_for_dcv_check) dcv_check_request = ValidCheckCreator.create_valid_dns_check_request() mpic_coordinator_lambda_handler = MpicCoordinatorLambdaHandler() - check_response = mpic_coordinator_lambda_handler.call_remote_perspective(RemotePerspective(code='us-west-1', rir='arin'), - CheckType.DCV, - dcv_check_request) + + await mpic_coordinator_lambda_handler.initialize_client_pools() + + perspective_code = 'us-west-1' + check_response = await mpic_coordinator_lambda_handler.call_remote_perspective( + RemotePerspective(code=perspective_code, rir='arin'), + CheckType.DCV, + dcv_check_request + ) assert check_response.check_passed is True # hijacking the value of 'perspective_code' to verify that the right arguments got passed to the call assert check_response.perspective_code == dcv_check_request.domain_or_ip_target + function_endpoint_info = mpic_coordinator_lambda_handler.remotes_per_perspective_per_check_type[CheckType.DCV][perspective_code] + + # Verify the mock was called correctly + mock_client.invoke.assert_called_once_with( + FunctionName=function_endpoint_info.arn, + InvocationType='RequestResponse', + Payload=dcv_check_request.model_dump_json() + ) + def lambda_handler__should_return_400_error_and_details_given_invalid_request_body(self): request = ValidMpicRequestCreator.create_valid_dcv_mpic_request() # noinspection PyTypeChecker @@ -89,7 +116,7 @@ def lambda_handler__should_return_400_error_and_details_given_invalid_check_type result_body = json.loads(result['body']) assert result_body['validation_issues'][0]['type'] == 'literal_error' - def lambda_handler__should_return_400_error_given_logically_invalid_request(self): + def lambda_handler__should_return_400_error_given_logically_invalid_request(self, set_env_variables): request = ValidMpicRequestCreator.create_valid_dcv_mpic_request() request.orchestration_parameters.perspective_count = 1 api_request = TestMpicCoordinatorLambda.create_api_gateway_request() @@ -158,14 +185,20 @@ def create_successful_boto3_api_call_response_for_dcv_check(self, lambda_method, return {'Payload': streaming_body_response} # noinspection PyUnusedLocal - def create_error_boto3_api_call_response(self, lambda_method, lambda_configuration): - # note: all perspective response details will be identical in these tests due to this mocking - expected_response_body = 'Something went wrong' - expected_response = {'statusCode': 500, 'body': expected_response_body} + async def create_successful_aioboto3_response_for_dcv_check(self, *args, **kwargs): + check_request = DcvCheckRequest.model_validate_json(kwargs['Payload']) + # hijacking the value of 'perspective_code' to verify that the right arguments got passed to the call + expected_response_body = DcvCheckResponse(perspective_code=check_request.domain_or_ip_target, + check_passed=True, details=DcvDnsCheckResponseDetails(validation_method=DcvValidationMethod.ACME_DNS_01)) + expected_response = {'statusCode': 200, 'body': expected_response_body.model_dump_json()} json_bytes = json.dumps(expected_response).encode('utf-8') - file_like_response = io.BytesIO(json_bytes) - streaming_body_response = StreamingBody(file_like_response, len(json_bytes)) - return {'Payload': streaming_body_response} + + # Mock the response structure that aioboto3 would return + class MockStreamingBody: + # noinspection PyMethodMayBeStatic + async def read(self): + return json_bytes + return {'Payload': MockStreamingBody()} @staticmethod def create_caa_mpic_response():