From 5fba7ac87a36e279dce2d0aa553de1cae07afbea Mon Sep 17 00:00:00 2001 From: omsherikar Date: Sun, 8 Feb 2026 02:01:27 +0530 Subject: [PATCH 01/19] refactron: backup before refactoring /Users/omsherikar/Refactron_Root/Refactron_Lib/complex_test_repo --- complex_test_repo/core/engine.py | 69 ++++++++++++++++++++++++ complex_test_repo/data/processor.py | 51 ++++++++++++++++++ complex_test_repo/main.py | 20 +++++++ complex_test_repo/utils/math_lib.py | 17 ++++++ complex_test_repo/utils/string_helper.py | 34 ++++++++++++ 5 files changed, 191 insertions(+) create mode 100644 complex_test_repo/core/engine.py create mode 100644 complex_test_repo/data/processor.py create mode 100644 complex_test_repo/main.py create mode 100644 complex_test_repo/utils/math_lib.py create mode 100644 complex_test_repo/utils/string_helper.py diff --git a/complex_test_repo/core/engine.py b/complex_test_repo/core/engine.py new file mode 100644 index 0000000..7945a08 --- /dev/null +++ b/complex_test_repo/core/engine.py @@ -0,0 +1,69 @@ + +import os +import sqlite3 + +class ProcessingEngine: + '''processing engine class. + + Attributes: + attribute1: Description of attribute1 + attribute2: Description of attribute2 + ''' + def __init__(self, mode="all"): + ''' + init . + + Args: + self: Class instance + mode: The mode + ''' + self.mode = mode + self.db = sqlite3.connect(":memory:") + + def execute(self, command): + ''' + Execute. + + Args: + self: Class instance + command: The command + ''' + # Security risk: Command injection + if self.mode == "dangerous": + os.system(command) + else: + print(f"Executing: {command}") + + def query_user(self, user_id): + ''' + Query user. + + Args: + self: Class instance + user_id: Unique identifier + + Returns: + The result of the operation + ''' + # Security risk: SQL injection + cursor = self.db.cursor() + query = f"SELECT * FROM users WHERE id = {user_id}" + cursor.execute(query) + return cursor.fetchone() + + def process_items(self, items): + ''' + Process items. + + Args: + self: Class instance + items: The items + + Returns: + The result of the operation + ''' + # Performance issue: String concatenation in loop + result = "" + for item in items: + result += str(item) + "," + return result diff --git a/complex_test_repo/data/processor.py b/complex_test_repo/data/processor.py new file mode 100644 index 0000000..6567205 --- /dev/null +++ b/complex_test_repo/data/processor.py @@ -0,0 +1,51 @@ + +import time + +def process_batch(data_list): + ''' + Process batch. + + Args: + data_list: Data to process + + Returns: + The result of the operation + ''' + # Performance issue: N+1 pattern or inefficient iteration + results = [] + for item in data_list: + # Simulating sub-query or heavy processing in loop + detail = get_item_detail(item) + results.append(detail) + return results + +def get_item_detail(item): + ''' + Get item detail. + + Args: + item: The item + + Returns: + The requested item detail + ''' + return {"id": item, "details": "example"} + +def deep_nesting_example(a, b, c, d): + ''' + Deep nesting example. + + Args: + a: The a + b: The b + c: The c + d: The d + ''' + # Complexity issue: Deep nesting + if a: + if b: + for i in range(10): + if c: + while d: + print(i) + break diff --git a/complex_test_repo/main.py b/complex_test_repo/main.py new file mode 100644 index 0000000..428696d --- /dev/null +++ b/complex_test_repo/main.py @@ -0,0 +1,20 @@ + +from utils.math_lib import legacy_compute +from core.engine import ProcessingEngine +from data.processor import process_batch + +def run(): + ''' + Run. + ''' + engine = ProcessingEngine() + val = legacy_compute(10, 20) + print(f"Result: {val}") + engine.execute("ls") + + data = [1, 2, 3] + processed = process_batch(data) + print(f"Processed: {processed}") + +if __name__ == "__main__": + run() diff --git a/complex_test_repo/utils/math_lib.py b/complex_test_repo/utils/math_lib.py new file mode 100644 index 0000000..f4b261d --- /dev/null +++ b/complex_test_repo/utils/math_lib.py @@ -0,0 +1,17 @@ + +import os + +def legacy_compute(a, b): + ''' + Legacy compute. + + Args: + a: The a + b: The b + + Returns: + The result of the operation + ''' + # Magic numbers and no docstring + res = a * 1.05 + b * 0.95 + return res diff --git a/complex_test_repo/utils/string_helper.py b/complex_test_repo/utils/string_helper.py new file mode 100644 index 0000000..00e55e0 --- /dev/null +++ b/complex_test_repo/utils/string_helper.py @@ -0,0 +1,34 @@ + +import sys +import os + +def clean_text(text): + ''' + Clean text. + + Args: + text: The text + + Returns: + The result of the operation + ''' + return text.strip().lower() + +def clean_content(content): + ''' + Clean content. + + Args: + content: The content + + Returns: + The result of the operation + ''' + return content.strip().lower() + +def helper_unused_private(): + ''' + Helper unused private. + ''' + # This function is defined but not used within this file + pass From c5eacb8707e133f853887fb78b57e54c9bf6eb3c Mon Sep 17 00:00:00 2001 From: omsherikar Date: Sun, 8 Feb 2026 02:03:37 +0530 Subject: [PATCH 02/19] refactron: backup before refactoring /Users/omsherikar/Refactron_Root/Refactron_Lib/complex_test_repo --- complex_test_repo/data/processor.py | 32 ++++++++++++++--------------- complex_test_repo/main.py | 8 ++++++-- complex_test_repo/utils/math_lib.py | 5 ++++- 3 files changed, 25 insertions(+), 20 deletions(-) diff --git a/complex_test_repo/data/processor.py b/complex_test_repo/data/processor.py index 6567205..97ce3f2 100644 --- a/complex_test_repo/data/processor.py +++ b/complex_test_repo/data/processor.py @@ -32,20 +32,18 @@ def get_item_detail(item): return {"id": item, "details": "example"} def deep_nesting_example(a, b, c, d): - ''' - Deep nesting example. - - Args: - a: The a - b: The b - c: The c - d: The d - ''' - # Complexity issue: Deep nesting - if a: - if b: - for i in range(10): - if c: - while d: - print(i) - break + '''Refactored version using early returns (guard clauses).''' + # Check invalid conditions first and return early + if not a: + return default_value + + # Each subsequent check is at the same level - no deep nesting + if not meets_requirement_1(): + return early_result_1 + + if not meets_requirement_2(): + return early_result_2 + + # Main logic is at top level - easy to read + result = perform_main_operation() + return result diff --git a/complex_test_repo/main.py b/complex_test_repo/main.py index 428696d..388b0b8 100644 --- a/complex_test_repo/main.py +++ b/complex_test_repo/main.py @@ -3,16 +3,20 @@ from core.engine import ProcessingEngine from data.processor import process_batch +THRESHOLD_10 = 10 +THRESHOLD_20 = 20 +CONSTANT_3 = 3 + def run(): ''' Run. ''' engine = ProcessingEngine() - val = legacy_compute(10, 20) + val = legacy_compute(THRESHOLD_10, THRESHOLD_20) print(f"Result: {val}") engine.execute("ls") - data = [1, 2, 3] + data = [1, 2, CONSTANT_3] processed = process_batch(data) print(f"Processed: {processed}") diff --git a/complex_test_repo/utils/math_lib.py b/complex_test_repo/utils/math_lib.py index f4b261d..7307320 100644 --- a/complex_test_repo/utils/math_lib.py +++ b/complex_test_repo/utils/math_lib.py @@ -1,6 +1,9 @@ import os +SURCHARGE_RATE = 1.05 +CONSTANT_0_95 = 0.95 + def legacy_compute(a, b): ''' Legacy compute. @@ -13,5 +16,5 @@ def legacy_compute(a, b): The result of the operation ''' # Magic numbers and no docstring - res = a * 1.05 + b * 0.95 + res = a * SURCHARGE_RATE + b * CONSTANT_0_95 return res From e9bb32920f1299bcf224b883dd888579e516569d Mon Sep 17 00:00:00 2001 From: omsherikar Date: Sun, 8 Feb 2026 17:10:47 +0530 Subject: [PATCH 03/19] release: v1.0.15 - LLM/RAG integration, test fixes, ML foundation Features: - LLM orchestration system with multi-backend support (Groq, OpenAI) - RAG-based code retrieval with FAISS indexing - Repository and workspace management - ML infrastructure foundation (AI module, metrics, analysis tools) Bug Fixes: - Fixed 9 critical test failures (security rule IDs, pattern fingerprinting, ranking) - All 716 tests now passing with 73% coverage - Resolved anonymization behavior in pattern learning Improvements: - Enhanced security analyzer with better context awareness - Improved CLI feedback and error messages - Code quality and linting improvements Breaking Changes: None --- complex_test_repo/core/__init__.py | 1 + complex_test_repo/data/__init__.py | 1 + complex_test_repo/utils/__init__.py | 1 + pyproject.toml | 10 +- refactron/__init__.py | 2 +- refactron/analyzers/security_analyzer.py | 116 ++- refactron/cli.py | 1081 +++++++++++++++++++++- refactron/core/config.py | 1 + refactron/core/models.py | 1 + refactron/core/refactor_result.py | 41 +- refactron/core/refactron.py | 15 +- refactron/core/repositories.py | 177 ++++ refactron/core/workspace.py | 226 +++++ refactron/llm/__init__.py | 6 + refactron/llm/backend_client.py | 119 +++ refactron/llm/client.py | 94 ++ refactron/llm/models.py | 66 ++ refactron/llm/orchestrator.py | 269 ++++++ refactron/llm/prompts.py | 94 ++ refactron/llm/safety.py | 113 +++ refactron/patterns/fingerprint.py | 54 +- refactron/rag/__init__.py | 6 + refactron/rag/chunker.py | 179 ++++ refactron/rag/indexer.py | 281 ++++++ refactron/rag/parser.py | 253 +++++ refactron/rag/retriever.py | 178 ++++ scripts/analyze_feedback_data.py | 133 +++ tests/test_analyzer_edge_cases.py | 2 +- tests/test_backend_client.py | 73 ++ tests/test_false_positive_reduction.py | 4 +- tests/test_groq_client.py | 110 +++ tests/test_llm_orchestrator.py | 145 +++ tests/test_patterns_fingerprint.py | 11 +- tests/test_patterns_integration.py | 8 +- tests/test_patterns_learner.py | 6 +- tests/test_patterns_ranker.py | 24 +- tests/test_rag_chunker.py | 147 +++ tests/test_rag_indexer.py | 171 ++++ tests/test_rag_parser.py | 148 +++ tests/test_rag_retriever.py | 195 ++++ 40 files changed, 4469 insertions(+), 93 deletions(-) create mode 100644 complex_test_repo/core/__init__.py create mode 100644 complex_test_repo/data/__init__.py create mode 100644 complex_test_repo/utils/__init__.py create mode 100644 refactron/core/repositories.py create mode 100644 refactron/core/workspace.py create mode 100644 refactron/llm/__init__.py create mode 100644 refactron/llm/backend_client.py create mode 100644 refactron/llm/client.py create mode 100644 refactron/llm/models.py create mode 100644 refactron/llm/orchestrator.py create mode 100644 refactron/llm/prompts.py create mode 100644 refactron/llm/safety.py create mode 100644 refactron/rag/__init__.py create mode 100644 refactron/rag/chunker.py create mode 100644 refactron/rag/indexer.py create mode 100644 refactron/rag/parser.py create mode 100644 refactron/rag/retriever.py create mode 100644 scripts/analyze_feedback_data.py create mode 100644 tests/test_backend_client.py create mode 100644 tests/test_groq_client.py create mode 100644 tests/test_llm_orchestrator.py create mode 100644 tests/test_rag_chunker.py create mode 100644 tests/test_rag_indexer.py create mode 100644 tests/test_rag_parser.py create mode 100644 tests/test_rag_retriever.py diff --git a/complex_test_repo/core/__init__.py b/complex_test_repo/core/__init__.py new file mode 100644 index 0000000..6d2adc4 --- /dev/null +++ b/complex_test_repo/core/__init__.py @@ -0,0 +1 @@ +"""Core engine module.""" diff --git a/complex_test_repo/data/__init__.py b/complex_test_repo/data/__init__.py new file mode 100644 index 0000000..3da002b --- /dev/null +++ b/complex_test_repo/data/__init__.py @@ -0,0 +1 @@ +"""Data processing module.""" diff --git a/complex_test_repo/utils/__init__.py b/complex_test_repo/utils/__init__.py new file mode 100644 index 0000000..1bfd38b --- /dev/null +++ b/complex_test_repo/utils/__init__.py @@ -0,0 +1 @@ +"""Utility functions module.""" diff --git a/pyproject.toml b/pyproject.toml index f6acb77..e6cdcc1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "refactron" -version = "1.0.14" +version = "1.0.15" description = "Python code analysis and refactoring tool with security scanning, performance detection, and automated fixes" readme = "README.md" requires-python = ">=3.8" @@ -35,6 +35,14 @@ dependencies = [ "radon>=6.0.0", "requests>=2.25.0", "astroid>=3.0.0", + # RAG Infrastructure + "chromadb>=0.4.22", + "tree-sitter>=0.20.4", + "tree-sitter-python>=0.20.4", + "sentence-transformers>=2.5.1", + # LLM Integration (Free Cloud APIs) + "groq>=0.4.0", # Free cloud LLM (Llama 3, Mixtral) + "pydantic>=2.6.0", ] [project.optional-dependencies] diff --git a/refactron/__init__.py b/refactron/__init__.py index ff2a735..a116b3f 100644 --- a/refactron/__init__.py +++ b/refactron/__init__.py @@ -9,7 +9,7 @@ from refactron.core.refactor_result import RefactorResult from refactron.core.refactron import Refactron -__version__ = "1.0.14" +__version__ = "1.0.15" __author__ = "Om Sherikar" __all__ = [ diff --git a/refactron/analyzers/security_analyzer.py b/refactron/analyzers/security_analyzer.py index 07b3099..9a70de2 100644 --- a/refactron/analyzers/security_analyzer.py +++ b/refactron/analyzers/security_analyzer.py @@ -3,7 +3,7 @@ import ast import fnmatch from pathlib import Path -from typing import List +from typing import Dict, List from refactron.analyzers.base_analyzer import BaseAnalyzer from refactron.core.models import CodeIssue, IssueCategory, IssueLevel @@ -24,6 +24,8 @@ class SecurityAnalyzer(BaseAnalyzer): "compile": "Potential code injection - use with extreme caution", "__import__": "Dynamic imports can be dangerous - use importlib instead", "input": "In Python 2, input() evaluates code - use raw_input() or upgrade to Python 3", + "system": "Command injection risk - uses a shell. Use subprocess.run() instead", + "popen": "Command injection risk - uses a shell. Use subprocess.Popen() with a list instead", # noqa: E501 } # Dangerous modules @@ -125,6 +127,9 @@ def analyze(self, file_path: Path, source_code: str) -> List[CodeIssue]: try: tree = ast.parse(source_code) + # Map of local names to full module/function paths (alias tracking) + self._alias_map = self._build_alias_map(tree) + # Check for various security issues issues.extend(self._check_dangerous_functions(tree, file_path)) issues.extend(self._check_dangerous_imports(tree, file_path)) @@ -139,8 +144,20 @@ def analyze(self, file_path: Path, source_code: str) -> List[CodeIssue]: issues.extend(self._check_insecure_random(tree, file_path)) issues.extend(self._check_weak_ssl_tls(tree, file_path)) - except SyntaxError: - pass + except SyntaxError as e: + # Report syntax errors as security risks + issues.append( + CodeIssue( + category=IssueCategory.SECURITY, + level=IssueLevel.ERROR, + message=f"Syntax error prevents security analysis: {str(e)}", + file_path=file_path, + line_number=getattr(e, "lineno", 1), + suggestion="Fix the syntax error to enable automated security scanning.", + rule_id="SEC000", + confidence=1.0, + ) + ) # Filter out whitelisted rules and low confidence issues filtered_issues = [] @@ -155,6 +172,21 @@ def analyze(self, file_path: Path, source_code: str) -> List[CodeIssue]: return filtered_issues + def _build_alias_map(self, tree: ast.AST) -> Dict[str, str]: + """Build a map of local names to their full qualified names (alias tracking).""" + aliases = {} + for node in ast.walk(tree): + if isinstance(node, ast.Import): + for alias in node.names: + if alias.asname: + aliases[alias.asname] = alias.name + elif isinstance(node, ast.ImportFrom) and node.module: + for alias in node.names: + local_name = alias.asname if alias.asname else alias.name + if local_name != "*": + aliases[local_name] = f"{node.module}.{alias.name}" + return aliases + def _check_dangerous_functions(self, tree: ast.AST, file_path: Path) -> List[CodeIssue]: """Check for dangerous built-in functions.""" issues = [] @@ -351,33 +383,60 @@ def _check_command_injection(self, tree: ast.AST, file_path: Path) -> List[CodeI """Check for command injection vulnerabilities.""" issues = [] - dangerous_calls = ["os.system", "subprocess.call", "subprocess.Popen", "os.popen"] + # Functions that always use a shell and are dangerous + always_shell = ["os.system", "os.popen", "system", "popen"] + # Functions that are dangerous specifically when shell=True is passed + shell_optional = [ + "subprocess.call", + "subprocess.Popen", + "subprocess.run", + "subprocess.check_call", + "subprocess.check_output", + ] for node in ast.walk(tree): if isinstance(node, ast.Call): func_name = self._get_full_function_name(node.func) - if any(dangerous in func_name for dangerous in dangerous_calls): - # Check if shell=True is used + # Check for "always shell" functions + if any( + dangerous == func_name or func_name.endswith(f".{dangerous}") + for dangerous in always_shell + ): + issue = CodeIssue( + category=IssueCategory.SECURITY, + level=IssueLevel.CRITICAL, + message=f"Command injection risk: {func_name}() uses a shell", + file_path=file_path, + line_number=node.lineno, + suggestion="Avoid functions that use a shell. Use subprocess.run() with a list of arguments instead.", # noqa: E501 + rule_id="SEC0051", + confidence=0.95, + ) + issues.append(issue) + continue + + # Check for "shell=True" in optional functions + if any(dangerous in func_name for dangerous in shell_optional): + is_shell_true = False for keyword in node.keywords: if keyword.arg == "shell" and isinstance(keyword.value, ast.Constant): if keyword.value.value is True: - issue = CodeIssue( - category=IssueCategory.SECURITY, - level=IssueLevel.CRITICAL, - message=( - f"Command injection risk: {func_name}() with shell=True" - ), - file_path=file_path, - line_number=node.lineno, - suggestion=( - "Avoid shell=True. Use subprocess with list of arguments " - "instead" - ), - rule_id="SEC005", - confidence=0.95, - ) - issues.append(issue) + is_shell_true = True + break + + if is_shell_true: + issue = CodeIssue( + category=IssueCategory.SECURITY, + level=IssueLevel.CRITICAL, + message=f"Command injection risk: {func_name}() with shell=True", + file_path=file_path, + line_number=node.lineno, + suggestion="Avoid shell=True. Use subprocess with list of arguments instead.", # noqa: E501 + rule_id="SEC0052", + confidence=0.95, + ) + issues.append(issue) return issues @@ -462,13 +521,16 @@ def _get_function_name(self, node: ast.AST) -> str: return "" def _get_full_function_name(self, node: ast.AST) -> str: - """Get full qualified function name (e.g., 'os.system').""" + """Get full qualified function name (e.g., 'os.system'), resolving aliases.""" if isinstance(node, ast.Name): - return node.id + # Resolve alias if it exists + return self._alias_map.get(node.id, node.id) elif isinstance(node, ast.Attribute): value_name = self._get_full_function_name(node.value) if value_name: - return f"{value_name}.{node.attr}" + full_name = f"{value_name}.{node.attr}" + # Check for secondary aliases (e.g., 'o.system' where 'o' is 'os') + return self._alias_map.get(full_name, full_name) return node.attr return "" @@ -483,13 +545,15 @@ def _check_sql_parameterization(self, tree: ast.AST, file_path: Path) -> List[Co if isinstance(node, ast.Assign): for target in node.targets: if isinstance(target, ast.Name): - # Check if the assignment uses string concatenation or .format() + # Check if the assignment uses string concatenation, f-strings, or .format() if isinstance(node.value, ast.BinOp) and isinstance(node.value.op, ast.Add): # Check if at least one side is a string if isinstance(node.value.left, ast.Constant) or isinstance( node.value.right, ast.Constant ): string_concat_vars[target.id] = node.lineno + elif isinstance(node.value, ast.JoinedStr): # f-string + string_concat_vars[target.id] = node.lineno elif isinstance(node.value, ast.Call) and isinstance( node.value.func, ast.Attribute ): diff --git a/refactron/cli.py b/refactron/cli.py index cab08f6..3b2c725 100644 --- a/refactron/cli.py +++ b/refactron/cli.py @@ -16,8 +16,9 @@ from rich import box from rich.align import Align from rich.console import Console +from rich.markdown import Markdown from rich.panel import Panel -from rich.prompt import Prompt +from rich.prompt import IntPrompt, Prompt from rich.table import Table from rich.text import Text from rich.theme import Theme @@ -41,9 +42,15 @@ start_device_authorization, ) from refactron.core.exceptions import ConfigError +from refactron.core.models import CodeIssue, IssueCategory, IssueLevel from refactron.core.refactor_result import RefactorResult +from refactron.core.repositories import Repository, list_repositories +from refactron.core.workspace import WorkspaceManager, WorkspaceMapping +from refactron.llm.models import SuggestionStatus +from refactron.llm.orchestrator import LLMOrchestrator from refactron.patterns.storage import PatternStorage from refactron.patterns.tuner import RuleTuner +from refactron.rag.retriever import ContextRetriever # Custom theme for a premium, modern look THEME = Theme( @@ -90,7 +97,7 @@ def _auth_banner(title: str) -> None: style="panel.border", box=box.ROUNDED, padding=(1, 2), - subtitle="[secondary]v1.0.14[/secondary]", + subtitle=f"[secondary]v{__version__}[/secondary]", subtitle_align="right", ) ) @@ -158,6 +165,27 @@ def _setup_logging(verbose: bool = False) -> None: datefmt="%Y-%m-%d %H:%M:%S", ) + # Suppress noisy third-party libraries + if not verbose: + # Standard logging suppression + for logger_name in [ + "httpx", + "sentence_transformers", + "transformers", + "tokenizers", + "chromadb", + "huggingface_hub", + ]: + logging.getLogger(logger_name).setLevel(logging.ERROR) + + # Specific suppression for transformers library to avoid "Load Report" + try: + from transformers import logging as tf_logging + + tf_logging.set_verbosity_error() + except ImportError: + pass + def _load_config( config_path: Optional[str], @@ -566,7 +594,7 @@ def get_renderable(step: int, phase: str) -> Align: info_table.add_column(style="dim", justify="right") info_table.add_column(style="bold white") - info_table.add_row("Version:", "v1.0.13") + info_table.add_row("Version:", f"v{__version__}") info_table.add_row("Python:", sys.version.split()[0]) info_table.add_row("OS:", platform.system()) @@ -603,7 +631,7 @@ def get_renderable(step: int, phase: str) -> Align: for line in LOGO_LINES: console.print(Align.center(Text(line, style="bold #ffffff"))) console.print(Align.center(Text(subtitle_text, style="italic #8a8a8a"))) - console.print(Align.center(Text("v1.0.13", style="dim"))) + console.print(Align.center(Text(f"v{__version__}", style="dim"))) console.print() @@ -714,7 +742,7 @@ def print_header() -> None: if choice == "1": _print_custom_help(ctx) elif choice == "2": - console.print("\nRefactron CLI v1.0.13") + console.print(f"\nRefactron CLI v{__version__}") elif choice == "3": console.print("Goodbye!") break @@ -1048,8 +1076,85 @@ def auth_logout() -> None: logout() +def _interactive_file_selector(workspace_path: Path) -> Path: + """Show an interactive file/folder selector for the workspace. + + Args: + workspace_path: The workspace root directory + + Returns: + Selected file or folder path + """ + console.print("\n[bold]Select a file or folder to analyze:[/bold]\n") + + # Get all Python files and directories + python_files = [] + directories = [] + + # Add the workspace root as option + directories.append((".", "Entire workspace")) + + # List immediate subdirectories + for item in sorted(workspace_path.iterdir()): + if item.is_dir() and not item.name.startswith("."): + # Count Python files in directory + py_count = len(list(item.rglob("*.py"))) + if py_count > 0: + directories.append( + (str(item.relative_to(workspace_path)), f"[{py_count} .py files]") + ) + elif item.suffix == ".py": + python_files.append(str(item.relative_to(workspace_path))) + + # Build selection table + table = Table(show_header=True, header_style="bold cyan", box=box.SIMPLE) + table.add_column("#", style="dim", width=4, justify="right") + table.add_column("Type", width=6) + table.add_column("Path", style="cyan") + table.add_column("Info", style="dim") + + options = [] + idx = 1 + + # Add directories + for rel_path, info in directories: + table.add_row(str(idx), "šŸ“ DIR", rel_path, info) + options.append((workspace_path / rel_path, "directory")) + idx += 1 + + # Add Python files + for file_path in python_files[:20]: # Limit to 20 files to avoid clutter + table.add_row(str(idx), "šŸ FILE", file_path, "") + options.append((workspace_path / file_path, "file")) + idx += 1 + + if len(python_files) > 20: + table.add_row("...", "", f"[dim]and {len(python_files) - 20} more files[/dim]", "") + + console.print(table) + console.print() + + # Get user selection + try: + choice = IntPrompt.ask( + "[bold]Enter number to analyze[/bold]", + choices=[str(i) for i in range(1, len(options) + 1)], + show_choices=False, + ) + selected_path, selected_type = options[choice - 1] + + console.print( + f"\n[success]āœ“ Selected: {selected_path.relative_to(workspace_path)}[/success]\n" + ) + return selected_path + + except (KeyboardInterrupt, EOFError): + console.print("\n[yellow]Selection cancelled.[/yellow]") + raise SystemExit(0) + + @main.command() -@click.argument("target", type=click.Path(exists=True)) +@click.argument("target", type=click.Path(exists=True), required=False) @click.option( "--config", "-c", @@ -1103,7 +1208,7 @@ def auth_logout() -> None: ), ) def analyze( - target: str, + target: Optional[str], config: Optional[str], detailed: bool, log_level: Optional[str], @@ -1116,7 +1221,7 @@ def analyze( """ Analyze code for issues and technical debt. - TARGET: Path to file or directory to analyze + TARGET: Path to file or directory to analyze (optional if workspace is connected) """ # Setup logging _setup_logging() @@ -1125,8 +1230,37 @@ def analyze( _auth_banner("Analysis") console.print() - # Setup - target_path = _validate_path(target) + # Determine target path - use workspace if not provided + if not target: + workspace_mgr = WorkspaceManager() + current_workspace = workspace_mgr.get_workspace_by_path(str(Path.cwd())) + + if current_workspace: + console.print(f"[dim]Connected workspace: {current_workspace.repo_full_name}[/dim]") + + # Show interactive file selector + workspace_root = Path(current_workspace.local_path) + target_path = _interactive_file_selector(workspace_root) + target = str(target_path) + else: + console.print( + ( + "[red]Error: No target specified and current directory " + "is not a connected workspace.[/red]\n\n" + "[dim]Options:[/dim]\n" + " 1. Specify a path: refactron analyze /path/to/code\n" + " 2. Connect a workspace: refactron repo connect \n" + " 3. Navigate to a connected workspace directory\n" + ) + ) + raise SystemExit(1) + else: + # Path explicitly provided, validate and use it + target_path = _validate_path(target) + + # Setup (only if not already set by interactive selector) + if "target_path" not in locals(): + target_path = _validate_path(target) cfg = _load_config(config, profile, environment) # Override config with CLI options @@ -1182,7 +1316,7 @@ def analyze( @main.command() -@click.argument("target", type=click.Path(exists=True)) +@click.argument("target", type=click.Path(exists=True), required=False) @click.option( "--config", "-c", @@ -1227,7 +1361,7 @@ def analyze( help="Collect interactive feedback on refactoring suggestions", ) def refactor( - target: str, + target: Optional[str], config: Optional[str], profile: Optional[str], environment: Optional[str], @@ -1238,7 +1372,7 @@ def refactor( """ Refactor code with intelligent transformations. - TARGET: Path to file or directory to refactor + TARGET: Path to file or directory to refactor (optional if workspace is connected) """ # Setup logging _setup_logging() @@ -1247,6 +1381,26 @@ def refactor( _auth_banner("Refactoring") console.print() + # Determine target path - use workspace if not provided + if not target: + workspace_mgr = WorkspaceManager() + current_workspace = workspace_mgr.get_workspace_by_path(str(Path.cwd())) + + if current_workspace: + target = current_workspace.local_path + console.print( + f"[dim]Using connected workspace: {current_workspace.repo_full_name}[/dim]\n" + ) + else: + console.print( + "[red]Error: No target specified and current directory is not a connected workspace.[/red]\n\n" # noqa: E501 + "[dim]Options:[/dim]\n" + " 1. Specify a path: refactron refactor /path/to/code\n" + " 2. Connect a workspace: refactron repo connect \n" + " 3. Navigate to a connected workspace directory\n" + ) + raise SystemExit(1) + # Setup target_path = _validate_path(target) cfg = _load_config(config, profile, environment) @@ -1257,13 +1411,17 @@ def refactor( session_id = None if not preview and cfg.backup_enabled: try: - backup_root = target_path.parent if target_path.is_file() else target_path + # Detect project root to ensure backups are stored in a consistent location + # (usually the directory containing .git or .refactron.yaml) + refactron_instance = Refactron(cfg) + backup_root = refactron_instance.detect_project_root(target_path) + backup_system = BackupRollbackSystem(backup_root) if target_path.is_file(): files = [target_path] else: - files = list(target_path.rglob("*.py")) + files = refactron_instance.get_python_files(target_path) if files: session_id, failed_files = backup_system.prepare_for_refactoring( @@ -1316,6 +1474,12 @@ def refactor( # Record feedback if not preview: + # Apply changes to disk + if result.apply(): + console.print("[green]Successfully applied refactoring changes.[/green]") + else: + console.print("[red]Failed to apply some refactoring changes.[/red]") + # Auto-record as accepted when applying changes _record_applied_operations(refactron, result) elif feedback: @@ -1651,12 +1815,13 @@ def init(template: str) -> None: @main.command() +@click.argument("session_id", required=False) @click.option( "--session", "-s", type=str, default=None, - help="Specific session ID to rollback (default: latest session)", + help="Specific session ID to rollback (deprecated, use argument instead)", ) @click.option( "--use-git", @@ -1678,6 +1843,7 @@ def init(template: str) -> None: help="Clear all backup sessions", ) def rollback( + session_id: Optional[str], session: Optional[str], use_git: bool, list_sessions: bool, @@ -1687,20 +1853,39 @@ def rollback( Rollback refactoring changes to restore original files. By default, restores files from the latest backup session. - Use --session to specify a specific session ID. - Use --use-git to rollback using Git instead of file backups. + + Arguments: + SESSION_ID: Optional specific session ID to rollback. Examples: refactron rollback # Rollback latest session + refactron rollback session_123 # Rollback specific session refactron rollback --list # List all backup sessions - refactron rollback --session session_20240101_120000 refactron rollback --use-git # Use Git rollback refactron rollback --clear # Clear all backups """ + # Support both argument and option for session + target_session = session_id or session console.print("\nšŸ”„ [bold blue]Refactron Rollback[/bold blue]\n") system = BackupRollbackSystem() + # If we appear to be in a subdirectory of a project, try to find the root + # so we can find the centralized backups directory. + if not system.list_sessions(): + # Quick check for markers up the tree + current = Path.cwd() + for _ in range(10): + if (current / ".refactron" / "backups").exists(): + system = BackupRollbackSystem(current) + break + if (current / ".git").exists() or (current / ".refactron.yaml").exists(): + system = BackupRollbackSystem(current) + break + if current.parent == current: + break + current = current.parent + if list_sessions: sessions = system.list_sessions() if not sessions: @@ -1743,13 +1928,13 @@ def rollback( console.print("[dim]Tip: Backups are created automatically when using --apply mode.[/dim]") return - if session: - sess = system.backup_manager.get_session(session) + if target_session: + sess = system.backup_manager.get_session(target_session) if not sess: - console.print(f"[error]Session not found: {session}[/error]") + console.print(f"[error]Session not found: {target_session}[/error]") console.print("[dim]Use 'refactron rollback --list' to see available sessions.[/dim]") raise SystemExit(1) - console.print(f"[dim]Rolling back session: {session}[/dim]") + console.print(f"[dim]Rolling back session: {target_session}[/dim]") console.print(f"[dim]Files to restore: {len(sess['files'])}[/dim]") else: latest = sessions[-1] @@ -1768,7 +1953,7 @@ def rollback( console.print("[yellow]Rollback cancelled.[/yellow]") return - result = system.rollback(session_id=session, use_git=use_git) + result = system.rollback(session_id=target_session, use_git=use_git) if result["success"]: console.print(f"\n[success]{result['message']}[/success]") @@ -2394,6 +2579,601 @@ def patterns_tune( ) +@main.group() +def repo() -> None: + """Manage GitHub repository connections.""" + pass + + +@repo.command("list") +@click.option( + "--api-base-url", + default=DEFAULT_API_BASE_URL, + show_default=True, + help="Refactron API base URL", +) +def repo_list(api_base_url: str) -> None: + """ + List all GitHub repositories connected to your account. + + Shows repositories that have been connected via the Refactron WebApp. + """ + _setup_logging() + console.print() + _auth_banner("Repository List") + console.print() + + try: + with console.status("[primary]Fetching repositories...[/primary]"): + repositories = list_repositories(api_base_url) + + if not repositories: + console.print( + Panel( + "[yellow]No repositories found.\n\n" + "Please connect your GitHub account on the Refactron website:[/yellow]\n" + f"[link]{api_base_url.replace('/api', '')}[/link]", + title="No Repositories", + border_style="warning", + box=box.ROUNDED, + ) + ) + return + + # Create table + table = Table( + title=f"Connected Repositories ({len(repositories)})", + show_header=True, + header_style="primary", + box=box.ROUNDED, + border_style="panel.border", + ) + table.add_column("Name", style="cyan", no_wrap=True) + table.add_column("Description", style="dim") + table.add_column("Language", justify="center", style="green") + table.add_column("Private", justify="center") + table.add_column("Updated", style="dim", no_wrap=True) + + # Check which repos are already connected locally + workspace_mgr = WorkspaceManager() + + for repository in repositories: + workspace = workspace_mgr.get_workspace(repository.full_name) + name_display = repository.name + if workspace: + name_display = f"āœ“ {repository.name}" + + desc = repository.description or "[dim]No description[/dim]" + if len(desc) > 60: + desc = desc[:57] + "..." + + lang = repository.language or "—" + private = "Yes" if repository.private else "No" + updated = repository.updated_at.split("T")[0] # Just the date + + table.add_row(name_display, desc, lang, private, updated) + + console.print(table) + console.print("\n[dim]āœ“ = Already connected locally[/dim]") + console.print( + "[dim]Tip: Use 'refactron repo connect' to link a repository to a local directory[/dim]" + ) + + except RuntimeError as e: + console.print(f"[red]Error: {e}[/red]") + raise SystemExit(1) + except Exception as e: + console.print(f"[red]Unexpected error: {e}[/red]") + raise SystemExit(1) + + +@repo.command("connect") +@click.argument("repo_name", required=False) +@click.option( + "--path", + "-p", + type=click.Path(file_okay=False), + default=None, + help="Local path to connect (default: auto-clone to managed workspace)", +) +@click.option( + "--api-base-url", + default=DEFAULT_API_BASE_URL, + show_default=True, + help="Refactron API base URL", +) +def repo_connect(repo_name: Optional[str], path: Optional[str], api_base_url: str) -> None: + """ + Connect to a GitHub repository. + + REPO_NAME: Name of the repository (e.g., 'my-project' or 'user/my-project') + + If the repository doesn't exist locally, it will be cloned automatically + to ~/.refactron/workspaces// + """ + _setup_logging() + console.print() + _auth_banner("Connect Repository") + console.print() + + workspace_mgr = WorkspaceManager() + + # If path is provided, use existing behavior (map existing local directory) + if path: + local_path = Path(path).resolve() + + # Auto-detect repository if not provided + if not repo_name: + console.print("[dim]No repository specified, attempting auto-detection...[/dim]\n") + detected = workspace_mgr.detect_repository(local_path) + if detected: + console.print(f"[success]Detected repository: {detected}[/success]\n") + repo_name = detected + else: + console.print( + "[red]Could not auto-detect repository from .git config.[/red]\n" + "[dim]Please specify the repository name:[/dim]\n" + " refactron repo connect \n" + ) + raise SystemExit(1) + else: + # No path provided - must have repo_name for cloning + if not repo_name: + console.print( + "[red]Error: Repository name is required when not in a git directory.[/red]\n\n" + "[dim]Usage:[/dim]\n" + " refactron repo connect # Auto-clone to workspace\n" + " refactron repo connect --path . # Link current directory\n" + ) + raise SystemExit(1) + + # Fetch available repositories + try: + with console.status("[primary]Fetching repositories...[/primary]"): + repositories = list_repositories(api_base_url) + except RuntimeError as e: + console.print(f"[red]Error: {e}[/red]") + raise SystemExit(1) + + # Find matching repository + matching_repo: Optional[Repository] = None + for repository in repositories: + if ( + repository.name.lower() == repo_name.lower() + or repository.full_name.lower() == repo_name.lower() + ): + matching_repo = repository + break + + if not matching_repo: + console.print( + f"[red]Repository '{repo_name}' not found in your connected repositories.[/red]\n" + ) + console.print("[dim]Available repositories:[/dim]") + for repository in repositories[:5]: + console.print(f" - {repository.full_name}") + if len(repositories) > 5: + console.print(f" ... and {len(repositories) - 5} more") + console.print("\n[dim]Run 'refactron repo list' to see all repositories.[/dim]") + raise SystemExit(1) + + # If no path provided, clone to managed workspace + if not path: + workspace_root = Path.home() / ".refactron" / "workspaces" + workspace_root.mkdir(parents=True, exist_ok=True) + local_path = workspace_root / matching_repo.name + + # Check if already cloned + if local_path.exists(): + console.print(f"[dim]Repository already exists at: {local_path}[/dim]\n") + else: + # Clone the repository + console.print(f"[primary]Cloning {matching_repo.full_name}...[/primary]\n") + + import subprocess + + try: + subprocess.run( + ["git", "clone", matching_repo.clone_url, str(local_path)], + capture_output=True, + text=True, + check=True, + ) + console.print(f"[success]āœ“ Cloned successfully to {local_path}[/success]\n") + except subprocess.CalledProcessError as e: + console.print(f"[red]Failed to clone repository:[/red]\n{e.stderr}") + raise SystemExit(1) + except FileNotFoundError: + console.print( + "[red]Error: git command not found.[/red]\n" + "[dim]Please install git or use --path to connect an existing directory.[/dim]" + ) + raise SystemExit(1) + + # Create workspace mapping + mapping = WorkspaceMapping( + repo_id=matching_repo.id, + repo_name=matching_repo.name, + repo_full_name=matching_repo.full_name, + local_path=str(local_path), + connected_at=datetime.now(timezone.utc).isoformat(), + ) + + workspace_mgr.add_workspace(mapping) + + # Trigger background indexing via subprocess + # We spawn a separate process so it survives after this CLI command exits + import subprocess + import sys + + console.print("[dim]Spawning background indexer...[/dim]") + try: + # Run 'refactron rag index' in the background + # Run 'refactron rag index' in the background + # We redirect output to DEVNULL to keep it quiet + pid = subprocess.Popen( + [sys.executable, "-m", "refactron.cli", "rag", "index", "--background"], + cwd=str(local_path), + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + start_new_session=True, # Detach from terminal + ).pid + console.print(f"[dim]Indexing started in background (PID: {pid}).[/dim]") + console.print("[dim]Run 'refactron rag status' to check progress.[/dim]") + except Exception as e: + console.print(f"[yellow]Auto-indexing failed to start: {e}[/yellow]") + + # Create helpful navigation command + cd_command = f"cd {local_path}" + + console.print( + Panel( + f"[success]Successfully connected![/success]\n\n" + f"Repository: [bold]{matching_repo.full_name}[/bold]\n" + f"Local Path: [bold]{local_path}[/bold]\n\n" + f"[yellow]To navigate to this directory, run:[/yellow]\n" + f"[bold cyan]{cd_command}[/bold cyan]", + title="āœ“ Connected", + border_style="success", + box=box.ROUNDED, + ) + ) + + # Also print the cd command separately for easy copying + console.print(f"\n[dim]Quick copy:[/dim] [bold cyan]{cd_command}[/bold cyan]\n") + + +@repo.command("disconnect") +@click.argument("repo_name", required=False) +@click.option( + "--delete-files", + is_flag=True, + help="Also delete the local directory (requires confirmation)", +) +def repo_disconnect(repo_name: Optional[str], delete_files: bool) -> None: + """ + Disconnect a repository and optionally delete local files. + + REPO_NAME: Name of the repository to disconnect (e.g., 'volumeofsphere' or 'user/volumeofsphere') # noqa: E501 + + If not provided, attempts to detect from current directory. + """ + _setup_logging() + console.print() + _auth_banner("Disconnect Repository") + console.print() + + workspace_mgr = WorkspaceManager() + + # Auto-detect if not provided + if not repo_name: + current_workspace = workspace_mgr.get_workspace_by_path(str(Path.cwd())) + if current_workspace: + repo_name = current_workspace.repo_full_name + console.print(f"[dim]Detected repository: {repo_name}[/dim]\n") + else: + console.print( + "[red]Error: No repository specified and current directory is not a connected workspace.[/red]\n\n" # noqa: E501 + "[dim]Usage:[/dim]\n" + " refactron repo disconnect \n" + " cd && refactron repo disconnect\n" + ) + raise SystemExit(1) + + # Find the workspace + workspace = workspace_mgr.get_workspace(repo_name) + if not workspace: + console.print(f"[yellow]Repository '{repo_name}' is not connected.[/yellow]\n") + console.print("[dim]Run 'refactron repo list' to see connected repositories.[/dim]") + raise SystemExit(1) + + local_path = Path(workspace.local_path) + + # Confirm deletion if requested + if delete_files: + if not local_path.exists(): + console.print(f"[yellow]Local directory does not exist: {local_path}[/yellow]\n") + else: + console.print( + Panel( + f"[yellow]āš ļø WARNING: This will permanently delete:[/yellow]\n\n" + f"[bold]{local_path}[/bold]\n\n" + f"[dim]This action cannot be undone![/dim]", + title="Confirm Deletion", + border_style="yellow", + box=box.ROUNDED, + ) + ) + + if not click.confirm( + "\nAre you sure you want to delete this directory?", default=False + ): + console.print("[yellow]Deletion cancelled.[/yellow]") + delete_files = False + + # Remove workspace mapping + workspace_mgr.remove_workspace(repo_name) + console.print(f"[success]āœ“ Removed workspace mapping for '{repo_name}'[/success]\n") + + # Delete files if confirmed + files_deleted = False + if delete_files and local_path.exists(): + try: + import shutil + + shutil.rmtree(local_path) + console.print(f"[success]āœ“ Deleted directory: {local_path}[/success]\n") + files_deleted = True + except Exception as e: + console.print(f"[red]Failed to delete directory: {e}[/red]\n") + raise SystemExit(1) + + # Show appropriate summary + if not local_path.exists() and not files_deleted: + # Directory was already gone + console.print( + Panel( + f"[yellow]Workspace mapping removed[/yellow]\n\n" + f"Repository: [bold]{repo_name}[/bold]\n" + f"Status: [dim]Local directory was already deleted[/dim]", + title="āœ“ Cleaned Up", + border_style="yellow", + box=box.ROUNDED, + ) + ) + else: + # Normal disconnect + console.print( + Panel( + f"[success]Repository disconnected successfully![/success]\n\n" + f"Repository: [bold]{repo_name}[/bold]\n" + f"Mapping removed: [bold]Yes[/bold]\n" + f"Files deleted: [bold]{'Yes' if files_deleted else 'No'}[/bold]", + title="āœ“ Disconnected", + border_style="success", + box=box.ROUNDED, + ) + ) + + +# ═══════════════════════════════════════════════════════════════════════════════ +# RAG Commands +# ═══════════════════════════════════════════════════════════════════════════════ + + +@main.group() +def rag() -> None: + """RAG (Retrieval-Augmented Generation) management commands.""" + pass + + +@rag.command("index") +@click.option("--background", is_flag=True, help="Run in background mode (suppress output)") +@click.option("--summarize", is_flag=True, help="Use AI to summarize code for better retrieval") +def rag_index(background: bool, summarize: bool) -> None: + """Index the current workspace for RAG retrieval.""" + if background: + # Suppress all logging and output in background mode + logging.getLogger().setLevel(logging.CRITICAL) + import os + import sys + + sys.stdout = open(os.devnull, "w") + sys.stderr = open(os.devnull, "w") + else: + _setup_logging() + console.print() + _auth_banner("Index Repository") + console.print() + + from refactron.core.workspace import WorkspaceManager + from refactron.rag.indexer import RAGIndexer + + workspace_mgr = WorkspaceManager() + + # Get current workspace + current_workspace = workspace_mgr.get_workspace_by_path(str(Path.cwd())) + if not current_workspace: + if not background: + console.print( + "[red]Error: Not in a connected workspace.[/red]\n\n" + "[dim]Run 'refactron repo connect ' first.[/dim]" + ) + raise SystemExit(1) + + local_path = Path(current_workspace.local_path) + + if not background: + console.print(f"[primary]Indexing:[/primary] {current_workspace.repo_full_name}\n") + + try: + if background: + # Run without visual feedback + indexer = RAGIndexer(local_path) + indexer.index_repository(local_path, summarize=summarize) + else: + with console.status("[primary]Parsing and indexing code...[/primary]"): + indexer = RAGIndexer(local_path) + stats = indexer.index_repository(local_path, summarize=summarize) + + console.print( + Panel( + f"[success]Indexing complete![/success]\n\n" + f"Files indexed: [bold]{stats.total_files}[/bold]\n" + f"Code chunks: [bold]{stats.total_chunks}[/bold]\n" + f"Index location: [dim]{stats.index_path}[/dim]\n\n" + f"[dim]Chunk breakdown:[/dim]\n" + f" • Functions: {stats.chunk_types.get('function', 0)}\n" + f" • Classes: {stats.chunk_types.get('class', 0)}\n" + f" • Methods: {stats.chunk_types.get('method', 0)}\n" + f" • Modules: {stats.chunk_types.get('module', 0)}", + title="āœ“ Indexed", + border_style="success", + box=box.ROUNDED, + ) + ) + except Exception as e: + console.print(f"[red]Error indexing repository: {e}[/red]") + raise SystemExit(1) + + +@rag.command("search") +@click.argument("query") +@click.option("--top-k", default=5, help="Number of results to return") +@click.option("--type", "chunk_type", help="Filter by chunk type (function/class/module)") +@click.option("--rerank", is_flag=True, help="Use AI to rerank results for better accuracy") +def rag_search(query: str, top_k: int, chunk_type: Optional[str], rerank: bool) -> None: + """Search the RAG index for similar code.""" + _setup_logging() + console.print() + + from refactron.core.workspace import WorkspaceManager + from refactron.rag.retriever import ContextRetriever + + workspace_mgr = WorkspaceManager() + + # Get current workspace + current_workspace = workspace_mgr.get_workspace_by_path(str(Path.cwd())) + if not current_workspace: + console.print( + "[red]Error: Not in a connected workspace.[/red]\n\n" + "[dim]Run 'refactron repo connect ' first.[/dim]" + ) + raise SystemExit(1) + + local_path = Path(current_workspace.local_path) + + try: + retriever = ContextRetriever(local_path) + results = retriever.retrieve_similar(query, top_k=top_k, chunk_type=chunk_type) + + if not results: + console.print(f"[yellow]No results found for: {query}[/yellow]") + return + + console.print(f"\n[primary]Found {len(results)} results for:[/primary] {query}\n") + + for i, result in enumerate(results, 1): + relevance_score = max(0, 1 - result.distance) * 100 + + # AI Reranking if enabled + if rerank: + try: + from refactron.llm.client import GroqClient + + client = GroqClient() + prompt = ( # noqa: E501 + f"Rate the relevance of the following code snippet to the user query: '{query}'\n\n" # noqa: E501 + f"Code:\n{result.content[:500]}\n\n" + "Provide only a percentage number (e.g. 85%) representing how well this code matches " # noqa: E501 + "the semantic intent of the query." + ) + ai_response = client.generate( + prompt=prompt, + system="You are a code relevance evaluator. Output only the percentage.", + max_tokens=10, + ) + # Extract number from response (e.g. "85%" or "85") + import re + + match = re.search(r"(\d+)%", ai_response) or re.search(r"(\d+)", ai_response) + if match: + relevance_score = float(match.group(1)) + except Exception: + pass # Fallback to distance-based score + + console.print( + Panel( + f"[bold]{result.name}[/bold] ({result.chunk_type})\n" + f"[dim]{result.file_path}:{result.line_range[0]}-{result.line_range[1]}[/dim]\n\n" # noqa: E501 + f"```python\n{result.content[:200]}{'...' if len(result.content) > 200 else ''}\n```\n\n" # noqa: E501 + f"[dim]Similarity: {relevance_score / 100.0:.2%}[/dim]", + title=f"Result {i}/{len(results)}", + border_style="dim", + box=box.ROUNDED, + ) + ) + except RuntimeError as e: + console.print(f"[red]{e}[/red]\n") + console.print("[dim]Run 'refactron rag index' to create an index first.[/dim]") + raise SystemExit(1) + + +@rag.command("status") +def rag_status() -> None: + """Show RAG index statistics.""" + _setup_logging() + console.print() + + from refactron.core.workspace import WorkspaceManager + from refactron.rag.indexer import RAGIndexer + + workspace_mgr = WorkspaceManager() + + # Get current workspace + current_workspace = workspace_mgr.get_workspace_by_path(str(Path.cwd())) + if not current_workspace: + console.print( + "[red]Error: Not in a connected workspace.[/red]\n\n" + "[dim]Run 'refactron repo connect ' first.[/dim]" + ) + raise SystemExit(1) + + local_path = Path(current_workspace.local_path) + + try: + indexer = RAGIndexer(local_path) + stats = indexer.get_stats() + + if stats.total_chunks == 0: + console.print( + "[yellow]No index found.[/yellow]\n\n" + "[dim]Run 'refactron rag index' to create one.[/dim]" + ) + return + + console.print( + Panel( + f"[primary]RAG Index Status[/primary]\n\n" + f"Files indexed: [bold]{stats.total_files}[/bold]\n" + f"Total chunks: [bold]{stats.total_chunks}[/bold]\n" + f"Embedding model: [dim]{stats.embedding_model}[/dim]\n" + f"Index location: [dim]{stats.index_path}[/dim]\n\n" + f"[dim]Chunk breakdown:[/dim]\n" + f" • Functions: {stats.chunk_types.get('function', 0)}\n" + f" • Classes: {stats.chunk_types.get('class', 0)}\n" + f" • Methods: {stats.chunk_types.get('method', 0)}\n" + f" • Modules: {stats.chunk_types.get('module', 0)}", + title="RAG Status", + border_style="primary", + box=box.ROUNDED, + ) + ) + except Exception as e: + console.print(f"[yellow]No index found: {e}[/yellow]\n") + console.print("[dim]Run 'refactron rag index' to create one.[/dim]") + + @patterns.command("profile") @click.option( "--project", @@ -2468,5 +3248,258 @@ def patterns_profile(project_path: str, config_path: Optional[str]) -> None: console.print(table) +@main.command() +@click.argument("target", required=False, type=click.Path(exists=True)) +@click.option("--line", type=int, help="Specific line number to fix") +@click.option("--interactive/--no-interactive", default=True, help="Use interactive mode") +@click.option("--apply/--no-apply", default=False, help="Apply the suggested changes to the file") +def suggest(target: Optional[str], line: Optional[int], interactive: bool, apply: bool): + """ + Generate AI-powered refactoring suggestions. + + Uses RAG and LLM to analyze code and propose fixes. + """ + console.print() + _auth_banner("AI Refactoring") + console.print() + + # 1. Setup + cfg = _load_config(None) + _setup_logging() + + target_path = Path(target or ".").resolve() + + # Try to find the project root for RAG context + refactron_instance = Refactron(cfg) + workspace_path = refactron_instance.detect_project_root(target_path) + + console.print(f"[bold]Analyzing:[/bold] {target_path}") + if line: + console.print(f"[bold]Line:[/bold] {line}") + + # 2. Initialize Components + try: + retriever = ContextRetriever(workspace_path) + console.print("[dim]RAG Index loaded.[/dim]") + except Exception: + console.print( + "[yellow]Warning: RAG index not found. Context retrieval will be limited.[/yellow]" + ) + console.print("[dim]Run 'refactron rag index' to enable full context.[/dim]") + retriever = None + + orchestrator = LLMOrchestrator(retriever=retriever) + + # 3. Read Code + start_line_idx = 0 + end_line_idx = 0 + + if target_path.is_file(): + code = target_path.read_text(encoding="utf-8") + original_snippet = code + # Extract snippet if line provided + if line: + lines = code.splitlines() + if 1 <= line <= len(lines): + # Context window +/- 10 lines for the LLM prompt + start_line_idx = max(0, line - 10) + end_line_idx = min(len(lines), line + 10) + original_snippet = "\n".join(lines[start_line_idx:end_line_idx]) + + # Smaller context for display + display_start = max(0, line - 3) + display_end = min(len(lines), line + 3) + display_code = "\n".join(lines[display_start:display_end]) + console.print(Panel(display_code, title="Code Snippet", style="dim")) + else: + console.print(f"[red]Error: Line {line} is out of range.[/red]") + return + else: + console.print( + "[red]Error: Directory analysis not yet supported. Please specify a file.[/red]" + ) + return + + # 4. Generate Suggestion + # Create a synthetic issue for now + issue = CodeIssue( + category=IssueCategory.MODERNIZATION, + level=IssueLevel.INFO, + message="Refactor and improve this code", + file_path=target_path, + line_number=line or 1, + ) + + with console.status("[bold cyan]Generating suggestion...[/bold cyan]"): + suggestion = orchestrator.generate_suggestion(issue, original_code=original_snippet) + + # 5. Display Result + if suggestion.status == SuggestionStatus.FAILED: + console.print(f"[red]Generation Failed:[/red] {suggestion.explanation}") + return + + console.print() + console.print( + Panel( + Markdown(suggestion.explanation), + title=f"Suggestion ({suggestion.model_name})", + border_style="green", + ) + ) + + console.print(Panel(suggestion.proposed_code, title="Proposed Code", style="on #1e1e1e")) + + console.print( + f"[dim]AI Confidence: [bold]{suggestion.llm_confidence:.2f}[/bold], Safety Score: [bold]{suggestion.confidence_score:.2f}[/bold][/dim]" # noqa: E501 + ) + + if suggestion.safety_result: + status_color = "green" if suggestion.safety_result.passed else "red" + console.print( + f"Safety Check: [{status_color}]{'PASSED' if suggestion.safety_result.passed else 'FAILED'}[/{status_color}]" # noqa: E501 + ) + if suggestion.safety_result.issues: + console.print(f"Issues: {', '.join(suggestion.safety_result.issues)}") + + console.print() + + # 6. Apply Changes + if apply: + if interactive: + if not click.confirm("Do you want to apply these changes?"): + console.print("[yellow]Changes cancelled.[/yellow]") + return + + try: + # Create backup + backup_sys = BackupRollbackSystem(workspace_path) + session_id, _ = backup_sys.prepare_for_refactoring( + [target_path], description="AI suggestion" + ) + console.print(f"[dim]Backup created: {session_id}[/dim]") + + # Construct new content + new_file_content = "" + if line: + # Reload lines to ensure freshness + current_lines = target_path.read_text(encoding="utf-8").splitlines() + # Determine indentation of the original block to verify alignment (optional, skipping for now) # noqa: E501 + + # Replace the exact block that was sent to LLM + replacement_lines = suggestion.proposed_code.splitlines() + + # Reconstruct + pre_block = current_lines[:start_line_idx] + post_block = current_lines[end_line_idx:] + + final_lines = pre_block + replacement_lines + post_block + new_file_content = "\n".join(final_lines) + if code.endswith("\n"): + new_file_content += "\n" + else: + new_file_content = suggestion.proposed_code + + target_path.write_text(new_file_content, encoding="utf-8") + console.print("[green bold]Successfully applied AI suggestion![/green bold]") + console.print(f"[dim]Run 'refactron rollback {session_id}' to undo.[/dim]") + + except Exception as e: + console.print(f"[red]Failed to apply changes: {e}[/red]") + + +@main.command() +@click.argument("target", type=click.Path(exists=True)) +@click.option( + "--apply/--no-apply", default=False, help="Apply the documentation changes to the file" +) +@click.option("--interactive/--no-interactive", default=True, help="Use interactive mode for apply") +def document(target: str, apply: bool, interactive: bool): + """ + Generate Google-style docstrings for a Python file. + + Uses AI to analyze code and add comprehensive documentation. + """ + console.print() + _auth_banner("AI Documentation") + console.print() + + # Setup + cfg = _load_config(None) + _setup_logging() + + target_path = Path(target).resolve() + + if not target_path.is_file(): + console.print("[red]Error: Please specify a file, not a directory.[/red]") + return + + refactron_instance = Refactron(cfg) + workspace_path = refactron_instance.detect_project_root(target_path) + + console.print(f"[bold]Documenting:[/bold] {target_path}") + + # Initialize components + try: + retriever = ContextRetriever(workspace_path) + except Exception: + console.print( + "[yellow]Warning: RAG index not found. Context retrieval will be limited.[/yellow]" + ) + retriever = None + + orchestrator = LLMOrchestrator(retriever=retriever) + + # Generate + code = target_path.read_text(encoding="utf-8") + + with console.status("[bold cyan]Generating documentation...[/bold cyan]"): + suggestion = orchestrator.generate_documentation(code, file_path=str(target_path)) + + if suggestion.status == SuggestionStatus.FAILED: + console.print(f"[red]Generation Failed:[/red] {suggestion.explanation}") + return + + doc_path = target_path.with_name(f"{target_path.stem}_doc.md") + + console.print() + console.print( + Panel( + Markdown(suggestion.explanation), + title=f"Documentation Plan ({suggestion.model_name})", + border_style="blue", + ) + ) + + console.print( + Panel( + Markdown(suggestion.proposed_code), + title=f"Preview: {doc_path.name}", + style="on #1e1e1e", + ) + ) + + console.print(f"[dim]Confidence: {suggestion.confidence_score:.2f}[/dim]") + console.print() + + # Apply + if apply: + if interactive: + if not click.confirm( + f"Do you want to create external documentation at {doc_path.name}?" + ): + console.print("[yellow]Changes cancelled.[/yellow]") + return + + try: + # Write new file (no backup needed for new file creation) + doc_path.write_text(suggestion.proposed_code, encoding="utf-8") + console.print( + f"[green bold]Successfully created documentation: {doc_path}[/green bold]" + ) + + except Exception as e: + console.print(f"[red]Failed to create documentation: {e}[/red]") + + if __name__ == "__main__": main(prog_name="refactron") diff --git a/refactron/core/config.py b/refactron/core/config.py index 754b73f..4f0c830 100644 --- a/refactron/core/config.py +++ b/refactron/core/config.py @@ -66,6 +66,7 @@ class RefactronConfig: "**/venv/**", "**/env/**", "**/.git/**", + "**/.refactron-backup/**", ] ) diff --git a/refactron/core/models.py b/refactron/core/models.py index 39e8ca5..78d0273 100644 --- a/refactron/core/models.py +++ b/refactron/core/models.py @@ -29,6 +29,7 @@ class IssueCategory(Enum): MODERNIZATION = "modernization" DEPENDENCY = "dependency" DEAD_CODE = "dead_code" + DOCUMENTATION = "documentation" @dataclass diff --git a/refactron/core/refactor_result.py b/refactron/core/refactor_result.py index 609d27f..27f1ffe 100644 --- a/refactron/core/refactor_result.py +++ b/refactron/core/refactor_result.py @@ -94,10 +94,43 @@ def show_diff(self) -> str: return "\n".join(lines) def apply(self) -> bool: - """Apply the refactoring operations (placeholder).""" - # This would actually apply the changes to files - self.applied = True - return True + """Apply the refactoring operations to the files.""" + # Group operations by file + file_ops: Dict[Path, List[RefactoringOperation]] = {} + for op in self.operations: + if op.file_path not in file_ops: + file_ops[op.file_path] = [] + file_ops[op.file_path].append(op) + + success = True + for file_path, ops in file_ops.items(): + try: + if not file_path.exists(): + continue + + content = file_path.read_text(encoding="utf-8") + + # Apply each operation + # Note: This simple implementation assumes non-overlapping old_code blocks + # and replaces exact matches. A more robust implementation would use AST + # or line-based replacement to handle overlapping edits. + new_content = content + for op in ops: + if op.old_code in new_content: + new_content = new_content.replace(op.old_code, op.new_code, 1) + else: + # Fallback: try to find by line number if exact match fails + # This part is omitted for simplicity in this version + pass + + if new_content != content: + file_path.write_text(new_content, encoding="utf-8") + + except Exception: + success = False + + self.applied = success + return success def summary(self) -> Dict[str, int]: """Get a summary of refactoring operations.""" diff --git a/refactron/core/refactron.py b/refactron/core/refactron.py index f4847a8..48580b2 100644 --- a/refactron/core/refactron.py +++ b/refactron/core/refactron.py @@ -246,7 +246,7 @@ def analyze(self, target: Union[str, Path]) -> AnalysisResult: if target_path.is_file(): files = [target_path] else: - files = self._get_python_files(target_path) + files = self.get_python_files(target_path) # Apply incremental analysis filtering if self.incremental_tracker.enabled: @@ -513,7 +513,7 @@ def refactor( if target_path.is_file(): files = [target_path] else: - files = self._get_python_files(target_path) + files = self.get_python_files(target_path) result = RefactorResult(preview_mode=preview) @@ -621,7 +621,7 @@ def _refactor_file( return operations - def _get_python_files(self, directory: Path) -> List[Path]: + def get_python_files(self, directory: Path) -> List[Path]: """Get all Python files in a directory, respecting exclude patterns.""" python_files = [] @@ -779,7 +779,14 @@ def detect_project_root(self, file_path: Path) -> Path: current = file_path.parent.resolve() # Common project markers - markers = [".git", "setup.py", "pyproject.toml", "setup.cfg", ".refactron"] + markers = [ + ".git", + "setup.py", + "pyproject.toml", + "setup.cfg", + ".refactron", + ".refactron.yaml", + ] for _ in range(10): # Limit search depth for marker in markers: diff --git a/refactron/core/repositories.py b/refactron/core/repositories.py new file mode 100644 index 0000000..3cf1ac4 --- /dev/null +++ b/refactron/core/repositories.py @@ -0,0 +1,177 @@ +"""GitHub repository integration for Refactron CLI. + +This module provides functionality to interact with the Refactron backend API +to fetch GitHub repositories connected to the user's account. +""" + +from __future__ import annotations + +import json +from dataclasses import dataclass +from typing import Any, Dict, List, Optional +from urllib.error import HTTPError, URLError +from urllib.request import Request, urlopen + +from refactron.core.credentials import load_credentials + + +@dataclass(frozen=True) +class Repository: + """Represents a GitHub repository.""" + + id: int + name: str + full_name: str + description: Optional[str] + private: bool + html_url: str + clone_url: str + ssh_url: str + default_branch: str + language: Optional[str] + updated_at: str + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "Repository": + """Create a Repository instance from API response data.""" + return cls( + id=data["id"], + name=data["name"], + full_name=data["full_name"], + description=data.get("description"), + private=data["private"], + html_url=data["html_url"], + clone_url=data["clone_url"], + ssh_url=data["ssh_url"], + default_branch=data.get("default_branch", "main"), + language=data.get("language"), + updated_at=data["updated_at"], + ) + + +def list_repositories(api_base_url: str, timeout_seconds: int = 10) -> List[Repository]: + """Fetch all GitHub repositories connected to the user's account. + + Args: + api_base_url: The Refactron API base URL + timeout_seconds: Request timeout in seconds + + Returns: + List of Repository objects + + Raises: + RuntimeError: If the request fails or user is not authenticated + """ + # Load credentials + creds = load_credentials() + if not creds: + raise RuntimeError("Not authenticated. Please run 'refactron login' first.") + + if not creds.access_token: + raise RuntimeError("No access token found. Please run 'refactron login' first.") + + # Check if token is expired + from datetime import datetime, timezone + + if creds.expires_at: + try: + # Parse the expiration time + from datetime import datetime + + if isinstance(creds.expires_at, str): + # Remove timezone info for comparison + expires_str = creds.expires_at.replace("+00:00", "").replace("Z", "") + expires_at = datetime.fromisoformat(expires_str).replace(tzinfo=timezone.utc) + else: + expires_at = creds.expires_at + + now = datetime.now(timezone.utc) + if now >= expires_at: + raise RuntimeError("Your session has expired. Please run 'refactron login' again.") + except (ValueError, AttributeError): + # If we can't parse the expiration, continue anyway + pass + + # Normalize the base URL + base = api_base_url.rstrip("/") + url = f"{base}/api/github/repositories" + + # Prepare the request with Bearer token + req = Request( + url=url, + headers={ + "Authorization": f"Bearer {creds.access_token}", + "Accept": "application/json", + }, + method="GET", + ) + + try: + with urlopen(req, timeout=timeout_seconds) as resp: + raw = resp.read().decode("utf-8") + data = json.loads(raw) if raw else [] + + # Handle both list and dict wrapper formats + repositories_data = [] + if isinstance(data, list): + repositories_data = data + elif isinstance(data, dict): + # Try common wrapper keys + repositories_data = ( + data.get("repositories") or data.get("data") or data.get("repos") or [] + ) + if not isinstance(repositories_data, list): + raise RuntimeError( + f"Unexpected API response format. Expected list or dict with 'repositories' key. " + f"Got: {type(data)} with keys: {list(data.keys()) if isinstance(data, dict) else 'N/A'}" + ) + else: + raise RuntimeError(f"Unexpected API response type: {type(data)}") + + return [Repository.from_dict(repo) for repo in repositories_data] + + except HTTPError as e: + if e.code == 401: + # Try to get more details from the error response + try: + error_body = e.read().decode("utf-8") + error_data = json.loads(error_body) + detail = error_data.get("message", error_data.get("detail", "Unknown error")) + except Exception: + detail = "No additional details" + + raise RuntimeError( + f"Authentication failed (HTTP 401): {detail}\n\n" + "Possible causes:\n" + " 1. Your session has expired - run 'refactron login' again\n" + " 2. The access token is invalid\n" + " 3. The API endpoint requires different authentication\n\n" + f"API URL: {url}\n" + f"Token present: {'Yes' if creds.access_token else 'No'}" + ) + elif e.code == 403: + raise RuntimeError( + "GitHub access denied. Please reconnect your GitHub account on the Refactron website." + ) + elif e.code == 404: + raise RuntimeError( + "Repository endpoint not found. Please check your API base URL or update Refactron." + ) + else: + # Try to parse error message from response + try: + error_body = e.read().decode("utf-8") + error_data = json.loads(error_body) + message = error_data.get("message", str(e)) + except Exception: + message = str(e) + raise RuntimeError(f"Failed to fetch repositories (HTTP {e.code}): {message}") + + except URLError as e: + raise RuntimeError(f"Network error: {e.reason}. Is the Refactron API accessible?") + + except json.JSONDecodeError as e: + raise RuntimeError(f"Invalid JSON response from API: {e}") + + except Exception as e: + raise RuntimeError(f"Unexpected error fetching repositories: {e}") diff --git a/refactron/core/workspace.py b/refactron/core/workspace.py new file mode 100644 index 0000000..d608a15 --- /dev/null +++ b/refactron/core/workspace.py @@ -0,0 +1,226 @@ +"""Workspace management for Refactron CLI. + +This module handles the mapping between remote GitHub repositories and local +directory paths, enabling seamless navigation and context switching. +""" + +from __future__ import annotations + +import json +import os +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Dict, Optional + + +@dataclass +class WorkspaceMapping: + """Represents a mapping between a remote repository and a local path.""" + + repo_id: int + repo_name: str + repo_full_name: str + local_path: str + connected_at: str + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary for JSON serialization.""" + return { + "repo_id": self.repo_id, + "repo_name": self.repo_name, + "repo_full_name": self.repo_full_name, + "local_path": self.local_path, + "connected_at": self.connected_at, + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "WorkspaceMapping": + """Create from dictionary.""" + return cls( + repo_id=data["repo_id"], + repo_name=data["repo_name"], + repo_full_name=data["repo_full_name"], + local_path=data["local_path"], + connected_at=data["connected_at"], + ) + + +class WorkspaceManager: + """Manages workspace mappings between repositories and local paths.""" + + def __init__(self, config_path: Optional[Path] = None) -> None: + """Initialize the workspace manager. + + Args: + config_path: Path to the workspaces.json file (default: ~/.refactron/workspaces.json) + """ + self.config_path = config_path or (Path.home() / ".refactron" / "workspaces.json") + self._ensure_config_exists() + + def _ensure_config_exists(self) -> None: + """Ensure the configuration directory and file exist.""" + self.config_path.parent.mkdir(parents=True, exist_ok=True) + if not self.config_path.exists(): + self._save_workspaces({}) + + def _load_workspaces(self) -> Dict[str, Dict[str, Any]]: + """Load workspace mappings from disk. + + Returns: + Dictionary mapping repo_full_name to workspace data + """ + try: + with open(self.config_path, "r", encoding="utf-8") as f: + data = json.load(f) + return data if isinstance(data, dict) else {} + except (json.JSONDecodeError, FileNotFoundError): + return {} + + def _save_workspaces(self, workspaces: Dict[str, Dict[str, Any]]) -> None: + """Save workspace mappings to disk. + + Args: + workspaces: Dictionary mapping repo_full_name to workspace data + """ + with open(self.config_path, "w", encoding="utf-8") as f: + json.dump(workspaces, f, indent=2, sort_keys=True) + + # Set file permissions to 0600 (user read/write only) + try: + os.chmod(self.config_path, 0o600) + except OSError: + pass + + def add_workspace(self, mapping: WorkspaceMapping) -> None: + """Add or update a workspace mapping. + + Args: + mapping: The workspace mapping to add + """ + workspaces = self._load_workspaces() + workspaces[mapping.repo_full_name] = mapping.to_dict() + self._save_workspaces(workspaces) + + def get_workspace(self, repo_name: str) -> Optional[WorkspaceMapping]: + """Get a workspace mapping by repository name. + + Args: + repo_name: The repository name (e.g., "repo" or "user/repo") + + Returns: + The workspace mapping, or None if not found + """ + workspaces = self._load_workspaces() + + # Try exact match first (full name) + data = workspaces.get(repo_name) + if data: + return WorkspaceMapping.from_dict(data) + + # Try matching by short name (repo name without user) + repo_name_lower = repo_name.lower() + for full_name, workspace_data in workspaces.items(): + # Extract short name from full name (e.g., "volumeofsphere" from "omsherikar/volumeofsphere") + short_name = full_name.split("/")[-1].lower() + if short_name == repo_name_lower: + return WorkspaceMapping.from_dict(workspace_data) + + return None + + def get_workspace_by_path(self, local_path: str) -> Optional[WorkspaceMapping]: + """Get a workspace mapping by local path. + + Args: + local_path: The local directory path + + Returns: + The workspace mapping, or None if not found + """ + normalized_path = str(Path(local_path).resolve()) + workspaces = self._load_workspaces() + + for data in workspaces.values(): + if str(Path(data["local_path"]).resolve()) == normalized_path: + return WorkspaceMapping.from_dict(data) + + return None + + def list_workspaces(self) -> list[WorkspaceMapping]: + """List all workspace mappings. + + Returns: + List of all workspace mappings + """ + workspaces = self._load_workspaces() + return [WorkspaceMapping.from_dict(data) for data in workspaces.values()] + + def remove_workspace(self, repo_full_name: str) -> bool: + """Remove a workspace mapping. + + Args: + repo_full_name: The full name of the repository + + Returns: + True if removed, False if not found + """ + workspaces = self._load_workspaces() + if repo_full_name in workspaces: + del workspaces[repo_full_name] + self._save_workspaces(workspaces) + return True + return False + + def detect_repository(self, directory: Optional[Path] = None) -> Optional[str]: + """Attempt to detect the GitHub repository from the .git config. + + Args: + directory: Directory to search (default: current directory) + + Returns: + The repository full name (e.g., "user/repo"), or None if not detected + """ + search_dir = directory or Path.cwd() + git_dir = search_dir / ".git" + + if not git_dir.exists(): + # Search parent directories + for parent in search_dir.parents: + git_dir = parent / ".git" + if git_dir.exists(): + search_dir = parent + break + else: + return None + + # Try to read the remote URL from .git/config + config_file = git_dir / "config" + if not config_file.exists(): + return None + + try: + with open(config_file, "r", encoding="utf-8") as f: + content = f.read() + + # Parse the remote URL (support both HTTPS and SSH) + for line in content.split("\n"): + line = line.strip() + if line.startswith("url = "): + url = line.replace("url = ", "") + + # Extract repo name from URL + # HTTPS: https://github.com/user/repo.git + # SSH: git@github.com:user/repo.git + if "github.com" in url: + if url.startswith("git@github.com:"): + repo_path = url.replace("git@github.com:", "").replace(".git", "") + elif "github.com/" in url: + repo_path = url.split("github.com/")[1].replace(".git", "") + else: + continue + + return repo_path + + except (IOError, OSError): + pass + + return None diff --git a/refactron/llm/__init__.py b/refactron/llm/__init__.py new file mode 100644 index 0000000..500a0e6 --- /dev/null +++ b/refactron/llm/__init__.py @@ -0,0 +1,6 @@ +"""LLM integration for intelligent code suggestions using free cloud APIs.""" + +from refactron.llm.client import GroqClient +from refactron.llm.orchestrator import LLMOrchestrator + +__all__ = ["GroqClient", "LLMOrchestrator"] diff --git a/refactron/llm/backend_client.py b/refactron/llm/backend_client.py new file mode 100644 index 0000000..8687bdd --- /dev/null +++ b/refactron/llm/backend_client.py @@ -0,0 +1,119 @@ +"""Client for Refactron backend LLM proxy.""" + +from __future__ import annotations + +import os +from typing import Any, Dict, Optional + +import requests + +from refactron.core.credentials import load_credentials + + +class BackendLLMClient: + """Client that proxies LLM requests through Refactron backend.""" + + def __init__( + self, + backend_url: Optional[str] = None, + model: str = "llama-3.3-70b-versatile", + temperature: float = 0.2, + max_tokens: int = 2000, + ): + """Initialize backend client. + + Args: + backend_url: Refactron backend URL + model: Model name to use + temperature: Sampling temperature + max_tokens: Maximum tokens to generate + """ + # Default to the local testing backend URL if not provided + self.backend_url = (backend_url or "https://api.refactron.dev").rstrip("/") + self.model = model + self.temperature = temperature + self.max_tokens = max_tokens + + # Load credentials to get access token/API key + self.creds = load_credentials() + + def generate( + self, + prompt: str, + system: Optional[str] = None, + temperature: Optional[float] = None, + max_tokens: Optional[int] = None, + ) -> str: + """Generate text using backend API. + + Args: + prompt: The user prompt + system: Optional system prompt + temperature: Override default temperature + max_tokens: Override default max tokens + + Returns: + Generated text + """ + headers = { + "Content-Type": "application/json", + "Accept": "application/json", + } + + # Use API key if available, otherwise use access token + if self.creds: + if self.creds.api_key: + headers["X-API-Key"] = self.creds.api_key + elif self.creds.access_token: + headers["Authorization"] = f"Bearer {self.creds.access_token}" + + payload = { + "prompt": prompt, + "system": system, + "temperature": temperature or self.temperature, + "max_tokens": max_tokens or self.max_tokens, + "model": self.model, + } + + try: + response = requests.post( + f"{self.backend_url}/api/llm/generate", + json=payload, + headers=headers, + timeout=60, + ) + + if response.status_code != 200: + error_msg = response.text + if response.headers.get("Content-Type") == "application/json": + try: + error_data = response.json() + error_msg = error_data.get("error", error_msg) + except Exception: + pass + raise RuntimeError(f"Backend LLM proxy error ({response.status_code}): {error_msg}") + + data = response.json() + return data["content"] + + except requests.exceptions.RequestException as e: + raise RuntimeError(f"Failed to connect to Refactron backend: {e}") + except RuntimeError: + raise + except Exception as e: + raise RuntimeError(f"Unexpected error during backend LLM generation: {e}") + + def check_health(self) -> bool: + """Check if the backend API is accessible. + + Returns: + True if API is accessible, False otherwise + """ + try: + response = requests.get( + f"{self.backend_url}/api/llm/health", + timeout=10, + ) + return response.status_code == 200 + except Exception: + return False diff --git a/refactron/llm/client.py b/refactron/llm/client.py new file mode 100644 index 0000000..bc3a576 --- /dev/null +++ b/refactron/llm/client.py @@ -0,0 +1,94 @@ +"""Groq cloud API client for free LLM inference.""" + +from __future__ import annotations + +import os +from typing import Optional + +try: + from groq import Groq + + GROQ_AVAILABLE = True +except ImportError: + GROQ_AVAILABLE = False + + +class GroqClient: + """Client for Groq cloud API (free LLM inference).""" + + def __init__( + self, + api_key: Optional[str] = None, + model: str = "llama-3.3-70b-versatile", # Updated to current model + temperature: float = 0.2, + max_tokens: int = 2000, + ): + """Initialize Groq client. + + Args: + api_key: Groq API key (defaults to GROQ_API_KEY env var) + model: Model name to use + temperature: Sampling temperature + max_tokens: Maximum tokens to generate + """ + if not GROQ_AVAILABLE: + raise RuntimeError("Groq is not available. Install with: pip install groq") + + self.api_key = api_key or os.getenv("GROQ_API_KEY") + if not self.api_key: + raise RuntimeError( + "GROQ_API_KEY environment variable not set. " + "Get your free API key at https://console.groq.com" + ) + + self.model = model + self.temperature = temperature + self.max_tokens = max_tokens + + self.client = Groq(api_key=self.api_key) + + def generate( + self, + prompt: str, + system: Optional[str] = None, + temperature: Optional[float] = None, + max_tokens: Optional[int] = None, + ) -> str: + """Generate text using Groq. + + Args: + prompt: The user prompt + system: Optional system prompt + temperature: Override default temperature + max_tokens: Override default max tokens + + Returns: + Generated text + """ + messages = [] + + if system: + messages.append({"role": "system", "content": system}) + + messages.append({"role": "user", "content": prompt}) + + response = self.client.chat.completions.create( + model=self.model, + messages=messages, + temperature=temperature or self.temperature, + max_tokens=max_tokens or self.max_tokens, + ) + + return response.choices[0].message.content + + def check_health(self) -> bool: + """Check if the Groq API is accessible. + + Returns: + True if API is accessible, False otherwise + """ + try: + self.generate("Hello", max_tokens=5) + return True + except Exception: + return False diff --git a/refactron/llm/models.py b/refactron/llm/models.py new file mode 100644 index 0000000..9e19362 --- /dev/null +++ b/refactron/llm/models.py @@ -0,0 +1,66 @@ +"""Data models for LLM integration.""" + +from dataclasses import dataclass, field +from enum import Enum +from typing import List, Optional, Dict, Any +import time +import uuid + +from refactron.core.models import CodeIssue + + +class SuggestionStatus(Enum): + """Status of a refactoring suggestion.""" + + PENDING = "pending" # Generated, waiting for review + VALIDATING = "validating" # Being checked by safety gate + APPROVED = "approved" # User approved + REJECTED = "rejected" # User rejected + APPLIED = "applied" # Code changes applied + FAILED = "failed" # Failed application or safety check + + +@dataclass +class SafetyCheckResult: + """Result of a safety gate validation.""" + + passed: bool + score: float # 0.0 to 1.0 + issues: List[str] + syntax_valid: bool = False + side_effects: List[str] = field(default_factory=list) + + +@dataclass +class RefactoringSuggestion: + """A refactoring suggestion generated by the LLM.""" + + # Context + issue: CodeIssue + original_code: str + context_files: List[str] # Files used for context + + # LLM Output + proposed_code: str + explanation: str + reasoning: str # Chain of thought + + # Metadata + model_name: str + confidence_score: float # Adjusted score (safety gate result) + llm_confidence: float = 0.5 # Raw score from LLM response + + # State + status: SuggestionStatus = SuggestionStatus.PENDING + safety_result: Optional[SafetyCheckResult] = None + + # Tracking + suggestion_id: str = field(default_factory=lambda: str(uuid.uuid4())) + timestamp: float = field(default_factory=time.time) + + def __post_init__(self): + if self.safety_result is None: + # Default empty safety result + self.safety_result = SafetyCheckResult( + passed=False, score=0.0, issues=["Not validated yet"] + ) diff --git a/refactron/llm/orchestrator.py b/refactron/llm/orchestrator.py new file mode 100644 index 0000000..914e703 --- /dev/null +++ b/refactron/llm/orchestrator.py @@ -0,0 +1,269 @@ +"""Orchestrator for LLM-based refactoring suggestions.""" + +import json +import logging +import os +import re +from typing import Optional, List, Union + +from pathlib import Path +from refactron.core.models import CodeIssue, IssueCategory, IssueLevel +from refactron.llm.client import GroqClient +from refactron.llm.backend_client import BackendLLMClient +from refactron.llm.models import RefactoringSuggestion, SuggestionStatus +from refactron.llm.prompts import SYSTEM_PROMPT, SUGGESTION_PROMPT, DOCUMENTATION_PROMPT +from refactron.llm.safety import SafetyGate +from refactron.rag.retriever import ContextRetriever + +logger = logging.getLogger(__name__) + + +class LLMOrchestrator: + """Coordinates RAG context retrieval and LLM generation.""" + + def __init__( + self, + retriever: Optional[ContextRetriever] = None, + llm_client: Optional[Union[GroqClient, BackendLLMClient]] = None, + safety_gate: Optional[SafetyGate] = None, + ): + self.retriever = retriever + + if llm_client: + self.client = llm_client + else: + # Try to use GroqClient if API key is present, otherwise use BackendLLMClient + if os.getenv("GROQ_API_KEY"): + try: + self.client = GroqClient() + except RuntimeError: + self.client = BackendLLMClient() + else: + self.client = BackendLLMClient() + + self.safety_gate = safety_gate or SafetyGate() + + def generate_suggestion(self, issue: CodeIssue, original_code: str) -> RefactoringSuggestion: + """Generate a refactoring suggestion for a code issue. + + Args: + issue: The code issue to fix + original_code: The failing code snippet + + Returns: + A validated refactoring suggestion + """ + # 1. Retrieve Context + context_snippets = [] + if self.retriever: + try: + # Search for similar code or relevant context + results = self.retriever.retrieve_similar( + f"{issue.message} {original_code}", top_k=3 + ) + context_snippets = [r.content for r in results] + except Exception as e: + logger.warning(f"Context retrieval failed: {e}") + + rag_context = "\n\n".join(context_snippets) if context_snippets else "No context available." + + # 2. Construct Prompt + prompt = SUGGESTION_PROMPT.format( + issue_message=issue.message, + file_path=issue.file_path, + line_number=issue.line_number, + severity=issue.level.value, + original_code=original_code, + rag_context=rag_context, + ) + + # 3. Call LLM + response_text = "N/A" + try: + response_text = self.client.generate( + prompt=prompt, system=SYSTEM_PROMPT, temperature=0.2 # Low temperature for code + ) + + # Parse JSON response + # Note: Groq might return markdown code blocks, strip them + clean_text = self._clean_json_response(response_text) + + # Using strict=False allows control characters like newlines in strings + data = json.loads(clean_text, strict=False) + + # Extract and parse confidence score + raw_confidence = data.get("confidence_score", 0.7) + try: + # Handle strings with % or range strings + if isinstance(raw_confidence, str): + match = re.search(r"(\d+\.?\d*)", raw_confidence) + confidence = ( + float(match.group(1)) / 100.0 + if "%" in raw_confidence + else float(match.group(1)) + ) + else: + confidence = float(raw_confidence) + except (ValueError, TypeError, AttributeError): + confidence = 0.5 # Fallback + + suggestion = RefactoringSuggestion( + issue=issue, + original_code=original_code, + context_files=[r.file_path for r in results] if self.retriever else [], + proposed_code=data.get("proposed_code", ""), + explanation=data.get("explanation", "No explanation provided."), + reasoning=data.get("reasoning", ""), + model_name=self.client.model, + confidence_score=min(max(confidence, 0.0), 1.0), + llm_confidence=min(max(confidence, 0.0), 1.0), + ) + + except Exception as e: + logger.error(f"LLM generation failed: {e}") + logger.debug(f"Raw response: {response_text}") + # Return a failed suggestion + return RefactoringSuggestion( + issue=issue, + original_code=original_code, + context_files=[], + proposed_code="", + explanation=f"Generation failed: {str(e)}", + reasoning="", + model_name=self.client.model, + confidence_score=0.0, + status=SuggestionStatus.FAILED, + ) + + # 4. Safety Validation + try: + safety_result = self.safety_gate.validate(suggestion) + suggestion.safety_result = safety_result + + # Link confidence score to safety result score + suggestion.confidence_score = safety_result.score + + if not safety_result.passed: + suggestion.status = SuggestionStatus.REJECTED + logger.warning(f"Suggestion failed safety check: {safety_result.issues}") + else: + suggestion.status = SuggestionStatus.PENDING + + except Exception as e: + logger.error(f"Safety validation failed: {e}") + suggestion.status = SuggestionStatus.FAILED + + return suggestion + + def generate_documentation( + self, code: str, file_path: str = "unknown" + ) -> RefactoringSuggestion: + """Generate documentation for the provided code. + + Args: + code: The code to document + file_path: Optional file path for context + + Returns: + A suggestion containing the documented code + """ + # Create a synthetic issue for tracking + issue = CodeIssue( + category=IssueCategory.DOCUMENTATION, + level=IssueLevel.INFO, + message="Add documentation", + file_path=Path(file_path), + line_number=1, + ) + + # 1. Retrieve Context + context_snippets = [] + if self.retriever: + try: + # Search for similar code or relevant context + results = self.retriever.retrieve_similar(code[:500], top_k=2) + context_snippets = [r.content for r in results] + except Exception as e: + logger.warning(f"Context retrieval failed: {e}") + + rag_context = "\n\n".join(context_snippets) if context_snippets else "No context available." + + # 2. Construct Prompt + prompt = DOCUMENTATION_PROMPT.format(original_code=code, rag_context=rag_context) + + # 3. Call LLM + try: + response_text = self.client.generate( + prompt=prompt, system=SYSTEM_PROMPT, temperature=0.2 + ) + + # Parse custom delimiter format + explanation = "Added documentation" + proposed_code = "" + confidence = 0.8 + + if "@@@EXPLANATION@@@" in response_text: + parts = response_text.split("@@@") + for i, part in enumerate(parts): + if part == "EXPLANATION": + explanation = parts[i + 1].strip() + elif part == "CONFIDENCE": + try: + confidence = float(parts[i + 1].strip()) + except ValueError: + pass + elif part == "MARKDOWN": + proposed_code = parts[i + 1].strip() + + if not proposed_code: + raise ValueError("Could not extract markdown content from response") + + return RefactoringSuggestion( + issue=issue, + original_code=code, + context_files=[], + explanation=explanation, + proposed_code=proposed_code, + reasoning="Documentation update", + confidence_score=confidence, + llm_confidence=confidence, + model_name=self.client.model, + status=SuggestionStatus.PENDING, + ) + + except Exception as e: + logger.error(f"Documentation generation failed: {e}") + return RefactoringSuggestion( + issue=issue, + original_code=code, + context_files=[], + explanation=f"Error: {str(e)}", + proposed_code="", + reasoning="", + confidence_score=0.0, + llm_confidence=0.0, + model_name=self.client.model, + status=SuggestionStatus.FAILED, + ) + + def _clean_json_response(self, text: str) -> str: + """Clean LLM response to extract JSON.""" + text = text.strip() + + # Remove markdown code blocks + if "```json" in text: + start = text.find("```json") + 7 + end = text.find("```", start) + if end != -1: + text = text[start:end] + else: + text = text[start:] + elif "```" in text: + start = text.find("```") + 3 + end = text.find("```", start) + if end != -1: + text = text[start:end] + else: + text = text[start:] + + return text.strip() diff --git a/refactron/llm/prompts.py b/refactron/llm/prompts.py new file mode 100644 index 0000000..ade8031 --- /dev/null +++ b/refactron/llm/prompts.py @@ -0,0 +1,94 @@ +"""Prompt templates for LLM code suggestions.""" + +SYSTEM_PROMPT = """You are an expert software architect and code refactoring specialist. +Your goal is to analyze code issues and provide safe, idiomatic, and performance-optimized fixes. + +RESPONSE FORMAT: +You must output ONLY valid JSON. +- Escape all double quotes inside strings with backslash (e.g. \\"). +- Do not use trailing commas. +- Do not output markdown code blocks, just the raw JSON object. +- Ensure newlines in strings are escaped as \\n. + +Output JSON structure: +{ + "explanation": "Brief explanation of the fix", + "proposed_code": "The complete fixed code block", + "reasoning": "Step-by-step reasoning for the change", + "confidence_score": "Float between 0.0 and 1.0 representing your confidence in this fix" +} +""" + +SUGGESTION_PROMPT = """ +Fix the following code issue: + +Issue: {issue_message} +File: {file_path}:{line_number} +Severity: {severity} + +Original Code: +```python +{original_code} +``` + +Relevant Context (RAG): +{rag_context} + +Provide a fix that resolves the issue while maintaining consistency with the codebase context. +IMPORTANT: Add inline comments ONLY where absolutely necessary to explain complex logic. +Do NOT add comments for obvious code or after every line. +""" + +SAFETY_CHECK_PROMPT = """ +Analyze the following code patch for safety risks: + +```python +{proposed_code} +``` + +Identify any: +1. Syntax errors +2. Security vulnerabilities +3. Dangerous side effects (file implementation, network calls) +4. Breaking changes + +Output valid JSON: +{ + "safe": boolean, + "risk_score": float (0.0-1.0), + "risk_score": float (0.0-1.0), + "issues": [list of strings] +} +""" + +DOCUMENTATION_PROMPT = """ +Analyze the following Python code and generate a comprehensive MARKDOWN documentation file. + +Original Code: +```python +{original_code} +``` + +Relevant Context (RAG): +{rag_context} + +Instructions: +1. Create a professional Developer Guide in Markdown format. +2. Include: + - Module Overview + - Key Classes and Functions (with signatures and descriptions) + - Usage Examples + - Logic Flow/Algorithm Explanation + - **Mermaid Diagram**: Create a `graph TD` flow chart representing the logic/algorithm. +3. Do NOT simply copy the code. Explain it. +4. Use the specific delimiters below for your response. + +RESPONSE FORMAT: +@@@EXPLANATION@@@ +Brief summary of documentation created +@@@CONFIDENCE@@@ +Float between 0.0 and 1.0 (e.g. 0.95) +@@@MARKDOWN@@@ +The complete Markdown documentation content including the mermaid diagram +@@@END@@@ +""" diff --git a/refactron/llm/safety.py b/refactron/llm/safety.py new file mode 100644 index 0000000..eda20ce --- /dev/null +++ b/refactron/llm/safety.py @@ -0,0 +1,113 @@ +"""Safety gate for validating LLM-generated code.""" + +import ast +from typing import List, Optional + +from refactron.llm.models import SafetyCheckResult, RefactoringSuggestion + + +class SafetyGate: + """Validates code suggestions for safety and correctness.""" + + def __init__(self, min_confidence: float = 0.7): + self.min_confidence = min_confidence + + def validate(self, suggestion: RefactoringSuggestion) -> SafetyCheckResult: + """Validate a refactoring suggestion. + + Args: + suggestion: The suggestion to validate + + Returns: + Safety check result + """ + issues = [] + side_effects = [] + score = 1.0 + + # 1. Syntax Check + try: + ast.parse(suggestion.proposed_code) + syntax_valid = True + except SyntaxError as e: + syntax_valid = False + issues.append(f"Syntax Error: {e}") + score = 0.0 + + # 2. Confidence Check + if suggestion.confidence_score < self.min_confidence: + issues.append( + f"Low confidence score: {suggestion.confidence_score:.2f} < {self.min_confidence}" + ) + score *= 0.8 + + # 3. Basic Security/Risk Checks + risk_score = self._assess_risk(suggestion.proposed_code) + if risk_score > 0.8: + issues.append("High risk code detected") + score *= 0.5 + + # 4. Dangerous Imports Check + dangerous_imports = self._check_dangerous_imports( + suggestion.proposed_code, suggestion.original_code + ) + if dangerous_imports: + issues.append(f"Dangerous imports detected: {', '.join(dangerous_imports)}") + side_effects.extend([f"Import: {imp}" for imp in dangerous_imports]) + score *= 0.7 + + return SafetyCheckResult( + passed=(score > 0.8 and syntax_valid), + score=score, + issues=issues, + syntax_valid=syntax_valid, + side_effects=side_effects, + ) + + def _assess_risk(self, code: str) -> float: + """Assess the risk of the code patch.""" + risk = 0.0 + + # Keywords that suggest side effects + risky_keywords = [ + "subprocess", + "os.system", + "shutil.rmtree", + "open(", + "requests.", + "urllib", + "eval(", + "exec(", + ] + + for keyword in risky_keywords: + if keyword in code: + risk += 0.3 + + return min(risk, 1.0) + + def _check_dangerous_imports(self, proposed_code: str, original_code: str) -> List[str]: + """Check for potentially dangerous imports that are NEW.""" + dangerous_modules = ["subprocess", "os", "shutil", "sys"] + + def get_imports(code): + imports = set() + try: + tree = ast.parse(code) + for node in ast.walk(tree): + if isinstance(node, ast.Import): + for name in node.names: + if name.name in dangerous_modules: + imports.add(name.name) + elif isinstance(node, ast.ImportFrom): + if node.module in dangerous_modules: + imports.add(node.module) + except SyntaxError: + pass + return imports + + original_imports = get_imports(original_code) + proposed_imports = get_imports(proposed_code) + + # Only flag imports that were added by the LLM + return list(proposed_imports - original_imports) diff --git a/refactron/patterns/fingerprint.py b/refactron/patterns/fingerprint.py index 8857347..beda724 100644 --- a/refactron/patterns/fingerprint.py +++ b/refactron/patterns/fingerprint.py @@ -28,9 +28,10 @@ def fingerprint_code(self, code_snippet: str) -> str: return self._hash_algo(b"").hexdigest() # Optimize: Parse AST once, extract both normalized code and pattern - normalized = self._normalize_code(code_snippet) + # Anonymize identifiers to allow structural generalization + anonymized = self._anonymize_code(code_snippet) ast_pattern = self._extract_ast_pattern(code_snippet) - combined = f"{normalized}\n{ast_pattern}".encode("utf-8") + combined = f"{anonymized}\n{ast_pattern}".encode("utf-8") return self._hash_algo(combined).hexdigest() def fingerprint_issue_context( @@ -70,8 +71,9 @@ def fingerprint_refactoring(self, operation: RefactoringOperation) -> str: SHA256 hash of the normalized refactoring pattern """ # Combine old_code pattern + operation_type for unique identification - normalized_old = self._normalize_code(operation.old_code) - operation_key = f"{operation.operation_type}:{normalized_old}" + # Use anonymization to group similar structural refactorings + anonymized_old = self._anonymize_code(operation.old_code) + operation_key = f"{operation.operation_type}:{anonymized_old}" combined = operation_key.encode("utf-8") return self._hash_algo(combined).hexdigest() @@ -149,9 +151,9 @@ def _extract_ast_pattern(self, code: str) -> str: # Extract key structural elements if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): - pattern_parts.append(f"FUNC:{node.name}") + pattern_parts.append("FUNC") elif isinstance(node, ast.ClassDef): - pattern_parts.append(f"CLASS:{node.name}") + pattern_parts.append("CLASS") elif isinstance(node, ast.If): pattern_parts.append("IF") elif isinstance(node, ast.For): @@ -176,3 +178,43 @@ def _extract_ast_pattern(self, code: str) -> str: except (SyntaxError, ValueError): # If code is invalid, return empty pattern return "" + + def _anonymize_code(self, code: str) -> str: + """ + Normalize and anonymize identifiers to help patterns generalize. + + Replaces specific function/variable names with generic tokens. + """ + try: + tree = ast.parse(code) + + for node in ast.walk(tree): + if isinstance(node, ast.Name): + node.id = "VAR" + elif isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): + node.name = "FUNC" + elif isinstance(node, ast.ClassDef): + node.name = "CLASS" + elif isinstance(node, ast.Attribute): + node.attr = "ATTR" + elif isinstance(node, ast.arg): + node.arg = "ARG" + elif isinstance(node, ast.Constant): + # Anonymize all literals (numbers, strings, etc) + node.value = "CONST" + elif isinstance(node, ast.JoinedStr): + # Anonymize f-strings by replacing constant parts + for i, value in enumerate(node.values): + if isinstance(value, ast.Constant): + node.values[i] = ast.Constant(value="STR") + + # Using ast.unparse if available (Python 3.9+) + if hasattr(ast, "unparse"): + return self._normalize_code(ast.unparse(tree)) + + # Fallback to standard normalization if unparse is not available + return self._normalize_code(code) + + except (SyntaxError, ValueError): + # If AST parsing fails, fallback to basic normalization + return self._normalize_code(code) diff --git a/refactron/rag/__init__.py b/refactron/rag/__init__.py new file mode 100644 index 0000000..3636f0a --- /dev/null +++ b/refactron/rag/__init__.py @@ -0,0 +1,6 @@ +"""RAG (Retrieval-Augmented Generation) infrastructure for code indexing and retrieval.""" + +from refactron.rag.indexer import RAGIndexer +from refactron.rag.retriever import ContextRetriever + +__all__ = ["RAGIndexer", "ContextRetriever"] diff --git a/refactron/rag/chunker.py b/refactron/rag/chunker.py new file mode 100644 index 0000000..5825229 --- /dev/null +++ b/refactron/rag/chunker.py @@ -0,0 +1,179 @@ +"""Code chunking strategies for RAG indexing.""" + +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Dict, List, Tuple + +from refactron.rag.parser import CodeParser, ParsedClass, ParsedFile, ParsedFunction + + +@dataclass +class CodeChunk: + """Represents a semantic chunk of code.""" + + content: str + chunk_type: str # "function", "class", "module" + file_path: str + line_range: Tuple[int, int] + name: str + dependencies: List[str] + metadata: Dict[str, Any] + + +class CodeChunker: + """Chunks parsed code into semantic units for embedding.""" + + def __init__(self, parser: CodeParser): + """Initialize the chunker. + + Args: + parser: CodeParser instance for parsing files + """ + self.parser = parser + + def chunk_file(self, file_path: Path) -> List[CodeChunk]: + """Chunk a file into semantic units. + + Args: + file_path: Path to the Python file + + Returns: + List of code chunks + """ + # Parse the file first + parsed_file = self.parser.parse_file(file_path) + + chunks = [] + + # Add module-level chunk if there's a docstring or imports + if parsed_file.module_docstring or parsed_file.imports: + chunks.append(self._create_module_chunk(parsed_file)) + + # Add function chunks + for func in parsed_file.functions: + chunks.append(self._create_function_chunk(func, parsed_file.file_path)) + + # Add class chunks + for cls in parsed_file.classes: + chunks.extend(self._create_class_chunks(cls, parsed_file.file_path)) + + return chunks + + def _create_module_chunk(self, parsed_file: ParsedFile) -> CodeChunk: + """Create a chunk for module-level information.""" + content_parts = [] + + if parsed_file.module_docstring: + content_parts.append(f'"""{parsed_file.module_docstring}"""') + + if parsed_file.imports: + content_parts.append("\n".join(parsed_file.imports)) + + # Add context header + header = f"File: {parsed_file.file_path}, Type: Module" + content = header + "\n" + "-" * len(header) + "\n\n" + "\n\n".join(content_parts) + + return CodeChunk( + content=content, + chunk_type="module", + file_path=parsed_file.file_path, + line_range=(1, len(parsed_file.imports) + 1), + name="module", + dependencies=parsed_file.imports, + metadata={ + "docstring": parsed_file.module_docstring, + "num_imports": len(parsed_file.imports), + }, + ) + + def _create_function_chunk(self, func: ParsedFunction, file_path: str) -> CodeChunk: + """Create a chunk for a function.""" + # Build content with context + content_parts = [] + + if func.docstring: + content_parts.append(f'"""{func.docstring}"""') + + content_parts.append(func.body) + + # Add context header + header = f"File: {file_path}, Function: {func.name}" + content = header + "\n" + "-" * len(header) + "\n\n" + "\n".join(content_parts) + + return CodeChunk( + content=content, + chunk_type="function", + file_path=file_path, + line_range=func.line_range, + name=func.name, + dependencies=[], # Could extract from body if needed + metadata={ + "docstring": func.docstring, + "params": func.params, + "num_params": len(func.params), + }, + ) + + def _create_class_chunks(self, cls: ParsedClass, file_path: str) -> List[CodeChunk]: + """Create chunks for a class and its methods.""" + chunks = [] + + # Create class overview chunk + class_content_parts = [] + + if cls.docstring: + class_content_parts.append(f'"""{cls.docstring}"""') + + # Include class signature and docstring + class_signature = cls.body.split("\n")[0] # First line + class_content_parts.append(class_signature) + + # Add context header + header = f"File: {file_path}, Class: {cls.name}" + content = header + "\n" + "-" * len(header) + "\n\n" + "\n".join(class_content_parts) + + class_chunk = CodeChunk( + content=content, + chunk_type="class", + file_path=file_path, + line_range=cls.line_range, + name=cls.name, + dependencies=[], + metadata={ + "docstring": cls.docstring, + "num_methods": len(cls.methods), + "methods": [m.name for m in cls.methods], + }, + ) + chunks.append(class_chunk) + + # Create chunks for each method + for method in cls.methods: + # Add context header specifically for methods + header = f"File: {file_path}, Class: {cls.name}, Method: {method.name}" + content_parts = [] + if method.docstring: + content_parts.append(f'"""{method.docstring}"""') + content_parts.append(method.body) + + content = header + "\n" + "-" * len(header) + "\n\n" + "\n".join(content_parts) + + method_chunk = CodeChunk( + content=content, + chunk_type="method", + file_path=file_path, + line_range=method.line_range, + name=method.name, + dependencies=[], + metadata={ + "docstring": method.docstring, + "params": method.params, + "num_params": len(method.params), + "class_name": cls.name, + }, + ) + chunks.append(method_chunk) + + return chunks diff --git a/refactron/rag/indexer.py b/refactron/rag/indexer.py new file mode 100644 index 0000000..a06e8e2 --- /dev/null +++ b/refactron/rag/indexer.py @@ -0,0 +1,281 @@ +"""Vector index management using ChromaDB.""" + +from __future__ import annotations + +import json +from dataclasses import dataclass +from pathlib import Path +from typing import List, Optional + +try: + import chromadb + from chromadb.config import Settings + from sentence_transformers import SentenceTransformer + + CHROMA_AVAILABLE = True +except ImportError: + chromadb = None + Settings = None + SentenceTransformer = None + CHROMA_AVAILABLE = False + +from refactron.rag.chunker import CodeChunk +from refactron.rag.parser import CodeParser, ParsedFile + + +@dataclass +class IndexStats: + """Statistics about the RAG index.""" + + total_chunks: int + total_files: int + chunk_types: dict + embedding_model: str + index_path: str + + +class RAGIndexer: + """Manages code indexing for RAG retrieval.""" + + def __init__( + self, + workspace_path: Path, + embedding_model: str = "all-MiniLM-L6-v2", + collection_name: str = "code_chunks", + llm_client: Optional[GroqClient] = None, + ): + """Initialize the RAG indexer. + + Args: + workspace_path: Path to the workspace directory + embedding_model: Name of the sentence-transformers model + collection_name: Name of the ChromaDB collection + llm_client: Optional LLM client for code summarization + """ + if not CHROMA_AVAILABLE: + raise RuntimeError( + "ChromaDB is not available. Install with: pip install chromadb sentence-transformers" + ) + + self.workspace_path = Path(workspace_path) + self.index_path = self.workspace_path / ".rag" + self.index_path.mkdir(exist_ok=True) + + # Initialize LLM client for summarization + from refactron.llm.client import GroqClient + + self.llm_client = llm_client + + # Initialize embedding model + self.embedding_model_name = embedding_model + self.embedding_model = SentenceTransformer(embedding_model) + + # Initialize ChromaDB + self.client = chromadb.PersistentClient( + path=str(self.index_path / "chroma"), settings=Settings(anonymized_telemetry=False) + ) + + # Get or create collection + self.collection = self.client.get_or_create_collection( + name=collection_name, + metadata={"embedding_model": embedding_model, "hnsw:space": "cosine"}, + ) + + self.parser = CodeParser() + + def index_repository( + self, repo_path: Optional[Path] = None, summarize: bool = False + ) -> IndexStats: + """Index an entire repository. + + Args: + repo_path: Path to repository (defaults to workspace_path) + summarize: Whether to use AI to summarize code for better retrieval + + Returns: + Statistics about the indexed content + """ + if summarize and not self.llm_client: + from refactron.llm.client import GroqClient + + try: + self.llm_client = GroqClient() + except Exception as e: + print(f"Warning: Could not initialize AI for summarization: {e}") + summarize = False + + if repo_path is None: + repo_path = self.workspace_path + + repo_path = Path(repo_path) + + # Find all Python files + python_files = list(repo_path.rglob("*.py")) + + # Filter out common excluded directories + excluded_dirs = {".git", ".rag", "__pycache__", "venv", ".venv", "env", "node_modules"} + python_files = [ + f for f in python_files if not any(excluded in f.parts for excluded in excluded_dirs) + ] + + total_chunks = 0 + chunk_type_counts = {} + + # Index each file + for py_file in python_files: + try: + chunks = self._index_file(py_file, summarize=summarize) + total_chunks += len(chunks) + + # Count chunk types + for chunk in chunks: + chunk_type_counts[chunk.chunk_type] = ( + chunk_type_counts.get(chunk.chunk_type, 0) + 1 + ) + except Exception as e: + # Skip files that can't be parsed + print(f"Warning: Could not index {py_file}: {e}") + continue + + # Save index metadata + self._save_metadata( + { + "total_chunks": total_chunks, + "total_files": len(python_files), + "chunk_types": chunk_type_counts, + } + ) + + return IndexStats( + total_chunks=total_chunks, + total_files=len(python_files), + chunk_types=chunk_type_counts, + embedding_model=self.embedding_model_name, + index_path=str(self.index_path), + ) + + def _index_file(self, file_path: Path, summarize: bool = False) -> List[CodeChunk]: + """Index a single Python file. + + Args: + file_path: Path to the Python file + summarize: Whether to use AI to summarize chunks + + Returns: + List of code chunks that were indexed + """ + from refactron.rag.chunker import CodeChunker + + # Chunk the file (parser is called inside chunker) + chunker = CodeChunker(self.parser) + chunks = chunker.chunk_file(file_path) + + if summarize and self.llm_client: + for chunk in chunks: + try: + summary = self._summarize_chunk(chunk) + if summary: + # Prepend summary to content for embedding (makes it searchable by plain English) + chunk.content = f"Summary: {summary}\n\n{chunk.content}" + chunk.metadata["ai_summary"] = summary + except Exception as e: + print(f"Warning: AI summarization failed for chunk in {file_path}: {e}") + + # Add chunks to the index + self.add_chunks(chunks) + + return chunks + + def _summarize_chunk(self, chunk: CodeChunk) -> Optional[str]: + """Use AI to generate a brief semantic summary of a code chunk.""" + if not self.llm_client: + return None + + prompt = ( + "Analyze the following Python code snippet and provide a one-sentence " + "summary of its purpose, focusing on what it DOES (e.g. 'Calculates user permissions' " + "or 'Handles secure database connections').\n\n" + f"Code:\n{chunk.content}" + ) + + try: + summary = self.llm_client.generate( + prompt=prompt, + system="You are a senior software architect. Provide a concise, semantic summary of code purpose.", + max_tokens=100, + ) + return summary.strip() + except Exception: + return None + + def add_chunks(self, chunks: List[CodeChunk]) -> None: + """Add code chunks to the vector index. + + Args: + chunks: List of code chunks to add + """ + if not chunks: + return + + # Prepare data for ChromaDB + documents = [chunk.content for chunk in chunks] + metadatas = [] + for chunk in chunks: + metadata = { + "chunk_type": chunk.chunk_type, + "file_path": chunk.file_path, + "name": chunk.name, + "line_start": chunk.line_range[0], + "line_end": chunk.line_range[1], + } + # Add chunk metadata, converting lists/dicts to JSON strings + for key, value in chunk.metadata.items(): + if isinstance(value, (list, dict)): + metadata[key] = json.dumps(value) + elif value is not None: # Skip None values + metadata[key] = value + metadatas.append(metadata) + + ids = [f"{chunk.file_path}:{chunk.line_range[0]}-{chunk.line_range[1]}" for chunk in chunks] + + # Generate embeddings + embeddings = self.embedding_model.encode(documents, show_progress_bar=False).tolist() + + # Add to collection + self.collection.add( + documents=documents, + metadatas=metadatas, + ids=ids, + embeddings=embeddings, + ) + + def get_stats(self) -> IndexStats: + """Get statistics about the current index. + + Returns: + Index statistics + """ + metadata = self._load_metadata() + + return IndexStats( + total_chunks=metadata.get("total_chunks", 0), + total_files=metadata.get("total_files", 0), + chunk_types=metadata.get("chunk_types", {}), + embedding_model=self.embedding_model_name, + index_path=str(self.index_path), + ) + + def _save_metadata(self, metadata: dict) -> None: + """Save index metadata.""" + metadata_file = self.index_path / "metadata.json" + with open(metadata_file, "w") as f: + json.dump(metadata, f, indent=2) + + def _load_metadata(self) -> dict: + """Load index metadata.""" + metadata_file = self.index_path / "metadata.json" + if not metadata_file.exists(): + return {} + + with open(metadata_file, "r") as f: + return json.load(f) diff --git a/refactron/rag/parser.py b/refactron/rag/parser.py new file mode 100644 index 0000000..942b7a8 --- /dev/null +++ b/refactron/rag/parser.py @@ -0,0 +1,253 @@ +"""Code parser using tree-sitter for AST-aware code analysis.""" + +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple + +try: + from tree_sitter import Language, Parser, Node + import tree_sitter_python as tspython + + TREE_SITTER_AVAILABLE = True +except ImportError: + TREE_SITTER_AVAILABLE = False + + +@dataclass +class ParsedFunction: + """Represents a parsed function.""" + + name: str + body: str + docstring: Optional[str] + line_range: Tuple[int, int] + params: List[str] + + +@dataclass +class ParsedClass: + """Represents a parsed class.""" + + name: str + body: str + docstring: Optional[str] + line_range: Tuple[int, int] + methods: List[ParsedFunction] + + +@dataclass +class ParsedFile: + """Represents a parsed Python file.""" + + file_path: str + imports: List[str] + functions: List[ParsedFunction] + classes: List[ParsedClass] + module_docstring: Optional[str] + + +class CodeParser: + """AST-aware code parser using tree-sitter.""" + + def __init__(self): + """Initialize the parser.""" + if not TREE_SITTER_AVAILABLE: + raise RuntimeError( + "tree-sitter is not available. Install with: pip install tree-sitter tree-sitter-python" + ) + + # Initialize Python language - wrap capsule with Language + PY_LANGUAGE = Language(tspython.language()) + self.parser = Parser(PY_LANGUAGE) + + def parse_file(self, file_path: Path) -> ParsedFile: + """Parse a Python file. + + Args: + file_path: Path to the Python file + + Returns: + ParsedFile object containing all parsed elements + """ + with open(file_path, "rb") as f: + source_code = f.read() + + tree = self.parser.parse(source_code) + root = tree.root_node + + # Extract module docstring + module_docstring = self._extract_module_docstring(root, source_code) + + # Extract imports + imports = self._extract_imports(root, source_code) + + # Extract functions + functions = self._extract_functions(root, source_code) + + # Extract classes + classes = self._extract_classes(root, source_code) + + return ParsedFile( + file_path=str(file_path), + imports=imports, + functions=functions, + classes=classes, + module_docstring=module_docstring, + ) + + def _extract_module_docstring(self, root: Node, source: bytes) -> Optional[str]: + """Extract module-level docstring.""" + for child in root.children: + if child.type == "expression_statement": + string_node = child.children[0] if child.children else None + if string_node and string_node.type == "string": + return ( + source[string_node.start_byte : string_node.end_byte] + .decode("utf-8") + .strip("\"'") + ) + return None + + def _extract_imports(self, root: Node, source: bytes) -> List[str]: + """Extract import statements.""" + imports = [] + for node in root.children: + if node.type in ("import_statement", "import_from_statement"): + import_text = source[node.start_byte : node.end_byte].decode("utf-8") + imports.append(import_text) + return imports + + def _extract_functions(self, root: Node, source: bytes) -> List[ParsedFunction]: + """Extract function definitions.""" + functions = [] + for node in root.children: + if node.type == "function_definition": + func = self._parse_function(node, source) + if func: + functions.append(func) + return functions + + def _extract_classes(self, root: Node, source: bytes) -> List[ParsedClass]: + """Extract class definitions.""" + classes = [] + for node in root.children: + if node.type == "class_definition": + cls = self._parse_class(node, source) + if cls: + classes.append(cls) + return classes + + def _parse_function(self, node: Node, source: bytes) -> Optional[ParsedFunction]: + """Parse a function node.""" + # Get function name + name_node = node.child_by_field_name("name") + if not name_node: + return None + + name = source[name_node.start_byte : name_node.end_byte].decode("utf-8") + + # Get function body + body = source[node.start_byte : node.end_byte].decode("utf-8") + + # Get docstring + docstring = self._extract_function_docstring(node, source) + + # Get line range + line_range = (node.start_point[0] + 1, node.end_point[0] + 1) + + # Get parameters + params = self._extract_parameters(node, source) + + return ParsedFunction( + name=name, + body=body, + docstring=docstring, + line_range=line_range, + params=params, + ) + + def _parse_class(self, node: Node, source: bytes) -> Optional[ParsedClass]: + """Parse a class node.""" + # Get class name + name_node = node.child_by_field_name("name") + if not name_node: + return None + + name = source[name_node.start_byte : name_node.end_byte].decode("utf-8") + + # Get class body + body = source[node.start_byte : node.end_byte].decode("utf-8") + + # Get docstring + docstring = self._extract_class_docstring(node, source) + + # Get line range + line_range = (node.start_point[0] + 1, node.end_point[0] + 1) + + # Extract methods from class body + methods = [] + body_node = node.child_by_field_name("body") + if body_node: + for child in body_node.children: + if child.type == "function_definition": + method = self._parse_function(child, source) + if method: + methods.append(method) + + return ParsedClass( + name=name, + body=body, + docstring=docstring, + line_range=line_range, + methods=methods, + ) + + def _extract_function_docstring(self, node: Node, source: bytes) -> Optional[str]: + """Extract function docstring.""" + body_node = node.child_by_field_name("body") + if not body_node: + return None + + for child in body_node.children: + if child.type == "expression_statement": + string_node = child.children[0] if child.children else None + if string_node and string_node.type == "string": + return ( + source[string_node.start_byte : string_node.end_byte] + .decode("utf-8") + .strip("\"'") + ) + return None + + def _extract_class_docstring(self, node: Node, source: bytes) -> Optional[str]: + """Extract class docstring.""" + body_node = node.child_by_field_name("body") + if not body_node: + return None + + for child in body_node.children: + if child.type == "expression_statement": + string_node = child.children[0] if child.children else None + if string_node and string_node.type == "string": + return ( + source[string_node.start_byte : string_node.end_byte] + .decode("utf-8") + .strip("\"'") + ) + return None + + def _extract_parameters(self, node: Node, source: bytes) -> List[str]: + """Extract function parameters.""" + params = [] + params_node = node.child_by_field_name("parameters") + if not params_node: + return params + + for child in params_node.children: + if child.type == "identifier": + param_name = source[child.start_byte : child.end_byte].decode("utf-8") + params.append(param_name) + + return params diff --git a/refactron/rag/retriever.py b/refactron/rag/retriever.py new file mode 100644 index 0000000..2a71d1b --- /dev/null +++ b/refactron/rag/retriever.py @@ -0,0 +1,178 @@ +"""Context retrieval from the RAG index.""" + +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path +from typing import List, Optional + +try: + import chromadb + from chromadb.config import Settings + from sentence_transformers import SentenceTransformer + + CHROMA_AVAILABLE = True +except ImportError: + chromadb = None + Settings = None + SentenceTransformer = None + CHROMA_AVAILABLE = False + + +@dataclass +class RetrievedContext: + """Represents a retrieved code context.""" + + content: str + file_path: str + chunk_type: str + name: str + line_range: tuple + distance: float # Similarity distance + metadata: dict + + +class ContextRetriever: + """Retrieves relevant code context from the RAG index.""" + + def __init__( + self, + workspace_path: Path, + embedding_model: str = "all-MiniLM-L6-v2", + collection_name: str = "code_chunks", + ): + """Initialize the context retriever. + + Args: + workspace_path: Path to the workspace directory + embedding_model: Name of the sentence-transformers model + collection_name: Name of the ChromaDB collection + """ + if not CHROMA_AVAILABLE: + raise RuntimeError( + "ChromaDB is not available. Install with: pip install chromadb sentence-transformers" + ) + + self.workspace_path = Path(workspace_path) + self.index_path = self.workspace_path / ".rag" + + # Initialize embedding model + self.embedding_model = SentenceTransformer(embedding_model) + + # Initialize ChromaDB client + self.client = chromadb.PersistentClient( + path=str(self.index_path / "chroma"), settings=Settings(anonymized_telemetry=False) + ) + + # Get collection + try: + self.collection = self.client.get_collection(name=collection_name) + except Exception: + raise RuntimeError( + f"Index not found at {self.index_path}. Run 'refactron rag index' first." + ) + + def retrieve_similar( + self, query: str, top_k: int = 5, chunk_type: Optional[str] = None + ) -> List[RetrievedContext]: + """Retrieve similar code chunks. + + Args: + query: The search query + top_k: Number of results to return + chunk_type: Optional filter by chunk type (function/class/module) + + Returns: + List of retrieved contexts sorted by relevance + """ + # Generate query embedding + query_embedding = self.embedding_model.encode([query], show_progress_bar=False).tolist()[0] + + # Build query filters + where_filter = {} + if chunk_type: + where_filter["chunk_type"] = chunk_type + + # Search in ChromaDB + results = self.collection.query( + query_embeddings=[query_embedding], + n_results=top_k, + where=where_filter if where_filter else None, + ) + + # Parse results + contexts = [] + if results and results["documents"] and results["documents"][0]: + for i, doc in enumerate(results["documents"][0]): + metadata = results["metadatas"][0][i] + distance = results["distances"][0][i] if results["distances"] else 0.0 + + contexts.append( + RetrievedContext( + content=doc, + file_path=metadata.get("file_path", ""), + chunk_type=metadata.get("chunk_type", ""), + name=metadata.get("name", ""), + line_range=(metadata.get("line_start", 0), metadata.get("line_end", 0)), + distance=distance, + metadata=metadata, + ) + ) + + return contexts + + def retrieve_by_file(self, file_path: str) -> List[RetrievedContext]: + """Retrieve all chunks from a specific file. + + Args: + file_path: Path to the file + + Returns: + List of all chunks from the file + """ + # Query by file path metadata + results = self.collection.get(where={"file_path": file_path}) + + # Parse results + contexts = [] + if results and results["documents"]: + for i, doc in enumerate(results["documents"]): + metadata = results["metadatas"][i] + + contexts.append( + RetrievedContext( + content=doc, + file_path=metadata.get("file_path", ""), + chunk_type=metadata.get("chunk_type", ""), + name=metadata.get("name", ""), + line_range=(metadata.get("line_start", 0), metadata.get("line_end", 0)), + distance=0.0, # Not applicable for exact match + metadata=metadata, + ) + ) + + return contexts + + def retrieve_functions(self, query: str, top_k: int = 5) -> List[RetrievedContext]: + """Retrieve similar functions. + + Args: + query: The search query + top_k: Number of results to return + + Returns: + List of similar function chunks + """ + return self.retrieve_similar(query, top_k=top_k, chunk_type="function") + + def retrieve_classes(self, query: str, top_k: int = 5) -> List[RetrievedContext]: + """Retrieve similar classes. + + Args: + query: The search query + top_k: Number of results to return + + Returns: + List of similar class chunks + """ + return self.retrieve_similar(query, top_k=top_k, chunk_type="class") diff --git a/scripts/analyze_feedback_data.py b/scripts/analyze_feedback_data.py new file mode 100644 index 0000000..92cac2b --- /dev/null +++ b/scripts/analyze_feedback_data.py @@ -0,0 +1,133 @@ +#!/usr/bin/env python3 +"""Analyze existing feedback and pattern data for ML model training. + +This script scans all Refactron storage directories to assess: +- Volume of feedback records +- Quality of data (completeness) +- Distribution of actions and operation types +- Readiness for ML training +""" + +from pathlib import Path +from collections import Counter +import json +import sys + +# Add parent directory to path for imports +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from refactron.patterns.storage import PatternStorage + + +def analyze_feedback(): + """Analyze all available feedback data.""" + + # Find all pattern storage directories + root = Path('.') + storage_dirs = list(root.glob('**/.refactron/patterns')) + + print(f"Found {len(storage_dirs)} storage directories\n") + + all_feedback = [] + all_patterns = {} + + # Aggregate data from all projects + for storage_dir in storage_dirs: + try: + storage = PatternStorage(storage_dir) + feedback = storage.load_feedback() + patterns = storage.load_patterns() + + all_feedback.extend(feedback) + all_patterns.update(patterns) + + print(f"šŸ“ {storage_dir.parent.parent}") + print(f" Feedback: {len(feedback)}, Patterns: {len(patterns)}") + except Exception as e: + print(f"āš ļø Error loading {storage_dir}: {e}") + + if not all_feedback: + print("\nāŒ No feedback data found!") + print(" Run some refactoring operations and provide feedback first.") + return None + + print(f"\n{'='*60}") + print(f"AGGREGATE STATISTICS") + print(f"{'='*60}\n") + + print(f"šŸ“Š Total Records:") + print(f" Feedback: {len(all_feedback)}") + print(f" Patterns: {len(all_patterns)}") + + # Action distribution + actions = Counter(f.action for f in all_feedback) + print(f"\nāœ… Action Distribution:") + for action, count in actions.most_common(): + pct = count / len(all_feedback) * 100 + print(f" {action:12s}: {count:4d} ({pct:5.1f}%)") + + # Operation types + operation_types = Counter(f.operation_type for f in all_feedback) + print(f"\nšŸ”§ Operation Types:") + for op_type, count in operation_types.most_common(5): + pct = count / len(all_feedback) * 100 + print(f" {op_type:20s}: {count:4d} ({pct:5.1f}%)") + + # Data quality + with_patterns = sum(1 for f in all_feedback if hasattr(f, 'code_pattern_hash') and f.code_pattern_hash) + with_reason = sum(1 for f in all_feedback if hasattr(f, 'reason') and f.reason) + + print(f"\nšŸ“‹ Data Quality:") + print(f" With pattern hash: {with_patterns:4d} ({with_patterns/len(all_feedback)*100:5.1f}%)") + print(f" With reason: {with_reason:4d} ({with_reason/len(all_feedback)*100:5.1f}%)") + + # ML readiness + quality_score = with_patterns / len(all_feedback) if all_feedback else 0 + + print(f"\nšŸŽÆ ML Readiness:") + print(f" Quality Score: {quality_score:.2%}") + + if len(all_feedback) < 50: + print(f" Status: āŒ INSUFFICIENT DATA") + print(f" Need: {50 - len(all_feedback)} more feedback records") + elif quality_score < 0.7: + print(f" Status: āš ļø LOW QUALITY") + print(f" Many records missing pattern hashes") + else: + print(f" Status: āœ… READY FOR TRAINING") + + # Save detailed report + report = { + 'summary': { + 'total_feedback': len(all_feedback), + 'total_patterns': len(all_patterns), + 'quality_score': quality_score, + 'ml_ready': len(all_feedback) >= 50 and quality_score >= 0.7 + }, + 'actions': dict(actions), + 'operation_types': dict(operation_types), + 'quality': { + 'with_pattern_hash': with_patterns, + 'with_reason': with_reason + } + } + + report_file = Path('feedback_analysis.json') + with open(report_file, 'w') as f: + json.dump(report, f, indent=2) + + print(f"\nšŸ’¾ Detailed report saved to: {report_file}") + + return report + + +if __name__ == '__main__': + print("šŸ” Refactron Feedback Data Analysis\n") + report = analyze_feedback() + + if report and report['summary']['ml_ready']: + print("\n✨ Ready to proceed with ML model training!") + elif report: + print("\nā³ Collect more feedback data before training.") + + sys.exit(0 if report else 1) diff --git a/tests/test_analyzer_edge_cases.py b/tests/test_analyzer_edge_cases.py index b5d5b96..8cf4f26 100644 --- a/tests/test_analyzer_edge_cases.py +++ b/tests/test_analyzer_edge_cases.py @@ -781,7 +781,7 @@ def run(cmd): subprocess.Popen(cmd, shell=True) """ issues = analyzer.analyze(Path("test.py"), code) - assert any(issue.rule_id == "SEC005" for issue in issues) + assert any(issue.rule_id == "SEC0052" for issue in issues) def test_yaml_safe_load_not_flagged(self) -> None: """Test that yaml.safe_load is not flagged.""" diff --git a/tests/test_backend_client.py b/tests/test_backend_client.py new file mode 100644 index 0000000..fec16a8 --- /dev/null +++ b/tests/test_backend_client.py @@ -0,0 +1,73 @@ +"""Tests for BackendLLMClient.""" + +import pytest +import requests +from unittest.mock import MagicMock, patch + +from refactron.llm.backend_client import BackendLLMClient + + +@pytest.fixture +def mock_credentials(): + with patch("refactron.llm.backend_client.load_credentials") as mock: + creds = MagicMock() + creds.api_key = "test-api-key" + creds.access_token = "test-access-token" + mock.return_value = creds + yield mock + + +def test_backend_client_initialization(): + client = BackendLLMClient(backend_url="http://test-backend:3000") + assert client.backend_url == "http://test-backend:3000" + assert client.model == "llama-3.3-70b-versatile" + + +@patch("requests.post") +def test_backend_client_generate(mock_post, mock_credentials): + # Mock successful response + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = {"content": "Suggested code"} + mock_post.return_value = mock_response + + client = BackendLLMClient(backend_url="http://test-backend:3000") + result = client.generate(prompt="Refactor this", system="You are an expert") + + assert result == "Suggested code" + + # Verify request + mock_post.assert_called_once() + args, kwargs = mock_post.call_args + assert args[0] == "http://test-backend:3000/api/llm/generate" + assert kwargs["json"]["prompt"] == "Refactor this" + assert kwargs["json"]["system"] == "You are an expert" + assert kwargs["headers"]["X-API-Key"] == "test-api-key" + + +@patch("requests.post") +def test_backend_client_error_handling(mock_post, mock_credentials): + # Mock error response + mock_response = MagicMock() + mock_response.status_code = 500 + mock_response.text = "Internal Server Error" + mock_post.return_value = mock_response + + client = BackendLLMClient() + with pytest.raises(RuntimeError, match="Backend LLM proxy error \(500\): Internal Server Error"): + client.generate(prompt="Refactor this") + + +@patch("requests.get") +def test_backend_client_check_health(mock_get): + # Mock healthy response + mock_response = MagicMock() + mock_response.status_code = 200 + mock_get.return_value = mock_response + + client = BackendLLMClient() + assert client.check_health() is True + + # Mock unhealthy response + mock_response.status_code = 503 + assert client.check_health() is False diff --git a/tests/test_false_positive_reduction.py b/tests/test_false_positive_reduction.py index 441a0cd..43a9883 100644 --- a/tests/test_false_positive_reduction.py +++ b/tests/test_false_positive_reduction.py @@ -222,7 +222,7 @@ def run_command(cmd): issues = analyzer.analyze(Path("myapp/utils.py"), code) # shell=True has high confidence (0.95), should be kept - assert len([i for i in issues if i.rule_id == "SEC005"]) > 0 + assert len([i for i in issues if i.rule_id == "SEC0052"]) > 0 def test_default_min_confidence(self): """Default minimum confidence should be 0.5.""" @@ -345,4 +345,4 @@ def run(cmd): dangerous_issues = analyzer.analyze(Path("myapp/runner.py"), dangerous_code) # shell=True has high confidence, should be reported assert len(dangerous_issues) > 0 - assert any(i.rule_id == "SEC005" for i in dangerous_issues) + assert any(i.rule_id == "SEC0052" for i in dangerous_issues) diff --git a/tests/test_groq_client.py b/tests/test_groq_client.py new file mode 100644 index 0000000..2ac4ca8 --- /dev/null +++ b/tests/test_groq_client.py @@ -0,0 +1,110 @@ +"""Tests for the Groq LLM client.""" + +import os +import pytest +from unittest.mock import Mock, patch + +from refactron.llm.client import GroqClient + + +class TestGroqClient: + """Test cases for GroqClient.""" + + @pytest.fixture + def mock_groq_api(self): + """Mock the Groq API.""" + with patch('refactron.llm.client.Groq') as mock_groq: + # Mock response + mock_response = Mock() + mock_response.choices = [Mock()] + mock_response.choices[0].message.content = "Generated text response" + + # Mock client + mock_client = Mock() + mock_client.chat.completions.create.return_value = mock_response + mock_groq.return_value = mock_client + + yield mock_groq + + def test_client_initialization_with_api_key(self, mock_groq_api): + """Test client initialization with explicit API key.""" + client = GroqClient(api_key="test_key_123") + + assert client.api_key == "test_key_123" + assert client.model == "llama-3.3-70b-versatile" + assert client.temperature == 0.2 + assert client.max_tokens == 2000 + + def test_client_initialization_from_env(self, mock_groq_api): + """Test client initialization from environment variable.""" + with patch.dict(os.environ, {'GROQ_API_KEY': 'env_key_456'}): + client = GroqClient() + assert client.api_key == 'env_key_456' + + def test_client_initialization_no_api_key(self, mock_groq_api): + """Test that missing API key raises error.""" + with patch.dict(os.environ, {}, clear=True): + with pytest.raises(RuntimeError, match="GROQ_API_KEY"): + GroqClient() + + def test_generate_basic(self, mock_groq_api): + """Test basic text generation.""" + client = GroqClient(api_key="testkey") + response = client.generate("Hello, how are you?") + + assert response == "Generated text response" + assert client.client.chat.completions.create.called + + def test_generate_with_system_prompt(self, mock_groq_api): + """Test generation with system prompt.""" + client = GroqClient(api_key="testkey") + response = client.generate( + prompt="User message", + system="You are a helpful assistant" + ) + + assert response == "Generated text response" + + # Check that both system and user messages were sent + call_args = client.client.chat.completions.create.call_args + messages = call_args.kwargs['messages'] + assert len(messages) == 2 + assert messages[0]['role'] == 'system' + assert messages[1]['role'] == 'user' + + def test_generate_with_custom_params(self, mock_groq_api): + """Test generation with custom temperature and max_tokens.""" + client = GroqClient(api_key="testkey") + client.generate( + prompt="Test", + temperature=0.5, + max_tokens=1000 + ) + + call_args = client.client.chat.completions.create.call_args + assert call_args.kwargs['temperature'] == 0.5 + assert call_args.kwargs['max_tokens'] == 1000 + + def test_custom_model(self, mock_groq_api): + """Test initialization with custom model.""" + client = GroqClient( + api_key="testkey", + model="llama-3.1-8b-instant" + ) + + assert client.model == "llama-3.1-8b-instant" + + def test_check_health_success(self, mock_groq_api): + """Test health check when API is accessible.""" + client = GroqClient(api_key="testkey") + is_healthy = client.check_health() + + assert is_healthy is True + + def test_check_health_failure(self, mock_groq_api): + """Test health check when API fails.""" + client = GroqClient(api_key="testkey") + client.client.chat.completions.create.side_effect = Exception("API Error") + + is_healthy = client.check_health() + assert is_healthy is False diff --git a/tests/test_llm_orchestrator.py b/tests/test_llm_orchestrator.py new file mode 100644 index 0000000..381187c --- /dev/null +++ b/tests/test_llm_orchestrator.py @@ -0,0 +1,145 @@ +"""Tests for LLM Orchestrator.""" + +import json +from unittest.mock import Mock, MagicMock +from pathlib import Path + +import pytest +from refactron.core.models import CodeIssue, IssueLevel, IssueCategory +from refactron.llm.orchestrator import LLMOrchestrator +from refactron.llm.models import SuggestionStatus, RefactoringSuggestion +from refactron.rag.retriever import RetrievedContext + + +class TestLLMOrchestrator: + """Tests for LLMOrchestrator.""" + + @pytest.fixture + def mock_retriever(self): + retriever = Mock() + retriever.retrieve_similar.return_value = [] + return retriever + + @pytest.fixture + def mock_client(self): + client = Mock() + client.model = "llama-3.3-70b-versatile" + # Return valid JSON in code block + client.generate.return_value = """ +```json +{ + "proposed_code": "def fixed(): pass", + "explanation": "Fixed the bug", + "reasoning": "Because it was broken" +} +``` +""" + return client + + @pytest.fixture + def mock_safety(self): + gate = Mock() + result = MagicMock() + result.passed = True + result.issues = [] + gate.validate.return_value = result + return gate + + @pytest.fixture + def sample_issue(self): + return CodeIssue( + category=IssueCategory.CODE_SMELL, + level=IssueLevel.WARNING, + message="Function is too long", + file_path=Path("/test.py"), + line_number=10 + ) + + def test_generate_suggestion_basic(self, mock_client, mock_safety, sample_issue): + """Test generating a successful suggestion.""" + orchestrator = LLMOrchestrator( + llm_client=mock_client, + safety_gate=mock_safety + ) + + suggestion = orchestrator.generate_suggestion( + issue=sample_issue, + original_code="def broken(): pass" + ) + + assert suggestion.status == SuggestionStatus.PENDING + assert suggestion.proposed_code == "def fixed(): pass" + assert suggestion.explanation == "Fixed the bug" + assert suggestion.model_name == "llama-3.3-70b-versatile" + + # Verify prompt construction (implicity) by checking generate call + mock_client.generate.assert_called_once() + args = mock_client.generate.call_args + assert "Function is too long" in args.kwargs['prompt'] + assert "def broken(): pass" in args.kwargs['prompt'] + + def test_generate_suggestion_with_rag(self, mock_client, mock_retriever, sample_issue): + """Test generation with RAG context.""" + # Setup retriever to return context + mock_retriever.retrieve_similar.return_value = [ + RetrievedContext( + content="def similar_func(): pass", + file_path="/similar.py", + chunk_type="function", + name="similar_func", + line_range=(1, 2), + distance=0.1, + metadata={} + ) + ] + + orchestrator = LLMOrchestrator( + retriever=mock_retriever, + llm_client=mock_client + ) + + suggestion = orchestrator.generate_suggestion( + issue=sample_issue, + original_code="original" + ) + + assert "/similar.py" in suggestion.context_files + + # Verify prompt contains RAG context + args = mock_client.generate.call_args + assert "def similar_func(): pass" in args.kwargs['prompt'] + + def test_bad_llm_response(self, mock_client, sample_issue): + """Test handling of invalid JSON from LLM.""" + mock_client.generate.return_value = "This is not JSON" + + orchestrator = LLMOrchestrator(llm_client=mock_client) + + suggestion = orchestrator.generate_suggestion( + issue=sample_issue, + original_code="original" + ) + + assert suggestion.status == SuggestionStatus.FAILED + assert "JSONDecodeError" in suggestion.explanation or "Expecting value" in suggestion.explanation + + def test_safety_check_failure(self, mock_client, mock_safety, sample_issue): + """Test rejection when safety check fails.""" + # Setup safety failure + result = MagicMock() + result.passed = False + result.issues = ["Syntax Error"] + mock_safety.validate.return_value = result + + orchestrator = LLMOrchestrator( + llm_client=mock_client, + safety_gate=mock_safety + ) + + suggestion = orchestrator.generate_suggestion( + issue=sample_issue, + original_code="original" + ) + + assert suggestion.status == SuggestionStatus.REJECTED + assert not suggestion.safety_result.passed diff --git a/tests/test_patterns_fingerprint.py b/tests/test_patterns_fingerprint.py index 92047eb..6c99a37 100644 --- a/tests/test_patterns_fingerprint.py +++ b/tests/test_patterns_fingerprint.py @@ -56,9 +56,8 @@ def test_fingerprint_code_whitespace_insensitive(self): hash1 = self.fingerprinter.fingerprint_code(code1) hash2 = self.fingerprinter.fingerprint_code(code2) - # Should be different because normalization preserves some structure - # But let's verify the hashing works - assert hash1 != hash2 # Different whitespace patterns produce different hashes + # Should produce same hash because normalization anonymizes code + assert hash1 == hash2 # Comment updated to match actual behavior def test_fingerprint_code_comment_removal(self): """Test that comments are removed before fingerprinting.""" @@ -226,7 +225,7 @@ def test_extract_ast_pattern_function(self): code = "def hello():\n print('world')" pattern = self.fingerprinter._extract_ast_pattern(code) - assert "FUNC:hello" in pattern + assert "FUNC" in pattern assert "CALL" in pattern def test_extract_ast_pattern_class(self): @@ -234,8 +233,8 @@ def test_extract_ast_pattern_class(self): code = "class MyClass:\n def method(self):\n pass" pattern = self.fingerprinter._extract_ast_pattern(code) - assert "CLASS:MyClass" in pattern - assert "FUNC:method" in pattern + assert "CLASS" in pattern + assert "FUNC" in pattern # method function def test_extract_ast_pattern_control_flow(self): """Test AST pattern extraction for control flow.""" diff --git a/tests/test_patterns_integration.py b/tests/test_patterns_integration.py index 3e7f51f..de724b4 100644 --- a/tests/test_patterns_integration.py +++ b/tests/test_patterns_integration.py @@ -262,10 +262,10 @@ def func2(): pattern_hashes1 = {p.pattern_hash for p in patterns1.values()} pattern_hashes2 = {p.pattern_hash for p in patterns2.values()} - # If both projects learned patterns, they should be different (different code) - if pattern_hashes1 and pattern_hashes2: - # Different code should produce different pattern hashes - assert pattern_hashes1 != pattern_hashes2 or len(pattern_hashes1) == 0 + # Verify the key property: storage directories are different (isolation works) + # Note: Anonymized fingerprinting may make structurally similar code have same hash + # which is correct behavior - the test should check storage isolation + assert refactron1.pattern_storage.storage_dir != refactron2.pattern_storage.storage_dir # Verify storage directories are separate (isolation mechanism) assert refactron1.pattern_storage.storage_dir != refactron2.pattern_storage.storage_dir diff --git a/tests/test_patterns_learner.py b/tests/test_patterns_learner.py index 674d9c3..ba663d2 100644 --- a/tests/test_patterns_learner.py +++ b/tests/test_patterns_learner.py @@ -228,8 +228,10 @@ def test_batch_learn_processes_multiple_feedbacks(self): stats = learner.batch_learn(operations_with_feedback) assert stats["processed"] == 3 - assert stats["created"] == 3 # Each has unique pattern - assert stats["updated"] == 0 + # Anonymization makes structurally identical code produce the same hash + # so all 3 will update the same pattern (1 created, 2 updated) + assert stats["created"] == 1 + assert stats["updated"] == 2 assert stats["failed"] == 0 def test_batch_learn_with_none_list(self): diff --git a/tests/test_patterns_ranker.py b/tests/test_patterns_ranker.py index 81cb08e..5352b1d 100644 --- a/tests/test_patterns_ranker.py +++ b/tests/test_patterns_ranker.py @@ -120,25 +120,25 @@ def test_rank_operations_sorts_by_score(self): fingerprinter = PatternFingerprinter() ranker = RefactoringRanker(storage, matcher, fingerprinter) - # Create high-acceptance pattern - pattern_hash1 = fingerprinter.fingerprint_code("def high(): pass") + # Create high-acceptance pattern - with return statement + pattern_hash1 = fingerprinter.fingerprint_code("def high():\n return 42") pattern1 = RefactoringPattern.create( pattern_hash=pattern_hash1, operation_type="extract_method", - code_snippet_before="def high(): pass", - code_snippet_after="def high_refactored(): pass", + code_snippet_before="def high():\n return 42", + code_snippet_after="def high_refactored():\n return 42", ) pattern1.acceptance_rate = 0.9 pattern1.total_occurrences = 20 storage.save_pattern(pattern1) - # Create low-acceptance pattern - pattern_hash2 = fingerprinter.fingerprint_code("def low(): pass") + # Create low-acceptance pattern - with if statement (different structure) + pattern_hash2 = fingerprinter.fingerprint_code("def low():\n if True:\n pass") pattern2 = RefactoringPattern.create( pattern_hash=pattern_hash2, operation_type="extract_method", - code_snippet_before="def low(): pass", - code_snippet_after="def low_refactored(): pass", + code_snippet_before="def low():\n if True:\n pass", + code_snippet_after="def low_refactored():\n if True:\n pass", ) pattern2.acceptance_rate = 0.3 pattern2.total_occurrences = 5 @@ -149,8 +149,8 @@ def test_rank_operations_sorts_by_score(self): file_path=Path("test.py"), line_number=10, description="High acceptance", - old_code="def high(): pass", - new_code="def high_refactored(): pass", + old_code="def high():\n return 42", + new_code="def high_refactored():\n return 42", risk_score=0.2, ) @@ -159,8 +159,8 @@ def test_rank_operations_sorts_by_score(self): file_path=Path("test.py"), line_number=20, description="Low acceptance", - old_code="def low(): pass", - new_code="def low_refactored(): pass", + old_code="def low():\n if True:\n pass", + new_code="def low_refactored():\n if True:\n pass", risk_score=0.2, ) diff --git a/tests/test_rag_chunker.py b/tests/test_rag_chunker.py new file mode 100644 index 0000000..a7dbae9 --- /dev/null +++ b/tests/test_rag_chunker.py @@ -0,0 +1,147 @@ +"""Tests for the RAG chunker module.""" + +import tempfile +from pathlib import Path + +import pytest + +from refactron.rag.chunker import CodeChunker, CodeChunk +from refactron.rag.parser import CodeParser + + +class TestCodeChunker: + """Test cases for CodeChunker.""" + + @pytest.fixture + def parser(self): + """Create a CodeParser instance.""" + return CodeParser() + + @pytest.fixture + def chunker(self, parser): + """Create a CodeChunker instance.""" + return CodeChunker(parser) + + @pytest.fixture + def temp_python_file(self): + """Create a temporary Python file for testing.""" + with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f: + content = '''"""Module level docstring.""" + +import os +import sys + +def test_function(x): + """Test function docstring.""" + return x * 2 + +def another_function(): + """Another function.""" + pass + +class TestClass: + """Test class docstring.""" + + def method_one(self): + """Method docstring.""" + return 1 +''' + f.write(content) + temp_path = Path(f.name) + + yield temp_path + + # Cleanup + temp_path.unlink() + + def test_chunker_initialization(self, parser, chunker): + """Test that chunker initializes with parser.""" + assert chunker.parser is not None + assert chunker.parser == parser + + def test_chunk_file_basic(self, chunker, temp_python_file): + """Test basic file chunking.""" + chunks = chunker.chunk_file(temp_python_file) + + assert len(chunks) > 0 + assert all(isinstance(chunk, CodeChunk) for chunk in chunks) + + def test_module_chunk_created(self, chunker, temp_python_file): + """Test that module chunk is created when there are imports/docstring.""" + chunks = chunker.chunk_file(temp_python_file) + + module_chunks = [c for c in chunks if c.chunk_type == "module"] + assert len(module_chunks) == 1 + + module_chunk = module_chunks[0] + assert "Module level docstring" in module_chunk.content + assert "import os" in module_chunk.content + + def test_function_chunks_created(self, chunker, temp_python_file): + """Test that function chunks are created correctly.""" + chunks = chunker.chunk_file(temp_python_file) + + function_chunks = [c for c in chunks if c.chunk_type == "function"] + assert len(function_chunks) == 2 + + # Check first function chunk + func_names = [c.name for c in function_chunks] + assert "test_function" in func_names + assert "another_function" in func_names + + def test_class_chunks_created(self, chunker, temp_python_file): + """Test that class chunks are created.""" + chunks = chunker.chunk_file(temp_python_file) + + class_chunks = [c for c in chunks if c.chunk_type == "class"] + assert len(class_chunks) == 1 + + class_chunk = class_chunks[0] + assert class_chunk.name == "TestClass" + assert "Test class docstring" in class_chunk.content + + def test_method_chunks_created(self, chunker, temp_python_file): + """Test that method chunks are created.""" + chunks = chunker.chunk_file(temp_python_file) + + method_chunks = [c for c in chunks if c.chunk_type == "method"] + assert len(method_chunks) == 1 + + method_chunk = method_chunks[0] + assert method_chunk.name == "method_one" + assert method_chunk.metadata["class_name"] == "TestClass" + + def test_chunk_metadata(self, chunker, temp_python_file): + """Test that chunk metadata is populated correctly.""" + chunks = chunker.chunk_file(temp_python_file) + + for chunk in chunks: + assert chunk.file_path == str(temp_python_file) + assert chunk.line_range[0] > 0 + assert chunk.line_range[1] >= chunk.line_range[0] + assert chunk.metadata is not None + + def test_empty_file(self, chunker): + """Test chunking an empty file.""" + with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f: + f.write("") + temp_path = Path(f.name) + + try: + chunks = chunker.chunk_file(temp_path) + assert len(chunks) == 0 + finally: + temp_path.unlink() + + def test_file_with_only_imports(self, chunker): + """Test file with only imports.""" + with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f: + f.write("import os\n") + temp_path = Path(f.name) + + try: + chunks = chunker.chunk_file(temp_path) + assert len(chunks) == 1 + assert chunks[0].chunk_type == "module" + finally: + temp_path.unlink() diff --git a/tests/test_rag_indexer.py b/tests/test_rag_indexer.py new file mode 100644 index 0000000..dea5c20 --- /dev/null +++ b/tests/test_rag_indexer.py @@ -0,0 +1,171 @@ +"""Tests for the RAG indexer module.""" + +import tempfile +from pathlib import Path + +import pytest +from unittest.mock import Mock, MagicMock, patch, create_autospec +import sys + +# Create a comprehensive mock for transformers that handles all submodule access +transformers_mock = MagicMock() +transformers_mock.__path__ = [] # Make it look like a package + +# Patch sys.modules before importing anything that depends on sentence-transformers +with patch.dict('sys.modules', { + 'transformers': transformers_mock, + 'transformers.configuration_utils': MagicMock(), + 'transformers.utils': MagicMock(), + 'transformers.models': MagicMock(), + 'transformers.file_utils': MagicMock(), + 'transformers.tokenization_utils_base': MagicMock(), +}): + from refactron.rag.chunker import CodeChunk + + +class TestRAGIndexer: + """Test cases for RAGIndexer.""" + + @pytest.fixture + def temp_workspace(self): + """Create a temporary workspace directory.""" + with tempfile.TemporaryDirectory() as tmpdir: + workspace_path = Path(tmpdir) + + # Create sample Python files + (workspace_path / "simple.py").write_text(''' +"""Simple module.""" + +def hello(): + """Say hello.""" + return "Hello" +''') + + (workspace_path / "utils.py").write_text(''' +"""Utility functions.""" + +class Calculator: + """A simple calculator.""" + + def add(self, x, y): + """Add two numbers.""" + return x + y +''') + + yield workspace_path + + @patch('refactron.rag.indexer.CHROMA_AVAILABLE', True) + @patch('refactron.rag.indexer.chromadb') + @patch('refactron.rag.indexer.Settings') + @patch('refactron.rag.indexer.SentenceTransformer') + @patch('refactron.rag.indexer.CodeParser') + def test_indexer_initialization(self, mock_parser, mock_transformer, mock_settings, mock_chroma, temp_workspace): + """Test indexer initialization.""" + from refactron.rag.indexer import RAGIndexer + + indexer = RAGIndexer(temp_workspace) + + assert indexer.workspace_path == temp_workspace + assert indexer.index_path == temp_workspace / ".rag" + assert indexer.embedding_model_name == "all-MiniLM-L6-v2" + + @patch('refactron.rag.indexer.CHROMA_AVAILABLE', False) + def test_indexer_requires_chromadb(self, temp_workspace): + """Test that indexer requires ChromaDB.""" + from refactron.rag.indexer import RAGIndexer + + with pytest.raises(RuntimeError, match="ChromaDB is not available"): + RAGIndexer(temp_workspace) + + @patch('refactron.rag.indexer.CHROMA_AVAILABLE', True) + @patch('refactron.rag.indexer.chromadb') + @patch('refactron.rag.indexer.Settings') + @patch('refactron.rag.indexer.SentenceTransformer') + @patch('refactron.rag.indexer.CodeParser') + def test_add_chunks(self, mock_parser, mock_transformer, mock_settings, mock_chroma, temp_workspace): + """Test adding chunks to the index.""" + from refactron.rag.indexer import RAGIndexer + + # Setup mocks + mock_collection = Mock() + mock_client = Mock() + mock_client.get_or_create_collection.return_value = mock_collection + mock_chroma.PersistentClient.return_value = mock_client + + mock_model = Mock() + mock_model.encode.return_value = Mock() + mock_model.encode.return_value.tolist.return_value = [[0.1, 0.2, 0.3]] + mock_transformer.return_value = mock_model + + indexer = RAGIndexer(temp_workspace) + + # Create test chunk + chunk = CodeChunk( + content="def test(): pass", + chunk_type="function", + file_path="/test/file.py", + line_range=(1, 1), + name="test", + dependencies=[], + metadata={} + ) + + indexer.add_chunks([chunk]) + + # Verify chunk was added + assert mock_collection.add.called + + @patch('refactron.rag.indexer.CHROMA_AVAILABLE', True) + @patch('refactron.rag.indexer.chromadb') + @patch('refactron.rag.indexer.Settings') + @patch('refactron.rag.indexer.SentenceTransformer') + @patch('refactron.rag.indexer.CodeParser') + def test_add_empty_chunks(self, mock_parser, mock_transformer, mock_settings, mock_chroma, temp_workspace): + """Test that adding empty chunks list does nothing.""" + from refactron.rag.indexer import RAGIndexer + + mock_collection = Mock() + mock_client = Mock() + mock_client.get_or_create_collection.return_value = mock_collection + mock_chroma.PersistentClient.return_value = mock_client + + mock_model = Mock() + mock_transformer.return_value = mock_model + + indexer = RAGIndexer(temp_workspace) + indexer.add_chunks([]) + + # Verify nothing was added + assert not mock_collection.add.called + + @patch('refactron.rag.indexer.CHROMA_AVAILABLE', True) + @patch('refactron.rag.indexer.chromadb') + @patch('refactron.rag.indexer.Settings') + @patch('refactron.rag.indexer.SentenceTransformer') + @patch('refactron.rag.indexer.CodeParser') + def test_get_stats(self, mock_parser, mock_transformer, mock_settings, mock_chroma, temp_workspace): + """Test getting index statistics.""" + from refactron.rag.indexer import RAGIndexer + + mock_collection = Mock() + mock_client = Mock() + mock_client.get_or_create_collection.return_value = mock_collection + mock_chroma.PersistentClient.return_value = mock_client + + mock_model = Mock() + mock_transformer.return_value = mock_model + + indexer = RAGIndexer(temp_workspace) + + # Save some metadata + indexer._save_metadata({ + "total_chunks": 10, + "total_files": 2, + "chunk_types": {"function": 8, "class": 2} + }) + + stats = indexer.get_stats() + + assert stats.total_chunks == 10 + assert stats.total_files == 2 + assert stats.chunk_types["function"] == 8 diff --git a/tests/test_rag_parser.py b/tests/test_rag_parser.py new file mode 100644 index 0000000..19af222 --- /dev/null +++ b/tests/test_rag_parser.py @@ -0,0 +1,148 @@ +"""Tests for the RAG parser module.""" + +import tempfile +from pathlib import Path + +import pytest + +from refactron.rag.parser import CodeParser, ParsedFile, ParsedFunction, ParsedClass + + +class TestCodeParser: + """Test cases for CodeParser.""" + + @pytest.fixture + def parser(self): + """Create a CodeParser instance.""" + return CodeParser() + + @pytest.fixture + def temp_python_file(self): + """Create a temporary Python file for testing.""" + with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f: + content = '''"""Module docstring for testing.""" + +import os +import sys +from pathlib import Path + +def simple_function(x, y): + """Add two numbers.""" + return x + y + +def another_function(): + """Function without params.""" + pass + +class TestClass: + """A test class.""" + + def method_one(self): + """First method.""" + return 1 + + def method_two(self, param): + """Second method.""" + return param * 2 +''' + f.write(content) + temp_path = Path(f.name) + + yield temp_path + + # Cleanup + temp_path.unlink() + + def test_parser_initialization(self, parser): + """Test that parser initializes correctly.""" + assert parser is not None + assert parser.parser is not None + + def test_parse_file_basic(self, parser, temp_python_file): + """Test parsing a basic Python file.""" + parsed = parser.parse_file(temp_python_file) + + assert isinstance(parsed, ParsedFile) + assert parsed.file_path == str(temp_python_file) + assert parsed.module_docstring == "Module docstring for testing." + + def test_extract_imports(self, parser, temp_python_file): + """Test that imports are extracted correctly.""" + parsed = parser.parse_file(temp_python_file) + + assert len(parsed.imports) == 3 + assert "import os" in parsed.imports + assert "import sys" in parsed.imports + assert any("pathlib" in imp for imp in parsed.imports) + + def test_extract_functions(self, parser, temp_python_file): + """Test that functions are extracted correctly.""" + parsed = parser.parse_file(temp_python_file) + + assert len(parsed.functions) == 2 + + # Check first function + func1 = parsed.functions[0] + assert isinstance(func1, ParsedFunction) + assert func1.name == "simple_function" + assert func1.docstring == "Add two numbers." + assert len(func1.params) >= 2 # Should have x and y + + # Check second function + func2 = parsed.functions[1] + assert func2.name == "another_function" + + def test_extract_classes(self, parser, temp_python_file): + """Test that classes are extracted correctly.""" + parsed = parser.parse_file(temp_python_file) + + assert len(parsed.classes) == 1 + + # Check class + cls = parsed.classes[0] + assert isinstance(cls, ParsedClass) + assert cls.name == "TestClass" + assert cls.docstring == "A test class." + + # Check methods + assert len(cls.methods) == 2 + assert cls.methods[0].name == "method_one" + assert cls.methods[1].name == "method_two" + + def test_parse_invalid_file(self, parser): + """Test parsing a non-existent file raises error.""" + with pytest.raises(FileNotFoundError): + parser.parse_file(Path("/nonexistent/file.py")) + + def test_parse_empty_file(self, parser): + """Test parsing an empty Python file.""" + with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f: + f.write("") + temp_path = Path(f.name) + + try: + parsed = parser.parse_file(temp_path) + assert parsed.module_docstring is None + assert len(parsed.imports) == 0 + assert len(parsed.functions) == 0 + assert len(parsed.classes) == 0 + finally: + temp_path.unlink() + + def test_function_line_ranges(self, parser, temp_python_file): + """Test that line ranges are captured correctly.""" + parsed = parser.parse_file(temp_python_file) + + for func in parsed.functions: + assert func.line_range[0] > 0 + assert func.line_range[1] >= func.line_range[0] + + def test_class_methods_have_correct_metadata(self, parser, temp_python_file): + """Test that class methods preserve metadata.""" + parsed = parser.parse_file(temp_python_file) + + test_class = parsed.classes[0] + for method in test_class.methods: + assert method.name in ["method_one", "method_two"] + assert method.docstring is not None + assert len(method.body) > 0 diff --git a/tests/test_rag_retriever.py b/tests/test_rag_retriever.py new file mode 100644 index 0000000..0ef3cc9 --- /dev/null +++ b/tests/test_rag_retriever.py @@ -0,0 +1,195 @@ +"""Tests for the RAG retriever module.""" + +import tempfile +from pathlib import Path + +import pytest +from unittest.mock import Mock, MagicMock, patch + +# Create a comprehensive mock for transformers +transformers_mock = MagicMock() +transformers_mock.__path__ = [] + +with patch.dict('sys.modules', { + 'transformers': transformers_mock, + 'transformers.configuration_utils': MagicMock(), + 'transformers.utils': MagicMock(), + 'transformers.models': MagicMock(), + 'transformers.file_utils': MagicMock(), + 'transformers.tokenization_utils_base': MagicMock(), +}): + from refactron.rag.retriever import ContextRetriever, RetrievedContext + + +class TestContextRetriever: + """Test cases for ContextRetriever.""" + + @pytest.fixture + def temp_workspace(self): + """Create a temporary workspace directory.""" + with tempfile.TemporaryDirectory() as tmpdir: + workspace_path = Path(tmpdir) + (workspace_path / ".rag").mkdir() + yield workspace_path + + @patch('refactron.rag.retriever.CHROMA_AVAILABLE', True) + @patch('refactron.rag.retriever.chromadb') + @patch('refactron.rag.retriever.Settings') + @patch('refactron.rag.retriever.SentenceTransformer') + def test_retriever_initialization(self, mock_transformer, mock_settings, mock_chroma, temp_workspace): + """Test retriever initialization.""" + mock_collection = Mock() + mock_client = Mock() + mock_client.get_collection.return_value = mock_collection + mock_chroma.PersistentClient.return_value = mock_client + + mock_model = Mock() + mock_transformer.return_value = mock_model + + retriever = ContextRetriever(temp_workspace) + + assert retriever.workspace_path == temp_workspace + assert retriever.index_path == temp_workspace / ".rag" + + @patch('refactron.rag.retriever.CHROMA_AVAILABLE', False) + def test_retriever_requires_chromadb(self, temp_workspace): + """Test that retriever requires ChromaDB.""" + with pytest.raises(RuntimeError, match="ChromaDB is not available"): + ContextRetriever(temp_workspace) + + @patch('refactron.rag.retriever.CHROMA_AVAILABLE', True) + @patch('refactron.rag.retriever.chromadb') + @patch('refactron.rag.retriever.Settings') + @patch('refactron.rag.retriever.SentenceTransformer') + def test_retriever_missing_index(self, mock_transformer, mock_settings, mock_chroma, temp_workspace): + """Test that missing index raises error.""" + mock_client = Mock() + mock_client.get_collection.side_effect = Exception("Collection not found") + mock_chroma.PersistentClient.return_value = mock_client + + mock_model = Mock() + mock_transformer.return_value = mock_model + + with pytest.raises(RuntimeError, match="Index not found"): + ContextRetriever(temp_workspace) + + @patch('refactron.rag.retriever.CHROMA_AVAILABLE', True) + @patch('refactron.rag.retriever.chromadb') + @patch('refactron.rag.retriever.Settings') + @patch('refactron.rag.retriever.SentenceTransformer') + def test_retrieve_similar(self, mock_transformer, mock_settings, mock_chroma, temp_workspace): + """Test retrieving similar code chunks.""" + # Setup mock collection with results + mock_collection = Mock() + mock_collection.query.return_value = { + 'documents': [['def test(): pass']], + 'metadatas': [[{ + 'file_path': '/test.py', + 'chunk_type': 'function', + 'name': 'test', + 'line_start': 1, + 'line_end': 1 + }]], + 'distances': [[0.15]] + } + + mock_client = Mock() + mock_client.get_collection.return_value = mock_collection + mock_chroma.PersistentClient.return_value = mock_client + + mock_model = Mock() + mock_model.encode.return_value.tolist.return_value = [[0.1, 0.2]] + mock_transformer.return_value = mock_model + + retriever = ContextRetriever(temp_workspace) + results = retriever.retrieve_similar("test function", top_k=1) + + assert len(results) == 1 + assert isinstance(results[0], RetrievedContext) + assert results[0].name == 'test' + assert results[0].chunk_type == 'function' + + @patch('refactron.rag.retriever.CHROMA_AVAILABLE', True) + @patch('refactron.rag.retriever.chromadb') + @patch('refactron.rag.retriever.Settings') + @patch('refactron.rag.retriever.SentenceTransformer') + def test_retrieve_by_file(self, mock_transformer, mock_settings, mock_chroma, temp_workspace): + """Test retrieving chunks by file path.""" + mock_collection = Mock() + mock_collection.get.return_value = { + 'documents': ['def test(): pass'], + 'metadatas': [{ + 'file_path': '/test.py', + 'chunk_type': 'function', + 'name': 'test', + 'line_start': 1, + 'line_end': 1 + }] + } + + mock_client = Mock() + mock_client.get_collection.return_value = mock_collection + mock_chroma.PersistentClient.return_value = mock_client + + mock_model = Mock() + mock_transformer.return_value = mock_model + + retriever = ContextRetriever(temp_workspace) + results = retriever.retrieve_by_file("/test.py") + + assert len(results) == 1 + assert results[0].file_path == '/test.py' + + @patch('refactron.rag.retriever.CHROMA_AVAILABLE', True) + @patch('refactron.rag.retriever.chromadb') + @patch('refactron.rag.retriever.Settings') + @patch('refactron.rag.retriever.SentenceTransformer') + def test_retrieve_functions(self, mock_transformer, mock_settings, mock_chroma, temp_workspace): + """Test retrieving only function chunks.""" + mock_collection = Mock() + mock_collection.query.return_value = { + 'documents': [[]], + 'metadatas': [[]], + 'distances': [[]] + } + + mock_client = Mock() + mock_client.get_collection.return_value = mock_collection + mock_chroma.PersistentClient.return_value = mock_client + + mock_model = Mock() + mock_model.encode.return_value.tolist.return_value = [[0.1, 0.2]] + mock_transformer.return_value = mock_model + + retriever = ContextRetriever(temp_workspace) + retriever.retrieve_functions("test", top_k=5) + + # Verify that chunk_type filter was used + call_args = mock_collection.query.call_args + assert call_args.kwargs.get('where') == {"chunk_type": "function"} + + @patch('refactron.rag.retriever.CHROMA_AVAILABLE', True) + @patch('refactron.rag.retriever.chromadb') + @patch('refactron.rag.retriever.Settings') + @patch('refactron.rag.retriever.SentenceTransformer') + def test_retrieve_no_results(self, mock_transformer, mock_settings, mock_chroma, temp_workspace): + """Test retrieval with no results.""" + mock_collection = Mock() + mock_collection.query.return_value = { + 'documents': [[]], + 'metadatas': [[]], + 'distances': [[]] + } + + mock_client = Mock() + mock_client.get_collection.return_value = mock_collection + mock_chroma.PersistentClient.return_value = mock_client + + mock_model = Mock() + mock_model.encode.return_value.tolist.return_value = [[0.1, 0.2]] + mock_transformer.return_value = mock_model + + retriever = ContextRetriever(temp_workspace) + results = retriever.retrieve_similar("nonexistent", top_k=5) + + assert len(results) == 0 From fdc51c790fe7ad46ce4022462ef104d5dbd25bf9 Mon Sep 17 00:00:00 2001 From: omsherikar Date: Sun, 8 Feb 2026 17:27:47 +0530 Subject: [PATCH 04/19] fix: add missing GroqClient import in RAG indexer - Fixed F821 undefined name error in refactron/rag/indexer.py - Added try/except import block for GroqClient type hint - Ensures CI linting passes --- refactron/rag/indexer.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/refactron/rag/indexer.py b/refactron/rag/indexer.py index a06e8e2..4d62fec 100644 --- a/refactron/rag/indexer.py +++ b/refactron/rag/indexer.py @@ -22,6 +22,12 @@ from refactron.rag.chunker import CodeChunk from refactron.rag.parser import CodeParser, ParsedFile +# Import for type hints +try: + from refactron.llm.client import GroqClient +except ImportError: + GroqClient = None # type: ignore + @dataclass class IndexStats: From ebd6341bf7b708ffa3b71300f5933b7f49d30d4d Mon Sep 17 00:00:00 2001 From: omsherikar Date: Sun, 8 Feb 2026 17:37:15 +0530 Subject: [PATCH 05/19] chore: exclude test repo artifacts from version control - Added complex_test_repo generated files to .gitignore - Prevents accidental commits of test artifacts - Cleaned up working directory --- .gitignore | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/.gitignore b/.gitignore index 471a3b1..8d267aa 100644 --- a/.gitignore +++ b/.gitignore @@ -142,3 +142,10 @@ refactron_incremental_state.json # Exclude analysis files DIRECTORY_ANALYSIS.md FEATURES.md + +# Test repository artifacts (generated during testing) +complex_test_repo/.rag/ +complex_test_repo/.refactron.yaml +complex_test_repo/utils/math_lib_doc.md +feedback_analysis.json +verify_full_ecosystem.py From 78bd8fc12ad763fb4568d8aea73f56fb26ea2217 Mon Sep 17 00:00:00 2001 From: omsherikar Date: Sun, 8 Feb 2026 17:42:51 +0530 Subject: [PATCH 06/19] security: fix incomplete URL sanitization vulnerability - Replaced insecure substring checks with proper URL parsing - Use urlparse() to validate hostname exactly equals 'github.com' - Prevents URL injection attacks (e.g., evil.com/github.com/fake) - Fixes CodeQL high-severity security alert Before: if "github.com" in url (VULNERABLE) After: if parsed.hostname == "github.com" (SECURE) --- refactron/core/workspace.py | 30 +++++++++++++++++++++--------- 1 file changed, 21 insertions(+), 9 deletions(-) diff --git a/refactron/core/workspace.py b/refactron/core/workspace.py index d608a15..31ba338 100644 --- a/refactron/core/workspace.py +++ b/refactron/core/workspace.py @@ -11,6 +11,7 @@ from dataclasses import dataclass from pathlib import Path from typing import Any, Dict, Optional +from urllib.parse import urlparse @dataclass @@ -205,21 +206,32 @@ def detect_repository(self, directory: Optional[Path] = None) -> Optional[str]: for line in content.split("\n"): line = line.strip() if line.startswith("url = "): - url = line.replace("url = ", "") + url = line.replace("url = ", "").strip() # Extract repo name from URL # HTTPS: https://github.com/user/repo.git # SSH: git@github.com:user/repo.git - if "github.com" in url: - if url.startswith("git@github.com:"): - repo_path = url.replace("git@github.com:", "").replace(".git", "") - elif "github.com/" in url: - repo_path = url.split("github.com/")[1].replace(".git", "") - else: + + # Handle SSH GitHub URLs explicitly (SCP-like syntax) + if url.startswith("git@github.com:"): + repo_path = url.replace("git@github.com:", "", 1).replace(".git", "") + if repo_path: + return repo_path + + # Handle HTTPS/HTTP GitHub URLs with proper parsing + elif "://" in url: + try: + parsed = urlparse(url) + # Validate hostname is exactly github.com (not a substring) + if parsed.hostname == "github.com": + path = parsed.path.lstrip("/") + if path.endswith(".git"): + path = path[:-4] # Remove .git suffix + if path: + return path + except ValueError: continue - return repo_path - except (IOError, OSError): pass From 62dce26cdedce3451558bfb42da63a98fb880b19 Mon Sep 17 00:00:00 2001 From: omsherikar Date: Sun, 8 Feb 2026 17:48:06 +0530 Subject: [PATCH 07/19] fix: add Python 3.8 compatibility for tree-sitter Language API - Try Language(capsule, 'python') first (Python 3.8 API) - Fall back to Language(capsule) for newer versions - Resolves 18 RAG test errors on Python 3.8 --- refactron/rag/parser.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/refactron/rag/parser.py b/refactron/rag/parser.py index 942b7a8..667e766 100644 --- a/refactron/rag/parser.py +++ b/refactron/rag/parser.py @@ -58,8 +58,14 @@ def __init__(self): "tree-sitter is not available. Install with: pip install tree-sitter tree-sitter-python" ) - # Initialize Python language - wrap capsule with Language - PY_LANGUAGE = Language(tspython.language()) + # Initialize Python language - handle different tree-sitter API versions + # Older versions (e.g., in Python 3.8) require 'name' parameter + try: + PY_LANGUAGE = Language(tspython.language(), "python") + except TypeError: + # Newer API doesn't need name parameter + PY_LANGUAGE = Language(tspython.language()) + self.parser = Parser(PY_LANGUAGE) def parse_file(self, file_path: Path) -> ParsedFile: From ebeebd0eed4e8c1f1c3348e2660052cbbbf53dd5 Mon Sep 17 00:00:00 2001 From: omsherikar Date: Sun, 8 Feb 2026 17:54:41 +0530 Subject: [PATCH 08/19] style: format code with black Auto-formatted 9 files to pass CI checks: - refactron/core/workspace.py - refactron/rag/parser.py - tests/test_*_client.py, test_*orchestrator.py, test_rag_*.py --- refactron/core/workspace.py | 4 +- refactron/rag/parser.py | 2 +- tests/test_backend_client.py | 16 +-- tests/test_groq_client.py | 52 ++++------ tests/test_llm_orchestrator.py | 82 ++++++--------- tests/test_rag_chunker.py | 34 +++--- tests/test_rag_indexer.py | 145 ++++++++++++++------------ tests/test_rag_parser.py | 32 +++--- tests/test_rag_retriever.py | 183 ++++++++++++++++++--------------- 9 files changed, 277 insertions(+), 273 deletions(-) diff --git a/refactron/core/workspace.py b/refactron/core/workspace.py index 31ba338..03b56d6 100644 --- a/refactron/core/workspace.py +++ b/refactron/core/workspace.py @@ -211,13 +211,13 @@ def detect_repository(self, directory: Optional[Path] = None) -> Optional[str]: # Extract repo name from URL # HTTPS: https://github.com/user/repo.git # SSH: git@github.com:user/repo.git - + # Handle SSH GitHub URLs explicitly (SCP-like syntax) if url.startswith("git@github.com:"): repo_path = url.replace("git@github.com:", "", 1).replace(".git", "") if repo_path: return repo_path - + # Handle HTTPS/HTTP GitHub URLs with proper parsing elif "://" in url: try: diff --git a/refactron/rag/parser.py b/refactron/rag/parser.py index 667e766..8558734 100644 --- a/refactron/rag/parser.py +++ b/refactron/rag/parser.py @@ -65,7 +65,7 @@ def __init__(self): except TypeError: # Newer API doesn't need name parameter PY_LANGUAGE = Language(tspython.language()) - + self.parser = Parser(PY_LANGUAGE) def parse_file(self, file_path: Path) -> ParsedFile: diff --git a/tests/test_backend_client.py b/tests/test_backend_client.py index fec16a8..29b6e4d 100644 --- a/tests/test_backend_client.py +++ b/tests/test_backend_client.py @@ -30,12 +30,12 @@ def test_backend_client_generate(mock_post, mock_credentials): mock_response.status_code = 200 mock_response.json.return_value = {"content": "Suggested code"} mock_post.return_value = mock_response - + client = BackendLLMClient(backend_url="http://test-backend:3000") result = client.generate(prompt="Refactor this", system="You are an expert") - + assert result == "Suggested code" - + # Verify request mock_post.assert_called_once() args, kwargs = mock_post.call_args @@ -52,9 +52,11 @@ def test_backend_client_error_handling(mock_post, mock_credentials): mock_response.status_code = 500 mock_response.text = "Internal Server Error" mock_post.return_value = mock_response - + client = BackendLLMClient() - with pytest.raises(RuntimeError, match="Backend LLM proxy error \(500\): Internal Server Error"): + with pytest.raises( + RuntimeError, match="Backend LLM proxy error \(500\): Internal Server Error" + ): client.generate(prompt="Refactor this") @@ -64,10 +66,10 @@ def test_backend_client_check_health(mock_get): mock_response = MagicMock() mock_response.status_code = 200 mock_get.return_value = mock_response - + client = BackendLLMClient() assert client.check_health() is True - + # Mock unhealthy response mock_response.status_code = 503 assert client.check_health() is False diff --git a/tests/test_groq_client.py b/tests/test_groq_client.py index 2ac4ca8..6a74d45 100644 --- a/tests/test_groq_client.py +++ b/tests/test_groq_client.py @@ -13,23 +13,23 @@ class TestGroqClient: @pytest.fixture def mock_groq_api(self): """Mock the Groq API.""" - with patch('refactron.llm.client.Groq') as mock_groq: + with patch("refactron.llm.client.Groq") as mock_groq: # Mock response mock_response = Mock() mock_response.choices = [Mock()] mock_response.choices[0].message.content = "Generated text response" - + # Mock client mock_client = Mock() mock_client.chat.completions.create.return_value = mock_response mock_groq.return_value = mock_client - + yield mock_groq def test_client_initialization_with_api_key(self, mock_groq_api): """Test client initialization with explicit API key.""" client = GroqClient(api_key="test_key_123") - + assert client.api_key == "test_key_123" assert client.model == "llama-3.3-70b-versatile" assert client.temperature == 0.2 @@ -37,9 +37,9 @@ def test_client_initialization_with_api_key(self, mock_groq_api): def test_client_initialization_from_env(self, mock_groq_api): """Test client initialization from environment variable.""" - with patch.dict(os.environ, {'GROQ_API_KEY': 'env_key_456'}): + with patch.dict(os.environ, {"GROQ_API_KEY": "env_key_456"}): client = GroqClient() - assert client.api_key == 'env_key_456' + assert client.api_key == "env_key_456" def test_client_initialization_no_api_key(self, mock_groq_api): """Test that missing API key raises error.""" @@ -51,60 +51,50 @@ def test_generate_basic(self, mock_groq_api): """Test basic text generation.""" client = GroqClient(api_key="testkey") response = client.generate("Hello, how are you?") - + assert response == "Generated text response" assert client.client.chat.completions.create.called def test_generate_with_system_prompt(self, mock_groq_api): """Test generation with system prompt.""" client = GroqClient(api_key="testkey") - response = client.generate( - prompt="User message", - system="You are a helpful assistant" - ) - + response = client.generate(prompt="User message", system="You are a helpful assistant") + assert response == "Generated text response" - + # Check that both system and user messages were sent call_args = client.client.chat.completions.create.call_args - messages = call_args.kwargs['messages'] + messages = call_args.kwargs["messages"] assert len(messages) == 2 - assert messages[0]['role'] == 'system' - assert messages[1]['role'] == 'user' + assert messages[0]["role"] == "system" + assert messages[1]["role"] == "user" def test_generate_with_custom_params(self, mock_groq_api): """Test generation with custom temperature and max_tokens.""" client = GroqClient(api_key="testkey") - client.generate( - prompt="Test", - temperature=0.5, - max_tokens=1000 - ) - + client.generate(prompt="Test", temperature=0.5, max_tokens=1000) + call_args = client.client.chat.completions.create.call_args - assert call_args.kwargs['temperature'] == 0.5 - assert call_args.kwargs['max_tokens'] == 1000 + assert call_args.kwargs["temperature"] == 0.5 + assert call_args.kwargs["max_tokens"] == 1000 def test_custom_model(self, mock_groq_api): """Test initialization with custom model.""" - client = GroqClient( - api_key="testkey", - model="llama-3.1-8b-instant" - ) - + client = GroqClient(api_key="testkey", model="llama-3.1-8b-instant") + assert client.model == "llama-3.1-8b-instant" def test_check_health_success(self, mock_groq_api): """Test health check when API is accessible.""" client = GroqClient(api_key="testkey") is_healthy = client.check_health() - + assert is_healthy is True def test_check_health_failure(self, mock_groq_api): """Test health check when API fails.""" client = GroqClient(api_key="testkey") client.client.chat.completions.create.side_effect = Exception("API Error") - + is_healthy = client.check_health() assert is_healthy is False diff --git a/tests/test_llm_orchestrator.py b/tests/test_llm_orchestrator.py index 381187c..d685990 100644 --- a/tests/test_llm_orchestrator.py +++ b/tests/test_llm_orchestrator.py @@ -13,13 +13,13 @@ class TestLLMOrchestrator: """Tests for LLMOrchestrator.""" - + @pytest.fixture def mock_retriever(self): retriever = Mock() retriever.retrieve_similar.return_value = [] return retriever - + @pytest.fixture def mock_client(self): client = Mock() @@ -35,7 +35,7 @@ def mock_client(self): ``` """ return client - + @pytest.fixture def mock_safety(self): gate = Mock() @@ -44,7 +44,7 @@ def mock_safety(self): result.issues = [] gate.validate.return_value = result return gate - + @pytest.fixture def sample_issue(self): return CodeIssue( @@ -52,31 +52,27 @@ def sample_issue(self): level=IssueLevel.WARNING, message="Function is too long", file_path=Path("/test.py"), - line_number=10 + line_number=10, ) def test_generate_suggestion_basic(self, mock_client, mock_safety, sample_issue): """Test generating a successful suggestion.""" - orchestrator = LLMOrchestrator( - llm_client=mock_client, - safety_gate=mock_safety - ) - + orchestrator = LLMOrchestrator(llm_client=mock_client, safety_gate=mock_safety) + suggestion = orchestrator.generate_suggestion( - issue=sample_issue, - original_code="def broken(): pass" + issue=sample_issue, original_code="def broken(): pass" ) - + assert suggestion.status == SuggestionStatus.PENDING assert suggestion.proposed_code == "def fixed(): pass" assert suggestion.explanation == "Fixed the bug" assert suggestion.model_name == "llama-3.3-70b-versatile" - + # Verify prompt construction (implicity) by checking generate call mock_client.generate.assert_called_once() args = mock_client.generate.call_args - assert "Function is too long" in args.kwargs['prompt'] - assert "def broken(): pass" in args.kwargs['prompt'] + assert "Function is too long" in args.kwargs["prompt"] + assert "def broken(): pass" in args.kwargs["prompt"] def test_generate_suggestion_with_rag(self, mock_client, mock_retriever, sample_issue): """Test generation with RAG context.""" @@ -89,39 +85,33 @@ def test_generate_suggestion_with_rag(self, mock_client, mock_retriever, sample_ name="similar_func", line_range=(1, 2), distance=0.1, - metadata={} + metadata={}, ) ] - - orchestrator = LLMOrchestrator( - retriever=mock_retriever, - llm_client=mock_client - ) - - suggestion = orchestrator.generate_suggestion( - issue=sample_issue, - original_code="original" - ) - + + orchestrator = LLMOrchestrator(retriever=mock_retriever, llm_client=mock_client) + + suggestion = orchestrator.generate_suggestion(issue=sample_issue, original_code="original") + assert "/similar.py" in suggestion.context_files - + # Verify prompt contains RAG context args = mock_client.generate.call_args - assert "def similar_func(): pass" in args.kwargs['prompt'] + assert "def similar_func(): pass" in args.kwargs["prompt"] def test_bad_llm_response(self, mock_client, sample_issue): """Test handling of invalid JSON from LLM.""" mock_client.generate.return_value = "This is not JSON" - + orchestrator = LLMOrchestrator(llm_client=mock_client) - - suggestion = orchestrator.generate_suggestion( - issue=sample_issue, - original_code="original" - ) - + + suggestion = orchestrator.generate_suggestion(issue=sample_issue, original_code="original") + assert suggestion.status == SuggestionStatus.FAILED - assert "JSONDecodeError" in suggestion.explanation or "Expecting value" in suggestion.explanation + assert ( + "JSONDecodeError" in suggestion.explanation + or "Expecting value" in suggestion.explanation + ) def test_safety_check_failure(self, mock_client, mock_safety, sample_issue): """Test rejection when safety check fails.""" @@ -130,16 +120,10 @@ def test_safety_check_failure(self, mock_client, mock_safety, sample_issue): result.passed = False result.issues = ["Syntax Error"] mock_safety.validate.return_value = result - - orchestrator = LLMOrchestrator( - llm_client=mock_client, - safety_gate=mock_safety - ) - - suggestion = orchestrator.generate_suggestion( - issue=sample_issue, - original_code="original" - ) - + + orchestrator = LLMOrchestrator(llm_client=mock_client, safety_gate=mock_safety) + + suggestion = orchestrator.generate_suggestion(issue=sample_issue, original_code="original") + assert suggestion.status == SuggestionStatus.REJECTED assert not suggestion.safety_result.passed diff --git a/tests/test_rag_chunker.py b/tests/test_rag_chunker.py index a7dbae9..0f09e76 100644 --- a/tests/test_rag_chunker.py +++ b/tests/test_rag_chunker.py @@ -25,7 +25,7 @@ def chunker(self, parser): @pytest.fixture def temp_python_file(self): """Create a temporary Python file for testing.""" - with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f: + with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f: content = '''"""Module level docstring.""" import os @@ -48,9 +48,9 @@ def method_one(self): ''' f.write(content) temp_path = Path(f.name) - + yield temp_path - + # Cleanup temp_path.unlink() @@ -62,17 +62,17 @@ def test_chunker_initialization(self, parser, chunker): def test_chunk_file_basic(self, chunker, temp_python_file): """Test basic file chunking.""" chunks = chunker.chunk_file(temp_python_file) - + assert len(chunks) > 0 assert all(isinstance(chunk, CodeChunk) for chunk in chunks) def test_module_chunk_created(self, chunker, temp_python_file): """Test that module chunk is created when there are imports/docstring.""" chunks = chunker.chunk_file(temp_python_file) - + module_chunks = [c for c in chunks if c.chunk_type == "module"] assert len(module_chunks) == 1 - + module_chunk = module_chunks[0] assert "Module level docstring" in module_chunk.content assert "import os" in module_chunk.content @@ -80,10 +80,10 @@ def test_module_chunk_created(self, chunker, temp_python_file): def test_function_chunks_created(self, chunker, temp_python_file): """Test that function chunks are created correctly.""" chunks = chunker.chunk_file(temp_python_file) - + function_chunks = [c for c in chunks if c.chunk_type == "function"] assert len(function_chunks) == 2 - + # Check first function chunk func_names = [c.name for c in function_chunks] assert "test_function" in func_names @@ -92,10 +92,10 @@ def test_function_chunks_created(self, chunker, temp_python_file): def test_class_chunks_created(self, chunker, temp_python_file): """Test that class chunks are created.""" chunks = chunker.chunk_file(temp_python_file) - + class_chunks = [c for c in chunks if c.chunk_type == "class"] assert len(class_chunks) == 1 - + class_chunk = class_chunks[0] assert class_chunk.name == "TestClass" assert "Test class docstring" in class_chunk.content @@ -103,10 +103,10 @@ def test_class_chunks_created(self, chunker, temp_python_file): def test_method_chunks_created(self, chunker, temp_python_file): """Test that method chunks are created.""" chunks = chunker.chunk_file(temp_python_file) - + method_chunks = [c for c in chunks if c.chunk_type == "method"] assert len(method_chunks) == 1 - + method_chunk = method_chunks[0] assert method_chunk.name == "method_one" assert method_chunk.metadata["class_name"] == "TestClass" @@ -114,7 +114,7 @@ def test_method_chunks_created(self, chunker, temp_python_file): def test_chunk_metadata(self, chunker, temp_python_file): """Test that chunk metadata is populated correctly.""" chunks = chunker.chunk_file(temp_python_file) - + for chunk in chunks: assert chunk.file_path == str(temp_python_file) assert chunk.line_range[0] > 0 @@ -123,10 +123,10 @@ def test_chunk_metadata(self, chunker, temp_python_file): def test_empty_file(self, chunker): """Test chunking an empty file.""" - with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f: + with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f: f.write("") temp_path = Path(f.name) - + try: chunks = chunker.chunk_file(temp_path) assert len(chunks) == 0 @@ -135,10 +135,10 @@ def test_empty_file(self, chunker): def test_file_with_only_imports(self, chunker): """Test file with only imports.""" - with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f: + with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f: f.write("import os\n") temp_path = Path(f.name) - + try: chunks = chunker.chunk_file(temp_path) assert len(chunks) == 1 diff --git a/tests/test_rag_indexer.py b/tests/test_rag_indexer.py index dea5c20..0c2e493 100644 --- a/tests/test_rag_indexer.py +++ b/tests/test_rag_indexer.py @@ -11,15 +11,18 @@ transformers_mock = MagicMock() transformers_mock.__path__ = [] # Make it look like a package -# Patch sys.modules before importing anything that depends on sentence-transformers -with patch.dict('sys.modules', { - 'transformers': transformers_mock, - 'transformers.configuration_utils': MagicMock(), - 'transformers.utils': MagicMock(), - 'transformers.models': MagicMock(), - 'transformers.file_utils': MagicMock(), - 'transformers.tokenization_utils_base': MagicMock(), -}): +# Patch sys.modules before importing anything that depends on sentence-transformers +with patch.dict( + "sys.modules", + { + "transformers": transformers_mock, + "transformers.configuration_utils": MagicMock(), + "transformers.utils": MagicMock(), + "transformers.models": MagicMock(), + "transformers.file_utils": MagicMock(), + "transformers.tokenization_utils_base": MagicMock(), + }, +): from refactron.rag.chunker import CodeChunk @@ -31,17 +34,20 @@ def temp_workspace(self): """Create a temporary workspace directory.""" with tempfile.TemporaryDirectory() as tmpdir: workspace_path = Path(tmpdir) - + # Create sample Python files - (workspace_path / "simple.py").write_text(''' + (workspace_path / "simple.py").write_text( + ''' """Simple module.""" def hello(): """Say hello.""" return "Hello" -''') - - (workspace_path / "utils.py").write_text(''' +''' + ) + + (workspace_path / "utils.py").write_text( + ''' """Utility functions.""" class Calculator: @@ -50,55 +56,60 @@ class Calculator: def add(self, x, y): """Add two numbers.""" return x + y -''') - +''' + ) + yield workspace_path - @patch('refactron.rag.indexer.CHROMA_AVAILABLE', True) - @patch('refactron.rag.indexer.chromadb') - @patch('refactron.rag.indexer.Settings') - @patch('refactron.rag.indexer.SentenceTransformer') - @patch('refactron.rag.indexer.CodeParser') - def test_indexer_initialization(self, mock_parser, mock_transformer, mock_settings, mock_chroma, temp_workspace): + @patch("refactron.rag.indexer.CHROMA_AVAILABLE", True) + @patch("refactron.rag.indexer.chromadb") + @patch("refactron.rag.indexer.Settings") + @patch("refactron.rag.indexer.SentenceTransformer") + @patch("refactron.rag.indexer.CodeParser") + def test_indexer_initialization( + self, mock_parser, mock_transformer, mock_settings, mock_chroma, temp_workspace + ): """Test indexer initialization.""" from refactron.rag.indexer import RAGIndexer - + indexer = RAGIndexer(temp_workspace) - + assert indexer.workspace_path == temp_workspace assert indexer.index_path == temp_workspace / ".rag" assert indexer.embedding_model_name == "all-MiniLM-L6-v2" - @patch('refactron.rag.indexer.CHROMA_AVAILABLE', False) + @patch("refactron.rag.indexer.CHROMA_AVAILABLE", False) def test_indexer_requires_chromadb(self, temp_workspace): """Test that indexer requires ChromaDB.""" from refactron.rag.indexer import RAGIndexer - + with pytest.raises(RuntimeError, match="ChromaDB is not available"): RAGIndexer(temp_workspace) - @patch('refactron.rag.indexer.CHROMA_AVAILABLE', True) - @patch('refactron.rag.indexer.chromadb') - @patch('refactron.rag.indexer.Settings') - @patch('refactron.rag.indexer.SentenceTransformer') - @patch('refactron.rag.indexer.CodeParser') - def test_add_chunks(self, mock_parser, mock_transformer, mock_settings, mock_chroma, temp_workspace): + @patch("refactron.rag.indexer.CHROMA_AVAILABLE", True) + @patch("refactron.rag.indexer.chromadb") + @patch("refactron.rag.indexer.Settings") + @patch("refactron.rag.indexer.SentenceTransformer") + @patch("refactron.rag.indexer.CodeParser") + def test_add_chunks( + self, mock_parser, mock_transformer, mock_settings, mock_chroma, temp_workspace + ): """Test adding chunks to the index.""" from refactron.rag.indexer import RAGIndexer - + # Setup mocks mock_collection = Mock() mock_client = Mock() mock_client.get_or_create_collection.return_value = mock_collection mock_chroma.PersistentClient.return_value = mock_client - + mock_model = Mock() mock_model.encode.return_value = Mock() mock_model.encode.return_value.tolist.return_value = [[0.1, 0.2, 0.3]] mock_transformer.return_value = mock_model - + indexer = RAGIndexer(temp_workspace) - + # Create test chunk chunk = CodeChunk( content="def test(): pass", @@ -107,65 +118,67 @@ def test_add_chunks(self, mock_parser, mock_transformer, mock_settings, mock_chr line_range=(1, 1), name="test", dependencies=[], - metadata={} + metadata={}, ) - + indexer.add_chunks([chunk]) - + # Verify chunk was added assert mock_collection.add.called - @patch('refactron.rag.indexer.CHROMA_AVAILABLE', True) - @patch('refactron.rag.indexer.chromadb') - @patch('refactron.rag.indexer.Settings') - @patch('refactron.rag.indexer.SentenceTransformer') - @patch('refactron.rag.indexer.CodeParser') - def test_add_empty_chunks(self, mock_parser, mock_transformer, mock_settings, mock_chroma, temp_workspace): + @patch("refactron.rag.indexer.CHROMA_AVAILABLE", True) + @patch("refactron.rag.indexer.chromadb") + @patch("refactron.rag.indexer.Settings") + @patch("refactron.rag.indexer.SentenceTransformer") + @patch("refactron.rag.indexer.CodeParser") + def test_add_empty_chunks( + self, mock_parser, mock_transformer, mock_settings, mock_chroma, temp_workspace + ): """Test that adding empty chunks list does nothing.""" from refactron.rag.indexer import RAGIndexer - + mock_collection = Mock() mock_client = Mock() mock_client.get_or_create_collection.return_value = mock_collection mock_chroma.PersistentClient.return_value = mock_client - + mock_model = Mock() mock_transformer.return_value = mock_model - + indexer = RAGIndexer(temp_workspace) indexer.add_chunks([]) - + # Verify nothing was added assert not mock_collection.add.called - @patch('refactron.rag.indexer.CHROMA_AVAILABLE', True) - @patch('refactron.rag.indexer.chromadb') - @patch('refactron.rag.indexer.Settings') - @patch('refactron.rag.indexer.SentenceTransformer') - @patch('refactron.rag.indexer.CodeParser') - def test_get_stats(self, mock_parser, mock_transformer, mock_settings, mock_chroma, temp_workspace): + @patch("refactron.rag.indexer.CHROMA_AVAILABLE", True) + @patch("refactron.rag.indexer.chromadb") + @patch("refactron.rag.indexer.Settings") + @patch("refactron.rag.indexer.SentenceTransformer") + @patch("refactron.rag.indexer.CodeParser") + def test_get_stats( + self, mock_parser, mock_transformer, mock_settings, mock_chroma, temp_workspace + ): """Test getting index statistics.""" from refactron.rag.indexer import RAGIndexer - + mock_collection = Mock() mock_client = Mock() mock_client.get_or_create_collection.return_value = mock_collection mock_chroma.PersistentClient.return_value = mock_client - + mock_model = Mock() mock_transformer.return_value = mock_model - + indexer = RAGIndexer(temp_workspace) - + # Save some metadata - indexer._save_metadata({ - "total_chunks": 10, - "total_files": 2, - "chunk_types": {"function": 8, "class": 2} - }) - + indexer._save_metadata( + {"total_chunks": 10, "total_files": 2, "chunk_types": {"function": 8, "class": 2}} + ) + stats = indexer.get_stats() - + assert stats.total_chunks == 10 assert stats.total_files == 2 assert stats.chunk_types["function"] == 8 diff --git a/tests/test_rag_parser.py b/tests/test_rag_parser.py index 19af222..b47d836 100644 --- a/tests/test_rag_parser.py +++ b/tests/test_rag_parser.py @@ -19,7 +19,7 @@ def parser(self): @pytest.fixture def temp_python_file(self): """Create a temporary Python file for testing.""" - with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f: + with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f: content = '''"""Module docstring for testing.""" import os @@ -47,9 +47,9 @@ def method_two(self, param): ''' f.write(content) temp_path = Path(f.name) - + yield temp_path - + # Cleanup temp_path.unlink() @@ -61,15 +61,15 @@ def test_parser_initialization(self, parser): def test_parse_file_basic(self, parser, temp_python_file): """Test parsing a basic Python file.""" parsed = parser.parse_file(temp_python_file) - + assert isinstance(parsed, ParsedFile) assert parsed.file_path == str(temp_python_file) assert parsed.module_docstring == "Module docstring for testing." - + def test_extract_imports(self, parser, temp_python_file): """Test that imports are extracted correctly.""" parsed = parser.parse_file(temp_python_file) - + assert len(parsed.imports) == 3 assert "import os" in parsed.imports assert "import sys" in parsed.imports @@ -78,16 +78,16 @@ def test_extract_imports(self, parser, temp_python_file): def test_extract_functions(self, parser, temp_python_file): """Test that functions are extracted correctly.""" parsed = parser.parse_file(temp_python_file) - + assert len(parsed.functions) == 2 - + # Check first function func1 = parsed.functions[0] assert isinstance(func1, ParsedFunction) assert func1.name == "simple_function" assert func1.docstring == "Add two numbers." assert len(func1.params) >= 2 # Should have x and y - + # Check second function func2 = parsed.functions[1] assert func2.name == "another_function" @@ -95,15 +95,15 @@ def test_extract_functions(self, parser, temp_python_file): def test_extract_classes(self, parser, temp_python_file): """Test that classes are extracted correctly.""" parsed = parser.parse_file(temp_python_file) - + assert len(parsed.classes) == 1 - + # Check class cls = parsed.classes[0] assert isinstance(cls, ParsedClass) assert cls.name == "TestClass" assert cls.docstring == "A test class." - + # Check methods assert len(cls.methods) == 2 assert cls.methods[0].name == "method_one" @@ -116,10 +116,10 @@ def test_parse_invalid_file(self, parser): def test_parse_empty_file(self, parser): """Test parsing an empty Python file.""" - with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f: + with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f: f.write("") temp_path = Path(f.name) - + try: parsed = parser.parse_file(temp_path) assert parsed.module_docstring is None @@ -132,7 +132,7 @@ def test_parse_empty_file(self, parser): def test_function_line_ranges(self, parser, temp_python_file): """Test that line ranges are captured correctly.""" parsed = parser.parse_file(temp_python_file) - + for func in parsed.functions: assert func.line_range[0] > 0 assert func.line_range[1] >= func.line_range[0] @@ -140,7 +140,7 @@ def test_function_line_ranges(self, parser, temp_python_file): def test_class_methods_have_correct_metadata(self, parser, temp_python_file): """Test that class methods preserve metadata.""" parsed = parser.parse_file(temp_python_file) - + test_class = parsed.classes[0] for method in test_class.methods: assert method.name in ["method_one", "method_two"] diff --git a/tests/test_rag_retriever.py b/tests/test_rag_retriever.py index 0ef3cc9..9b79c75 100644 --- a/tests/test_rag_retriever.py +++ b/tests/test_rag_retriever.py @@ -10,14 +10,17 @@ transformers_mock = MagicMock() transformers_mock.__path__ = [] -with patch.dict('sys.modules', { - 'transformers': transformers_mock, - 'transformers.configuration_utils': MagicMock(), - 'transformers.utils': MagicMock(), - 'transformers.models': MagicMock(), - 'transformers.file_utils': MagicMock(), - 'transformers.tokenization_utils_base': MagicMock(), -}): +with patch.dict( + "sys.modules", + { + "transformers": transformers_mock, + "transformers.configuration_utils": MagicMock(), + "transformers.utils": MagicMock(), + "transformers.models": MagicMock(), + "transformers.file_utils": MagicMock(), + "transformers.tokenization_utils_base": MagicMock(), + }, +): from refactron.rag.retriever import ContextRetriever, RetrievedContext @@ -32,164 +35,176 @@ def temp_workspace(self): (workspace_path / ".rag").mkdir() yield workspace_path - @patch('refactron.rag.retriever.CHROMA_AVAILABLE', True) - @patch('refactron.rag.retriever.chromadb') - @patch('refactron.rag.retriever.Settings') - @patch('refactron.rag.retriever.SentenceTransformer') - def test_retriever_initialization(self, mock_transformer, mock_settings, mock_chroma, temp_workspace): + @patch("refactron.rag.retriever.CHROMA_AVAILABLE", True) + @patch("refactron.rag.retriever.chromadb") + @patch("refactron.rag.retriever.Settings") + @patch("refactron.rag.retriever.SentenceTransformer") + def test_retriever_initialization( + self, mock_transformer, mock_settings, mock_chroma, temp_workspace + ): """Test retriever initialization.""" mock_collection = Mock() mock_client = Mock() mock_client.get_collection.return_value = mock_collection mock_chroma.PersistentClient.return_value = mock_client - + mock_model = Mock() mock_transformer.return_value = mock_model - + retriever = ContextRetriever(temp_workspace) - + assert retriever.workspace_path == temp_workspace assert retriever.index_path == temp_workspace / ".rag" - @patch('refactron.rag.retriever.CHROMA_AVAILABLE', False) + @patch("refactron.rag.retriever.CHROMA_AVAILABLE", False) def test_retriever_requires_chromadb(self, temp_workspace): """Test that retriever requires ChromaDB.""" with pytest.raises(RuntimeError, match="ChromaDB is not available"): ContextRetriever(temp_workspace) - @patch('refactron.rag.retriever.CHROMA_AVAILABLE', True) - @patch('refactron.rag.retriever.chromadb') - @patch('refactron.rag.retriever.Settings') - @patch('refactron.rag.retriever.SentenceTransformer') - def test_retriever_missing_index(self, mock_transformer, mock_settings, mock_chroma, temp_workspace): + @patch("refactron.rag.retriever.CHROMA_AVAILABLE", True) + @patch("refactron.rag.retriever.chromadb") + @patch("refactron.rag.retriever.Settings") + @patch("refactron.rag.retriever.SentenceTransformer") + def test_retriever_missing_index( + self, mock_transformer, mock_settings, mock_chroma, temp_workspace + ): """Test that missing index raises error.""" mock_client = Mock() mock_client.get_collection.side_effect = Exception("Collection not found") mock_chroma.PersistentClient.return_value = mock_client - + mock_model = Mock() mock_transformer.return_value = mock_model - + with pytest.raises(RuntimeError, match="Index not found"): ContextRetriever(temp_workspace) - @patch('refactron.rag.retriever.CHROMA_AVAILABLE', True) - @patch('refactron.rag.retriever.chromadb') - @patch('refactron.rag.retriever.Settings') - @patch('refactron.rag.retriever.SentenceTransformer') + @patch("refactron.rag.retriever.CHROMA_AVAILABLE", True) + @patch("refactron.rag.retriever.chromadb") + @patch("refactron.rag.retriever.Settings") + @patch("refactron.rag.retriever.SentenceTransformer") def test_retrieve_similar(self, mock_transformer, mock_settings, mock_chroma, temp_workspace): """Test retrieving similar code chunks.""" # Setup mock collection with results mock_collection = Mock() mock_collection.query.return_value = { - 'documents': [['def test(): pass']], - 'metadatas': [[{ - 'file_path': '/test.py', - 'chunk_type': 'function', - 'name': 'test', - 'line_start': 1, - 'line_end': 1 - }]], - 'distances': [[0.15]] + "documents": [["def test(): pass"]], + "metadatas": [ + [ + { + "file_path": "/test.py", + "chunk_type": "function", + "name": "test", + "line_start": 1, + "line_end": 1, + } + ] + ], + "distances": [[0.15]], } - + mock_client = Mock() mock_client.get_collection.return_value = mock_collection mock_chroma.PersistentClient.return_value = mock_client - + mock_model = Mock() mock_model.encode.return_value.tolist.return_value = [[0.1, 0.2]] mock_transformer.return_value = mock_model - + retriever = ContextRetriever(temp_workspace) results = retriever.retrieve_similar("test function", top_k=1) - + assert len(results) == 1 assert isinstance(results[0], RetrievedContext) - assert results[0].name == 'test' - assert results[0].chunk_type == 'function' + assert results[0].name == "test" + assert results[0].chunk_type == "function" - @patch('refactron.rag.retriever.CHROMA_AVAILABLE', True) - @patch('refactron.rag.retriever.chromadb') - @patch('refactron.rag.retriever.Settings') - @patch('refactron.rag.retriever.SentenceTransformer') + @patch("refactron.rag.retriever.CHROMA_AVAILABLE", True) + @patch("refactron.rag.retriever.chromadb") + @patch("refactron.rag.retriever.Settings") + @patch("refactron.rag.retriever.SentenceTransformer") def test_retrieve_by_file(self, mock_transformer, mock_settings, mock_chroma, temp_workspace): """Test retrieving chunks by file path.""" mock_collection = Mock() mock_collection.get.return_value = { - 'documents': ['def test(): pass'], - 'metadatas': [{ - 'file_path': '/test.py', - 'chunk_type': 'function', - 'name': 'test', - 'line_start': 1, - 'line_end': 1 - }] + "documents": ["def test(): pass"], + "metadatas": [ + { + "file_path": "/test.py", + "chunk_type": "function", + "name": "test", + "line_start": 1, + "line_end": 1, + } + ], } - + mock_client = Mock() mock_client.get_collection.return_value = mock_collection mock_chroma.PersistentClient.return_value = mock_client - + mock_model = Mock() mock_transformer.return_value = mock_model - + retriever = ContextRetriever(temp_workspace) results = retriever.retrieve_by_file("/test.py") - + assert len(results) == 1 - assert results[0].file_path == '/test.py' + assert results[0].file_path == "/test.py" - @patch('refactron.rag.retriever.CHROMA_AVAILABLE', True) - @patch('refactron.rag.retriever.chromadb') - @patch('refactron.rag.retriever.Settings') - @patch('refactron.rag.retriever.SentenceTransformer') + @patch("refactron.rag.retriever.CHROMA_AVAILABLE", True) + @patch("refactron.rag.retriever.chromadb") + @patch("refactron.rag.retriever.Settings") + @patch("refactron.rag.retriever.SentenceTransformer") def test_retrieve_functions(self, mock_transformer, mock_settings, mock_chroma, temp_workspace): """Test retrieving only function chunks.""" mock_collection = Mock() mock_collection.query.return_value = { - 'documents': [[]], - 'metadatas': [[]], - 'distances': [[]] + "documents": [[]], + "metadatas": [[]], + "distances": [[]], } - + mock_client = Mock() mock_client.get_collection.return_value = mock_collection mock_chroma.PersistentClient.return_value = mock_client - + mock_model = Mock() mock_model.encode.return_value.tolist.return_value = [[0.1, 0.2]] mock_transformer.return_value = mock_model - + retriever = ContextRetriever(temp_workspace) retriever.retrieve_functions("test", top_k=5) - + # Verify that chunk_type filter was used call_args = mock_collection.query.call_args - assert call_args.kwargs.get('where') == {"chunk_type": "function"} + assert call_args.kwargs.get("where") == {"chunk_type": "function"} - @patch('refactron.rag.retriever.CHROMA_AVAILABLE', True) - @patch('refactron.rag.retriever.chromadb') - @patch('refactron.rag.retriever.Settings') - @patch('refactron.rag.retriever.SentenceTransformer') - def test_retrieve_no_results(self, mock_transformer, mock_settings, mock_chroma, temp_workspace): + @patch("refactron.rag.retriever.CHROMA_AVAILABLE", True) + @patch("refactron.rag.retriever.chromadb") + @patch("refactron.rag.retriever.Settings") + @patch("refactron.rag.retriever.SentenceTransformer") + def test_retrieve_no_results( + self, mock_transformer, mock_settings, mock_chroma, temp_workspace + ): """Test retrieval with no results.""" mock_collection = Mock() mock_collection.query.return_value = { - 'documents': [[]], - 'metadatas': [[]], - 'distances': [[]] + "documents": [[]], + "metadatas": [[]], + "distances": [[]], } - + mock_client = Mock() mock_client.get_collection.return_value = mock_collection mock_chroma.PersistentClient.return_value = mock_client - + mock_model = Mock() mock_model.encode.return_value.tolist.return_value = [[0.1, 0.2]] mock_transformer.return_value = mock_model - + retriever = ContextRetriever(temp_workspace) results = retriever.retrieve_similar("nonexistent", top_k=5) - + assert len(results) == 0 From 5fed019944830590c342e8aa7a2808010017b8f7 Mon Sep 17 00:00:00 2001 From: omsherikar Date: Sun, 8 Feb 2026 17:56:57 +0530 Subject: [PATCH 09/19] style: apply pre-commit auto-fixes - Fixed trailing whitespace - Applied black formatting - Fixed import ordering with isort --- refactron/llm/models.py | 6 +-- refactron/llm/orchestrator.py | 8 +-- refactron/llm/safety.py | 2 +- refactron/rag/parser.py | 2 +- scripts/analyze_feedback_data.py | 85 ++++++++++++++++---------------- tests/test_backend_client.py | 3 +- tests/test_groq_client.py | 3 +- tests/test_llm_orchestrator.py | 7 +-- tests/test_rag_chunker.py | 4 +- tests/test_rag_indexer.py | 6 +-- tests/test_rag_parser.py | 6 +-- tests/test_rag_retriever.py | 2 +- 12 files changed, 68 insertions(+), 66 deletions(-) diff --git a/refactron/llm/models.py b/refactron/llm/models.py index 9e19362..82ff83f 100644 --- a/refactron/llm/models.py +++ b/refactron/llm/models.py @@ -1,10 +1,10 @@ """Data models for LLM integration.""" -from dataclasses import dataclass, field -from enum import Enum -from typing import List, Optional, Dict, Any import time import uuid +from dataclasses import dataclass, field +from enum import Enum +from typing import Any, Dict, List, Optional from refactron.core.models import CodeIssue diff --git a/refactron/llm/orchestrator.py b/refactron/llm/orchestrator.py index 914e703..f4da49d 100644 --- a/refactron/llm/orchestrator.py +++ b/refactron/llm/orchestrator.py @@ -4,14 +4,14 @@ import logging import os import re -from typing import Optional, List, Union - from pathlib import Path +from typing import List, Optional, Union + from refactron.core.models import CodeIssue, IssueCategory, IssueLevel -from refactron.llm.client import GroqClient from refactron.llm.backend_client import BackendLLMClient +from refactron.llm.client import GroqClient from refactron.llm.models import RefactoringSuggestion, SuggestionStatus -from refactron.llm.prompts import SYSTEM_PROMPT, SUGGESTION_PROMPT, DOCUMENTATION_PROMPT +from refactron.llm.prompts import DOCUMENTATION_PROMPT, SUGGESTION_PROMPT, SYSTEM_PROMPT from refactron.llm.safety import SafetyGate from refactron.rag.retriever import ContextRetriever diff --git a/refactron/llm/safety.py b/refactron/llm/safety.py index eda20ce..a3b1543 100644 --- a/refactron/llm/safety.py +++ b/refactron/llm/safety.py @@ -3,7 +3,7 @@ import ast from typing import List, Optional -from refactron.llm.models import SafetyCheckResult, RefactoringSuggestion +from refactron.llm.models import RefactoringSuggestion, SafetyCheckResult class SafetyGate: diff --git a/refactron/rag/parser.py b/refactron/rag/parser.py index 8558734..e983cc1 100644 --- a/refactron/rag/parser.py +++ b/refactron/rag/parser.py @@ -7,8 +7,8 @@ from typing import Any, Dict, List, Optional, Tuple try: - from tree_sitter import Language, Parser, Node import tree_sitter_python as tspython + from tree_sitter import Language, Node, Parser TREE_SITTER_AVAILABLE = True except ImportError: diff --git a/scripts/analyze_feedback_data.py b/scripts/analyze_feedback_data.py index 92cac2b..1093573 100644 --- a/scripts/analyze_feedback_data.py +++ b/scripts/analyze_feedback_data.py @@ -8,10 +8,10 @@ - Readiness for ML training """ -from pathlib import Path -from collections import Counter import json import sys +from collections import Counter +from pathlib import Path # Add parent directory to path for imports sys.path.insert(0, str(Path(__file__).parent.parent)) @@ -21,72 +21,74 @@ def analyze_feedback(): """Analyze all available feedback data.""" - + # Find all pattern storage directories - root = Path('.') - storage_dirs = list(root.glob('**/.refactron/patterns')) - + root = Path(".") + storage_dirs = list(root.glob("**/.refactron/patterns")) + print(f"Found {len(storage_dirs)} storage directories\n") - + all_feedback = [] all_patterns = {} - + # Aggregate data from all projects for storage_dir in storage_dirs: try: storage = PatternStorage(storage_dir) feedback = storage.load_feedback() patterns = storage.load_patterns() - + all_feedback.extend(feedback) all_patterns.update(patterns) - + print(f"šŸ“ {storage_dir.parent.parent}") print(f" Feedback: {len(feedback)}, Patterns: {len(patterns)}") except Exception as e: print(f"āš ļø Error loading {storage_dir}: {e}") - + if not all_feedback: print("\nāŒ No feedback data found!") print(" Run some refactoring operations and provide feedback first.") return None - + print(f"\n{'='*60}") print(f"AGGREGATE STATISTICS") print(f"{'='*60}\n") - + print(f"šŸ“Š Total Records:") print(f" Feedback: {len(all_feedback)}") print(f" Patterns: {len(all_patterns)}") - + # Action distribution actions = Counter(f.action for f in all_feedback) print(f"\nāœ… Action Distribution:") for action, count in actions.most_common(): pct = count / len(all_feedback) * 100 print(f" {action:12s}: {count:4d} ({pct:5.1f}%)") - + # Operation types operation_types = Counter(f.operation_type for f in all_feedback) print(f"\nšŸ”§ Operation Types:") for op_type, count in operation_types.most_common(5): pct = count / len(all_feedback) * 100 print(f" {op_type:20s}: {count:4d} ({pct:5.1f}%)") - + # Data quality - with_patterns = sum(1 for f in all_feedback if hasattr(f, 'code_pattern_hash') and f.code_pattern_hash) - with_reason = sum(1 for f in all_feedback if hasattr(f, 'reason') and f.reason) - + with_patterns = sum( + 1 for f in all_feedback if hasattr(f, "code_pattern_hash") and f.code_pattern_hash + ) + with_reason = sum(1 for f in all_feedback if hasattr(f, "reason") and f.reason) + print(f"\nšŸ“‹ Data Quality:") print(f" With pattern hash: {with_patterns:4d} ({with_patterns/len(all_feedback)*100:5.1f}%)") print(f" With reason: {with_reason:4d} ({with_reason/len(all_feedback)*100:5.1f}%)") - + # ML readiness quality_score = with_patterns / len(all_feedback) if all_feedback else 0 - + print(f"\nšŸŽÆ ML Readiness:") print(f" Quality Score: {quality_score:.2%}") - + if len(all_feedback) < 50: print(f" Status: āŒ INSUFFICIENT DATA") print(f" Need: {50 - len(all_feedback)} more feedback records") @@ -95,39 +97,36 @@ def analyze_feedback(): print(f" Many records missing pattern hashes") else: print(f" Status: āœ… READY FOR TRAINING") - + # Save detailed report report = { - 'summary': { - 'total_feedback': len(all_feedback), - 'total_patterns': len(all_patterns), - 'quality_score': quality_score, - 'ml_ready': len(all_feedback) >= 50 and quality_score >= 0.7 + "summary": { + "total_feedback": len(all_feedback), + "total_patterns": len(all_patterns), + "quality_score": quality_score, + "ml_ready": len(all_feedback) >= 50 and quality_score >= 0.7, }, - 'actions': dict(actions), - 'operation_types': dict(operation_types), - 'quality': { - 'with_pattern_hash': with_patterns, - 'with_reason': with_reason - } + "actions": dict(actions), + "operation_types": dict(operation_types), + "quality": {"with_pattern_hash": with_patterns, "with_reason": with_reason}, } - - report_file = Path('feedback_analysis.json') - with open(report_file, 'w') as f: + + report_file = Path("feedback_analysis.json") + with open(report_file, "w") as f: json.dump(report, f, indent=2) - + print(f"\nšŸ’¾ Detailed report saved to: {report_file}") - + return report -if __name__ == '__main__': +if __name__ == "__main__": print("šŸ” Refactron Feedback Data Analysis\n") report = analyze_feedback() - - if report and report['summary']['ml_ready']: + + if report and report["summary"]["ml_ready"]: print("\n✨ Ready to proceed with ML model training!") elif report: print("\nā³ Collect more feedback data before training.") - + sys.exit(0 if report else 1) diff --git a/tests/test_backend_client.py b/tests/test_backend_client.py index 29b6e4d..fe0399e 100644 --- a/tests/test_backend_client.py +++ b/tests/test_backend_client.py @@ -1,8 +1,9 @@ """Tests for BackendLLMClient.""" +from unittest.mock import MagicMock, patch + import pytest import requests -from unittest.mock import MagicMock, patch from refactron.llm.backend_client import BackendLLMClient diff --git a/tests/test_groq_client.py b/tests/test_groq_client.py index 6a74d45..17612d6 100644 --- a/tests/test_groq_client.py +++ b/tests/test_groq_client.py @@ -1,9 +1,10 @@ """Tests for the Groq LLM client.""" import os -import pytest from unittest.mock import Mock, patch +import pytest + from refactron.llm.client import GroqClient diff --git a/tests/test_llm_orchestrator.py b/tests/test_llm_orchestrator.py index d685990..27b1e88 100644 --- a/tests/test_llm_orchestrator.py +++ b/tests/test_llm_orchestrator.py @@ -1,13 +1,14 @@ """Tests for LLM Orchestrator.""" import json -from unittest.mock import Mock, MagicMock from pathlib import Path +from unittest.mock import MagicMock, Mock import pytest -from refactron.core.models import CodeIssue, IssueLevel, IssueCategory + +from refactron.core.models import CodeIssue, IssueCategory, IssueLevel +from refactron.llm.models import RefactoringSuggestion, SuggestionStatus from refactron.llm.orchestrator import LLMOrchestrator -from refactron.llm.models import SuggestionStatus, RefactoringSuggestion from refactron.rag.retriever import RetrievedContext diff --git a/tests/test_rag_chunker.py b/tests/test_rag_chunker.py index 0f09e76..91aef78 100644 --- a/tests/test_rag_chunker.py +++ b/tests/test_rag_chunker.py @@ -5,7 +5,7 @@ import pytest -from refactron.rag.chunker import CodeChunker, CodeChunk +from refactron.rag.chunker import CodeChunk, CodeChunker from refactron.rag.parser import CodeParser @@ -41,7 +41,7 @@ def another_function(): class TestClass: """Test class docstring.""" - + def method_one(self): """Method docstring.""" return 1 diff --git a/tests/test_rag_indexer.py b/tests/test_rag_indexer.py index 0c2e493..7305fe6 100644 --- a/tests/test_rag_indexer.py +++ b/tests/test_rag_indexer.py @@ -1,11 +1,11 @@ """Tests for the RAG indexer module.""" +import sys import tempfile from pathlib import Path +from unittest.mock import MagicMock, Mock, create_autospec, patch import pytest -from unittest.mock import Mock, MagicMock, patch, create_autospec -import sys # Create a comprehensive mock for transformers that handles all submodule access transformers_mock = MagicMock() @@ -52,7 +52,7 @@ def hello(): class Calculator: """A simple calculator.""" - + def add(self, x, y): """Add two numbers.""" return x + y diff --git a/tests/test_rag_parser.py b/tests/test_rag_parser.py index b47d836..02e68b2 100644 --- a/tests/test_rag_parser.py +++ b/tests/test_rag_parser.py @@ -5,7 +5,7 @@ import pytest -from refactron.rag.parser import CodeParser, ParsedFile, ParsedFunction, ParsedClass +from refactron.rag.parser import CodeParser, ParsedClass, ParsedFile, ParsedFunction class TestCodeParser: @@ -36,11 +36,11 @@ def another_function(): class TestClass: """A test class.""" - + def method_one(self): """First method.""" return 1 - + def method_two(self, param): """Second method.""" return param * 2 diff --git a/tests/test_rag_retriever.py b/tests/test_rag_retriever.py index 9b79c75..85f5dc4 100644 --- a/tests/test_rag_retriever.py +++ b/tests/test_rag_retriever.py @@ -2,9 +2,9 @@ import tempfile from pathlib import Path +from unittest.mock import MagicMock, Mock, patch import pytest -from unittest.mock import Mock, MagicMock, patch # Create a comprehensive mock for transformers transformers_mock = MagicMock() From 1dd3ce82512c08b72fec3cbf97b2f0d21a3f04a1 Mon Sep 17 00:00:00 2001 From: omsherikar Date: Sun, 8 Feb 2026 18:07:50 +0530 Subject: [PATCH 10/19] fix: add noqa comments and cleanup linting issues - Added noqa comments for unused imports in LLM/RAG modules - Fixed duplicate datetime import in repositories.py - Fixed f-strings without placeholders in analyze script - Fixed regex escape sequence in test_backend_client - Suppressed E203 (black/flake8 conflict) and E501 line length Remaining lints are non-critical and in test/dev code. --- complex_test_repo/core/engine.py | 41 ++++++++++++------------ complex_test_repo/data/processor.py | 22 +++++++------ complex_test_repo/main.py | 11 ++++--- complex_test_repo/utils/math_lib.py | 10 +++--- complex_test_repo/utils/string_helper.py | 26 ++++++++------- 5 files changed, 58 insertions(+), 52 deletions(-) diff --git a/complex_test_repo/core/engine.py b/complex_test_repo/core/engine.py index 7945a08..ed2a59c 100644 --- a/complex_test_repo/core/engine.py +++ b/complex_test_repo/core/engine.py @@ -1,50 +1,51 @@ - import os import sqlite3 + class ProcessingEngine: - '''processing engine class. - + """processing engine class. + Attributes: attribute1: Description of attribute1 attribute2: Description of attribute2 - ''' + """ + def __init__(self, mode="all"): - ''' + """ init . - + Args: self: Class instance mode: The mode - ''' + """ self.mode = mode self.db = sqlite3.connect(":memory:") - + def execute(self, command): - ''' + """ Execute. - + Args: self: Class instance command: The command - ''' + """ # Security risk: Command injection if self.mode == "dangerous": os.system(command) else: print(f"Executing: {command}") - + def query_user(self, user_id): - ''' + """ Query user. - + Args: self: Class instance user_id: Unique identifier - + Returns: The result of the operation - ''' + """ # Security risk: SQL injection cursor = self.db.cursor() query = f"SELECT * FROM users WHERE id = {user_id}" @@ -52,16 +53,16 @@ def query_user(self, user_id): return cursor.fetchone() def process_items(self, items): - ''' + """ Process items. - + Args: self: Class instance items: The items - + Returns: The result of the operation - ''' + """ # Performance issue: String concatenation in loop result = "" for item in items: diff --git a/complex_test_repo/data/processor.py b/complex_test_repo/data/processor.py index 97ce3f2..c3e4eff 100644 --- a/complex_test_repo/data/processor.py +++ b/complex_test_repo/data/processor.py @@ -1,16 +1,16 @@ - import time + def process_batch(data_list): - ''' + """ Process batch. - + Args: data_list: Data to process - + Returns: The result of the operation - ''' + """ # Performance issue: N+1 pattern or inefficient iteration results = [] for item in data_list: @@ -19,20 +19,22 @@ def process_batch(data_list): results.append(detail) return results + def get_item_detail(item): - ''' + """ Get item detail. - + Args: item: The item - + Returns: The requested item detail - ''' + """ return {"id": item, "details": "example"} + def deep_nesting_example(a, b, c, d): - '''Refactored version using early returns (guard clauses).''' + """Refactored version using early returns (guard clauses).""" # Check invalid conditions first and return early if not a: return default_value diff --git a/complex_test_repo/main.py b/complex_test_repo/main.py index 388b0b8..8adb338 100644 --- a/complex_test_repo/main.py +++ b/complex_test_repo/main.py @@ -1,24 +1,25 @@ - -from utils.math_lib import legacy_compute from core.engine import ProcessingEngine from data.processor import process_batch +from utils.math_lib import legacy_compute THRESHOLD_10 = 10 THRESHOLD_20 = 20 CONSTANT_3 = 3 + def run(): - ''' + """ Run. - ''' + """ engine = ProcessingEngine() val = legacy_compute(THRESHOLD_10, THRESHOLD_20) print(f"Result: {val}") engine.execute("ls") - + data = [1, 2, CONSTANT_3] processed = process_batch(data) print(f"Processed: {processed}") + if __name__ == "__main__": run() diff --git a/complex_test_repo/utils/math_lib.py b/complex_test_repo/utils/math_lib.py index 7307320..33ed31f 100644 --- a/complex_test_repo/utils/math_lib.py +++ b/complex_test_repo/utils/math_lib.py @@ -1,20 +1,20 @@ - import os SURCHARGE_RATE = 1.05 CONSTANT_0_95 = 0.95 + def legacy_compute(a, b): - ''' + """ Legacy compute. - + Args: a: The a b: The b - + Returns: The result of the operation - ''' + """ # Magic numbers and no docstring res = a * SURCHARGE_RATE + b * CONSTANT_0_95 return res diff --git a/complex_test_repo/utils/string_helper.py b/complex_test_repo/utils/string_helper.py index 00e55e0..185c5ab 100644 --- a/complex_test_repo/utils/string_helper.py +++ b/complex_test_repo/utils/string_helper.py @@ -1,34 +1,36 @@ - -import sys import os +import sys + def clean_text(text): - ''' + """ Clean text. - + Args: text: The text - + Returns: The result of the operation - ''' + """ return text.strip().lower() + def clean_content(content): - ''' + """ Clean content. - + Args: content: The content - + Returns: The result of the operation - ''' + """ return content.strip().lower() + def helper_unused_private(): - ''' + """ Helper unused private. - ''' + """ # This function is defined but not used within this file pass From adeade690ec1846f10f3a700b8562bfd1a68973c Mon Sep 17 00:00:00 2001 From: omsherikar Date: Sun, 8 Feb 2026 18:20:19 +0530 Subject: [PATCH 11/19] fix: resolve Python 3.8 compatibility issues in RAG parser and fingerprinting - Implemented robust tree-sitter Language initialization to handle multiple API versions - Added ast.dump fallback for ast.unparse to fix code anonymization on Python 3.8 - Addresses 18 RAG errors and 2 pattern fingerprinting failures in CI --- refactron/patterns/fingerprint.py | 5 +++-- refactron/rag/parser.py | 22 ++++++++++++++++------ 2 files changed, 19 insertions(+), 8 deletions(-) diff --git a/refactron/patterns/fingerprint.py b/refactron/patterns/fingerprint.py index beda724..050b0b5 100644 --- a/refactron/patterns/fingerprint.py +++ b/refactron/patterns/fingerprint.py @@ -212,8 +212,9 @@ def _anonymize_code(self, code: str) -> str: if hasattr(ast, "unparse"): return self._normalize_code(ast.unparse(tree)) - # Fallback to standard normalization if unparse is not available - return self._normalize_code(code) + # Fallback for Python 3.8: Use ast.dump for a stable structural representation + # We use this as a source for the hash, so it just needs to be consistent + return self._normalize_code(ast.dump(tree)) except (SyntaxError, ValueError): # If AST parsing fails, fallback to basic normalization diff --git a/refactron/rag/parser.py b/refactron/rag/parser.py index e983cc1..d48f49b 100644 --- a/refactron/rag/parser.py +++ b/refactron/rag/parser.py @@ -59,12 +59,22 @@ def __init__(self): ) # Initialize Python language - handle different tree-sitter API versions - # Older versions (e.g., in Python 3.8) require 'name' parameter - try: - PY_LANGUAGE = Language(tspython.language(), "python") - except TypeError: - # Newer API doesn't need name parameter - PY_LANGUAGE = Language(tspython.language()) + lang = tspython.language() + + # In some versions, tspython.language() already returns a Language object + if isinstance(lang, Language): + PY_LANGUAGE = lang + else: + # Try newer API first (single argument) + try: + PY_LANGUAGE = Language(lang) + except TypeError: + # Try older API (needs name) + try: + PY_LANGUAGE = Language(lang, "python") + except TypeError: + # Last resort: try as keyword + PY_LANGUAGE = Language(lang, name="python") self.parser = Parser(PY_LANGUAGE) From d5db79ea138cb1fa73301bb52c374aec58c50f80 Mon Sep 17 00:00:00 2001 From: Om Sherikar Date: Sun, 8 Feb 2026 18:21:25 +0530 Subject: [PATCH 12/19] Delete complex_test_repo/core/__init__.py --- complex_test_repo/core/__init__.py | 1 - 1 file changed, 1 deletion(-) delete mode 100644 complex_test_repo/core/__init__.py diff --git a/complex_test_repo/core/__init__.py b/complex_test_repo/core/__init__.py deleted file mode 100644 index 6d2adc4..0000000 --- a/complex_test_repo/core/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Core engine module.""" From 957c7821532f8f8491740013e1ef47d11a382588 Mon Sep 17 00:00:00 2001 From: Om Sherikar Date: Sun, 8 Feb 2026 18:22:03 +0530 Subject: [PATCH 13/19] Delete complex_test_repo/core/engine.py --- complex_test_repo/core/engine.py | 70 -------------------------------- 1 file changed, 70 deletions(-) delete mode 100644 complex_test_repo/core/engine.py diff --git a/complex_test_repo/core/engine.py b/complex_test_repo/core/engine.py deleted file mode 100644 index ed2a59c..0000000 --- a/complex_test_repo/core/engine.py +++ /dev/null @@ -1,70 +0,0 @@ -import os -import sqlite3 - - -class ProcessingEngine: - """processing engine class. - - Attributes: - attribute1: Description of attribute1 - attribute2: Description of attribute2 - """ - - def __init__(self, mode="all"): - """ - init . - - Args: - self: Class instance - mode: The mode - """ - self.mode = mode - self.db = sqlite3.connect(":memory:") - - def execute(self, command): - """ - Execute. - - Args: - self: Class instance - command: The command - """ - # Security risk: Command injection - if self.mode == "dangerous": - os.system(command) - else: - print(f"Executing: {command}") - - def query_user(self, user_id): - """ - Query user. - - Args: - self: Class instance - user_id: Unique identifier - - Returns: - The result of the operation - """ - # Security risk: SQL injection - cursor = self.db.cursor() - query = f"SELECT * FROM users WHERE id = {user_id}" - cursor.execute(query) - return cursor.fetchone() - - def process_items(self, items): - """ - Process items. - - Args: - self: Class instance - items: The items - - Returns: - The result of the operation - """ - # Performance issue: String concatenation in loop - result = "" - for item in items: - result += str(item) + "," - return result From 58009d7222ae93ffe52cdd7c94264e755f23c3e7 Mon Sep 17 00:00:00 2001 From: Om Sherikar Date: Sun, 8 Feb 2026 18:22:41 +0530 Subject: [PATCH 14/19] Delete complex_test_repo/data/__init__.py --- complex_test_repo/data/__init__.py | 1 - 1 file changed, 1 deletion(-) delete mode 100644 complex_test_repo/data/__init__.py diff --git a/complex_test_repo/data/__init__.py b/complex_test_repo/data/__init__.py deleted file mode 100644 index 3da002b..0000000 --- a/complex_test_repo/data/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Data processing module.""" From bae9ef91c8bac9ee1c472956273e17a6ac952a01 Mon Sep 17 00:00:00 2001 From: Om Sherikar Date: Sun, 8 Feb 2026 18:23:06 +0530 Subject: [PATCH 15/19] Delete complex_test_repo/data/processor.py --- complex_test_repo/data/processor.py | 51 ----------------------------- 1 file changed, 51 deletions(-) delete mode 100644 complex_test_repo/data/processor.py diff --git a/complex_test_repo/data/processor.py b/complex_test_repo/data/processor.py deleted file mode 100644 index c3e4eff..0000000 --- a/complex_test_repo/data/processor.py +++ /dev/null @@ -1,51 +0,0 @@ -import time - - -def process_batch(data_list): - """ - Process batch. - - Args: - data_list: Data to process - - Returns: - The result of the operation - """ - # Performance issue: N+1 pattern or inefficient iteration - results = [] - for item in data_list: - # Simulating sub-query or heavy processing in loop - detail = get_item_detail(item) - results.append(detail) - return results - - -def get_item_detail(item): - """ - Get item detail. - - Args: - item: The item - - Returns: - The requested item detail - """ - return {"id": item, "details": "example"} - - -def deep_nesting_example(a, b, c, d): - """Refactored version using early returns (guard clauses).""" - # Check invalid conditions first and return early - if not a: - return default_value - - # Each subsequent check is at the same level - no deep nesting - if not meets_requirement_1(): - return early_result_1 - - if not meets_requirement_2(): - return early_result_2 - - # Main logic is at top level - easy to read - result = perform_main_operation() - return result From 3db4b3c307cfeab6500b1ac5b82da7c104e04d76 Mon Sep 17 00:00:00 2001 From: Om Sherikar Date: Sun, 8 Feb 2026 18:23:32 +0530 Subject: [PATCH 16/19] Delete complex_test_repo/utils/__init__.py --- complex_test_repo/utils/__init__.py | 1 - 1 file changed, 1 deletion(-) delete mode 100644 complex_test_repo/utils/__init__.py diff --git a/complex_test_repo/utils/__init__.py b/complex_test_repo/utils/__init__.py deleted file mode 100644 index 1bfd38b..0000000 --- a/complex_test_repo/utils/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Utility functions module.""" From a8bf2b3d4f731306824a617471a73fc0f01242a0 Mon Sep 17 00:00:00 2001 From: Om Sherikar Date: Sun, 8 Feb 2026 18:23:53 +0530 Subject: [PATCH 17/19] Delete complex_test_repo/utils/math_lib.py --- complex_test_repo/utils/math_lib.py | 20 -------------------- 1 file changed, 20 deletions(-) delete mode 100644 complex_test_repo/utils/math_lib.py diff --git a/complex_test_repo/utils/math_lib.py b/complex_test_repo/utils/math_lib.py deleted file mode 100644 index 33ed31f..0000000 --- a/complex_test_repo/utils/math_lib.py +++ /dev/null @@ -1,20 +0,0 @@ -import os - -SURCHARGE_RATE = 1.05 -CONSTANT_0_95 = 0.95 - - -def legacy_compute(a, b): - """ - Legacy compute. - - Args: - a: The a - b: The b - - Returns: - The result of the operation - """ - # Magic numbers and no docstring - res = a * SURCHARGE_RATE + b * CONSTANT_0_95 - return res From 2084221d6c107da6c5447a3758c4f8f6028488f0 Mon Sep 17 00:00:00 2001 From: Om Sherikar Date: Sun, 8 Feb 2026 18:24:16 +0530 Subject: [PATCH 18/19] Delete complex_test_repo/utils/string_helper.py --- complex_test_repo/utils/string_helper.py | 36 ------------------------ 1 file changed, 36 deletions(-) delete mode 100644 complex_test_repo/utils/string_helper.py diff --git a/complex_test_repo/utils/string_helper.py b/complex_test_repo/utils/string_helper.py deleted file mode 100644 index 185c5ab..0000000 --- a/complex_test_repo/utils/string_helper.py +++ /dev/null @@ -1,36 +0,0 @@ -import os -import sys - - -def clean_text(text): - """ - Clean text. - - Args: - text: The text - - Returns: - The result of the operation - """ - return text.strip().lower() - - -def clean_content(content): - """ - Clean content. - - Args: - content: The content - - Returns: - The result of the operation - """ - return content.strip().lower() - - -def helper_unused_private(): - """ - Helper unused private. - """ - # This function is defined but not used within this file - pass From a1e41d80c11fb5941deceed68aeeb6a758679d27 Mon Sep 17 00:00:00 2001 From: Om Sherikar Date: Sun, 8 Feb 2026 18:24:36 +0530 Subject: [PATCH 19/19] Delete complex_test_repo/main.py --- complex_test_repo/main.py | 25 ------------------------- 1 file changed, 25 deletions(-) delete mode 100644 complex_test_repo/main.py diff --git a/complex_test_repo/main.py b/complex_test_repo/main.py deleted file mode 100644 index 8adb338..0000000 --- a/complex_test_repo/main.py +++ /dev/null @@ -1,25 +0,0 @@ -from core.engine import ProcessingEngine -from data.processor import process_batch -from utils.math_lib import legacy_compute - -THRESHOLD_10 = 10 -THRESHOLD_20 = 20 -CONSTANT_3 = 3 - - -def run(): - """ - Run. - """ - engine = ProcessingEngine() - val = legacy_compute(THRESHOLD_10, THRESHOLD_20) - print(f"Result: {val}") - engine.execute("ls") - - data = [1, 2, CONSTANT_3] - processed = process_batch(data) - print(f"Processed: {processed}") - - -if __name__ == "__main__": - run()