diff --git a/.appignore b/.appignore new file mode 100644 index 0000000..76a5059 --- /dev/null +++ b/.appignore @@ -0,0 +1,17 @@ +.DS_Store +venv/ +.sidekickvenv/ +var/ +.git/ +.idea/ +*/__pycache__/ +scripts/ +setup_cythonize/ +.sh +build/ +dist/ +tests/ +ci/ +examples/sleep_eda/ +examples/telemetry/ +.log diff --git a/.github/ISSUE_TEMPLATE/documentation_issue.md b/.github/ISSUE_TEMPLATE/documentation_issue.md new file mode 100644 index 0000000..8775328 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/documentation_issue.md @@ -0,0 +1,15 @@ +--- +name: "\U0001F4C3 Documentation" +about: Create a documentation issue or request to help us improve the sql-sidekick repo +title: "[DOCS]" +labels: area/documentation +assignees: '5675sp' +--- + +### πŸ“ƒ Documentation issue/request + + + +### Documentation version + + diff --git a/.github/workflows/deploy-to-github-pages.yml b/.github/workflows/deploy-to-github-pages.yml new file mode 100644 index 0000000..8261135 --- /dev/null +++ b/.github/workflows/deploy-to-github-pages.yml @@ -0,0 +1,32 @@ +name: Deploy to GitHub pages + +on: + workflow_dispatch: + +jobs: + deploy: + name: Deploy to GitHub Pages + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - uses: actions/setup-node@v3 + with: + always-auth: true + registry-url: https://npm.pkg.github.com/ + node-version: 18 + cache: npm + cache-dependency-path: documentation/package-lock.json + + - name: Install dependencies + run: cd documentation && npm install --frozen-lockfile + env: + NODE_AUTH_TOKEN: ${{ secrets.GITHUB_TOKEN }} + - name: Build docs + run: cd documentation && npm run build + - name: Deploy to GitHub Pages + uses: peaceiris/actions-gh-pages@v3 + with: + github_token: ${{ secrets.GITHUB_TOKEN }} + publish_dir: ./documentation/tmp/build + user_name: 5675sp ##swap username out with the username of someone with admin access to the repo + user_email: sergio.perez@h2o.ai ##swap email out with the email of someone with admin access to the repo diff --git a/.gitignore b/.gitignore index 68bc17f..2dc53ca 100644 --- a/.gitignore +++ b/.gitignore @@ -157,4 +157,4 @@ cython_debug/ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. -#.idea/ +.idea/ diff --git a/.vscode/extensions.json b/.vscode/extensions.json new file mode 100644 index 0000000..3a390d4 --- /dev/null +++ b/.vscode/extensions.json @@ -0,0 +1,5 @@ +{ + "recommendations": [ + "ms-python.python" + ] +} \ No newline at end of file diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..c2f20fe --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,15 @@ +{ + "[python]": { + "editor.tabSize": 4, + "editor.defaultFormatter": "ms-python.black-formatter" + }, + "files.eol": "\n", + "files.insertFinalNewline": true, + "files.trimFinalNewlines": true, + "files.trimTrailingWhitespace": true, + "python.formatting.provider": "none", + "python.linting.enabled": true, + "python.linting.flake8Enabled": true, + "python.formatting.blackArgs": ["--line-length", "120"], + "python.linting.flake8Args": ["--max-line-length=120"], +} diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..b14d32d --- /dev/null +++ b/Makefile @@ -0,0 +1,41 @@ +sentence_transformer = s3cmd get --recursive --skip-existing s3://h2o-model-gym/models/nlp/sentence_trasnsformer/all-MiniLM-L6-v2/ ./models/sentence_transformers/sentence-transformers_all-MiniLM-L6-v2 +demo_data = s3cmd get --recursive --skip-existing s3://h2o-sql-sidekick/demo/sleepEDA/ ./examples/demo/ + +.PHONY: download_demo_data + +all: download_demo_data + +setup: download_demo_data ## Setup + python3 -m venv .sidekickvenv + ./.sidekickvenv/bin/python3 -m pip install --upgrade pip + ./.sidekickvenv/bin/python3 -m pip install wheel + ./.sidekickvenv/bin/python3 -m pip install -r requirements.txt + mkdir -p ./db/sqlite + mkdir -p ./examples/demo/ + +download_models: + mkdir -p ./models/sentence_transformers/sentence-transformers_all-MiniLM-L6-v2 + +download_demo_data: + mkdir -p ./examples/demo/ + $(demo_data) + +cloud_bundle: + h2o bundle -L debug 2>&1 | tee -a h2o-bundle.log + + +setup-doc: # Install documentation dependencies + cd documentation && npm install + +run-doc: # Run the doc locally + cd documentation && npm start + +update-documentation-infrastructure: + cd documentation && npm update @h2oai/makersaurus + cd documentation && npm ls + +build-doc-locally: # Bundles your website into static files for production + cd documentation && npm run build + +serve-doc-locally: # Serves the built website locally + cd documentation && npm run serve diff --git a/README.md b/README.md index 2ad6d58..af074ff 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,49 @@ # sql-sidekick -A simple sql assistant +A simple SQL assistant (WIP) + + +# Installation +## Dev +``` +1. git clone git@github.com:h2oai/sql-sidekick.git +2. cd sql-sidekick +3. make setup +4. source ./.sidekickvenv/bin/activate +5. python sidekick/prompter.py +``` +## Usage +``` +Dialect: postgres +- docker pull postgres (will pull the latest version) +- docker run --rm --name pgsql-dev -e POSTGRES_PASSWORD=abc -p 5432:5432 postgres + +Default: sqlite +Step: +- Download and install .whl --> s3://sql-sidekick/releases/sql_sidekick-0.0.3-py3-none-any.whl +- python3 -m venv .sidekickvenv +- source .sidekickvenv/bin/activate +- python3 -m pip install sql_sidekick-0.0.3-py3-none-any.whl +``` +## Start +``` +Welcome to the SQL Sidekick! I am an AI assistant that helps you with SQL +queries. I can help you with the following: + + 1. Configure a local database(for schema validation and syntax checking): + `sql-sidekick configure db-setup -t "/table_info.jsonl"` (e.g., format --> https://github.com/h2oai/sql-sidekick/blob/main/examples/telemetry/table_info.jsonl) + + 2. Ask a question: `sql-sidekick query -q "avg Gpus" -s "/samples.csv"` (e.g., format --> https://github.com/h2oai/sql-sidekick/blob/main/examples/telemetry/samples.csv) + + 3. Learn contextual query/answer pairs: `sql-sidekick learn add-samples` (optional) + + 4. Add context as key/value pairs: `sql-sidekick learn update-context` (optional) + +Options: + --version Show the version and exit. + --help Show this message and exit. + +Commands: + configure Helps in configuring local database. + learn Helps in learning and building memory. + query Asks question and returns SQL +``` diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/about.md b/about.md new file mode 100644 index 0000000..1b81978 --- /dev/null +++ b/about.md @@ -0,0 +1,12 @@ +**App Goal:** Demo-ware web client for SQL-Sidekick + +**Target Audience:** Data (Machine Learning) Scientists, Citizen Data Scientists, Data Engineers Managers and Business Analysts + +**Actively Being Maintained:** Yes (Demo release: _In active RnD_) + +**Last Updated:** September, 2023 + +**Allows uploading and using new model and data:** Yes + +**Detailed Description:** +An experimental demo to evaluate text-to-SQL capabilities of large language models (LLMs) to enable QnA for tabular data diff --git a/app.toml b/app.toml new file mode 100644 index 0000000..ec327a0 --- /dev/null +++ b/app.toml @@ -0,0 +1,23 @@ +[App] +name = "ai.h2o.wave.sql-sidekick" +title = "SQL-Sidekick" +description = "QnA with tabular data using NLQ" +LongDescription = "about.md" +Tags = ["DATA_SCIENCE", "MACHINE_LEARNING", "NLP"] +Version = "0.0.13" + +[Runtime] +MemoryLimit = "64Gi" +MemoryReservation = "16Gi" +module = "start" +VolumeMount = "/meta_data" +VolumeSize = "100Gi" +ResourceVolumeSize = "64Gi" +GPUCount = 1 +RuntimeVersion = "ub2004_cuda114_cudnn8_py38_wlatest_a10g" +RoutingMode = "BASE_URL" +EnableOIDC = true + +[[Env]] +Name = "H2O_WAVE_MAX_REQUEST_SIZE" +Value = "20M" diff --git a/documentation/.gitignore b/documentation/.gitignore new file mode 100644 index 0000000..f236fdf --- /dev/null +++ b/documentation/.gitignore @@ -0,0 +1,17 @@ +node_modules +tmp + +# Generated files +.docusaurus +.cach-loader + +# Misc +.DS_Store +.env.local +.env.development.local +.env.test.local +.env.production.local + +npm-debug.log* +yarn-debug.log* +yarn-error.log* diff --git a/documentation/README.md b/documentation/README.md new file mode 100644 index 0000000..fb4f68c --- /dev/null +++ b/documentation/README.md @@ -0,0 +1,20 @@ +# New Documentation Site + +What is the purpose of these docs? + +## Running this site + +This site was generated using Makersaurus, which is a very thin wrapper around Facebook's Docusaurus. You can write documentation in the typical way, using markdown files located in the `docs` folder and registering those files in `sidebars.js`. + +Use the following commands to run the generate the site and run it locally: + +``` +npx @h2oai/makersaurus@latest gen +cd gen +npm install +npm start +``` + +## More information + +Use the Makersaurus docs to earn how to edit docs, deploy the site, set up versioning and more. diff --git a/documentation/docs/admin-guide/page-1.md b/documentation/docs/admin-guide/page-1.md new file mode 100644 index 0000000..9608001 --- /dev/null +++ b/documentation/docs/admin-guide/page-1.md @@ -0,0 +1 @@ +# Page 1 diff --git a/documentation/docs/api-reference-guide/page-1.md b/documentation/docs/api-reference-guide/page-1.md new file mode 100644 index 0000000..9608001 --- /dev/null +++ b/documentation/docs/api-reference-guide/page-1.md @@ -0,0 +1 @@ +# Page 1 diff --git a/documentation/docs/application-name-logo.png b/documentation/docs/application-name-logo.png new file mode 100644 index 0000000..21283f4 Binary files /dev/null and b/documentation/docs/application-name-logo.png differ diff --git a/documentation/docs/concepts.md b/documentation/docs/concepts.md new file mode 100644 index 0000000..b353170 --- /dev/null +++ b/documentation/docs/concepts.md @@ -0,0 +1 @@ +# Concepts diff --git a/documentation/docs/faqs.md b/documentation/docs/faqs.md new file mode 100644 index 0000000..4e302e5 --- /dev/null +++ b/documentation/docs/faqs.md @@ -0,0 +1,13 @@ +# FAQs + +[application_description] + + +--- + +The below sections provide answers to frequently asked questions. If you have additional questions, please send them to . + + +## General + + diff --git a/documentation/docs/get-started/access-application-name/access-application-name.md b/documentation/docs/get-started/access-application-name/access-application-name.md new file mode 100644 index 0000000..41f83c0 --- /dev/null +++ b/documentation/docs/get-started/access-application-name/access-application-name.md @@ -0,0 +1,117 @@ +import Tabs from '@theme/Tabs'; +import TabItem from '@theme/TabItem'; +import Icon from "@material-ui/core/Icon"; + +# Access [application_name] + +You can access [application_name] through an instance that you can create on the H2O AI Cloud (**HAIC**). To access [application_name]: + +- [Step 1: Access HAIC](#step-1-access-haic) +- [Step 2: Search [application_name]](#step-2-search) +- [Step 3: Run [application_name]](#step-3-run-) +- [Step 4: [application_name] instance](#step-4-) + +## Step 1: Access HAIC + +Access your H2O AI Cloud (**HAIC**) account. + +## Step 2: Search [application_name] + +1. In HAIC, click **APP STORE**. +2. In the **HAIC** search bar, search `[application_name]`. + + +![search-bar](search-bar.png) + +Now, select the [application_name] tile. Details about [application_name] appear. + +## Step 3: Run [application_name] + +1. To start a [application_name] instance, click **Run**. + + ![h2o-sample-home-page](h2o-sample-home-page.png) + +## Step 4: [application_name] instance + +Now, the H2O AI Cloud is starting an instance of [application_name] for you. While you have a starting/running instance, the **Run** button will change its name to **Visit**. + +1. To open [application_name] in a new tab, click **Visit**. + +
+ +![go-to-instance](go-to-instance.png) + +
+ +:::info Note + - The latest version of [application_name] is preselected. + - In the [application_name] instance, several items will be installed. Right after, you will be able to use [application_name]. All items are automatically installed when you start an instance. +::: + +## Pause or terminate instance + +You can pause or terminate an instance of [application_name]. + +- **Pause**: Pausing an instance reduces computational resources (and is less expensive). In other words, the cost of having an instance decreases. +- **Terminate**: Terminating an instance deletes the instance permanently. + +:::info Note + Customers pay for H2O AI Cloud via AI Units so that as you consume more resources, you pay more. +::: + + + + You can Pause an instance in the app details page: +

+

+
    +
  1. In the app details page, click the My instances tab.
  2. +
  3. Locate the instance you want to pause.
  4. +
  5. Click Pause.
  6. + Pause +
+ + You can also Pause an instance in the My Instances page: + +
    +
  1. In the H2O AI Cloud menu, click My Instances.
  2. +
  3. Locate the instance you want to pause.
  4. +
  5. Click Pause.
  6. + Pause +
+
+ + You can Terminate an instance in the app details page: +

+

+
    +
  1. In the app details page, click the My instances tab.
  2. +
  3. Locate the instance you want to terminate.
  4. +
  5. Click expand_more Expand.
  6. +
  7. Click Terminate.
  8. + Pause +
+ + You can also Terminate an instance in the My Instance page: + +
    +
  1. In the H2O AI Cloud menu, click My Instances.
  2. +
  3. Locate the instance you want to terminate.
  4. +
  5. Click expand_more Expand.
  6. +
  7. Click Terminate.
  8. + Pause +
+
+
diff --git a/documentation/docs/get-started/access-application-name/go-to-instance.png b/documentation/docs/get-started/access-application-name/go-to-instance.png new file mode 100644 index 0000000..8958fa0 Binary files /dev/null and b/documentation/docs/get-started/access-application-name/go-to-instance.png differ diff --git a/documentation/docs/get-started/access-application-name/h2o-sample-home-page.png b/documentation/docs/get-started/access-application-name/h2o-sample-home-page.png new file mode 100644 index 0000000..c52a640 Binary files /dev/null and b/documentation/docs/get-started/access-application-name/h2o-sample-home-page.png differ diff --git a/documentation/docs/get-started/access-application-name/pause-instance.png b/documentation/docs/get-started/access-application-name/pause-instance.png new file mode 100644 index 0000000..3b7c563 Binary files /dev/null and b/documentation/docs/get-started/access-application-name/pause-instance.png differ diff --git a/documentation/docs/get-started/access-application-name/pause.png b/documentation/docs/get-started/access-application-name/pause.png new file mode 100644 index 0000000..78fee17 Binary files /dev/null and b/documentation/docs/get-started/access-application-name/pause.png differ diff --git a/documentation/docs/get-started/access-application-name/search-bar.png b/documentation/docs/get-started/access-application-name/search-bar.png new file mode 100644 index 0000000..83f27f4 Binary files /dev/null and b/documentation/docs/get-started/access-application-name/search-bar.png differ diff --git a/documentation/docs/get-started/access-application-name/terminate-instance.png b/documentation/docs/get-started/access-application-name/terminate-instance.png new file mode 100644 index 0000000..805d470 Binary files /dev/null and b/documentation/docs/get-started/access-application-name/terminate-instance.png differ diff --git a/documentation/docs/get-started/access-application-name/terminate.png b/documentation/docs/get-started/access-application-name/terminate.png new file mode 100644 index 0000000..efc6e13 Binary files /dev/null and b/documentation/docs/get-started/access-application-name/terminate.png differ diff --git a/documentation/docs/get-started/application-name-flow.md b/documentation/docs/get-started/application-name-flow.md new file mode 100644 index 0000000..ba53f9c --- /dev/null +++ b/documentation/docs/get-started/application-name-flow.md @@ -0,0 +1 @@ +# [application_name] flow diff --git a/documentation/docs/get-started/local_video.mp4 b/documentation/docs/get-started/local_video.mp4 new file mode 100644 index 0000000..99bde1a Binary files /dev/null and b/documentation/docs/get-started/local_video.mp4 differ diff --git a/documentation/docs/get-started/use-cases.md b/documentation/docs/get-started/use-cases.md new file mode 100644 index 0000000..c5cac30 --- /dev/null +++ b/documentation/docs/get-started/use-cases.md @@ -0,0 +1 @@ +# Use cases diff --git a/documentation/docs/get-started/videos.md b/documentation/docs/get-started/videos.md new file mode 100644 index 0000000..ac6a30c --- /dev/null +++ b/documentation/docs/get-started/videos.md @@ -0,0 +1,15 @@ +import ReactPlayer from 'react-player' +import local_video from './local_video.mp4'; + +# Videos + +

+ +

+ +## Local Videos + +

+ +

+ diff --git a/documentation/docs/get-started/what-is-application-name.md b/documentation/docs/get-started/what-is-application-name.md new file mode 100644 index 0000000..3a2c9f4 --- /dev/null +++ b/documentation/docs/get-started/what-is-application-name.md @@ -0,0 +1 @@ +# What is [application_name]? diff --git a/documentation/docs/index.md b/documentation/docs/index.md new file mode 100644 index 0000000..ca11785 --- /dev/null +++ b/documentation/docs/index.md @@ -0,0 +1,53 @@ +--- +slug: / +displayed_sidebar: defaultSidebar +title: Home +hide_table_of_contents: true +hide_title: true +--- + +import H2OHome from '@site/src/components/H2OHome'; + + diff --git a/documentation/docs/key-terms.md b/documentation/docs/key-terms.md new file mode 100644 index 0000000..1ed42d8 --- /dev/null +++ b/documentation/docs/key-terms.md @@ -0,0 +1 @@ +# Key terms diff --git a/documentation/docs/python-client-guide/page-1.md b/documentation/docs/python-client-guide/page-1.md new file mode 100644 index 0000000..9d8544c --- /dev/null +++ b/documentation/docs/python-client-guide/page-1.md @@ -0,0 +1,34 @@ +import Tabs from '@theme/Tabs'; +import TabItem from '@theme/TabItem'; + +# Page 1 + + + + ```mdx-code-block + + + ``` + + ```python title="StandardFlowExample.py" + ### Constants + MLOPS_API_URL = "https://api.mlops.my.domain" + TOKEN_ENDPOINT_URL="https://mlops.keycloak.domain/auth/realms/[fill-in-realm-name]" + REFRESH_TOKEN="" + ``` + + ```mdx-code-block + + + ``` + + ```python title="StandardFlowExample.py" + Deployment drift metrics: {'count_feature_drift': 13, 'feature_frequency': + [{'categorical': {'description': '','name': '','point': [], + ``` + + ```mdx-code-block + + + ``` + diff --git a/documentation/docs/release-notes.md b/documentation/docs/release-notes.md new file mode 100644 index 0000000..903c1cc --- /dev/null +++ b/documentation/docs/release-notes.md @@ -0,0 +1,4 @@ +# Release notes + +## v1.2.0 | Sep 2, 2022 + diff --git a/documentation/docs/third-party-licenses.md b/documentation/docs/third-party-licenses.md new file mode 100644 index 0000000..5ae6816 --- /dev/null +++ b/documentation/docs/third-party-licenses.md @@ -0,0 +1,18 @@ +# Third-party licenses + + +## Certif + + + + + + +## Cffi + + + + + + + diff --git a/documentation/docs/tutorials/datasets/tutorial-1a.md b/documentation/docs/tutorials/datasets/tutorial-1a.md new file mode 100644 index 0000000..d69b7b2 --- /dev/null +++ b/documentation/docs/tutorials/datasets/tutorial-1a.md @@ -0,0 +1,16 @@ +--- +title: Tutorial 1A +--- + +# Tutorial 1A: [name_of_tutorial] + + +## Prerequisites + +## Step 1: + +## Step 2: + +## Summary + +## Next diff --git a/documentation/docs/tutorials/datasets/tutorial-2a.md b/documentation/docs/tutorials/datasets/tutorial-2a.md new file mode 100644 index 0000000..1c7a164 --- /dev/null +++ b/documentation/docs/tutorials/datasets/tutorial-2a.md @@ -0,0 +1,15 @@ +--- +title: Tutorial 2A +--- + +# Tutorial 2A: [name_of_tutorial] + +## Prerequisites + +## Step 1: + +## Step 2: + +## Summary + +## Next diff --git a/documentation/docs/tutorials/datasets/tutorial-3a.md b/documentation/docs/tutorials/datasets/tutorial-3a.md new file mode 100644 index 0000000..5960c78 --- /dev/null +++ b/documentation/docs/tutorials/datasets/tutorial-3a.md @@ -0,0 +1,15 @@ +--- +title: Tutorial 3A +--- + +# Tutorial 3A: [name_of_tutorial] + +## Prerequisites + +## Step 1: + +## Step 2: + +## Summary + +## Next diff --git a/documentation/docs/tutorials/experiments/tutorial-1b.md b/documentation/docs/tutorials/experiments/tutorial-1b.md new file mode 100644 index 0000000..767ddff --- /dev/null +++ b/documentation/docs/tutorials/experiments/tutorial-1b.md @@ -0,0 +1,15 @@ +--- +title: Tutorial 1B +--- + +# Tutorial 1B: [name_of_tutorial] + +## Prerequisites + +## Step 1: + +## Step 2: + +## Summary + +## Next diff --git a/documentation/docs/tutorials/experiments/tutorial-2b.md b/documentation/docs/tutorials/experiments/tutorial-2b.md new file mode 100644 index 0000000..b8d9c6f --- /dev/null +++ b/documentation/docs/tutorials/experiments/tutorial-2b.md @@ -0,0 +1,17 @@ +--- +title: Tutorial 2B +--- + +# Tutorial 2B: [name_of_tutorial] + +## Prerequisites + +## Step 1: + +## Step 2: + +## Summary + +## Next + + diff --git a/documentation/docs/tutorials/experiments/tutorial-3b.md b/documentation/docs/tutorials/experiments/tutorial-3b.md new file mode 100644 index 0000000..673b647 --- /dev/null +++ b/documentation/docs/tutorials/experiments/tutorial-3b.md @@ -0,0 +1,15 @@ +--- +title: Tutorial 3B +--- + +# Tutorial 3B: [name_of_tutorial] + +## Prerequisites + +## Step 1: + +## Step 2: + +## Summary + +## Next diff --git a/documentation/docs/tutorials/predictions/tutorial-1c.md b/documentation/docs/tutorials/predictions/tutorial-1c.md new file mode 100644 index 0000000..e82a4bf --- /dev/null +++ b/documentation/docs/tutorials/predictions/tutorial-1c.md @@ -0,0 +1,15 @@ +--- +title: Tutorial 1C +--- + +# Tutorial 1C: [name_of_tutorial] + +## Prerequisites + +## Step 1: + +## Step 2: + +## Summary + +## Next diff --git a/documentation/docs/tutorials/predictions/tutorial-2c.md b/documentation/docs/tutorials/predictions/tutorial-2c.md new file mode 100644 index 0000000..472e944 --- /dev/null +++ b/documentation/docs/tutorials/predictions/tutorial-2c.md @@ -0,0 +1,17 @@ +--- +title: Tutorial 2C +--- + +# Tutorial 2C: [name_of_tutorial] + +## Prerequisites + +## Step 1: + +## Step 2: + +## Summary + +## Next + + diff --git a/documentation/docs/tutorials/predictions/tutorial-3c.md b/documentation/docs/tutorials/predictions/tutorial-3c.md new file mode 100644 index 0000000..2133200 --- /dev/null +++ b/documentation/docs/tutorials/predictions/tutorial-3c.md @@ -0,0 +1,16 @@ +--- +title: Tutorial 3C +--- + +# Tutorial 3C: [name_of_tutorial] + +## Prerequisites + +## Step 1: + +## Step 2: + +## Summary + +## Next + diff --git a/documentation/docs/tutorials/tutorials-overview.md b/documentation/docs/tutorials/tutorials-overview.md new file mode 100644 index 0000000..299ecb8 --- /dev/null +++ b/documentation/docs/tutorials/tutorials-overview.md @@ -0,0 +1 @@ +# Tutorials diff --git a/documentation/docs/user-guide/page-1.md b/documentation/docs/user-guide/page-1.md new file mode 100644 index 0000000..9608001 --- /dev/null +++ b/documentation/docs/user-guide/page-1.md @@ -0,0 +1 @@ +# Page 1 diff --git a/documentation/makersaurus.config.js b/documentation/makersaurus.config.js new file mode 100644 index 0000000..0b13dac --- /dev/null +++ b/documentation/makersaurus.config.js @@ -0,0 +1,17 @@ +module.exports = { + title: "New Makersaurus Project", + tagline: "This code for this site was generated by Makersaurus", + url: "https://example.h2o.ai", + baseUrl: "/", + projectName: "new-makersaurus-project", // Usually your repo name + feedbackAssignee: "5675sp", // Should be a github username + dependencies: { + "@emotion/react": "^11.10.5", + "@emotion/styled": "^11.10.5", + "@material-ui/core": "^4.12.4", + "@material/card": "^14.0.0", + "@mui/icons-material": "^5.10.16", + "@mui/material": "^5.10.16", + "react-player": "^2.11.0", + }, +}; diff --git a/documentation/package-lock.json b/documentation/package-lock.json new file mode 100644 index 0000000..e21a366 --- /dev/null +++ b/documentation/package-lock.json @@ -0,0 +1,461 @@ +{ + "name": "new-makersaurus-project", + "version": "0.0.0", + "lockfileVersion": 3, + "requires": true, + "packages": { + "": { + "name": "new-makersaurus-project", + "version": "0.0.0", + "dependencies": { + "@h2oai/makersaurus": "^0.6.0" + } + }, + "node_modules/@babel/runtime": { + "version": "7.23.1", + "resolved": "https://registry.npmjs.org/@babel/runtime/-/runtime-7.23.1.tgz", + "integrity": "sha512-hC2v6p8ZSI/W0HUzh3V8C5g+NwSKzKPtJwSpTjwl0o297GP9+ZLQSkdvHz46CM3LqyoXxq+5G9komY+eSqSO0g==", + "dependencies": { + "regenerator-runtime": "^0.14.0" + }, + "engines": { + "node": ">=6.9.0" + } + }, + "node_modules/@h2oai/makersaurus": { + "version": "0.6.0", + "resolved": "https://npm.pkg.github.com/download/@h2oai/makersaurus/0.6.0/255967cb0799d8de5b106f2cef4a971dcbe55862", + "integrity": "sha512-UtDq9pCk6XFnHRBsyZWxn8bITf7wFft3jBmLD3S2K/H8GsV8U+8oxVu/Ric09dva1WTQukUUCgiYX2v/4qXFzg==", + "dependencies": { + "commander": "^9.4.1", + "handlebars": "^4.7.7", + "sync-directory": "^5.1.9", + "yup": "^0.32.11" + }, + "bin": { + "makersaurus": "src/bin.js" + } + }, + "node_modules/@types/lodash": { + "version": "4.14.199", + "resolved": "https://registry.npmjs.org/@types/lodash/-/lodash-4.14.199.tgz", + "integrity": "sha512-Vrjz5N5Ia4SEzWWgIVwnHNEnb1UE1XMkvY5DGXrAeOGE9imk0hgTHh5GyDjLDJi9OTCn9oo9dXH1uToK1VRfrg==" + }, + "node_modules/anymatch": { + "version": "3.1.3", + "resolved": "https://registry.npmjs.org/anymatch/-/anymatch-3.1.3.tgz", + "integrity": "sha512-KMReFUr0B4t+D+OBkjR3KYqvocp2XaSzO55UcB6mgQMd3KbcE+mWTyvVV7D/zsdEbNnV6acZUutkiHQXvTr1Rw==", + "dependencies": { + "normalize-path": "^3.0.0", + "picomatch": "^2.0.4" + }, + "engines": { + "node": ">= 8" + } + }, + "node_modules/binary-extensions": { + "version": "2.2.0", + "resolved": "https://registry.npmjs.org/binary-extensions/-/binary-extensions-2.2.0.tgz", + "integrity": "sha512-jDctJ/IVQbZoJykoeHbhXpOlNBqGNcwXJKJog42E5HDPUwQTSdjCHdihjj0DlnheQ7blbT6dHOafNAiS8ooQKA==", + "engines": { + "node": ">=8" + } + }, + "node_modules/braces": { + "version": "3.0.2", + "resolved": "https://registry.npmjs.org/braces/-/braces-3.0.2.tgz", + "integrity": "sha512-b8um+L1RzM3WDSzvhm6gIz1yfTbBt6YTlcEKAvsmqCZZFw46z626lVj9j1yEPW33H5H+lBQpZMP1k8l+78Ha0A==", + "dependencies": { + "fill-range": "^7.0.1" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/call-me-maybe": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/call-me-maybe/-/call-me-maybe-1.0.2.tgz", + "integrity": "sha512-HpX65o1Hnr9HH25ojC1YGs7HCQLq0GCOibSaWER0eNpgJ/Z1MZv2mTc7+xh6WOPxbRVcmgbv4hGU+uSQ/2xFZQ==" + }, + "node_modules/chokidar": { + "version": "3.5.3", + "resolved": "https://registry.npmjs.org/chokidar/-/chokidar-3.5.3.tgz", + "integrity": "sha512-Dr3sfKRP6oTcjf2JmUmFJfeVMvXBdegxB0iVQ5eb2V10uFJUCAS8OByZdVAyVb8xXNz3GjjTgj9kLWsZTqE6kw==", + "funding": [ + { + "type": "individual", + "url": "https://paulmillr.com/funding/" + } + ], + "dependencies": { + "anymatch": "~3.1.2", + "braces": "~3.0.2", + "glob-parent": "~5.1.2", + "is-binary-path": "~2.1.0", + "is-glob": "~4.0.1", + "normalize-path": "~3.0.0", + "readdirp": "~3.6.0" + }, + "engines": { + "node": ">= 8.10.0" + }, + "optionalDependencies": { + "fsevents": "~2.3.2" + } + }, + "node_modules/commander": { + "version": "9.5.0", + "resolved": "https://registry.npmjs.org/commander/-/commander-9.5.0.tgz", + "integrity": "sha512-KRs7WVDKg86PWiuAqhDrAQnTXZKraVcCc6vFdL14qrZ/DcWwuRo7VoiYXalXO7S5GKpqYiVEwCbgFDfxNHKJBQ==", + "engines": { + "node": "^12.20.0 || >=14" + } + }, + "node_modules/es6-promise": { + "version": "4.2.8", + "resolved": "https://registry.npmjs.org/es6-promise/-/es6-promise-4.2.8.tgz", + "integrity": "sha512-HJDGx5daxeIvxdBxvG2cb9g4tEvwIk3i8+nhX0yGrYmZUzbkdg8QbDevheDB8gd0//uPj4c1EQua8Q+MViT0/w==" + }, + "node_modules/fill-range": { + "version": "7.0.1", + "resolved": "https://registry.npmjs.org/fill-range/-/fill-range-7.0.1.tgz", + "integrity": "sha512-qOo9F+dMUmC2Lcb4BbVvnKJxTPjCm+RRpe4gDuGrzkL7mEVl/djYSu2OdQ2Pa302N4oqkSg9ir6jaLWJ2USVpQ==", + "dependencies": { + "to-regex-range": "^5.0.1" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/fs-extra": { + "version": "7.0.1", + "resolved": "https://registry.npmjs.org/fs-extra/-/fs-extra-7.0.1.tgz", + "integrity": "sha512-YJDaCJZEnBmcbw13fvdAM9AwNOJwOzrE4pqMqBq5nFiEqXUqHwlK4B+3pUw6JNvfSPtX05xFHtYy/1ni01eGCw==", + "dependencies": { + "graceful-fs": "^4.1.2", + "jsonfile": "^4.0.0", + "universalify": "^0.1.0" + }, + "engines": { + "node": ">=6 <7 || >=8" + } + }, + "node_modules/fsevents": { + "version": "2.3.3", + "resolved": "https://registry.npmjs.org/fsevents/-/fsevents-2.3.3.tgz", + "integrity": "sha512-5xoDfX+fL7faATnagmWPpbFtwh/R77WmMMqqHGS65C3vvB0YHrgF+B1YmZ3441tMj5n63k0212XNoJwzlhffQw==", + "hasInstallScript": true, + "optional": true, + "os": [ + "darwin" + ], + "engines": { + "node": "^8.16.0 || ^10.6.0 || >=11.0.0" + } + }, + "node_modules/glob-parent": { + "version": "5.1.2", + "resolved": "https://registry.npmjs.org/glob-parent/-/glob-parent-5.1.2.tgz", + "integrity": "sha512-AOIgSQCepiJYwP3ARnGx+5VnTu2HBYdzbGP45eLw1vr3zB3vZLeyed1sC9hnbcOc9/SrMyM5RPQrkGz4aS9Zow==", + "dependencies": { + "is-glob": "^4.0.1" + }, + "engines": { + "node": ">= 6" + } + }, + "node_modules/glob-to-regexp": { + "version": "0.3.0", + "resolved": "https://registry.npmjs.org/glob-to-regexp/-/glob-to-regexp-0.3.0.tgz", + "integrity": "sha512-Iozmtbqv0noj0uDDqoL0zNq0VBEfK2YFoMAZoxJe4cwphvLR+JskfF30QhXHOR4m3KrE6NLRYw+U9MRXvifyig==" + }, + "node_modules/graceful-fs": { + "version": "4.2.11", + "resolved": "https://registry.npmjs.org/graceful-fs/-/graceful-fs-4.2.11.tgz", + "integrity": "sha512-RbJ5/jmFcNNCcDV5o9eTnBLJ/HszWV0P73bc+Ff4nS/rJj+YaS6IGyiOL0VoBYX+l1Wrl3k63h/KrH+nhJ0XvQ==" + }, + "node_modules/handlebars": { + "version": "4.7.8", + "resolved": "https://registry.npmjs.org/handlebars/-/handlebars-4.7.8.tgz", + "integrity": "sha512-vafaFqs8MZkRrSX7sFVUdo3ap/eNiLnb4IakshzvP56X5Nr1iGKAIqdX6tMlm6HcNRIkr6AxO5jFEoJzzpT8aQ==", + "dependencies": { + "minimist": "^1.2.5", + "neo-async": "^2.6.2", + "source-map": "^0.6.1", + "wordwrap": "^1.0.0" + }, + "bin": { + "handlebars": "bin/handlebars" + }, + "engines": { + "node": ">=0.4.7" + }, + "optionalDependencies": { + "uglify-js": "^3.1.4" + } + }, + "node_modules/is-absolute": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/is-absolute/-/is-absolute-1.0.0.tgz", + "integrity": "sha512-dOWoqflvcydARa360Gvv18DZ/gRuHKi2NU/wU5X1ZFzdYfH29nkiNZsF3mp4OJ3H4yo9Mx8A/uAGNzpzPN3yBA==", + "dependencies": { + "is-relative": "^1.0.0", + "is-windows": "^1.0.1" + }, + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/is-binary-path": { + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/is-binary-path/-/is-binary-path-2.1.0.tgz", + "integrity": "sha512-ZMERYes6pDydyuGidse7OsHxtbI7WVeUEozgR/g7rd0xUimYNlvZRE/K2MgZTjWy725IfelLeVcEM97mmtRGXw==", + "dependencies": { + "binary-extensions": "^2.0.0" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/is-extglob": { + "version": "2.1.1", + "resolved": "https://registry.npmjs.org/is-extglob/-/is-extglob-2.1.1.tgz", + "integrity": "sha512-SbKbANkN603Vi4jEZv49LeVJMn4yGwsbzZworEoyEiutsN3nJYdbO36zfhGJ6QEDpOZIFkDtnq5JRxmvl3jsoQ==", + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/is-glob": { + "version": "4.0.3", + "resolved": "https://registry.npmjs.org/is-glob/-/is-glob-4.0.3.tgz", + "integrity": "sha512-xelSayHH36ZgE7ZWhli7pW34hNbNl8Ojv5KVmkJD4hBdD3th8Tfk9vYasLM+mXWOZhFkgZfxhLSnrwRr4elSSg==", + "dependencies": { + "is-extglob": "^2.1.1" + }, + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/is-number": { + "version": "7.0.0", + "resolved": "https://registry.npmjs.org/is-number/-/is-number-7.0.0.tgz", + "integrity": "sha512-41Cifkg6e8TylSpdtTpeLVMqvSBEVzTttHvERD741+pnZ8ANv0004MRL43QKPDlK9cGvNp6NZWZUBlbGXYxxng==", + "engines": { + "node": ">=0.12.0" + } + }, + "node_modules/is-relative": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/is-relative/-/is-relative-1.0.0.tgz", + "integrity": "sha512-Kw/ReK0iqwKeu0MITLFuj0jbPAmEiOsIwyIXvvbfa6QfmN9pkD1M+8pdk7Rl/dTKbH34/XBFMbgD4iMJhLQbGA==", + "dependencies": { + "is-unc-path": "^1.0.0" + }, + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/is-unc-path": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/is-unc-path/-/is-unc-path-1.0.0.tgz", + "integrity": "sha512-mrGpVd0fs7WWLfVsStvgF6iEJnbjDFZh9/emhRDcGWTduTfNHd9CHeUwH3gYIjdbwo4On6hunkztwOaAw0yllQ==", + "dependencies": { + "unc-path-regex": "^0.1.2" + }, + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/is-windows": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/is-windows/-/is-windows-1.0.2.tgz", + "integrity": "sha512-eXK1UInq2bPmjyX6e3VHIzMLobc4J94i4AWn+Hpq3OU5KkrRC96OAcR3PRJ/pGu6m8TRnBHP9dkXQVsT/COVIA==", + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/jsonfile": { + "version": "4.0.0", + "resolved": "https://registry.npmjs.org/jsonfile/-/jsonfile-4.0.0.tgz", + "integrity": "sha512-m6F1R3z8jjlf2imQHS2Qez5sjKWQzbuuhuJ/FKYFRZvPE3PuHcSMVZzfsLhGVOkfd20obL5SWEBew5ShlquNxg==", + "optionalDependencies": { + "graceful-fs": "^4.1.6" + } + }, + "node_modules/lodash": { + "version": "4.17.21", + "resolved": "https://registry.npmjs.org/lodash/-/lodash-4.17.21.tgz", + "integrity": "sha512-v2kDEe57lecTulaDIuNTPy3Ry4gLGJ6Z1O3vE1krgXZNrsQ+LFTGHVxVjcXPs17LhbZVGedAJv8XZ1tvj5FvSg==" + }, + "node_modules/lodash-es": { + "version": "4.17.21", + "resolved": "https://registry.npmjs.org/lodash-es/-/lodash-es-4.17.21.tgz", + "integrity": "sha512-mKnC+QJ9pWVzv+C4/U3rRsHapFfHvQFoFB92e52xeyGMcX6/OlIl78je1u8vePzYZSkkogMPJ2yjxxsb89cxyw==" + }, + "node_modules/minimist": { + "version": "1.2.8", + "resolved": "https://registry.npmjs.org/minimist/-/minimist-1.2.8.tgz", + "integrity": "sha512-2yyAR8qBkN3YuheJanUpWC5U3bb5osDywNB8RzDVlDwDHbocAJveqqj1u8+SVD7jkWT4yvsHCpWqqWqAxb0zCA==", + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/nanoclone": { + "version": "0.2.1", + "resolved": "https://registry.npmjs.org/nanoclone/-/nanoclone-0.2.1.tgz", + "integrity": "sha512-wynEP02LmIbLpcYw8uBKpcfF6dmg2vcpKqxeH5UcoKEYdExslsdUA4ugFauuaeYdTB76ez6gJW8XAZ6CgkXYxA==" + }, + "node_modules/neo-async": { + "version": "2.6.2", + "resolved": "https://registry.npmjs.org/neo-async/-/neo-async-2.6.2.tgz", + "integrity": "sha512-Yd3UES5mWCSqR+qNT93S3UoYUkqAZ9lLg8a7g9rimsWmYGK8cVToA4/sF3RrshdyV3sAGMXVUmpMYOw+dLpOuw==" + }, + "node_modules/normalize-path": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/normalize-path/-/normalize-path-3.0.0.tgz", + "integrity": "sha512-6eZs5Ls3WtCisHWp9S2GUy8dqkpGi4BVSz3GaqiE6ezub0512ESztXUwUB6C6IKbQkY2Pnb/mD4WYojCRwcwLA==", + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/picomatch": { + "version": "2.3.1", + "resolved": "https://registry.npmjs.org/picomatch/-/picomatch-2.3.1.tgz", + "integrity": "sha512-JU3teHTNjmE2VCGFzuY8EXzCDVwEqB2a8fsIvwaStHhAWJEeVd1o1QD80CU6+ZdEXXSLbSsuLwJjkCBWqRQUVA==", + "engines": { + "node": ">=8.6" + }, + "funding": { + "url": "https://github.com/sponsors/jonschlinkert" + } + }, + "node_modules/property-expr": { + "version": "2.0.5", + "resolved": "https://registry.npmjs.org/property-expr/-/property-expr-2.0.5.tgz", + "integrity": "sha512-IJUkICM5dP5znhCckHSv30Q4b5/JA5enCtkRHYaOVOAocnH/1BQEYTC5NMfT3AVl/iXKdr3aqQbQn9DxyWknwA==" + }, + "node_modules/readdir-enhanced": { + "version": "1.5.2", + "resolved": "https://registry.npmjs.org/readdir-enhanced/-/readdir-enhanced-1.5.2.tgz", + "integrity": "sha512-oncAoS9LLjy/+DeZfSAdZBI/iFJGcPCOp44RPFI6FIMHuxt5CC5P0cUZ9mET+EZB9ONhcEvAids/lVRkj0sTHw==", + "dependencies": { + "call-me-maybe": "^1.0.1", + "es6-promise": "^4.1.0", + "glob-to-regexp": "^0.3.0" + } + }, + "node_modules/readdirp": { + "version": "3.6.0", + "resolved": "https://registry.npmjs.org/readdirp/-/readdirp-3.6.0.tgz", + "integrity": "sha512-hOS089on8RduqdbhvQ5Z37A0ESjsqz6qnRcffsMU3495FuTdqSm+7bhJ29JvIOsBDEEnan5DPu9t3To9VRlMzA==", + "dependencies": { + "picomatch": "^2.2.1" + }, + "engines": { + "node": ">=8.10.0" + } + }, + "node_modules/regenerator-runtime": { + "version": "0.14.0", + "resolved": "https://registry.npmjs.org/regenerator-runtime/-/regenerator-runtime-0.14.0.tgz", + "integrity": "sha512-srw17NI0TUWHuGa5CFGGmhfNIeja30WMBfbslPNhf6JrqQlLN5gcrvig1oqPxiVaXb0oW0XRKtH6Nngs5lKCIA==" + }, + "node_modules/source-map": { + "version": "0.6.1", + "resolved": "https://registry.npmjs.org/source-map/-/source-map-0.6.1.tgz", + "integrity": "sha512-UjgapumWlbMhkBgzT7Ykc5YXUT46F0iKu8SGXq0bcwP5dz/h0Plj6enJqjz1Zbq2l5WaqYnrVbwWOWMyF3F47g==", + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/sync-directory": { + "version": "5.1.9", + "resolved": "https://registry.npmjs.org/sync-directory/-/sync-directory-5.1.9.tgz", + "integrity": "sha512-0942RssO+NrIjDcaNiXUH/NQoAamURT1zpzN/uB8fgyetDM8NtPPOQNax3+BuNUfw/2JcaEXrAz567DokNq0lw==", + "dependencies": { + "chokidar": "^3.3.1", + "commander": "^6.2.0", + "fs-extra": "^7.0.1", + "is-absolute": "^1.0.0", + "readdir-enhanced": "^1.5.2" + }, + "bin": { + "syncdir": "cmd.js" + } + }, + "node_modules/sync-directory/node_modules/commander": { + "version": "6.2.1", + "resolved": "https://registry.npmjs.org/commander/-/commander-6.2.1.tgz", + "integrity": "sha512-U7VdrJFnJgo4xjrHpTzu0yrHPGImdsmD95ZlgYSEajAn2JKzDhDTPG9kBTefmObL2w/ngeZnilk+OV9CG3d7UA==", + "engines": { + "node": ">= 6" + } + }, + "node_modules/to-regex-range": { + "version": "5.0.1", + "resolved": "https://registry.npmjs.org/to-regex-range/-/to-regex-range-5.0.1.tgz", + "integrity": "sha512-65P7iz6X5yEr1cwcgvQxbbIw7Uk3gOy5dIdtZ4rDveLqhrdJP+Li/Hx6tyK0NEb+2GCyneCMJiGqrADCSNk8sQ==", + "dependencies": { + "is-number": "^7.0.0" + }, + "engines": { + "node": ">=8.0" + } + }, + "node_modules/toposort": { + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/toposort/-/toposort-2.0.2.tgz", + "integrity": "sha512-0a5EOkAUp8D4moMi2W8ZF8jcga7BgZd91O/yabJCFY8az+XSzeGyTKs0Aoo897iV1Nj6guFq8orWDS96z91oGg==" + }, + "node_modules/uglify-js": { + "version": "3.17.4", + "resolved": "https://registry.npmjs.org/uglify-js/-/uglify-js-3.17.4.tgz", + "integrity": "sha512-T9q82TJI9e/C1TAxYvfb16xO120tMVFZrGA3f9/P4424DNu6ypK103y0GPFVa17yotwSyZW5iYXgjYHkGrJW/g==", + "optional": true, + "bin": { + "uglifyjs": "bin/uglifyjs" + }, + "engines": { + "node": ">=0.8.0" + } + }, + "node_modules/unc-path-regex": { + "version": "0.1.2", + "resolved": "https://registry.npmjs.org/unc-path-regex/-/unc-path-regex-0.1.2.tgz", + "integrity": "sha512-eXL4nmJT7oCpkZsHZUOJo8hcX3GbsiDOa0Qu9F646fi8dT3XuSVopVqAcEiVzSKKH7UoDti23wNX3qGFxcW5Qg==", + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/universalify": { + "version": "0.1.2", + "resolved": "https://registry.npmjs.org/universalify/-/universalify-0.1.2.tgz", + "integrity": "sha512-rBJeI5CXAlmy1pV+617WB9J63U6XcazHHF2f2dbJix4XzpUF0RS3Zbj0FGIOCAva5P/d/GBOYaACQ1w+0azUkg==", + "engines": { + "node": ">= 4.0.0" + } + }, + "node_modules/wordwrap": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/wordwrap/-/wordwrap-1.0.0.tgz", + "integrity": "sha512-gvVzJFlPycKc5dZN4yPkP8w7Dc37BtP1yczEneOb4uq34pXZcvrtRTmWV8W+Ume+XCxKgbjM+nevkyFPMybd4Q==" + }, + "node_modules/yup": { + "version": "0.32.11", + "resolved": "https://registry.npmjs.org/yup/-/yup-0.32.11.tgz", + "integrity": "sha512-Z2Fe1bn+eLstG8DRR6FTavGD+MeAwyfmouhHsIUgaADz8jvFKbO/fXc2trJKZg+5EBjh4gGm3iU/t3onKlXHIg==", + "dependencies": { + "@babel/runtime": "^7.15.4", + "@types/lodash": "^4.14.175", + "lodash": "^4.17.21", + "lodash-es": "^4.17.21", + "nanoclone": "^0.2.1", + "property-expr": "^2.0.4", + "toposort": "^2.0.2" + }, + "engines": { + "node": ">=10" + } + } + } +} diff --git a/documentation/package.json b/documentation/package.json new file mode 100644 index 0000000..c5e0620 --- /dev/null +++ b/documentation/package.json @@ -0,0 +1,14 @@ +{ + "name": "new-makersaurus-project", + "version": "0.0.0", + "scripts": { + "start": "makersaurus start", + "build": "makersaurus build", + "serve": "makersaurus serve", + "scaffold": "makersaurus scaffold", + "deploy": "makersaurus deploy" + }, + "dependencies": { + "@h2oai/makersaurus": "^0.6.0" + } + } \ No newline at end of file diff --git a/documentation/sidebars.js b/documentation/sidebars.js new file mode 100644 index 0000000..13d9228 --- /dev/null +++ b/documentation/sidebars.js @@ -0,0 +1,63 @@ +module.exports = { + defaultSidebar: [ + "index", + { + "Get started": [ + "get-started/what-is-application-name", + "get-started/access-application-name/access-application-name", + "get-started/application-name-flow", + "get-started/use-cases", + "get-started/videos", + ], + }, + { + type: "category", + label: "Tutorials", + link: { type: "doc", id: "tutorials/tutorials-overview" }, + items: [ + { + type: "category", + label: "Datasets", + items: [ + "tutorials/datasets/tutorial-1a", + "tutorials/datasets/tutorial-2a", + "tutorials/datasets/tutorial-3a", + ], + }, + { + type: "category", + label: "Experiments", + items: [ + "tutorials/experiments/tutorial-1b", + "tutorials/experiments/tutorial-2b", + "tutorials/experiments/tutorial-3b", + ], + }, + { + type: "category", + label: "Predictions", + items: [ + "tutorials/predictions/tutorial-1c", + "tutorials/predictions/tutorial-2c", + "tutorials/predictions/tutorial-3c", + ], + }, + ], + }, + "concepts", + { + "User guide": ["user-guide/page-1"], + }, + { + "Admin guide": ["admin-guide/page-1"], + }, + { + "Python client guide": ["python-client-guide/page-1"], + }, + "key-terms", + "release-notes", + "third-party-licenses", + "faqs", + ], +}; + diff --git a/examples/sleep_eda/table_info.jsonl b/examples/sleep_eda/table_info.jsonl new file mode 100644 index 0000000..04fc5b7 --- /dev/null +++ b/examples/sleep_eda/table_info.jsonl @@ -0,0 +1,13 @@ +{"Column Name": "Person_ID", "Column Type": "uuid PRIMARY KEY"} +{"Column Name": "Gender", "Column Type": "TEXT", "Sample Values": ["Female", "Male"]} +{"Column Name": "Age", "Column Type": "NUMERIC"} +{"Column Name": "Occupation", "Column Type": "TEXT", "Sample Values": ["Accountant", "Doctor", "Engineer", "Lawyer","Manager", "Nurse", "Sales Representative", "Salesperson", "Scientist", "Software Engineer", "Teacher"]} +{"Column Name": "Sleep_Duration", "Column Type": "NUMERIC"} +{"Column Name": "Quality_of_Sleep", "Column Type": "NUMERIC"} +{"Column Name": "Physical_Activity_Level", "Column Type": "NUMERIC"} +{"Column Name": "Stress_Level", "Column Type": "NUMERIC"} +{"Column Name": "BMI_Category", "Column Type": "TEXT", "Sample Values": ["Normal", "Normal Weight", "Obese", "Overweight"]} +{"Column Name": "Blood_Pressure", "Column Type": "TEXT", "Sample Values": ["115/75", "115/78", "117/76", "118/75", "118/76", "119/77"]} +{"Column Name": "Heart_Rate", "Column Type": "NUMERIC"} +{"Column Name": "Daily_Steps", "Column Type": "NUMERIC"} +{"Column Name": "Sleep_Disorder", "Column Type": "TEXT", "Sample Values": ["Insomnia", "Sleep Apnea"]} diff --git a/examples/telemetry/samples.csv b/examples/telemetry/samples.csv new file mode 100644 index 0000000..6f2f598 --- /dev/null +++ b/examples/telemetry/samples.csv @@ -0,0 +1,27 @@ +query,answer +Total number of CPUs used?,SELECT sum((payload->'engineEvent'-> 'pausing' -> 'engine'->> 'cpu')::integer) AS total_cpus_used FROM telemetry WHERE payload->'engineEvent'-> 'pausing' -> 'engine'->> 'cpu' IS NOT NULL; +Find the number of AI units for each user using stream for each resource type (overall),"SELECT user_id, user_name, resource_type, date_trunc('day', ts) as start_day, + sum(AI_units) as AI_units FROM ( + SELECT user_id, user_name, resource_type, ts, + extract(epoch from ts - lag(ts) over (partition by user_id, resource_type order by ts)) / 3600 AS AI_units + FROM telemetry + WHERE stream = 'running' + ) sub GROUP BY user_id, user_name, resource_type, start_day +ORDER BY start_day DESC NULLS LAST;" +Compute global usage over time,"SELECT + ts AS time_interval, + GREATEST((GREATEST((ram_gi / 64.0), (cpu / 8.0)) - gpu), 0) + (gpu * 4.0) as ai_units +FROM ( + SELECT + -- This is a gauge stream, meaning multiple sources are exporting duplicate entries during the same hour interval + ts, + -- RAM usage in Gi + COALESCE(((payload->'usageGauge'->'billingResources'->>'paddedMemoryReservationBytes')::bigint/1024.0/1024.0/1024.0), 0) AS ram_gi, + -- CPU usage in vCPU + COALESCE(((payload->'usageGauge'->'billingResources'->'paddedCpuReservationMillicpu')::int/1000.0), 0) AS cpu, + -- GPU usage in number of GPUs + COALESCE(((payload->'usageGauge'->'billingResources'->'gpuCount')::int), 0) AS gpu + FROM telemetry + WHERE stream = 'gauage_resources' +) AS internal +ORDER BY 1, 2 DESC;" \ No newline at end of file diff --git a/examples/telemetry/table_info.jsonl b/examples/telemetry/table_info.jsonl new file mode 100644 index 0000000..9d957ba --- /dev/null +++ b/examples/telemetry/table_info.jsonl @@ -0,0 +1,10 @@ +{"Column Name": "id", "Column Type": "uuid PRIMARY KEY"} +{"Column Name": "ts", "Column Type": "TIMESTAMP WITH TIME ZONE NOT NULL"} +{"Column Name": "kind", "Column Type": "TEXT NOT NULL, -- or int?", "Sample Values": ["EVENT"]} +{"Column Name": "user_id", "Column Type": "TEXT"} +{"Column Name": "user_name", "Column Type": "TEXT"} +{"Column Name": "resource_type", "Column Type": "TEXT NOT NULL, -- or int?", "Sample Values": ["FEATURE_STORE", "PROJECT", "MLOPS_EXPERIMENT", "APP", "APP_INSTANCE", "MLOPS_DEPLOYMENT", "MLOPS_DATASET", "MLOPS_USER", "RESOURCE_TYPE_UNSPECIFIED", "SCORING", "DAI_ENGINE", "MLOPS_MODEL"]} +{"Column Name": "resource_id", "Column Type": "TEXT"} +{"Column Name": "stream", "Column Type": "TEXT NOT NULL", "Sample Values": ["air/h2o/cloud/mlops/deployment/created", "ai/h2o/cloud/appstore/instance/gauge/running", "ai/h2o/cloud/mlops/project/unshared", "ai/h2o/cloud/mlops/gauge/project", "ai/h2o/cloud/appstore/user/event/login", "ai/h2o/cloud/mlops/gauge/registered-model-version", "ai/h2o/cloud/appstore/instance/event/started", "ai/h2o/cloud/mlops/deployment/deleted", "ai/h2o/cloud/mlops/gauge/dataset", "ai/h2o/cloud/fs/job/running", "ai/h2o/engine/event/paused", "ai/h2o/cloud/mlops/project/deleted", "ai/h2o/engine/event/deleting", "ai/h2o/engine/event/pausing", "ai/h2o/cloud/mlops/gauge/deployment", "ai/h2o/cloud/usage/global/gauge/resources", "ai/h2o/cloud/mlops/gauge/registered-model", "ai/h2o/cloud/appstore/instance/event/suspended", "ai/h2o/cloud/usage/namespace/gauge/resources", "ai/h2o/cloud/mlops/registered-model-version/created", "ai/h2o/cloud/mlops/project/created", "ai/h2o/cloud/mlops/project/shared", "ai/h2o/cloud/mlops/experiment/created", "ai/h2o/cloud/mlops/dataset/created", "ai/h2o/cloud/appstore/app/event/created", "ai/h2o/cloud/appstore/instance/event/terminated", "ai/h2o/cloud/mlops/gauge/user", "ai/h2o/engine/event/starting", "ai/h2o/cloud/mlops/event/scoring-result/created", "ai/h2o/engine/event/running", "ai/h2o/cloud/fs/job/submitted", "ai/h2o/cloud/mlops/registered-model/created", "ai/h2o/cloud/mlops/gauge/experiment", "ai/h2o/document/ai/proxy", "ai/h2o/cloud/mlops/experiment/unlinked", "ai/h2o/cloud/fs/job/finished", "ai/h2o/cloud/appstore/app/event/deleted", "ai/h2o/cloud/appstore/instance/event/resumed"]} +{"Column Name": "source", "Column Type": "TEXT NOT NULL"} +{"Column Name": "payload", "Column Type": "jsonb NOT NULL", "Sample Values":[{"engineEvent": {"pausing": {"engine": {"cpu": "1", "memory": "1", "gpu": "0"}}}}]} diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..048a16a --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,83 @@ +[tool.poetry] +name = "sql-sidekick" +version = "0.0.13" +license = "Proprietary" +description = "An AI assistant for SQL" +authors = [ + "Pramit Choudhary ", + "Michal Malohlava " +] +readme = "README.md" +classifiers = [ + "Development Status :: Alpha", + "Environment :: CLI", + "Intended Audience :: Developers, Analysts", + "License :: Other/Proprietary License", + "Programming Language :: Python", + "Programming Language :: Python :: 3.8+", + "Topic :: Scientific/Engineering", + "Topic :: Scientific/Engineering :: Artificial Intelligence" +] +packages = [{include = "sidekick"}] + +[tool.poetry.dependencies] +python = ">=3.8.1,<=3.10" +pandas = "^1.3.3" +numpy = "^1.21.2" +click = "^8.0.1" +sqlalchemy = "^1.4.23" +psycopg2-binary = "^2.9.6" +colorama = "^0.4.6" +llama_index = "^0.5.27" +loguru = "^0.7.0" +toml = "^0.10.2" +sqlglot = "^12.2.0" +transformers = "^4.29.0" +sentence-transformers = "^2.2.2" +torch = "^2.0.1" +sqlalchemy-utils = "^0.41.1" +h2o-wave = "0.26.1" +pandasql = "0.7.3" +accelerate = "0.21.0" +bitsandbytes = "0.41.0" +InstructorEmbedding = "^1.0.1" + +[tool.poetry.scripts] +sql-sidekick = "sidekick.prompter:cli" + +[tool.poetry.dev-dependencies] +pylint = { version = "^2.12.2", allow-prereleases = true } +flake8 = { version = "^4.0.1", allow-prereleases = true } +black = { version = "21.12b0", allow-prereleases = true } + +[tool.black] +line-length = 120 +skip-string-normalization = true +target-version = ['py38.16'] +include = '\.pyi?$' +exclude = ''' +( + /( + \.eggs # exclude a few common directories in the + | \.git # root of the project + | \.hg + | \.mypy_cache + | \.tox + | \.venv + | _build + | buck-out + | build + | dist + )/ + | notebooks + | local +) +''' + +[tool.isort] +line_length = 120 +multi_line_output = 3 + +[build-system] +requires = ["poetry-core>=1.0.0"] +build-backend = "poetry.core.masonry.api" diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..0e043d7 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,89 @@ +accelerate==0.21.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +aiohttp==3.8.5 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +aiosignal==1.3.1 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +ansicon==1.89.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" and platform_system == "Windows" +anyio==3.7.1 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +async-timeout==4.0.3 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +attrs==23.1.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +bitsandbytes==0.41.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +blessed==1.20.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +cachetools==5.3.1 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +certifi==2023.7.22 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +charset-normalizer==3.2.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +click==8.1.6 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +colorama==0.4.6 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +dataclasses-json==0.5.14 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +exceptiongroup==1.1.3 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +filelock==3.12.2 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +frozenlist==1.4.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +fsspec==2023.6.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +gptcache==0.1.39.1 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +greenlet==2.0.2 ; python_full_version >= "3.8.1" and platform_machine == "aarch64" and python_full_version <= "3.10.0" or python_full_version >= "3.8.1" and platform_machine == "ppc64le" and python_full_version <= "3.10.0" or python_full_version >= "3.8.1" and platform_machine == "x86_64" and python_full_version <= "3.10.0" or python_full_version >= "3.8.1" and platform_machine == "amd64" and python_full_version <= "3.10.0" or python_full_version >= "3.8.1" and platform_machine == "AMD64" and python_full_version <= "3.10.0" or python_full_version >= "3.8.1" and platform_machine == "win32" and python_full_version <= "3.10.0" or python_full_version >= "3.8.1" and platform_machine == "WIN32" and python_full_version <= "3.10.0" +h11==0.14.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +h2o-wave==0.26.1 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +httpcore==0.17.3 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +httpx==0.24.1 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +huggingface-hub==0.16.4 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +idna==3.4 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +inquirer==3.1.3 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +instructorembedding==1.0.1 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +jinja2==3.1.2 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +jinxed==1.2.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" and platform_system == "Windows" +joblib==1.3.2 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +langchain==0.0.142 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +llama-index==0.5.27 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +loguru==0.7.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +markupsafe==2.1.3 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +marshmallow==3.20.1 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +mpmath==1.3.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +multidict==6.0.4 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +mypy-extensions==1.0.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +networkx==3.1 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +nltk==3.8.1 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +numexpr==2.8.5 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +numpy==1.24.4 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +openai==0.27.8 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +openapi-schema-pydantic==1.2.4 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +packaging==23.1 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +pandas==1.5.3 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +pandasql==0.7.3 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +pillow==10.0.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +psutil==5.9.5 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +psycopg2-binary==2.9.7 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +pydantic==1.10.12 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +python-dateutil==2.8.2 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +python-editor==1.0.4 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +pytz==2023.3 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +pyyaml==6.0.1 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +readchar==4.0.5 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +regex==2023.8.8 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +requests==2.31.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +safetensors==0.3.2 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +scikit-learn==1.3.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +scipy==1.10.1 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +sentence-transformers==2.2.2 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +sentencepiece==0.1.99 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +setuptools==68.1.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +six==1.16.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +sniffio==1.3.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +sqlalchemy-utils==0.41.1 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +sqlalchemy==1.4.49 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +sqlglot==12.4.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +starlette==0.31.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +sympy==1.12 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +tenacity==8.2.3 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +threadpoolctl==3.2.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +tiktoken==0.4.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +tokenizers==0.13.3 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +toml==0.10.2 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +torch==2.0.1 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +torchvision==0.15.2 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +tqdm==4.66.1 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +transformers==4.31.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +typing-extensions==4.7.1 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +typing-inspect==0.9.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +urllib3==2.0.4 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +uvicorn==0.23.2 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +wcwidth==0.2.6 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +win32-setctime==1.1.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" and sys_platform == "win32" +yarl==1.9.2 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" diff --git a/sidekick/__init__.py b/sidekick/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/sidekick/configs/__init__.py b/sidekick/configs/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/sidekick/configs/data_template.py b/sidekick/configs/data_template.py new file mode 100644 index 0000000..0f7f662 --- /dev/null +++ b/sidekick/configs/data_template.py @@ -0,0 +1,11 @@ +# Reference: https://github.com/openai/openai-cookbook/blob/main/examples/Backtranslation_of_SQL_queries.py +question_query_samples = """ +{ + "question": "{}", + "query": "{}" +} +""" + +schema_info_template = {"Column Name": "", "Column Type": "", "Sample Values": []} + +data_samples_template = "'{column_name}' contains values similar to {comma_separated_sample_values}." diff --git a/sidekick/configs/env.toml b/sidekick/configs/env.toml new file mode 100644 index 0000000..2289673 --- /dev/null +++ b/sidekick/configs/env.toml @@ -0,0 +1,23 @@ +[MODEL_INFO] +OPENAI_API_KEY = "" # Needed only for openAI models +MODEL_NAME = "h2ogpt-sql" # Others: e.g. gpt-4, gpt-4-32k, text-davinci-003 +QUANT_TYPE = '4bit' + +[LOCAL_DB_CONFIG] +HOST_NAME = "localhost" +USER_NAME = "sqlite" +PASSWORD = "abc" +DB_NAME = "querydb" +PORT = "5432" + +[LOGGING] +LOG-LEVEL = "DEBUG" + +[DB-DIALECT] +DB_TYPE = "sqlite" + +[TABLE_INFO] +TABLE_INFO_PATH = "examples/demo/table_info.jsonl" +SAMPLE_QNA_PATH = "examples/demo/demo_qa.csv" +TABLE_SAMPLES_PATH = "examples/demo/demo_data.csv" +TABLE_NAME = "demo" diff --git a/sidekick/configs/prompt_template.py b/sidekick/configs/prompt_template.py new file mode 100644 index 0000000..f2224bd --- /dev/null +++ b/sidekick/configs/prompt_template.py @@ -0,0 +1,78 @@ +# Chain of thought for reasoning and task decomposition +# Reference: https://arxiv.org/pdf/2201.11903.pdf +TASK_PROMPT = { + "system_prompt": "Act as a Data Analyst", + "user_prompt": """ + ### For table {_table_name}, given an input *Question*, let's work it out in a detailed step by step way and only return specific, detailed and informative tasks as an ordered numeric list for SQL generation to be sure we have the right answer. + Use values that are explicitly mentioned in the *Question*. + Use the *History* and *Context* section for co-reference and to infer relationships and identify column names. *Context* contains entity mapping containing keys:values. + If the words in the *Question* do not match column names *Data* section; Search for them in *Context* section. + Always use *Context* with highest similarity score with the *Question*. + If words in the *Question* match more than one key, include both the values using "or" when forming step by step tasks. + If no information related to the *Question* is found; apply self reasoning and predict for possible tasks. + Infer the return type of the Question. + Do not generate SQL response, only return itemized tasks. + # *Data:* \nFor table {_table_name} schema info is mentioned below,\n{_data_info} + # *History*: \n{_sample_queries} + # *Question*: For table {_table_name}, {_question_str}, *Context*: {_context} + # Output: Tasks: ordered numeric list of tasks + """, +} + +# Few shot learning prompt +## Task Description +## Examples +## Prompt +# Reference: https://arxiv.org/pdf/2005.14165.pdf +QUERY_PROMPT = """ + ### System: Act as a SQL Expert + # For table {_table_name}, given an input *Question*, only generate syntactically correct SQL queries. + # Let's work it out in a detailed step by step way using the reasoning from *Tasks* section. + # Pick the SQL query which has the highest average log probability if more than one result is likely to answer the + candidate *Question*. + ### {_dialect} SQL tables + ### *Data:* \nFor table {_table_name} schema info is mentioned below,\n{_data_info} + ### *History*:\n{_sample_queries} + ### *Question*: For table {_table_name}, {_question} + # SELECT 1 + ### *Tasks for table {_table_name}*:\n{_tasks} + ### *Policies for SQL generation*: + # Avoid overly complex SQL queries + # Use values that are explicitly mentioned in the question. + # Don't use aggregate and window function together + # Use COUNT(1) instead of COUNT(*) + # Return with LIMIT 100 + # Prefer NOT EXISTS to LEFT JOIN ON null id + # Avoid using the WITH statement + # When using DESC keep NULLs at the end + # If JSONB format found in Table schema, do pattern matching on keywords from the question and use SQL functions such as ->> or -> + # Add explanation and reasoning for each SQL query + """ + +DEBUGGING_PROMPT = { + "system_prompt": "Act as a SQL expert for {_dialect} code", + "user_prompt": """ + ### Fix syntax errors for provided incorrect SQL Query. + # Add ``` as prefix and ``` as suffix to generated SQL + # Error: {ex_traceback} + # Add explanation and reasoning for each SQL query + Query:\n {qry_txt} + """, +} + +NSQL_QUERY_PROMPT = """ +For SQL TABLE '{table_name}' sample question/answer pairs,\n({sample_queries}) + +CREATE TABLE '{table_name}'({column_info} +) + +Table '{table_name}' has sample values ({data_info_detailed}) + + + +-- Using valid SQLite, answer the following questions with the information for '{table_name}' provided above; for final SQL only use column names from the CREATE TABLE. + + +-- Using reference for TABLES '{table_name}' {context}; {question_txt}? + +SELECT""" diff --git a/sidekick/db_config.py b/sidekick/db_config.py new file mode 100644 index 0000000..41d6866 --- /dev/null +++ b/sidekick/db_config.py @@ -0,0 +1,237 @@ +# create db with supplied info +import json +from pathlib import Path + +import pandas as pd +import psycopg2 as pg +import sqlalchemy +from psycopg2.extras import Json +from sidekick.configs.data_template import data_samples_template +from sidekick.logger import logger +from sqlalchemy import create_engine, text +from sqlalchemy.exc import SQLAlchemyError +from sqlalchemy_utils import database_exists + + +class DBConfig: + def __init__( + self, + db_name, + hostname, + user_name, + password, + port, + base_path, + schema_info_path=None, + schema_info=None, + dialect="sqlite", + ) -> None: + self.db_name = db_name + self.hostname = hostname + self.user_name = user_name + self.password = password + self.port = port + self._table_name = None + self.schema_info_path = schema_info_path + self.schema_info = schema_info + self._engine = None + self.dialect = dialect + self.base_path = base_path + self.column_names = [] + if dialect == "sqlite": + self._url = f"sqlite:///{base_path}/db/sqlite/{db_name}.db" + else: + self._url = f"{self.dialect}://{self.user_name}:{self.password}@{self.hostname}:{self.port}/" + + @property + def table_name(self): + return self._table_name + + @table_name.setter + def table_name(self, val): + self._table_name = val.lower().replace(" ", "_") + + @property + def engine(self): + return self._engine + + def db_exists(self): + if self.dialect == "sqlite": + engine = create_engine(f"{self._url}", echo=True) + else: + engine = create_engine(f"{self._url}{self.db_name}", echo=True) + return database_exists(f"{engine.url}") + + def create_db(self): + engine = create_engine(self._url) + self._engine = engine + try: + with engine.connect() as conn: + # conn.execute("commit") + # Do not substitute user-supplied database names here. + if self.dialect != "sqlite": + conn.execute("commit") + res = conn.execute(f"CREATE DATABASE {self.db_name}") + self._url = f"{self._url}{self.db_name}" + return res, None + else: + logger.debug("SQLite DB is created when 'engine.connect()' is called") + + return True, None + except SQLAlchemyError as sqla_error: + logger.debug("SQLAlchemy error:", sqla_error) + return None, sqla_error + except Exception as error: + logger.debug("Error Occurred:", error) + return None, error + + def _extract_schema_info(self, schema_info_path=None): + # From jsonl format + # E.g. {"Column Name": "id", "Column Type": "uuid PRIMARY KEY"} + if schema_info_path is None: + table_info_file = f"{self.base_path}/var/lib/tmp/data/table_context.json" + if Path(table_info_file).exists(): + with open(table_info_file, "w") as outfile: + schema_info_path = json.load(outfile)["schema_info_path"] + res = [] + sample_values = [] + try: + if Path(schema_info_path).exists(): + with open(schema_info_path, "r") as in_file: + for line in in_file: + if line.strip(): + data = json.loads(line) + if "Column Name" in data and "Column Type" in data: + col_name = data["Column Name"] + self.column_names.append(col_name) + col_type = data["Column Type"] + if col_type.lower() == "text": + col_type = col_type + " COLLATE NOCASE" + # if column has sample values, save in cache for future use. + if "Sample Values" in data: + _sample_values = data["Sample Values"] + _ds = data_samples_template.format( + column_name=col_name, + comma_separated_sample_values=",".join( + str(_sample_val) for _sample_val in _sample_values + ), + ) + sample_values.append(_ds) + _new_samples = f"{col_name} {col_type}" + res.append(_new_samples) + if len(sample_values) > 0: + # cache it for future use + with open( + f"{self.base_path}/var/lib/tmp/data/{self._table_name}_column_values.json", "w" + ) as outfile: + json.dump(sample_values, outfile, indent=2, sort_keys=False) + except ValueError as ve: + logger.error(f"Error in reading table context file: {ve}") + pass + return res + + def create_table(self, schema_info_path=None, schema_info=None): + try: + engine = create_engine(self._url, isolation_level="AUTOCOMMIT") + self._engine = engine + if self.schema_info is None: + if schema_info is not None: + self.schema_info = schema_info + else: + # If schema information is not provided, extract from the template. + self.schema_info = """,\n""".join(self._extract_schema_info(schema_info_path)).strip() + logger.debug(f"Schema info used for creating table:\n {self.schema_info}") + create_syntax = f""" + CREATE TABLE IF NOT EXISTS {self.table_name} ( + {self.schema_info} + ) + """ + with engine.connect() as conn: + if self.dialect != "sqlite": + conn.execute("commit") + conn.execute(create_syntax) + + return self.table_name, None + except SQLAlchemyError as sqla_error: + logger.debug("SQLAlchemy error:", sqla_error) + return None, sqla_error + except Exception as error: + logger.debug("Error Occurred:", error) + return None, error + + def has_table(self): + engine = create_engine(self._url) + return sqlalchemy.inspect(engine).has_table(self.table_name) + + def add_samples(self, data_csv_path=None): + conn_str = self._url + try: + df_chunks = pd.read_csv(data_csv_path, chunksize=10000) + engine = create_engine(conn_str, isolation_level="AUTOCOMMIT") + + sample_query = f"SELECT COUNT(*) AS ROWS FROM {self.table_name} LIMIT 1" + for idx, chunk in enumerate(df_chunks): + # Write rows to database + logger.debug(f"Inserting chunk: {idx}") + chunk.columns = self.column_names + # Make sure column names in the data-frame match the schema + chunk.to_sql(self.table_name, engine, if_exists="append", index=False, method="multi") + + logger.info(f"Data inserted into table: {self.table_name}") + # Fetch the number of rows from the table + num_rows = pd.read_sql_query(sample_query, engine) + logger.info(f"Number of rows inserted: {num_rows.values[0][0]}") + engine.dispose() + return num_rows, None + except SQLAlchemyError as sqla_error: + logger.debug("SQLAlchemy error:", sqla_error) + return None, sqla_error + except Exception as error: + logger.debug("Error Occurred:", error) + return None, error + finally: + if engine: + engine.dispose() + + def execute_query_db(self, query=None, n_rows=100): + output = [] + if self.dialect != "sqlite": + conn_str = f"{self._url}{self.db_name}" + else: + conn_str = self._url + + try: + if query: + # Create an engine + engine = create_engine(conn_str) + + # Create a connection + connection = engine.connect() + logger.debug(f"Executing query:\n {query}") + _query = text(query) + result = connection.execute(_query) + + # Process the query results + cnt = 0 + for row in result: + if cnt <= n_rows: + # Access row data using row[column_name] + output.append(row) + cnt += 1 + else: + break + # Close the connection + connection.close() + + # Close the engine + engine.dispose() + else: + logger.info("Query Empty or None!") + return output, None + except Exception as e: + err = f"Error occurred : {format(e)}" + logger.info(err) + return None, err + finally: + connection.close() + engine.dispose() diff --git a/sidekick/logger.py b/sidekick/logger.py new file mode 100644 index 0000000..b7a9999 --- /dev/null +++ b/sidekick/logger.py @@ -0,0 +1,9 @@ +from loguru import logger +import sys +import toml +from pathlib import Path + +logger.remove() +base_path = (Path(__file__).parent / "../").resolve() +env_settings = toml.load(f"{base_path}/sidekick/configs/env.toml") +logger.add(sys.stderr, level=env_settings["LOGGING"]["LOG-LEVEL"]) diff --git a/sidekick/memory.py b/sidekick/memory.py new file mode 100644 index 0000000..0c23458 --- /dev/null +++ b/sidekick/memory.py @@ -0,0 +1,81 @@ +import json +import re +from pathlib import Path +from typing import Dict, List, Tuple + + +# Reference: https://python.langchain.com/en/latest/modules/memory/examples/custom_memory.html +class EntityMemory: + def __init__(self, k, path: str = None): + self.k = k + self.track_history: List = [] + self.track_entity: List = [] + self.path = path + + def extract_entity(self, question: str, answer: str) -> Tuple[List, List]: + # Currently, anything tagged between below tags are extracted respectively, + # 1. From Input text: some key + # 2. From Output text: some key + # TODO Chat mode for auto extraction of entities + c_k = re.findall(r"(.+?)", question) + c_val = re.findall(r"(.+?)", answer) + return (c_k, c_val) + + def save_context(self, info: str, extract_context: bool = True) -> Dict: + # Construct dictionary to record history + # { + # 'Query': + # 'Answer': + # } + # Extract info from the supplied text + split_token = ";" + query = " ".join(info.partition(":")[2].split(split_token)[0].strip().split()) + response = " ".join(info.partition(":")[2].split(split_token)[1].partition(":")[2].strip().split()) + # TODO add additional guardrails to check if the response is a valid response. + # At-least syntactically correct SQL. + + # Check if entity extraction is enabled + # Add logic for entity extraction + extracted_entity = None + if extract_context: + _k, _v = self.extract_entity(query, response) + k_v = " ".join(_k) + c_v = ", ".join(_v) + extracted_entity = {k_v: c_v} + self.track_entity.append(extracted_entity) + + chat_history = {} + if query.strip() and "SELECT".lower() in response.lower(): + # Remove and tags from the query/response before persisting + query = ( + query.lower().replace("", "").replace("", "").replace("", "").replace("", "") + ) + response = ( + response.lower() + .replace("", "") + .replace("", "") + .replace("", "") + .replace("", "") + ) + chat_history = {"Query": query, "Answer": response, "Entity": extracted_entity} + self.track_history.append(chat_history) + else: + raise ValueError("Response not valid. Please try again.") + # persist the information for future use + res = {"history": self.track_history, "entity": self.track_entity} + + # Persist added information locally + if chat_history: + with open(f"{self.path}/var/lib/tmp/data/history.jsonl", "a") as outfile: + json.dump(chat_history, outfile) + outfile.write("\n") + if extract_context: + # Update context.json file for tracking entities + content_file_path = f"{self.path}/var/lib/tmp/data/context.json" + context_dict = extracted_entity + if Path(content_file_path).exists(): + context_dict = json.load(open(content_file_path, "r")) + context_dict.update(extracted_entity) + with open(content_file_path, "w") as outfile: + json.dump(context_dict, outfile, indent=4, sort_keys=False) + return res diff --git a/sidekick/prompter.py b/sidekick/prompter.py new file mode 100644 index 0000000..cb566de --- /dev/null +++ b/sidekick/prompter.py @@ -0,0 +1,558 @@ +import gc +import json +import os +import string +from pathlib import Path + +import click +import openai +import toml +import torch +from colorama import Back as B +from colorama import Fore as F +from colorama import Style +from loguru import logger +from pandasql import sqldf +from sidekick.db_config import DBConfig +from sidekick.memory import EntityMemory +from sidekick.query import SQLGenerator +from sidekick.schema_generator import generate_schema +from sidekick.utils import (_execute_sql, check_vulnerability, + execute_query_pd, extract_table_names, save_query, + setup_dir) + +# Load the config file and initialize required paths +base_path = (Path(__file__).parent / "../").resolve() +env_settings = toml.load(f"{base_path}/sidekick/configs/env.toml") +db_dialect = env_settings["DB-DIALECT"]["DB_TYPE"] +model_name = env_settings["MODEL_INFO"]["MODEL_NAME"] +os.environ["TOKENIZERS_PARALLELISM"] = "False" +__version__ = "0.0.4" + + +def color(fore="", back="", text=None): + return f"{fore}{back}{text}{Style.RESET_ALL}" + + +msg = """Welcome to the SQL Sidekick!\nI am an AI assistant that helps you with SQL queries. +I can help you with the following:\n +1. Configure a local database(for schema validation and syntax checking): `sql-sidekick configure db-setup`.\n +2. Learn contextual query/answer pairs: `sql-sidekick learn add-samples`.\n +3. Simply add context: `sql-sidekick learn update-context`.\n +4. Ask a question: `sql-sidekick query`. +""" + + +@click.group(help=msg) +@click.version_option("-V", "--version", message=f"sql-sidekick - {__version__}") +def cli(): + # Book-keeping + setup_dir(base_path) + + +@cli.group("configure") +def configure(): + """Helps in configuring local database.""" + + +def enter_table_name(): + val = input(color(F.GREEN, "", "Would you like to create a table for the database? (y/n): ")) + return val + + +def enter_file_path(table: str): + val = input(color(F.GREEN, "", f"Please input the CSV file path to table {table} : ")) + return val + + +@configure.command("log", help="Adjust log settings") +@click.option("--set_level", "-l", help="Set log level (Default: INFO)") +def set_loglevel(set_level): + env_settings["LOGGING"]["LOG-LEVEL"] = set_level + # Update settings file for future use. + f = open(f"{base_path}/sidekick/configs/env.toml", "w") + toml.dump(env_settings, f) + f.close() + + +def _get_table_info(cache_path: str): + # Search for the file in the default current path, if not present ask user to enter the path + if Path(f"{cache_path}/table_info.jsonl").exists(): + table_info_path = f"{cache_path}/table_info.jsonl" + else: + # Check in table cache before requesting + if Path(f"{cache_path}/table_context.json").exists(): + f = open(f"{cache_path}/table_context.json", "r") + table_metadata = json.load(f) + if "schema_info_path" in table_metadata: + table_info_path = table_metadata["schema_info_path"] + else: + table_info_path = click.prompt("Enter table info path") + table_metadata["schema_info_path"] = table_info_path + with open(f"{cache_path}/table_context.json", "w") as outfile: + json.dump(table_metadata, outfile, indent=4, sort_keys=False) + else: + table_info_path = click.prompt("Enter table info path") + table_metadata = {"schema_info_path": table_info_path} + with open(f"{cache_path}/table_context.json", "w") as outfile: + json.dump(table_metadata, outfile, indent=4, sort_keys=False) + return table_info_path + + +def update_table_info(cache_path: str, table_info_path: str = None, table_name: str = None): + if Path(f"{cache_path}/table_context.json").exists(): + f = open(f"{cache_path}/table_context.json", "r") + table_metadata = json.load(f) + if table_name: + table_metadata["tables_in_use"] = [table_name] + if table_info_path: + table_metadata["schema_info_path"] = table_info_path + else: + table_metadata = dict() + if table_name: + table_metadata["tables_in_use"] = [table_name] + if table_info_path: + table_metadata["schema_info_path"] = table_info_path + + table_metadata["data_table_map"] = {} + with open(f"{cache_path}/table_context.json", "w") as outfile: + json.dump(table_metadata, outfile, indent=4, sort_keys=False) + + +@configure.command( + "generate_schema", help=f"Helps generate default schema for the selected Database dialect: {db_dialect}" +) +@click.option("--data_path", default="data.csv", help="Enter the path of csv", type=str) +@click.option("--output_path", default="table_info.jsonl", help="Enter the path of generated schema in jsonl", type=str) +def generate_input_schema(data_path, output_path): + o_path = generate_schema(data_path, output_path) + click.echo(f"Schema generated for the input data at {o_path}") + + +@configure.command("db-setup", help=f"Enter information to configure {db_dialect} database locally") +@click.option("--db_name", "-n", default="querydb", help="Database name", prompt="Enter Database name") +@click.option("--hostname", "-h", default="localhost", help="Database hostname", prompt="Enter hostname name") +@click.option("--user_name", "-u", default=f"{db_dialect}", help="Database username", prompt="Enter username name") +@click.option( + "--password", + "-p", + default="abc", + hide_input=True, + help="Database password", + prompt="Enter password", +) +@click.option("--port", "-P", default=5432, help="Database port", prompt="Enter port (default 5432)") +@click.option("--table-info-path", "-t", help="Table info path", default=None) +def db_setup(db_name: str, hostname: str, user_name: str, password: str, port: int, table_info_path: str): + db_setup_api( + db_name=db_name, + hostname=hostname, + user_name=user_name, + password=password, + port=port, + table_info_path=table_info_path, + table_samples_path=None, + table_name=None, + is_command=True, + ) + + +def db_setup_api( + db_name: str, + hostname: str, + user_name: str, + password: str, + port: int, + table_info_path: str, + table_samples_path: str, + table_name: str, + is_command: bool = False, +): + """Creates context for the new Database""" + click.echo(f" Information supplied:\n {db_name}, {hostname}, {user_name}, {password}, {port}") + try: + res = err = None + env_settings["LOCAL_DB_CONFIG"]["HOST_NAME"] = hostname + env_settings["LOCAL_DB_CONFIG"]["USER_NAME"] = user_name + env_settings["LOCAL_DB_CONFIG"]["PASSWORD"] = password + env_settings["LOCAL_DB_CONFIG"]["PORT"] = port + env_settings["LOCAL_DB_CONFIG"]["DB_NAME"] = db_name + + # To-DO + # --- Need to remove the below keys from ENV toml --- # + # env_settings["TABLE_INFO"]["TABLE_INFO_PATH"] = table_info_path + # env_settings["TABLE_INFO"]["TABLE_SAMPLES_PATH"] = table_samples_path + + # Update settings file for future use. + f = open(f"{base_path}/sidekick/configs/env.toml", "w") + toml.dump(env_settings, f) + f.close() + path = f"{base_path}/var/lib/tmp/data" + # For current session + db_obj = DBConfig(db_name, hostname, user_name, password, port, base_path=base_path, dialect=db_dialect) + + # Create Database + if db_obj.dialect == "sqlite" and not os.path.isfile(f"{base_path}/db/sqlite/{db_name}.db"): + res, err = db_obj.create_db() + click.echo("Database created successfully!") + elif not db_obj.db_exists(): + res, err = db_obj.create_db() + click.echo("Database created successfully!") + else: + click.echo("Database already exists!") + + # Create Table in DB + val = enter_table_name() if is_command else "y" + while True: + if val.lower() != "y" and val.lower() != "n": + click.echo("In-correct values. Enter Yes(y) or no(n)") + val = enter_table_name() + else: + break + + if table_info_path is None: + table_info_path = _get_table_info(path) + + if val.lower() == "y" or val.lower() == "yes": + table_value = input("Enter table name: ") if is_command else table_name + click.echo(f"Table name: {table_value}") + # set table name + db_obj.table_name = table_value.lower().replace(" ", "_") + res, err = db_obj.create_table(table_info_path) + + update_table_info(path, table_info_path, db_obj.table_name) + # Check if table exists; pending --> and doesn't have any rows + # Add rows to table + if db_obj.has_table(): + click.echo(f"Checked table {db_obj.table_name} exists in the DB.") + val = ( + input(color(F.GREEN, "", "Would you like to add few sample rows (at-least 3)? (y/n):")) + if is_command + else "y" + ) + if val.lower().strip() == "y" or val.lower().strip() == "yes": + val = input("Path to a CSV file to insert data from:") if is_command else table_samples_path + res, err = db_obj.add_samples(val) + else: + click.echo("Exiting...") + return + else: + echo_msg = "Job done. Ask a question now!" + click.echo(echo_msg) + + if err is None: + return ( + f"Created a Database {db_name}. Inserted sample values from {table_samples_path} into table {table_name}, please ask questions!", + None, + ) + else: + return None, err + except Exception as e: + echo_msg = f"Error creating database. Check configuration parameters.\n: {e}" + click.echo(echo_msg) + if not is_command: + return echo_msg + + +@cli.group("learn") +def learn(): + """Helps in learning and building memory.""" + + +def _add_context(entity_memory: EntityMemory): + _FORMAT = '''# Add input Query and Response \n\n +"Query": "";\n +"Response": """""" +''' + res = click.edit(_FORMAT.replace("\t", "")) + # Check if user has entered any value + if res: + try: + _ = entity_memory.save_context(res) + except ValueError as ve: + logger.info(f"Not a valid input. Try again") + + +@learn.command("add-samples", help="Helps add contextual query/answer pairs.") +def add_query_response(): + em = EntityMemory(k=5, path=base_path) + _add_context(em) + _more = "y" + while _more.lower() != "n" or _more.lower() != "no": + _more = click.prompt("Would you like to add more samples? (y/n)") + if _more.lower() == "y": + _add_context(em) + else: + break + + +@learn.command("update-context", help="Update context in memory for future use") +def update_context(): + """Helps learn context for generation.""" + # Book-keeping + setup_dir(base_path) + + context_dict = """{\n"": ""\n} + """ + content_file_path = f"{base_path}/var/lib/tmp/data/context.json" + context_str = context_dict + if Path(f"{base_path}/var/lib/tmp/data/context.json").exists(): + context_dict = json.load(open(content_file_path, "r")) + context_dict[""] = "" in context_dict: + del context_dict[""] + path = f"{base_path}/var/lib/tmp/data/" + with open(f"{path}/context.json", "w") as outfile: + json.dump(context_dict, outfile, indent=4, sort_keys=False) + else: + logger.debug("No content updated ...") + + +@cli.command() +@click.option("--question", "-q", help="Database name", prompt="Ask a question") +@click.option("--table-info-path", "-t", help="Table info path", default=None) +@click.option("--sample_qna_path", "-s", help="Samples path", default=None) +def query(question: str, table_info_path: str, sample_qna_path: str): + """Asks question and returns SQL.""" + query_api( + question=question, + table_info_path=table_info_path, + sample_queries_path=sample_qna_path, + table_name=None, + is_command=True, + ) + + +def query_api( + question: str, + table_info_path: str, + sample_queries_path: str, + table_name: str, + is_regenerate: bool = False, + is_regen_with_options: bool = False, + is_command: bool = False, +): + """Asks question and returns SQL.""" + results = [] + err = None # TODO - Need to handle errors if occurred + # Book-keeping + setup_dir(base_path) + + # Check if table exists + path = f"{base_path}/var/lib/tmp/data" + table_context_file = f"{path}/table_context.json" + table_context = json.load(open(table_context_file, "r")) if Path(table_context_file).exists() else {} + table_names = [] + + if table_name is not None: + table_names = [table_name.lower().replace(" ", "_")] + elif table_context and "tables_in_use" in table_context: + _tables = table_context["tables_in_use"] + table_names = [_t.lower().replace(" ", "_") for _t in _tables] + else: + # Ask for table name only when more than one table exists. + table_names = [click.prompt("Which table to use?")] + table_context["tables_in_use"] = [_t.lower().replace(" ", "_") for _t in table_names] + with open(f"{path}/table_context.json", "w") as outfile: + json.dump(table_context, outfile, indent=4, sort_keys=False) + logger.info(f"Table in use: {table_names}") + # Check if env.toml file exists + api_key = None + if model_name != "h2ogpt-sql": + api_key = env_settings["MODEL_INFO"]["OPENAI_API_KEY"] + if api_key is None or api_key == "": + if os.getenv("OPENAI_API_KEY") is None or os.getenv("OPENAI_API_KEY") == "": + if is_command: + val = input( + color( + F.GREEN, "", "Looks like API key is not set, would you like to set OPENAI_API_KEY? (y/n):" + ) + ) + if val.lower() == "y": + api_key = input(color(F.GREEN, "", "Enter OPENAI_API_KEY :")) + + if api_key is None and is_command: + return ["Looks like API key is not set, please set OPENAI_API_KEY!"], err + + os.environ["OPENAI_API_KEY"] = api_key + env_settings["MODEL_INFO"]["OPENAI_API_KEY"] = api_key + + # Update settings file for future use. + f = open(f"{base_path}/sidekick/configs/env.toml", "w") + toml.dump(env_settings, f) + f.close() + openai.api_key = api_key + + try: + # Set context + logger.info("Setting context...") + logger.info(f"Question: {question}") + # Get updated info from env.toml + host_name = env_settings["LOCAL_DB_CONFIG"]["HOST_NAME"] + user_name = env_settings["LOCAL_DB_CONFIG"]["USER_NAME"] + passwd = env_settings["LOCAL_DB_CONFIG"]["PASSWORD"] + db_name = env_settings["LOCAL_DB_CONFIG"]["DB_NAME"] + + if db_dialect == "sqlite": + db_url = f"sqlite:///{base_path}/db/sqlite/{db_name}.db" + else: + db_url = f"{db_dialect}+psycopg2://{user_name}:{passwd}@{host_name}/{db_name}".format( + user_name, passwd, host_name, db_name + ) + + if table_info_path is None: + table_info_path = _get_table_info(path) + + sql_g = SQLGenerator( + db_url, + api_key, + job_path=base_path, + data_input_path=table_info_path, + sample_queries_path=sample_queries_path, + is_regenerate_with_options=is_regen_with_options, + is_regenerate=is_regenerate, + ) + if "h2ogpt-sql" not in model_name: + sql_g._tasks = sql_g.generate_tasks(table_names, question) + results.extend(["List of Actions Generated: \n", sql_g._tasks, "\n"]) + click.echo(sql_g._tasks) + + updated_tasks = None + if sql_g._tasks is not None and is_command: + edit_val = click.prompt("Would you like to edit the tasks? (y/n)") + if edit_val.lower() == "y": + updated_tasks = click.edit(sql_g._tasks) + click.echo(f"Tasks:\n {updated_tasks}") + else: + click.echo("Skipping edit...") + if updated_tasks is not None: + sql_g._tasks = updated_tasks + alt_res = None + # The interface could also be used to simply execute user provided SQL + # Keyword: "Execute SQL: " + if ( + question is not None + and "select" in question.lower() + and (question.lower().startswith("question:") or question.lower().startswith("q:")) + ): + _q = question.lower().split("q:")[1].split("r:")[0].strip() + res = question.lower().split("r:")[1].strip() + question = _q + elif _execute_sql(question): + logger.info("Executing user provided SQL without re-generation...") + res = question.strip().lower().split("execute sql:")[1].strip() + else: + res, alt_res = sql_g.generate_sql(table_names, question, model_name=model_name, _dialect=db_dialect) + logger.info(f"Input query: {question}") + logger.info(f"Generated response:\n\n{res}") + + if res is not None: + updated_sql = None + res_val = "e" + if is_command: + while res_val.lower() in ["e", "edit", "r", "regenerate"]: + res_val = click.prompt( + "Would you like to 'edit' or 'regenerate' the SQL? Use 'e' to edit or 'r' to regenerate. " + "To skip, enter 's' or 'skip'" + ) + if res_val.lower() == "e" or res_val.lower() == "edit": + updated_sql = click.edit(res) + click.echo(f"Updated SQL:\n {updated_sql}") + elif res_val.lower() == "r" or res_val.lower() == "regenerate": + click.echo("Attempting to regenerate...") + res, alt_res = sql_g.generate_sql( + table_names, question, model_name=model_name, _dialect=db_dialect + ) + logger.info(f"Input query: {question}") + logger.info(f"Generated response:\n\n{res}") + + results.extend([f"**Generated response for question,**\n{question}\n", res, "\n"]) + logger.info(f"Alternate responses:\n\n{alt_res}") + + exe_sql = click.prompt("Would you like to execute the generated SQL (y/n)?") if is_command else "y" + if exe_sql.lower() == "y" or exe_sql.lower() == "yes": + # For the time being, the default option is Pandas, but the user can be asked to select Database or pandas DF later. + q_res = None + option = "DB" # or DB + _val = updated_sql if updated_sql else res + if option == "DB": + hostname = env_settings["LOCAL_DB_CONFIG"]["HOST_NAME"] + user_name = env_settings["LOCAL_DB_CONFIG"]["USER_NAME"] + password = env_settings["LOCAL_DB_CONFIG"]["PASSWORD"] + port = env_settings["LOCAL_DB_CONFIG"]["PORT"] + db_name = env_settings["LOCAL_DB_CONFIG"]["DB_NAME"] + + db_obj = DBConfig( + db_name, hostname, user_name, password, port, base_path=base_path, dialect=db_dialect + ) + + # Before executing, check if known vulnerabilities exist in the generated SQL code. + _val = _val.replace("β€œ", '"').replace("”", '"') + [_val := _val.replace(s, "'") for s in "β€˜`" if s in _val] + r, m = check_vulnerability(_val) + if not r: + q_res, err = db_obj.execute_query_db(query=_val) + else: + q_res = m + + elif option == "pandas": + tables = extract_table_names(_val) + tables_path = dict() + if Path(f"{path}/table_context.json").exists(): + f = open(f"{path}/table_context.json", "r") + table_metadata = json.load(f) + for table in tables: + # Check if the local table_path exists in the cache + if table not in table_metadata["data_table_map"].keys(): + val = enter_file_path(table) + if not os.path.isfile(val): + click.echo("In-correct Path. Please enter again! Yes(y) or no(n)") + else: + tables_path[table] = val + table_metadata["data_table_map"][table] = val + break + else: + tables_path[table] = table_metadata["data_table_map"][table] + assert len(tables) == len(tables_path) + with open(f"{path}/table_context.json", "w") as outfile: + json.dump(table_metadata, outfile, indent=4, sort_keys=False) + try: + q_res = execute_query_pd(query=_val, tables_path=tables_path, n_rows=100) + click.echo(f"The query results are:\n {q_res}") + except sqldf.PandaSQLException as e: + logger.error(f"Error in executing the query: {e}") + click.echo("Error in executing the query. Validate generated SQL and try again.") + click.echo("No result to display.") + + results.append("**Result:** \n") + if q_res: + click.echo(f"The query results are:\n {q_res}") + results.extend([str(q_res), "\n"]) + else: + click.echo(f"While executing query:\n {err}") + results.extend([str(err), "\n"]) + + save_sql = click.prompt("Would you like to save the generated SQL (y/n)?") if is_command else "n" + if save_sql.lower() == "y" or save_sql.lower() == "yes": + # Persist for future use + _val = updated_sql if updated_sql else res + save_query(base_path, query=question, response=_val) + else: + click.echo("Exiting...") + except (MemoryError, RuntimeError) as e: + logger.error(f"Something went wrong while generating response: {e}") + del sql_g + gc.collect() + torch.cuda.empty_cache() + alt_res, err = None, None + results = ["Something went wrong while generating response. Please try again."] + return results, alt_res, err + + +if __name__ == "__main__": + cli() diff --git a/sidekick/query.py b/sidekick/query.py new file mode 100644 index 0000000..5ac42ed --- /dev/null +++ b/sidekick/query.py @@ -0,0 +1,587 @@ +import gc +import json +import os +import random +import sys +from pathlib import Path + +import numpy as np +import openai +import sqlglot +import torch +import torch.nn.functional as F +from langchain import OpenAI +from llama_index import (GPTSimpleVectorIndex, GPTSQLStructStoreIndex, + LLMPredictor, ServiceContext, SQLDatabase) +from llama_index.indices.struct_store import SQLContextContainerBuilder +from sidekick.configs.prompt_template import (DEBUGGING_PROMPT, + NSQL_QUERY_PROMPT, QUERY_PROMPT, + TASK_PROMPT) +from sidekick.logger import logger +from sidekick.utils import (_check_file_info, filter_samples, is_resource_low, + load_causal_lm_model, load_embedding_model, + read_sample_pairs, remove_duplicates) +from sqlalchemy import create_engine + + +class SQLGenerator: + _instance = None + + def __new__( + cls, + db_url: str, + openai_key: str = None, + model_name="NumbersStation/nsql-llama-2-7B", + data_input_path: str = "./table_info.jsonl", + sample_queries_path: str = "./samples.csv", + job_path: str = "./", + device: str = "auto", + is_regenerate: bool = False, + is_regenerate_with_options: bool = False, + ): + offloading = is_resource_low() + if offloading and is_regenerate_with_options: + del cls._instance + cls._instance = None + gc.collect() + torch.cuda.empty_cache() + logger.info(f"Low memory: {offloading}/ Model re-initialization: True") + if cls._instance is None: + cls._instance = super().__new__(cls) + cls._instance.model, cls._instance.tokenizer = load_causal_lm_model( + model_name, + cache_path=f"{job_path}/models/", + device=device, + off_load=offloading, + re_generate=is_regenerate_with_options, + ) + model_embed_path = f"{job_path}/models/sentence_transformers" + device = "cuda" if torch.cuda.is_available() else "cpu" if device == "auto" else device + cls._instance.similarity_model = load_embedding_model(model_path=model_embed_path, device=device) + return cls._instance + + def __init__( + self, + db_url: str, + openai_key: str = None, + model_name="NumbersStation/nsql-llama-2-7B", + data_input_path: str = "./table_info.jsonl", + sample_queries_path: str = "./samples.csv", + job_path: str = "./", + device: str = "cpu", + is_regenerate: bool = False, + is_regenerate_with_options: bool = False, + ): + self.db_url = db_url + self.engine = create_engine(db_url) + self.sql_database = SQLDatabase(self.engine) + self.context_builder = None + self.data_input_path = _check_file_info(data_input_path) + self.sample_queries_path = sample_queries_path + self.path = job_path + self._data_info = None + self._tasks = None + self.model_name = model_name + self.openai_key = openai_key + self.content_queries = None + self.is_regenerate_with_options = is_regenerate_with_options + self.is_regenerate = is_regenerate + self.device = device + + def clear(self): + del SQLGenerator._instance + SQLGenerator._instance = None + + def load_column_samples(self, tables: list): + # TODO: Maybe we add table name as a member variable + # Load column values if they exists + examples = {} + for _t in tables: + f_p = f"{self.path}/var/lib/tmp/data/{_t}_column_values.json" + with open(f_p, "r") as f: + examples[_t] = json.load(f) + return examples + + def build_index(self, persist: bool = True): + # Below re-assignment of the OPENAI API key is weird but without that, it throws an error. + if self.openai_key: + os.environ["OPENAI_API_KEY"] = self.openai_key + openai.api_key = self.openai_key + + table_schema_index = self.context_builder.derive_index_from_context( + GPTSimpleVectorIndex, + ) + if persist: + table_schema_index.save_to_disk(f"{self.path}/sql_index_check.json") + return table_schema_index + + def update_context_queries(self): + # Check if seed samples were provided + new_context_queries = [] + if self.sample_queries_path is not None and Path(self.sample_queries_path).exists(): + logger.info(f"Using QnA samples from path {self.sample_queries_path}") + new_context_queries = read_sample_pairs(self.sample_queries_path, "h2ogpt-sql") + # cache the samples for future use + with open(f"{self.path}/var/lib/tmp/data/queries_cache.json", "w") as f: + json.dump(new_context_queries, f, indent=2) + elif self.sample_queries_path is None and Path(f"{self.path}/var/lib/tmp/data/queries_cache.json").exists(): + logger.info(f"Using samples from cache") + with open(f"{self.path}/var/lib/tmp/data/queries_cache.json", "r") as f: + new_context_queries = json.load(f) + # Read the history file and update the context queries + history_file = f"{self.path}/var/lib/tmp/data/history.jsonl" + try: + if Path(history_file).exists(): + with open(history_file, "r") as in_file: + for line in in_file: + # Format: + # """ + # # query: + # # answer: + # """ + if line.strip(): + data = json.loads(line) + if "Query" in data and "Answer" in data: + query = data["Query"] + response = data["Answer"] + _new_samples = f"""# query: {query}\n# answer: {response}""" + new_context_queries.append(_new_samples) + except ValueError as ve: + logger.error(f"Error in reading history file: {ve}") + pass + return new_context_queries + + def _query_tasks(self, question_str, data_info, sample_queries, table_name: list): + try: + context_file = f"{self.path}/var/lib/tmp/data/context.json" + additional_context = json.load(open(context_file, "r")) if Path(context_file).exists() else {} + + system_prompt = TASK_PROMPT["system_prompt"] + user_prompt = TASK_PROMPT["user_prompt"].format( + _table_name=",".join(table_name), + _data_info=data_info, + _sample_queries=sample_queries, + _context=str(additional_context).lower(), + _question_str=question_str, + ) + # Role and content + query_txt = [{"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}] + logger.debug(f"Query Text:\n {query_txt}") + + # TODO ADD local model + completion = openai.ChatCompletion.create( + model="gpt-3.5-turbo-0301", + messages=query_txt, + ) + res = completion.choices[0].message["content"] + return res + except Exception as se: + _, ex_value, _ = sys.exc_info() + res = ex_value.statement if ex_value.statement else None + return res + + def generate_response( + self, context_container, sql_index, input_prompt, attempt_fix_on_error: bool = True, _dialect: str = "sqlite" + ): + try: + response = sql_index.query(input_prompt, sql_context_container=context_container) + res = response.extra_info["sql_query"] + return res + except Exception as se: + # Take the SQL and make an attempt for correction + _, ex_value, ex_traceback = sys.exc_info() + qry_txt = ex_value.statement + if attempt_fix_on_error: + try: + # Attempt to heal with simple feedback + # Reference: Teaching Large Language Models to Self-Debug, https://arxiv.org/abs/2304.05128 + logger.info(f"Attempting to fix syntax error ...,\n {se}") + system_prompt = DEBUGGING_PROMPT["system_prompt"].format(_dialect=_dialect) + user_prompt = DEBUGGING_PROMPT["user_prompt"].format(ex_traceback=ex_traceback, qry_txt=qry_txt) + # Role and content + query_msg = [{"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}] + + completion = openai.ChatCompletion.create( + model="gpt-3.5-turbo-0301", + messages=query_msg, + ) + res = completion.choices[0].message["content"] + if "SELECT" not in res: + res = qry_txt + return res + except Exception as se: + # Another exception occurred, return the original SQL + res = qry_txt + return res + + def generate_tasks(self, table_names: list, input_question: str): + try: + # Step 1: Given a question, generate tasks to possibly answer the question and persist the result -> tasks.txt + # Step 2: Append task list to 'query_prompt_template', generate SQL code to answer the question and persist the result -> sql.txt + context_queries: list = self.update_context_queries() + logger.info(f"Number of context queries found: {len(context_queries)}") + + # Remove duplicates from the context queries + m_path = f"{self.path}/models/sentence_transformers/" + duplicates_idx = remove_duplicates(context_queries, m_path) + updated_context = np.delete(np.array(context_queries), duplicates_idx).tolist() + + # Filter closest samples to the input question, threshold = 0.45 + filtered_context = ( + filter_samples( + input_question, + updated_context, + m_path, + threshold=0.9, + is_regenerate=True if (self.is_regenerate and not self.is_regenerate_with_options) else False, + ) + if len(updated_context) > 1 + else updated_context + ) + logger.info(f"Number of possible contextual queries to question: {len(filtered_context)}") + _queries = "\n".join(filtered_context) + self.content_queries = _queries + + # data info + input_file = self.data_input_path + data_info = "" + with open(input_file, "r") as in_file: + for line in in_file: + if line.strip(): + data = json.loads(line) + data_info += "\n" + json.dumps(data) + self._data_info = data_info + task_list = self._query_tasks(input_question, data_info, _queries, table_names) + with open(f"{self.path}/var/lib/tmp/data/tasks.txt", "w") as f: + f.write(task_list) + return task_list + except Exception as se: + raise se + + def generate_sql( + self, table_names: list, input_question: str, _dialect: str = "sqlite", model_name: str = "h2ogpt-sql" + ): + # TODO: Update needed to support multiple tables + table_name = str(table_names[0].replace(" ", "_")).lower() + alternate_queries = [] + describe_keywords = ["describe table", "describe", "describe table schema", "describe data"] + enable_describe_qry = any([True for _dk in describe_keywords if _dk in input_question.lower()]) + if input_question is not None and enable_describe_qry: + result = f"""SELECT "name" from PRAGMA_TABLE_INFO("{table_name}")""" + else: + context_file = f"{self.path}/var/lib/tmp/data/context.json" + additional_context = json.load(open(context_file, "r")) if Path(context_file).exists() else {} + table_context_dict = {table_name: str(additional_context).lower()} + context_queries = self.content_queries + self.context_builder = SQLContextContainerBuilder(self.sql_database, context_dict=table_context_dict) + + if model_name != "h2ogpt-sql": + _tasks = self.task_formatter(self._tasks) + + # TODO: The need to pass data info again could be eliminated if Task generation becomes more consistent and accurate. + query_str = QUERY_PROMPT.format( + _dialect=_dialect, + _data_info=self._data_info, + _question=input_question, + _table_name=table_names, + _sample_queries=context_queries, + _tasks=_tasks, + ) + + # Reference: https://github.com/jerryjliu/llama_index/issues/987 + llm_predictor_gpt3 = LLMPredictor(llm=OpenAI(temperature=0.5, model_name=model_name)) + service_context_gpt3 = ServiceContext.from_defaults( + llm_predictor=llm_predictor_gpt3, chunk_size_limit=512 + ) + + table_schema_index = self.build_index(persist=False) + self.context_builder.query_index_for_context(table_schema_index, query_str, store_context_str=True) + context_container = self.context_builder.build_context_container() + + index = GPTSQLStructStoreIndex( + [], sql_database=self.sql_database, table_name=table_names, service_context=service_context_gpt3 + ) + res = self.generate_response(context_container, sql_index=index, input_prompt=query_str) + try: + # Check if `SQL` is formatted ---> ``` SQL_text ``` + if "```" in str(res): + res = ( + str(res) + .split("```", 1)[1] + .split(";", 1)[0] + .strip() + .replace("```", "") + .replace("sql\n", "") + .strip() + ) + else: + res = str(res).split("Explanation:", 1)[0].strip() + sqlglot.transpile(res) + except (sqlglot.errors.ParseError, ValueError, RuntimeError) as e: + logger.info("We did the best we could, there might be still be some error:\n") + logger.info(f"Realized query so far:\n {res}") + else: + # TODO Update needed for multiple tables + columns_w_type = ( + self.context_builder.full_context_dict[table_name].split(":")[2].split("and")[0].strip() + ) + + data_samples_list = self.load_column_samples(table_names) + + _context = { + "if patterns like 'current time' or 'now' occurs in question": "always use NOW() - INTERVAL", + "if patterns like 'total number', or 'List' occurs in question": "always use DISTINCT", + "detailed summary": "include min, avg, max", + "summary": "include min, avg, max", + } + + m_path = f"{self.path}/models/sentence_transformers/" + filtered_context = filter_samples( + model_obj=self.similarity_model, + input_q=input_question, + probable_qs=list(_context.keys()), + model_path=m_path, + threshold=0.90, + ) + logger.debug(f"Filter Context: {filtered_context}") + + contextual_context = [] + for _item in filtered_context: + _val = _context.get(_item, None) + if _val: + contextual_context.append(f"{_item}: {_val}") + + logger.info("Filtering Question/Query pairs ...") + context_queries: list = self.update_context_queries() + logger.info(f"Number of context queries found: {len(context_queries)}") + + # Remove duplicates from the context queries + m_path = f"{self.path}/models/sentence_transformers/" + # duplicates_idx = remove_duplicates(context_queries, m_path, similarity_model=self.similarity_model) + # updated_context = np.delete(np.array(context_queries), duplicates_idx).tolist() + + # Filter closest samples to the input question, threshold = 0.9 + filtered_context = ( + filter_samples( + input_q=input_question, + probable_qs=context_queries, + model_path=m_path, + model_obj=self.similarity_model, + threshold=0.9, + is_regenerate=True if (self.is_regenerate and not self.is_regenerate_with_options) else False, + ) + if len(context_queries) > 1 + else context_queries + ) + logger.info(f"Number of possible contextual queries to question: {len(filtered_context)}") + # If QnA pairs > 5, we keep top 5 for focused context + _samples = filtered_context + if len(filtered_context) > 5: + _samples = filtered_context[0:5][::-1] + + qna_samples = "\n".join(_samples) + + contextual_context_val = ", ".join(contextual_context) + column_names = columns_w_type.strip().split(",") + clmn_names = [i.split("(")[0].strip() for i in column_names] + + context_columns = [] + if len(_samples) > 2: + # Check for the columns in the QnA samples provided, if exists keep them + context_columns = [_c for _c in clmn_names if _c.lower().strip() in qna_samples.lower()] + + # To be safe, when we have more than 2 samples, we check for the column names in the question as well + first_pass = [_c for _c in clmn_names if _c.lower().strip() in input_question.lower().strip()] + _input = input_question.lower().split(" ") + for _c in clmn_names: + for _f in _c.lower().split("_"): + res = _f in _input + if res: + first_pass.append(_c) + context_columns = set(context_columns + first_pass) + if len(context_columns) > 0: + contextual_data_samples = [ + _d + for _cc in context_columns + for _d in data_samples_list[table_name] + if _cc.lower() in _d.lower() + ] + data_samples_list = contextual_data_samples + + relevant_columns = context_columns if len(context_columns) > 0 else clmn_names + _column_info = ", ".join(relevant_columns) + + logger.debug(f"Relevant sample column values: {data_samples_list}") + _table_name = ", ".join(table_names) + + query = NSQL_QUERY_PROMPT.format( + table_name=_table_name, + column_info=_column_info, + data_info_detailed=data_samples_list, + sample_queries=qna_samples, + context=contextual_context_val, + question_txt=input_question, + ) + + logger.debug(f"Query Text:\n {query}") + inputs = self.tokenizer([query], return_tensors="pt") + input_length = 1 if self.model.config.is_encoder_decoder else inputs.input_ids.shape[1] + logger.info(f"Context length: {input_length}") + + # Handle limited context length + # Currently, conservative approach: remove column description from the prompt, if input_length > (2048-300) + # Others to try: + # 1. Move to a model with larger context length + # 2. Possibly use a different tokenizer for chunking + # 3. Maybe positional interpolation --> https://arxiv.org/abs/2306.15595 + if int(input_length) > 4000: + logger.info("Input length is greater than 1748, removing column description from the prompt") + query = NSQL_QUERY_PROMPT.format( + table_name=_table_name, + column_info=_column_info, + data_info_detailed="", + sample_queries=qna_samples, + context=contextual_context_val, + question_txt=input_question, + ) + logger.debug(f"Adjusted query Text:\n {query}") + inputs = self.tokenizer([query], return_tensors="pt") + input_length = 1 if self.model.config.is_encoder_decoder else inputs.input_ids.shape[1] + logger.info(f"Adjusted context length: {input_length}") + # Generate SQL + random_seed = random.randint(0, 50) + torch.manual_seed(random_seed) + + # Greedy search for quick response + self.model.eval() + device_type = "cuda" if torch.cuda.is_available() else "cpu" + + if not self.is_regenerate_with_options and not self.is_regenerate: + # Greedy decoding + output = self.model.generate( + **inputs.to(device_type), + max_new_tokens=300, + temperature=0.5, + output_scores=True, + do_sample=True, + return_dict_in_generate=True, + ) + + generated_tokens = output.sequences[:, input_length:][0] + elif self.is_regenerate and not self.is_regenerate_with_options: + # throttle temperature for different result + logger.info("Regeneration requested on previous query ...") + random_seed = random.randint(0, 50) + torch.manual_seed(random_seed) + possible_temp_choice = [0.1, 0.2, 0.3, 0.6, 0.75, 0.9, 1.0] + random_temperature = np.random.choice(possible_temp_choice, 1)[0] + logger.debug(f"Selected temperature for fast regeneration : {random_temperature}") + output = self.model.generate( + **inputs.to(device_type), + max_new_tokens=300, + temperature=random_temperature, + output_scores=True, + do_sample=True, + return_dict_in_generate=True, + ) + generated_tokens = output.sequences[:, input_length:][0] + else: + logger.info("Regeneration with options requested on previous query ...") + # Diverse beam search decoding to explore more options + random_seed = random.randint(0, 50) + torch.manual_seed(random_seed) + possible_temp_choice = [0.1, 0.3, 0.5, 0.6, 0.75, 0.9, 1.0] + random_temperature = np.random.choice(possible_temp_choice, 1)[0] + logger.debug(f"Selected temperature for diverse beam search: {random_temperature}") + output_re = self.model.generate( + **inputs.to(device_type), + max_new_tokens=300, + temperature=random_temperature, + top_k=5, + top_p=0.9, + num_beams=5, + num_beam_groups=5, + num_return_sequences=5, + output_scores=True, + do_sample=False, + diversity_penalty=2.0, + return_dict_in_generate=True, + ) + + transition_scores = self.model.compute_transition_scores( + output_re.sequences, output_re.scores, output_re.beam_indices, normalize_logits=False + ) + + # Create a boolean tensor where elements are True if the corresponding element in transition_scores is less than 0 + mask = transition_scores < 0 + # Sum the True values along axis 1 + counts = torch.sum(mask, dim=1) + output_length = inputs.input_ids.shape[1] + counts + length_penalty = self.model.generation_config.length_penalty + reconstructed_scores = transition_scores.sum(axis=1) / (output_length**length_penalty) + + # Converting logit scores to prob scores + probabilities_scores = F.softmax(reconstructed_scores, dim=-1) + out_idx = torch.argmax(probabilities_scores) + # Final output + output = output_re.sequences[out_idx] + generated_tokens = output[input_length:] + + logger.info(f"Generated options:\n") + prob_sorted_idxs = sorted( + range(len(probabilities_scores)), key=lambda k: probabilities_scores[k], reverse=True + ) + for idx, sorted_idx in enumerate(prob_sorted_idxs): + _out = output_re.sequences[sorted_idx] + res = self.tokenizer.decode(_out[input_length:], skip_special_tokens=True) + result = res.replace("table_name", _table_name) + if "LIMIT".lower() not in result.lower(): + res = "SELECT " + result.strip() + " LIMIT 100;" + else: + res = "SELECT " + result.strip() + ";" + alt_res = f"Option {idx+1}: (_probability_: {probabilities_scores[sorted_idx]})\n{res}\n" + alternate_queries.append(alt_res) + logger.info(alt_res) + + _res = self.tokenizer.decode(generated_tokens, skip_special_tokens=True) + # Below is a pre-caution in-case of an error in table name during generation + # COLLATE NOCASE is used to ignore case sensitivity, this might be specific to sqlite + _temp = _res.replace("table_name", table_name).split(";")[0] + + if "LIMIT".lower() not in _temp.lower(): + res = "SELECT " + _temp.strip() + " LIMIT 100;" + else: + res = "SELECT " + _temp.strip() + ";" + + # Validate the generate SQL for parsing errors, along with dialect specific validation + # Note: Doesn't do well with handling date-time conversions + # e.g. + # sqlite: SELECT DATETIME(MAX(timestamp), '-5 minute') FROM demo WHERE isin_id = 'VM88109EGG92' + # postgres: SELECT MAX(timestamp) - INTERVAL '5 minutes' FROM demo where isin_id='VM88109EGG92' + # Reference ticket: https://github.com/tobymao/sqlglot/issues/2011 + result = res + try: + result = sqlglot.transpile(res, identify=True, write="sqlite")[0] + except (sqlglot.errors.ParseError, ValueError, RuntimeError) as e: + logger.info("We did the best we could, there might be still be some error:\n") + logger.info(f"Realized query so far:\n {res}") + return result, alternate_queries + + def task_formatter(self, input_task: str): + # Generated format + """ + Tasks: + 1. Generate a SELECT query to display all columns of the {selected tables}. + 2. Infer the return type of the question as a description of the table schema. + 3. Final output: Return the table schema for the selected table. + """ + + # Converted format + """ + # 1. Generate a SELECT query to display all columns of the {selected tables}. + # 2. Infer the return type of the question as a description of the table schema. + """ + _res = input_task.split("\n") + start_index = 1 if "Tasks" in _res[0] else 0 + res = "\n".join([f"# {i}" for i in _res[start_index:]]) # Skip the first line + return res diff --git a/sidekick/schema_generator.py b/sidekick/schema_generator.py new file mode 100644 index 0000000..cdd8a4c --- /dev/null +++ b/sidekick/schema_generator.py @@ -0,0 +1,32 @@ +import json +import re +import pandas as pd + + +def generate_schema(data_path, output_path): + df = pd.read_csv(data_path) + # Extract the schema information + schema = df.dtypes.to_dict() + schema_list = [] + special_characters = {" ": "_", ":": "_", "/": "_", "-": "_"} + + for key, value in schema.items(): + new_key = "".join(special_characters[s] if s in special_characters.keys() else s for s in key) + + if value == "object": + value = "TEXT" + unique_values = df[key].dropna().unique().tolist() + if not bool(re.search(r"[A-Za-z]", unique_values[0])): + schema_list.append({"Column Name": new_key, "Column Type": value}) + else: + schema_list.append({"Column Name": new_key, "Column Type": value, "Sample Values": unique_values}) + else: + value = "NUMERIC" + schema_list.append({"Column Name": new_key, "Column Type": value}) + + # Save the schema information to a JSONL format + with open(output_path, "w") as f: + for item in schema_list: + json.dump(item, f) + f.write("\n") + return output_path diff --git a/sidekick/utils.py b/sidekick/utils.py new file mode 100644 index 0000000..b91f29f --- /dev/null +++ b/sidekick/utils.py @@ -0,0 +1,425 @@ +import glob +import json +import os +import re +from pathlib import Path +from typing import Optional + +import numpy as np +import pandas as pd +import torch +from accelerate import infer_auto_device_map, init_empty_weights +from InstructorEmbedding import INSTRUCTOR +from pandasql import sqldf +from sentence_transformers import SentenceTransformer +from sidekick.logger import logger +from sklearn.metrics.pairwise import cosine_similarity +from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer, + BitsAndBytesConfig) + + +def generate_sentence_embeddings(model_path: str, x, batch_size: int = 32, device: Optional[str] = None): + # Reference: + # 1. https://www.sbert.net/docs/pretrained_models.html#sentence-embedding-models + # 2. Evaluation result: https://www.sbert.net/_static/html/models_en_sentence_embeddings.html + # 3. Model Card: https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2 + # 4. Reference: https://huggingface.co/spaces/mteb/leaderboard + # Maps sentence & paragraphs to a 384 dimensional dense vector space. + model_name_path = f"{model_path}/sentence_transformers/sentence-transformers_all-MiniLM-L6-v2/" + current_torch_home = os.environ.get("TORCH_HOME", "") + if Path(model_name_path).is_dir(): + is_empty = not any(Path(model_name_path).iterdir()) + if is_empty: + # Download n cache at the specified location + # https://public.ukp.informatik.tu-darmstadt.de/reimers/sentence-transformers/v0.2/all-MiniLM-L6-v2.zip + os.environ["TORCH_HOME"] = model_path + model_name_path = "sentence-transformers/all-MiniLM-L6-v2" + sentence_model = SentenceTransformer(model_name_path, device=device) + all_res = np.zeros(shape=(len(x), 0)) + res = sentence_model.encode(x, batch_size=batch_size, show_progress_bar=True) + all_res = np.hstack((all_res, res)) + del sentence_model + os.environ["TORCH_HOME"] = current_torch_home + return all_res + + +def load_embedding_model(model_path: str, device: str): + model_name_path = glob.glob(f"{model_path}/models--BAAI--bge-base-en/snapshots/*/")[0] + + sentence_model = SentenceTransformer(model_name_path, cache_folder=model_path, device=device) + if "cuda" not in device: + # Issue https://github.com/pytorch/pytorch/issues/69364 + # # In the initial experimentation, quantized model is generates slightly better results + logger.debug("Sentence embedding model is quantized ...") + model_obj = torch.quantization.quantize_dynamic(sentence_model, {torch.nn.Linear}, dtype=torch.qint8) + else: + model_obj = sentence_model + return model_obj + + +def generate_text_embeddings(model_path: str, x, model_obj=None, batch_size: int = 32, device: Optional[str] = "cpu"): + # Reference: + # 1. https://www.sbert.net/docs/pretrained_models.html#sentence-embedding-models + # Maps sentence & paragraphs to a 384 dimensional dense vector space. + if model_obj is None: + model_obj = load_embedding_model(model_path, device) + + _sentences = [["Represent the Financial question for retrieving duplicate examples: ", _item] for _item in x] + + res = model_obj.encode(_sentences) + return res + + +def filter_samples( + input_q: str, + probable_qs: list, + model_path: str, + model_obj=None, + threshold: float = 0.80, + device="auto", + is_regenerate: bool = False, +): + # Only consider the questions, note: this might change in future. + _inq = ("# query: " + input_q).strip().lower() + logger.debug(f"Input questions: {_inq}") + _device = "cuda" if torch.cuda.is_available() else "cpu" if device == "auto" else device + question_embeddings = generate_text_embeddings(model_path, x=[_inq], model_obj=model_obj, device=_device) + + input_pqs = [_se.split("# answer")[0].strip().lower() for _se in probable_qs] + logger.debug(f"Probable context: {input_pqs}") + embeddings = generate_text_embeddings(model_path, x=input_pqs, model_obj=model_obj, device=_device) + res = {} + _scores = {} + for idx, _se in enumerate(embeddings): + similarities_score = cosine_similarity( + [_se.astype(float).tolist()], [question_embeddings.astype(float).tolist()[0]] + ) + logger.debug(f"Similarity score for: {input_pqs[idx]}: {similarities_score[0][0]}") + _scores[idx] = similarities_score[0][0] + if similarities_score[0][0] > threshold: + res[str(probable_qs[idx])] = similarities_score[0][0] + + # Get Top N Context Queries if user requested to regenerate regardless of scores + if len(res) == 0 and is_regenerate and len(_scores) > 0: + top_n = min(len(_scores), 2) + sorted_res = dict() + sorted_scores = sorted(_scores, key=_scores.get, reverse=True) + top_idxs = sorted_scores[:top_n] + for idx in top_idxs: + sorted_res[str(probable_qs[idx])] = similarities_score[0][0] + else: + sorted_res = sorted(res.items(), key=lambda x: x[1], reverse=True) + + logger.debug(f"Sorted context: {sorted_res}") + return list(dict(sorted_res).keys()) + + +def remove_duplicates( + input_x: list, model_path: str, similarity_model=None, threshold: float = 0.989, device: str = "cpu" +): + # Remove duplicates pairs + if input_x is None or len(input_x) < 2: + res = [] + else: + embeddings, _ = generate_text_embeddings(model_path, x=input_x, model_obj=similarity_model, device=device) + similarity_scores = cosine_similarity(embeddings) + similar_indices = [(x, y) for (x, y) in np.argwhere(similarity_scores > threshold) if x != y] + + # Remove identical pairs e.g. [(0, 3), (3, 0)] -> [(0, 3)] + si = [similarity_scores[tpl] for tpl in similar_indices] + dup_pairs_idx = np.where(pd.Series(si).duplicated())[0].tolist() + remove_vals = [similar_indices[_itm] for _itm in dup_pairs_idx] + [similar_indices.remove(_itm) for _itm in remove_vals] + res = list(set([item[0] for item in similar_indices])) + return res + + +def save_query(output_path: str, query, response, extracted_entity: Optional[dict] = ""): + _response = response + # Probably need to find a better way to extra the info rather than depending on key phrases + if response and "Generated response for question,".lower() in response.lower(): + _response = response.split("**Generated response for question,**")[1].split("\n")[3].strip() + chat_history = {"Query": query, "Answer": _response, "Entity": extracted_entity} + + with open(f"{output_path}/var/lib/tmp/data/history.jsonl", "a") as outfile: + json.dump(chat_history, outfile) + outfile.write("\n") + + +def setup_dir(base_path: str): + dir_list = ["var/lib/tmp/data", "var/lib/tmp/jobs", "var/lib/tmp/.cache", "models/weights"] + for _dl in dir_list: + p = Path(f"{base_path}/{_dl}") + if not p.is_dir(): + p.mkdir(parents=True, exist_ok=True) + + +def update_tables(json_file_path: str, new_data: dict): + # Check if the JSON file exists + if os.path.exists(json_file_path): + try: + # Read the existing content from the JSON file + with open(json_file_path, "r") as json_file: + existing_data = json.load(json_file) + logger.debug("Existing Data:", existing_data) + except Exception as e: + logger.debug(f"An error occurred while reading: {e}") + else: + existing_data = {} + logger.debug("JSON file doesn't exist. Creating a new one.") + + # Append new data to the existing content + existing_data.update(new_data) + + # Write the updated content back to the JSON file + try: + with open(json_file_path, "w") as json_file: + json.dump(existing_data, json_file, indent=4) + logger.debug("Data appended and file updated.") + except Exception as e: + logger.debug(f"An error occurred while writing: {e}") + + +def read_sample_pairs(input_path: str, model_name: str = "h2ogpt-sql"): + df = pd.read_csv(input_path) + df = df.dropna() + df = df.drop_duplicates() + df = df.reset_index(drop=True) + + # NSQL format + if model_name != "h2ogpt-sql": + # Open AI format + # Convert frame to below format + # [ + # "# query": "" + # "# answer": "" + # ] + res = df.apply(lambda row: f"# query: {row['query']}\n# answer: {row['answer']}", axis=1).to_list() + else: + # Convert frame to below format + # [ + # "Question": + # "Answer": + # + # ] + res = df.apply(lambda row: f"Question: {row['query']}\nAnswer:\n{row['answer']}", axis=1).to_list() + return res + + +def extract_table_names(query: str): + """ + Extracts table names from a SQL query. + + Parameters: + query (str): The SQL query to extract table names from. + + Returns: + list: A list of table names. + """ + table_names = re.findall(r"\bFROM\s+(\w+)", query, re.IGNORECASE) + table_names += re.findall(r"\bJOIN\s+(\w+)", query, re.IGNORECASE) + table_names += re.findall(r"\bUPDATE\s+(\w+)", query, re.IGNORECASE) + table_names += re.findall(r"\bINTO\s+(\w+)", query, re.IGNORECASE) + + # Below keywords may not be relevant for the project but adding for sake for completeness + table_names += re.findall(r"\bINSERT\s+INTO\s+(\w+)", query, re.IGNORECASE) + table_names += re.findall(r"\bDELETE\s+FROM\s+(\w+)", query, re.IGNORECASE) + return np.unique(table_names).tolist() + + +def execute_query_pd(query=None, tables_path=None, n_rows=100): + """ + Runs an SQL query on a pandas DataFrame. + + Parameters: + df (pandas DataFrame): The DataFrame to query. + query (str): The SQL query to execute. + + Returns: + pandas DataFrame: The result of the SQL query. + """ + for table in tables_path: + if not table in locals(): + # Update the local namespace with the table name, pandas object + locals()[f"{table}"] = pd.read_csv(tables_path[table]) + + res_df = sqldf(query, locals()) + return res_df + + +def get_table_keys(file_path: str, table_key: str): + res = [] + if not os.path.exists(file_path): + logger.debug(f"File '{file_path}' does not exist.") + return res, dict() + + with open(file_path, "r") as json_file: + data = json.load(json_file) + if isinstance(data, dict): + res = list(data.keys()) + if table_key: + return None, data[table_key] + else: + return res, data + + +def is_resource_low(): + free_in_GB = int(torch.cuda.mem_get_info()[0] / 1024**3) + total_memory = int(torch.cuda.get_device_properties(0).total_memory / 1024**3) + logger.info(f"Total Memory: {total_memory}GB") + logger.info(f"Free GPU memory: {free_in_GB}GB") + off_load = True + if (int(free_in_GB) - 2) >= int(0.5 * total_memory): + off_load = False + return off_load + + +def load_causal_lm_model( + model_name: str, + cache_path: str, + device: str, + load_in_8bit: bool = False, + load_in_4bit=True, + off_load: bool = False, + re_generate: bool = False, +): + try: + # Load h2oGPT.NSQL model + device = {"": 0} if torch.cuda.is_available() else "cpu" if device == "auto" else device + total_memory = int(torch.cuda.get_device_properties(0).total_memory / 1024**3) + free_in_GB = int(torch.cuda.mem_get_info()[0] / 1024**3) + logger.info(f"Free GPU memory: {free_in_GB}GB") + n_gpus = torch.cuda.device_count() + _load_in_8bit = load_in_8bit + + # 22GB (Least requirement on GPU) is a magic number for the current model size. + if off_load and re_generate and total_memory < 22: + # To prevent the system from crashing in-case memory runs low. + # TODO: Performance when offloading to CPU. + max_memory = f"{4}GB" + max_memory = {i: max_memory for i in range(n_gpus)} + logger.info(f"Max Memory: {max_memory}, offloading to CPU") + with init_empty_weights(): + config = AutoConfig.from_pretrained(model_name, cache_dir=cache_path, offload_folder=cache_path) + # A blank model with desired config. + model = AutoModelForCausalLM.from_config(config) + device = infer_auto_device_map(model, max_memory=max_memory) + device["lm_head"] = 0 + _offload_state_dict = True + _llm_int8_enable_fp32_cpu_offload = True + _load_in_8bit = True + load_in_4bit = False + else: + max_memory = f"{int(free_in_GB)-2}GB" + max_memory = {i: max_memory for i in range(n_gpus)} + _offload_state_dict = False + _llm_int8_enable_fp32_cpu_offload = False + + if _load_in_8bit and _offload_state_dict and not load_in_4bit: + _load_in_8bit = False if "cpu" in device else True + logger.debug( + f"Loading in 8 bit mode: {_load_in_8bit} with offloading state: {_llm_int8_enable_fp32_cpu_offload}" + ) + model = AutoModelForCausalLM.from_pretrained( + model_name, + cache_dir=cache_path, + device_map=device, + load_in_8bit=_load_in_8bit, + llm_int8_enable_fp32_cpu_offload=_llm_int8_enable_fp32_cpu_offload, + offload_state_dict=_offload_state_dict, + max_memory=max_memory, + offload_folder=f"{cache_path}/weights/", + ) + + else: + logger.debug(f"Loading in 4 bit mode: {load_in_4bit} with device {device}") + nf4_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_use_double_quant=True, + bnb_4bit_compute_dtype=torch.bfloat16, + ) + + model = AutoModelForCausalLM.from_pretrained( + model_name, cache_dir=cache_path, device_map=device, quantization_config=nf4_config + ) + + tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_path, device_map=device, use_fast=True) + + return model, tokenizer + except Exception as e: + logger.info(f"An error occurred while loading the model: {e}") + return None, None + + +def _check_file_info(file_path: str): + if file_path is not None and Path(file_path).exists(): + logger.info(f"Using information info from path {file_path}") + return file_path + else: + logger.info("Required info not found, provide a path for table information and try again") + raise FileNotFoundError(f"Table info not found at {file_path}") + + +def _execute_sql(query: str): + # Check for, + # 1. Keyword: "Execute SQL: " + # 2. Query starts with SQL statement + # TODO vulnerability check for possible SELECT SQL injection via source code. + _cond1 = _cond2 = False + _cond1 = re.findall(r"Execute SQL:\s+(.*)", query, re.IGNORECASE) + _temp_cond = query.strip().lower().split("execute sql:") + if len(_temp_cond) > 1: + _cond2 = True if query.strip().lower().split("execute sql:")[1].strip().startswith("select") else False + return _cond1 and _cond2 + + +def make_dir(path: str): + try: + os.makedirs(path) + except OSError as exc: + if exc.errno == errno.EXIST and os.path.isdir(path): + pass + else: + raise Exception("Error reported while creating default directory path.") + + +def flatten_list(_list: list): + return [item for sublist in _list for item in sublist] + + +def check_vulnerability(input_query: str): + # Common SQL injection patterns checklist + # Reference: https://github.com/payloadbox/sql-injection-payload-list#generic-sql-injection-payloads + sql_injection_patterns = [ + r"\b(UNION\s+ALL\s+SELECT|OR\s+\d+\s*=\s*\d+|1\s*=\s*1|--\s+)", + r'\b(SELECT\s+\*\s+FROM\s+\w+\s+WHERE\s+\w+\s*=\s*[\'"].*?[\'"]\s*;?\s*--)', + r'\b(INSERT\s+INTO\s+\w+\s+\(\s*\w+\s*,\s*\w+\s*\)\s+VALUES\s*\(\s*[\'"].*?[\'"]\s*,\s*[\'"].*?[\'"]\s*\)\s*;?\s*--)', + r"\b(DROP\s+TABLE|ALTER\s+TABLE|admin\'--)", # DROP TABLE/ALTER TABLE + r"(?:'|\”|--|#|β€˜\s*OR\s*β€˜1|β€˜\s*OR\s*\d+\s*--\s*-|\"\s*OR\s*\"\" = \"|\"\s*OR\s*\d+\s*=\s*\d+\s*--\s*-|’\s*OR\s*''\s*=\s*β€˜|β€˜=β€˜|'=0--+|OR\s*\d+\s*=\s*\d+|β€˜\s*OR\s*β€˜x’=β€˜x’|AND\s*id\s*IS\s*NULL;\s*--|β€˜β€™β€™β€™β€™β€™β€™β€™β€™β€™β€™β€™β€™UNION\s*SELECT\s*β€˜\d+|%00|/\*.*?\*/|\+|\|\||%|@\w+|@@\w+)", + r"AND\s[01]|AND\s(true|false)|[01]-((true|false))", + r"\d+'\s*ORDER\s*BY\s*\d+--\+|\d+'\s*GROUP\s*BY\s*(\d+,)*\d+--\+|'\s*GROUP\s*BY\s*columnnames\s*having\s*1=1\s*--", + r"\bUNION\b\s+\b(?:ALL\s+)?\bSELECT\b\s+[A-Za-z0-9]+", # Union Based + r'\b(OR|AND|HAVING|AS|WHERE)\s+\d+=\d+(\s+AND\s+[\'"]\w+[\'"]\s*=\s*[\'"]\w+[\'"])?(\s*--|\s*#)?\b', + r"\b(?:RLIKE|IF)\s*\(\s*SELECT\s*\(\s*CASE\s*WHEN\s*\(\s*[\d=]+\s*\)\s*THEN\s*0x[0-9a-fA-F]+\s*ELSE\s*0x[0-9a-fA-F]+\s*END\s*\)\s*\)\s*AND\s*'\w+'=\w+\s*|\b%\s*AND\s*[\d=]+\s*AND\s*'\w+'=\w+\s*|and\s*\(\s*select\s*substring\s*\(\s*@@version,\d+,\d+\)\s*\)=\s*'[\w]'\b", + r"('|\")?\s*(or|\|\|)\s*sleep\(.*?\)\s*(\#|--)?\s*(;waitfor\s+delay\s+'[0-9:]+')?\s*;?(\s+AND\s+)?\s*\w+\s*=\s*\w+\s*", # Time Based + r"(ORDER BY \d+,\s*)*(ORDER BY \d+,?)*SLEEP\(\d+\),?(BENCHMARK\(\d+,\s*MD5\('[A-Z]'\)\),?)*\d*,?", # Additional generic UNION patterns + ] + + # Check for SQL injection patterns in the SQL code + res = False + _msg = None + p_detected = [] + for pattern in sql_injection_patterns: + matches = re.findall(pattern, input_query, re.IGNORECASE) + if matches: + if all(v == "'" for v in matches) or all(v == "''" for v in matches): + matches = [] + else: + res = True + p_detected.append(matches) + _pd = set(flatten_list(p_detected)) + if res: + _detected_patterns = ", ".join([str(elem) for elem in _pd]) + _msg = f"The input question has malicious patterns, **{_detected_patterns}** that could lead to SQL Injection.\nSorry, I will not be able to provide an answer.\nPlease try rephrasing the question." + return res, _msg diff --git a/start.py b/start.py new file mode 100644 index 0000000..b80d5f1 --- /dev/null +++ b/start.py @@ -0,0 +1,19 @@ +import os +import shlex +import subprocess +from pathlib import Path + +from huggingface_hub import snapshot_download +from loguru import logger as logging + +logging.info(f"Download model...") +base_path = (Path(__file__).parent).resolve() +snapshot_download(repo_id="NumbersStation/nsql-llama-2-7B", cache_dir=f"{base_path}/models/") +logging.info(f"Download embedding model...") +snapshot_download(repo_id="BAAI/bge-base-en", cache_dir=f"{base_path}/models/sentence_transformers/") + +logging.info("Starting SQL-Sidekick.") +DAEMON_PATH = "./.sidekickvenv/bin/uvicorn" if os.path.isdir("./.sidekickvenv/bin/") else "/resources/venv/bin/uvicorn" + +cmd = f"{DAEMON_PATH} ui.app:main" +subprocess.check_output(shlex.split(cmd)) diff --git a/ui/app.py b/ui/app.py new file mode 100644 index 0000000..dc9ee02 --- /dev/null +++ b/ui/app.py @@ -0,0 +1,635 @@ +import gc +import json +import logging +import os +from pathlib import Path +from typing import List, Optional + +import openai +import toml +import torch +from h2o_wave import Q, app, data, handle_on, main, on, ui +from h2o_wave.core import expando_to_dict +from sidekick.prompter import db_setup_api, query_api +from sidekick.utils import get_table_keys, save_query, setup_dir, update_tables + +# Load the config file and initialize required paths +base_path = (Path(__file__).parent / "../").resolve() +env_settings = toml.load(f"{base_path}/ui/app_config.toml") +tmp_path = f"{base_path}/var/lib/tmp" + +ui_title = env_settings["WAVE_UI"]["TITLE"] +ui_description = env_settings["WAVE_UI"]["SUB_TITLE"] + + +async def user_variable(q: Q): + db_settings = toml.load(f"{base_path}/sidekick/configs/env.toml") + + q.user.db_dialect = db_settings["DB-DIALECT"]["DB_TYPE"] + q.user.host_name = db_settings["LOCAL_DB_CONFIG"]["HOST_NAME"] + q.user.user_name = db_settings["LOCAL_DB_CONFIG"]["USER_NAME"] + q.user.password = db_settings["LOCAL_DB_CONFIG"]["PASSWORD"] + q.user.db_name = db_settings["LOCAL_DB_CONFIG"]["DB_NAME"] + q.user.port = db_settings["LOCAL_DB_CONFIG"]["PORT"] + + tables, tables_info = get_table_keys(f"{tmp_path}/data/tables.json", None) + table_info = tables_info[tables[0]] if len(tables) > 0 else None + + q.user.table_info_path = table_info["schema_info_path"] if len(tables) > 0 else None + q.user.table_samples_path = table_info["samples_path"] if len(tables) > 0 else None + q.user.sample_qna_path = table_info["samples_qa"] if len(tables) > 0 else None + q.user.table_name = tables[0] if len(tables) > 0 else None + + logging.basicConfig(format="%(asctime)s %(levelname)s %(message)s") + + +async def client_variable(q: Q): + q.client.query = None + + +# Use for page cards that should be removed when navigating away. +# For pages that should be always present on screen use q.page[key] = ... +def add_card(q, name, card) -> None: + q.client.cards.add(name) + q.page[name] = card + + +# Remove all the cards related to navigation. +def clear_cards(q, ignore: Optional[List[str]] = []) -> None: + if not q.client.cards: + return + + for name in q.client.cards.copy(): + if name not in ignore: + del q.page[name] + q.client.cards.remove(name) + + +@on("#chat") +async def chat(q: Q): + q.page["sidebar"].value = "#chat" + + if q.args.table_dropdown: + # If a table is selected, the trigger causes refresh of the page + # so we update chat history with table name selection and return + # avoiding re-drawing. + q.page["chat_card"].data += [q.args.chatbot, False] + return + + clear_cards(q) # When routing, drop all the cards except of the main ones (header, sidebar, meta). + table_names = [] + tables, _ = get_table_keys(f"{tmp_path}/data/tables.json", None) + if len(tables) > 0: + with open(f"{tmp_path}/data/tables.json", "r") as json_file: + meta_data = json.load(json_file) + for table in tables: + original_name = meta_data[table].get("original_name", q.user.original_name) + table_names.append(ui.choice(table, f"{original_name}")) + + add_card( + q, + "background_card", + ui.form_card( + box="horizontal", + items=[ + ui.text("Ask your questions:"), + ui.inline(items=[ui.toggle(name="demo_mode", label="Demo", trigger=True)], justify="end"), + ], + ), + ) + + add_card( + q, + "select_tables", + ui.form_card( + box="vertical", + items=[ + ui.dropdown( + name="table_dropdown", + label="Table", + required=True, + choices=table_names, + value=q.user.table_name if q.user.table_name else None, + trigger=True, + ) + ], + ), + ) + add_card( + q, + "chat_card", + ui.chatbot_card( + box=ui.box("vertical", height="500px"), + name="chatbot", + data=data(fields="content from_user", t="list", size=-50), + ), + ) + add_card( + q, + "additional_actions", + ui.form_card( + box=ui.box("vertical", height="120px"), + items=[ + ui.buttons( + [ + ui.button( + name="regenerate", + icon="RepeatOne", + caption="Attempts regeneration", + label="Regenerate", + primary=True, + ), + ui.button( + name="regenerate_with_options", + icon="RepeatAll", + caption="Regenerates with options", + label="Try Harder", + ), + ui.button( + name="save_conversation", + caption="Saves the conversation for future reference/to improve response", + label="Save", + icon="Save", + ), + ], + justify="center", + ) + ], + ), + ) + + if q.args.chatbot is None or q.args.chatbot.strip() == "": + _msg = """Welcome to the SQL Sidekick!\nI am an AI assistant, i am here to help you find answers to questions on structured data. +To get started, please select a table from the dropdown and ask your question. +One could start by learning about the dataset by asking questions like: +- Describe data.""" + + q.args.chatbot = _msg + q.page["chat_card"].data += [q.args.chatbot, False] + logging.info(f"Chatbot response: {q.args.chatbot}") + + +@on("chatbot") +async def chatbot(q: Q): + q.page["sidebar"].value = "#chat" + + # Append user message. + q.page["chat_card"].data += [q.args.chatbot, True] + + if q.page["select_tables"].table_dropdown.value is None or q.user.table_name is None: + q.page["chat_card"].data += ["Please select a table to continue!", False] + return + + # Append bot response. + question = f"{q.args.chatbot}" + logging.info(f"Question: {question}") + + # For regeneration, currently there are 2 modes + # 1. Quick fast approach by throttling the temperature + # 2. "Try harder mode (THM)" Slow approach by using the diverse beam search + llm_response = None + try: + if q.args.chatbot and q.args.chatbot.lower() == "db setup": + llm_response, err = db_setup_api( + db_name=q.user.db_name, + hostname=q.user.host_name, + user_name=q.user.user_name, + password=q.user.password, + port=q.user.port, + table_info_path=q.user.table_info_path, + table_samples_path=q.user.table_samples_path, + table_name=q.user.table_name, + ) + elif q.args.chatbot and q.args.chatbot.lower() == "regenerate" or q.args.regenerate: + # Attempts to regenerate response on the last supplied query + logging.info(f"Attempt for regeneration") + if q.client.query is not None and q.client.query.strip() != "": + llm_response, alt_response, err = query_api( + question=q.client.query, + sample_queries_path=q.user.sample_qna_path, + table_info_path=q.user.table_info_path, + table_name=q.user.table_name, + is_regenerate=True, + is_regen_with_options=False, + ) + llm_response = "\n".join(llm_response) + else: + llm_response = ( + "Sure, I can generate a new response for you. " + "However, in order to assist you effectively could you please provide me with your question?" + ) + elif q.args.chatbot and q.args.chatbot.lower() == "try harder" or q.args.regenerate_with_options: + # Attempts to regenerate response on the last supplied query + logging.info(f"Attempt for regeneration with options.") + if q.client.query is not None and q.client.query.strip() != "": + llm_response, alt_response, err = query_api( + question=q.client.query, + sample_queries_path=q.user.sample_qna_path, + table_info_path=q.user.table_info_path, + table_name=q.user.table_name, + is_regenerate=False, + is_regen_with_options=True, + ) + response = "\n".join(llm_response) + if alt_response: + llm_response = response + "\n\n" + "**Alternate options:**\n" + "\n".join(alt_response) + logging.info(f"Regenerate response: {llm_response}") + else: + llm_response = response + else: + llm_response = ( + "Sure, I can generate a new response for you. " + "However, in order to assist you effectively could you please provide me with your question?" + ) + else: + q.client.query = question + llm_response, alt_response, err = query_api( + question=q.client.query, + sample_queries_path=q.user.sample_qna_path, + table_info_path=q.user.table_info_path, + table_name=q.user.table_name, + ) + llm_response = "\n".join(llm_response) + except (MemoryError, RuntimeError) as e: + logging.error(f"Something went wrong while generating response: {e}") + gc.collect() + torch.cuda.empty_cache() + llm_response = "Something went wrong, try executing the query again!" + q.client.llm_response = llm_response + q.page["chat_card"].data += [llm_response, False] + + +@on("file_upload") +async def fileupload(q: Q): + q.page["dataset"].error_bar.visible = False + q.page["dataset"].success_bar.visible = False + q.page["dataset"].progress_bar.visible = True + + await q.page.save() + + q.page["sidebar"].value = "#datasets" + usr_info_path = None + usr_samples_path = None + usr_sample_qa = None + + sample_data = q.args.sample_data + sample_schema = q.args.data_schema + sample_qa = q.args.sample_qa + org_table_name = q.args.table_name + usr_table_name = q.args.table_name.strip().lower().replace(" ", "_") + + if sample_data is None or sample_schema is None or usr_table_name is None or usr_table_name.strip() == "": + q.page["dataset"].error_bar.visible = True + q.page["dataset"].progress_bar.visible = False + else: + if sample_data: + usr_samples_path = await q.site.download( + sample_data[0], f"{tmp_path}/jobs/{usr_table_name}_table_samples.csv" + ) + if sample_schema: + usr_info_path = await q.site.download( + sample_schema[0], f"{tmp_path}/jobs/{usr_table_name}_table_info.jsonl" + ) + if sample_qa: + usr_sample_qa = await q.site.download(sample_qa[0], f"{tmp_path}/jobs/{usr_table_name}_sample_qa.csv") + + q.page["dataset"].error_bar.visible = False + + table_metadata = dict() + table_metadata[usr_table_name] = { + "original_name": org_table_name, + "schema_info_path": usr_info_path, + "samples_path": usr_samples_path, + "samples_qa": usr_sample_qa, + } + update_tables(f"{tmp_path}/data/tables.json", table_metadata) + + q.user.table_name = usr_table_name + q.user.table_samples_path = usr_samples_path + q.user.table_info_path = usr_info_path + q.user.sample_qna_path = usr_sample_qa + + db_resp = db_setup_api( + db_name=q.user.db_name, + hostname=q.user.host_name, + user_name=q.user.user_name, + password=q.user.password, + port=q.user.port, + table_info_path=q.user.table_info_path, + table_samples_path=q.user.table_samples_path, + table_name=q.user.table_name, + ) + logging.info(f"DB updates: \n {db_resp}") + q.page["dataset"].progress_bar.visible = False + q.page["dataset"].success_bar.visible = True + + +@on("#datasets") +async def datasets(q: Q): + q.page["sidebar"].value = "#datasets" + clear_cards(q) # When routing, drop all the cards except of the main ones (header, sidebar, meta). + add_card(q, "data_header", ui.form_card(box="horizontal", title="Dataset", items=[])) + + add_card( + q, + "dataset", + ui.form_card( + box="vertical", + items=[ + ui.message_bar( + name="error_bar", + type="error", + text="Please input table name, data & schema files to upload!", + visible=False, + ), + ui.message_bar(name="success_bar", type="success", text="Files Uploaded Successfully!", visible=False), + ui.textbox(name="table_name", label="Table Name", required=True), + ui.file_upload( + name="data_schema", + label="Data Schema", + multiple=False, + compact=True, + file_extensions=["jsonl"], + required=True, + max_file_size=5000, # Specified in MB. + tooltip="The data describing table schema and sample values, formats allowed are JSONL & CSV respectively!", + ), + ui.file_upload( + name="sample_qa", + label="Sample Q&A", + multiple=False, + compact=True, + file_extensions=["csv"], + required=False, + max_file_size=5000, # Specified in MB. + tooltip="The data describing table schema and sample values, formats allowed are JSONL & CSV respectively!", + ), + ui.file_upload( + name="sample_data", + label="Data Samples", + multiple=False, + compact=True, + file_extensions=["csv"], + required=True, + max_file_size=5000, # Specified in MB. + tooltip="The data describing table schema and sample values, formats allowed are JSONL & CSV respectively!", + ), + ui.progress( + name="progress_bar", width="100%", label="Uploading datasets and creating tables!", visible=False + ), + ui.button(name="file_upload", label="Upload", primary=True), + ], + ), + ) + + +@on("#about") +async def about(q: Q): + q.page["sidebar"].value = "#about" + clear_cards(q) + + +@on("#support") +async def handle_page4(q: Q): + q.page["sidebar"].value = "#support" + # When routing, drop all the cards except of the main ones (header, sidebar, meta). + # Since this page is interactive, we want to update its card instead of recreating it every time, so ignore 'form' card on drop. + clear_cards(q, ["form"]) + + +@on("submit_table") +async def submit_table(q: Q): + table_key = q.args.table_dropdown + if table_key: + table_name = table_key.lower().replace(" ", "_") + _, table_info = get_table_keys(f"{tmp_path}/data/tables.json", table_name) + + q.user.table_info_path = table_info["schema_info_path"] + q.user.table_samples_path = table_info["samples_path"] + q.user.sample_qna_path = table_info["samples_qa"] + q.user.table_name = table_key.replace(" ", "_") + q.user.original_name = table_info["original_name"] + q.page["select_tables"].table_dropdown.value = table_name + else: + q.page["select_tables"].table_dropdown.value = q.user.table_name + + +async def init(q: Q) -> None: + q.client.timezone = "UTC" + username, profile_pic = q.auth.username, q.app.persona_path + q.page["meta"] = ui.meta_card( + box="", + layouts=[ + ui.layout( + breakpoint="xs", + min_height="100vh", + zones=[ + ui.zone( + "main", + size="1", + direction=ui.ZoneDirection.ROW, + zones=[ + ui.zone("sidebar", size="250px"), + ui.zone( + "body", + zones=[ + ui.zone( + "content", + zones=[ + # Specify various zones and use the one that is currently needed. Empty zones are ignored. + ui.zone("horizontal", direction=ui.ZoneDirection.ROW), + ui.zone("vertical"), + ui.zone( + "grid", direction=ui.ZoneDirection.ROW, wrap="stretch", justify="center" + ), + ], + ), + ], + ), + ], + ) + ], + ) + ], + ) + q.page["sidebar"] = ui.nav_card( + box="sidebar", + color="primary", + title="QnA Assistant", + subtitle="Get answers to your questions.", + value=f'#{q.args["#"]}' if q.args["#"] else "#chat", + image="https://wave.h2o.ai/img/h2o-logo.svg", + items=[ + ui.nav_group( + "Menu", + items=[ui.nav_item(name="#datasets", label="Upload Dataset"), ui.nav_item(name="#chat", label="Chat")], + ), + ui.nav_group( + "Help", + items=[ + ui.nav_item(name="#about", label="About"), + ui.nav_item(name="#support", label="Support"), + ], + ), + ], + secondary_items=[ + ui.persona( + title=username, + size="xs", + image=profile_pic, + ), + ], + ) + + # Connect to LLM + openai.api_key = "" + + await user_variable(q) + await client_variable(q) + # If no active hash present, render chat. + if q.args["#"] is None: + await chat(q) + + +def on_shutdown(): + logging.info("App stopped. Goodbye!") + + +# Preload sample data for the app +def upload_demo_examples(q: Q): + upload_action = True + cur_dir = os.getcwd() + sample_data_path = f"{cur_dir}/examples/demo/" + org_table_name = "Sleep health and lifestyle study" + usr_table_name = org_table_name.lower().replace(" ", "_") + + table_metadata_path = f"{tmp_path}/data/tables.json" + # Do not upload dataset if user had any tables uploaded previously. This check avoids re-uploading sample dataset. + if os.path.exists(table_metadata_path): + # Read the existing content from the JSON file + with open(table_metadata_path, "r") as json_file: + existing_data = json.load(json_file) + if usr_table_name in existing_data: + upload_action = False + logging.info(f"Dataset already uploaded, skipping upload!") + + if upload_action: + table_metadata = dict() + table_metadata[usr_table_name] = { + "original_name": org_table_name, + "schema_info_path": f"{sample_data_path}/table_info.jsonl", + "samples_path": f"{sample_data_path}/sleep_health_and_lifestyle_dataset.csv", + "samples_qa": None, + } + update_tables(f"{tmp_path}/data/tables.json", table_metadata) + + q.user.org_table_name = org_table_name + q.user.table_name = usr_table_name + q.user.table_samples_path = f"{sample_data_path}/sleep_health_and_lifestyle_dataset.csv" + q.user.table_info_path = f"{sample_data_path}/table_info.jsonl" + q.user.sample_qna_path = None + + db_resp = db_setup_api( + db_name=q.user.db_name, + hostname=q.user.host_name, + user_name=q.user.user_name, + password=q.user.password, + port=q.user.port, + table_info_path=q.user.table_info_path, + table_samples_path=q.user.table_samples_path, + table_name=q.user.table_name, + ) + logging.info(f"DB updated with demo examples: \n {db_resp}") + q.args.table_dropdown = usr_table_name + + +async def on_event(q: Q): + event_handled = False + args_dict = expando_to_dict(q.args) + logging.info(f"Args dict {args_dict}") + if q.args.regenerate_with_options: + q.args.chatbot = "try harder" + elif q.args.regenerate: + q.args.chatbot = "regenerate" + + if q.args.table_dropdown and not q.args.chatbot: + logging.info(f"User selected table: {q.args.table_dropdown}") + await submit_table(q) + q.args.chatbot = f"Table {q.args.table_dropdown} selected" + # Refresh response is triggered when user selects a table via dropdown + event_handled = True + + if q.args.save_conversation or (q.args.chatbot and "save the qna pair:" in q.args.chatbot.lower()): + question = q.client.query + _val = q.client.llm_response + # Currently, any manual input by the user is a Question by default + if ( + question is not None + and "SELECT" in question + and (question.lower().startswith("question:") or question.lower().startswith("q:")) + ): + _q = question.lower().split("q:")[1].split("r:")[0].strip() + _r = question.lower().split("r:")[1].strip() + logging.info(f"Saving conversation for question: {_q} and response: {_r}") + save_query(base_path, query=_q, response=_r) + _msg = "Conversation saved successfully!" + elif question is not None and _val is not None and _val.strip() != "": + logging.info(f"Saving conversation for question: {question} and response: {_val}") + save_query(base_path, query=question, response=_val) + _msg = "Conversation saved successfully!" + else: + _msg = "Sorry, try generating a conversation to save." + q.page["chat_card"].data += [_msg, False] + event_handled = True + elif q.args.regenerate or q.args.regenerate_with_options: + await chatbot(q) + event_handled = True + elif q.args.demo_mode: + logging.info(f"Switching to demo mode!") + # If demo datasets are not present, register them. + upload_demo_examples(q) + logging.info(f"Demo dataset selected: {q.user.table_name}") + await submit_table(q) + sample_qs = """ + Data description: The Sleep Health and Lifestyle Dataset comprises 400 rows and 13 columns, + covering a wide range of variables related to sleep and daily habits. + It includes details such as gender, age, occupation, sleep duration, quality of sleep, + physical activity level, stress levels, BMI category, blood pressure, heart rate, daily steps, + and the presence or absence of sleep disorders\n + Reference: https://www.kaggle.com/datasets/uom190346a/sleep-health-and-lifestyle-dataset \n + Example questions:\n + 1. Describe data. Tip: For more detailed insights on the data try AutoInsights on the Cloud marketplace. + 2. What is the average sleep duration for each gender? + 3. How does average sleep duration vary across different age groups? + 4. What are the most common occupations among individuals in the dataset? + 5. What is the average sleep duration for each occupation? + 6. What is the average sleep duration for each age group? + 7. What is the effect of Physical Activity Level on Quality of Sleep? + """ + q.args.chatbot = ( + f"Demo mode is enabled.\nTry below example questions for the selected data to get started,\n{sample_qs}" + ) + q.page["chat_card"].data += [q.args.chatbot, True] + q.page["meta"].redirect = "#chat" + event_handled = True + else: # default chatbot event + await handle_on(q) + event_handled = True + logging.info(f"Event handled: {event_handled} ... ") + return event_handled + + +@app("/", on_shutdown=on_shutdown) +async def serve(q: Q): + # Run only once per client connection. + if not q.client.initialized: + q.client.cards = set() + setup_dir(base_path) + await init(q) + q.client.initialized = True + logging.info("App initialized.") + + # Handle routing. + if await on_event(q): + await q.page.save() + return diff --git a/ui/app_config.toml b/ui/app_config.toml new file mode 100644 index 0000000..db786dc --- /dev/null +++ b/ui/app_config.toml @@ -0,0 +1,3 @@ +[WAVE_UI] +TITLE = "SideKick Assistant UI" +SUB_TITLE = "Get answers to your questions"