diff --git a/tutorials/load-wiki-to-milvus.ipynb b/tutorials/load-wiki-to-milvus.ipynb new file mode 100644 index 0000000..9f4ca87 --- /dev/null +++ b/tutorials/load-wiki-to-milvus.ipynb @@ -0,0 +1,573 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "b65a7e72", + "metadata": {}, + "source": [ + "\n", + "# Load Data into Milvus for RAG\n" + ] + }, + { + "cell_type": "markdown", + "id": "15258f81", + "metadata": {}, + "source": [ + " \n", + "\n", + "\n", + "\n", + "## 1. Set up the environment \n", + "\n", + "### Install Libraries\n", + "\n", + "We need to install the pymilvus package to the watsonx.ai Python environment." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "51677357", + "metadata": {}, + "outputs": [], + "source": [ + "!pip install grpcio==1.60.0 \n", + "!pip install pymilvus" + ] + }, + { + "cell_type": "markdown", + "id": "579e78ba", + "metadata": {}, + "source": [ + "## !!RESTART THE KERNAL AFTER pymilvus install!!\n", + "\n", + "Certain dependencies need to be persisted. Restarting the kernal allows this to occur. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "694eff33", + "metadata": {}, + "outputs": [], + "source": [ + "!pip install ipython-sql==0.4.1\n", + "!pip install sqlalchemy==1.4.46\n", + "!pip install sqlalchemy==1.4.46 \"pyhive[presto]\"\n", + "!pip install python-dotenv\n", + "!pip install wikipedia\n", + "!pip install sentence_transformers" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "356cf465", + "metadata": {}, + "outputs": [], + "source": [ + "# Create environment variables\n", + "\n", + "import os\n", + "from dotenv import load_dotenv\n", + "from ibm_cloud_sdk_core import IAMTokenManager\n", + "from ibm_watson_studio_lib import access_project_or_space\n", + "\n", + "wslib = access_project_or_space({\n", + " 'token': '',\n", + " 'project_id': ''\n", + "})\n", + "\n", + "wslib.download_file('config.env')\n", + "load_dotenv('config.env')\n", + "\n", + "\n", + "# Connection variables\n", + "api_key = os.getenv(\"API_KEY\", None)\n", + "ibm_cloud_url = os.getenv(\"IBM_CLOUD_URL\", None) \n", + "project_id = os.getenv(\"PROJECT_ID\", None)\n", + "\n", + "creds = {\n", + " \"url\": ibm_cloud_url,\n", + " \"apikey\": api_key \n", + "}\n", + "access_token = IAMTokenManager(\n", + " apikey = api_key,\n", + " url = \"https://iam.cloud.ibm.com/identity/token\"\n", + ").get_token()" + ] + }, + { + "cell_type": "markdown", + "id": "b747755a", + "metadata": {}, + "source": [ + "## Wikipedia Exploration" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "32f2b28a", + "metadata": {}, + "outputs": [], + "source": [ + "import wikipedia\n", + "\n", + "# search\n", + "search_results = wikipedia.search(\"\")\n", + "search_results\n", + "\n", + "print(search_results)\n", + "\n", + "# view article summary\n", + "article_summary = wikipedia.summary(search_results[0])\n", + "article_summary\n", + "\n", + "print(article_summary)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d89314fa", + "metadata": {}, + "outputs": [], + "source": [ + "import wikipedia\n", + "\n", + "# fetch wikipedia articles\n", + "articles = {\n", + " #'IBM': None, \n", + " '': None\n", + "}\n", + "\n", + "for k,v in articles.items():\n", + " article = wikipedia.page(k)\n", + " articles[k] = article.content\n", + " print(f\"Successfully fetched {k}\")\n", + "\n", + "print(f\"Successfully fetched {len(articles)} articles \")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f5f350d8", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "id": "4353d5b0", + "metadata": {}, + "source": [ + "### Split Wikipedia Data into Chunks" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "18894c13", + "metadata": {}, + "outputs": [], + "source": [ + "# Chunk data\n", + "def split_into_chunks(text, chunk_size):\n", + " words = text.split()\n", + " return [' '.join(words[i:i + chunk_size]) for i in range(0, len(words), chunk_size)]\n", + "\n", + "split_articles = {}\n", + "for k,v in articles.items():\n", + " split_articles[k] = split_into_chunks(v, 225)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bb6dc64d", + "metadata": {}, + "outputs": [], + "source": [ + "article_titles = list(split_articles.keys())\n", + "article_chunks = list(split_articles.values())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "947812bc", + "metadata": {}, + "outputs": [], + "source": [ + "## create titles_list for associates chunks to be loaded into milvus \n", + "\n", + "i = 0\n", + "for title in article_titles:\n", + " list_length = len(article_chunks[i])\n", + " article_titles[i] = [title] * list_length\n", + " i+=1\n", + " \n" + ] + }, + { + "cell_type": "markdown", + "id": "a58926d6", + "metadata": {}, + "source": [ + "## Insert Chunks with Embeddings into Milvus" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "65714f49", + "metadata": {}, + "outputs": [], + "source": [ + "from pymilvus import(\n", + " Milvus,\n", + " IndexType,\n", + " Status,\n", + " connections,\n", + " FieldSchema,\n", + " DataType,\n", + " Collection,\n", + " CollectionSchema,\n", + ")\n", + "\n", + "import os \n", + "\n", + "host = os.getenv(\"MILVUS_HOST\", None)\n", + "port = os.getenv(\"MILVUS_PORT\", None)\n", + "password = os.getenv(\"MILVUS_PASSWORD\", None)\n", + "user = 'ibmlhapikey'\n", + "\n", + "\n", + "connections.connect(alias=\"default\", \n", + " host=url, \n", + " port=port, \n", + " user=apiuser, \n", + " password=apikey, \n", + " secure=True)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d4d98158", + "metadata": {}, + "outputs": [], + "source": [ + "# Create collection - define fields + schema\n", + "\n", + "fields = [\n", + " FieldSchema(name=\"id\", dtype=DataType.INT64, is_primary=True, auto_id=True), # Primary key\n", + " FieldSchema(name=\"article_text\", dtype=DataType.VARCHAR, max_length=2500,),\n", + " FieldSchema(name=\"article_title\", dtype=DataType.VARCHAR, max_length=200,),\n", + " FieldSchema(name=\"vector\", dtype=DataType.FLOAT_VECTOR, dim=384),\n", + "]\n", + "\n", + "schema = CollectionSchema(fields, \"\")\n", + "\n", + "wiki_collection = Collection(\"\", schema)\n", + "\n", + "# Create index\n", + "index_params = {\n", + " 'metric_type':'L2',\n", + " 'index_type':\"IVF_FLAT\",\n", + " 'params':{\"nlist\":2048}\n", + "}\n", + "\n", + "wiki_collection.create_index(field_name=\"vector\", index_params=index_params)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "efbefcbe", + "metadata": {}, + "outputs": [], + "source": [ + "# we can run a check to see the collections in our milvus instance and we see the new collection has been created \n", + "\n", + "from pymilvus import utility\n", + "utility.list_collections()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7e2b9046", + "metadata": {}, + "outputs": [], + "source": [ + "# load data into Milvus\n", + "import pandas as pd\n", + "from sentence_transformers import SentenceTransformer\n", + "from pymilvus import Collection, connections\n", + "import warnings\n", + "warnings.filterwarnings('ignore')\n", + "\n", + "\n", + "for i in range(len(article_titles)):\n", + " # Create vector embeddings + data\n", + " model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2') # 384 dim\n", + " passage_embeddings = model.encode(article_chunks[i])\n", + "\n", + " basic_collection = Collection(\"\") \n", + " data = [\n", + " article_chunks[i],\n", + " article_titles[i],\n", + " passage_embeddings\n", + " ]\n", + " \n", + " out = basic_collection.insert(data)\n", + " basic_collection.flush() # Ensures data persistence\n", + "\n", + " \n", + " print(\"Wikipedia Article: \\'\" + article_titles[i][0] + \"\\' has been loaded.\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3547110e", + "metadata": {}, + "outputs": [], + "source": [ + "## check to ensure entities have been loaded into the collection\n", + "\n", + "basic_collection = Collection(\"\") \n", + "\n", + "basic_collection.num_entities " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0f8253da", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "id": "34a915ff", + "metadata": {}, + "source": [ + "### Prompt LLM with Query Results\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "77e075a8", + "metadata": {}, + "outputs": [], + "source": [ + "from sentence_transformers import SentenceTransformer\n", + "from pymilvus import(\n", + " Milvus,\n", + " IndexType,\n", + " Status,\n", + " connections,\n", + " FieldSchema,\n", + " DataType,\n", + " Collection,\n", + " CollectionSchema,\n", + ")\n", + "\n", + "import os \n", + "\n", + "host = os.getenv(\"MILVUS_HOST\", None)\n", + "port = os.getenv(\"MILVUS_PORT\", None)\n", + "password = os.getenv(\"MILVUS_PASSWORD\", None)\n", + "user = 'ibmlhapikey'\n", + "\n", + "\n", + "connections.connect(alias=\"default\", \n", + " host=url, \n", + " port=port, \n", + " user=apiuser, \n", + " password=apikey, \n", + " secure=True)\n", + "\n", + "\n", + "# Load collection\n", + "\n", + "basic_collection = Collection(\"\") \n", + "basic_collection.load()\n", + "\n", + "# Query function\n", + "def query_milvus(query, num_results):\n", + " \n", + " # Vectorize query\n", + " model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2') # 384 dim\n", + " query_embeddings = model.encode([query])\n", + "\n", + " # Search\n", + " search_params = {\n", + " \"metric_type\": \"L2\", \n", + " \"params\": {\"nprobe\": 5}\n", + " }\n", + " results = basic_collection.search(\n", + " data=query_embeddings, \n", + " anns_field=\"vector\", \n", + " param=search_params,\n", + " limit=num_results,\n", + " expr=None, \n", + " output_fields=['article_text'],\n", + " )\n", + " return results" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0f62a489", + "metadata": {}, + "outputs": [], + "source": [ + "## Consider some questions to ask regarding the topic you have chosen \n", + "\n", + "#question_text = \"How does IBM treat their employees?\"\n", + "\n", + "question_text = \"\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a8259b76", + "metadata": {}, + "outputs": [], + "source": [ + "# Query Milvus \n", + "\n", + "num_results = 3\n", + "results = query_milvus(question_text, num_results)\n", + "\n", + "relevant_chunks = []\n", + "for i in range(num_results): \n", + " #print(f\"id: {results[0].ids[i]}\")\n", + " #print(f\"distance: {results[0].distances[i]}\")\n", + " text = results[0][i].entity.get('article_text')\n", + " relevant_chunks.append(text)\n", + " \n", + "#print(relevant_chunks)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c72ab9a3", + "metadata": {}, + "outputs": [], + "source": [ + "def make_prompt(context, question_text):\n", + " return (f\"{context}\\n\\nPlease answer a question using this text. \"\n", + " + f\"If the question is unanswerable, say \\\"unanswerable\\\".\"\n", + " + f\"\\n\\nQuestion: {question_text}\")\n", + "\n", + "\n", + "# Build prompt w/ Milvus results\n", + "# Embed retrieved passages(context) and user question into into prompt text\n", + "\n", + "context = \"\\n\\n\".join(relevant_chunks)\n", + "prompt = make_prompt(context, question_text)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "aef20205", + "metadata": {}, + "outputs": [], + "source": [ + "print(prompt)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "47852e4f", + "metadata": {}, + "outputs": [], + "source": [ + "from ibm_watson_machine_learning.foundation_models import Model\n", + "from ibm_watson_machine_learning.metanames import GenTextParamsMetaNames as GenParams\n", + "\n", + "# Model Parameters\n", + "params = {\n", + " GenParams.DECODING_METHOD: \"greedy\",\n", + " GenParams.MIN_NEW_TOKENS: 1,\n", + " GenParams.MAX_NEW_TOKENS: 500,\n", + " GenParams.TEMPERATURE: 0,\n", + "}\n", + "model = Model(\n", + " model_id='meta-llama/llama-2-70b-chat', \n", + " params=params, credentials=creds, \n", + " project_id=project_id\n", + ")\n", + "\n", + "# Prompt LLM\n", + "response = model.generate_text(prompt)\n", + "print(f\"Question: {question_text}{response}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f9503c73", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5faa6b69", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8dedcfc9", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3.10", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.14" + }, + "vscode": { + "interpreter": { + "hash": "aee8b7b246df8f9039afb4144a1f6fd8d2ca17a180786b69acc140d282b71a49" + } + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}