Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/agents/repository_analysis_agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ async def execute(self, **kwargs) -> AgentResult:
await analyze_pr_history(state, request.max_prs)
await analyze_contributing_guidelines(state)

# Only generate recommendations if we have basic repository data
if not state.repository_features.language:
raise ValueError("Unable to determine repository language - cannot generate appropriate rules")

state.recommendations = _default_recommendations(state)
validate_recommendations(state)
response = summarize_analysis(state, request)
Expand Down
86 changes: 67 additions & 19 deletions src/agents/repository_analysis_agent/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from __future__ import annotations

import logging
import textwrap
from typing import Any

Expand All @@ -32,6 +33,9 @@ async def analyze_repository_structure(state: RepositoryAnalysisState) -> None:
installation_id = state.installation_id

repo_data = await github_client.get_repository(repo, installation_id=installation_id)
if not repo_data:
raise ValueError(f"Could not fetch repository data for {repo}")

workflows = await github_client.list_directory_any_auth(
repo_full_name=repo, path=".github/workflows", installation_id=installation_id
)
Expand All @@ -42,7 +46,7 @@ async def analyze_repository_structure(state: RepositoryAnalysisState) -> None:
has_codeowners=bool(await github_client.get_file_content(repo, ".github/CODEOWNERS", installation_id)),
has_workflows=bool(workflows),
workflow_count=len(workflows or []),
language=(repo_data or {}).get("language"),
language=repo_data.get("language"),
contributor_count=len(contributors),
pr_count=0,
)
Expand All @@ -54,8 +58,14 @@ async def analyze_pr_history(state: RepositoryAnalysisState, max_prs: int) -> No
installation_id = state.installation_id
prs = await github_client.list_pull_requests(repo, installation_id=installation_id, state="all", per_page=max_prs)

if prs is None:
# If PR listing fails, continue with empty samples rather than failing
state.pr_samples = []
state.repository_features.pr_count = 0
return

samples: list[PullRequestSample] = []
for pr in prs or []:
for pr in prs:
samples.append(
PullRequestSample(
number=pr.get("number", 0),
Expand Down Expand Up @@ -215,19 +225,28 @@ def _default_recommendations(

Currently, validators like `author_team_is` and `file_patterns` operate independently.
"""
logger = logging.getLogger(__name__)

recommendations: list[RuleRecommendation] = []

# Get language-specific patterns based on repository analysis
source_patterns, test_patterns = _get_language_specific_patterns(state.repository_features.language)
language = state.repository_features.language
source_patterns, test_patterns = _get_language_specific_patterns(language)

logger.info(
f"Generating recommendations for {state.repository_full_name}: language={language}, pr_count={state.repository_features.pr_count}"
)

# Analyze PR history for bad habits
pr_issues = _analyze_pr_bad_habits(state)

# Require tests when source code changes.
# This is especially important if we detect missing tests in PR history
test_reasoning = f"Default guardrail for code changes without tests. Patterns adapted for {state.repository_features.language or 'multi-language'} repository."
test_reasoning = f"Repository analysis for {state.repository_full_name}. Language: {language or 'unknown'}. Patterns adapted for {language or 'multi-language'} repository."
if pr_issues.get("missing_tests", 0) > 0:
test_reasoning += f" Detected {pr_issues['missing_tests']} recent PRs without test files."
if state.contributing_analysis.content and state.contributing_analysis.requires_tests:
test_reasoning += " Contributing guidelines explicitly require tests."

# Build YAML rule with proper indentation
# parameters: is at column 0, source_patterns: at column 2, list items at column 4
Expand All @@ -239,55 +258,84 @@ def _default_recommendations(
severity: medium
event_types:
- pull_request
parameters:
parameters:
source_patterns:
{source_patterns_yaml}
test_patterns:
{test_patterns_yaml}
"""

confidence = 0.74
if pr_issues.get("missing_tests", 0) > 0:
confidence = 0.85
if state.contributing_analysis.content and state.contributing_analysis.requires_tests:
confidence = min(0.95, confidence + 0.1)

recommendations.append(
RuleRecommendation(
yaml_rule=yaml_content.strip(),
confidence=0.74 if pr_issues.get("missing_tests", 0) == 0 else 0.85,
confidence=confidence,
reasoning=test_reasoning,
strategy_used="hybrid",
)
)

# Require description in PR body.
# Increase confidence if we detect short titles in PR history (indicator of missing context)
desc_reasoning = "Encourage context for reviewers; lightweight default."
desc_reasoning = f"Repository analysis for {state.repository_full_name}."
if pr_issues.get("short_titles", 0) > 0:
desc_reasoning += f" Detected {pr_issues['short_titles']} PRs with very short titles (likely missing context)."
else:
desc_reasoning += " Encourages context for reviewers; lightweight default."

desc_confidence = 0.68
if pr_issues.get("short_titles", 0) > 0:
desc_confidence = 0.80

recommendations.append(
RuleRecommendation(
yaml_rule=textwrap.dedent(
"""
description: "Ensure PRs include context"
enabled: true
enabled: true
severity: low
event_types:
- pull_request
parameters:
event_types:
- pull_request
parameters:
min_description_length: 50
"""
).strip(),
confidence=0.68 if pr_issues.get("short_titles", 0) == 0 else 0.80,
confidence=desc_confidence,
reasoning=desc_reasoning,
strategy_used="static",
)
)

# If contributing guidelines require tests, increase confidence
if state.contributing_analysis.content is not None and state.contributing_analysis.requires_tests:
# Find the test rule and boost its confidence
for rec in recommendations:
if "tests" in rec.yaml_rule.lower():
rec.confidence = min(0.95, rec.confidence + 0.1)
rec.reasoning += " Contributing guidelines explicitly require tests."
# Add a repository-specific rule if we detect specific patterns
if state.repository_features.has_workflows:
workflow_rule = textwrap.dedent(
"""
description: "Protect CI/CD workflows"
enabled: true
severity: high
event_types:
- pull_request
parameters:
file_patterns:
- ".github/workflows/**"
"""
).strip()

recommendations.append(
RuleRecommendation(
yaml_rule=workflow_rule,
confidence=0.90,
reasoning=f"Repository {state.repository_full_name} has {state.repository_features.workflow_count} workflows that should be protected.",
strategy_used="static",
)
)

logger.info(f"Generated {len(recommendations)} recommendations for {state.repository_full_name}")
return recommendations


Expand Down
10 changes: 2 additions & 8 deletions src/agents/repository_analysis_agent/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,14 +83,8 @@
severity: "medium"
event_types:
- pull_request
conditions:
- type: "condition_type"
parameters:
key: "value"
actions:
- type: "action_type"
parameters:
key: "value"
parameters:
key: "value"
```

Make sure the rule is functional and follows best practices.
Expand Down
98 changes: 91 additions & 7 deletions src/api/recommendations.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,9 @@ async def recommend_rules(
if not request.repository_full_name or "/" not in request.repository_full_name:
raise HTTPException(status_code=400, detail="Invalid repository name format. Expected 'owner/repo'")

cache_key = f"repo_analysis:{request.repository_full_name}"
# Include authentication context in cache key to ensure different access levels get different results
auth_context = request.installation_id or request.user_token or "anonymous"
cache_key = f"repo_analysis:{request.repository_full_name}:{auth_context}"
cached_result = await get_cache(cache_key)

if cached_result:
Expand All @@ -57,6 +59,7 @@ async def recommend_rules(
"cache_hit",
operation="repository_analysis",
subject_ids=[request.repository_full_name],
auth_context=auth_context,
cached=True,
)
return RepositoryAnalysisResponse(**cached_result)
Expand Down Expand Up @@ -85,6 +88,8 @@ async def recommend_rules(
decision="failed",
error=result.message,
)
# Clear any cached results for this repository to ensure fresh analysis on retry
await set_cache(cache_key, None, ttl=1) # Use 1 second TTL to effectively clear cache
raise HTTPException(status_code=500, detail=result.message)

analysis_response = result.data.get("analysis_response")
Expand Down Expand Up @@ -161,6 +166,18 @@ async def proceed_with_pr(request: ProceedWithPullRequestRequest) -> ProceedWith
branch=request.branch_name,
existing_sha=existing_branch_sha,
)
# Verify the branch points to the correct base
if existing_branch_sha != base_sha:
log_structured(
logger,
"branch_sha_mismatch",
operation="proceed_with_pr",
subject_ids=[repo],
branch=request.branch_name,
existing_sha=existing_branch_sha,
expected_sha=base_sha,
warning="Branch exists but points to different SHA than base branch",
)
else:
# Create new branch
created_ref = await github_client.create_git_ref(repo, request.branch_name, base_sha, **auth_ctx)
Expand All @@ -182,6 +199,15 @@ async def proceed_with_pr(request: ProceedWithPullRequestRequest) -> ProceedWith
"The branch may already exist or you may not have permission to create branches."
),
)
log_structured(
logger,
"branch_created",
operation="proceed_with_pr",
subject_ids=[repo],
branch=request.branch_name,
base_branch=base_branch,
new_sha=created_ref.get("object", {}).get("sha"),
)

file_result = await github_client.create_or_update_file(
repo_full_name=repo,
Expand Down Expand Up @@ -209,6 +235,17 @@ async def proceed_with_pr(request: ProceedWithPullRequestRequest) -> ProceedWith
),
)

commit_sha = (file_result.get("commit") or {}).get("sha")
log_structured(
logger,
"file_created",
operation="proceed_with_pr",
subject_ids=[repo],
branch=request.branch_name,
file_path=request.file_path,
commit_sha=commit_sha,
)

pr = await github_client.create_pull_request(
repo_full_name=repo,
title=request.pr_title,
Expand Down Expand Up @@ -237,16 +274,62 @@ async def proceed_with_pr(request: ProceedWithPullRequestRequest) -> ProceedWith
)

pr_url = pr.get("html_url", "")
if not pr_url:
pr_number = pr.get("number")
if not pr_url or not pr_number:
log_structured(
logger,
"pr_url_missing",
"pr_creation_incomplete",
operation="proceed_with_pr",
subject_ids=[repo],
pr_data=pr,
error="PR created but html_url is missing",
pr_url=pr_url,
pr_number=pr_number,
error="PR creation response missing required fields",
)
raise HTTPException(status_code=500, detail="PR was created but response is incomplete")

# Validate the PR URL is a proper GitHub URL format
if not pr_url.startswith("https://github.com/") or "/pull/" not in pr_url:
log_structured(
logger,
"pr_url_invalid",
operation="proceed_with_pr",
subject_ids=[repo],
pr_url=pr_url,
pr_number=pr_number,
error="PR URL is not a valid GitHub pull request URL",
)
raise HTTPException(status_code=500, detail="PR was created but returned invalid URL format")

# Validate PR number is reasonable
if not isinstance(pr_number, int) or pr_number <= 0:
log_structured(
logger,
"pr_number_invalid",
operation="proceed_with_pr",
subject_ids=[repo],
pr_url=pr_url,
pr_number=pr_number,
error="PR number is invalid",
)
raise HTTPException(status_code=500, detail="PR was created but returned invalid PR number")

# Double-check URL format one more time
expected_url_pattern = f"https://github.com/{repo}/pull/{pr_number}"
if pr_url != expected_url_pattern:
log_structured(
logger,
"pr_url_mismatch",
operation="proceed_with_pr",
subject_ids=[repo],
expected_url=expected_url_pattern,
actual_url=pr_url,
pr_number=pr_number,
error="PR URL doesn't match expected pattern",
)
raise HTTPException(
status_code=500, detail=f"PR URL mismatch: expected {expected_url_pattern} but got {pr_url}"
)
raise HTTPException(status_code=500, detail="PR was created but URL is missing")

log_structured(
logger,
Expand All @@ -255,11 +338,12 @@ async def proceed_with_pr(request: ProceedWithPullRequestRequest) -> ProceedWith
subject_ids=[repo],
decision="success",
branch=request.branch_name,
pr_number=pr.get("number"),
pr_number=pr_number,
pr_url=pr_url,
)

return ProceedWithPullRequestResponse(
pull_request_url=pr.get("html_url", ""),
pull_request_url=pr_url,
branch_name=request.branch_name,
base_branch=base_branch,
file_path=request.file_path,
Expand Down