diff --git a/aikido_zen/sinks/psycopg.py b/aikido_zen/sinks/psycopg.py index fd2f436cb..5518848ef 100644 --- a/aikido_zen/sinks/psycopg.py +++ b/aikido_zen/sinks/psycopg.py @@ -11,8 +11,7 @@ @before def _copy(func, instance, args, kwargs): statement = get_argument(args, kwargs, 0, "statement") - - op = "psycopg.Cursor.copy" + op = f"psycopg.{instance.__class__.__name__}.copy" register_call(op, "sql_op") vulns.run_vulnerability_scan( @@ -23,7 +22,7 @@ def _copy(func, instance, args, kwargs): @before def _execute(func, instance, args, kwargs): query = get_argument(args, kwargs, 0, "query") - op = f"psycopg.Cursor.{func.__name__}" + op = f"psycopg.{instance.__class__.__name__}.{func.__name__}" vulns.run_vulnerability_scan(kind="sql_injection", op=op, args=(query, "postgres")) @@ -38,3 +37,13 @@ def patch(m): patch_function(m, "Cursor.copy", _copy) patch_function(m, "Cursor.execute", _execute) patch_function(m, "Cursor.executemany", _execute) + + +@on_import("psycopg.cursor_async", "psycopg", version_requirement="3.1.0") +def patch_async(m): + """ + patching module psycopg.cursor_async (similar to normal patch) + """ + patch_function(m, "AsyncCursor.copy", _copy) + patch_function(m, "AsyncCursor.execute", _execute) + patch_function(m, "AsyncCursor.executemany", _execute) diff --git a/aikido_zen/sinks/tests/psycopg_test.py b/aikido_zen/sinks/tests/psycopg_test.py index 94e9a6580..e9a8a89f3 100644 --- a/aikido_zen/sinks/tests/psycopg_test.py +++ b/aikido_zen/sinks/tests/psycopg_test.py @@ -1,5 +1,8 @@ import pytest from unittest.mock import patch + +import pytest_asyncio + import aikido_zen.sinks.psycopg from aikido_zen.background_process.comms import reset_comms @@ -99,3 +102,113 @@ def test_cursor_copy(database_conn): mock_run_vulnerability_scan.assert_called_once() database_conn.close() + + +@pytest.mark.asyncio +async def test_async_cursor_execute(): + import psycopg + + reset_comms() + async with await psycopg.AsyncConnection.connect( + host="127.0.0.1", user="user", password="password", dbname="db" + ) as conn: + async with conn.cursor() as cursor: + with patch( + "aikido_zen.vulnerabilities.run_vulnerability_scan" + ) as mock_run_vulnerability_scan: + query = "SELECT * FROM dogs" + await cursor.execute(query) + + called_with = mock_run_vulnerability_scan.call_args[1] + assert called_with["args"][0] == query + assert called_with["args"][1] == "postgres" + assert called_with["op"] == "psycopg.AsyncCursor.execute" + assert called_with["kind"] == "sql_injection" + mock_run_vulnerability_scan.assert_called_once() + + await cursor.fetchall() + mock_run_vulnerability_scan.assert_called_once() + + +@pytest.mark.asyncio +async def test_async_cursor_execute_parameterized(): + import psycopg + + reset_comms() + query = "SELECT * FROM dogs WHERE dog_name = %s" + params = ("Fido",) + + async with await psycopg.AsyncConnection.connect( + host="127.0.0.1", user="user", password="password", dbname="db" + ) as conn: + async with conn.cursor() as cursor: + with patch( + "aikido_zen.vulnerabilities.run_vulnerability_scan" + ) as mock_run_vulnerability_scan: + await cursor.execute(query, params) + + called_with = mock_run_vulnerability_scan.call_args[1] + assert ( + called_with["args"][0] == "SELECT * FROM dogs WHERE dog_name = %s" + ) + assert called_with["args"][1] == "postgres" + assert called_with["op"] == "psycopg.AsyncCursor.execute" + assert called_with["kind"] == "sql_injection" + mock_run_vulnerability_scan.assert_called_once() + + await cursor.fetchall() + + +@pytest.mark.asyncio +async def test_async_cursor_executemany(): + import psycopg + + reset_comms() + query = "INSERT INTO dogs (dog_name, isadmin) VALUES (%s, %s)" + params = [("doggo1", False), ("Rex", False), ("Buddy", True)] + + async with await psycopg.AsyncConnection.connect( + host="127.0.0.1", user="user", password="password", dbname="db" + ) as conn: + async with conn.cursor() as cursor: + with patch( + "aikido_zen.vulnerabilities.run_vulnerability_scan" + ) as mock_run_vulnerability_scan: + await cursor.executemany(query, params) + + # Check the last call to run_vulnerability_scan + called_with = mock_run_vulnerability_scan.call_args[1] + assert ( + called_with["args"][0] + == "INSERT INTO dogs (dog_name, isadmin) VALUES (%s, %s)" + ) + assert called_with["args"][1] == "postgres" + assert called_with["op"] == "psycopg.AsyncCursor.executemany" + assert called_with["kind"] == "sql_injection" + mock_run_vulnerability_scan.assert_called() + + await conn.commit() + + +@pytest.mark.asyncio +async def test_async_cursor_copy(): + import psycopg + + reset_comms() + query = "COPY dogs FROM STDIN" + + async with await psycopg.AsyncConnection.connect( + host="127.0.0.1", user="user", password="password", dbname="db" + ) as conn: + async with conn.cursor() as cursor: + with patch( + "aikido_zen.vulnerabilities.run_vulnerability_scan" + ) as mock_run_vulnerability_scan: + cursor.copy(query) + + called_with = mock_run_vulnerability_scan.call_args[1] + assert called_with["args"][0] == query + assert called_with["args"][1] == "postgres" + assert called_with["op"] == "psycopg.AsyncCursor.copy" + assert called_with["kind"] == "sql_injection" + mock_run_vulnerability_scan.assert_called_once()