diff --git a/core/prompt.py b/core/prompt.py deleted file mode 100644 index f6cc554..0000000 --- a/core/prompt.py +++ /dev/null @@ -1,27 +0,0 @@ -import os -import getpass -from datetime import datetime -from colorama import Fore, Style -from .git import get_git_info - -def get_prompt(): - """Constructs a customized prompt with user, host, path, git and venv info.""" - user = getpass.getuser() - host = os.uname().nodename - path = os.getcwd() - if len(path) > 42: - path = "../" + os.path.basename(path) - git_info = get_git_info() - venv = os.environ.get('VIRTUAL_ENV', '').split('/')[-1] if os.environ.get('VIRTUAL_ENV') else "" - current_time = datetime.now().strftime('%H:%M:%S') - prompt_parts = [ - f"{Fore.RED}[SheLLM]{Style.RESET_ALL}", - f"{Fore.BLUE}[{current_time}]{Style.RESET_ALL}", - f"{Fore.GREEN}{user}{Style.RESET_ALL}@{host}:", - f"{Fore.GREEN}{path}{Style.RESET_ALL}" - ] - if git_info: - prompt_parts.append(f"{Fore.CYAN}({git_info}){Style.RESET_ALL}") - if venv: - prompt_parts.append(f"{Fore.MAGENTA}(venv:{venv}){Style.RESET_ALL}") - return ' '.join(prompt_parts) + "\n>" diff --git a/core/prompts.py b/core/prompts.py new file mode 100644 index 0000000..e1f259d --- /dev/null +++ b/core/prompts.py @@ -0,0 +1,65 @@ +import os +import getpass +from datetime import datetime +from colorama import Fore, Style + +from .git import get_git_info +from utils.schemas import Context + + +def get_prompt(): + """Constructs a customized prompt with user, host, path, git and venv info.""" + user = getpass.getuser() + host = os.uname().nodename + path = os.getcwd() + if len(path) > 42: + path = "../" + os.path.basename(path) + git_info = get_git_info() + venv = os.environ.get('VIRTUAL_ENV', '').split('/')[-1] if os.environ.get('VIRTUAL_ENV') else "" + current_time = datetime.now().strftime('%H:%M:%S') + prompt_parts = [ + f"{Fore.RED}[SheLLM]{Style.RESET_ALL}", + f"{Fore.BLUE}[{current_time}]{Style.RESET_ALL}", + f"{Fore.GREEN}{user}{Style.RESET_ALL}@{host}:", + f"{Fore.GREEN}{path}{Style.RESET_ALL}" + ] + if git_info: + prompt_parts.append(f"{Fore.CYAN}({git_info}){Style.RESET_ALL}") + if venv: + prompt_parts.append(f"{Fore.MAGENTA}(venv:{venv}){Style.RESET_ALL}") + return ' '.join(prompt_parts) + "\n>" + + +def generate_shell_system_prompt(context: Context) -> str: + """System prompt for when a shell command is to be generated.""" + return ( + "You are SheLLM, a shell command generator. Your task is to generate " + "accurate shell commands for a highly skilled Linux user. The user " + "expects precise, context-aware suggestions" + "The user's history of commands and their outputs from their current " + "linux terminal session are given to you " + f"below and should be analyzed to understand their patterns and " + f"goals:\n{context.session_history}\n\n" + "The user's most recent command and its output are given to you " + "below - prioritize them as the primary basis " + "for inference, while still considering the broader context of the " + "given Shell Session history for" + "additional insights." + f"Most prior command from user:\n{context.last_command}\n" + f"Response to the most prior command:\n{context.last_command}\n\n" + "Your output must consist solely of shell commands, with no " + "explanations, additional information, comments, " + "or symbols not part of the command syntax." + ) + + +def generate_qa_system_prompt(context: Context) -> str: + """System prompt for when a semantic question is asked.""" + return ( + "You are SheLLM, a shell command specialist. Your task is to not " + "discuss other topics, provide short, accurate," + " extremely concise, and context-aware shell commands and shell " + "scripting related topics knowledge to a highly " + f"skilled Linux user. Use for context the user's current terminal " + f"session history:\n{context.session_history}\n\n" + ) diff --git a/core/shellm.py b/core/shellm.py index 8816fd2..05423aa 100644 --- a/core/shellm.py +++ b/core/shellm.py @@ -2,17 +2,19 @@ import logging from datetime import datetime from colorama import Fore, Style + from .commands import change_directory, run_command_with_pty from .ssh import run_interactive_ssh -from .prompt import get_prompt +from utils.schemas import Context from models.openai_model import OpenAIModel from models.groq_model import GroqModel logger = logging.getLogger(__name__) + class SheLLM: def __init__(self, llm_api): - self.context = "" + self.context = Context() self.history = [] self.current_process_pid = None if llm_api == 'groq': @@ -22,17 +24,20 @@ def __init__(self, llm_api): self.ssh_session = None logger.info(f"SheLLM initialized with {llm_api} model.") - def update_context(self, output): - """Updates the context with new terminal output.""" - self.context += output + "\n" - logger.debug(f"Updated context: {self.context}") + def update_context(self, command, output) -> None: + """Updates the context object with the last command and its output.""" + self.context.last_command = command + self.context.last_output = output + self.context.update_session_history(command, output) + + logger.debug(f"Updated the context after the last command of >{self.context.last_command})") # noqa def execute_system_command(self, command): """Executes system commands and captures output.""" if not command.strip(): logger.info("No command entered. Please enter a valid command.") return - + tokens = command.split() if tokens[0] == 'cd': change_directory(tokens) @@ -42,14 +47,16 @@ def execute_system_command(self, command): run_interactive_ssh(tokens, self) else: output = run_command_with_pty(command) - self.update_context(output) + self.update_context(command, output) def handle_lm_command(self, command, remote=False): """Handles commands generated by the language model.""" while True: - suggestion = self.model.get_command_suggestion(self.context, command) + suggestion = self.model.get_command_suggestion( + context = self.context, + prompt = command + ) if suggestion: - current_time = datetime.now().strftime('%H:%M:%S') logger.info(f"Execute command: {Fore.RED}{suggestion}{Style.RESET_ALL}") response = input(f"{Fore.RED}[SheLLM]{Style.RESET_ALL} Confirm execution (Y/n/r)").lower() if response == 'y': diff --git a/core/ssh.py b/core/ssh.py index 5e2f0cb..215a4e3 100644 --- a/core/ssh.py +++ b/core/ssh.py @@ -3,10 +3,12 @@ import sys import select import logging -from .prompt import get_prompt + +from .prompts import get_prompt logger = logging.getLogger(__name__) + def run_interactive_ssh(tokens, shellm): """Runs an interactive SSH session.""" try: diff --git a/main.py b/main.py index 802d76f..202904e 100644 --- a/main.py +++ b/main.py @@ -6,7 +6,7 @@ from colorama import init, Fore, Style from config.logger_setup import setup_logging from core.shellm import SheLLM -from core.prompt import get_prompt +from core.prompts import get_prompt # Configure logging setup_logging() diff --git a/models/groq_model.py b/models/groq_model.py index 5d05587..6938ac6 100644 --- a/models/groq_model.py +++ b/models/groq_model.py @@ -3,7 +3,9 @@ from groq import Groq from dotenv import load_dotenv from config.logger_setup import setup_logging +from core import prompts from utils.sanitizer import remove_code_block +from utils.schemas import Context # Configure logging setup_logging() @@ -66,18 +68,18 @@ def validate_command(self, command): logger.error(f"Error fetching suggestion from Groq: {e}") return None - def get_command_suggestion(self, context, prompt): - """Generates shell commands based on the provided context and prompt.""" - logger.debug(f"Generating command suggestion for context: {context} and prompt: {prompt}") + def get_command_suggestion( + self, + context: Context, + prompt: str + ): + """Generates shell commands based on context and a prompt.""" + logger.debug(f"Generating command suggestion from {self.__class__.__name__} and for prompt: {prompt}") # noqa try: messages = [ { "role": "system", - "content": "You serve as a dedicated assistant that exclusively generates shell commands. Given the context provided, proactively discern the user's requirements and provide the most suitable command. Exclude any comments or flags in your output." - }, - { - "role": "user", - "content": context + "content": prompts.generate_shell_system_prompt(context) }, { "role": "user", @@ -100,18 +102,18 @@ def get_command_suggestion(self, context, prompt): logger.error(f"Error fetching suggestion from Groq: {e}") return None - def answer_question(self, context, question): - """Generates answers to questions based on the provided context and question.""" - logger.debug(f"Answering question for context: {context} and question: {question}") + def answer_question( + self, + context: Context, + question: str + ) -> str | None: + """Generates answers to semantic questions.""" + logger.debug(f"Answering question for context: {context.session_history} and question: {question}") # noqa try: messages = [ { "role": "system", - "content": "You are a knowledgeable assistant who provides detailed answers to questions." - }, - { - "role": "user", - "content": context + "content": prompts.generate_qa_system_prompt(context) }, { "role": "user", diff --git a/models/openai_model.py b/models/openai_model.py index 59895ed..0449f2a 100644 --- a/models/openai_model.py +++ b/models/openai_model.py @@ -2,13 +2,16 @@ import os import logging from dotenv import load_dotenv +from core import prompts from config.logger_setup import setup_logging +from utils.schemas import Context from utils.sanitizer import remove_code_block # Configure logging setup_logging() logger = logging.getLogger(__name__) + class OpenAIModel: def __init__(self): logger.debug("Initializing OpenAIModel...") @@ -67,24 +70,25 @@ def validate_command(self, command): logger.error(f"Error fetching suggestion from OpenAI: {e}") return None - def get_command_suggestion(self, context, prompt): - """Generates shell commands based on the provided context and prompt.""" - logger.debug(f"Generating command suggestion for context: {context} and prompt: {prompt}") + def get_command_suggestion( + self, + context: Context, + prompt: str + ) -> str | None: + """Generates shell commands based on context and a prompt.""" + logger.debug(f"Generating command suggestion from {self.__class__.__name__} and prompt: {prompt}") # noqa try: messages = [ { "role": "system", - "content": "You are a helpful assistant that must only output shell commands and nothing else. Anticipate the user's needs and provide the best possible solution. Do not include any comments or flags in the output." - }, - { - "role": "user", - "content": context + "content": prompts.generate_shell_system_prompt(context) }, { "role": "user", "content": prompt } ] + logger.debug(f"Messages sent to API: {messages}") response = self.client.chat.completions.create( model="gpt-4o", @@ -104,18 +108,18 @@ def get_command_suggestion(self, context, prompt): logger.error(f"Error fetching suggestion from OpenAI: {e}") return None - def answer_question(self, context, question): - """Generates answers to questions based on the provided context and question.""" - logger.debug(f"Answering question for context: {context} and question: {question}") + def answer_question( + self, + context: Context, + question: str + ) -> str | None: + """Generates answers to semantic questions.""" + logger.debug(f"Answering question for context: {context.session_history} and question: {question}") # noqa try: messages = [ { "role": "system", - "content": "You are a knowledgeable assistant who provides detailed answers to questions." - }, - { - "role": "user", - "content": context + "content": prompts.generate_qa_system_prompt(context) }, { "role": "user", diff --git a/requirements.txt b/requirements.txt index 37f4bcb..dad3771 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,3 +3,4 @@ openai==1.33.0 click==8.1.7 colorama==0.4.6 groq==0.8.0 +pydantic==2.9.2 diff --git a/utils/schemas.py b/utils/schemas.py new file mode 100644 index 0000000..c2a94e3 --- /dev/null +++ b/utils/schemas.py @@ -0,0 +1,11 @@ +from pydantic import BaseModel + + +class Context(BaseModel): + last_command: str = "echo 'Hello, World!'" + last_output: str = "Hello, World!\n" + session_history: str = "" + + def update_session_history(self, command: str, output: str) -> None: + """Updates session_history with the last command and output.""" + self.session_history += f"> {command}\n{output}\n"