Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions open-tofu/aws-perspective.tf.template
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ resource "aws_lambda_function" "mpic_dcv_checker_lambda_{{region}}" {
timeout = 60
runtime = "python3.11"
architectures = ["arm64"]
memory_size = var.perspective_memory_size
layers = [
aws_lambda_layer_version.python3_open_mpic_layer_{{region}}.arn,
]
Expand Down Expand Up @@ -240,6 +241,7 @@ resource "aws_lambda_function" "mpic_caa_checker_lambda_{{region}}" {
timeout = 60
runtime = "python3.11"
architectures = ["arm64"]
memory_size = var.perspective_memory_size
layers = [
aws_lambda_layer_version.python3_open_mpic_layer_{{region}}.arn,
]
Expand Down
7 changes: 6 additions & 1 deletion open-tofu/variables.tf
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@ variable "dnssec_enabled" {
variable "coordinator_memory_size" {
type = number
description = "MPIC Coordinator Lambda Function Memory"
default = 512
default = 256
}

variable "perspective_memory_size" {
type = number
description = "MPIC Perspective Lambda Function Memory"
default = 256
}
2 changes: 1 addition & 1 deletion src/aws_lambda_mpic/__about__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "1.0.4"
__version__ = "1.0.5"
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,7 @@ def __init__(self):
)

def process_invocation(self, caa_request: CaaCheckRequest):
try:
event_loop = asyncio.get_running_loop()
except RuntimeError:
# No running event loop, create a new one
event_loop = asyncio.new_event_loop()
asyncio.set_event_loop(event_loop)

caa_response = event_loop.run_until_complete(self.caa_checker.check_caa(caa_request))
caa_response = asyncio.get_event_loop().run_until_complete(self.caa_checker.check_caa(caa_request))
result = {
"statusCode": 200, # note: must be snakeCase
"headers": {"Content-Type": "application/json"},
Expand All @@ -50,10 +43,16 @@ def get_handler() -> MpicCaaCheckerLambdaHandler:
"""
global _handler
if _handler is None:
event_loop = asyncio.new_event_loop()
asyncio.set_event_loop(event_loop)
_handler = MpicCaaCheckerLambdaHandler()
return _handler


if os.environ.get("AWS_LAMBDA_FUNCTION_NAME") is not None:
get_handler()


# noinspection PyUnusedLocal
# for now, we are not using context, but it is required by the lambda handler signature
@event_parser(model=CaaCheckRequest)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,27 +82,13 @@ def __init__(self):
self.mpic_request_adapter = TypeAdapter(MpicRequest)
self.check_response_adapter = TypeAdapter(CheckResponse)

self._session = aioboto3.Session()
self._client_pools = defaultdict(lambda: Queue(maxsize=10)) # pool of 10 clients per region
self._clients = {}
asyncio.get_event_loop().run_until_complete(self.initialize_clients())

async def initialize_client_pools(self):
# Call this during cold start
async def initialize_clients(self):
session = aioboto3.Session()
for perspective_code in self._all_target_perspective_codes:
for _ in range(10): # prepopulate pool
client = await self._session.client("lambda", perspective_code).__aenter__()
await self._client_pools[perspective_code].put(client)

async def get_lambda_client(self, perspective_code: str):
return await self._client_pools[perspective_code].get()

async def release_lambda_client(self, perspective_code: str, client):
await self._client_pools[perspective_code].put(client)

# async def cleanup(self): # Call this during shutdown if needed (maybe not needed in Lambda)
# for pool in self._client_pools.values():
# while not pool.empty():
# client = await pool.get()
# await client.__aexit__(None, None, None)
self._clients[perspective_code] = await session.client("lambda", perspective_code).__aenter__()

@staticmethod
def load_aws_region_config() -> dict[str, RemotePerspective]:
Expand Down Expand Up @@ -136,24 +122,21 @@ def convert_codes_to_remote_perspectives(
async def call_remote_perspective(
self, perspective: RemotePerspective, check_type: CheckType, check_request: CheckRequest
) -> CheckResponse:
client = await self.get_lambda_client(perspective.code)
try:
function_endpoint_info = self.remotes_per_perspective_per_check_type[check_type][perspective.code]
response = await client.invoke( # AWS Lambda-specific structure
FunctionName=function_endpoint_info.arn,
InvocationType="RequestResponse",
Payload=check_request.model_dump_json(), # AWS Lambda functions expect a JSON string for payload
)
response_payload = await response["Payload"].read()
if 'FunctionError' in response:
raise LambdaExecutionException(f"Lambda execution error: {response_payload.decode('utf-8')}")
response_payload = json.loads(response_payload)
return self.check_response_adapter.validate_json(response_payload["body"])
finally:
await self.release_lambda_client(perspective.code, client)

async def process_invocation(self, mpic_request: MpicRequest) -> dict:
mpic_response = await self.mpic_coordinator.coordinate_mpic(mpic_request)
client = self._clients[perspective.code]
function_endpoint_info = self.remotes_per_perspective_per_check_type[check_type][perspective.code]
response = await client.invoke( # AWS Lambda-specific structure
FunctionName=function_endpoint_info.arn,
InvocationType="RequestResponse",
Payload=check_request.model_dump_json(), # AWS Lambda functions expect a JSON string for payload
)
response_payload = await response["Payload"].read()
if 'FunctionError' in response:
raise LambdaExecutionException(f"Lambda execution error: {response_payload.decode('utf-8')}")
response_payload = json.loads(response_payload)
return self.check_response_adapter.validate_json(response_payload["body"])

def process_invocation(self, mpic_request: MpicRequest) -> dict:
mpic_response = asyncio.get_event_loop().run_until_complete(self.mpic_coordinator.coordinate_mpic(mpic_request))
return {
"statusCode": 200,
"headers": {"Content-Type": "application/json"},
Expand All @@ -165,30 +148,24 @@ async def process_invocation(self, mpic_request: MpicRequest) -> dict:
_handler = None


async def initialize_handler() -> MpicCoordinatorLambdaHandler:
handler = MpicCoordinatorLambdaHandler()
await handler.initialize_client_pools()
return handler


def get_handler() -> MpicCoordinatorLambdaHandler:
"""
Singleton pattern to avoid recreating the handler on every Lambda invocation.
Performs lazy initialization using event loop.
"""
global _handler
if _handler is None:
try:
event_loop = asyncio.get_running_loop()
except RuntimeError:
# No running event loop, create a new one
event_loop = asyncio.new_event_loop()
asyncio.set_event_loop(event_loop)

_handler = event_loop.run_until_complete(initialize_handler())
event_loop = asyncio.new_event_loop()
asyncio.set_event_loop(event_loop)
_handler = MpicCoordinatorLambdaHandler()
return _handler


# Not eagerly initialize the handler when running in a test environment as it leads to errors.
if os.environ.get("AWS_LAMBDA_FUNCTION_NAME") is not None:
get_handler()


def handle_lambda_exceptions(func):
def build_400_response(error_name, issues_list):
return {
Expand Down Expand Up @@ -224,5 +201,4 @@ def wrapper(*args, **kwargs):
@handle_lambda_exceptions
@event_parser(model=MpicRequest, envelope=envelopes.ApiGatewayEnvelope) # AWS Lambda Powertools decorator
def lambda_handler(event: MpicRequest, context): # AWS Lambda entry point
handler = get_handler()
return asyncio.get_event_loop().run_until_complete(handler.process_invocation(event))
return get_handler().process_invocation(event)
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,10 @@ def __init__(self):
self.dcv_checker = MpicDcvChecker(reuse_http_client=False, log_level=self.logger.level)

def process_invocation(self, dcv_request: DcvCheckRequest):
try:
event_loop = asyncio.get_running_loop()
except RuntimeError:
# No running event loop, create a new one
event_loop = asyncio.new_event_loop()
asyncio.set_event_loop(event_loop)

self.logger.debug("(debug log) Processing DCV check request: %s", dcv_request)
print("(print) Processing DCV check request: %s", dcv_request)

dcv_response = event_loop.run_until_complete(self.dcv_checker.check_dcv(dcv_request))
dcv_response = asyncio.get_event_loop().run_until_complete(self.dcv_checker.check_dcv(dcv_request))
status_code = 200
if dcv_response.errors is not None and len(dcv_response.errors) > 0:
if dcv_response.errors[0].error_type == "404":
Expand All @@ -55,10 +48,16 @@ def get_handler() -> MpicDcvCheckerLambdaHandler:
"""
global _handler
if _handler is None:
event_loop = asyncio.new_event_loop()
asyncio.set_event_loop(event_loop)
_handler = MpicDcvCheckerLambdaHandler()
return _handler


if os.environ.get("AWS_LAMBDA_FUNCTION_NAME") is not None:
get_handler()


# noinspection PyUnusedLocal
# for now, we are not using context, but it is required by the lambda handler signature
@event_parser(model=DcvCheckRequest)
Expand Down
21 changes: 11 additions & 10 deletions tests/unit/aws_lambda_mpic/test_mpic_coordinator_lambda.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import io
import json
from datetime import datetime
Expand Down Expand Up @@ -76,16 +77,16 @@ def set_env_variables():
class_scoped_monkeypatch.setenv(k, v)
yield class_scoped_monkeypatch # restore the environment afterward

async def call_remote_perspective__should_make_aws_lambda_call_with_provided_arguments_and_return_check_response(
def call_remote_perspective__should_make_aws_lambda_call_with_provided_arguments_and_return_check_response(
self, set_env_variables, mocker
):
lambda_handler, mock_client = await self.mock_lambda_handler_for_lambda_invoke(mocker, self.create_successful_aioboto3_response_for_dcv_check)
lambda_handler, mock_client = self.mock_lambda_handler_for_lambda_invoke(mocker, self.create_successful_aioboto3_response_for_dcv_check)

dcv_check_request = ValidCheckCreator.create_valid_dns_check_request()
perspective_code = "us-west-1"
check_response = await lambda_handler.call_remote_perspective(
check_response = asyncio.get_event_loop().run_until_complete(lambda_handler.call_remote_perspective(
RemotePerspective(code=perspective_code, rir="arin"), CheckType.DCV, dcv_check_request
)
))
assert check_response.check_passed is True
# hijacking the value of 'details.found_at' to verify that the right arguments got passed to the call
assert check_response.details.found_at == dcv_check_request.domain_or_ip_target
Expand All @@ -99,18 +100,18 @@ async def call_remote_perspective__should_make_aws_lambda_call_with_provided_arg
Payload=dcv_check_request.model_dump_json(),
)

async def call_remote_perspective__should_make_aws_lambda_call_and_handle_lambda_execution_exceptions(
def call_remote_perspective__should_make_aws_lambda_call_and_handle_lambda_execution_exceptions(
self, set_env_variables, mocker
):
lambda_handler, mock_client = await self.mock_lambda_handler_for_lambda_invoke(mocker, self.create_error_aioboto3_response)
lambda_handler, mock_client = self.mock_lambda_handler_for_lambda_invoke(mocker, self.create_error_aioboto3_response)

class Dummy(BaseModel):
pass

with pytest.raises(LambdaExecutionException) as exc_info:
await lambda_handler.call_remote_perspective(
asyncio.get_event_loop().run_until_complete(lambda_handler.call_remote_perspective(
RemotePerspective(code="us-west-1", rir="arin"), CheckType.DCV, Dummy()
)
))
assert exc_info.value.args[0] == "Lambda execution error: {\"errorMessage\": \"some message\"}"

def lambda_handler__should_return_400_error_and_details_given_invalid_request_body(self):
Expand Down Expand Up @@ -319,7 +320,7 @@ def get_perspectives_by_code_dict_from_file() -> dict[str, RemotePerspective]:
perspectives = perspective_type_adapter.validate_python(perspectives_yaml["aws_available_regions"])
return {perspective.code: perspective for perspective in perspectives}

async def mock_lambda_handler_for_lambda_invoke(self, mocker, lambda_invoke_side_effect):
def mock_lambda_handler_for_lambda_invoke(self, mocker, lambda_invoke_side_effect):
# Mock the aioboto3 client creation and context manager
mock_client = AsyncMock()
mock_client.invoke = AsyncMock(side_effect=lambda_invoke_side_effect)
Expand All @@ -329,8 +330,8 @@ async def mock_lambda_handler_for_lambda_invoke(self, mocker, lambda_invoke_side
mock_session = mocker.patch("aioboto3.Session")
mock_session.return_value.client.return_value = mock_client
lambda_handler = MpicCoordinatorLambdaHandler()
await lambda_handler.initialize_client_pools()
return lambda_handler, mock_client


if __name__ == "__main__":
pytest.main()