Skip to content

Canonical example produces unexpected output #12

@jaraco

Description

@jaraco

I've managed to get the example to work by tweaking the requirements, but even after running the example

import duckdb
from llama_cpp import Llama
from wurlitzer import pipes

from examples.utils import generate_sql

# Set up client with model path and context size
with pipes() as (out, err):
    client = Llama(
        model_path="examples/DuckDB-NSQL-7B-v0.1-q8_0.gguf",
        n_ctx=2048,
    )

# Connect to DuckDB database
con = duckdb.connect("examples/nyc.duckdb")

# Sample question for SQL generation
question = "get all columns from taxi table starting with 'a'"

# Generate SQL, check validity, and print
query = generate_sql(question, con, client)
print(query)

print(con.execute(query).fetchdf())

The output I get for the query doesn't seem to correspond to the intention:

 DuckDB-NSQL debt/just-install @ py -m example
 SELECT COLUMNS('a.*') FROM taxi;
 SELECT COLUMNS('a.*') FROM taxi;
    tpep_pickup_datetime tpep_dropoff_datetime  passenger_count  ...  total_amount  congestion_surcharge airport_fee
0    2022-11-04 00:51:52   2022-11-04 01:02:08              1.0  ...         15.80                   2.5        0.00
1    2022-11-04 00:25:29   2022-11-04 00:39:51              5.0  ...         19.56                   2.5        0.00
2    2022-11-04 00:43:21   2022-11-04 00:54:51              5.0  ...         18.36                   2.5        0.00
3    2022-11-04 00:05:49   2022-11-04 00:21:23              1.0  ...         18.96                   2.5        0.00
4    2022-11-04 00:35:49   2022-11-04 00:35:53              1.0  ...         -5.05                   0.0       -1.25
..                   ...                   ...              ...  ...           ...                   ...         ...
995  2022-11-04 00:40:37   2022-11-04 00:48:56              1.0  ...         13.30                   2.5        0.00
996  2022-11-04 00:57:24   2022-11-04 01:27:29              1.0  ...         32.80                   2.5        0.00
997  2022-11-04 01:29:40   2022-11-04 01:56:05              1.0  ...         25.70                   2.5        0.00
998  2022-11-04 01:44:59   2022-11-04 01:53:23              1.0  ...         12.80                   2.5        0.00
999  2022-11-04 01:06:47   2022-11-04 01:24:46              2.0  ...         25.55                   2.5        0.00

[1000 rows x 18 columns]

It's not outputting the columns that start with a. It's not even outputting rows for columns that start with a. In fact, it seems to be emitting all the rows from the taxi table.

If the LLM can't produce basic queries, I have little hope for it doing anything more useful.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions