diff --git a/aws_lambda_powertools/logging/lambda_context.py b/aws_lambda_powertools/logging/lambda_context.py index 65e9e652a92..56fac28cee1 100644 --- a/aws_lambda_powertools/logging/lambda_context.py +++ b/aws_lambda_powertools/logging/lambda_context.py @@ -1,4 +1,9 @@ -from typing import Any +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from aws_lambda_powertools.utilities.typing import LambdaContext class LambdaContextModel: @@ -34,25 +39,47 @@ def __init__( self.function_request_id = function_request_id +def _unwrap_durable_context(context: Any) -> LambdaContext: + """Unwrap Lambda Context from DurableContext if applicable. + + Parameters + ---------- + context : object + Lambda context object or DurableContext + + Returns + ------- + LambdaContext + The unwrapped Lambda context + """ + # Check if this is a DurableContext by duck typing + if hasattr(context, "lambda_context") and hasattr(context, "state"): + return context.lambda_context + + return context + + def build_lambda_context_model(context: Any) -> LambdaContextModel: """Captures Lambda function runtime info to be used across all log statements Parameters ---------- context : object - Lambda context object + Lambda context object or DurableContext Returns ------- LambdaContextModel Lambda context only with select fields """ + # Unwrap DurableContext if applicable + lambda_context = _unwrap_durable_context(context) context = { - "function_name": context.function_name, - "function_memory_size": context.memory_limit_in_mb, - "function_arn": context.invoked_function_arn, - "function_request_id": context.aws_request_id, + "function_name": lambda_context.function_name, + "function_memory_size": lambda_context.memory_limit_in_mb, + "function_arn": lambda_context.invoked_function_arn, + "function_request_id": lambda_context.aws_request_id, } return LambdaContextModel(**context) diff --git a/aws_lambda_powertools/logging/logger.py b/aws_lambda_powertools/logging/logger.py index 89e1fb01da6..28eec0facc3 100644 --- a/aws_lambda_powertools/logging/logger.py +++ b/aws_lambda_powertools/logging/logger.py @@ -474,6 +474,11 @@ def inject_lambda_context( POWERTOOLS_LOGGER_LOG_EVENT : str instruct logger to log Lambda Event (e.g. `"true", "True", "TRUE"`) + Notes + ----- + Supports both standard Lambda Context and DurableContext from AWS Durable Execution SDK. + When DurableContext is passed, it automatically unwraps the underlying Lambda Context. + Example ------- **Captures Lambda contextual runtime info (e.g memory, arn, req_id)** diff --git a/aws_lambda_powertools/metrics/provider/base.py b/aws_lambda_powertools/metrics/provider/base.py index 4db047eae45..4da7f58b665 100644 --- a/aws_lambda_powertools/metrics/provider/base.py +++ b/aws_lambda_powertools/metrics/provider/base.py @@ -15,6 +15,26 @@ logger = logging.getLogger(__name__) +def _unwrap_durable_context(context: Any) -> LambdaContext: + """Unwrap Lambda Context from DurableContext if applicable. + + Parameters + ---------- + context : object + Lambda context object or DurableContext + + Returns + ------- + LambdaContext + The unwrapped Lambda context + """ + # Check if this is a DurableContext by duck typing + if hasattr(context, "lambda_context") and hasattr(context, "state"): + return context.lambda_context + + return context + + class BaseProvider(ABC): """ Interface to create a metrics provider. @@ -178,6 +198,11 @@ def handler(event, context): default_dimensions: dict[str, str], optional metric dimensions as key=value that will always be present + Notes + ----- + Supports both standard Lambda Context and DurableContext from AWS Durable Execution SDK. + When DurableContext is passed, it automatically unwraps the underlying Lambda Context. + Raises ------ e @@ -223,13 +248,15 @@ def _add_cold_start_metric(self, context: Any) -> None: Parameters ---------- context : Any - Lambda context + Lambda context or DurableContext """ if not cold_start.is_cold_start: return logger.debug("Adding cold start metric and function_name dimension") - self.add_cold_start_metric(context=context) + # Unwrap DurableContext if applicable before passing to add_cold_start_metric + lambda_context = _unwrap_durable_context(context) + self.add_cold_start_metric(context=lambda_context) cold_start.is_cold_start = False diff --git a/tests/functional/logger/required_dependencies/test_logger_durable_context.py b/tests/functional/logger/required_dependencies/test_logger_durable_context.py new file mode 100644 index 00000000000..0ec8f98bfbb --- /dev/null +++ b/tests/functional/logger/required_dependencies/test_logger_durable_context.py @@ -0,0 +1,158 @@ +"""Tests for Logger with DurableContext support.""" + +import io +import json +import random +import string +from collections import namedtuple +from unittest.mock import Mock + +import pytest + +from aws_lambda_powertools import Logger +from aws_lambda_powertools.utilities.typing import DurableContextProtocol + + +@pytest.fixture +def stdout(): + return io.StringIO() + + +@pytest.fixture +def lambda_context(): + lambda_context = { + "function_name": "test", + "memory_limit_in_mb": 128, + "invoked_function_arn": "arn:aws:lambda:eu-west-1:809313241:function:test", + "aws_request_id": "52fdfc07-2182-154f-163f-5f0f9a621d72", + } + return namedtuple("LambdaContext", lambda_context.keys())(*lambda_context.values()) + + +@pytest.fixture +def service_name(): + chars = string.ascii_letters + string.digits + return "".join(random.SystemRandom().choice(chars) for _ in range(15)) + + +def capture_logging_output(stdout): + return json.loads(stdout.getvalue().strip()) + + +def capture_multiple_logging_statements_output(stdout): + return [json.loads(line.strip()) for line in stdout.getvalue().split("\n") if line] + + +@pytest.fixture +def durable_context(lambda_context): + """Create a mock DurableContext with embedded Lambda context.""" + durable_ctx = Mock(spec=DurableContextProtocol) + durable_ctx.lambda_context = lambda_context + durable_ctx.state = Mock(operations=[{"id": "op1"}]) + return durable_ctx + + +def test_inject_lambda_context_with_durable_context(durable_context, stdout, service_name): + """Test that inject_lambda_context works with DurableContext.""" + # GIVEN Logger is initialized + logger = Logger(service=service_name, stream=stdout) + + # WHEN a lambda function is decorated with logger and receives DurableContext + @logger.inject_lambda_context + def handler(event, context): + logger.info("Hello from durable function") + + handler({}, durable_context) + + # THEN lambda contextual info from the unwrapped context should be in the logs + log = capture_logging_output(stdout) + + expected_logger_context_keys = ( + "function_name", + "function_memory_size", + "function_arn", + "function_request_id", + ) + for key in expected_logger_context_keys: + assert key in log + + # Verify the actual values match the embedded lambda_context + assert log["function_name"] == durable_context.lambda_context.function_name + assert log["function_memory_size"] == durable_context.lambda_context.memory_limit_in_mb + assert log["function_arn"] == durable_context.lambda_context.invoked_function_arn + assert log["function_request_id"] == durable_context.lambda_context.aws_request_id + assert log["message"] == "Hello from durable function" + + +def test_inject_lambda_context_with_durable_context_log_event(durable_context, stdout, service_name): + """Test that inject_lambda_context with log_event=True works with DurableContext.""" + # GIVEN Logger is initialized + logger = Logger(service=service_name, stream=stdout) + + test_event = {"test_key": "test_value"} + + # WHEN a lambda function is decorated with log_event=True and receives DurableContext + @logger.inject_lambda_context(log_event=True) + def handler(event, context): + logger.info("Processing event") + + handler(test_event, durable_context) + + # THEN both the event and lambda contextual info should be logged + logs = capture_multiple_logging_statements_output(stdout) + assert len(logs) >= 2 # At least event log and info log + + # First log should be the event + assert logs[0]["message"] == test_event + + +def test_inject_lambda_context_with_durable_context_clear_state(durable_context, stdout, service_name): + """Test that inject_lambda_context with clear_state works with DurableContext.""" + # GIVEN Logger is initialized with custom keys + logger = Logger(service=service_name, stream=stdout) + logger.append_keys(custom_key="initial_value") + + # WHEN a lambda function is decorated with clear_state=True and receives DurableContext + @logger.inject_lambda_context(clear_state=True) + def handler(event, context): + logger.info("After clear state") + + handler({}, durable_context) + + # THEN the custom key should be cleared and lambda context should be present + log = capture_logging_output(stdout) + + # Lambda context fields should be present + assert "function_name" in log + assert log["function_name"] == durable_context.lambda_context.function_name + + # Custom key should not be present (cleared) + assert "custom_key" not in log or log.get("custom_key") != "initial_value" + + +def test_inject_lambda_context_standard_context_still_works(lambda_context, stdout, service_name): + """Test that standard Lambda context still works (regression test).""" + # GIVEN Logger is initialized + logger = Logger(service=service_name, stream=stdout) + + # WHEN a lambda function is decorated with logger and receives standard LambdaContext + @logger.inject_lambda_context + def handler(event, context): + logger.info("Hello from standard lambda") + + handler({}, lambda_context) + + # THEN lambda contextual info should be in the logs + log = capture_logging_output(stdout) + + expected_logger_context_keys = ( + "function_name", + "function_memory_size", + "function_arn", + "function_request_id", + ) + for key in expected_logger_context_keys: + assert key in log + + assert log["function_name"] == lambda_context.function_name + assert log["message"] == "Hello from standard lambda" diff --git a/tests/functional/metrics/required_dependencies/test_metrics_durable_context.py b/tests/functional/metrics/required_dependencies/test_metrics_durable_context.py new file mode 100644 index 00000000000..ad233376057 --- /dev/null +++ b/tests/functional/metrics/required_dependencies/test_metrics_durable_context.py @@ -0,0 +1,174 @@ +"""Tests for Metrics with DurableContext support.""" + +import json +from collections import namedtuple +from unittest.mock import Mock + +import pytest + +from aws_lambda_powertools import Metrics + +# Reset cold start flag before each test +from aws_lambda_powertools.metrics.provider import cold_start +from aws_lambda_powertools.utilities.typing import DurableContextProtocol + + +def capture_metrics_output(capsys): + return json.loads(capsys.readouterr().out.strip()) + + +def capture_metrics_output_multiple_emf_objects(capsys): + return [json.loads(line.strip()) for line in capsys.readouterr().out.split("\n") if line] + + +def reset_cold_start_flag(): + cold_start.is_cold_start = True + + +@pytest.fixture +def lambda_context(): + lambda_context = { + "function_name": "test", + "memory_limit_in_mb": 128, + "invoked_function_arn": "arn:aws:lambda:eu-west-1:809313241:function:test", + "aws_request_id": "52fdfc07-2182-154f-163f-5f0f9a621d72", + } + return namedtuple("LambdaContext", lambda_context.keys())(*lambda_context.values()) + + +@pytest.fixture +def durable_context(lambda_context): + """Create a mock DurableContext with embedded Lambda context.""" + durable_ctx = Mock(spec=DurableContextProtocol) + durable_ctx.lambda_context = lambda_context + durable_ctx.state = Mock(operations=[{"id": "op1"}]) + return durable_ctx + + +@pytest.fixture +def lambda_context_with_function_name(): + """Create a simple lambda context with function_name.""" + LambdaContext = namedtuple("LambdaContext", "function_name") + return LambdaContext("test_function") + + +def test_log_metrics_with_durable_context_basic(capsys, namespace, service, durable_context): + """Test that log_metrics works with DurableContext.""" + # GIVEN Metrics is initialized + my_metrics = Metrics(service=service, namespace=namespace) + + # WHEN log_metrics decorator is used with a handler that receives DurableContext + @my_metrics.log_metrics + def lambda_handler(evt, context): + my_metrics.add_metric(name="test_metric", value=1.0, unit="Count") + + lambda_handler({}, durable_context) + + # THEN metrics should be emitted successfully + output = capture_metrics_output(capsys) + + assert output["test_metric"] == [1.0] + assert output["service"] == service + + +def test_log_metrics_capture_cold_start_with_durable_context(capsys, namespace, service): + """Test that capture_cold_start_metric works with DurableContext.""" + reset_cold_start_flag() + + # GIVEN Metrics is initialized + my_metrics = Metrics(service=service, namespace=namespace) + + # Create a DurableContext with embedded Lambda context + LambdaContext = namedtuple("LambdaContext", "function_name") + lambda_ctx = LambdaContext("durable_test_function") + + durable_ctx = Mock(spec=DurableContextProtocol) + durable_ctx.lambda_context = lambda_ctx + durable_ctx.state = Mock(operations=[{"id": "op1"}]) + + # WHEN log_metrics is used with capture_cold_start_metric and DurableContext + @my_metrics.log_metrics(capture_cold_start_metric=True) + def lambda_handler(evt, context): + my_metrics.add_metric(name="test_metric", value=1.0, unit="Count") + + lambda_handler({}, durable_ctx) + + # THEN ColdStart metric should be captured with the function name from unwrapped context + outputs = capture_metrics_output_multiple_emf_objects(capsys) + + # Cold start is in a separate EMF blob + cold_start_output = outputs[0] + assert cold_start_output["ColdStart"] == [1.0] + assert cold_start_output["function_name"] == "durable_test_function" + assert cold_start_output["service"] == service + + +def test_log_metrics_capture_cold_start_with_durable_context_explicit_function_name(capsys, namespace, service): + """Test capture_cold_start_metric with explicit function_name and DurableContext.""" + reset_cold_start_flag() + + # GIVEN Metrics is initialized with explicit function_name + my_metrics = Metrics(service=service, namespace=namespace, function_name="explicit_function") + + # Create a DurableContext + LambdaContext = namedtuple("LambdaContext", "function_name") + lambda_ctx = LambdaContext("context_function") + + durable_ctx = Mock(spec=DurableContextProtocol) + durable_ctx.lambda_context = lambda_ctx + durable_ctx.state = Mock(operations=[{"id": "op1"}]) + + # WHEN log_metrics is used with capture_cold_start_metric + @my_metrics.log_metrics(capture_cold_start_metric=True) + def lambda_handler(evt, context): + pass + + lambda_handler({}, durable_ctx) + + # THEN explicit function_name should take priority + output = capture_metrics_output(capsys) + + assert output.get("function_name") == "explicit_function" + + +def test_log_metrics_with_standard_context_still_works(capsys, namespace, service, lambda_context): + """Test that standard Lambda context still works (regression test).""" + # GIVEN Metrics is initialized + my_metrics = Metrics(service=service, namespace=namespace) + + # WHEN log_metrics decorator is used with standard LambdaContext + @my_metrics.log_metrics + def lambda_handler(evt, context): + my_metrics.add_metric(name="regression_test", value=42.0, unit="Count") + + lambda_handler({}, lambda_context) + + # THEN metrics should be emitted successfully + output = capture_metrics_output(capsys) + + assert output["regression_test"] == [42.0] + assert output["service"] == service + + +def test_log_metrics_capture_cold_start_standard_context_still_works(capsys, namespace, service): + """Test that capture_cold_start_metric with standard context still works (regression test).""" + reset_cold_start_flag() + + # GIVEN Metrics is initialized + my_metrics = Metrics(service=service, namespace=namespace) + + LambdaContext = namedtuple("LambdaContext", "function_name") + standard_context = LambdaContext("standard_function") + + # WHEN log_metrics is used with capture_cold_start_metric and standard context + @my_metrics.log_metrics(capture_cold_start_metric=True) + def lambda_handler(evt, context): + pass + + lambda_handler({}, standard_context) + + # THEN ColdStart metric should be captured + output = capture_metrics_output(capsys) + + assert "ColdStart" in output or output.get("ColdStart") == [1.0] + assert output.get("function_name") == "standard_function"