diff --git a/pyproject.toml b/pyproject.toml index 4267f65319..5b66f8ba32 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -107,7 +107,6 @@ duckdb = [] fabric = ["pyodbc>=5.0.0"] gcppostgres = ["cloud-sql-python-connector[pg8000]>=1.8.0"] github = ["PyGithub>=2.6.0"] -llm = ["langchain", "openai"] motherduck = ["duckdb>=1.2.0"] mssql = ["pymssql"] mssql-odbc = ["pyodbc>=5.0.0"] @@ -213,7 +212,6 @@ module = [ "pymssql.*", "pyodbc.*", "psycopg2.*", - "langchain.*", "pytest_lazyfixture.*", "dbt.adapters.*", "slack_sdk.*", diff --git a/sqlmesh/cli/main.py b/sqlmesh/cli/main.py index 2d8673405f..2f18c0a4b7 100644 --- a/sqlmesh/cli/main.py +++ b/sqlmesh/cli/main.py @@ -1079,50 +1079,6 @@ def rewrite(obj: Context, sql: str, read: str = "", write: str = "") -> None: ) -@cli.command("prompt") -@click.argument("prompt") -@click.option( - "-e", - "--evaluate", - is_flag=True, - help="Evaluate the generated SQL query and display the results.", -) -@click.option( - "-t", - "--temperature", - type=float, - help="Sampling temperature. 0.0 - precise and predictable, 0.5 - balanced, 1.0 - creative. Default: 0.7", - default=0.7, -) -@opt.verbose -@click.pass_context -@error_handler -@cli_analytics -def prompt( - ctx: click.Context, - prompt: str, - evaluate: bool, - temperature: float, - verbose: int, -) -> None: - """Uses LLM to generate a SQL query from a prompt.""" - from sqlmesh.integrations.llm import LLMIntegration - - context = ctx.obj - - llm_integration = LLMIntegration( - context.models.values(), - context.engine_adapter.dialect, - temperature=temperature, - verbosity=Verbosity(verbose), - ) - query = llm_integration.query(prompt) - - context.console.log_status_update(query) - if evaluate: - context.console.log_success(context.fetchdf(query)) - - @cli.command("clean") @click.pass_obj @error_handler diff --git a/sqlmesh/integrations/llm.py b/sqlmesh/integrations/llm.py deleted file mode 100644 index a44ec79997..0000000000 --- a/sqlmesh/integrations/llm.py +++ /dev/null @@ -1,56 +0,0 @@ -from __future__ import annotations - -import typing as t - -from langchain import LLMChain, PromptTemplate -from langchain.chat_models import ChatOpenAI - -from sqlmesh.utils import Verbosity -from sqlmesh.core.model import Model - -_QUERY_PROMPT_TEMPLATE = """Given an input request, create a syntactically correct {dialect} SQL query. -Use full table names. -Convert string operands to lowercase in the WHERE clause. -Reply with a SQL query and nothing else. - -Use the following tables and columns: - -{table_info} - -Request: {input}""" - - -class LLMIntegration: - def __init__( - self, - models: t.Iterable[Model], - dialect: str, - temperature: float = 0.7, - verbosity: Verbosity = Verbosity.DEFAULT, - ): - query_prompt_template = PromptTemplate.from_template(_QUERY_PROMPT_TEMPLATE).partial( - dialect=dialect, table_info=_to_table_info(models) - ) - llm = ChatOpenAI(temperature=temperature) # type: ignore - self._query_chain = LLMChain( - llm=llm, prompt=query_prompt_template, verbose=verbosity >= Verbosity.VERBOSE - ) - - def query(self, prompt: str) -> str: - result = self._query_chain.predict(input=prompt).strip() - select_pos = result.find("SELECT") - if select_pos >= 0: - return result[select_pos:] - return result - - -def _to_table_info(models: t.Iterable[Model]) -> str: - infos = [] - for model in models: - if not model.kind.is_materialized: - continue - - columns_csv = ", ".join(model.columns_to_types_or_raise) - infos.append(f"Table: {model.name}\nColumns: {columns_csv}\n") - - return "\n".join(infos)