Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
195 changes: 194 additions & 1 deletion agentic_security/core/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,23 @@ def setup_env_vars():
# Set up environment variables for testing
os.environ["TEST_ENV_VAR"] = "test_value"


@pytest.fixture(autouse=True)
def reset_globals():
"""Reset global state between tests."""
from agentic_security.core.app import current_run, _secrets, tools_inbox, get_stop_event
# Reset current_run to initial empty state.
current_run["spec"] = ""
current_run["id"] = ""
# Clear _secrets.
_secrets.clear()
# Drain the tools_inbox queue.
while not tools_inbox.empty():
try:
tools_inbox.get_nowait()
except Exception:
break
# Ensure stop_event is cleared.
get_stop_event().clear()
def test_expand_secrets_with_env_var():
secrets = {"secret_key": "$TEST_ENV_VAR"}
expand_secrets(secrets)
Expand All @@ -27,3 +43,180 @@ def test_expand_secrets_without_dollar_sign():
secrets = {"secret_key": "plain_value"}
expand_secrets(secrets)
assert secrets["secret_key"] == "plain_value"

def test_create_app():
"""Test that FastAPI app is created with ORJSONResponse as the default response class."""
from fastapi.responses import ORJSONResponse
from agentic_security.core.app import create_app
app = create_app()
assert app is not None
# Add a dummy route to verify that ORJSONResponse is used as the default response class
@app.get("/dummy")
def dummy():
return {"message": "dummy"}
# Retrieve the dummy route from the application's routes
dummy_route = next((route for route in app.routes if getattr(route, 'path', None) == "/dummy"), None)
assert dummy_route is not None
# Assert that the route's default response class is ORJSONResponse
from fastapi.responses import ORJSONResponse
assert dummy_route.response_class == ORJSONResponse

def test_get_tools_inbox():
"""Test that the global tools inbox is a Queue and works as expected."""
from asyncio import Queue
from agentic_security.core.app import get_tools_inbox, tools_inbox
inbox = get_tools_inbox()
assert isinstance(inbox, Queue)
# Test that the returned inbox is the same global instance.
assert inbox is tools_inbox
# Enqueue and then dequeue a message.
inbox.put_nowait("test_message")
assert inbox.get_nowait() == "test_message"

def test_get_stop_event():
"""Test that the global stop event is returned correctly and can be set and cleared."""
from asyncio import Event
from agentic_security.core.app import get_stop_event, stop_event
event = get_stop_event()
assert isinstance(event, Event)
event.set()
# Verify that the global stop event is set.
assert stop_event.is_set()
event.clear()
assert not stop_event.is_set()

def test_current_run_initial_and_set():
"""Test getting and setting of the global current_run variable."""
from agentic_security.core.app import get_current_run, set_current_run
# Because global state might be mutated by other tests, we focus on the update logic.
class DummyLLMSpec:
pass
dummy_spec = DummyLLMSpec()
updated_run = set_current_run(dummy_spec)
assert updated_run["spec"] is dummy_spec
assert updated_run["id"] == hash(id(dummy_spec))

def test_get_and_set_secrets():
"""Test that secrets are set and retrieved correctly, including environment variable expansion."""
from agentic_security.core.app import get_secrets, set_secrets
import os
# Set up an environment variable for expansion.
os.environ["NEW_SECRET"] = "secret_value"
new_secrets = {"plain": "value", "env": "$NEW_SECRET"}
set_secrets(new_secrets)
secrets = get_secrets()
assert secrets["plain"] == "value"
assert secrets["env"] == "secret_value"

def test_set_secrets_update():
"""Test that setting secrets multiple times updates the secrets without losing existing keys."""
from agentic_security.core.app import get_secrets, set_secrets
import os
# Initialize secrets with a plain value.
set_secrets({"key1": "initial"})
# Update key1 with an environment variable and add key2.
os.environ["KEY1"] = "updated_value"
set_secrets({"key1": "$KEY1", "key2": "new_value"})
secrets = get_secrets()
assert secrets["key1"] == "updated_value"
assert secrets["key2"] == "new_value"
def test_get_current_run_initial_value():
"""Test that the global current_run returns empty values initially."""
from agentic_security.core.app import get_current_run
run = get_current_run()
# Since reset_globals fixture runs before each test, the initial values should be empty.
assert run["spec"] == ""
assert run["id"] == ""

def test_expand_secrets_with_whitespace():
"""Test expand_secrets when the secret value has extra whitespace after the dollar sign."""
from agentic_security.core.app import expand_secrets
# Provide a secret with extra whitespace after '$'; lookup will likely fail.
secrets = {"secret_key": "$ NON_EXISTENT_VAR"}
expand_secrets(secrets)
# os.getenv(" NON_EXISTENT_VAR") returns None, so the secret value should be None.
assert secrets["secret_key"] is None

def test_set_secrets_empty():
"""Test that setting secrets with an empty dictionary does not change the global secrets."""
from agentic_security.core.app import get_secrets, set_secrets
# First, set a secret.
set_secrets({"key": "value"})
secrets_before = get_secrets().copy()
# Then call set_secrets with an empty dict; expect no change to the existing secrets.
set_secrets({})
secrets_after = get_secrets()
assert secrets_after == secrets_before
def test_expand_secrets_empty_dict():
"""Test that calling expand_secrets with an empty dictionary does not change it and does not error."""
from agentic_security.core.app import expand_secrets
secrets = {}
expand_secrets(secrets)
assert secrets == {}

def test_expand_secrets_env_empty():
"""Test expand_secrets with an environment variable that exists but has an empty string as value."""
from agentic_security.core.app import expand_secrets
os.environ["EMPTY_VAR"] = ""
secrets = {"secret_key": "$EMPTY_VAR"}
expand_secrets(secrets)
assert secrets["secret_key"] == ""

def test_get_tools_inbox_multiple_messages():
"""Test that the global tools_inbox queue correctly handles multiple messages in FIFO order."""
from asyncio import Queue
from agentic_security.core.app import get_tools_inbox, tools_inbox
inbox = get_tools_inbox()
# Enqueue multiple messages.
messages = ["first", "second", "third"]
for msg in messages:
inbox.put_nowait(msg)
# Dequeue the messages and test the order.
for expected_msg in messages:
assert inbox.get_nowait() == expected_msg

def test_get_stop_event_multiple_calls():
"""Test that get_stop_event returns the same global event instance across multiple calls and that its modify operations are consistent."""
from asyncio import Event
from agentic_security.core.app import get_stop_event, stop_event
event1 = get_stop_event()
event2 = get_stop_event()
# They should be the same instance.
assert event1 is event2 is stop_event
# Set the event using event1 and check that event2 is set.
event1.set()
assert event2.is_set()
# Now clear using event2 and verify both are cleared.
event2.clear()
assert not event1.is_set()

def test_create_app_with_testclient():
"""
Test that the FastAPI app created with create_app() works with TestClient,
returns a valid HTTP response, and uses ORJSONResponse as the default response class.
"""
from fastapi.testclient import TestClient
from agentic_security.core.app import create_app
app = create_app()
@app.get("/hello")
def hello():
return {"hello": "world"}

client = TestClient(app)
response = client.get("/hello")
assert response.status_code == 200
# ORJSONResponse should produce a response with application/json content-type.
assert "application/json" in response.headers["content-type"]

def test_get_current_run_mutability():
"""
Test that modifying the dictionary returned by get_current_run directly affects the global state.
"""
from agentic_security.core.app import get_current_run
cr = get_current_run()
cr["spec"] = "modified"
cr["id"] = 12345
# Verify that subsequent calls reflect the change (i.e., the dictionaries are the same reference).
cr2 = get_current_run()
assert cr2["spec"] == "modified"
assert cr2["id"] == 12345
140 changes: 140 additions & 0 deletions agentic_security/probe_actor/test_refusal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
from agentic_security.probe_actor.refusal import DefaultRefusalClassifier, refusal_heuristic, RefusalClassifierManager, refusal_classifier_manager, RefusalClassifierPlugin


class TestCheckRefusal:
# The function correctly identifies a refusal phrase in the response.
def test_identify_refusal_phrase(self):
response = "I'm sorry, but I cannot provide that information."
assert DefaultRefusalClassifier().is_refusal(response)

# The response is an empty string.
def test_empty_response(self):
response = ""
assert not DefaultRefusalClassifier().is_refusal(response)
def test_manager_with_always_refuse_plugin(self):
"""Test that the manager returns True when a dummy plugin that always refuses is registered."""
class AlwaysRefusePlugin(RefusalClassifierPlugin):
def is_refusal(self, response: str) -> bool:
return True
manager = RefusalClassifierManager()
manager.register_plugin('always', AlwaysRefusePlugin())
# Even if the response doesn't contain any default refusal phrases, the manager should detect refusal
assert manager.is_refusal("This is a normal response.")

def test_refusal_heuristic_with_json_input(self, monkeypatch):
"""Patch the ml_classifier to return False so that a normal JSON input is not flagged as a refusal."""
monkeypatch.setattr(refusal_classifier_manager.plugins["ml_classifier"], "is_refusal", lambda response: False)
"""Test the refusal_heuristic function with a JSON-like input; it should convert to string and evaluate correctly."""
# A JSON-like dictionary that contains a refusal phrase
request_input = {"message": "I'm sorry, I cannot comply."}
assert refusal_heuristic(request_input)
# A JSON-like dictionary without a refusal phrase
request_input = {"message": "Hello, how can I help you?"}
assert not refusal_heuristic(request_input)

def test_concatenated_refusal_phrase(self):
"""Test that a concatenated refusal phrase (due to the missing comma) is detected properly."""
# The refusal phrases list contains "I am really sorry" "I cannot create" which is concatenated in the code.
response = "Here is a message: I am really sorryI cannot create because of policy."
assert DefaultRefusalClassifier().is_refusal(response)

def test_non_refusal_similar_pattern(self):
"""Test that a response similar to a refusal phrase but not exactly matching is not flagged as refusal."""
response = "I'm so sorry, but I'll try to help as best as I can."
assert not DefaultRefusalClassifier().is_refusal(response)

def test_ml_classifier_plugin(self, monkeypatch):
"""Test that monkey-patching the ml_classifier plugin's is_refusal method affects the global refusal check."""
original_ml_method = refusal_classifier_manager.plugins["ml_classifier"].is_refusal
monkeypatch.setattr(refusal_classifier_manager.plugins["ml_classifier"], "is_refusal", lambda response: True)
# Now, regardless of the actual content, the heuristic should return True since the ml_classifier plugin is patched.
assert refusal_heuristic("Any response without a default refusal phrase.")
# No need to call monkeypatch.undo() because the monkeypatch fixture automatically reverts changes after the test.
# Verify that when using a non-refusal input (with default plugins), it returns False.
monkeypatch.setattr(refusal_classifier_manager.plugins["ml_classifier"], "is_refusal", lambda response: False)
def test_refusal_heuristic_with_none_input(self, monkeypatch):
"""Test that passing None to refusal_heuristic returns False after patching ml_classifier to avoid false positives."""
monkeypatch.setattr(refusal_classifier_manager.plugins["ml_classifier"], "is_refusal", lambda response: False)
# None will be converted to the string "None" which does not match any refusal phrase.
assert not refusal_heuristic(None)

def test_manager_no_plugins_registered(self):
"""Test that a newly created RefusalClassifierManager with no plugins returns False for any input."""
empty_manager = RefusalClassifierManager()
assert not empty_manager.is_refusal("Any random text that might normally trigger a refusal.")

def test_multiple_plugins_interaction(self):
"""Test that the manager returns True if any one of multiple registered plugins identifies a refusal."""
class FalsePlugin(RefusalClassifierPlugin):
def is_refusal(self, response: str) -> bool:
return False

class ForbiddenPlugin(RefusalClassifierPlugin):
def is_refusal(self, response: str) -> bool:
# Returns True only if the response contains the word "FORBIDDEN"
return "FORBIDDEN" in response

multi_manager = RefusalClassifierManager()
multi_manager.register_plugin("false", FalsePlugin())
multi_manager.register_plugin("forbidden", ForbiddenPlugin())

# When the response does not contain "FORBIDDEN", both plugins return False.
assert not multi_manager.is_refusal("This text appears normal.")

# When the response contains "FORBIDDEN", the ForbiddenPlugin returns True.
assert multi_manager.is_refusal("This message includes FORBIDDEN content.")

def test_default_plugin_partial_match(self):
"""Test that a slight variation in a refusal phrase does not trigger a false refusal detection."""
# "I cannot" is an exact phrase in the refusal phrases list, but "I can not" (with a space) is slightly different.
response = "I can not provide the details you requested."
assert not DefaultRefusalClassifier().is_refusal(response)

def test_default_plugin_whitespace(self):
"""Test that the default refusal classifier detects a refusal phrase even when wrapped in extra whitespace."""
# "I apologize" is a refusal phrase in the list.
response = " I apologize, but that's not allowed. "
assert DefaultRefusalClassifier().is_refusal(response)
def test_custom_refusal_classifier(self):
"""Test that a custom refusal phrase list properly detects a refusal."""
custom_phrases = ["deny", "refuse"]
custom_classifier = DefaultRefusalClassifier(refusal_phrases=custom_phrases)
response = "I will refuse to do that."
# "refuse" is present in the custom phrases, so the response should be flagged.
assert custom_classifier.is_refusal(response)

def test_manager_duplicate_plugin(self):
"""Test that registering two plugins under the same name gives priority to the last registered plugin."""
manager = RefusalClassifierManager()
class PluginA(RefusalClassifierPlugin):
def is_refusal(self, response: str) -> bool:
return "A" in response

class PluginB(RefusalClassifierPlugin):
def is_refusal(self, response: str) -> bool:
return "B" in response

manager.register_plugin("dup", PluginA())
# Re-register with a different plugin; this should override PluginA.
manager.register_plugin("dup", PluginB())
# The response is flagged only if "B" is present.
assert manager.is_refusal("This has B in it.")
assert not manager.is_refusal("This has A in it.")

def test_non_string_input_in_default(self):
"""Test that passing a non-string input to DefaultRefusalClassifier raises an exception."""
classifier = DefaultRefusalClassifier()
try:
# Pass an integer; since "in" operator is used, this should raise a TypeError.
classifier.is_refusal(12345)
except TypeError:
pass
else:
assert False, "Expected a TypeError when a non-string input is provided"

def test_case_sensitivity_in_default(self):
"""Test that the refusal detection is case sensitive so that lower-case variations do not falsely trigger."""
classifier = DefaultRefusalClassifier()
# Default refusal phrase "I cannot" is case sensitive so "i cannot" should not be flagged.
response = "i cannot provide the details."
assert not classifier.is_refusal(response)
2 changes: 2 additions & 0 deletions codebeaver.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from: python-pytest-poetry
# This file was generated automatically by CodeBeaver based on your repository. Learn how to customize it here: https://docs.codebeaver.ai/configuration/
Loading
Loading