From 10901b5f3805dda48045562d0f775614999910ab Mon Sep 17 00:00:00 2001 From: Brian Axelson <86568017+baxeaz@users.noreply.github.com> Date: Fri, 11 Apr 2025 17:15:20 +0000 Subject: [PATCH] feat: Adding support for redacted environment variable values through openjd_redacted_env Signed-off-by: Brian Axelson <86568017+baxeaz@users.noreply.github.com> --- src/openjd/sessions/_action_filter.py | 407 ++++++++++++++--- src/openjd/sessions/_session.py | 21 +- src/openjd/sessions/_subprocess.py | 10 +- test/openjd/sessions/conftest.py | 102 ++++- test/openjd/sessions/test_action_filter.py | 338 ++++++++++---- test/openjd/sessions/test_redacted_env.py | 505 +++++++++++++++++++++ test/openjd/sessions/test_redaction.py | 191 ++++++++ test/openjd/sessions/test_session.py | 381 +++++++++++++++- 8 files changed, 1806 insertions(+), 149 deletions(-) create mode 100644 test/openjd/sessions/test_redacted_env.py create mode 100644 test/openjd/sessions/test_redaction.py diff --git a/src/openjd/sessions/_action_filter.py b/src/openjd/sessions/_action_filter.py index a2b6de32..9304c94e 100644 --- a/src/openjd/sessions/_action_filter.py +++ b/src/openjd/sessions/_action_filter.py @@ -6,11 +6,51 @@ import logging import re from enum import Enum -from typing import Any, Callable +from typing import Any, Callable, Optional + +from openjd.model import RevisionExtensions, SpecificationRevision from ._logging import LOG, LogContent, LogExtraInfo -__all__ = ("ActionMessageKind", "ActionMonitoringFilter") +__all__ = ("ActionMessageKind", "ActionMonitoringFilter", "redact_openjd_redacted_env_requests") + + +def redact_openjd_redacted_env_requests(command_str: str) -> str: + """Redact sensitive information in command strings before they're processed by the regular redaction mechanism. + + For example, if an openjd session is about to run the following command as a subprocess: + + python -c "print('openjd_redacted_env: SECRETKEY=SECRETVAL')" + + Once that print statement is received by the logger filter it will become redacted, but if we + were to log the full line before executing it in the subprocess, it would be unredacted. + + This method will turn: + + python -c "print('openjd_redacted_env: SECRETKEY=SECRETVAL')" + + to + + python -c "print('openjd_redacted_env: ******** + + So it may be safely logged. + + Args: + command_str: The command string that might contain sensitive information + + Returns: + The command string with sensitive information redacted + """ + # Find the position of the redaction token + token = "openjd_redacted_env:" + pos = command_str.find(token) + + # Fast path for the common case where there's no redaction needed + if pos == -1: + return command_str + + # If this is a redacted env command, redact everything after the token + return command_str[: pos + len(token)] + " ********" class ActionMessageKind(Enum): @@ -18,6 +58,7 @@ class ActionMessageKind(Enum): STATUS = "status" # A status message FAIL = "fail" # A failure message ENV = "env" # Defining an environment variable + REDACTED_ENV = "redacted_env" # Defining an environment variable with redacted value in logs UNSET_ENV = "unset_env" # Unsetting an environment variable # The following are not in the spec, but are utility provided by this runtime. @@ -33,7 +74,7 @@ class ActionMessageKind(Enum): ) filter_matcher = re.compile(filter_regex) -openjd_env_actions_filter_regex = "^(openjd_env|openjd_unset_env)" +openjd_env_actions_filter_regex = "^(openjd_env|openjd_redacted_env|openjd_unset_env)" openjd_env_actions_filter_matcher = re.compile(openjd_env_actions_filter_regex) # A regex for matching the assignment of a value to an environment variable @@ -99,6 +140,7 @@ def __init__( session_id: str, callback: Callable[[ActionMessageKind, Any, bool], None], suppress_filtered: bool = False, + revision_extensions: Optional[RevisionExtensions] = None, ): """ Args: @@ -111,20 +153,109 @@ def __init__( with a message payload when an Open Job Description message is found in the log. suppress_filtered (bool, optional): If True, then all Open Job Description messages will be filtered out of the log. Defaults to True. + revision_extensions (Optional[RevisionExtensions]): Contains information about the + specification revision and supported extensions. """ super().__init__(name) self._session_id = session_id self._callback = callback self._suppress_filtered = suppress_filtered + self._revision_extensions = revision_extensions + + # Initialize set to store sensitive values for redaction + self._redacted_values: set[str] = set() + # Initialize set to store line-specific redactions (for multi-line secrets) + self._redacted_lines: set[str] = set() self._internal_handlers = { ActionMessageKind.PROGRESS: self._handle_progress, ActionMessageKind.STATUS: self._handle_status, ActionMessageKind.FAIL: self._handle_fail, ActionMessageKind.ENV: self._handle_env, + ActionMessageKind.REDACTED_ENV: self._handle_redacted_env, ActionMessageKind.UNSET_ENV: self._handle_unset_env, ActionMessageKind.SESSION_RUNTIME_LOGLEVEL: self._handle_session_runtime_loglevel, } + def _redactions_enabled(self) -> bool: + """Check if redacted environment variables are enabled. + + Redactions are enabled if either: + 1. The specification revision is newer than v2023_09, OR + 2. The REDACTED_ENV_VARS extension is explicitly enabled + + Returns: + bool: True if redactions are enabled, False otherwise. + """ + return self._revision_extensions is not None and ( + self._revision_extensions.spec_rev > SpecificationRevision.v2023_09 + or "REDACTED_ENV_VARS" in self._revision_extensions.extensions + ) + + def apply_message_redaction(self, record: logging.LogRecord): + """Redact the log message if it contains any substrings which have been registered for redaction + + Args: + record (logging.LogRecord): The log record to check. + """ + # Check if we need to redact any sensitive values from the log message + if (self._redacted_values or self._redacted_lines) and isinstance(record.msg, str): + + # If we have args, first do string formatting, then redact + try: + record.msg = record.msg % record.args + record.args = () # Clear args since we've done the formatting + except Exception: + # If string formatting fails, fall back to just redacting the message + LOG.warning( + "Failed to format log message for redaction. Proceeding with redaction on unformatted message." + ) + + # Check if the entire message matches a line in the redacted_lines set + if record.msg in self._redacted_lines: + record.msg = "*" * 8 + record.args = () + return True + + # Find all segments that need redaction + segments_to_redact = [] + for value in self._redacted_values: + if value: + start = 0 + while True: + pos = record.msg.find(value, start) + if pos == -1: + break + segments_to_redact.append((pos, pos + len(value))) + start = pos + 1 + + # If we found segments to redact, merge overlapping segments + if segments_to_redact: + # Sort segments by start position + segments_to_redact.sort() + + # Merge overlapping segments + merged_segments = [] + current_start, current_end = segments_to_redact[0] + + for start, end in segments_to_redact[1:]: + if start <= current_end: + # Segments overlap, extend current segment + current_end = max(current_end, end) + else: + # No overlap, add current segment and start new one + merged_segments.append((current_start, current_end)) + current_start, current_end = start, end + + # Add the last segment + merged_segments.append((current_start, current_end)) + + # Apply redactions from end to start to avoid position shifts + msg_chars = list(record.msg) + for start, end in reversed(merged_segments): + msg_chars[start:end] = list("*" * 8) # Always use 8 asterisks for redaction + record.msg = "".join(msg_chars) + record.args = () + def filter(self, record: logging.LogRecord) -> bool: """Called automatically by Python's logging subsystem when a log record is sent to a log to which this filter class is applied. @@ -139,57 +270,78 @@ def filter(self, record: logging.LogRecord) -> bool: bool: If true then the Python logger will keep the record in the log, else it will remove it. """ - if not hasattr(record, "session_id") or getattr(record, "session_id") != self._session_id: - # Not a record for us to process - return True - if not isinstance(record.msg, str): - # If something sends a non-string to the logger (e.g. via logger.exception) then - # don't try to string match it. - return True - match = filter_matcher.match(record.msg) - if match and match.lastindex is not None: - message = match.group(match.lastindex) - # Note: keys of match.groupdict() are the names of named groups in the regex - matched_named_groups = tuple(k for k, v in match.groupdict().items() if v is not None) - if len(matched_named_groups) > 1: - # The only way that this happens is if filter_matcher is constructed incorrectly. - all_matched_groups = ",".join(k for k in matched_named_groups) - LOG.error( - f"Open Job Description: Malformed output stream filter matched multiple kinds ({all_matched_groups})", - extra=LogExtraInfo(openjd_log_content=LogContent.COMMAND_OUTPUT), - ) + try: + + if getattr(record, "session_id", None) != self._session_id: + # Not a record for us to process return True - message_kind = ActionMessageKind(matched_named_groups[0]) - try: - handler = self._internal_handlers[message_kind] - except KeyError: - LOG.error( - f"Open Job Description: Unhandled message kind ({message_kind.value})", - extra=LogExtraInfo(openjd_log_content=LogContent.COMMAND_OUTPUT), - ) + if not isinstance(record.msg, str): + # If something sends a non-string to the logger (e.g. via logger.exception) then + # don't try to string match it. return True - try: - handler(message) - except ValueError as e: - record.msg = record.msg + f" -- ERROR: {str(e)}" - # There was an error. Don't suppress the message from the log. + + match = filter_matcher.match(record.msg) + if match and match.lastindex is not None: + message = match.group(match.lastindex) + # Note: keys of match.groupdict() are the names of named groups in the regex + matched_named_groups = tuple( + k for k, v in match.groupdict().items() if v is not None + ) + if len(matched_named_groups) > 1: + # The only way that this happens is if filter_matcher is constructed incorrectly. + all_matched_groups = ",".join(k for k in matched_named_groups) + LOG.error( + f"Open Job Description: Malformed output stream filter matched multiple kinds ({all_matched_groups})", + extra=LogExtraInfo(openjd_log_content=LogContent.COMMAND_OUTPUT), + ) + return True + message_kind = ActionMessageKind(matched_named_groups[0]) + try: + handler = self._internal_handlers[message_kind] + except KeyError: + LOG.error( + f"Open Job Description: Unhandled message kind ({message_kind.value})", + extra=LogExtraInfo(openjd_log_content=LogContent.COMMAND_OUTPUT), + ) + return True + + # Check if this is a redacted_env message and the extension is not enabled + if ( + message_kind == ActionMessageKind.REDACTED_ENV + and not self._redactions_enabled() + ): + LOG.warning( + "Received openjd_redacted_env message but REDACTED_ENV_VARS extension is not enabled", + extra=LogExtraInfo(openjd_log_content=LogContent.COMMAND_OUTPUT), + ) + # We still process the message - just log the warning + + try: + handler(message) + + except ValueError as e: + record.msg = record.msg + f" -- ERROR: {str(e)}" + # There was an error. Don't suppress the message from the log. + return True + return not self._suppress_filtered + + # Check for "almost" matching openjd_env and openjd_unset_env commands + lower_case_trimmed_msg: str = record.msg.lstrip().lower() + if openjd_env_actions_filter_matcher.match(lower_case_trimmed_msg): + # There was a minor error like spaces or case in the env commands + err_message = ( + f"Open Job Description: Incorrectly formatted openjd env command ({record.msg})" + ) + record.msg = record.msg + f" -- ERROR: {err_message}" + + # Callback to cancel the action and mark it as FAILED + self._callback(ActionMessageKind.FAIL, err_message, True) return True - return not self._suppress_filtered - - # Check for "almost" matching openjd_env and openjd_unset_env commands - lower_case_trimmed_msg: str = record.msg.lstrip().lower() - if openjd_env_actions_filter_matcher.match(lower_case_trimmed_msg): - # There was a minor error like spaces or case in the env commands - err_message = ( - f"Open Job Description: Incorrectly formatted openjd env command ({record.msg})" - ) - record.msg = record.msg + f" -- ERROR: {err_message}" - # Callback to cancel the action and mark it as FAILED - self._callback(ActionMessageKind.FAIL, err_message, True) return True - - return True + finally: + # Always check for redaction before returning + self.apply_message_redaction(record) def _handle_progress(self, message: str) -> None: """Local handling of Progress messages. Processes the message and then @@ -227,29 +379,87 @@ def _handle_fail(self, message: str) -> None: """ self._callback(ActionMessageKind.FAIL, message, False) + def _parse_env_variable(self, message: str) -> tuple[str, str, bool, int, Optional[str]]: + """Parse an environment variable assignment string. + + Args: + message (str): The message containing the variable assignment + + Returns: + tuple: (variable_name, variable_value, is_valid, equals_position, original_value) + where equals_position is the index of the equals sign in the original message + and original_value is the value before JSON parsing (if applicable) + + A correctly formed message is of the form: + = + where: + consists of latin alphanumeric characters and the underscore, + and starts with a non-digit + can be any characters including empty. + """ + message = message.lstrip() + + # Find the position of the equals sign + equals_position = message.find("=") + if equals_position == -1: + return "", "", False, -1, None + + # Check if the message is valid + is_valid = envvar_set_matcher_str.match(message) or envvar_set_matcher_json.match(message) + + if not is_valid: + return "", "", False, equals_position, None + + # Parse the variable name and value + try: + original_value = None + if envvar_set_matcher_str.match(message): + name, _, value = message.partition("=") + else: + # Handle JSON format + try: + # Store the original value before JSON parsing + original_value = message[equals_position + 1 :] + message_json_str = json.loads(message) + name, _, value = message_json_str.partition("=") + except json.JSONDecodeError as e: + raise ValueError( + f"Unterminated string starting at: line {e.lineno} column {e.colno} (char {e.pos})" + ) + return name, value, True, equals_position, original_value + except json.JSONDecodeError as e: + raise ValueError( + f"Unterminated string starting at: line {e.lineno} column {e.colno} (char {e.pos})" + ) + + def _handle_env_error(self, error_message: str, is_redacted: bool = False) -> None: + """Handle errors in environment variable processing. + + Args: + error_message (str): The error message + is_redacted (bool): Whether this is for a redacted env var + """ + if is_redacted and self._redactions_enabled(): + LOG.warning( + f"Malformed openjd_redacted_env command: {error_message} No environment variable will be set.", + extra=LogExtraInfo(openjd_log_content=LogContent.COMMAND_OUTPUT), + ) + else: + # Callback to fail and cancel action on this error + self._callback(ActionMessageKind.ENV, error_message, True) + + raise ValueError(error_message) + def _handle_env(self, message: str) -> None: """Local handling of the Env messages. Args: message (str): The message after the leading 'openjd_env: ' prefix """ - message = message.lstrip() - # A correctly formed message is of the form: - # = - # where: - # consists of latin alphanumeric characters and the underscore, - # and starts with a non-digit - # can be any characters including empty. - if not envvar_set_matcher_str.match(message) and not envvar_set_matcher_json.match(message): - err_message = "Failed to parse environment variable assignment." - # Callback to fail and cancel action on this error - self._callback(ActionMessageKind.ENV, err_message, True) - raise ValueError(err_message) - elif envvar_set_matcher_str.match(message): - name, _, value = message.partition("=") - else: - message_json_str = json.loads(message) - name, _, value = message_json_str.partition("=") + name, value, is_valid, _, _ = self._parse_env_variable(message) + + if not is_valid: + self._handle_env_error("Failed to parse environment variable assignment.") self._callback(ActionMessageKind.ENV, {"name": name, "value": value}, False) @@ -292,3 +502,68 @@ def _handle_session_runtime_loglevel(self, message: str) -> None: raise ValueError( f"Unknown log level: {message}. Known values: {','.join(levels.keys())}" ) + + def _handle_redacted_env(self, message: str) -> None: + """Local handling of the Redacted Env messages. Similar to _handle_env but + redacts the value in logs and adds it to the set of values to redact in future logs. + + Args: + message (str): The message after the leading 'openjd_redacted_env: ' prefix + """ + + message = message.lstrip() + + # Use the shared parsing logic to validate and extract the variable + name, value, is_valid, equals_position, original_value = self._parse_env_variable(message) + + # Determine the value to redact + redaction_value = None + if is_valid: + # Use the properly parsed value for redaction + redaction_value = value + + # If we have an original value (from JSON parsing), add it to redaction set too + if original_value is not None: + self._redacted_values.add(original_value) + elif equals_position != -1: # Invalid format but has equals sign + # Fall back to extracting value directly from the message + redaction_value = message[equals_position + 1 :] + else: + # No equals sign, use entire content + redaction_value = message + + # Add the value to the redaction set + if redaction_value: + self._redacted_values.add(redaction_value) + + # Add the individual parts if we've received a string with newlines + parts = redaction_value.split("\n") + for i, part in enumerate(parts): + if part: # Skip empty parts + if i == 0 or i == len(parts) - 1: + # First and last parts go in the regular redaction set + self._redacted_values.add(part) + else: + # Middle parts go in the line redaction set + self._redacted_lines.add(part) + + # Handle invalid format + if not is_valid: + if self._redactions_enabled(): + if "=" not in message: + self._handle_env_error( + "Malformed openjd_redacted_env command: missing equals sign.", + is_redacted=True, + ) + else: + self._handle_env_error( + "Malformed openjd_redacted_env command: invalid format.", is_redacted=True + ) + else: + err_message = "Failed to parse environment variable assignment." + self._callback(ActionMessageKind.ENV, err_message, True) + return + + # Only set the environment variable if the extension is enabled + if self._redactions_enabled(): + self._callback(ActionMessageKind.ENV, {"name": name, "value": value}, False) diff --git a/src/openjd/sessions/_session.py b/src/openjd/sessions/_session.py index b59f0441..189984c3 100644 --- a/src/openjd/sessions/_session.py +++ b/src/openjd/sessions/_session.py @@ -20,6 +20,7 @@ JobParameterValues, ParameterValue, ParameterValueType, + RevisionExtensions, SpecificationRevision, SymbolTable, TaskParameterSet, @@ -318,6 +319,9 @@ def __init__( callback: Optional[SessionCallbackType] = None, os_env_vars: Optional[dict[str, str]] = None, session_root_directory: Optional[Path] = None, + revision_extensions: RevisionExtensions = RevisionExtensions( + spec_rev=SpecificationRevision.v2023_09, supported_extensions=[] + ), ): """ Arguments: @@ -358,6 +362,8 @@ def __init__( 2. The 'user' (if given) must have at least read permissions to it; and 3. The Working Directory for this Session will be created in the given directory. If not provided, then the default of gettempdir()/"openjd" is used instead. + revision_extensions (RevisionExtensions): Specification revision and supported extensions + for this session. Defaults to SpecificationRevision.v2023_09 with no extensions. Raises: RuntimeError - If the Session initialization fails for any reason. @@ -389,9 +395,14 @@ def __init__( ) self._reset_action_state() + # Store the revision_extensions + self._revision_extensions = revision_extensions + # Set up our logging hook & callback self._log_filter = ActionMonitoringFilter( - session_id=self._session_id, callback=self._action_log_filter_callback + session_id=self._session_id, + callback=self._action_log_filter_callback, + revision_extensions=revision_extensions, ) LOG.addFilter(self._log_filter) self._logger = LoggerAdapter(LOG, extra={"session_id": self._session_id}) @@ -832,6 +843,14 @@ def run_task( # ========================= # Helpers + def get_enabled_extensions(self) -> list[str]: + """Return the list of enabled extensions for this session. + + Returns: + list[str]: The list of enabled extensions + """ + return list(self._revision_extensions.extensions) + def _reset_action_state(self) -> None: """Reset the internal action state. This resets to a state equivalent to having nothing running. diff --git a/src/openjd/sessions/_subprocess.py b/src/openjd/sessions/_subprocess.py index ea29011c..f9f02a3c 100644 --- a/src/openjd/sessions/_subprocess.py +++ b/src/openjd/sessions/_subprocess.py @@ -19,6 +19,7 @@ from ._logging import LoggerAdapter, LogContent, LogExtraInfo from ._os_checker import is_linux, is_posix, is_windows from ._session_user import PosixSessionUser, WindowsSessionUser, SessionUser +from ._action_filter import redact_openjd_redacted_env_requests if is_windows(): # pragma: nocover from subprocess import CREATE_NEW_PROCESS_GROUP, CREATE_NO_WINDOW # type: ignore @@ -274,11 +275,18 @@ def _start_subprocess(self) -> Optional[Popen]: # https://docs.python.org/2/library/subprocess.html#subprocess.CREATE_NEW_PROCESS_GROUP popen_args["creationflags"] = CREATE_NEW_PROCESS_GROUP + # Get the command string for logging cmd_line_for_logger: str if is_posix(): cmd_line_for_logger = shlex.join(command) else: - cmd_line_for_logger = list2cmdline(self._args) + + cmd_line = list2cmdline(self._args) + # Command line could contain openjd_redacted_env: token lines not yet processed by the + # session logger. If the token appears in the command line we'll redact everything + # in the line after it for the logs. Note that on Linux currently the command including + # args are in a .sh script, so the full argument list isn't printed by default. + cmd_line_for_logger = redact_openjd_redacted_env_requests(cmd_line) self._logger.info( "Running command %s", cmd_line_for_logger, diff --git a/test/openjd/sessions/conftest.py b/test/openjd/sessions/conftest.py index 222a4eb7..caa854e7 100644 --- a/test/openjd/sessions/conftest.py +++ b/test/openjd/sessions/conftest.py @@ -6,12 +6,16 @@ from logging import INFO, getLogger from logging.handlers import QueueHandler from queue import Empty, SimpleQueue -from typing import Generator +from typing import Generator, Optional +from hashlib import sha256 +from unittest.mock import MagicMock import pytest from openjd.sessions import PosixSessionUser, WindowsSessionUser, BadCredentialsException from openjd.sessions._os_checker import is_posix, is_windows from openjd.sessions._logging import LoggerAdapter +from openjd.sessions._action_filter import ActionMonitoringFilter +from openjd.model import RevisionExtensions, SpecificationRevision if is_windows(): from openjd.sessions._win32._helpers import ( # type: ignore @@ -55,15 +59,107 @@ def pytest_collection_modifyitems(config, items): config.option.markexpr = mark_expr +def create_unique_logger_name(prefix: str = "", seed: Optional[str] = None) -> str: + """Create a unique logger name using a hash to avoid collisions. + + Args: + prefix: Optional prefix for the logger name + seed: Optional seed string to use for generating the hash + + Returns: + A unique logger name + """ + if seed: + h = sha256() + h.update(seed.encode("utf-8")) + suffix = h.hexdigest()[0:32] + else: + charset = string.ascii_letters + string.digits + suffix = "".join(random.choices(charset, k=32)) + + return f"{prefix}{suffix}" + + def build_logger(handler: QueueHandler) -> LoggerAdapter: - charset = string.ascii_letters + string.digits + string.punctuation - name_suffix = "".join(random.choices(charset, k=32)) + """Build a logger for testing purposes. + + Args: + handler: The queue handler to attach to the logger + + Returns: + A configured LoggerAdapter + """ + name_suffix = create_unique_logger_name() log = getLogger(".".join((__name__, name_suffix))) log.setLevel(INFO) log.addHandler(handler) return LoggerAdapter(log, extra=dict()) +def setup_action_filter_test( + queue_handler: QueueHandler, + session_id: str = "foo", + callback: Optional[MagicMock] = None, + suppress_filtered: bool = False, + enabled_extensions: Optional[list[str]] = None, +) -> tuple[LoggerAdapter, ActionMonitoringFilter, MagicMock]: + """Set up a test environment for testing ActionMonitoringFilter. + + This helper method creates a unique logger name, sets up the ActionMonitoringFilter, + and configures the logger with the filter. + + Args: + queue_handler: The QueueHandler to attach to the logger + session_id: The session ID to use for the filter + callback: Optional mock callback to use for the filter + suppress_filtered: Whether to suppress filtered messages + enabled_extensions: Optional list of extensions to enable + + Returns: + A tuple containing (logger_adapter, action_filter, callback_mock) + + Note: + This helper works for most tests, but for tests that need to verify specific + callback behavior with redacted values, it's better to create the filter and + logger directly in the test. This is because when multiple filters are applied + to the same log message (which can happen when running multiple tests), the + redaction can happen before the callback is invoked, resulting in the callback + receiving redacted values instead of the original values. + """ + # Create a unique logger name WITHOUT using the message as seed + # This ensures each test gets a truly unique logger name + logger_name = create_unique_logger_name(prefix="action_filter_") + + # Create a mock callback if one wasn't provided + if callback is None: + callback = MagicMock() + + # Create a RevisionExtensions with the provided extensions or an empty set + revision_extensions = RevisionExtensions( + spec_rev=SpecificationRevision.v2023_09, supported_extensions=enabled_extensions or [] + ) + + # Create the filter directly with the provided parameters + action_filter = ActionMonitoringFilter( + session_id=session_id, + callback=callback, + suppress_filtered=suppress_filtered, + revision_extensions=revision_extensions, + ) + + # Set up the logger + log = getLogger(".".join((__name__, logger_name))) + log.setLevel(INFO) + log.addHandler(queue_handler) + log.addFilter(action_filter) + + # Create and return the logger adapter with the session_id + # This is critical for the filter to work properly + logger_adapter = LoggerAdapter(log, extra={"session_id": session_id}) + + return logger_adapter, action_filter, callback + + def collect_queue_messages(queue: SimpleQueue) -> list[str]: """Extract the text of messages from a SimpleQueue containing LogRecords""" messages: list[str] = [] diff --git a/test/openjd/sessions/test_action_filter.py b/test/openjd/sessions/test_action_filter.py index 42b3110e..baea2afd 100644 --- a/test/openjd/sessions/test_action_filter.py +++ b/test/openjd/sessions/test_action_filter.py @@ -5,8 +5,6 @@ from __future__ import annotations import logging -from hashlib import sha256 -from openjd.sessions._logging import LoggerAdapter from logging.handlers import QueueHandler from queue import SimpleQueue from typing import Union @@ -18,6 +16,7 @@ ActionMessageKind, ActionMonitoringFilter, ) +from .conftest import setup_action_filter_test class TestActionMonitoringFilter: @@ -29,15 +28,6 @@ def message_queue(self) -> SimpleQueue: def queue_handler(self, message_queue: SimpleQueue) -> QueueHandler: return QueueHandler(message_queue) - def build_logger( - self, name: str, handler: QueueHandler, filter: ActionMonitoringFilter - ) -> logging.Logger: - log = logging.getLogger(".".join((__name__, name))) - log.setLevel(logging.INFO) - log.addHandler(handler) - log.addFilter(filter) - return log - @pytest.mark.parametrize( "message,kind,value", ( @@ -117,19 +107,19 @@ def build_logger( "openjd_session_runtime_loglevel: INFO", ActionMessageKind.SESSION_RUNTIME_LOGLEVEL, logging.INFO, - id="loglevel debug", + id="loglevel info", ), pytest.param( "openjd_session_runtime_loglevel: WARNING", ActionMessageKind.SESSION_RUNTIME_LOGLEVEL, logging.WARNING, - id="loglevel debug", + id="loglevel warning", ), pytest.param( "openjd_session_runtime_loglevel: ERROR", ActionMessageKind.SESSION_RUNTIME_LOGLEVEL, logging.ERROR, - id="loglevel debug", + id="loglevel error", ), ), ) @@ -142,15 +132,10 @@ def test_captures_suppress( value: Union[float, str], ) -> None: # GIVEN - h = sha256() - h.update(message.encode("utf-8")) - logger_name = "suppress" + h.hexdigest()[0:32] - callback_mock = MagicMock() - filter = ActionMonitoringFilter( - session_id="foo", callback=callback_mock, suppress_filtered=True + loga, _, callback_mock = setup_action_filter_test( + queue_handler=queue_handler, + suppress_filtered=True, ) - log = self.build_logger(logger_name, queue_handler, filter) - loga = LoggerAdapter(log, extra={"session_id": "foo"}) # WHEN loga.info(message) @@ -166,16 +151,16 @@ def test_ignores_different_session( ) -> None: # GIVEN message = "openjd_fail: an error message" - h = sha256() - h.update(message.encode("utf-8")) - logger_name = "suppress" + h.hexdigest()[0:32] callback_mock = MagicMock() filter = ActionMonitoringFilter( session_id="foo", callback=callback_mock, suppress_filtered=True ) - log = self.build_logger(logger_name, queue_handler, filter) + log = logging.getLogger("test.different_session") + log.setLevel(logging.INFO) + log.addHandler(queue_handler) + log.addFilter(filter) - # WHEN + # WHEN - Note we don't use LoggerAdapter with session_id here log.info(message) # THEN @@ -226,13 +211,7 @@ def test_captures_no_suppress( value: Union[float, str], ) -> None: # GIVEN - h = sha256() - h.update(message.encode("utf-8")) - logger_name = "no_suppress" + h.hexdigest()[0:32] - callback_mock = MagicMock() - filter = ActionMonitoringFilter(session_id="foo", callback=callback_mock) - log = self.build_logger(logger_name, queue_handler, filter) - loga = LoggerAdapter(log, extra={"session_id": "foo"}) + loga, _, callback_mock = setup_action_filter_test(queue_handler=queue_handler) # WHEN loga.info(message) @@ -274,13 +253,7 @@ def test_malformed_does_not_match_no_callback( self, queue_handler: QueueHandler, message: str ) -> None: # GIVEN - h = sha256() - h.update(message.encode("utf-8")) - logger_name = "malformed" + h.hexdigest()[0:32] - callback_mock = MagicMock() - filter = ActionMonitoringFilter(session_id="foo", callback=callback_mock) - log = self.build_logger(logger_name, queue_handler, filter) - loga = LoggerAdapter(log, extra={"session_id": "foo"}) + loga, _, callback_mock = setup_action_filter_test(queue_handler=queue_handler) # WHEN loga.info(message) @@ -307,17 +280,27 @@ def test_malformed_does_not_match_no_callback( "openjd_env: F😁=bar", id="env, non-latin", ), + pytest.param( + "openjd_redacted_env: foo", + id="redacted_env, missing assignment", + ), + pytest.param( + "openjd_redacted_env: foo =value", + id="redacted_env, extra whitespace", + ), + pytest.param( + "openjd_redacted_env: 1F_F_12=bar", + id="redacted_env, start with digit", + ), + pytest.param( + "openjd_redacted_env: F😁=bar", + id="redacted_env, non-latin", + ), ), ) def test_malformed_set_env_assigment(self, queue_handler: QueueHandler, message: str) -> None: # GIVEN - h = sha256() - h.update(message.encode("utf-8")) - logger_name = "malformed" + h.hexdigest()[0:32] - callback_mock = MagicMock() - filter = ActionMonitoringFilter(session_id="foo", callback=callback_mock) - log = self.build_logger(logger_name, queue_handler, filter) - loga = LoggerAdapter(log, extra={"session_id": "foo"}) + loga, _, callback_mock = setup_action_filter_test(queue_handler=queue_handler) # WHEN loga.info(message) @@ -341,6 +324,18 @@ def test_malformed_set_env_assigment(self, queue_handler: QueueHandler, message: " openjd_env: foo=bar", id="env, leading whitespace", ), + pytest.param( + "openjd_redacted_env:foo=bar", + id="redacted_env, no space", + ), + pytest.param( + "OPENJD_REDACTED_ENV: foo=bar", + id="redacted_env, uppercase", + ), + pytest.param( + " openjd_redacted_env: foo=bar", + id="redacted_env, leading whitespace", + ), pytest.param( "openjd_unset_env:foo", id="unset_env, no space", @@ -357,13 +352,7 @@ def test_malformed_set_env_assigment(self, queue_handler: QueueHandler, message: ) def test_malformed_openjd_regex(self, queue_handler: QueueHandler, message: str) -> None: # GIVEN - h = sha256() - h.update(message.encode("utf-8")) - logger_name = "malformed" + h.hexdigest()[0:32] - callback_mock = MagicMock() - filter = ActionMonitoringFilter(session_id="foo", callback=callback_mock) - log = self.build_logger(logger_name, queue_handler, filter) - loga = LoggerAdapter(log, extra={"session_id": "foo"}) + loga, _, callback_mock = setup_action_filter_test(queue_handler=queue_handler) # WHEN loga.info(message) @@ -393,13 +382,7 @@ def test_malformed_does_not_match_unset_env( self, queue_handler: QueueHandler, message: str ) -> None: # GIVEN - h = sha256() - h.update(message.encode("utf-8")) - logger_name = "malformed" + h.hexdigest()[0:32] - callback_mock = MagicMock() - filter = ActionMonitoringFilter(session_id="foo", callback=callback_mock) - log = self.build_logger(logger_name, queue_handler, filter) - loga = LoggerAdapter(log, extra={"session_id": "foo"}) + loga, _, callback_mock = setup_action_filter_test(queue_handler=queue_handler) # WHEN loga.info(message) @@ -424,13 +407,7 @@ def test_progress_appends_error( # message through to the log and we append an error message to it. # # GIVEN - h = sha256() - h.update(message.encode("utf-8")) - logger_name = "appends" + h.hexdigest()[0:32] - callback_mock = MagicMock() - filter = ActionMonitoringFilter(session_id="foo", callback=callback_mock) - log = self.build_logger(logger_name, queue_handler, filter) - loga = LoggerAdapter(log, extra={"session_id": "foo"}) + loga, _, callback_mock = setup_action_filter_test(queue_handler=queue_handler) expected_message = ( message + " -- ERROR: Progress must be a floating point value between 0.0 and 100.0, inclusive." @@ -450,15 +427,10 @@ def test_handles_non_string( queue_handler: QueueHandler, ) -> None: # GIVEN - h = sha256() - h.update("exception-test".encode("utf-8")) - logger_name = "non_string" + h.hexdigest()[0:32] - callback_mock = MagicMock() - filter = ActionMonitoringFilter( - session_id="foo", callback=callback_mock, suppress_filtered=True + loga, _, callback_mock = setup_action_filter_test( + queue_handler=queue_handler, + suppress_filtered=True, ) - log = self.build_logger(logger_name, queue_handler, filter) - loga = LoggerAdapter(log, extra={"session_id": "foo"}) # WHEN try: @@ -470,3 +442,215 @@ def test_handles_non_string( callback_mock.assert_not_called() assert message_queue.qsize() == 1 assert "Exception: Surprise!" in message_queue.get(block=False).getMessage() + + def test_redacted_env_redacts_value( + self, message_queue: SimpleQueue, queue_handler: QueueHandler + ) -> None: + """Test that openjd_redacted_env properly redacts values in logs.""" + # GIVEN + message = "openjd_redacted_env: PASSWORD=secret123" + loga, _, callback_mock = setup_action_filter_test( + queue_handler=queue_handler, + enabled_extensions=["REDACTED_ENV_VARS"], + ) + + # WHEN + loga.info(message) + + # THEN + # Check that the callback was called with the correct parameters + callback_mock.assert_called_once_with( + ActionMessageKind.ENV, {"name": "PASSWORD", "value": "secret123"}, False + ) + + # Check that the message in the log is redacted + assert message_queue.qsize() == 1 + log_message = message_queue.get(block=False).getMessage() + assert "openjd_redacted_env: PASSWORD=********" in log_message + assert "secret123" not in log_message + + def test_redacted_env_with_warning( + self, message_queue: SimpleQueue, queue_handler: QueueHandler, monkeypatch + ) -> None: + """Test that redacted_env messages log a warning when the extension is not enabled.""" + # GIVEN + mock_log = MagicMock() + monkeypatch.setattr("openjd.sessions._action_filter.LOG", mock_log) + + message = "openjd_redacted_env: SECRET_VAR=secret_value" + loga, _, callback_mock = setup_action_filter_test( + queue_handler=queue_handler, + enabled_extensions=[], # No extensions enabled + ) + + # WHEN + loga.info(message) + + # THEN + mock_log.warning.assert_called_once() + assert "REDACTED_ENV_VARS extension is not enabled" in mock_log.warning.call_args[0][0] + + # The callback should NOT be called since the extension is not enabled + callback_mock.assert_not_called() + + # Check that the message in the log is redacted + assert message_queue.qsize() == 1, "Message passed through" + log_message = message_queue.get(block=False).getMessage() + assert "SECRET_VAR=********" in log_message + assert "secret_value" not in log_message + + def test_redacted_env_uses_fixed_length_redaction( + self, message_queue: SimpleQueue, queue_handler: QueueHandler + ) -> None: + """Test that openjd_redacted_env uses a fixed-length redaction regardless of value length.""" + # GIVEN + # Create a single logger setup for both tests + loga, _, callback_mock = setup_action_filter_test(queue_handler=queue_handler) + + short_message = "openjd_redacted_env: KEY=x" + long_message = "openjd_redacted_env: TOKEN=abcdefghijklmnopqrstuvwxyz1234567890" + expected_redacted_format = "********" + + # WHEN + loga.info(short_message) + loga.info(long_message) + + # THEN + # Check that both messages use the same fixed-length redaction + assert message_queue.qsize() == 2, "Both messages passed through" + + log_message1 = message_queue.get(block=False).getMessage() + assert log_message1 == f"openjd_redacted_env: KEY={expected_redacted_format}" + + log_message2 = message_queue.get(block=False).getMessage() + assert log_message2 == f"openjd_redacted_env: TOKEN={expected_redacted_format}" + + def test_redacted_env_redacts_subsequent_occurrences( + self, message_queue: SimpleQueue, queue_handler: QueueHandler + ) -> None: + """Test that values from openjd_redacted_env are redacted in all subsequent log messages.""" + # GIVEN + loga, _, callback_mock = setup_action_filter_test( + queue_handler=queue_handler, + enabled_extensions=["REDACTED_ENV_VARS"], + ) + + # WHEN + # First log the redacted_env message to set up the value + redacted_message = "openjd_redacted_env: PASSWORD=supersecret123" + loga.info(redacted_message) + + # Then log a regular message containing the sensitive value + regular_message = "Here is the password: supersecret123 for your reference" + loga.info(regular_message) + + # THEN + # Check that both messages were logged + assert message_queue.qsize() == 2, "Both messages should be in the queue" + + # First message should have redacted the value in the openjd_redacted_env line + first_log = message_queue.get(block=False).getMessage() + assert first_log == "openjd_redacted_env: PASSWORD=********" + assert "supersecret123" not in first_log + + # Second message should have redacted the sensitive value + second_log = message_queue.get(block=False).getMessage() + assert "supersecret123" not in second_log + assert "Here is the password: ********" in second_log + + # The callback should have been called with the actual value for env processing + callback_mock.assert_any_call( + ActionMessageKind.ENV, {"name": "PASSWORD", "value": "supersecret123"}, False + ) + + def test_redacted_env_handles_multiple_values( + self, message_queue: SimpleQueue, queue_handler: QueueHandler + ) -> None: + """Test that multiple redacted values are all properly redacted in logs.""" + # GIVEN + loga, _, callback_mock = setup_action_filter_test( + queue_handler=queue_handler, + enabled_extensions=["REDACTED_ENV_VARS"], + ) + + # WHEN + # Set up multiple redacted values + loga.info("openjd_redacted_env: PASSWORD=password123") + loga.info("openjd_redacted_env: API_KEY=abcdef123456") + + # Log a message containing both sensitive values + loga.info("Using PASSWORD=password123 and API_KEY=abcdef123456 for authentication") + + # THEN + # Skip the first two messages which are the redacted_env declarations + message_queue.get(block=False) + message_queue.get(block=False) + + # Check that the third message has both values redacted + final_message = message_queue.get(block=False).getMessage() + assert "password123" not in final_message + assert "abcdef123456" not in final_message + assert "Using PASSWORD=******** and API_KEY=******** for authentication" in final_message + + def test_redacted_env_with_extension( + self, message_queue: SimpleQueue, queue_handler: QueueHandler + ) -> None: + """Test that redacted_env messages set environment variables when the extension is enabled.""" + # GIVEN + message = "openjd_redacted_env: PASSWORD=secret123" + loga, action_filter, callback_mock = setup_action_filter_test( + queue_handler=queue_handler, + enabled_extensions=["REDACTED_ENV_VARS"], + ) + + # WHEN + loga.info(message) + + # THEN + # The callback should be called with the environment variable info + callback_mock.assert_called_once_with( + ActionMessageKind.ENV, {"name": "PASSWORD", "value": "secret123"}, False + ) + + # The message should be redacted in the logs + assert message_queue.qsize() == 1 + log_message = message_queue.get(block=False).getMessage() + assert "openjd_redacted_env: PASSWORD=********" in log_message + assert "secret123" not in log_message + + def test_malformed_redacted_env_commands( + self, message_queue: SimpleQueue, queue_handler: QueueHandler + ) -> None: + """Test handling of malformed redacted_env commands with spaces or missing equals sign.""" + # GIVEN + loga, _, callback_mock = setup_action_filter_test( + queue_handler=queue_handler, + enabled_extensions=["REDACTED_ENV_VARS"], + ) + + # Case 1: Space after key (key =value) + message1 = "openjd_redacted_env: PASSWORD =secret123" + + # Case 2: Missing equals sign (keyvalue) + message2 = "openjd_redacted_env: SECRETsensitivedata" + + # WHEN + loga.info(message1) + loga.info(message2) + + # THEN + # Check that both messages were processed + assert message_queue.qsize() == 2 + + # For Case 1 (key =value), we should still try to redact the value + log_message1 = message_queue.get(block=False).getMessage() + assert "openjd_redacted_env: PASSWORD =********" in log_message1 + assert "secret123" not in log_message1 + + # For Case 2 (missing equals), we should redact the entire content after the prefix + log_message2 = message_queue.get(block=False).getMessage() + assert "openjd_redacted_env: ********" in log_message2 + assert "SECRETsensitivedata" not in log_message2 + + # Neither case should set an environment variable + callback_mock.assert_not_called() diff --git a/test/openjd/sessions/test_redacted_env.py b/test/openjd/sessions/test_redacted_env.py new file mode 100644 index 00000000..41bc55fe --- /dev/null +++ b/test/openjd/sessions/test_redacted_env.py @@ -0,0 +1,505 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + +"""Tests for redacted environment variables functionality.""" + +import logging +import logging.handlers +import pytest +from queue import SimpleQueue +from unittest.mock import MagicMock + +from openjd.sessions._action_filter import ActionMessageKind +from .conftest import setup_action_filter_test + + +class TestRedactedEnv: + """Tests for redacted environment variables functionality.""" + + # Basic functionality tests + def test_basic_redacted_env( + self, message_queue: SimpleQueue, queue_handler: logging.handlers.QueueHandler + ) -> None: + """Test basic functionality of redacted environment variables.""" + # Setup + callback_mock = MagicMock() + loga, _, _ = setup_action_filter_test( + queue_handler=queue_handler, + callback=callback_mock, + enabled_extensions=["REDACTED_ENV_VARS"], + ) + + # WHEN + loga.info("openjd_redacted_env: KEY=VALUE") + + # THEN + # Check that the callback was called with the correct parameters + callback_mock.assert_called_once_with( + ActionMessageKind.ENV, {"name": "KEY", "value": "VALUE"}, False + ) + + # Check that the message in the log is redacted + assert message_queue.qsize() == 1 + log_record = message_queue.get(block=False).getMessage() + assert "VALUE" not in log_record + assert "********" in log_record + + # Edge cases tests using parameterization + @pytest.mark.parametrize( + "case,expected_key,expected_value,should_set_env", + [ + # Space after equals + ("openjd_redacted_env: KEY= VALUE", "KEY", " VALUE", True), + # Space before equals + ("openjd_redacted_env: KEY =VALUE", None, "VALUE", False), + # No equals + ("openjd_redacted_env: KEYVALUE", None, "KEYVALUE", False), + # Multiple equals + ("openjd_redacted_env: KEY=VALUE=MORE", "KEY", "VALUE=MORE", True), + # Empty value + ("openjd_redacted_env: KEY=", "KEY", "", True), + ], + ) + def test_redacted_env_edge_cases( + self, + message_queue: SimpleQueue, + queue_handler: logging.handlers.QueueHandler, + case: str, + expected_key: str, + expected_value: str, + should_set_env: bool, + ) -> None: + """Test edge cases for redacted environment variables.""" + # Setup + callback_mock = MagicMock() + loga, _, _ = setup_action_filter_test( + queue_handler=queue_handler, + callback=callback_mock, + enabled_extensions=["REDACTED_ENV_VARS"], + ) + + # WHEN + loga.info(case) + + # THEN + if should_set_env: + # Check that the callback was called with the correct parameters + callback_mock.assert_called_once_with( + ActionMessageKind.ENV, {"name": expected_key, "value": expected_value}, False + ) + else: + # Check that the callback was NOT called (no env var should be set) + callback_mock.assert_not_called() + + # Check that the message in the log is redacted + assert message_queue.qsize() == 1 + log_message = message_queue.get(block=False).getMessage() + if expected_value: # Skip empty string check + assert expected_value not in log_message + + # Check that subsequent logs with the value are redacted + if expected_value: # Skip empty string + loga.info(f"The value is: {expected_value}") + assert message_queue.qsize() == 1 + log_message = message_queue.get(block=False).getMessage() + assert expected_value not in log_message + assert "********" in log_message + + # Special cases tests + @pytest.mark.parametrize( + "value,description", + [ + ("p@$$w0rd!*&^%", "special characters"), + ("line1\\nline2\\nline3", "newlines"), + ("C:\\Program Files\\App\\bin;D:\\Tools", "Windows paths"), + ], + ) + def test_redacted_env_special_values( + self, + message_queue: SimpleQueue, + queue_handler: logging.handlers.QueueHandler, + value: str, + description: str, + ) -> None: + """Test redacted environment variables with special values.""" + # Setup + callback_mock = MagicMock() + loga, _, _ = setup_action_filter_test( + queue_handler=queue_handler, + callback=callback_mock, + enabled_extensions=["REDACTED_ENV_VARS"], + ) + + # WHEN + loga.info(f"openjd_redacted_env: TEST_VAR={value}") + + # THEN + # Check that the callback was called with the correct parameters + callback_mock.assert_called_once_with( + ActionMessageKind.ENV, {"name": "TEST_VAR", "value": value}, False + ) + + # Check that the message in the log is redacted + assert message_queue.qsize() == 1 + log_record = message_queue.get(block=False).getMessage() + assert value not in log_record + assert "********" in log_record + + # JSON format tests + @pytest.mark.parametrize( + "json_case,should_succeed", + [ + # Standard JSON format (properly escaped strings) + ('openjd_redacted_env: "FOO=BAR"', True), + ('openjd_redacted_env: "FOO=BAR\\nBAZ"', True), + ('openjd_redacted_env: "FOO="', True), + ], + ) + def test_redacted_env_json_format( + self, + message_queue: SimpleQueue, + queue_handler: logging.handlers.QueueHandler, + json_case: str, + should_succeed: bool, + ) -> None: + """Test JSON format for redacted environment variables.""" + # Setup + callback_mock = MagicMock() + loga, _, _ = setup_action_filter_test( + queue_handler=queue_handler, + callback=callback_mock, + enabled_extensions=["REDACTED_ENV_VARS"], + ) + + # WHEN + loga.info(json_case) + + # THEN + if should_succeed: + # At least one call should be for setting an environment variable + env_calls = [ + call + for call in callback_mock.call_args_list + if call[0][0] == ActionMessageKind.ENV and isinstance(call[0][1], dict) + ] + assert len(env_calls) > 0, f"Case '{json_case}': No environment variable was set" + + # Check that the message in the log is redacted + assert message_queue.qsize() == 1 + log_record = message_queue.get(block=False).getMessage() + + # Extract the value that should be redacted + if "\\n" in json_case: + # For the line break case, check that both parts are redacted + assert "BAR" not in log_record, f"Value 'BAR' was not redacted in: {log_record}" + assert "BAZ" not in log_record, f"Value 'BAZ' was not redacted in: {log_record}" + assert "********" in log_record, f"Redaction marker not found in: {log_record}" + elif "FOO=BAR" in json_case: + assert "BAR" not in log_record, f"Value 'BAR' was not redacted in: {log_record}" + assert "********" in log_record, f"Redaction marker not found in: {log_record}" + # For empty value case, we don't check for redaction marker since there's nothing to redact + else: + # No calls should be for setting an environment variable + env_calls = [ + call + for call in callback_mock.call_args_list + if call[0][0] == ActionMessageKind.ENV and isinstance(call[0][1], dict) + ] + assert ( + len(env_calls) == 0 + ), f"Case '{json_case}': Environment variable was set unexpectedly" + + # Consistency tests between openjd_env and openjd_redacted_env + @pytest.mark.parametrize( + "case,should_succeed", + [ + # Success cases - standard format + ("openjd_env: KEY=VALUE", True), + ("openjd_env: KEY= VALUE", True), + ("openjd_env: KEY=VALUE=MORE", True), + ("openjd_env: KEY=", True), + # Success cases - quoted format + ('openjd_env: "FOO=12\\n34"', True), + ('openjd_env: "FOO="', True), + # Success cases - JSON format (pre-encoded as strings) + ('openjd_env: "FOO=BAR"', True), + ('openjd_env: "FOO=BAR\\nBAZ"', True), + # Success case - whitespace after prefix + ("openjd_env: \t foo=bar", True), + # Failure cases + ("openjd_env: KEY =VALUE", False), + ("openjd_env: KEYVALUE", False), + ("openjd_env: 1F_F_12=bar", False), + ("openjd_env: F😁=bar", False), + # Format issue cases + ("openjd_env:foo=bar", False), + ("OPENJD_ENV: foo=bar", False), + (" openjd_env: foo=bar", False), + ], + ) + def test_env_redacted_env_consistency( + self, + message_queue: SimpleQueue, + queue_handler: logging.handlers.QueueHandler, + case: str, + should_succeed: bool, + ) -> None: + """Test that openjd_redacted_env behaves the same as openjd_env for setting environment variables, + except for the redaction behavior.""" + + # Create a fresh filter and mock for each test case + callback_mock_env = MagicMock() + loga_env, _, _ = setup_action_filter_test( + queue_handler=queue_handler, + callback=callback_mock_env, + ) + + callback_mock_redacted = MagicMock() + loga_redacted, _, _ = setup_action_filter_test( + queue_handler=queue_handler, + callback=callback_mock_redacted, + enabled_extensions=["REDACTED_ENV_VARS"], + ) + + # Run both commands with the same input + loga_env.info(case) + # Replace "openjd_env" with "openjd_redacted_env" in the command + redacted_case = case.replace("openjd_env", "openjd_redacted_env") + redacted_case = redacted_case.replace("OPENJD_ENV", "OPENJD_REDACTED_ENV") + loga_redacted.info(redacted_case) + + # Clear the queue + while not message_queue.empty(): + message_queue.get() + + if should_succeed: + # For success cases, verify both set the environment variable correctly + callback_mock_env.assert_called_once() + callback_mock_redacted.assert_called_once() + + # Check that the parameters match + env_args = callback_mock_env.call_args[0] + redacted_args = callback_mock_redacted.call_args[0] + + # The first argument should be ActionMessageKind.ENV for both + assert ( + env_args[0] == redacted_args[0] == ActionMessageKind.ENV + ), f"Case '{case}': Different message kinds" + + # The second argument should be the same dictionary (name and value) + assert ( + env_args[1] == redacted_args[1] + ), f"Case '{case}': Different environment variable settings" + + # The third argument should be False for both + assert ( + env_args[2] == redacted_args[2] and not env_args[2] + ), f"Case '{case}': Different third argument" + else: + # For failure cases, verify neither sets an environment variable with a dict + env_success_calls = [ + call + for call in callback_mock_env.call_args_list + if call[0][0] == ActionMessageKind.ENV and isinstance(call[0][1], dict) + ] + redacted_success_calls = [ + call + for call in callback_mock_redacted.call_args_list + if call[0][0] == ActionMessageKind.ENV and isinstance(call[0][1], dict) + ] + + assert ( + len(env_success_calls) == 0 + ), f"Case '{case}': openjd_env should not set environment variable" + assert ( + len(redacted_success_calls) == 0 + ), f"Case '{case}': openjd_redacted_env should not set environment variable" + + # Additional tests for specific behaviors + def test_subsequent_redaction( + self, message_queue: SimpleQueue, queue_handler: logging.handlers.QueueHandler + ) -> None: + """Test that values are redacted in subsequent logs.""" + # Setup + callback_mock = MagicMock() + loga, _, _ = setup_action_filter_test( + queue_handler=queue_handler, + callback=callback_mock, + enabled_extensions=["REDACTED_ENV_VARS"], + ) + + # WHEN + # First set the environment variable + loga.info("openjd_redacted_env: API_KEY=abcdef123456") + + # Clear the queue + while not message_queue.empty(): + message_queue.get() + + # Then log a message containing the secret value + loga.info("Using API key: abcdef123456") + + # THEN + # Check that the secret value is redacted in the subsequent log + assert message_queue.qsize() == 1 + log_record = message_queue.get(block=False).getMessage() + assert "abcdef123456" not in log_record + assert "Using API key: ********" in log_record + + def test_redaction_persists_after_unset( + self, message_queue: SimpleQueue, queue_handler: logging.handlers.QueueHandler + ) -> None: + """Test that when we unset a redacted environment variable: + 1. The variable is unset (via callback) + 2. The value continues to be redacted in logs""" + # Setup + callback_mock = MagicMock() + loga, _, _ = setup_action_filter_test( + queue_handler=queue_handler, + callback=callback_mock, + enabled_extensions=["REDACTED_ENV_VARS"], + ) + + # Set up redaction + loga.info("openjd_redacted_env: SECRETVAR=SECRETVAL") + + # Clear the queue of the setup messages + while not message_queue.empty(): + message_queue.get() + + # WHEN - Unset the variable + loga.info("openjd_unset_env: SECRETVAR") + + # THEN - The callback should be called to unset the var + callback_mock.assert_called_with(ActionMessageKind.UNSET_ENV, "SECRETVAR", False) + + # Clear the queue of the unset message + while not message_queue.empty(): + message_queue.get() + + # AND - The value should still be redacted in logs + loga.info("The value is: SECRETVAL") + assert message_queue.qsize() == 1 + log_message = message_queue.get(block=False).getMessage() + assert "The value is: ********" in log_message + assert "SECRETVAL" not in log_message + + def test_redacted_env_with_linebreak( + self, message_queue: SimpleQueue, queue_handler: logging.handlers.QueueHandler + ) -> None: + """Test that values with actual line breaks are properly redacted in subsequent logs.""" + # Setup + callback_mock = MagicMock() + loga, _, _ = setup_action_filter_test( + queue_handler=queue_handler, + callback=callback_mock, + enabled_extensions=["REDACTED_ENV_VARS"], + ) + + # WHEN - First set the environment variable with a line break using JSON format + loga.info('openjd_redacted_env: "SECRETVAR2=line\\nbreak"') + + # Clear the queue + while not message_queue.empty(): + message_queue.get() + + # Then log a message containing the secret value split across lines + loga.info("We set SECRETVAR2 to line\nbreak") + + # THEN + # Check that both parts of the secret value are redacted in the subsequent log + assert message_queue.qsize() == 1 + log_record = message_queue.get(block=False).getMessage() + + assert "line" not in log_record, "First part of the secret value was not redacted" + assert "break" not in log_record, "Second part of the secret value was not redacted" + + def test_redacted_env_with_multiline_redaction( + self, message_queue: SimpleQueue, queue_handler: logging.handlers.QueueHandler + ) -> None: + """Test that subsequent lines of a multi-line secret are properly redacted.""" + # Setup + callback_mock = MagicMock() + loga, _, _ = setup_action_filter_test( + queue_handler=queue_handler, + callback=callback_mock, + enabled_extensions=["REDACTED_ENV_VARS"], + ) + + # WHEN - First set the environment variable with a multi-line value + loga.info('openjd_redacted_env: "SECRETVAR=first_line\\nsecond_line\\nthird_line"') + + # Clear the queue + while not message_queue.empty(): + message_queue.get() + + # Then log messages containing the individual lines + loga.info("The first part is: first_line") + + # THEN + # Check that the first line is redacted as part of a larger string + assert message_queue.qsize() == 1 + log_record = message_queue.get(block=False).getMessage() + assert "first_line" not in log_record, "First part of the secret value was not redacted" + assert "The first part is: ********" in log_record + + # Log just the second line by itself (should be fully redacted) + loga.info("second_line") + + # Check that the second line is completely redacted + assert message_queue.qsize() == 1 + log_record = message_queue.get(block=False).getMessage() + assert log_record == "********", "Second line was not fully redacted" + + # Log the third line by itself + loga.info("third_line") + + # Check that the third line is completely redacted + assert message_queue.qsize() == 1 + log_record = message_queue.get(block=False).getMessage() + assert log_record == "********", "Third line was not fully redacted" + + # Log the second line with some prefix text (should NOT be fully redacted) + loga.info("Prefix second_line") + + # Check that the "second_line" part is NOT redacted at all in "Prefix second_line" + # since it's only in _redacted_lines (exact match only) and not in _redacted_values + assert message_queue.qsize() == 1 + log_record = message_queue.get(block=False).getMessage() + assert "Prefix second_line" in log_record, "Line was incorrectly redacted" + + def test_redacted_env_with_multiline_redaction_last_part( + self, message_queue: SimpleQueue, queue_handler: logging.handlers.QueueHandler + ) -> None: + """Test that the last part of a multi-line secret is properly redacted.""" + # Setup + callback_mock = MagicMock() + loga, _, _ = setup_action_filter_test( + queue_handler=queue_handler, + callback=callback_mock, + enabled_extensions=["REDACTED_ENV_VARS"], + ) + + # WHEN - First set the environment variable with a multi-line value + loga.info('openjd_redacted_env: "SECRETVAR=first_line\\nmiddle_line\\nlast_line"') + + # Clear the queue + while not message_queue.empty(): + message_queue.get() + + # Log the last line with some prefix text (should be partially redacted) + loga.info("Prefix last_line") + + # THEN + # Check that the last line is redacted even with a prefix + assert message_queue.qsize() == 1 + log_record = message_queue.get(block=False).getMessage() + assert "last_line" not in log_record, "Last part of the secret value was not redacted" + assert "Prefix ********" in log_record, "Last part was not properly redacted" + + # Log the middle line with some prefix text (should NOT be redacted) + loga.info("Prefix middle_line") + + # Check that the middle line is NOT redacted when it has a prefix + assert message_queue.qsize() == 1 + log_record = message_queue.get(block=False).getMessage() + assert "Prefix middle_line" in log_record, "Middle line was incorrectly redacted" diff --git a/test/openjd/sessions/test_redaction.py b/test/openjd/sessions/test_redaction.py new file mode 100644 index 00000000..13978cac --- /dev/null +++ b/test/openjd/sessions/test_redaction.py @@ -0,0 +1,191 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + +"""Tests for the redaction functionality in the action filter.""" + +import logging +import logging.handlers +from queue import SimpleQueue + +from openjd.sessions._action_filter import ( + ActionMonitoringFilter, + ActionMessageKind, + redact_openjd_redacted_env_requests, +) +from openjd.model import RevisionExtensions, SpecificationRevision + +from .conftest import setup_action_filter_test + + +def test_redact_openjd_redacted_env_requests(): + """Test that redact_openjd_redacted_env_requests correctly redacts sensitive information in command strings.""" + # Test command without redaction needed + command = "echo hello world" + assert redact_openjd_redacted_env_requests(command) == command + + # Test command with redacted env + command = "python -c \"print('openjd_redacted_env: PASSWORD=secret123')\"" + assert ( + redact_openjd_redacted_env_requests(command) + == "python -c \"print('openjd_redacted_env: ********" + ) + + # Test command with multiple redacted env values + command = ( + 'echo "openjd_redacted_env: PASSWORD=secret123"; echo "openjd_redacted_env: API_KEY=abc123"' + ) + assert redact_openjd_redacted_env_requests(command) == 'echo "openjd_redacted_env: ********' + + +def test_redaction_with_string_formatting(): + """Test that redaction works correctly with string formatting.""" + # Create a list to capture callback calls + callback_calls = [] + + def callback(kind: ActionMessageKind, message: str, cancel: bool): + callback_calls.append((kind, message, cancel)) + + # Create a RevisionExtensions with REDACTED_ENV_VARS enabled + revision_extensions = RevisionExtensions( + spec_rev=SpecificationRevision.v2023_09, supported_extensions=["REDACTED_ENV_VARS"] + ) + + # Create the filter + action_filter = ActionMonitoringFilter( + session_id="test_session", callback=callback, revision_extensions=revision_extensions + ) + + # Add a value to redact via redacted_env message + record = logging.LogRecord( + name="test", + level=logging.INFO, + pathname="", + lineno=0, + msg="openjd_redacted_env: PASSWORD=secret123", + args=(), + exc_info=None, + ) + record.session_id = "test_session" + action_filter.filter(record) + + # Test string formatting with args + record = logging.LogRecord( + name="test", + level=logging.INFO, + msg="Command: %s", + args=("echo secret123",), + pathname="", + lineno=0, + exc_info=None, + ) + record.session_id = "test_session" + action_filter.filter(record) + + # Verify redaction happened after string formatting + assert record.msg == "Command: echo ********" + assert not record.args # Args should be cleared after formatting + + # Test multiple args + record = logging.LogRecord( + name="test", + level=logging.INFO, + msg="First: %s, Second: %s", + args=("secret123", "hello"), + pathname="", + lineno=0, + exc_info=None, + ) + record.session_id = "test_session" + action_filter.filter(record) + + assert record.msg == "First: ********, Second: hello" + assert not record.args + + # Test with non-string args + record = logging.LogRecord( + name="test", + level=logging.INFO, + msg="Number: %d, Secret: %s", + args=(42, "secret123"), + pathname="", + lineno=0, + exc_info=None, + ) + record.session_id = "test_session" + action_filter.filter(record) + + assert record.msg == "Number: 42, Secret: ********" + assert not record.args + + +class TestRedactionCore: + """Tests for the core redaction functionality.""" + + def test_redaction_preserves_spaces( + self, message_queue: SimpleQueue, queue_handler: logging.handlers.QueueHandler + ) -> None: + """Test that when redacting values in an f-string, spaces around the value are preserved.""" + # Setup + loga, _, _ = setup_action_filter_test( + queue_handler=queue_handler, + enabled_extensions=["REDACTED_ENV_VARS"], + ) + + # Set up redaction + loga.info("openjd_redacted_env: SECRETVAR=SECRETVAL") + + # Clear the queue of the setup messages + while not message_queue.empty(): + message_queue.get() + + # WHEN - Message with token + loga.info("SECRETVAR is . SECRETVAL ;") + + # THEN - The spaces should be preserved in the redacted output + assert message_queue.qsize() == 1 + log_message = message_queue.get(block=False).getMessage() + assert "SECRETVAR is . ******** ;" in log_message # Spaces should be preserved + assert "SECRETVAL" not in log_message + + def test_overlapping_redactions( + self, message_queue: SimpleQueue, queue_handler: logging.handlers.QueueHandler + ) -> None: + """Test that overlapping redactions are handled correctly.""" + # Setup + loga, _, _ = setup_action_filter_test( + queue_handler=queue_handler, + enabled_extensions=["REDACTED_ENV_VARS"], + ) + + # Test case 1: Overlapping redactions at boundary + loga.info("openjd_redacted_env: KEY1=FOOOBAR") + loga.info("openjd_redacted_env: KEY2=BARKEY") + + # Clear the queue of the setup messages + while not message_queue.empty(): + message_queue.get() + + # Log a message containing the overlapping string + loga.info("The value is: FOOOBARKEY") + + # The entire overlapping string should be redacted + assert message_queue.qsize() == 1 + log_message = message_queue.get(block=False).getMessage() + assert "The value is: ********" in log_message + assert "FOOOBARKEY" not in log_message + + # Test case 2: One redaction completely contained within another + loga.info("openjd_redacted_env: KEY3=SUPERSECRETPASSWORD") + loga.info("openjd_redacted_env: KEY4=SECRET") + + # Clear the queue of the setup messages + while not message_queue.empty(): + message_queue.get() + + # Log a message containing the nested redaction + loga.info("The value is: SUPERSECRETPASSWORD") + + # The entire string should be redacted with a single redaction + assert message_queue.qsize() == 1 + log_message = message_queue.get(block=False).getMessage() + assert "The value is: ********" in log_message + assert "SUPERSECRETPASSWORD" not in log_message diff --git a/test/openjd/sessions/test_session.py b/test/openjd/sessions/test_session.py index c550b109..eddafb8e 100644 --- a/test/openjd/sessions/test_session.py +++ b/test/openjd/sessions/test_session.py @@ -15,7 +15,13 @@ import pytest -from openjd.model import ParameterValue, ParameterValueType, SpecificationRevision, SymbolTable +from openjd.model import ( + ParameterValue, + ParameterValueType, + SpecificationRevision, + SymbolTable, + RevisionExtensions, +) from openjd.model.v2023_09 import Action as Action_2023_09 from openjd.model.v2023_09 import ( EmbeddedFileText as EmbeddedFileText_2023_09, @@ -3014,3 +3020,376 @@ def test_undef_via_stdout( # THEN assert "FOO=FOO-not-set" in caplog.messages assert "BAR=BAR-value" in caplog.messages + + @pytest.mark.usefixtures("caplog") # builtin fixture + def test_def_via_redacted_env_stdout( + self, caplog: pytest.LogCaptureFixture, step_script_definition: StepScript_2023_09 + ) -> None: + # Test that when an environment defines variables via a stdout handler with openjd_redacted_env + # the variable is set correctly but the value is redacted in logs + + # GIVEN + environment = Environment_2023_09( + name="Env", + script=EnvironmentScript_2023_09( + actions=EnvironmentActions_2023_09( + onEnter=Action_2023_09( + command=CommandString_2023_09(sys.executable), + args=[ + ArgString_2023_09("-c"), + ArgString_2023_09("print('openjd_redacted_env: PASSWORD=secret123')"), + ], + ) + ) + ), + ) + + # Create a script that will print the environment variable to verify it was set correctly + script = StepScript_2023_09( + actions=StepActions_2023_09( + onRun=Action_2023_09( + command=CommandString_2023_09(sys.executable), + args=[ + ArgString_2023_09("-c"), + ArgString_2023_09( + "import os; print(f'PASSWORD={os.environ[\"PASSWORD\"]}')" + ), + ], + ) + ), + ) + + session_id = uuid.uuid4().hex + job_params = dict[str, ParameterValue]() + with Session(session_id=session_id, job_parameter_values=job_params) as session: + session.enter_environment(environment=environment) + while session.state == SessionState.RUNNING: + time.sleep(0.1) + + # WHEN + session.run_task( + step_script=script, + task_parameter_values=dict[str, ParameterValue](), + ) + while session.state == SessionState.RUNNING: + time.sleep(0.1) + + # THEN + # Check that the redacted message appears in the logs + assert "openjd_redacted_env: PASSWORD=********" in caplog.messages + # Check that the actual value is not in the logs from the environment setup + assert "openjd_redacted_env: PASSWORD=secret123" not in caplog.messages + # Check that the script output shows a KeyError since the env var wasn't set + # (extension not enabled by default) + assert "KeyError: 'PASSWORD'" in "\n".join(caplog.messages) + # Check that the sensitive value doesn't appear anywhere in the logs + assert "secret123" not in "\n".join(caplog.messages) + + @pytest.mark.usefixtures("caplog") # builtin fixture + def test_def_via_redacted_env_json_stdout( + self, caplog: pytest.LogCaptureFixture, step_script_definition: StepScript_2023_09 + ) -> None: + # Test that when an environment defines variables via a stdout handler with openjd_redacted_env + # using JSON format, the variable is set correctly but the value is redacted in logs + + # GIVEN + environment = Environment_2023_09( + name="Env", + script=EnvironmentScript_2023_09( + actions=EnvironmentActions_2023_09( + onEnter=Action_2023_09( + command=CommandString_2023_09(sys.executable), + args=[ + ArgString_2023_09("-c"), + ArgString_2023_09("print('openjd_redacted_env: API_KEY=abc123def456')"), + ], + ) + ) + ), + ) + + # Create a script that will print the environment variable to verify it was set correctly + script = StepScript_2023_09( + actions=StepActions_2023_09( + onRun=Action_2023_09( + command=CommandString_2023_09(sys.executable), + args=[ + ArgString_2023_09("-c"), + ArgString_2023_09("import os; print(f'API_KEY={os.environ[\"API_KEY\"]}')"), + ], + ) + ), + ) + + session_id = uuid.uuid4().hex + job_params = dict[str, ParameterValue]() + with Session(session_id=session_id, job_parameter_values=job_params) as session: + session.enter_environment(environment=environment) + while session.state == SessionState.RUNNING: + time.sleep(0.1) + + # WHEN + session.run_task( + step_script=script, + task_parameter_values=dict[str, ParameterValue](), + ) + while session.state == SessionState.RUNNING: + time.sleep(0.1) + + # THEN + # Check that the redacted message appears in the logs with fixed-length redaction + assert "openjd_redacted_env: API_KEY=********" in "\n".join(caplog.messages) + # Check that the actual value is not in the logs from the environment setup + assert "openjd_redacted_env: API_KEY=abc123def456" not in "\n".join(caplog.messages) + # Check that the script output shows a KeyError since the env var wasn't set + # (extension not enabled by default) + assert "KeyError: 'API_KEY'" in "\n".join(caplog.messages) + # Check that the sensitive value doesn't appear anywhere in the logs + assert "abc123def456" not in "\n".join(caplog.messages) + + @pytest.mark.usefixtures("caplog") # builtin fixture + def test_session_with_enabled_extensions(self, caplog: pytest.LogCaptureFixture) -> None: + """Test that the Session constructor accepts and stores enabled_extensions.""" + # GIVEN + session_id = str(uuid.uuid4()) + job_parameter_values = {"Foo": ParameterValue(type=ParameterValueType.STRING, value="Bar")} + enabled_extensions = ["REDACTED_ENV_VARS"] + revision_extensions = RevisionExtensions( + spec_rev=SpecificationRevision.v2023_09, supported_extensions=enabled_extensions + ) + + # WHEN + with Session( + session_id=session_id, + job_parameter_values=job_parameter_values, + revision_extensions=revision_extensions, + ) as session: + # THEN + # Check that the revision_extensions was properly set + assert session.get_enabled_extensions() == enabled_extensions + + @pytest.mark.usefixtures("caplog") # builtin fixture + def test_session_with_no_extensions(self, caplog: pytest.LogCaptureFixture) -> None: + """Test that the Session constructor handles None for enabled_extensions.""" + # GIVEN + session_id = str(uuid.uuid4()) + job_parameter_values = {"Foo": ParameterValue(type=ParameterValueType.STRING, value="Bar")} + + # WHEN + with Session( + session_id=session_id, + job_parameter_values=job_parameter_values, + ) as session: + # THEN + # Check that the default empty list is used + assert session.get_enabled_extensions() == [] + + @pytest.mark.usefixtures("caplog") # builtin fixture + def test_def_via_redacted_env_with_variables(self, caplog: pytest.LogCaptureFixture) -> None: + """Test that redacted env vars override directly defined variables.""" + # GIVEN + environment = Environment_2023_09( + name="Env", + script=EnvironmentScript_2023_09( + actions=EnvironmentActions_2023_09( + onEnter=Action_2023_09( + command=CommandString_2023_09(sys.executable), + args=[ + ArgString_2023_09("-c"), + ArgString_2023_09("print('openjd_redacted_env: TOKEN=secret-token')"), + ], + ) + ) + ), + variables={"TOKEN": EnvironmentVariableValueString_2023_09("public-token")}, + ) + + # Create a script that will print the environment variable to verify it was set correctly + script = StepScript_2023_09( + actions=StepActions_2023_09( + onRun=Action_2023_09( + command=CommandString_2023_09(sys.executable), + args=[ + ArgString_2023_09("-c"), + ArgString_2023_09("import os; print(f'TOKEN={os.environ[\"TOKEN\"]}')"), + ], + ) + ), + ) + + session_id = uuid.uuid4().hex + job_params = dict[str, ParameterValue]() + with Session(session_id=session_id, job_parameter_values=job_params) as session: + session.enter_environment(environment=environment) + while session.state == SessionState.RUNNING: + time.sleep(0.1) + + # WHEN + session.run_task( + step_script=script, + task_parameter_values=dict[str, ParameterValue](), + ) + while session.state == SessionState.RUNNING: + time.sleep(0.1) + + # THEN + # Check that the redacted message appears in the logs + assert "openjd_redacted_env: TOKEN=********" in caplog.messages + # Check that the actual value is not in the logs from the environment setup + assert "openjd_redacted_env: TOKEN=secret-token" not in caplog.messages + # Check that the script output shows the original value since the redacted value wasn't set + # (extension not enabled by default) + assert "TOKEN=public-token" in caplog.messages + # Check that the sensitive value doesn't appear anywhere in the logs + assert "secret-token" not in "\n".join(caplog.messages) + # The original value is visible when setting up the environment + assert "Setting: TOKEN=public-token" in caplog.messages + + @pytest.mark.usefixtures("caplog") # builtin fixture + def test_def_via_redacted_env_with_extension(self, caplog: pytest.LogCaptureFixture) -> None: + """Test that redacted env vars are set when the extension is enabled.""" + # GIVEN + environment = Environment_2023_09( + name="Env", + script=EnvironmentScript_2023_09( + actions=EnvironmentActions_2023_09( + onEnter=Action_2023_09( + command=CommandString_2023_09(sys.executable), + args=[ + ArgString_2023_09("-c"), + ArgString_2023_09("print('openjd_redacted_env: PASSWORD=secret123')"), + ], + ) + ) + ), + ) + + # Create a script that will print the environment variable to verify it was set correctly + script = StepScript_2023_09( + actions=StepActions_2023_09( + onRun=Action_2023_09( + command=CommandString_2023_09(sys.executable), + args=[ + ArgString_2023_09("-c"), + ArgString_2023_09( + "import os; print(f'PASSWORD={os.environ[\"PASSWORD\"]}') if 'PASSWORD' in os.environ else print('PASSWORD not set')" + ), + ], + ) + ), + ) + + session_id = uuid.uuid4().hex + job_params = dict[str, ParameterValue]() + revision_extensions = RevisionExtensions( + spec_rev=SpecificationRevision.v2023_09, supported_extensions=["REDACTED_ENV_VARS"] + ) + with Session( + session_id=session_id, + job_parameter_values=job_params, + revision_extensions=revision_extensions, # Enable the extension + ) as session: + session.enter_environment(environment=environment) + while session.state == SessionState.RUNNING: + time.sleep(0.1) + + # WHEN + session.run_task( + step_script=script, + task_parameter_values=dict[str, ParameterValue](), + ) + while session.state == SessionState.RUNNING: + time.sleep(0.1) + + # THEN + # Check that the redacted message appears in the logs + assert "openjd_redacted_env: PASSWORD=********" in caplog.messages + # Check that the actual value is not in the logs from the environment setup + assert "openjd_redacted_env: PASSWORD=secret123" not in caplog.messages + # Check that the script was able to access the actual value but it's redacted in logs + assert "PASSWORD=********" in caplog.messages + # Check that the sensitive value doesn't appear anywhere in the logs + assert "secret123" not in "\n".join(caplog.messages) + + @pytest.mark.usefixtures("caplog") # builtin fixture + def test_multiple_different_redacted_env_vars(self, caplog: pytest.LogCaptureFixture) -> None: + """Test that multiple redacted env vars with similar but different values are handled correctly.""" + # GIVEN + # Create an environment that sets two similar but different redacted env vars + environment = Environment_2023_09( + name="Env", + script=EnvironmentScript_2023_09( + actions=EnvironmentActions_2023_09( + onEnter=Action_2023_09( + command=CommandString_2023_09(sys.executable), + args=[ + ArgString_2023_09("-c"), + ArgString_2023_09( + "print('openjd_redacted_env: PASSWORD=secret123'); print('openjd_redacted_env: PASSWORD2=mysecret123')" + ), + ], + ) + ) + ), + ) + + # Create a script that will print both environment variables to verify they were set correctly + script = StepScript_2023_09( + actions=StepActions_2023_09( + onRun=Action_2023_09( + command=CommandString_2023_09(sys.executable), + args=[ + ArgString_2023_09("-c"), + ArgString_2023_09( + 'import os; print(f\'PASSWORD={os.environ.get("PASSWORD", "not-set")}\'); ' + 'print(f\'PASSWORD2={os.environ.get("PASSWORD2", "not-set")}\'); ' + "print('Both values are present in this log: secret123 mysecret123')" + ), + ], + ) + ), + ) + + session_id = uuid.uuid4().hex + job_params = dict[str, ParameterValue]() + + # WHEN + revision_extensions = RevisionExtensions( + spec_rev=SpecificationRevision.v2023_09, supported_extensions=["REDACTED_ENV_VARS"] + ) + with Session( + session_id=session_id, + job_parameter_values=job_params, + revision_extensions=revision_extensions, # Enable the extension + ) as session: + session.enter_environment(environment=environment) + while session.state == SessionState.RUNNING: + time.sleep(0.1) + + session.run_task( + step_script=script, + task_parameter_values=dict[str, ParameterValue](), + ) + while session.state == SessionState.RUNNING: + time.sleep(0.1) + + # THEN + # Check that both redacted messages appear in the logs with redacted values + assert "openjd_redacted_env: PASSWORD=********" in caplog.messages + assert "openjd_redacted_env: PASSWORD2=********" in caplog.messages + + # Check that the actual values are not in the logs from the environment setup + assert "openjd_redacted_env: PASSWORD=secret123" not in caplog.messages + assert "openjd_redacted_env: PASSWORD2=mysecret123" not in caplog.messages + + # Check that the script output shows the variables were set but values are redacted + assert "PASSWORD=********" in caplog.messages + assert "PASSWORD2=********" in caplog.messages + + # Check that both sensitive values are redacted in the log line that contains both + assert "Both values are present in this log: ******** ********" in caplog.messages + + # Check that neither sensitive value appears anywhere in the logs + log_content = "\n".join(caplog.messages) + assert "secret123" not in log_content + assert "mysecret123" not in log_content