Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ def __init__(self):
self.perspective_code = os.environ['AWS_REGION']
self.dcv_checker = MpicDcvChecker(self.perspective_code)

async def initialize(self):
await self.dcv_checker.initialize()

def process_invocation(self, dcv_request: DcvCheckRequest):
try:
event_loop = asyncio.get_running_loop()
Expand Down Expand Up @@ -39,13 +42,26 @@ def process_invocation(self, dcv_request: DcvCheckRequest):
_handler = None


async def initialize_handler() -> MpicDcvCheckerLambdaHandler:
handler = MpicDcvCheckerLambdaHandler()
await handler.initialize()
return handler


def get_handler() -> MpicDcvCheckerLambdaHandler:
"""
Singleton pattern to avoid recreating the handler on every Lambda invocation
"""
global _handler
if _handler is None:
_handler = MpicDcvCheckerLambdaHandler()
try:
event_loop = asyncio.get_running_loop()
except RuntimeError:
# No running event loop, create a new one
event_loop = asyncio.new_event_loop()
asyncio.set_event_loop(event_loop)

_handler = event_loop.run_until_complete(initialize_handler())
return _handler


Expand Down
33 changes: 33 additions & 0 deletions tests/unit/aws_lambda_mpic/test_dcv_checker_lambda.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
import asyncio
import time
from asyncio import StreamReader
from unittest.mock import MagicMock, AsyncMock

import pytest
from aiohttp import ClientResponse
from multidict import CIMultiDictProxy, CIMultiDict
from yarl import URL

import aws_lambda_mpic.mpic_dcv_checker_lambda.mpic_dcv_checker_lambda_function as mpic_dcv_checker_lambda_function
from open_mpic_core.common_domain.validation_error import MpicValidationError
Expand All @@ -9,6 +16,7 @@
from open_mpic_core_test.test_util.valid_check_creator import ValidCheckCreator


# noinspection PyMethodMayBeStatic
class TestDcvCheckerLambda:
@staticmethod
@pytest.fixture(scope='class')
Expand Down Expand Up @@ -53,6 +61,31 @@ def lambda_handler__should_return_appropriate_status_code_given_errors_in_respon
result = mpic_dcv_checker_lambda_function.lambda_handler(dcv_check_request, None)
assert result == mock_return_value

def lambda_handler__should_ensure_dcv_checker_is_fully_initialized_to_perform_http_based_checks(self, set_env_variables, mocker):
dcv_check_request = ValidCheckCreator.create_valid_http_check_request()
expected_challenge_value = dcv_check_request.dcv_check_parameters.validation_details.challenge_value

# this test requires getting pretty far into the Dcv Checker execution; need to mock an aiohttp.ClientResponse
event_loop = asyncio.get_event_loop()
response = ClientResponse(
method='GET', url=URL('http://example.com'), writer=MagicMock(), continue100=None,
timer=AsyncMock(), request_info=AsyncMock(), traces=[], loop=event_loop, session=AsyncMock()
)
response.status = 200
response.content = StreamReader(loop=event_loop)
response.content.feed_data(bytes(expected_challenge_value.encode('utf-8')))
response.content.feed_eof()
response._headers = CIMultiDictProxy(CIMultiDict({
'Content-Type': 'text/plain; charset=utf-8', 'Content-Length': str(len(expected_challenge_value))
}))

mocker.patch(
'aiohttp.ClientSession.get',
side_effect=lambda *args, **kwargs: AsyncMock(__aenter__=AsyncMock(return_value=response))
)
result = mpic_dcv_checker_lambda_function.lambda_handler(dcv_check_request, None)
assert result['statusCode'] == 200

@staticmethod
def create_dcv_check_response():
return DcvCheckResponse(perspective_code='us-east-1', check_passed=True,
Expand Down
Loading