From bfe52857d89e3a7831ec39e5d81f85daeed4187a Mon Sep 17 00:00:00 2001 From: christophergs Date: Thu, 7 Mar 2024 11:24:12 +0000 Subject: [PATCH] update part 3 to use llama index v0.10.0 --- part_3_rag/rag.py | 31 +++++++++++++------------------ requirements.txt | 4 +++- 2 files changed, 16 insertions(+), 19 deletions(-) diff --git a/part_3_rag/rag.py b/part_3_rag/rag.py index 95a98f6..107669a 100644 --- a/part_3_rag/rag.py +++ b/part_3_rag/rag.py @@ -3,20 +3,19 @@ import sys import textwrap from pathlib import Path - -from llama_index import ( - Response, - set_global_tokenizer, - ServiceContext, - VectorStoreIndex, +from llama_index.core import ( SimpleDirectoryReader, + VectorStoreIndex, StorageContext, load_index_from_storage, + set_global_tokenizer, set_global_handler, ) -from llama_index.core.llms.types import ChatMessage, MessageRole, ChatResponse -from llama_index.embeddings import HuggingFaceEmbedding -from llama_index.llms import LlamaCPP +from llama_index.core.base.llms.types import ChatResponse, ChatMessage, MessageRole +from llama_index.core.base.response.schema import Response +from llama_index.embeddings.huggingface import HuggingFaceEmbedding +from llama_index.llms.llama_cpp import LlamaCPP + from transformers import AutoTokenizer from shared.settings import DATA_DIR @@ -44,13 +43,13 @@ def load_embedding_model() -> HuggingFaceEmbedding: def save_or_load_index( - index_dir: Path, service_context: ServiceContext + index_dir: Path, embed_model: HuggingFaceEmbedding ) -> VectorStoreIndex: index_exists = any(item for item in index_dir.iterdir() if item.name != ".gitkeep") if index_exists: storage_context = StorageContext.from_defaults(persist_dir=index_dir) return load_index_from_storage( - storage_context=storage_context, service_context=service_context + storage_context=storage_context, embed_model=embed_model ) transcript_files = glob.glob(str(DATA_DIR / "**/*transcript*"), recursive=True) @@ -64,7 +63,7 @@ def save_or_load_index( ).load_data() index = VectorStoreIndex.from_documents( - documents, service_context=service_context, show_progress=True + documents, embed_model=embed_model, show_progress=True ) # persist the index index.storage_context.persist(persist_dir=index_dir) @@ -84,14 +83,10 @@ def run_inference( AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2") ) - service_context = ServiceContext.from_defaults( - llm=llm, embed_model=embedding_model, system_prompt=SYSTEM_PROMPT_TEXT - ) - index_dir = DATA_DIR / "indices" - index = save_or_load_index(index_dir=index_dir, service_context=service_context) + index = save_or_load_index(index_dir=index_dir, embed_model=embedding_model) - query_engine = index.as_query_engine() + query_engine = index.as_query_engine(llm=llm) return query_engine.query(messages[1].content) diff --git a/requirements.txt b/requirements.txt index f57fa60..9fac83f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,7 @@ llama-cpp-python==0.2.38 -llama-index==0.9.39 +llama-index>=0.10.0,<0.11.0 +llama-index-embeddings-huggingface>=0.1.4,<0.2.0 +llama-index-llms-llama-cpp>=0.1.3,<0.2.0 transformers==4.37.1 deepeval==0.20.55 pytest==8.0.0