Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,9 @@ dependencies = [
"pydantic==2.8.2",
"aiohttp==3.11.11",
"aws-lambda-powertools[parser]==3.2.0",
"open-mpic-core==4.6.1",
"open-mpic-core==4.7.2",
"aioboto3~=13.3.0",
"black==24.8.0",
]

[project.optional-dependencies]
Expand Down Expand Up @@ -135,3 +136,6 @@ include_namespace_packages = true
omit = [
"*/src/*/__about__.py",
]

[tool.black]
line-length = 120
2 changes: 1 addition & 1 deletion src/aws_lambda_mpic/__about__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.4.0"
__version__ = "0.4.1"
Original file line number Diff line number Diff line change
Expand Up @@ -3,26 +3,28 @@

from aws_lambda_powertools.utilities.parser import event_parser

from open_mpic_core.common_domain.check_request import CaaCheckRequest
from open_mpic_core.mpic_caa_checker.mpic_caa_checker import MpicCaaChecker
from open_mpic_core.common_util.trace_level_logger import get_logger
from open_mpic_core import CaaCheckRequest
from open_mpic_core import MpicCaaChecker
from open_mpic_core import get_logger

logger = get_logger(__name__)


class MpicCaaCheckerLambdaHandler:
def __init__(self):
self.perspective_code = os.environ['AWS_REGION']
self.default_caa_domain_list = os.environ['default_caa_domains'].split("|")
self.log_level = os.environ['log_level'] if 'log_level' in os.environ else None
self.perspective_code = os.environ["AWS_REGION"]
self.default_caa_domain_list = os.environ["default_caa_domains"].split("|")
self.log_level = os.environ["log_level"] if "log_level" in os.environ else None

self.logger = logger.getChild(self.__class__.__name__)
if self.log_level:
self.logger.setLevel(self.log_level)

self.caa_checker = MpicCaaChecker(default_caa_domain_list=self.default_caa_domain_list,
perspective_code=self.perspective_code,
log_level=self.logger.level)
self.caa_checker = MpicCaaChecker(
default_caa_domain_list=self.default_caa_domain_list,
perspective_code=self.perspective_code,
log_level=self.logger.level,
)

def process_invocation(self, caa_request: CaaCheckRequest):
try:
Expand All @@ -34,9 +36,9 @@ def process_invocation(self, caa_request: CaaCheckRequest):

caa_response = event_loop.run_until_complete(self.caa_checker.check_caa(caa_request))
result = {
'statusCode': 200, # note: must be snakeCase
'headers': {'Content-Type': 'application/json'},
'body': caa_response.model_dump_json()
"statusCode": 200, # note: must be snakeCase
"headers": {"Content-Type": "application/json"},
"body": caa_response.model_dump_json(),
}
return result

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import logging
import os
import json
import traceback
Expand All @@ -13,15 +12,12 @@
from pydantic import TypeAdapter, ValidationError, BaseModel
from aws_lambda_powertools.utilities.parser import event_parser, envelopes

from open_mpic_core.common_domain.check_request import BaseCheckRequest
from open_mpic_core.common_domain.check_response import CheckResponse
from open_mpic_core.mpic_coordinator.domain.mpic_request import MpicRequest
from open_mpic_core.mpic_coordinator.domain.mpic_request_validation_error import MpicRequestValidationError
from open_mpic_core.mpic_coordinator.messages.mpic_request_validation_messages import MpicRequestValidationMessages
from open_mpic_core.mpic_coordinator.mpic_coordinator import MpicCoordinator, MpicCoordinatorConfiguration
from open_mpic_core.common_domain.enum.check_type import CheckType
from open_mpic_core.mpic_coordinator.domain.remote_perspective import RemotePerspective
from open_mpic_core.common_util.trace_level_logger import get_logger
from open_mpic_core import MpicRequest, CheckRequest, CheckResponse
from open_mpic_core import MpicRequestValidationError, MpicRequestValidationMessages
from open_mpic_core import MpicCoordinator, MpicCoordinatorConfiguration
from open_mpic_core import CheckType
from open_mpic_core import RemotePerspective
from open_mpic_core import get_logger

logger = get_logger(__name__)

Expand All @@ -37,35 +33,46 @@ class PerspectiveEndpoints(BaseModel):

class MpicCoordinatorLambdaHandler:
def __init__(self):
perspectives_json = os.environ['perspectives']
perspectives = {code: PerspectiveEndpoints.model_validate(endpoints) for code, endpoints in json.loads(perspectives_json).items()}
perspectives_json = os.environ["perspectives"]
perspectives = {
code: PerspectiveEndpoints.model_validate(endpoints)
for code, endpoints in json.loads(perspectives_json).items()
}
self._all_target_perspective_codes = list(perspectives.keys())
self.default_perspective_count = int(os.environ['default_perspective_count'])
self.global_max_attempts = int(os.environ['absolute_max_attempts']) if 'absolute_max_attempts' in os.environ else None
self.hash_secret = os.environ['hash_secret']
self.log_level = os.getenv('log_level', None)
self.default_perspective_count = int(os.environ["default_perspective_count"])
self.global_max_attempts = (
int(os.environ["absolute_max_attempts"]) if "absolute_max_attempts" in os.environ else None
)
self.hash_secret = os.environ["hash_secret"]
self.log_level = os.getenv("log_level", None)

self.logger = logger.getChild(self.__class__.__name__)
if self.log_level:
self.logger.setLevel(self.log_level)

self.remotes_per_perspective_per_check_type = {
CheckType.DCV: {perspective_code: perspective_config.dcv_endpoint_info for perspective_code, perspective_config in perspectives.items()},
CheckType.CAA: {perspective_code: perspective_config.caa_endpoint_info for perspective_code, perspective_config in perspectives.items()}
CheckType.DCV: {
perspective_code: perspective_config.dcv_endpoint_info
for perspective_code, perspective_config in perspectives.items()
},
CheckType.CAA: {
perspective_code: perspective_config.caa_endpoint_info
for perspective_code, perspective_config in perspectives.items()
},
}

all_possible_perspectives_by_code = MpicCoordinatorLambdaHandler.load_aws_region_config()
self.target_perspectives = MpicCoordinatorLambdaHandler.convert_codes_to_remote_perspectives(
self._all_target_perspective_codes, all_possible_perspectives_by_code)
self._all_target_perspective_codes, all_possible_perspectives_by_code
)

self.mpic_coordinator_configuration = MpicCoordinatorConfiguration(
self.target_perspectives,
self.default_perspective_count,
self.global_max_attempts,
self.hash_secret
self.target_perspectives, self.default_perspective_count, self.global_max_attempts, self.hash_secret
)

self.mpic_coordinator = MpicCoordinator(self.call_remote_perspective, self.mpic_coordinator_configuration, self.logger.level)
self.mpic_coordinator = MpicCoordinator(
self.call_remote_perspective, self.mpic_coordinator_configuration, self.logger.level
)

# for correct deserialization of responses based on discriminator field (check type)
self.mpic_request_adapter = TypeAdapter(MpicRequest)
Expand All @@ -78,7 +85,7 @@ async def initialize_client_pools(self):
# Call this during cold start
for perspective_code in self._all_target_perspective_codes:
for _ in range(10): # pre-populate pool
client = await self._session.client('lambda', perspective_code).__aenter__()
client = await self._session.client("lambda", perspective_code).__aenter__()
await self._client_pools[perspective_code].put(client)

async def get_lambda_client(self, perspective_code: str):
Expand All @@ -99,16 +106,17 @@ def load_aws_region_config() -> dict[str, RemotePerspective]:
Reads in the available perspectives from a configuration yaml and returns them as a dict (map).
:return: dict of available perspectives with region code as key
"""
with resources.files('resources').joinpath('aws_region_config.yaml').open('r') as file:
with resources.files("resources").joinpath("aws_region_config.yaml").open("r") as file:
aws_region_config_yaml = yaml.safe_load(file)
aws_region_type_adapter = TypeAdapter(list[RemotePerspective])
aws_regions_list = aws_region_type_adapter.validate_python(aws_region_config_yaml['aws_available_regions'])
aws_regions_list = aws_region_type_adapter.validate_python(aws_region_config_yaml["aws_available_regions"])
aws_regions_dict = {region.code: region for region in aws_regions_list}
return aws_regions_dict

@staticmethod
def convert_codes_to_remote_perspectives(perspective_codes: list[str],
all_possible_perspectives_by_code: dict[str, RemotePerspective]) -> list[RemotePerspective]:
def convert_codes_to_remote_perspectives(
perspective_codes: list[str], all_possible_perspectives_by_code: dict[str, RemotePerspective]
) -> list[RemotePerspective]:
remote_perspectives = []

for perspective_code in perspective_codes:
Expand All @@ -121,29 +129,31 @@ def convert_codes_to_remote_perspectives(perspective_codes: list[str],
return remote_perspectives

# This function MUST validate its response and return a proper open_mpic_core object type.
async def call_remote_perspective(self, perspective: RemotePerspective, check_type: CheckType, check_request: BaseCheckRequest) -> CheckResponse:
async def call_remote_perspective(
self, perspective: RemotePerspective, check_type: CheckType, check_request: CheckRequest
) -> CheckResponse:
client = await self.get_lambda_client(perspective.code)
try:
function_endpoint_info = self.remotes_per_perspective_per_check_type[check_type][perspective.code]
response = await client.invoke( # AWS Lambda-specific structure
FunctionName=function_endpoint_info.arn,
InvocationType='RequestResponse',
Payload=check_request.model_dump_json() # AWS Lambda functions expect a JSON string for payload
InvocationType="RequestResponse",
Payload=check_request.model_dump_json(), # AWS Lambda functions expect a JSON string for payload
)
response_payload = json.loads(await response['Payload'].read())
return self.check_response_adapter.validate_json(response_payload['body'])
response_payload = json.loads(await response["Payload"].read())
return self.check_response_adapter.validate_json(response_payload["body"])
except ValidationError as ve:
self.logger.log(level=logging.ERROR, msg=f"Validation error in response from {perspective.code}: {ve}")
self.logger.error(msg=f"Validation error in response from {perspective.code}: {ve}")
raise ve
finally:
await self.release_lambda_client(perspective.code, client)

async def process_invocation(self, mpic_request: MpicRequest) -> dict:
mpic_response = await self.mpic_coordinator.coordinate_mpic(mpic_request)
return {
'statusCode': 200,
'headers': {'Content-Type': 'application/json'},
'body': mpic_response.model_dump_json()
"statusCode": 200,
"headers": {"Content-Type": "application/json"},
"body": mpic_response.model_dump_json(),
}


Expand Down Expand Up @@ -178,9 +188,9 @@ def get_handler() -> MpicCoordinatorLambdaHandler:
def handle_lambda_exceptions(func):
def build_400_response(error_name, issues_list):
return {
'statusCode': 400,
'headers': {'Content-Type': 'application/json'},
'body': json.dumps({'error': error_name, 'validation_issues': issues_list})
"statusCode": 400,
"headers": {"Content-Type": "application/json"},
"body": json.dumps({"error": error_name, "validation_issues": issues_list}),
}

def wrapper(*args, **kwargs):
Expand All @@ -190,16 +200,18 @@ def wrapper(*args, **kwargs):
validation_issues = json.loads(e.__notes__[0])
return build_400_response(MpicRequestValidationMessages.REQUEST_VALIDATION_FAILED.key, validation_issues)
except ValidationError as validation_error:
return build_400_response(MpicRequestValidationMessages.REQUEST_VALIDATION_FAILED.key, validation_error.errors())
return build_400_response(
MpicRequestValidationMessages.REQUEST_VALIDATION_FAILED.key, validation_error.errors()
)
except Exception as e:
logger.error(f"An error occurred: {str(e)}")
print(traceback.format_exc())
print(f"BOY HOWDY error occurred: {str(e)}")
return {
'statusCode': 500,
'headers': {'Content-Type': 'application/json'},
'body': json.dumps({'error': str(e)})
"statusCode": 500,
"headers": {"Content-Type": "application/json"},
"body": json.dumps({"error": str(e)}),
}

return wrapper


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,24 @@

from aws_lambda_powertools.utilities.parser import event_parser

from open_mpic_core.common_domain.check_request import DcvCheckRequest
from open_mpic_core.mpic_dcv_checker.mpic_dcv_checker import MpicDcvChecker
from open_mpic_core.common_util.trace_level_logger import get_logger
from open_mpic_core import DcvCheckRequest, MpicDcvChecker
from open_mpic_core import get_logger

logger = get_logger(__name__)


class MpicDcvCheckerLambdaHandler:
def __init__(self):
self.perspective_code = os.environ['AWS_REGION']
self.log_level = os.environ['log_level'] if 'log_level' in os.environ else None
self.perspective_code = os.environ["AWS_REGION"]
self.log_level = os.environ["log_level"] if "log_level" in os.environ else None

self.logger = logger.getChild(self.__class__.__name__)
if self.log_level:
self.logger.setLevel(self.log_level)

self.dcv_checker = MpicDcvChecker(perspective_code=self.perspective_code,
reuse_http_client=False,
log_level=self.logger.level)
self.dcv_checker = MpicDcvChecker(
perspective_code=self.perspective_code, reuse_http_client=False, log_level=self.logger.level
)

def process_invocation(self, dcv_request: DcvCheckRequest):
try:
Expand All @@ -37,14 +36,14 @@ def process_invocation(self, dcv_request: DcvCheckRequest):
dcv_response = event_loop.run_until_complete(self.dcv_checker.check_dcv(dcv_request))
status_code = 200
if dcv_response.errors is not None and len(dcv_response.errors) > 0:
if dcv_response.errors[0].error_type == '404':
if dcv_response.errors[0].error_type == "404":
status_code = 404
else:
status_code = 500
result = {
'statusCode': status_code,
'headers': {'Content-Type': 'application/json'},
'body': dcv_response.model_dump_json()
"statusCode": status_code,
"headers": {"Content-Type": "application/json"},
"body": dcv_response.model_dump_json(),
}
return result

Expand Down
Loading