Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,19 +31,19 @@ dependencies = [
"requests>=2.32.3",
"dnspython==2.6.1",
"pydantic==2.8.2",
"aiohttp==3.11.11",
"aws-lambda-powertools[parser]==3.2.0",
"open-mpic-core==3.1.0",
"open-mpic-core==4.1.2",
"aioboto3~=13.3.0",
]

[project.optional-dependencies]
provided = [
"boto3~=1.34.141",
]
test = [
"pytest==8.2.2",
"pytest-cov==5.0.0",
"pytest-mock==3.14.0",
"pytest-html==4.1.1",
"pytest-asyncio==0.25.1",
]

[project.urls]
Expand Down Expand Up @@ -88,7 +88,6 @@ install = "pip install . --platform manylinux2014_aarch64 --only-binary=:all: --
skip-install = false
features = [
"test",
"provided"
]
installer = "pip"

Expand Down Expand Up @@ -123,6 +122,8 @@ markers = [
addopts = [
"--import-mode=prepend", # explicit default, as the tests rely on it for proper import resolution
]
asyncio_mode = "auto" # defaults to "strict"
asyncio_default_fixture_loop_scope = "function"

[tool.coverage.run]
source = [
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import asyncio

from aws_lambda_powertools.utilities.parser import event_parser

from open_mpic_core.common_domain.check_request import CaaCheckRequest
Expand All @@ -12,7 +14,14 @@ def __init__(self):
self.caa_checker = MpicCaaChecker(self.default_caa_domain_list, self.perspective_code)

def process_invocation(self, caa_request: CaaCheckRequest):
caa_response = self.caa_checker.check_caa(caa_request)
try:
event_loop = asyncio.get_running_loop()
except RuntimeError:
# No running event loop, create a new one
event_loop = asyncio.new_event_loop()
asyncio.set_event_loop(event_loop)

caa_response = event_loop.run_until_complete(self.caa_checker.check_caa(caa_request))
result = {
'statusCode': 200, # note: must be snakeCase
'headers': {'Content-Type': 'application/json'},
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
import os
import json
import yaml
import asyncio
import aioboto3

from asyncio import Queue
from collections import defaultdict
from importlib import resources

import yaml
from aws_lambda_powertools.utilities.parser import event_parser, envelopes
from pydantic import TypeAdapter, ValidationError, BaseModel
from open_mpic_core.common_domain.check_request import BaseCheckRequest
Expand All @@ -12,10 +19,6 @@
from open_mpic_core.common_domain.enum.check_type import CheckType
from open_mpic_core.mpic_coordinator.domain.remote_perspective import RemotePerspective

import boto3
import os
import json


class PerspectiveEndpointInfo(BaseModel):
arn: str
Expand All @@ -30,7 +33,7 @@ class MpicCoordinatorLambdaHandler:
def __init__(self):
perspectives_json = os.environ['perspectives']
perspectives = {code: PerspectiveEndpoints.model_validate(endpoints) for code, endpoints in json.loads(perspectives_json).items()}
self.all_target_perspective_codes = list(perspectives.keys())
self._all_target_perspective_codes = list(perspectives.keys())
self.default_perspective_count = int(os.environ['default_perspective_count'])
self.global_max_attempts = int(os.environ['absolute_max_attempts']) if 'absolute_max_attempts' in os.environ else None
self.hash_secret = os.environ['hash_secret']
Expand All @@ -42,7 +45,7 @@ def __init__(self):

all_possible_perspectives_by_code = MpicCoordinatorLambdaHandler.load_aws_region_config()
self.target_perspectives = MpicCoordinatorLambdaHandler.convert_codes_to_remote_perspectives(
self.all_target_perspective_codes, all_possible_perspectives_by_code)
self._all_target_perspective_codes, all_possible_perspectives_by_code)

self.mpic_coordinator_configuration = MpicCoordinatorConfiguration(
self.target_perspectives,
Expand All @@ -60,6 +63,28 @@ def __init__(self):
self.mpic_request_adapter = TypeAdapter(MpicRequest)
self.check_response_adapter = TypeAdapter(CheckResponse)

self._session = aioboto3.Session()
self._client_pools = defaultdict(lambda: Queue(maxsize=10)) # pool of 10 clients per region

async def initialize_client_pools(self):
# Call this during cold start
for perspective_code in self._all_target_perspective_codes:
for _ in range(10): # pre-populate pool
client = await self._session.client('lambda', perspective_code).__aenter__()
await self._client_pools[perspective_code].put(client)

async def get_lambda_client(self, perspective_code: str):
return await self._client_pools[perspective_code].get()

async def release_lambda_client(self, perspective_code: str, client):
await self._client_pools[perspective_code].put(client)

# async def cleanup(self): # Call this during shutdown if needed (maybe not needed in Lambda)
# for pool in self._client_pools.values():
# while not pool.empty():
# client = await pool.get()
# await client.__aexit__(None, None, None)

@staticmethod
def load_aws_region_config() -> dict[str, RemotePerspective]:
"""
Expand Down Expand Up @@ -88,24 +113,25 @@ def convert_codes_to_remote_perspectives(perspective_codes: list[str],
return remote_perspectives

# This function MUST validate its response and return a proper open_mpic_core object type.
def call_remote_perspective(self, perspective: RemotePerspective, check_type: CheckType, check_request: BaseCheckRequest) -> CheckResponse:
# Uses dcv_arn_list, caa_arn_list
client = boto3.client('lambda', perspective.code)
function_endpoint_info = self.remotes_per_perspective_per_check_type[check_type][perspective.code]
response = client.invoke( # AWS Lambda-specific structure
async def call_remote_perspective(self, perspective: RemotePerspective, check_type: CheckType, check_request: BaseCheckRequest) -> CheckResponse:
client = await self.get_lambda_client(perspective.code)
try:
function_endpoint_info = self.remotes_per_perspective_per_check_type[check_type][perspective.code]
response = await client.invoke( # AWS Lambda-specific structure
FunctionName=function_endpoint_info.arn,
InvocationType='RequestResponse',
Payload=check_request.model_dump_json() # AWS Lambda functions expect a JSON string for payload
)
response_payload = json.loads(response['Payload'].read().decode('utf-8'))
try:
response_payload = json.loads(await response['Payload'].read())
return self.check_response_adapter.validate_json(response_payload['body'])
except ValidationError as ve:
# We might want to handle this differently later.
raise ve
finally:
await self.release_lambda_client(perspective.code, client)

def process_invocation(self, mpic_request: MpicRequest) -> dict:
mpic_response = self.mpic_coordinator.coordinate_mpic(mpic_request)
async def process_invocation(self, mpic_request: MpicRequest) -> dict:
mpic_response = await self.mpic_coordinator.coordinate_mpic(mpic_request)
return {
'statusCode': 200,
'headers': {'Content-Type': 'application/json'},
Expand All @@ -117,13 +143,27 @@ def process_invocation(self, mpic_request: MpicRequest) -> dict:
_handler = None


async def initialize_handler() -> MpicCoordinatorLambdaHandler:
handler = MpicCoordinatorLambdaHandler()
await handler.initialize_client_pools()
return handler


def get_handler() -> MpicCoordinatorLambdaHandler:
"""
Singleton pattern to avoid recreating the handler on every Lambda invocation
Singleton pattern to avoid recreating the handler on every Lambda invocation.
Performs lazy initialization using event loop.
"""
global _handler
if _handler is None:
_handler = MpicCoordinatorLambdaHandler()
try:
event_loop = asyncio.get_running_loop()
except RuntimeError:
# No running event loop, create a new one
event_loop = asyncio.new_event_loop()
asyncio.set_event_loop(event_loop)

_handler = event_loop.run_until_complete(initialize_handler())
return _handler


Expand Down Expand Up @@ -157,4 +197,5 @@ def wrapper(*args, **kwargs):
@handle_lambda_exceptions
@event_parser(model=MpicRequest, envelope=envelopes.ApiGatewayEnvelope) # AWS Lambda Powertools decorator
def lambda_handler(event: MpicRequest, context): # AWS Lambda entry point
return get_handler().process_invocation(event)
handler = get_handler()
return asyncio.get_event_loop().run_until_complete(handler.process_invocation(event))
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import asyncio

from aws_lambda_powertools.utilities.parser import event_parser

from open_mpic_core.common_domain.check_request import DcvCheckRequest
Expand All @@ -11,7 +13,14 @@ def __init__(self):
self.dcv_checker = MpicDcvChecker(self.perspective_code)

def process_invocation(self, dcv_request: DcvCheckRequest):
dcv_response = self.dcv_checker.check_dcv(dcv_request)
try:
event_loop = asyncio.get_running_loop()
except RuntimeError:
# No running event loop, create a new one
event_loop = asyncio.new_event_loop()
asyncio.set_event_loop(event_loop)

dcv_response = event_loop.run_until_complete(self.dcv_checker.check_dcv(dcv_request))
status_code = 200
if dcv_response.errors is not None and len(dcv_response.errors) > 0:
if dcv_response.errors[0].error_type == '404':
Expand Down
59 changes: 46 additions & 13 deletions tests/unit/aws_lambda_mpic/test_mpic_coordinator_lambda.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import json
from datetime import datetime
from importlib import resources
from unittest.mock import AsyncMock

import pytest
import yaml
Expand Down Expand Up @@ -55,17 +56,43 @@ def set_env_variables():
class_scoped_monkeypatch.setenv(k, v)
yield class_scoped_monkeypatch # restore the environment afterward

def call_remote_perspective__should_make_aws_lambda_call_with_provided_arguments_and_return_check_response(self, set_env_variables, mocker):
mocker.patch('botocore.client.BaseClient._make_api_call', side_effect=self.create_successful_boto3_api_call_response_for_dcv_check)
async def call_remote_perspective__should_make_aws_lambda_call_with_provided_arguments_and_return_check_response(self, set_env_variables, mocker):
# Mock the aioboto3 client creation and context manager
mock_client = AsyncMock()
mock_client.invoke = AsyncMock(side_effect=self.create_successful_aioboto3_response_for_dcv_check)

# Mock the __aenter__ method that gets called in initialize_client_pools()
mock_client.__aenter__.return_value = mock_client

# Mock the session creation and client initialization
mock_session = mocker.patch('aioboto3.Session')
mock_session.return_value.client.return_value = mock_client

# mocker.patch('botocore.client.BaseClient._make_api_call', side_effect=self.create_successful_boto3_api_call_response_for_dcv_check)
dcv_check_request = ValidCheckCreator.create_valid_dns_check_request()
mpic_coordinator_lambda_handler = MpicCoordinatorLambdaHandler()
check_response = mpic_coordinator_lambda_handler.call_remote_perspective(RemotePerspective(code='us-west-1', rir='arin'),
CheckType.DCV,
dcv_check_request)

await mpic_coordinator_lambda_handler.initialize_client_pools()

perspective_code = 'us-west-1'
check_response = await mpic_coordinator_lambda_handler.call_remote_perspective(
RemotePerspective(code=perspective_code, rir='arin'),
CheckType.DCV,
dcv_check_request
)
assert check_response.check_passed is True
# hijacking the value of 'perspective_code' to verify that the right arguments got passed to the call
assert check_response.perspective_code == dcv_check_request.domain_or_ip_target

function_endpoint_info = mpic_coordinator_lambda_handler.remotes_per_perspective_per_check_type[CheckType.DCV][perspective_code]

# Verify the mock was called correctly
mock_client.invoke.assert_called_once_with(
FunctionName=function_endpoint_info.arn,
InvocationType='RequestResponse',
Payload=dcv_check_request.model_dump_json()
)

def lambda_handler__should_return_400_error_and_details_given_invalid_request_body(self):
request = ValidMpicRequestCreator.create_valid_dcv_mpic_request()
# noinspection PyTypeChecker
Expand All @@ -89,7 +116,7 @@ def lambda_handler__should_return_400_error_and_details_given_invalid_check_type
result_body = json.loads(result['body'])
assert result_body['validation_issues'][0]['type'] == 'literal_error'

def lambda_handler__should_return_400_error_given_logically_invalid_request(self):
def lambda_handler__should_return_400_error_given_logically_invalid_request(self, set_env_variables):
request = ValidMpicRequestCreator.create_valid_dcv_mpic_request()
request.orchestration_parameters.perspective_count = 1
api_request = TestMpicCoordinatorLambda.create_api_gateway_request()
Expand Down Expand Up @@ -158,14 +185,20 @@ def create_successful_boto3_api_call_response_for_dcv_check(self, lambda_method,
return {'Payload': streaming_body_response}

# noinspection PyUnusedLocal
def create_error_boto3_api_call_response(self, lambda_method, lambda_configuration):
# note: all perspective response details will be identical in these tests due to this mocking
expected_response_body = 'Something went wrong'
expected_response = {'statusCode': 500, 'body': expected_response_body}
async def create_successful_aioboto3_response_for_dcv_check(self, *args, **kwargs):
check_request = DcvCheckRequest.model_validate_json(kwargs['Payload'])
# hijacking the value of 'perspective_code' to verify that the right arguments got passed to the call
expected_response_body = DcvCheckResponse(perspective_code=check_request.domain_or_ip_target,
check_passed=True, details=DcvDnsCheckResponseDetails(validation_method=DcvValidationMethod.ACME_DNS_01))
expected_response = {'statusCode': 200, 'body': expected_response_body.model_dump_json()}
json_bytes = json.dumps(expected_response).encode('utf-8')
file_like_response = io.BytesIO(json_bytes)
streaming_body_response = StreamingBody(file_like_response, len(json_bytes))
return {'Payload': streaming_body_response}

# Mock the response structure that aioboto3 would return
class MockStreamingBody:
# noinspection PyMethodMayBeStatic
async def read(self):
return json_bytes
return {'Payload': MockStreamingBody()}

@staticmethod
def create_caa_mpic_response():
Expand Down
Loading