Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,11 @@ jobs:
- name: Install package with dev extras
run: pip install -e ".[dev]"

- name: pytest
run: pytest tests/ -v --tb=short
- name: Install pytest-cov
run: pip install pytest-cov

- name: pytest with coverage (≥80% required)
run: pytest tests/ -v --tb=short --cov=app --cov-report=term-missing --cov-fail-under=80

secret-scan:
name: Secret Scan
Expand Down
36 changes: 23 additions & 13 deletions app/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from .metrics import (
get_metrics, get_metrics_content_type, set_service_info,
record_precheck_request, record_postcheck_request, record_policy_evaluation,
set_active_requests
record_request_error, set_active_requests
)
from .settings import settings
from .auth import require_api_key
Expand All @@ -26,6 +26,12 @@

router = APIRouter()


def _ensure_correlation_id(corr_id: Optional[str]) -> str:
"""Return an existing correlation ID or generate a new one."""
return corr_id or f"corr-{secrets.token_hex(12)}"


def extract_pii_info_from_reasons(reasons: Optional[List[str]]) -> Tuple[List[str], float]:
"""Extract PII types and calculate confidence from reason codes"""
pii_types = []
Expand Down Expand Up @@ -219,6 +225,7 @@ async def precheck(
"""Precheck endpoint for policy evaluation and PII redaction"""
# User ID is optional - websocket will resolve from API key if needed
user_id = req.user_id
correlation_id = _ensure_correlation_id(req.corr_id)

# Rate limiting (100 requests per minute per user/api_key)
if user_id:
Expand All @@ -235,7 +242,7 @@ async def precheck(
start_ts = int(start_time)

try:
logger.debug("precheck request", extra={"tool": req.tool, "corr_id": req.corr_id})
logger.debug("precheck request", extra={"tool": req.tool, "corr_id": correlation_id})

# Use new policy evaluation with payload policies
policy_config = req.policy_config.model_dump() if req.policy_config else None
Expand Down Expand Up @@ -281,7 +288,7 @@ async def precheck(
"type": "INGEST",
"channel": webhook_channel,
"schema": "decision.v1",
"idempotencyKey": f"precheck-{start_ts}-{req.corr_id or 'unknown'}",
"idempotencyKey": f"precheck-{start_ts}-{correlation_id}",
"data": {
"orgId": webhook_org_id,
"direction": "precheck",
Expand All @@ -298,7 +305,7 @@ async def precheck(
},
"payloadHash": f"sha256:{hashlib.sha256(req.raw_text.encode()).hexdigest()}",
"latencyMs": int((time.time() - start_time) * 1000),
"correlationId": req.corr_id,
"correlationId": correlation_id,
"tags": [], # TODO: Extract from request or make configurable
"ts": f"{datetime.fromtimestamp(start_ts).isoformat()}Z",
"authentication": {
Expand All @@ -310,17 +317,17 @@ async def precheck(

# Fire and forget (don't block response path)
try:
asyncio.create_task(emit_event(event))
asyncio.create_task(emit_event(event, correlation_id=correlation_id))
except RuntimeError:
# If no running loop (tests), do it inline once
await emit_event(event)
await emit_event(event, correlation_id=correlation_id)

# Audit log before response
audit_log("precheck",
user_id=user_id,
tool=req.tool,
decision=result["decision"],
corr_id=req.corr_id,
corr_id=correlation_id,
policy_id=result.get("policy_id"),
reasons=result.get("reasons", []))

Expand All @@ -337,6 +344,7 @@ async def precheck(
return DecisionResponse(**result)

except Exception as e:
record_request_error("precheck", type(e).__name__)
# Re-raise the exception after clearing metrics
raise e

Expand All @@ -352,6 +360,7 @@ async def postcheck(
"""Postcheck endpoint for post-execution validation"""
# User ID is optional - websocket will resolve from API key if needed
user_id = req.user_id
correlation_id = _ensure_correlation_id(req.corr_id)

# Rate limiting (100 requests per minute per user/api_key)
if user_id:
Expand All @@ -368,7 +377,7 @@ async def postcheck(
start_ts = int(start_time)

try:
logger.debug("postcheck request", extra={"tool": req.tool, "corr_id": req.corr_id})
logger.debug("postcheck request", extra={"tool": req.tool, "corr_id": correlation_id})

# Use new policy evaluation with payload policies
policy_config = req.policy_config.model_dump() if req.policy_config else None
Expand Down Expand Up @@ -414,7 +423,7 @@ async def postcheck(
"type": "INGEST",
"channel": webhook_channel,
"schema": "decision.v1",
"idempotencyKey": f"postcheck-{start_ts}-{req.corr_id or 'unknown'}",
"idempotencyKey": f"postcheck-{start_ts}-{correlation_id}",
"data": {
"orgId": webhook_org_id,
"direction": "postcheck",
Expand All @@ -431,7 +440,7 @@ async def postcheck(
},
"payloadHash": f"sha256:{hashlib.sha256(req.raw_text.encode()).hexdigest()}",
"latencyMs": int((time.time() - start_time) * 1000),
"correlationId": req.corr_id,
"correlationId": correlation_id,
"tags": [], # TODO: Extract from request or make configurable
"ts": f"{datetime.fromtimestamp(start_ts).isoformat()}Z",
"authentication": {
Expand All @@ -443,17 +452,17 @@ async def postcheck(

# Fire and forget (don't block response path)
try:
asyncio.create_task(emit_event(event))
asyncio.create_task(emit_event(event, correlation_id=correlation_id))
except RuntimeError:
# If no running loop (tests), do it inline once
await emit_event(event)
await emit_event(event, correlation_id=correlation_id)

# Audit log before response
audit_log("postcheck",
user_id=user_id,
tool=req.tool,
decision=result["decision"],
corr_id=req.corr_id,
corr_id=correlation_id,
policy_id=result.get("policy_id"),
reasons=result.get("reasons", []))

Expand All @@ -470,6 +479,7 @@ async def postcheck(
return DecisionResponse(**result)

except Exception as e:
record_request_error("postcheck", type(e).__name__)
# Re-raise the exception after clearing metrics
raise e

Expand Down
4 changes: 4 additions & 0 deletions app/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from datetime import datetime
from typing import Optional
from .storage import get_db, APIKey
from .metrics import record_auth_failure


async def require_api_key(
Expand All @@ -11,14 +12,17 @@ async def require_api_key(
) -> str:
"""Require and validate API key from header against the database."""
if not x_governs_key:
record_auth_failure("missing_api_key")
raise HTTPException(status_code=401, detail="missing api key")

record = db.query(APIKey).filter(APIKey.key == x_governs_key).first()

if record is None or not record.is_active:
record_auth_failure("invalid_api_key")
raise HTTPException(status_code=401, detail="invalid api key")

if record.expires_at is not None and record.expires_at < datetime.utcnow():
record_auth_failure("expired_api_key")
raise HTTPException(status_code=401, detail="api key expired")

return x_governs_key
44 changes: 40 additions & 4 deletions app/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from urllib.parse import urlparse, parse_qs
from typing import Any, Dict, Optional, Tuple
from .settings import settings
from .metrics import record_webhook_event, record_dlq_event, set_dlq_size

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -53,18 +54,44 @@ def _write_dlq(event: Dict[str, Any], err: str, dlq_path: Optional[str] = None)
with open(path, "a", encoding="utf-8") as f:
f.write(json.dumps({"err": err, "event": event}) + "\n")

record_dlq_event(_error_type(err))
_set_dlq_size(path)


def _error_type(err: str) -> str:
"""Normalize raw error strings to a stable error type label."""
if not err:
return "unknown"
return err.split(":", 1)[0]


def _set_dlq_size(path: str) -> None:
"""Update DLQ size gauge from the current DLQ file line count."""
try:
with open(path, "r", encoding="utf-8") as f:
size = sum(1 for _ in f)
set_dlq_size(size)
except FileNotFoundError:
set_dlq_size(0)
except Exception as e:
logger.warning("Failed to set DLQ size: %s", type(e).__name__)


async def _sleep_ms(ms: int):
"""Sleep for specified milliseconds"""
await asyncio.sleep(ms / 1000.0)


async def _send_via_websocket(url: str, message: str, api_key: Optional[str]) -> None:
async def _send_via_websocket(
url: str, message: str, api_key: Optional[str], correlation_id: Optional[str] = None
) -> None:
"""Open a WebSocket connection, authenticate if key is available, then send message."""
headers = {"X-Correlation-ID": correlation_id} if correlation_id else None
async with websockets.connect(
url,
open_timeout=settings.webhook_timeout_s,
close_timeout=settings.webhook_timeout_s,
extra_headers=headers,
) as websocket:
if api_key:
auth_msg = json.dumps({"type": "AUTH", "apiKey": api_key})
Expand All @@ -81,16 +108,22 @@ async def _send_via_websocket(url: str, message: str, api_key: Optional[str]) ->
await websocket.send(message)


async def emit_event(event: Dict[str, Any]) -> None:
async def emit_event(event: Dict[str, Any], correlation_id: Optional[str] = None) -> None:
"""Sends the event via WebSocket to WEBHOOK_URL.
Authenticates the connection before sending, so the raw API key never
travels inside the INGEST payload.
Falls back to DLQ (jsonl) after retries."""
webhook_url = settings.webhook_url
dlq_path = settings.precheck_dlq
event_type = str(event.get("schema") or event.get("type") or "unknown")
correlation = correlation_id or event.get("correlationId")
if not correlation and isinstance(event.get("data"), dict):
correlation = event["data"].get("correlationId")
emit_started_at = time.time()

if not webhook_url:
_write_dlq(event, "webhook_url_not_configured", dlq_path)
record_webhook_event(event_type, "failed", 0.0)
return

websocket_url = webhook_url
Expand All @@ -103,8 +136,9 @@ async def emit_event(event: Dict[str, Any]) -> None:
err = "no_attempts"
for attempt in range(1, settings.webhook_max_retries + 1):
try:
await _send_via_websocket(websocket_url, message, conn_api_key)
await _send_via_websocket(websocket_url, message, conn_api_key, correlation)
logger.debug("event emitted attempt=%d", attempt)
record_webhook_event(event_type, "success", time.time() - emit_started_at)
return
except Exception as e:
err = f"websocket_exception:{type(e).__name__}:{str(e)[:200]}"
Expand All @@ -116,8 +150,9 @@ async def emit_event(event: Dict[str, Any]) -> None:
if "SSL" in str(e) and websocket_url.startswith("wss://"):
try:
fallback_url = websocket_url.replace("wss://", "ws://", 1)
await _send_via_websocket(fallback_url, message, conn_api_key)
await _send_via_websocket(fallback_url, message, conn_api_key, correlation)
logger.debug("event emitted via ssl fallback attempt=%d", attempt)
record_webhook_event(event_type, "success", time.time() - emit_started_at)
return
except Exception as fallback_e:
err = f"websocket_fallback_exception:{type(fallback_e).__name__}:{str(fallback_e)[:200]}"
Expand All @@ -128,6 +163,7 @@ async def emit_event(event: Dict[str, Any]) -> None:

if attempt == settings.webhook_max_retries:
_write_dlq(event, err, dlq_path)
record_webhook_event(event_type, "failed", time.time() - emit_started_at)
return
await _sleep_ms(delay_ms)
delay_ms *= 2
25 changes: 25 additions & 0 deletions app/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,18 @@
['error_type']
)

auth_failures_total = Counter(
'auth_failures_total',
'Total number of authentication failures',
['reason']
)

request_errors_total = Counter(
'request_errors_total',
'Total number of request processing errors',
['endpoint', 'error_type']
)

# Histogram metrics
precheck_duration_seconds = Histogram(
'precheck_duration_seconds',
Expand Down Expand Up @@ -171,6 +183,19 @@ def record_dlq_event(error_type: str):
error_type=error_type
).inc()

def record_auth_failure(reason: str):
"""Record an authentication failure."""
auth_failures_total.labels(
reason=reason
).inc()

def record_request_error(endpoint: str, error_type: str):
"""Record request processing errors by endpoint."""
request_errors_total.labels(
endpoint=endpoint,
error_type=error_type
).inc()

def set_active_requests(endpoint: str, count: int):
"""Set the number of active requests"""
active_requests.labels(endpoint=endpoint).set(count)
Expand Down
21 changes: 20 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ dependencies = [
dev = [
"pytest>=7.4.0",
"pytest-asyncio>=0.21.0",
"pytest-cov>=4.1.0",
"httpx>=0.25.0",
"black>=23.0.0",
"isort>=5.12.0",
Expand Down Expand Up @@ -103,5 +104,23 @@ testpaths = ["tests"]
python_files = ["test_*.py"]
python_classes = ["Test*"]
python_functions = ["test_*"]
addopts = "-v --tb=short"
addopts = "-v --tb=short --cov=app --cov-report=term-missing --cov-fail-under=80"
asyncio_mode = "auto"

[tool.coverage.run]
source = ["app"]
omit = [
"app/migrations/*",
"app/__pycache__/*",
]

[tool.coverage.report]
# Enforce 80% coverage on the critical policy engine path
fail_under = 80
exclude_lines = [
"pragma: no cover",
"def __repr__",
"if __name__ == .__main__.:",
"raise NotImplementedError",
"pass$",
]
Loading
Loading