diff --git a/open-tofu/aws-perspective.tf.template b/open-tofu/aws-perspective.tf.template index f2665e1..9a6b8f6 100644 --- a/open-tofu/aws-perspective.tf.template +++ b/open-tofu/aws-perspective.tf.template @@ -210,6 +210,7 @@ resource "aws_lambda_function" "mpic_dcv_checker_lambda_{{region}}" { timeout = 60 runtime = "python3.11" architectures = ["arm64"] + memory_size = var.perspective_memory_size layers = [ aws_lambda_layer_version.python3_open_mpic_layer_{{region}}.arn, ] @@ -240,6 +241,7 @@ resource "aws_lambda_function" "mpic_caa_checker_lambda_{{region}}" { timeout = 60 runtime = "python3.11" architectures = ["arm64"] + memory_size = var.perspective_memory_size layers = [ aws_lambda_layer_version.python3_open_mpic_layer_{{region}}.arn, ] diff --git a/open-tofu/variables.tf b/open-tofu/variables.tf index f91a357..d3c545a 100644 --- a/open-tofu/variables.tf +++ b/open-tofu/variables.tf @@ -7,6 +7,11 @@ variable "dnssec_enabled" { variable "coordinator_memory_size" { type = number description = "MPIC Coordinator Lambda Function Memory" - default = 512 + default = 256 +} +variable "perspective_memory_size" { + type = number + description = "MPIC Perspective Lambda Function Memory" + default = 256 } diff --git a/src/aws_lambda_mpic/__about__.py b/src/aws_lambda_mpic/__about__.py index 92192ee..68cdeee 100644 --- a/src/aws_lambda_mpic/__about__.py +++ b/src/aws_lambda_mpic/__about__.py @@ -1 +1 @@ -__version__ = "1.0.4" +__version__ = "1.0.5" diff --git a/src/aws_lambda_mpic/mpic_caa_checker_lambda/mpic_caa_checker_lambda_function.py b/src/aws_lambda_mpic/mpic_caa_checker_lambda/mpic_caa_checker_lambda_function.py index 32c0ec7..e1188ed 100644 --- a/src/aws_lambda_mpic/mpic_caa_checker_lambda/mpic_caa_checker_lambda_function.py +++ b/src/aws_lambda_mpic/mpic_caa_checker_lambda/mpic_caa_checker_lambda_function.py @@ -24,14 +24,7 @@ def __init__(self): ) def process_invocation(self, caa_request: CaaCheckRequest): - try: - event_loop = asyncio.get_running_loop() - except RuntimeError: - # No running event loop, create a new one - event_loop = asyncio.new_event_loop() - asyncio.set_event_loop(event_loop) - - caa_response = event_loop.run_until_complete(self.caa_checker.check_caa(caa_request)) + caa_response = asyncio.get_event_loop().run_until_complete(self.caa_checker.check_caa(caa_request)) result = { "statusCode": 200, # note: must be snakeCase "headers": {"Content-Type": "application/json"}, @@ -50,10 +43,16 @@ def get_handler() -> MpicCaaCheckerLambdaHandler: """ global _handler if _handler is None: + event_loop = asyncio.new_event_loop() + asyncio.set_event_loop(event_loop) _handler = MpicCaaCheckerLambdaHandler() return _handler +if os.environ.get("AWS_LAMBDA_FUNCTION_NAME") is not None: + get_handler() + + # noinspection PyUnusedLocal # for now, we are not using context, but it is required by the lambda handler signature @event_parser(model=CaaCheckRequest) diff --git a/src/aws_lambda_mpic/mpic_coordinator_lambda/mpic_coordinator_lambda_function.py b/src/aws_lambda_mpic/mpic_coordinator_lambda/mpic_coordinator_lambda_function.py index a6daec1..b80c7c9 100644 --- a/src/aws_lambda_mpic/mpic_coordinator_lambda/mpic_coordinator_lambda_function.py +++ b/src/aws_lambda_mpic/mpic_coordinator_lambda/mpic_coordinator_lambda_function.py @@ -82,27 +82,13 @@ def __init__(self): self.mpic_request_adapter = TypeAdapter(MpicRequest) self.check_response_adapter = TypeAdapter(CheckResponse) - self._session = aioboto3.Session() - self._client_pools = defaultdict(lambda: Queue(maxsize=10)) # pool of 10 clients per region + self._clients = {} + asyncio.get_event_loop().run_until_complete(self.initialize_clients()) - async def initialize_client_pools(self): - # Call this during cold start + async def initialize_clients(self): + session = aioboto3.Session() for perspective_code in self._all_target_perspective_codes: - for _ in range(10): # prepopulate pool - client = await self._session.client("lambda", perspective_code).__aenter__() - await self._client_pools[perspective_code].put(client) - - async def get_lambda_client(self, perspective_code: str): - return await self._client_pools[perspective_code].get() - - async def release_lambda_client(self, perspective_code: str, client): - await self._client_pools[perspective_code].put(client) - - # async def cleanup(self): # Call this during shutdown if needed (maybe not needed in Lambda) - # for pool in self._client_pools.values(): - # while not pool.empty(): - # client = await pool.get() - # await client.__aexit__(None, None, None) + self._clients[perspective_code] = await session.client("lambda", perspective_code).__aenter__() @staticmethod def load_aws_region_config() -> dict[str, RemotePerspective]: @@ -136,24 +122,21 @@ def convert_codes_to_remote_perspectives( async def call_remote_perspective( self, perspective: RemotePerspective, check_type: CheckType, check_request: CheckRequest ) -> CheckResponse: - client = await self.get_lambda_client(perspective.code) - try: - function_endpoint_info = self.remotes_per_perspective_per_check_type[check_type][perspective.code] - response = await client.invoke( # AWS Lambda-specific structure - FunctionName=function_endpoint_info.arn, - InvocationType="RequestResponse", - Payload=check_request.model_dump_json(), # AWS Lambda functions expect a JSON string for payload - ) - response_payload = await response["Payload"].read() - if 'FunctionError' in response: - raise LambdaExecutionException(f"Lambda execution error: {response_payload.decode('utf-8')}") - response_payload = json.loads(response_payload) - return self.check_response_adapter.validate_json(response_payload["body"]) - finally: - await self.release_lambda_client(perspective.code, client) - - async def process_invocation(self, mpic_request: MpicRequest) -> dict: - mpic_response = await self.mpic_coordinator.coordinate_mpic(mpic_request) + client = self._clients[perspective.code] + function_endpoint_info = self.remotes_per_perspective_per_check_type[check_type][perspective.code] + response = await client.invoke( # AWS Lambda-specific structure + FunctionName=function_endpoint_info.arn, + InvocationType="RequestResponse", + Payload=check_request.model_dump_json(), # AWS Lambda functions expect a JSON string for payload + ) + response_payload = await response["Payload"].read() + if 'FunctionError' in response: + raise LambdaExecutionException(f"Lambda execution error: {response_payload.decode('utf-8')}") + response_payload = json.loads(response_payload) + return self.check_response_adapter.validate_json(response_payload["body"]) + + def process_invocation(self, mpic_request: MpicRequest) -> dict: + mpic_response = asyncio.get_event_loop().run_until_complete(self.mpic_coordinator.coordinate_mpic(mpic_request)) return { "statusCode": 200, "headers": {"Content-Type": "application/json"}, @@ -165,12 +148,6 @@ async def process_invocation(self, mpic_request: MpicRequest) -> dict: _handler = None -async def initialize_handler() -> MpicCoordinatorLambdaHandler: - handler = MpicCoordinatorLambdaHandler() - await handler.initialize_client_pools() - return handler - - def get_handler() -> MpicCoordinatorLambdaHandler: """ Singleton pattern to avoid recreating the handler on every Lambda invocation. @@ -178,17 +155,17 @@ def get_handler() -> MpicCoordinatorLambdaHandler: """ global _handler if _handler is None: - try: - event_loop = asyncio.get_running_loop() - except RuntimeError: - # No running event loop, create a new one - event_loop = asyncio.new_event_loop() - asyncio.set_event_loop(event_loop) - - _handler = event_loop.run_until_complete(initialize_handler()) + event_loop = asyncio.new_event_loop() + asyncio.set_event_loop(event_loop) + _handler = MpicCoordinatorLambdaHandler() return _handler +# Not eagerly initialize the handler when running in a test environment as it leads to errors. +if os.environ.get("AWS_LAMBDA_FUNCTION_NAME") is not None: + get_handler() + + def handle_lambda_exceptions(func): def build_400_response(error_name, issues_list): return { @@ -224,5 +201,4 @@ def wrapper(*args, **kwargs): @handle_lambda_exceptions @event_parser(model=MpicRequest, envelope=envelopes.ApiGatewayEnvelope) # AWS Lambda Powertools decorator def lambda_handler(event: MpicRequest, context): # AWS Lambda entry point - handler = get_handler() - return asyncio.get_event_loop().run_until_complete(handler.process_invocation(event)) + return get_handler().process_invocation(event) diff --git a/src/aws_lambda_mpic/mpic_dcv_checker_lambda/mpic_dcv_checker_lambda_function.py b/src/aws_lambda_mpic/mpic_dcv_checker_lambda/mpic_dcv_checker_lambda_function.py index b41868c..ed923e5 100644 --- a/src/aws_lambda_mpic/mpic_dcv_checker_lambda/mpic_dcv_checker_lambda_function.py +++ b/src/aws_lambda_mpic/mpic_dcv_checker_lambda/mpic_dcv_checker_lambda_function.py @@ -20,17 +20,10 @@ def __init__(self): self.dcv_checker = MpicDcvChecker(reuse_http_client=False, log_level=self.logger.level) def process_invocation(self, dcv_request: DcvCheckRequest): - try: - event_loop = asyncio.get_running_loop() - except RuntimeError: - # No running event loop, create a new one - event_loop = asyncio.new_event_loop() - asyncio.set_event_loop(event_loop) - self.logger.debug("(debug log) Processing DCV check request: %s", dcv_request) print("(print) Processing DCV check request: %s", dcv_request) - dcv_response = event_loop.run_until_complete(self.dcv_checker.check_dcv(dcv_request)) + dcv_response = asyncio.get_event_loop().run_until_complete(self.dcv_checker.check_dcv(dcv_request)) status_code = 200 if dcv_response.errors is not None and len(dcv_response.errors) > 0: if dcv_response.errors[0].error_type == "404": @@ -55,10 +48,16 @@ def get_handler() -> MpicDcvCheckerLambdaHandler: """ global _handler if _handler is None: + event_loop = asyncio.new_event_loop() + asyncio.set_event_loop(event_loop) _handler = MpicDcvCheckerLambdaHandler() return _handler +if os.environ.get("AWS_LAMBDA_FUNCTION_NAME") is not None: + get_handler() + + # noinspection PyUnusedLocal # for now, we are not using context, but it is required by the lambda handler signature @event_parser(model=DcvCheckRequest) diff --git a/tests/unit/aws_lambda_mpic/test_mpic_coordinator_lambda.py b/tests/unit/aws_lambda_mpic/test_mpic_coordinator_lambda.py index 27d2a07..db7b2f4 100644 --- a/tests/unit/aws_lambda_mpic/test_mpic_coordinator_lambda.py +++ b/tests/unit/aws_lambda_mpic/test_mpic_coordinator_lambda.py @@ -1,3 +1,4 @@ +import asyncio import io import json from datetime import datetime @@ -76,16 +77,16 @@ def set_env_variables(): class_scoped_monkeypatch.setenv(k, v) yield class_scoped_monkeypatch # restore the environment afterward - async def call_remote_perspective__should_make_aws_lambda_call_with_provided_arguments_and_return_check_response( + def call_remote_perspective__should_make_aws_lambda_call_with_provided_arguments_and_return_check_response( self, set_env_variables, mocker ): - lambda_handler, mock_client = await self.mock_lambda_handler_for_lambda_invoke(mocker, self.create_successful_aioboto3_response_for_dcv_check) + lambda_handler, mock_client = self.mock_lambda_handler_for_lambda_invoke(mocker, self.create_successful_aioboto3_response_for_dcv_check) dcv_check_request = ValidCheckCreator.create_valid_dns_check_request() perspective_code = "us-west-1" - check_response = await lambda_handler.call_remote_perspective( + check_response = asyncio.get_event_loop().run_until_complete(lambda_handler.call_remote_perspective( RemotePerspective(code=perspective_code, rir="arin"), CheckType.DCV, dcv_check_request - ) + )) assert check_response.check_passed is True # hijacking the value of 'details.found_at' to verify that the right arguments got passed to the call assert check_response.details.found_at == dcv_check_request.domain_or_ip_target @@ -99,18 +100,18 @@ async def call_remote_perspective__should_make_aws_lambda_call_with_provided_arg Payload=dcv_check_request.model_dump_json(), ) - async def call_remote_perspective__should_make_aws_lambda_call_and_handle_lambda_execution_exceptions( + def call_remote_perspective__should_make_aws_lambda_call_and_handle_lambda_execution_exceptions( self, set_env_variables, mocker ): - lambda_handler, mock_client = await self.mock_lambda_handler_for_lambda_invoke(mocker, self.create_error_aioboto3_response) + lambda_handler, mock_client = self.mock_lambda_handler_for_lambda_invoke(mocker, self.create_error_aioboto3_response) class Dummy(BaseModel): pass with pytest.raises(LambdaExecutionException) as exc_info: - await lambda_handler.call_remote_perspective( + asyncio.get_event_loop().run_until_complete(lambda_handler.call_remote_perspective( RemotePerspective(code="us-west-1", rir="arin"), CheckType.DCV, Dummy() - ) + )) assert exc_info.value.args[0] == "Lambda execution error: {\"errorMessage\": \"some message\"}" def lambda_handler__should_return_400_error_and_details_given_invalid_request_body(self): @@ -319,7 +320,7 @@ def get_perspectives_by_code_dict_from_file() -> dict[str, RemotePerspective]: perspectives = perspective_type_adapter.validate_python(perspectives_yaml["aws_available_regions"]) return {perspective.code: perspective for perspective in perspectives} - async def mock_lambda_handler_for_lambda_invoke(self, mocker, lambda_invoke_side_effect): + def mock_lambda_handler_for_lambda_invoke(self, mocker, lambda_invoke_side_effect): # Mock the aioboto3 client creation and context manager mock_client = AsyncMock() mock_client.invoke = AsyncMock(side_effect=lambda_invoke_side_effect) @@ -329,8 +330,8 @@ async def mock_lambda_handler_for_lambda_invoke(self, mocker, lambda_invoke_side mock_session = mocker.patch("aioboto3.Session") mock_session.return_value.client.return_value = mock_client lambda_handler = MpicCoordinatorLambdaHandler() - await lambda_handler.initialize_client_pools() return lambda_handler, mock_client + if __name__ == "__main__": pytest.main()