diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..0c487f4 --- /dev/null +++ b/.gitignore @@ -0,0 +1,52 @@ +# Python +__pycache__/ +*.py[cod] +*$py.class +*.pyc +*.pyo + +# Virtual environments +venv/ +.venv/ +env/ +.env + +# Environment files (keep .env.example, ignore .env) +.env +!.env.example + +# FAISS indexes (generated at runtime) +faiss_index/ +legal_faiss_index/ +papers_faiss_index/ +text_faiss_index/ +image_faiss_index/ +table_faiss_index/ +kb_faiss_index/ +*_faiss_index/ + +# Generated reports +research_report_*.md + +# Extracted content (generated at runtime) +data/extracted/images/*.png +data/extracted/images/*.jpg +data/extracted/tables/*.csv + +# Jupyter +.ipynb_checkpoints/ +*.ipynb + +# macOS +.DS_Store + +# IDEs +.vscode/ +.idea/ +*.swp +*.swo + +# Distribution / packaging +dist/ +build/ +*.egg-info/ diff --git a/01-rag-from-scratch/.env.example b/01-rag-from-scratch/.env.example new file mode 100644 index 0000000..f4f0169 --- /dev/null +++ b/01-rag-from-scratch/.env.example @@ -0,0 +1,5 @@ +# OpenAI API Key (required if using OpenAI as LLM) +OPENAI_API_KEY=your_openai_api_key_here + +# Optional: Ollama base URL (if running locally, no API key needed) +# OLLAMA_BASE_URL=http://localhost:11434 diff --git a/01-rag-from-scratch/README.md b/01-rag-from-scratch/README.md new file mode 100644 index 0000000..fbaa37b --- /dev/null +++ b/01-rag-from-scratch/README.md @@ -0,0 +1,307 @@ +# RAG from Scratch πŸ” + +A beginner-friendly implementation of Retrieval-Augmented Generation (RAG) built step-by-step using LangChain, FAISS, and HuggingFace embeddings. Every file is heavily commented to explain *why* each piece exists, not just *what* it does. + +--- + +## What is RAG and Why Does It Matter? + +**The problem with plain LLMs:** Large Language Models like GPT-4 are trained on data up to a certain cutoff date, and they have no knowledge of *your* private documents β€” your company's policy manuals, your research papers, your product documentation. If you ask GPT-4 "What is the refund policy in our internal handbook?", it simply doesn't know. + +**What RAG does:** RAG (Retrieval-Augmented Generation) solves this by giving the LLM access to your documents *at query time*. Instead of retraining the model (expensive, slow), you store your documents in a searchable vector database. When a user asks a question, you retrieve the most relevant passages and include them in the LLM's prompt. The LLM reads those passages and answers *based on your documents*. + +**Why it matters:** RAG is currently the dominant architecture for production AI Q&A systems. It's cost-effective (no retraining), updatable (just add documents to the database), and auditable (you can see exactly which document chunks informed each answer). Understanding RAG from scratch gives you the foundation to build everything from customer support bots to internal knowledge assistants. + +--- + +## Architecture + +``` +YOUR DOCUMENTS (PDF / TXT / DOCX) + β”‚ + β–Ό + β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” + β”‚ 1. LOAD β”‚ Read files from disk into LangChain Document objects + β””β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”˜ + β”‚ + β–Ό + β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” + β”‚ 2. CHUNK β”‚ Split large docs into ~500-char overlapping pieces + β””β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”˜ + β”‚ + β–Ό + β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” + β”‚ 3. EMBED β”‚ Convert each chunk β†’ 384-dim vector (all-MiniLM-L6-v2) + β””β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”˜ + β”‚ + β–Ό + β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” + β”‚ 4. INDEX β”‚ Store vectors in FAISS (saved to disk for reuse) + β””β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”˜ + β”‚ + β”‚ USER QUESTION + β”‚ β”‚ + β”‚ β–Ό + β”‚ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” + β”‚ β”‚ 5. EMBED β”‚ Embed question β†’ vector + β”‚ β””β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”˜ + β”‚ β”‚ + β–Ό β–Ό + β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” + β”‚ FAISS SIMILARITY SEARCH β”‚ Find top-k most similar chunks + β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ + β”‚ + β–Ό + TOP-k RELEVANT CHUNKS + β”‚ + β–Ό + β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” + β”‚ 6. GENERATE (LLM + Prompt) β”‚ LLM reads chunks + question + β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ + β”‚ + β–Ό + GROUNDED ANSWER βœ… +``` + +--- + +## Tech Stack + +| Component | Library / Tool | Purpose | +|-------------------|-----------------------------------------|------------------------------------------| +| Document loading | `langchain-community` loaders | Read PDF, TXT, DOCX files | +| Text splitting | `langchain` RecursiveCharacterTextSplitter | Split docs into overlapping chunks | +| Embeddings | `sentence-transformers` (HuggingFace) | Convert text β†’ vectors (free, local) | +| Vector database | `faiss-cpu` | Fast similarity search over embeddings | +| LLM | OpenAI GPT-3.5/4 or local Ollama | Generate answers from retrieved context | +| Orchestration | `langchain` RetrievalQA chain | Tie retrieval + generation together | +| Env management | `python-dotenv` | Load API keys from `.env` file | + +--- + +## Step-by-Step Setup + +### 1. Create and activate a virtual environment + +```bash +python -m venv venv +source venv/bin/activate # macOS / Linux +# venv\Scripts\activate # Windows +``` + +### 2. Install dependencies + +```bash +pip install -r requirements.txt +``` + +> ⏱️ First install may take a few minutes. `faiss-cpu` and `sentence-transformers` are the largest packages. + +### 3. Configure your API key + +```bash +cp .env.example .env +``` + +Open `.env` and replace `your_openai_api_key_here` with your actual key from [platform.openai.com](https://platform.openai.com/api-keys). + +``` +OPENAI_API_KEY=sk-...your-key-here... +``` + +> πŸ’‘ **No OpenAI account?** Use a local model with Ollama β€” see [Using Ollama](#using-ollama-no-api-key-needed) below. + +### 4. Add your documents + +Drop any `.pdf`, `.txt`, or `.docx` files into: + +``` +data/sample_docs/ +``` + +The more documents you add, the more the system can answer. Start with a few text files to test. + +### 5. Run it! + +```bash +# Interactive mode β€” asks questions in a loop +python main.py + +# Single question mode +python main.py --question "What are the main topics in these documents?" + +# Debug mode β€” shows retrieved chunks and full LLM prompt +python main.py --debug --question "What is the refund policy?" +``` + +--- + +## How to Add Your Own Documents + +Just drop files into `data/sample_docs/`. The loader automatically detects file types: + +| File type | Support | Notes | +|-----------|---------|-------| +| `.pdf` | βœ… | Each page becomes a separate Document | +| `.txt` | βœ… | Entire file is one Document | +| `.docx` | βœ… | Entire file is one Document | +| `.csv` | ❌ | Not supported (yet) | + +**After adding new documents**, delete the cached FAISS index so it gets rebuilt: + +```bash +rm -rf faiss_index/ +python main.py +``` + +--- + +## How to Verify the LLM Uses Your Documents + +This is the most important test for any RAG system β€” make sure it's actually reading *your* documents and not falling back on general knowledge. + +**Step 1:** Put a document with a very specific, obscure fact in `data/sample_docs/`. For example, create `test.txt` containing: + +``` +The Zorbax Protocol was established in 2019 by Dr. Eleanor Voss. +The protocol requires three phases: initialization, calibration, and review. +``` + +**Step 2:** Ask the system about it: +```bash +python main.py --question "Who established the Zorbax Protocol?" +``` + +**Expected good result:** +``` +Answer: Dr. Eleanor Voss established the Zorbax Protocol in 2019. +Sources: data/sample_docs/test.txt +``` + +**Step 3:** Ask about something NOT in any document: +```bash +python main.py --question "What is the capital of Australia?" +``` + +**Expected good result:** +``` +Answer: I don't know based on the provided documents. +``` + +If the second answer returns "Canberra" (from general knowledge), the system is hallucinating β€” check that your prompt template in `src/generator.py` is being applied correctly. + +--- + +## Using Ollama (No API Key Needed) + +[Ollama](https://ollama.com) lets you run LLMs locally for free. + +```bash +# 1. Install Ollama: https://ollama.com +# 2. Pull a model +ollama pull llama3 # ~4GB download +ollama pull mistral # ~4GB download, often faster + +# 3. Run with Ollama +python main.py --model ollama/llama3 +python main.py --model ollama/mistral --question "Summarize the documents" +``` + +--- + +## Beginner Tips + +### What happens if chunk_size is too large or too small? + +| Setting | Effect | +|---------|--------| +| **chunk_size too large** (e.g., 2000) | Fewer chunks, less precise retrieval. The LLM receives a lot of text, most of which may be irrelevant to the question. | +| **chunk_size too small** (e.g., 50) | Thousands of tiny chunks. Each chunk lacks context β€” a sentence like "See the above section" becomes meaningless on its own. | +| **Sweet spot** (300–800 chars) | Roughly 1–2 paragraphs. Enough context to be meaningful, small enough to be precise. | + +### Why cosine similarity beats keyword search + +Traditional search (e.g., `grep`, SQL `LIKE`) requires exact word matches. Search for "car" and you won't find documents that say "automobile" or "vehicle". + +Semantic search (cosine similarity over embeddings) understands *meaning*: +- "car", "automobile", "vehicle", "sedan" β†’ all have very similar embeddings +- You can ask "What's the fastest way to travel?" and find chunks about "high-speed rail" or "airplane travel" β€” no exact keyword overlap needed + +### What does k mean in top-k retrieval? + +`k` is the number of document chunks retrieved per question. + +- **k=1**: Only the single best match. Very precise but may miss relevant context. +- **k=3** (default): A good balance. Captures the primary answer + nearby supporting text. +- **k=10**: Comprehensive but may include loosely related chunks that dilute the LLM's focus. + +Use `--k 5` on the command line to experiment. If the LLM keeps saying "I don't know" on questions you know are in the docs, try increasing k. + +--- + +## Troubleshooting + +### `OPENAI_API_KEY is not set` +```bash +cp .env.example .env +# edit .env and add your key +``` + +### `No documents were loaded` +Make sure you have files in `data/sample_docs/`. Only `.pdf`, `.txt`, and `.docx` are supported. + +### `FileNotFoundError: data/sample_docs does not exist` +```bash +mkdir -p data/sample_docs +# then add your files +``` + +### `Error: Connection refused` (Ollama) +Make sure Ollama is running: +```bash +ollama serve +``` + +### `Model not found` (Ollama) +Pull the model first: +```bash +ollama pull llama3 +``` + +### Answers seem wrong or generic +1. Run with `--debug` to see which chunks are being retrieved +2. Check the sources printed after each answer β€” are they the right files? +3. Try deleting `faiss_index/` and rebuilding β€” you may have stale embeddings +4. Try increasing `--k` to retrieve more context + +### `pip install` fails on `faiss-cpu` +On some systems you may need to install build tools: +```bash +# Ubuntu/Debian +sudo apt-get install build-essential + +# macOS +xcode-select --install +``` + +--- + +## Project Structure + +``` +01-rag-from-scratch/ +β”œβ”€β”€ README.md ← You are here +β”œβ”€β”€ requirements.txt ← Python dependencies +β”œβ”€β”€ .env.example ← Template for your API keys +β”œβ”€β”€ main.py ← Entry point β€” ties all 6 steps together +β”œβ”€β”€ data/ +β”‚ └── sample_docs/ ← Drop your .pdf/.txt/.docx files here +└── src/ + β”œβ”€β”€ __init__.py ← Makes src/ a Python package + β”œβ”€β”€ document_loader.py ← Step 1: Load documents from disk + β”œβ”€β”€ chunker.py ← Step 2: Split documents into chunks + β”œβ”€β”€ embedder.py ← Step 3: Convert text to vectors + β”œβ”€β”€ vector_store.py ← Step 4: Store/search vectors with FAISS + β”œβ”€β”€ retriever.py ← Step 5: Retrieve relevant chunks + └── generator.py ← Step 6: Generate answers with LLM +``` diff --git a/01-rag-from-scratch/data/sample_docs/.gitkeep b/01-rag-from-scratch/data/sample_docs/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/01-rag-from-scratch/main.py b/01-rag-from-scratch/main.py new file mode 100644 index 0000000..1baac5e --- /dev/null +++ b/01-rag-from-scratch/main.py @@ -0,0 +1,327 @@ +#!/usr/bin/env python3 +# main.py +# +# RAG FROM SCRATCH β€” COMPLETE PIPELINE +# ====================================== +# This file ties together all 6 steps of the RAG (Retrieval-Augmented Generation) +# pipeline into a single runnable script. +# +# THE 6 STEPS: +# 1. LOAD β†’ Read .pdf/.txt/.docx files from disk into LangChain Documents +# 2. CHUNK β†’ Split large documents into smaller overlapping chunks +# 3. EMBED β†’ Convert each chunk to a vector using a HuggingFace model +# 4. INDEX β†’ Store all vectors in a FAISS index (saved to disk for reuse) +# 5. RETRIEVE β†’ Given a user question, find the top-k most relevant chunks +# 6. GENERATE β†’ Pass the question + retrieved chunks to an LLM for a grounded answer +# +# USAGE: +# # Single question mode: +# python main.py --question "What are the main topics in these documents?" +# +# # Interactive mode (loops until you type 'quit'): +# python main.py +# +# # Use a local Ollama model instead of OpenAI: +# python main.py --model ollama/llama3 +# +# # Debug mode (shows full prompt sent to LLM and retrieved chunks): +# python main.py --debug --question "What is the refund policy?" +# +# # Specify a different data folder or index location: +# python main.py --data-dir my_docs/ --index-path my_index/ + +import os +import argparse + +# python-dotenv loads KEY=VALUE pairs from your .env file into os.environ. +# This is the standard way to manage API keys without hardcoding them in source code. +from dotenv import load_dotenv + +# Import each step of our pipeline from the src/ package +from src.document_loader import load_documents +from src.chunker import chunk_documents +from src.embedder import get_embedding_model, embed_text +from src.vector_store import get_or_create_vector_store +from src.retriever import get_retriever, retrieve_chunks +from src.generator import build_qa_chain + + +def parse_args(): + """ + Parse command-line arguments. + + argparse is Python's built-in library for CLI argument handling. + It automatically generates --help text from the descriptions below. + """ + parser = argparse.ArgumentParser( + description="RAG from Scratch β€” Ask questions about your documents using AI.", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + python main.py + python main.py --question "What are the main topics?" + python main.py --model ollama/llama3 --question "Summarize the documents" + python main.py --debug --question "What is the refund policy?" + python main.py --data-dir /path/to/docs --index-path /path/to/index + """, + ) + + parser.add_argument( + "--data-dir", + default="data/sample_docs", + help="Path to folder containing .pdf, .txt, or .docx files. " + "Default: data/sample_docs", + ) + + parser.add_argument( + "--index-path", + default="faiss_index", + help="Path to save/load the FAISS vector index. " + "Default: faiss_index (created automatically on first run).", + ) + + parser.add_argument( + "--model", + default="gpt-3.5-turbo", + help="LLM to use for answer generation. " + "Options: gpt-3.5-turbo, gpt-4, ollama/llama3, ollama/mistral. " + "Default: gpt-3.5-turbo", + ) + + parser.add_argument( + "--question", + default=None, + help="A single question to answer and exit. " + "If omitted, starts an interactive Q&A loop.", + ) + + parser.add_argument( + "--debug", + action="store_true", + help="Enable debug mode: print the full prompt sent to LLM and " + "detailed chain steps.", + ) + + parser.add_argument( + "--k", + type=int, + default=3, + help="Number of chunks to retrieve per query (top-k). Default: 3.", + ) + + return parser.parse_args() + + +def run_pipeline(args): + """ + Execute the full RAG pipeline end-to-end. + + This function orchestrates all 6 steps, printing clear separators between + each phase so you can follow along and understand what's happening. + """ + + print("=" * 60) + print(" RAG FROM SCRATCH β€” PIPELINE STARTING") + print("=" * 60) + + # ------------------------------------------------------------------------- + # LOAD ENVIRONMENT VARIABLES + # ------------------------------------------------------------------------- + # .env is NOT committed to git (see .gitignore). Copy .env.example β†’ .env + # and fill in your OPENAI_API_KEY before running with an OpenAI model. + load_dotenv() + + # Warn early if using OpenAI but the API key is missing + if not args.model.startswith("ollama/") and not os.getenv("OPENAI_API_KEY"): + print( + "\n⚠️ WARNING: OPENAI_API_KEY is not set in your environment.\n" + " Either:\n" + " 1. Copy .env.example to .env and add your API key, OR\n" + " 2. Use a local model with --model ollama/llama3\n" + ) + + # ------------------------------------------------------------------------- + # STEP 1: LOAD DOCUMENTS + # ------------------------------------------------------------------------- + print("\n" + "─" * 60) + print("STEP 1/6: Loading documents") + print("─" * 60) + print(f" Source directory: {args.data_dir}") + + documents = load_documents(args.data_dir) + + # If no documents were found, we can't continue β€” tell the user what to do + if not documents: + print( + "\n❌ No documents loaded. Please add .pdf, .txt, or .docx files to:\n" + f" {args.data_dir}\n" + "\nThen re-run: python main.py" + ) + return + + # ------------------------------------------------------------------------- + # STEP 2: CHUNK DOCUMENTS + # ------------------------------------------------------------------------- + print("\n" + "─" * 60) + print("STEP 2/6: Chunking documents") + print("─" * 60) + + chunks = chunk_documents( + documents, + chunk_size=500, # ~1-2 short paragraphs per chunk + chunk_overlap=50, # 50 chars of overlap to preserve context at boundaries + ) + + # ------------------------------------------------------------------------- + # STEP 3: LOAD EMBEDDING MODEL + # ------------------------------------------------------------------------- + print("\n" + "─" * 60) + print("STEP 3/6: Loading embedding model") + print("─" * 60) + print(" Model: all-MiniLM-L6-v2 (free, local, no API key needed)") + + embedding_model = get_embedding_model("all-MiniLM-L6-v2") + + # DEMO: Show what an embedding vector looks like (educational, not required) + if args.debug and chunks: + embed_text(chunks[0].page_content[:100], embedding_model) + + # ------------------------------------------------------------------------- + # STEP 4: BUILD OR LOAD VECTOR STORE + # ------------------------------------------------------------------------- + print("\n" + "─" * 60) + print("STEP 4/6: Building / loading FAISS vector store") + print("─" * 60) + print(f" Index location: {args.index_path}/") + print(f" Tip: Delete '{args.index_path}/' to force a full rebuild.") + + vector_store = get_or_create_vector_store( + chunks=chunks, + embedding_model=embedding_model, + path=args.index_path, + ) + + # ------------------------------------------------------------------------- + # STEP 5: SET UP RETRIEVER + # ------------------------------------------------------------------------- + print("\n" + "─" * 60) + print("STEP 5/6: Configuring retriever") + print("─" * 60) + print(f" Retrieval strategy: cosine similarity, top-k={args.k}") + + retriever = get_retriever(vector_store, k=args.k) + + # ------------------------------------------------------------------------- + # STEP 6: BUILD QA CHAIN (LLM + RETRIEVER) + # ------------------------------------------------------------------------- + print("\n" + "─" * 60) + print("STEP 6/6: Building QA chain (LLM + Retriever)") + print("─" * 60) + + qa_chain = build_qa_chain( + retriever=retriever, + model_name=args.model, + debug=args.debug, + ) + + print("\n" + "=" * 60) + print(" PIPELINE READY β€” Let's ask some questions!") + print("=" * 60) + + # ------------------------------------------------------------------------- + # Q&A PHASE: Single question or interactive loop + # ------------------------------------------------------------------------- + + if args.question: + # Single question mode β€” answer it and exit + ask_question(qa_chain, args.question, args.debug) + else: + # Interactive mode β€” keep asking until the user types 'quit' or 'exit' + print("\nπŸ’¬ Interactive Q&A Mode") + print(" Type your question and press Enter.") + print(" Type 'quit' or 'exit' to stop.\n") + + # Sample question to get the user started + sample_question = "What are the main topics covered in these documents?" + print(f" πŸ’‘ Sample question: {sample_question}\n") + + while True: + try: + question = input("Your question: ").strip() + except (KeyboardInterrupt, EOFError): + # Handle Ctrl+C gracefully + print("\n\nGoodbye! πŸ‘‹") + break + + if not question: + print(" (Please type a question, or 'quit' to exit)") + continue + + if question.lower() in ("quit", "exit", "q"): + print("Goodbye! πŸ‘‹") + break + + ask_question(qa_chain, question, args.debug) + + +def ask_question(qa_chain, question: str, debug: bool = False): + """ + Ask a single question and print the answer with source attribution. + + Args: + qa_chain: The assembled RetrievalQA chain. + question (str): The question to ask. + debug (bool): If True, print source document details. + """ + + print(f"\n❓ Question: {question}") + print(" (Retrieving relevant chunks and generating answer...)\n") + + try: + # .invoke() runs the full chain: + # question β†’ embed β†’ FAISS search β†’ retrieve chunks β†’ fill prompt β†’ LLM β†’ answer + result = qa_chain.invoke({"query": question}) + + # The result dict has: + # result["result"] β†’ the LLM's answer string + # result["source_documents"] β†’ list of Document objects used as context + answer = result["result"] + source_docs = result.get("source_documents", []) + + print(f"πŸ’‘ Answer:\n{answer}") + + # Show which source documents contributed to this answer + if source_docs: + print("\nπŸ“š Sources used:") + seen_sources = set() + for doc in source_docs: + source = doc.metadata.get("source", "unknown") + page = doc.metadata.get("page", "") + page_info = f", page {page}" if page != "" else "" + source_key = f"{source}{page_info}" + + # Deduplicate β€” a source file may appear multiple times (different chunks) + if source_key not in seen_sources: + print(f" β€’ {source_key}") + seen_sources.add(source_key) + + # In debug mode, show the actual chunk text used + if debug: + print(f" Context: {doc.page_content[:150]}...") + + except Exception as e: + print(f"\n❌ Error generating answer: {e}") + print( + "\nCommon causes:\n" + " β€’ Missing OPENAI_API_KEY (check your .env file)\n" + " β€’ Ollama not running (start with: ollama serve)\n" + " β€’ Model not pulled (run: ollama pull llama3)\n" + " β€’ Network connectivity issues\n" + ) + + print() # blank line for readability between questions + + +if __name__ == "__main__": + args = parse_args() + run_pipeline(args) diff --git a/01-rag-from-scratch/requirements.txt b/01-rag-from-scratch/requirements.txt new file mode 100644 index 0000000..89a9dc3 --- /dev/null +++ b/01-rag-from-scratch/requirements.txt @@ -0,0 +1,9 @@ +langchain==0.1.20 +langchain-community==0.0.38 +langchain-openai==0.1.6 +faiss-cpu==1.8.0 +sentence-transformers==2.7.0 +pypdf==4.2.0 +python-docx==1.1.2 +openai==1.30.1 +python-dotenv==1.0.1 diff --git a/01-rag-from-scratch/src/__init__.py b/01-rag-from-scratch/src/__init__.py new file mode 100644 index 0000000..369aae7 --- /dev/null +++ b/01-rag-from-scratch/src/__init__.py @@ -0,0 +1,5 @@ +# src/__init__.py +# Makes the src/ directory a Python package so we can do: +# from src.document_loader import load_documents +# from src.chunker import chunk_documents +# etc. from main.py diff --git a/01-rag-from-scratch/src/chunker.py b/01-rag-from-scratch/src/chunker.py new file mode 100644 index 0000000..bf4933d --- /dev/null +++ b/01-rag-from-scratch/src/chunker.py @@ -0,0 +1,107 @@ +# src/chunker.py +# +# STEP 2 OF THE RAG PIPELINE: CHUNKING DOCUMENTS +# +# WHY DO WE SPLIT DOCUMENTS INTO CHUNKS? +# ---------------------------------------- +# Large Language Models (LLMs) have a "context window" β€” a hard limit on how much +# text they can receive in a single prompt. For example, GPT-3.5-turbo has a ~16k +# token limit (roughly 12,000 words). If your document is a 200-page PDF, you +# CANNOT send the whole thing to the LLM at once. +# +# Even if you could, it's wasteful: most of the document is irrelevant to any +# given question. We only want to send the 2-3 paragraphs that are actually useful. +# +# The solution: split documents into small "chunks", embed each chunk as a vector, +# store them in a vector database, and at query time retrieve ONLY the most relevant +# chunks to include in the LLM prompt. +# +# WHAT IS chunk_overlap AND WHY DOES IT MATTER? +# ----------------------------------------------- +# Imagine a document with this text: +# "...the policy expires on December 31st. Renewal must be submitted 30 days..." +# +# Without overlap, if the split happens between "December 31st." and "Renewal", +# one chunk ends with an incomplete thought and the other starts mid-context. +# With overlap (e.g., 50 characters), the second chunk will start a bit before +# "Renewal", capturing "...policy expires on December 31st. Renewal..." β€” giving +# the LLM enough context to understand the sentence properly. +# +# HOW DOES RecursiveCharacterTextSplitter WORK? +# ----------------------------------------------- +# It tries to split text in a smart order of preference: +# 1. Split on paragraph breaks ("\n\n") first β€” preserves paragraph structure +# 2. If still too long, split on newlines ("\n") β€” preserves line structure +# 3. If still too long, split on spaces (" ") β€” preserves word boundaries +# 4. Last resort: split on individual characters β€” avoids going over limit +# +# This is smarter than a naive "split every N characters" approach because it +# tries to keep semantically coherent units together. +# +# TUNING chunk_size: +# ------------------- +# Too LARGE (e.g., 2000): Fewer chunks, retrieval is less precise. +# The LLM may receive a lot of irrelevant text. +# +# Too SMALL (e.g., 50): Many chunks, each missing context. +# A sentence like "See section above" becomes meaningless alone. +# +# Sweet spot: 300–800 characters (roughly 1-2 short paragraphs). +# Default here is 500 with 50-character overlap β€” a good starting point. + +from langchain.text_splitter import RecursiveCharacterTextSplitter + + +def chunk_documents( + documents: list, + chunk_size: int = 500, + chunk_overlap: int = 50, +) -> list: + """ + Split a list of LangChain Documents into smaller chunks. + + Each chunk is itself a LangChain Document object, inheriting the metadata + of the original document (so we still know which file each chunk came from). + + Args: + documents (list): List of LangChain Document objects (from document_loader). + chunk_size (int): Maximum number of characters per chunk. Default: 500. + chunk_overlap (int): Number of characters to overlap between consecutive chunks. + Helps preserve context at boundaries. Default: 50. + + Returns: + list: A (usually much longer) list of smaller LangChain Document objects. + + Example: + chunks = chunk_documents(documents, chunk_size=500, chunk_overlap=50) + print(f"Created {len(chunks)} chunks") + print(chunks[0].page_content) # first chunk's text + print(chunks[0].metadata) # same metadata as parent document + """ + + print(f"\nπŸ“ Chunking {len(documents)} document(s)...") + print(f" Settings: chunk_size={chunk_size}, chunk_overlap={chunk_overlap}") + + # Create the splitter with our chosen settings. + # separators: the list of strings it will try to split on, in order of preference. + splitter = RecursiveCharacterTextSplitter( + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + # Try these separators in order β€” paragraph breaks β†’ newlines β†’ spaces β†’ characters + separators=["\n\n", "\n", " ", ""], + # length_function: how to measure "size". len() counts characters. + # You could swap this for a token counter if you want chunk_size in tokens. + length_function=len, + ) + + # split_documents() handles the full list at once and preserves metadata. + # It returns a new list of Document objects β€” one per chunk. + chunks = splitter.split_documents(documents) + + print(f"βœ… Created {len(chunks)} chunks from {len(documents)} document(s)") + print( + f" Average chunk size: " + f"~{sum(len(c.page_content) for c in chunks) // max(len(chunks), 1)} characters" + ) + + return chunks diff --git a/01-rag-from-scratch/src/document_loader.py b/01-rag-from-scratch/src/document_loader.py new file mode 100644 index 0000000..2123be4 --- /dev/null +++ b/01-rag-from-scratch/src/document_loader.py @@ -0,0 +1,125 @@ +# src/document_loader.py +# +# STEP 1 OF THE RAG PIPELINE: LOADING DOCUMENTS +# +# Before we can answer questions about your documents, we first need to READ them. +# This module handles loading different file types (.pdf, .txt, .docx) into a +# common format that LangChain can work with. +# +# What is a LangChain "Document" object? +# ---------------------------------------- +# LangChain uses a Document object to represent a piece of text. It has two fields: +# +# document.page_content β†’ the actual text string (e.g., "The capital of France is Paris...") +# document.metadata β†’ a dict with info about where the text came from +# e.g., {"source": "data/sample_docs/report.pdf", "page": 2} +# +# Why use Documents instead of plain strings? +# Because we want to keep track of WHERE each piece of text came from. +# When the LLM answers a question, we can tell the user "this answer came from page 3 of report.pdf" +# β€” that's only possible if we preserve the metadata through the pipeline. + +import os +from pathlib import Path + +# LangChain community loaders for different file types. +# These loaders know how to read each format and return a list of Document objects. +from langchain_community.document_loaders import ( + PyPDFLoader, # Reads PDF files β€” returns one Document per page + TextLoader, # Reads plain .txt files β€” returns one Document per file + Docx2txtLoader, # Reads .docx (Word) files β€” returns one Document per file +) + + +def load_documents(data_dir: str) -> list: + """ + Load all supported documents from a directory. + + Walks through the given directory, finds all .pdf, .txt, and .docx files, + loads each one using the appropriate loader, and returns a flat list of + LangChain Document objects. + + Args: + data_dir (str): Path to the folder containing your documents. + e.g., "data/sample_docs" + + Returns: + list: A list of LangChain Document objects. Each document has: + - page_content: the text extracted from the file + - metadata: dict containing at minimum {"source": } + + Example: + documents = load_documents("data/sample_docs") + print(documents[0].page_content) # prints raw text + print(documents[0].metadata) # prints {"source": "data/sample_docs/report.pdf", "page": 0} + """ + + # Convert to a Path object for easier cross-platform file handling + data_path = Path(data_dir) + + # Make sure the directory actually exists before trying to read it + if not data_path.exists(): + raise FileNotFoundError( + f"Data directory '{data_dir}' does not exist. " + f"Please create it and add some .pdf, .txt, or .docx files." + ) + + all_documents = [] # We'll collect all Document objects here + + # Map each file extension to its corresponding LangChain loader class. + # This makes it easy to add new file types later β€” just add an entry here. + loader_map = { + ".pdf": PyPDFLoader, + ".txt": TextLoader, + ".docx": Docx2txtLoader, + } + + # Walk through every file in the directory (and subdirectories) + for file_path in sorted(data_path.rglob("*")): + + # Skip directories β€” we only want files + if not file_path.is_file(): + continue + + # Skip hidden files (e.g., .gitkeep, .DS_Store) + if file_path.name.startswith("."): + continue + + file_ext = file_path.suffix.lower() # e.g., ".pdf", ".txt", ".docx" + + # Check if we have a loader for this file type + if file_ext not in loader_map: + # Unsupported file type β€” skip it with a warning + print(f" ⚠️ Skipping unsupported file type: {file_path.name} ({file_ext})") + continue + + print(f" πŸ“„ Loading: {file_path.name}") + + try: + # Instantiate the appropriate loader with the file path + loader_class = loader_map[file_ext] + loader = loader_class(str(file_path)) + + # .load() returns a list of Document objects. + # For PDFs, each page becomes its own Document. + # For TXT/DOCX, the whole file is usually one Document. + documents = loader.load() + + print(f" β†’ Loaded {len(documents)} document chunk(s)") + all_documents.extend(documents) + + except Exception as e: + # Don't crash the whole pipeline if one file fails to load. + # Print the error and continue with the remaining files. + print(f" ❌ Failed to load {file_path.name}: {e}") + continue + + if len(all_documents) == 0: + print( + f"\n⚠️ No documents were loaded from '{data_dir}'.\n" + f" Add some .pdf, .txt, or .docx files and try again." + ) + else: + print(f"\nβœ… Total documents loaded: {len(all_documents)}") + + return all_documents diff --git a/01-rag-from-scratch/src/embedder.py b/01-rag-from-scratch/src/embedder.py new file mode 100644 index 0000000..30532aa --- /dev/null +++ b/01-rag-from-scratch/src/embedder.py @@ -0,0 +1,114 @@ +# src/embedder.py +# +# STEP 3 OF THE RAG PIPELINE: EMBEDDING TEXT INTO VECTORS +# +# WHAT ARE EMBEDDINGS? +# ---------------------- +# An "embedding" is a way to represent text as a list of numbers (a vector). +# The key insight is that semantically similar text produces numerically similar vectors. +# +# For example, these two sentences will have very similar vectors: +# "The cat sat on the mat." +# "A feline rested on the rug." +# +# Even though they share no keywords, an embedding model understands they mean +# the same thing. This is the magic that makes semantic search work! +# +# WHY all-MiniLM-L6-v2? +# ----------------------- +# We use the "all-MiniLM-L6-v2" model from HuggingFace for several reasons: +# +# βœ… FREE β€” no API key required, runs entirely on your local machine +# βœ… FAST β€” it's a small, distilled model (only ~80MB to download) +# βœ… GOOD QUALITY β€” despite its size, it scores well on semantic benchmarks +# βœ… 384 DIMENSIONS β€” each piece of text becomes a list of 384 numbers +# +# Alternative: OpenAI's text-embedding-ada-002 is more powerful but costs money +# and requires an API key. For learning, the free HuggingFace model is perfect. +# +# WHAT DOES "384 DIMENSIONS" MEAN? +# ---------------------------------- +# Each text string gets converted to a list of 384 floating-point numbers. +# Think of it as a point in 384-dimensional space. Similar texts are "close" +# to each other in this space; unrelated texts are "far apart." +# +# WHY COSINE SIMILARITY? +# ----------------------- +# To find which chunks are most relevant to a query, we compare their vectors. +# We use "cosine similarity" which measures the angle between two vectors: +# - Score of 1.0 = identical direction = very similar meaning +# - Score of 0.0 = perpendicular = unrelated +# - Score of -1.0 = opposite direction = opposite meaning (rare in practice) +# +# Cosine similarity is preferred over Euclidean distance because it's insensitive +# to the magnitude of the vectors β€” only the direction matters. + +from langchain_community.embeddings import HuggingFaceEmbeddings + + +def get_embedding_model(model_name: str = "all-MiniLM-L6-v2") -> HuggingFaceEmbeddings: + """ + Load a HuggingFace sentence-transformer embedding model. + + The first time you call this, it will download the model (~80MB) from + HuggingFace Hub and cache it locally. Subsequent calls use the cache. + + Args: + model_name (str): HuggingFace model name. Default: "all-MiniLM-L6-v2" + Other options: "all-mpnet-base-v2" (higher quality, slower) + + Returns: + HuggingFaceEmbeddings: A LangChain-compatible embedding model object. + Call model.embed_documents([...]) or model.embed_query("...") + """ + + print(f"\nπŸ”’ Loading embedding model: '{model_name}'") + print(f" (First run will download ~80MB β€” subsequent runs use cache)") + + # model_kwargs: passed directly to the underlying sentence-transformers library + # device="cpu" means we run on CPU β€” change to "cuda" if you have a GPU + embedding_model = HuggingFaceEmbeddings( + model_name=model_name, + model_kwargs={"device": "cpu"}, + # encode_kwargs: controls how the model encodes text into vectors + # normalize_embeddings=True ensures vectors have length 1.0, + # which makes cosine similarity equivalent to dot product (faster computation) + encode_kwargs={"normalize_embeddings": True}, + ) + + print(f"βœ… Embedding model loaded successfully") + return embedding_model + + +def embed_text(text: str, model) -> list: + """ + Embed a single string and show what the resulting vector looks like. + + This is a teaching/demo function β€” it helps beginners see that "embedding" + just means converting text into a list of numbers. + + Args: + text (str): Any string to embed. + model: A loaded HuggingFaceEmbeddings (or compatible) model. + + Returns: + list: A list of 384 floats representing the text's meaning as a vector. + + Example: + model = get_embedding_model() + vector = embed_text("Hello world", model) + # Prints: Vector shape: 384 dimensions + # Prints: First 5 values: [0.023, -0.041, 0.118, ...] + """ + + # embed_query() is the LangChain method for embedding a single string. + # (embed_documents() is for embedding a list of strings all at once β€” more efficient.) + vector = model.embed_query(text) + + # Show the learner what a vector actually looks like + print(f"\nπŸ” Embedding demo for: '{text[:60]}{'...' if len(text) > 60 else ''}'") + print(f" Vector shape: {len(vector)} dimensions") + print(f" First 5 values: {[round(v, 4) for v in vector[:5]]}") + print(f" (Each number encodes a tiny aspect of the text's meaning)") + + return vector diff --git a/01-rag-from-scratch/src/generator.py b/01-rag-from-scratch/src/generator.py new file mode 100644 index 0000000..a917be0 --- /dev/null +++ b/01-rag-from-scratch/src/generator.py @@ -0,0 +1,167 @@ +# src/generator.py +# +# STEP 6 OF THE RAG PIPELINE: GENERATING THE ANSWER WITH AN LLM +# +# WHAT DOES RetrievalQA DO? +# -------------------------- +# RetrievalQA is a LangChain "chain" that combines two things: +# 1. A retriever (which fetches relevant chunks from FAISS) +# 2. An LLM (which reads those chunks and generates an answer) +# +# It handles the "stuffing" step: it takes the retrieved Document objects, +# extracts their page_content, concatenates them into a {context} block, +# and injects that into our prompt template before calling the LLM. +# +# WHY THE "ONLY USE CONTEXT" INSTRUCTION PREVENTS HALLUCINATION: +# --------------------------------------------------------------- +# LLMs are trained on massive datasets and have general knowledge baked in. +# Without explicit instructions, an LLM might answer from its training data +# instead of your documents β€” which defeats the entire purpose of RAG. +# +# The system instruction "answer ONLY based on the following context" tells +# the LLM to restrict itself to what we provide. The fallback phrase +# "I don't know based on the provided documents" prevents the LLM from +# making things up when the answer truly isn't in the documents. +# +# This is the most important prompt engineering technique in RAG systems. +# +# WHAT IS THE SYSTEM PROMPT PATTERN? +# ------------------------------------ +# A "prompt template" is a string with placeholder variables (like {context} +# and {question}) that get filled in at runtime. This lets us: +# - Set the LLM's behavior with clear instructions at the top +# - Inject the retrieved context dynamically for each query +# - Ask the user's question at the end +# +# The resulting filled-in prompt is what actually gets sent to the LLM API. +# With debug=True, you can print this full prompt to see exactly what the LLM receives. + +from langchain_openai import ChatOpenAI +from langchain.chains import RetrievalQA +from langchain.prompts import PromptTemplate + + +# The prompt template instructs the LLM to stay grounded in the provided context. +# {context} will be replaced by the retrieved chunks (as a single text block). +# {question} will be replaced by the user's question. +RAG_PROMPT_TEMPLATE = """You are a helpful assistant. Answer the question based ONLY on the following context. +If the answer is not in the context, say "I don't know based on the provided documents." +Do not use your general knowledge. + +Context: +{context} + +Question: {question} + +Answer:""" + + +def build_qa_chain( + retriever, + model_name: str = "gpt-3.5-turbo", + debug: bool = False, +): + """ + Build a RetrievalQA chain that combines document retrieval with LLM generation. + + This is the final assembly step of the RAG pipeline: + User question + β†’ retriever fetches top-k relevant chunks from FAISS + β†’ chunks are injected into the prompt template as {context} + β†’ LLM reads context + question and generates a grounded answer + + Args: + retriever: A LangChain retriever (from retriever.py). + model_name (str): LLM to use. Options: + - "gpt-3.5-turbo" (OpenAI, requires OPENAI_API_KEY) + - "gpt-4" (OpenAI, more powerful, costs more) + - "ollama/llama3" (local Ollama, no API key needed) + - "ollama/mistral" (local Ollama, no API key needed) + debug (bool): If True, prints the full prompt sent to the LLM. + Useful for understanding what the LLM actually receives. + + Returns: + RetrievalQA: A runnable chain. Call chain.invoke({"query": "your question"}) + to get an answer dict with keys "query", "result", "source_documents". + + Example: + chain = build_qa_chain(retriever, model_name="gpt-3.5-turbo", debug=True) + result = chain.invoke({"query": "What is the refund policy?"}) + print(result["result"]) + """ + + print(f"\nπŸ€– Building QA chain with model: '{model_name}'") + + # ------------------------------------------------------------------------- + # SELECT THE LLM BASED ON model_name + # ------------------------------------------------------------------------- + if model_name.startswith("ollama/"): + # Ollama runs LLMs locally on your machine β€” no API key, no cost. + # Install Ollama from https://ollama.com and pull a model: + # ollama pull llama3 + # ollama pull mistral + # + # The model_name format is "ollama/" e.g. "ollama/llama3" + import os + from langchain_community.llms import Ollama + + # Extract the model tag after the "ollama/" prefix + ollama_model = model_name.split("/", 1)[1] + base_url = os.getenv("OLLAMA_BASE_URL", "http://localhost:11434") + + print(f" Using local Ollama model '{ollama_model}' at {base_url}") + llm = Ollama(model=ollama_model, base_url=base_url) + + else: + # OpenAI models (gpt-3.5-turbo, gpt-4, gpt-4o, etc.) + # Requires OPENAI_API_KEY to be set in your .env file. + # + # temperature=0 means "deterministic" β€” the LLM always picks the highest + # probability token. For Q&A this is ideal; you want consistent, factual + # answers rather than creative variation. + llm = ChatOpenAI( + model_name=model_name, + temperature=0, # 0 = deterministic/factual, 1 = more creative/varied + ) + print(f" Using OpenAI model '{model_name}' (ensure OPENAI_API_KEY is set)") + + # ------------------------------------------------------------------------- + # BUILD THE PROMPT TEMPLATE + # ------------------------------------------------------------------------- + prompt = PromptTemplate( + template=RAG_PROMPT_TEMPLATE, + input_variables=["context", "question"], # placeholders to fill at runtime + ) + + # If debug mode is on, show the template so learners can see the structure + if debug: + print("\nπŸ› DEBUG: Prompt template being used:") + print("-" * 60) + print(RAG_PROMPT_TEMPLATE) + print("-" * 60) + + # ------------------------------------------------------------------------- + # ASSEMBLE THE RetrievalQA CHAIN + # ------------------------------------------------------------------------- + # chain_type="stuff" means: take all retrieved chunks, "stuff" them all into + # the context at once. This works well for small k values (k=3 to k=5). + # + # Other chain_type options: + # "map_reduce" β€” summarize each chunk separately, then combine (handles many chunks) + # "refine" β€” iteratively refine the answer chunk by chunk (slower but thorough) + # "map_rerank" β€” score each chunk separately and pick the best answer + # + # For most use cases with k<=5, "stuff" is the simplest and most effective. + qa_chain = RetrievalQA.from_chain_type( + llm=llm, + chain_type="stuff", + retriever=retriever, + return_source_documents=True, # include source docs in the result dict + chain_type_kwargs={ + "prompt": prompt, + "verbose": debug, # if debug=True, LangChain will print internal chain steps + }, + ) + + print(f"βœ… QA chain ready") + return qa_chain diff --git a/01-rag-from-scratch/src/retriever.py b/01-rag-from-scratch/src/retriever.py new file mode 100644 index 0000000..e673204 --- /dev/null +++ b/01-rag-from-scratch/src/retriever.py @@ -0,0 +1,117 @@ +# src/retriever.py +# +# STEP 5 OF THE RAG PIPELINE: RETRIEVING RELEVANT CHUNKS +# +# WHAT DOES "RETRIEVAL" DO IN THE RAG PIPELINE? +# ----------------------------------------------- +# At this point we have: +# - All our document chunks stored as vectors in FAISS +# - A user's question +# +# The retriever's job is to: +# 1. Embed the question using the SAME embedding model we used for the chunks +# 2. Search FAISS for the k most similar chunk vectors to the question vector +# 3. Return those k chunks (as Document objects) so the LLM can read them +# +# The LLM never reads the whole document database β€” it only reads these k chunks. +# This is what makes RAG efficient and precise. +# +# WHAT IS k IN TOP-k RETRIEVAL? +# -------------------------------- +# k is the number of chunks we retrieve. Think of it as: +# "Give me the top 3 most relevant paragraphs from my documents." +# +# k=1: Very focused. Only the single best match. May miss related info. +# k=3: A good balance. Captures the main answer + nearby context. (default) +# k=10: Comprehensive but may include loosely related chunks that confuse the LLM. +# +# Rule of thumb: Start with k=3 and increase if the LLM says "I don't know" +# on questions you KNOW are in your documents. +# +# WHY COSINE SIMILARITY BEATS KEYWORD SEARCH: +# -------------------------------------------- +# Traditional search (like grep or SQL LIKE) requires exact keyword matches. +# If your document says "automobile" and you search for "car", you get nothing. +# +# Semantic (vector) search understands meaning: +# "car" β†’ very similar vector to "automobile", "vehicle", "sedan" +# +# This means you can ask questions in natural language and still find relevant +# chunks even when the exact words don't match. This is crucial for Q&A systems +# where users phrase questions differently than how documents are written. + + +def get_retriever(vector_store, k: int = 3): + """ + Create a LangChain retriever from a FAISS vector store. + + A LangChain "retriever" is a standardized interface that wraps the vector store + and exposes a simple .invoke(query) method. This makes it easy to plug into + LangChain chains (like RetrievalQA in generator.py). + + Args: + vector_store: A FAISS vector store (from vector_store.py). + k (int): How many chunks to retrieve per query. Default: 3. + Increase if answers are missing info; decrease if too noisy. + + Returns: + A LangChain VectorStoreRetriever object. + + Example: + retriever = get_retriever(vector_store, k=3) + docs = retriever.invoke("What is the refund policy?") + """ + + # as_retriever() wraps the FAISS store in a Retriever interface. + # search_type="similarity" uses cosine similarity (since we normalized embeddings). + # Other options: "mmr" (Maximal Marginal Relevance β€” reduces redundancy among results) + retriever = vector_store.as_retriever( + search_type="similarity", + search_kwargs={"k": k}, # retrieve top-k most similar chunks + ) + + print(f"\nπŸ”Ž Retriever configured (top-k={k}, search_type=similarity)") + return retriever + + +def retrieve_chunks(question: str, retriever) -> list: + """ + Retrieve the most relevant document chunks for a given question. + + Also prints the retrieved chunks so learners can inspect what gets passed + to the LLM. This transparency is key to understanding and debugging RAG. + + Args: + question (str): The user's question in natural language. + retriever: A LangChain retriever (from get_retriever()). + + Returns: + list: A list of LangChain Document objects β€” the most relevant chunks. + Each has .page_content (the text) and .metadata (source file, page, etc.) + + Example: + chunks = retrieve_chunks("What is the refund policy?", retriever) + for chunk in chunks: + print(chunk.page_content) + print(chunk.metadata["source"]) + """ + + print(f"\nπŸ” Retrieving relevant chunks for: '{question}'") + + # .invoke() embeds the question and runs the similarity search + relevant_chunks = retriever.invoke(question) + + print(f"\nπŸ“‹ Top {len(relevant_chunks)} retrieved chunk(s):") + print("-" * 60) + + for i, chunk in enumerate(relevant_chunks, 1): + source = chunk.metadata.get("source", "unknown") + page = chunk.metadata.get("page", "") + page_info = f" (page {page})" if page != "" else "" + + print(f"\n[Chunk {i}] Source: {source}{page_info}") + print(f"Content preview: {chunk.page_content[:200]}...") + + print("-" * 60) + + return relevant_chunks diff --git a/01-rag-from-scratch/src/vector_store.py b/01-rag-from-scratch/src/vector_store.py new file mode 100644 index 0000000..472a3a0 --- /dev/null +++ b/01-rag-from-scratch/src/vector_store.py @@ -0,0 +1,159 @@ +# src/vector_store.py +# +# STEP 4 OF THE RAG PIPELINE: STORING VECTORS IN FAISS +# +# WHAT IS FAISS? +# --------------- +# FAISS (Facebook AI Similarity Search) is an open-source library developed by +# Meta (Facebook) Research. It is specifically designed for one task: +# +# Given a query vector, quickly find the most similar vectors in a large collection. +# +# This is called "Approximate Nearest Neighbor" (ANN) search. Doing this naively +# (comparing the query against every stored vector one by one) would be too slow +# at scale. FAISS builds an *index* β€” a special data structure that lets it find +# the top-k similar vectors in milliseconds, even across millions of documents. +# +# HOW FAISS WORKS (CONCEPTUALLY): +# --------------------------------- +# 1. During indexing: FAISS takes all your chunk vectors and organizes them into +# a spatial data structure (e.g., an inverted file index or HNSW graph). +# 2. During search: Given a query vector, FAISS navigates the data structure to +# find the nearest neighbors without checking every single vector. +# +# For our use case (hundreds to thousands of chunks), FAISS is near-instant. +# It really shines at millions of vectors, but it's a great habit to use from day one. +# +# WHY SAVE THE INDEX TO DISK? +# ----------------------------- +# Embedding documents takes time (each chunk must be processed by the neural network). +# If we re-ran embedding every time we started the app, we'd waste seconds/minutes +# on every run even when the documents haven't changed. +# +# By saving the FAISS index to disk, we only embed once. On subsequent runs, we +# load the pre-built index from disk in milliseconds. +# +# The saved index consists of two files: +# faiss_index/index.faiss β†’ the actual vector index (binary) +# faiss_index/index.pkl β†’ metadata mapping (which chunk belongs to which vector) + +import os +from langchain_community.vectorstores import FAISS + + +def create_vector_store(chunks: list, embedding_model) -> FAISS: + """ + Embed all chunks and build a FAISS vector store from them. + + This is the "indexing" phase β€” it calls the embedding model once per chunk + (or in batches) and stores all resulting vectors in a FAISS index. + + Args: + chunks (list): List of LangChain Document objects (from chunker.py). + embedding_model: A loaded HuggingFaceEmbeddings model (from embedder.py). + + Returns: + FAISS: An in-memory FAISS vector store ready for similarity search. + """ + + print(f"\nπŸ—„οΈ Building FAISS vector store from {len(chunks)} chunks...") + print(f" (Embedding each chunk β€” this may take a moment on first run)") + + # FAISS.from_documents() does two things in one call: + # 1. Calls embedding_model.embed_documents() on all chunks + # 2. Builds the FAISS index from the resulting vectors + vector_store = FAISS.from_documents( + documents=chunks, + embedding=embedding_model, + ) + + print(f"βœ… Vector store created with {len(chunks)} vectors") + return vector_store + + +def save_vector_store(vector_store: FAISS, path: str = "faiss_index") -> None: + """ + Persist the FAISS index to disk so we don't have to re-embed next time. + + Saves two files: + {path}/index.faiss β€” the binary vector index + {path}/index.pkl β€” the document metadata mapping + + Args: + vector_store (FAISS): The in-memory FAISS vector store to save. + path (str): Directory path where index files will be written. + Default: "faiss_index" + """ + + # Create the directory if it doesn't exist yet + os.makedirs(path, exist_ok=True) + + # LangChain's FAISS wrapper handles the actual serialization + vector_store.save_local(path) + + print(f"πŸ’Ύ Vector store saved to '{path}/'") + print(f" Files: {path}/index.faiss, {path}/index.pkl") + + +def load_vector_store(path: str, embedding_model) -> FAISS: + """ + Load a previously saved FAISS index from disk. + + Args: + path (str): Directory path where index files are stored. + embedding_model: The SAME embedding model used when the index was created. + IMPORTANT: If you use a different model, the vectors won't + match and search results will be nonsense. + + Returns: + FAISS: The loaded vector store, ready for similarity search. + """ + + print(f"\nπŸ“‚ Loading existing FAISS index from '{path}/'...") + + # allow_dangerous_deserialization=True is required because FAISS uses pickle + # under the hood. This is safe as long as you trust the source of the index file + # (which you do, since you created it yourself). + vector_store = FAISS.load_local( + folder_path=path, + embeddings=embedding_model, + allow_dangerous_deserialization=True, + ) + + print(f"βœ… Vector store loaded from disk") + return vector_store + + +def get_or_create_vector_store( + chunks: list, + embedding_model, + path: str = "faiss_index", +) -> FAISS: + """ + Convenience function: load index from disk if it exists, otherwise build it. + + This is the function you'll call in main.py. It implements a simple cache: + - If '{path}/index.faiss' exists β†’ load it (fast, skips re-embedding) + - Otherwise β†’ embed all chunks and build a new index, then save it + + Args: + chunks (list): List of LangChain Document chunks. + embedding_model: Loaded embedding model. + path (str): Path to save/load the FAISS index. Default: "faiss_index" + + Returns: + FAISS: Ready-to-use vector store. + """ + + # Check if a saved index already exists on disk + index_file = os.path.join(path, "index.faiss") + + if os.path.exists(index_file): + print(f"\n♻️ Found existing FAISS index at '{path}/' β€” loading from disk") + print(f" (Skipping re-embedding. Delete '{path}/' to force rebuild.)") + return load_vector_store(path, embedding_model) + else: + print(f"\nπŸ†• No existing index found at '{path}/' β€” building from scratch") + vector_store = create_vector_store(chunks, embedding_model) + save_vector_store(vector_store, path) + return vector_store diff --git a/02-legal-ai-assistant/.env.example b/02-legal-ai-assistant/.env.example new file mode 100644 index 0000000..d28df29 --- /dev/null +++ b/02-legal-ai-assistant/.env.example @@ -0,0 +1,8 @@ +# OpenAI API Key (required - GPT-4 recommended for legal analysis accuracy) +OPENAI_API_KEY=your_openai_api_key_here + +# Model to use (gpt-4 recommended for accuracy, gpt-3.5-turbo for cost savings) +OPENAI_MODEL=gpt-4 + +# Optional: Anthropic Claude API (alternative to OpenAI) +# ANTHROPIC_API_KEY=your_anthropic_key_here diff --git a/02-legal-ai-assistant/README.md b/02-legal-ai-assistant/README.md new file mode 100644 index 0000000..3580048 --- /dev/null +++ b/02-legal-ai-assistant/README.md @@ -0,0 +1,225 @@ +# Legal AI Assistant + +> ⚠️ **DISCLAIMER: This tool is for educational purposes only. It does NOT constitute legal advice. Always consult a qualified attorney before making any legal or business decisions.** + +A Retrieval-Augmented Generation (RAG) pipeline that helps you understand contracts by extracting key clauses, flagging risks, detecting internal conflicts, and answering natural-language questions β€” all grounded in the actual document text. + +--- + +## What the Tool Does + +- **Parses** PDF and DOCX contract files into structured text with section detection +- **Indexes** the document into a FAISS vector store for semantic search +- **Summarises** the contract: parties, type, effective date, duration, key obligations +- **Extracts clauses**: indemnification, limitation of liability, termination, governing law, IP ownership, confidentiality +- **Analyses risks**: flags HIGH / MEDIUM / LOW risk patterns with plain-English explanations +- **Detects conflicts**: surfaces internal contradictions between clauses +- **Answers questions**: grounded Q&A with mandatory section citations + +--- + +## Supported Document Types + +| Format | Extension | Notes | +|--------|-----------|-------| +| PDF | `.pdf` | Text-based PDFs only. Scanned PDFs require OCR pre-processing. | +| Word | `.docx` | Supports Heading styles for better section detection. | + +--- + +## Architecture + +``` +β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” +β”‚ main.py (CLI) β”‚ +β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ + β”‚ + β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β–Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” + β”‚ document_parser.py β”‚ PDF / DOCX β†’ full_text + sections + β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ + β”‚ + β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β–Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” + β”‚ indexer.py β”‚ text β†’ chunks β†’ HuggingFace embeddings β†’ FAISS + β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ + β”‚ + β”Œβ”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β” + β”‚ OpenAI GPT-4 β”‚ (all LLM calls below use this) + β””β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”˜ + β”‚ + β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β–Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” + β”‚ summarizer.py β”‚ β”‚ clause_extractor.py β”‚ + β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ + β”‚ + β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β–Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” + β”‚ risk_analyzer.py β”‚ + β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ + β”‚ + β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β–Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” + β”‚ conflict_detector.py β”‚ + β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ + β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” + β”‚ qa_chain.py β”‚ FAISS retriever + custom legal prompt + β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ +``` + +--- + +## Setup + +### 1. Clone / navigate to the project + +```bash +cd 02-legal-ai-assistant +``` + +### 2. Create a virtual environment + +```bash +python -m venv venv +source venv/bin/activate # macOS/Linux +venv\Scripts\activate # Windows +``` + +### 3. Install dependencies + +```bash +pip install -r requirements.txt +``` + +### 4. Configure environment variables + +```bash +cp .env.example .env +# Edit .env and add your OpenAI API key +``` + +```env +OPENAI_API_KEY=sk-... +OPENAI_MODEL=gpt-4 # or gpt-3.5-turbo for lower cost +``` + +### 5. Add a contract file + +Place a PDF or DOCX contract in `data/sample_contracts/` or any other path. + +--- + +## How to Run + +### Full analysis (default) + +```bash +python main.py --file data/sample_contracts/service_agreement.pdf +``` + +### Use a cheaper model (faster, less accurate) + +```bash +python main.py --file contract.pdf --model gpt-3.5-turbo +``` + +### Skip risk analysis and conflict detection + +```bash +python main.py --file contract.pdf --skip-risks --skip-conflicts +``` + +### Ask a single question and exit + +```bash +python main.py --file contract.pdf --question "What are my termination rights?" +``` + +### Interactive Q&A after analysis + +```bash +python main.py --file contract.pdf --interactive +``` + +--- + +## Sample Questions to Ask + +``` +What are my termination rights? +Who owns IP I create during the contract? +What is the liability cap? +How does auto-renewal work? +What information must I keep confidential and for how long? +Which court has jurisdiction over disputes? +Can the company change the terms without my consent? +What happens to my work if the contract is terminated early? +``` + +--- + +## Output Sections Explained + +| Section | What it shows | +|---------|---------------| +| **Executive Summary** | Parties, contract type, effective date, duration, key obligations, plain-English overview | +| **Key Clauses** | Table of named clause types with their section references and plain-English translations | +| **Risk Analysis** | πŸ”΄ HIGH / 🟑 MEDIUM / 🟒 LOW risks with explanations and fair alternatives | +| **Conflict Detection** | Internal contradictions between clauses (e.g. mismatched notice periods) | +| **Q&A** | Grounded answers with mandatory section citations | + +--- + +## Limitations + +> These are not bugs β€” they are inherent limitations of the technology. + +1. **Cannot reliably detect all conflicts.** The LLM may miss conflicts requiring deep legal expertise or flag false positives. Every flagged conflict must be manually verified. + +2. **PDF extraction may miss some formatting.** Tables lose column alignment, scanned PDFs produce no text, and footnotes may appear mid-sentence. Complex formatting in PDFs will degrade extraction quality. + +3. **LLM can misinterpret complex legal language.** Highly technical, jurisdiction-specific, or archaic legal terms may be interpreted incorrectly. The model is not a lawyer. + +4. **Context window limits truncate long contracts.** Summary and clause extraction are capped at 8 000–12 000 characters. Very long contracts (100+ pages) will have their later sections underweighted. + +5. **Embeddings may not capture domain-specific meaning.** The `all-MiniLM-L6-v2` model was not trained on legal text specifically; niche legal terms may not retrieve optimally. + +6. **Always verify with a qualified attorney.** This tool helps you know WHAT to look for and WHERE to look. It does not replace professional legal review. + +--- + +## Project Structure + +``` +02-legal-ai-assistant/ +β”œβ”€β”€ README.md ← this file +β”œβ”€β”€ requirements.txt +β”œβ”€β”€ .env.example +β”œβ”€β”€ data/ +β”‚ └── sample_contracts/ ← place your PDF/DOCX files here +β”œβ”€β”€ src/ +β”‚ β”œβ”€β”€ __init__.py +β”‚ β”œβ”€β”€ document_parser.py ← PDF/DOCX β†’ structured text +β”‚ β”œβ”€β”€ indexer.py ← text β†’ FAISS vector index +β”‚ β”œβ”€β”€ summarizer.py ← executive summary generation +β”‚ β”œβ”€β”€ clause_extractor.py ← named clause extraction +β”‚ β”œβ”€β”€ risk_analyzer.py ← HIGH/MEDIUM/LOW risk scoring +β”‚ β”œβ”€β”€ conflict_detector.py ← internal contradiction detection +β”‚ └── qa_chain.py ← RAG Q&A chain +β”œβ”€β”€ prompts/ +β”‚ β”œβ”€β”€ summary_prompt.txt +β”‚ β”œβ”€β”€ clause_prompt.txt +β”‚ └── risk_prompt.txt +└── main.py ← CLI entry point +``` + +--- + +## Dependencies + +| Package | Purpose | +|---------|---------| +| `langchain` + `langchain-community` + `langchain-openai` | LLM orchestration and RAG chains | +| `faiss-cpu` | Local vector similarity search | +| `sentence-transformers` | HuggingFace embedding model (runs locally) | +| `pypdf` | PDF text extraction | +| `python-docx` | DOCX parsing | +| `openai` | OpenAI API client | +| `python-dotenv` | `.env` file loading | +| `pydantic` | Data validation | +| `rich` | Formatted terminal output | diff --git a/02-legal-ai-assistant/data/sample_contracts/.gitkeep b/02-legal-ai-assistant/data/sample_contracts/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/02-legal-ai-assistant/main.py b/02-legal-ai-assistant/main.py new file mode 100644 index 0000000..40a9d10 --- /dev/null +++ b/02-legal-ai-assistant/main.py @@ -0,0 +1,377 @@ +""" +main.py β€” Legal AI Assistant Entry Point + +Full analysis pipeline for legal contracts: + 1. Parse document β†’ structured text + sections + 2. Index for RAG β†’ FAISS vector store + 3. Summarize β†’ executive summary (parties, type, obligations) + 4. Extract clauses β†’ indemnification, IP, termination, etc. + 5. Analyze risks β†’ HIGH/MEDIUM/LOW risk flags + 6. Detect conflictsβ†’ internal contradictions + 7. Q&A β†’ answer specific questions or enter interactive mode + +Usage examples: + python main.py --file data/sample_contracts/service_agreement.pdf + python main.py --file contract.pdf --model gpt-3.5-turbo --skip-conflicts + python main.py --file contract.pdf --question "What are my termination rights?" + python main.py --file contract.pdf --interactive +""" + +import argparse +import os +import sys + +from dotenv import load_dotenv +from rich.console import Console +from rich.panel import Panel +from rich.rule import Rule +from rich.table import Table +from rich import box + +# --------------------------------------------------------------------------- +# Load environment variables from .env file +# --------------------------------------------------------------------------- +load_dotenv() + +console = Console() + + +# --------------------------------------------------------------------------- +# Argument parsing +# --------------------------------------------------------------------------- + +def build_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser( + description="Legal AI Assistant β€” contract analysis powered by LLMs + RAG", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + python main.py --file contract.pdf + python main.py --file contract.pdf --model gpt-3.5-turbo --skip-risks + python main.py --file contract.pdf --question "Who owns the IP?" + python main.py --file contract.pdf --interactive + """, + ) + parser.add_argument( + "--file", + type=str, + help="Path to the contract file (PDF or DOCX). Required unless --interactive.", + ) + parser.add_argument( + "--model", + type=str, + default=os.getenv("OPENAI_MODEL", "gpt-4"), + help="OpenAI model to use (default: gpt-4).", + ) + parser.add_argument( + "--skip-risks", + action="store_true", + help="Skip the risk analysis step.", + ) + parser.add_argument( + "--skip-conflicts", + action="store_true", + help="Skip the conflict detection step.", + ) + parser.add_argument( + "--question", + type=str, + default=None, + help="Ask a single question about the contract and exit.", + ) + parser.add_argument( + "--interactive", + action="store_true", + help="Start an interactive Q&A loop after analysis.", + ) + return parser + + +# --------------------------------------------------------------------------- +# Rich display helpers +# --------------------------------------------------------------------------- + +def print_disclaimer() -> None: + """Print the mandatory legal disclaimer prominently.""" + console.print( + Panel( + "⚠️ [bold yellow]DISCLAIMER[/bold yellow]\n\n" + "This tool is for [bold]educational purposes only[/bold]. " + "It does [bold red]NOT[/bold red] constitute legal advice.\n" + "Always consult a qualified attorney before making any legal or business decisions.", + title="[bold red]LEGAL NOTICE[/bold red]", + border_style="red", + padding=(1, 4), + ) + ) + + +def print_section(title: str) -> None: + console.print(Rule(f"[bold cyan]{title}[/bold cyan]", style="cyan")) + + +def print_summary(summary: dict) -> None: + from src.summarizer import format_summary_output + console.print( + Panel( + format_summary_output(summary), + title="[bold green]Executive Summary[/bold green]", + border_style="green", + padding=(1, 2), + ) + ) + + +def print_clauses(clauses: list) -> None: + if not clauses: + console.print("[dim]No clauses extracted.[/dim]") + return + + table = Table( + title="Extracted Clauses", + box=box.ROUNDED, + show_lines=True, + style="blue", + ) + table.add_column("Type", style="bold cyan", no_wrap=True, min_width=20) + table.add_column("Section", style="dim", min_width=10) + table.add_column("Plain English", style="white") + + for clause in clauses: + table.add_row( + clause.get("clause_type", "").replace("_", " ").title(), + clause.get("section_reference", "Unknown"), + clause.get("plain_english", ""), + ) + + console.print(table) + + +def print_risks(risks: list) -> None: + from src.risk_analyzer import format_risk_output + console.print( + Panel( + format_risk_output(risks), + title="[bold red]Risk Analysis[/bold red]", + border_style="red", + padding=(1, 2), + ) + ) + + +def print_conflicts(conflicts: list) -> None: + from src.conflict_detector import format_conflicts_output + console.print( + Panel( + format_conflicts_output(conflicts), + title="[bold yellow]Conflict Detection[/bold yellow]", + border_style="yellow", + padding=(1, 2), + ) + ) + + +# --------------------------------------------------------------------------- +# Interactive Q&A loop +# --------------------------------------------------------------------------- + +def run_interactive_qa(qa_chain) -> None: + """ + Enter a REPL-style loop so the user can ask multiple questions about + the contract without re-running the full analysis each time. + """ + console.print( + Panel( + "Type your question and press [bold]Enter[/bold].\n" + "Type [bold]'exit'[/bold] or [bold]'quit'[/bold] to stop.\n\n" + "Sample questions:\n" + " β€’ What are my termination rights?\n" + " β€’ Who owns the IP I create?\n" + " β€’ What is the liability cap?\n" + " β€’ How does auto-renewal work?\n" + " β€’ What information must I keep confidential?", + title="[bold cyan]Interactive Q&A Mode[/bold cyan]", + border_style="cyan", + padding=(1, 2), + ) + ) + + from src.qa_chain import ask_question + + while True: + try: + question = console.input("\n[bold cyan]Question >[/bold cyan] ").strip() + except (KeyboardInterrupt, EOFError): + console.print("\n[dim]Exiting Q&A mode.[/dim]") + break + + if not question: + continue + if question.lower() in ("exit", "quit", "q"): + console.print("[dim]Exiting Q&A mode.[/dim]") + break + + with console.status("[bold green]Thinking...[/bold green]"): + answer = ask_question(question, qa_chain) + + console.print( + Panel( + answer, + title="[bold green]Answer[/bold green]", + border_style="green", + padding=(1, 2), + ) + ) + + +# --------------------------------------------------------------------------- +# Main pipeline +# --------------------------------------------------------------------------- + +def main() -> None: + parser = build_parser() + args = parser.parse_args() + + # Validate API key + api_key = os.getenv("OPENAI_API_KEY") + if not api_key or api_key == "your_openai_api_key_here": + console.print( + "[bold red]ERROR:[/bold red] OPENAI_API_KEY not set. " + "Copy .env.example to .env and add your key." + ) + sys.exit(1) + + # Validate file argument + if not args.file: + console.print( + "[bold red]ERROR:[/bold red] --file is required. " + "Provide a path to a PDF or DOCX contract." + ) + parser.print_help() + sys.exit(1) + + if not os.path.exists(args.file): + console.print(f"[bold red]ERROR:[/bold red] File not found: {args.file}") + sys.exit(1) + + # ── Startup banner ────────────────────────────────────────────────────── + console.print( + Panel( + "[bold white]Legal AI Assistant[/bold white]\n" + f"[dim]Contract:[/dim] {os.path.basename(args.file)}\n" + f"[dim]Model :[/dim] {args.model}", + border_style="white", + padding=(1, 4), + ) + ) + print_disclaimer() + + # ── Import modules here to keep startup fast for --help ───────────────── + from langchain_openai import ChatOpenAI + from src.document_parser import parse_legal_document + from src.indexer import index_document, get_retriever + from src.summarizer import generate_summary + from src.clause_extractor import extract_clauses + from src.risk_analyzer import analyze_risks + from src.conflict_detector import detect_conflicts + from src.qa_chain import build_qa_chain, ask_question + + # Initialise LLM + llm = ChatOpenAI( + model=args.model, + temperature=0, # deterministic output for legal analysis + openai_api_key=api_key, + ) + + # ── Step 1: Parse document ─────────────────────────────────────────────── + print_section("Step 1 β€” Parsing Document") + with console.status("[bold green]Parsing document...[/bold green]"): + doc = parse_legal_document(args.file) + + console.print( + f" βœ… Parsed [bold]{doc['file_name']}[/bold] β€” " + f"{doc['page_count']} page(s), {len(doc['sections'])} section(s) detected" + ) + + # ── Step 2: Index for RAG ──────────────────────────────────────────────── + print_section("Step 2 β€” Building Vector Index") + index_path = f"legal_index_{os.path.splitext(doc['file_name'])[0]}" + with console.status("[bold green]Indexing document...[/bold green]"): + vector_store = index_document(args.file, index_path=index_path) + retriever = get_retriever(vector_store, k=4) + console.print(f" βœ… Index built at [bold]{index_path}/[/bold]") + + # ── Step 3: Executive Summary ──────────────────────────────────────────── + print_section("Step 3 β€” Executive Summary") + with console.status("[bold green]Generating summary...[/bold green]"): + summary = generate_summary(doc["full_text"], llm) + print_summary(summary) + + # ── Step 4: Clause Extraction ──────────────────────────────────────────── + print_section("Step 4 β€” Key Clause Extraction") + with console.status("[bold green]Extracting clauses...[/bold green]"): + clauses = extract_clauses(doc["full_text"], llm) + console.print(f" βœ… {len(clauses)} clause(s) extracted") + print_clauses(clauses) + + # ── Step 5: Risk Analysis ──────────────────────────────────────────────── + if not args.skip_risks: + print_section("Step 5 β€” Risk Analysis") + with console.status("[bold green]Analyzing risks...[/bold green]"): + risks = analyze_risks(clauses, llm) + console.print(f" βœ… {len(risks)} risk(s) identified") + print_risks(risks) + else: + console.print("[dim]Risk analysis skipped (--skip-risks).[/dim]") + risks = [] + + # ── Step 6: Conflict Detection ─────────────────────────────────────────── + if not args.skip_conflicts: + print_section("Step 6 β€” Conflict Detection") + with console.status("[bold green]Detecting conflicts...[/bold green]"): + conflicts = detect_conflicts(clauses, llm) + console.print(f" βœ… {len(conflicts)} potential conflict(s) found") + print_conflicts(conflicts) + else: + console.print("[dim]Conflict detection skipped (--skip-conflicts).[/dim]") + + # ── Step 7: Q&A ────────────────────────────────────────────────────────── + qa_chain = build_qa_chain(retriever, llm) + + if args.question: + # Single question mode β€” answer and exit + print_section("Q&A β€” Single Question") + with console.status("[bold green]Thinking...[/bold green]"): + answer = ask_question(args.question, qa_chain) + console.print(f"\n[bold cyan]Q:[/bold cyan] {args.question}") + console.print( + Panel( + answer, + title="[bold green]Answer[/bold green]", + border_style="green", + padding=(1, 2), + ) + ) + + elif args.interactive: + print_section("Step 7 β€” Interactive Q&A") + run_interactive_qa(qa_chain) + + else: + console.print( + "\n[dim]Tip: run with [bold]--interactive[/bold] to ask follow-up questions, " + "or [bold]--question \"...[/bold]\" for a single query.[/dim]" + ) + + console.print( + Panel( + "βœ… Analysis complete.\n\n" + "[bold yellow]Reminder:[/bold yellow] Always verify findings with a qualified attorney.", + border_style="green", + padding=(1, 2), + ) + ) + + +if __name__ == "__main__": + main() diff --git a/02-legal-ai-assistant/prompts/clause_prompt.txt b/02-legal-ai-assistant/prompts/clause_prompt.txt new file mode 100644 index 0000000..a808ac6 --- /dev/null +++ b/02-legal-ai-assistant/prompts/clause_prompt.txt @@ -0,0 +1,24 @@ +You are a legal analyst specializing in contract review. Extract specific clause types from the following contract text. + +Contract text: +{contract_text} + +Identify and extract the following clause types if present: +- indemnification: Who must protect whom from losses +- limitation_of_liability: Caps on damages one party can claim +- termination: Conditions under which the contract can be ended +- governing_law: Which jurisdiction's laws govern the contract +- ip_ownership: Who owns intellectual property created under the contract +- confidentiality: Obligations to keep information secret + +For each clause found, respond with a JSON array: +[ + {{ + "clause_type": "indemnification", + "original_text": "The exact text from the contract", + "plain_english": "What this means in simple terms", + "section_reference": "Section number if identifiable, else 'Unknown'" + }} +] + +If a clause type is not found, omit it from the array. diff --git a/02-legal-ai-assistant/prompts/risk_prompt.txt b/02-legal-ai-assistant/prompts/risk_prompt.txt new file mode 100644 index 0000000..d382cd7 --- /dev/null +++ b/02-legal-ai-assistant/prompts/risk_prompt.txt @@ -0,0 +1,24 @@ +You are a legal risk analyst. Review the following contract clauses and identify risks. + +Contract clauses: +{clauses_text} + +For each risky clause, provide a risk assessment. Common risk patterns to look for: +- Unlimited liability (one party bears all risk with no cap) +- One-sided termination rights (only one party can terminate) +- Vague language (ambiguous terms that could be interpreted broadly) +- Auto-renewal traps (contracts that automatically renew without notice) +- Broad IP assignment (company owns everything you create, even on personal time) +- Non-compete clauses (restrictions on future employment) +- Unilateral modification rights (one party can change terms without consent) + +Respond with a JSON array: +[ + {{ + "clause_summary": "Brief description of the clause", + "risk_level": "HIGH | MEDIUM | LOW", + "risk_type": "Type of risk (e.g., unlimited_liability, one_sided_termination)", + "explanation": "Why this is risky and what a fair version would look like", + "original_text_excerpt": "The specific text that raises the concern" + }} +] diff --git a/02-legal-ai-assistant/prompts/summary_prompt.txt b/02-legal-ai-assistant/prompts/summary_prompt.txt new file mode 100644 index 0000000..4dfd70c --- /dev/null +++ b/02-legal-ai-assistant/prompts/summary_prompt.txt @@ -0,0 +1,14 @@ +You are a legal analyst. Read the following contract text and produce a structured executive summary. + +Contract text: +{contract_text} + +Respond ONLY with a JSON object in this exact format: +{{ + "parties": ["Party 1 name and role", "Party 2 name and role"], + "contract_type": "Type of contract (e.g., NDA, Service Agreement, Employment Contract)", + "effective_date": "Date the contract takes effect, or 'Not specified'", + "duration": "Contract duration or expiration, or 'Not specified'", + "key_obligations": ["Obligation 1", "Obligation 2", "Obligation 3"], + "summary": "Plain-English summary of the contract in 2-3 sentences, max 100 words" +}} diff --git a/02-legal-ai-assistant/requirements.txt b/02-legal-ai-assistant/requirements.txt new file mode 100644 index 0000000..6f52047 --- /dev/null +++ b/02-legal-ai-assistant/requirements.txt @@ -0,0 +1,11 @@ +langchain==0.1.20 +langchain-community==0.0.38 +langchain-openai==0.1.6 +faiss-cpu==1.8.0 +sentence-transformers==2.7.0 +pypdf==4.2.0 +python-docx==1.1.2 +openai==1.30.1 +python-dotenv==1.0.1 +pydantic==2.7.1 +rich==13.7.1 diff --git a/02-legal-ai-assistant/src/__init__.py b/02-legal-ai-assistant/src/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/02-legal-ai-assistant/src/clause_extractor.py b/02-legal-ai-assistant/src/clause_extractor.py new file mode 100644 index 0000000..d1a120e --- /dev/null +++ b/02-legal-ai-assistant/src/clause_extractor.py @@ -0,0 +1,150 @@ +""" +clause_extractor.py β€” Legal Clause Extraction + +Identifies and extracts specific, named clause types from a contract and +translates them into plain English. This is the "translation layer" between +dense legal prose and actionable information. + +Clause type reference: + indemnification β€” Party A must pay for losses caused to Party B. + e.g. "Vendor shall indemnify Client against all claims + arising from Vendor's performance of the Services." + β†’ "The vendor must cover any lawsuits or losses the + client suffers because of the vendor's work." + + limitation_of_liability β€” Maximum damages one party can claim. + e.g. "In no event shall either party's liability + exceed the fees paid in the prior 3 months." + β†’ "Neither side can sue for more than 3 months of + contract payments." + + termination β€” How and when the contract can be ended early. + e.g. "Either party may terminate with 30 days notice." + β†’ "Either side can cancel with a month's warning." + + governing_law β€” Which state/country's courts have jurisdiction. + e.g. "This Agreement shall be governed by the laws + of the State of New York." + β†’ "Disputes go to New York courts." + + ip_ownership β€” Who owns code, inventions, or designs created during + the contract. + e.g. "All work product created by Contractor shall + be deemed works made for hire owned by Company." + β†’ "Everything you build belongs to the company." + + confidentiality β€” What information must be kept secret and for how long. + e.g. "Each party agrees to keep Confidential + Information secret for 5 years after termination." + β†’ "Both sides must keep secrets for 5 years after + the contract ends." +""" + +import json +import re +from pathlib import Path + +from langchain.schema import HumanMessage + + +# --------------------------------------------------------------------------- +# Prompt loading +# --------------------------------------------------------------------------- + +def _load_clause_prompt() -> str: + """Load the clause extraction prompt from prompts/clause_prompt.txt.""" + prompt_path = Path(__file__).parent.parent / "prompts" / "clause_prompt.txt" + with open(prompt_path, "r", encoding="utf-8") as f: + return f.read() + + +# --------------------------------------------------------------------------- +# Core function +# --------------------------------------------------------------------------- + +def extract_clauses(contract_text: str, llm) -> list[dict]: + """ + Extract named clause types from a contract and provide plain-English + translations. + + Parameters + ---------- + contract_text : str β€” full or truncated contract text + llm : LLM β€” any LangChain-compatible chat model + + Returns + ------- + list of dicts, each with: + clause_type : str β€” one of the six named types above + original_text : str β€” verbatim text from the contract + plain_english : str β€” plain-language explanation + section_reference: str β€” section number or "Unknown" + + Returns an empty list if extraction fails or no clauses are found. + """ + prompt_template = _load_clause_prompt() + + # Use up to 12 000 chars β€” clause extraction needs more context than summary + # because clauses may appear anywhere across a long document. + truncated_text = contract_text[:12000] + if len(contract_text) > 12000: + truncated_text += "\n\n[... document truncated ...]" + + prompt = prompt_template.format(contract_text=truncated_text) + response = llm.invoke([HumanMessage(content=prompt)]) + raw_content = response.content if hasattr(response, "content") else str(response) + + # Strip markdown code fences if present + clean = raw_content.strip() + if clean.startswith("```"): + clean = re.sub(r"^```(?:json)?\s*", "", clean, flags=re.MULTILINE) + clean = re.sub(r"```\s*$", "", clean, flags=re.MULTILINE) + clean = clean.strip() + + try: + clauses = json.loads(clean) + # Ensure we always return a list + if isinstance(clauses, dict): + clauses = [clauses] + return clauses + except json.JSONDecodeError: + # Return empty list rather than crashing β€” downstream code handles this + print(f"[ClauseExtractor] WARNING: Could not parse LLM response as JSON.\n{raw_content[:300]}") + return [] + + +# --------------------------------------------------------------------------- +# Display formatting +# --------------------------------------------------------------------------- + +def format_clauses_output(clauses: list) -> str: + """ + Format extracted clauses as a human-readable string for terminal output. + + Parameters + ---------- + clauses : list β€” result from extract_clauses() + + Returns + ------- + str β€” multi-line formatted clause list + """ + if not clauses: + return "No clauses extracted (or extraction failed)." + + lines = [] + for i, clause in enumerate(clauses, start=1): + clause_type = clause.get("clause_type", "unknown").replace("_", " ").title() + section = clause.get("section_reference", "Unknown") + plain = clause.get("plain_english", "") + original = clause.get("original_text", "") + + lines.append(f"[{i}] {clause_type} (Section: {section})") + lines.append(f" Plain English: {plain}") + # Truncate long original text for display purposes + if len(original) > 200: + original = original[:200] + "..." + lines.append(f" Original Text: {original}") + lines.append("") # blank line between clauses + + return "\n".join(lines) diff --git a/02-legal-ai-assistant/src/conflict_detector.py b/02-legal-ai-assistant/src/conflict_detector.py new file mode 100644 index 0000000..97d15fb --- /dev/null +++ b/02-legal-ai-assistant/src/conflict_detector.py @@ -0,0 +1,174 @@ +""" +conflict_detector.py β€” Contract Clause Conflict Detection + +Compares extracted clauses against each other to surface internal +contradictions β€” places where one part of the contract conflicts with +another part. These inconsistencies are a common source of disputes. + +⚠️ IMPORTANT DISCLAIMER ⚠️ +──────────────────────────────────────────────────────────────────────────── +LLM-based conflict detection is NOT 100% reliable. The model may: + β€’ Miss conflicts that require deep legal domain expertise to spot. + β€’ Flag "conflicts" that are actually intentional or legally complementary. + β€’ Fail on highly technical or jurisdiction-specific language. + +This tool helps you KNOW WHAT TO LOOK FOR and directs your attention to +potentially problematic areas. It is NOT a substitute for a qualified +attorney's review. Always verify flagged conflicts with a licensed lawyer +before making any legal or business decisions. +──────────────────────────────────────────────────────────────────────────── + +Common conflict patterns this module targets: + + 1. NOTICE PERIOD MISMATCH + Termination clause: "30 days written notice required." + Payment clause: "Invoices are due 60 days after notice of termination." + β†’ You're legally required to pay for 60 days but can only terminate in 30. + + 2. CONFIDENTIALITY vs DEFINITION CONFLICT + Definition section: "Confidential Information means only written materials + marked CONFIDENTIAL." + Confidentiality clause: "All information disclosed, including oral + communications, is confidential." + β†’ The definition is narrower than what the clause protects. + + 3. TERMINATION vs AUTO-RENEWAL + Termination clause: "Either party may terminate on 30 days notice." + Auto-renewal clause: "This Agreement auto-renews annually unless notice + is given 90 days before expiry." + β†’ You need 90 days notice for auto-renewal but only 30 for termination β€” + which governs if the contract expires and auto-renews in 35 days? + + 4. IP OWNERSHIP vs CONFIDENTIALITY + IP clause: "All work product is owned by Company and may be used freely." + Confidentiality clause: "All work product is Confidential Information + and must not be disclosed." + β†’ Company claims ownership AND confidentiality β€” can they publish your work? +""" + +import json +import re + +from langchain.schema import HumanMessage + + +# --------------------------------------------------------------------------- +# LLM prompt (inline β€” short enough not to warrant a separate .txt file) +# --------------------------------------------------------------------------- + +_CONFLICT_PROMPT = """ +You are a legal contract analyst. Review the following extracted contract clauses +and identify any internal conflicts or contradictions between them. + +Extracted clauses: +{clauses_json} + +Look for conflicts such as: +- Different notice periods for the same event mentioned in two different clauses +- A definition that contradicts how a term is used elsewhere +- A termination clause that conflicts with an auto-renewal clause +- An IP ownership clause that contradicts a confidentiality clause +- Different liability caps stated in different sections +- Inconsistent governing law references + +Respond with a JSON array. If no conflicts are found, return an empty array []. +[ + {{ + "conflict_type": "Short name for the type of conflict (e.g. notice_period_mismatch)", + "clause_a": "Description or quote from the first clause", + "clause_b": "Description or quote from the conflicting clause", + "description": "Plain-English explanation of why these clauses conflict and the practical impact" + }} +] +""".strip() + + +# --------------------------------------------------------------------------- +# Core function +# --------------------------------------------------------------------------- + +def detect_conflicts(clauses: list[dict], llm) -> list[dict]: + """ + Use an LLM to compare extracted clauses for internal contradictions. + + Parameters + ---------- + clauses : list[dict] β€” output from clause_extractor.extract_clauses() + llm : LLM β€” any LangChain-compatible chat model + + Returns + ------- + list of dicts, each with: + conflict_type : str β€” short category label + clause_a : str β€” description/quote from first clause + clause_b : str β€” description/quote from conflicting clause + description : str β€” plain-English explanation of the conflict + + Returns an empty list if no conflicts are found or detection fails. + + ⚠️ See module docstring for reliability limitations. + """ + if not clauses: + return [] + + clauses_json = json.dumps(clauses, indent=2) + prompt = _CONFLICT_PROMPT.format(clauses_json=clauses_json) + + response = llm.invoke([HumanMessage(content=prompt)]) + raw_content = response.content if hasattr(response, "content") else str(response) + + # Strip markdown code fences + clean = raw_content.strip() + if clean.startswith("```"): + clean = re.sub(r"^```(?:json)?\s*", "", clean, flags=re.MULTILINE) + clean = re.sub(r"```\s*$", "", clean, flags=re.MULTILINE) + clean = clean.strip() + + try: + conflicts = json.loads(clean) + if isinstance(conflicts, dict): + conflicts = [conflicts] + return conflicts + except json.JSONDecodeError: + print(f"[ConflictDetector] WARNING: Could not parse LLM response as JSON.\n{raw_content[:300]}") + return [] + + +# --------------------------------------------------------------------------- +# Display formatting +# --------------------------------------------------------------------------- + +def format_conflicts_output(conflicts: list) -> str: + """ + Format detected conflicts as a human-readable terminal string. + + Parameters + ---------- + conflicts : list β€” result from detect_conflicts() + + Returns + ------- + str β€” formatted conflict report + """ + if not conflicts: + return "βšͺ No internal conflicts detected." + + lines = [ + "⚠️ The following potential conflicts were detected.", + " Verify each finding with a qualified attorney before acting on it.", + "", + ] + + for i, conflict in enumerate(conflicts, start=1): + conflict_type = conflict.get("conflict_type", "unknown").replace("_", " ").title() + clause_a = conflict.get("clause_a", "") + clause_b = conflict.get("clause_b", "") + description = conflict.get("description", "") + + lines.append(f"⚑ [{i}] {conflict_type}") + lines.append(f" Clause A : {clause_a}") + lines.append(f" Clause B : {clause_b}") + lines.append(f" Impact : {description}") + lines.append("") + + return "\n".join(lines) diff --git a/02-legal-ai-assistant/src/document_parser.py b/02-legal-ai-assistant/src/document_parser.py new file mode 100644 index 0000000..5ead578 --- /dev/null +++ b/02-legal-ai-assistant/src/document_parser.py @@ -0,0 +1,216 @@ +""" +document_parser.py β€” Legal Document Parser + +Handles loading and parsing of PDF and DOCX contract files into structured text. + +Why section structure matters for legal documents: + Legal contracts are highly cross-referential. A clause in Section 8 may say + "subject to Section 4.2", so knowing WHICH section a piece of text belongs to + is critical for accurate clause extraction and conflict detection. If we just + dumped all text together we'd lose those structural anchors. + +Limitations of PDF text extraction: + PyPDF (and most PDF parsers) extract raw text by reading the PDF's character + stream. This means: + - Tables lose their column alignment and become garbled rows of text. + - Headers/footers repeat on every page, creating noise. + - Some scanned PDFs produce no text at all (OCR required separately). + - Footnotes sometimes appear inline mid-sentence rather than at the bottom. + These limitations mean downstream analysis must tolerate imperfect text. + +Why legal document structure is important for accurate clause extraction: + LLMs asked to find "the termination clause" perform significantly better when + the prompt includes section headings because headings act as semantic anchors. + Without them, the model may conflate a termination clause buried in an exhibit + with the main termination provisions, leading to inaccurate summaries. +""" + +import os +import re +from typing import Optional + +# PyPDFLoader uses pypdf under the hood β€” handles multi-page PDFs gracefully +from langchain_community.document_loaders import PyPDFLoader + +# python-docx for Microsoft Word (.docx) files +from docx import Document as DocxDocument + + +# --------------------------------------------------------------------------- +# Internal helpers +# --------------------------------------------------------------------------- + +def _detect_heading(line: str) -> bool: + """ + Heuristic: a line is treated as a section heading if it matches any of: + 1. Numbered section β€” "1.", "1.1", "2.3.4", etc. + 2. ALL-CAPS line β€” "INDEMNIFICATION", "LIMITATION OF LIABILITY" + 3. Trailing colon β€” "Governing Law:", "Notice:" + These patterns cover the vast majority of standard contract heading styles. + """ + line = line.strip() + if not line: + return False + # Pattern 1: numbered section (e.g. "1.", "2.1", "10.3.2") + if re.match(r"^\d+(\.\d+)*\.?\s+\S", line): + return True + # Pattern 2: all-caps (ignoring punctuation/spaces, at least 3 chars of alpha) + alpha_only = re.sub(r"[^A-Za-z]", "", line) + if len(alpha_only) >= 3 and alpha_only == alpha_only.upper(): + return True + # Pattern 3: ends with colon + if line.endswith(":"): + return True + return False + + +def _split_into_sections(full_text: str) -> list[dict]: + """ + Walk through lines of text and group them into sections based on headings. + Returns a list of section dicts. Each dict has: + heading β€” the heading text (or "Preamble" for leading content) + content β€” the body text under that heading + page_num β€” approximate page number (estimated by form-feed character '\x0c') + """ + sections = [] + current_heading = "Preamble" + current_lines: list[str] = [] + page_num = 1 + + for line in full_text.splitlines(): + # pypdf uses form-feed (\x0c) as a page separator + if "\x0c" in line: + page_num += line.count("\x0c") + line = line.replace("\x0c", "") + + if _detect_heading(line): + # Save the previous section before starting a new one + if current_lines: + sections.append({ + "heading": current_heading, + "content": "\n".join(current_lines).strip(), + "page_num": page_num, + }) + current_heading = line.strip() + current_lines = [] + else: + current_lines.append(line) + + # Don't forget the last section + if current_lines: + sections.append({ + "heading": current_heading, + "content": "\n".join(current_lines).strip(), + "page_num": page_num, + }) + + return sections + + +# --------------------------------------------------------------------------- +# PDF parsing +# --------------------------------------------------------------------------- + +def _parse_pdf(file_path: str) -> dict: + """ + Load a PDF using LangChain's PyPDFLoader (backed by pypdf). + Each LangChain Document corresponds to one PDF page. + We concatenate all pages and then split by detected headings. + """ + loader = PyPDFLoader(file_path) + pages = loader.load() # list of langchain Document objects, one per page + + full_text = "\n".join(page.page_content for page in pages) + page_count = len(pages) + + sections = _split_into_sections(full_text) + + return { + "full_text": full_text, + "sections": sections, + "file_name": os.path.basename(file_path), + "page_count": page_count, + } + + +# --------------------------------------------------------------------------- +# DOCX parsing +# --------------------------------------------------------------------------- + +def _parse_docx(file_path: str) -> dict: + """ + Load a DOCX file using python-docx. + Word documents store paragraphs with an explicit style name; we use + "Heading" styles as section delimiters when present, falling back to + the same heuristic used for PDFs. + """ + doc = DocxDocument(file_path) + + full_lines: list[str] = [] + for para in doc.paragraphs: + full_lines.append(para.text) + + full_text = "\n".join(full_lines) + + # Approximate page count: Word doesn't expose pages easily via python-docx; + # we use a rough heuristic (every ~40 paragraphs β‰ˆ 1 page). + page_count = max(1, len(doc.paragraphs) // 40) + + sections = _split_into_sections(full_text) + + return { + "full_text": full_text, + "sections": sections, + "file_name": os.path.basename(file_path), + "page_count": page_count, + } + + +# --------------------------------------------------------------------------- +# Public API +# --------------------------------------------------------------------------- + +def parse_legal_document(file_path: str) -> dict: + """ + Parse a legal document (PDF or DOCX) into structured text. + + Parameters + ---------- + file_path : str + Absolute or relative path to the contract file. + + Returns + ------- + dict with keys: + full_text : str β€” complete raw text of the document + sections : list[dict] β€” list of {heading, content, page_num} dicts + file_name : str β€” basename of the file + page_count : int β€” number of pages (PDF) or estimated pages (DOCX) + + Raises + ------ + ValueError if the file type is not supported. + FileNotFoundError if the file does not exist. + """ + if not os.path.exists(file_path): + raise FileNotFoundError(f"Contract file not found: {file_path}") + + ext = os.path.splitext(file_path)[1].lower() + + if ext == ".pdf": + return _parse_pdf(file_path) + elif ext in (".docx", ".doc"): + return _parse_docx(file_path) + else: + raise ValueError( + f"Unsupported file type '{ext}'. Supported types: .pdf, .docx" + ) + + +def extract_full_text(file_path: str) -> str: + """ + Convenience wrapper: parse a document and return only the full text string. + Useful when callers don't need the structured section breakdown. + """ + result = parse_legal_document(file_path) + return result["full_text"] diff --git a/02-legal-ai-assistant/src/indexer.py b/02-legal-ai-assistant/src/indexer.py new file mode 100644 index 0000000..1172b33 --- /dev/null +++ b/02-legal-ai-assistant/src/indexer.py @@ -0,0 +1,158 @@ +""" +indexer.py β€” Vector-Store Indexing for Legal Documents + +This module is intentionally similar to the RAG indexer from Project 1 +(01-rag-from-scratch). The same pattern β€” chunk β†’ embed β†’ store in FAISS β€” +works equally well for legal documents. The only difference is that legal +chunks benefit from slightly larger sizes because legal sentences are long +and context-dependent (a 512-token chunk mid-clause may miss the crucial +subject from the sentence above). + +Reuse note: + If you already built a FAISS index in Project 1, the load_index() and + get_retriever() helpers here are identical. The legal domain just requires + a different source document and potentially different chunk sizes. +""" + +import os + +# HuggingFaceEmbeddings runs locally β€” no API key needed for embedding. +# We default to "all-MiniLM-L6-v2" which is fast and good for semantic search. +from langchain_community.embeddings import HuggingFaceEmbeddings +from langchain_community.vectorstores import FAISS + +# RecursiveCharacterTextSplitter tries to split on paragraphs, then sentences, +# then words β€” preserving as much semantic context as possible per chunk. +from langchain.text_splitter import RecursiveCharacterTextSplitter + +from langchain_community.document_loaders import PyPDFLoader +from langchain.schema import Document + + +# --------------------------------------------------------------------------- +# Chunking configuration +# --------------------------------------------------------------------------- + +# Legal sentences are verbose; 1200-char chunks with 200-char overlap keeps +# clauses intact while still providing sufficient retrieval granularity. +CHUNK_SIZE = 1200 +CHUNK_OVERLAP = 200 + +# Embedding model β€” same as Project 1, works well for legal text +EMBEDDING_MODEL = "all-MiniLM-L6-v2" + + +def _get_embeddings() -> HuggingFaceEmbeddings: + """Return a HuggingFaceEmbeddings instance (downloaded on first call).""" + return HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL) + + +def _chunk_text(full_text: str) -> list[Document]: + """ + Split raw contract text into overlapping chunks suitable for embedding. + + The RecursiveCharacterTextSplitter cascades through separators: + ["\n\n", "\n", " ", ""] β€” so it prefers to break at paragraph boundaries, + then line breaks, then spaces. This keeps sentences from being split in + the middle of a legal obligation where the subject is at the start and + the verb is at the end. + """ + splitter = RecursiveCharacterTextSplitter( + chunk_size=CHUNK_SIZE, + chunk_overlap=CHUNK_OVERLAP, + separators=["\n\n", "\n", ". ", " ", ""], + ) + # Wrap in a single Document so the splitter returns Document objects + docs = splitter.create_documents([full_text]) + return docs + + +# --------------------------------------------------------------------------- +# Public API +# --------------------------------------------------------------------------- + +def index_document(file_path: str, index_path: str = "legal_faiss_index") -> FAISS: + """ + Parse, chunk, embed, and persist a contract document to a FAISS index. + + Parameters + ---------- + file_path : str β€” path to the PDF/DOCX contract + index_path : str β€” directory where the FAISS index will be saved + + Returns + ------- + FAISS vector store ready for similarity search. + """ + # --- Step 1: Load raw text --- + # Use PyPDFLoader for PDFs; for DOCX we read via document_parser then wrap + ext = os.path.splitext(file_path)[1].lower() + if ext == ".pdf": + loader = PyPDFLoader(file_path) + raw_docs = loader.load() + full_text = "\n".join(d.page_content for d in raw_docs) + else: + # For non-PDF files, fall back to document_parser's full_text + from src.document_parser import extract_full_text + full_text = extract_full_text(file_path) + + # --- Step 2: Chunk --- + chunks = _chunk_text(full_text) + print(f"[Indexer] Created {len(chunks)} chunks from '{os.path.basename(file_path)}'") + + # --- Step 3: Embed & build FAISS index --- + embeddings = _get_embeddings() + vector_store = FAISS.from_documents(chunks, embeddings) + + # --- Step 4: Persist to disk --- + os.makedirs(index_path, exist_ok=True) + vector_store.save_local(index_path) + print(f"[Indexer] Index saved to '{index_path}'") + + return vector_store + + +def load_index(index_path: str = "legal_faiss_index") -> FAISS: + """ + Load a previously saved FAISS index from disk. + + Parameters + ---------- + index_path : str β€” directory containing the saved FAISS index files + + Returns + ------- + FAISS vector store ready for similarity search. + """ + if not os.path.exists(index_path): + raise FileNotFoundError( + f"No FAISS index found at '{index_path}'. " + "Run index_document() first to create the index." + ) + embeddings = _get_embeddings() + vector_store = FAISS.load_local( + index_path, + embeddings, + allow_dangerous_deserialization=True, # required by newer LangChain versions + ) + print(f"[Indexer] Loaded index from '{index_path}'") + return vector_store + + +def get_retriever(vector_store: FAISS, k: int = 4): + """ + Wrap a FAISS vector store as a LangChain retriever. + + Parameters + ---------- + vector_store : FAISS β€” the in-memory or loaded vector store + k : int β€” number of chunks to retrieve per query (default 4) + + Returns + ------- + A LangChain BaseRetriever that can be plugged into any chain. + + Note: k=4 is a good balance for legal Q&A β€” enough context to answer most + clause-level questions without exceeding typical context-window limits. + """ + return vector_store.as_retriever(search_kwargs={"k": k}) diff --git a/02-legal-ai-assistant/src/qa_chain.py b/02-legal-ai-assistant/src/qa_chain.py new file mode 100644 index 0000000..1c8cdb6 --- /dev/null +++ b/02-legal-ai-assistant/src/qa_chain.py @@ -0,0 +1,121 @@ +""" +qa_chain.py β€” Retrieval-Augmented Q&A for Legal Documents + +Builds a RAG Q&A chain that answers questions about a contract by: + 1. Retrieving the most relevant chunks from the FAISS index + 2. Sending those chunks + the user's question to the LLM + 3. Requiring the model to cite the specific section it's referencing + +Why source citation is critical in legal Q&A: + Unlike a general knowledge chatbot, a legal assistant's answers directly + influence decisions with real financial and legal consequences. If a user + asks "Can I terminate in 30 days?" and the model answers "Yes" without + citing the source clause, the user cannot verify whether that answer is + based on the actual contract or a hallucination. Forcing the model to + cite sections: + β€’ Lets the user cross-check against the original document. + β€’ Makes hallucinations easier to spot (the cited section won't exist). + β€’ Builds appropriate trust β€” the user knows WHAT to verify, not just + WHETHER to trust. + + This is fundamentally different from Q&A over, say, a technical manual, + where a wrong answer is inconvenient. A wrong answer about a contract + clause can result in a breach of contract, lawsuit, or financial loss. +""" + +from langchain.chains import RetrievalQA +from langchain.prompts import PromptTemplate + + +# --------------------------------------------------------------------------- +# Custom legal Q&A prompt +# --------------------------------------------------------------------------- + +# The prompt explicitly: +# 1. Restricts answers to provided context (reduces hallucination) +# 2. Requires section citations (enables user verification) +# 3. Provides a safe fallback for out-of-scope questions +# 4. Includes the "not legal advice" disclaimer in every answer +_LEGAL_QA_TEMPLATE = """You are a legal document assistant. Answer questions about the contract based ONLY on the provided context. +Always cite the specific section or clause you are referencing. +If the answer is not in the provided context, say "This information is not found in the provided contract." + +Important: This is for informational purposes only, not legal advice. + +Context from contract: +{context} + +Question: {question} + +Answer (include section references):""" + +_QA_PROMPT = PromptTemplate( + template=_LEGAL_QA_TEMPLATE, + input_variables=["context", "question"], +) + + +# --------------------------------------------------------------------------- +# Public API +# --------------------------------------------------------------------------- + +def build_qa_chain(retriever, llm) -> RetrievalQA: + """ + Construct a RetrievalQA chain grounded in the indexed contract. + + Parameters + ---------- + retriever : BaseRetriever β€” from indexer.get_retriever() + llm : LLM β€” any LangChain-compatible chat model + + Returns + ------- + RetrievalQA chain ready to accept questions via .invoke() or .run() + + The chain uses "stuff" document combination strategy β€” it concatenates + retrieved chunks into a single context block. For very long contracts + "map_reduce" or "refine" strategies may be preferable, but "stuff" is + the most reliable for faithfully citing specific text. + """ + qa_chain = RetrievalQA.from_chain_type( + llm=llm, + chain_type="stuff", + retriever=retriever, + return_source_documents=True, # lets callers show which chunks were used + chain_type_kwargs={"prompt": _QA_PROMPT}, + ) + return qa_chain + + +def ask_question(question: str, qa_chain) -> str: + """ + Ask a natural-language question about the indexed contract. + + Parameters + ---------- + question : str β€” the user's question (e.g. "What are my termination rights?") + qa_chain : RetrievalQA β€” built by build_qa_chain() + + Returns + ------- + str β€” the model's answer with section citations + + The returned string always includes the LLM's answer. If source documents + were returned they are appended as a "Sources" footer so users can quickly + locate the referenced passage in the original document. + """ + result = qa_chain.invoke({"query": question}) + + answer = result.get("result", "No answer returned.") + + # Append source chunk references if available β€” helps users locate the + # exact passage that was used to generate the answer. + source_docs = result.get("source_documents", []) + if source_docs: + answer += "\n\n─── Sources (retrieved chunks) ───" + for i, doc in enumerate(source_docs, start=1): + # Show the first 150 chars of each source chunk as a reference hint + snippet = doc.page_content[:150].replace("\n", " ").strip() + answer += f"\n [{i}] ...{snippet}..." + + return answer diff --git a/02-legal-ai-assistant/src/risk_analyzer.py b/02-legal-ai-assistant/src/risk_analyzer.py new file mode 100644 index 0000000..2c2f0b2 --- /dev/null +++ b/02-legal-ai-assistant/src/risk_analyzer.py @@ -0,0 +1,161 @@ +""" +risk_analyzer.py β€” Contract Risk Analysis + +Scores extracted clauses for potential risks and explains WHY each clause +is risky and what a fair alternative would look like. + +Example of HIGH-RISK vs STANDARD clause language: + + HIGH RISK (indemnification): + "Employee agrees to indemnify, defend, and hold harmless Company and + its officers, directors, and employees from any and all claims, + losses, or damages, including those arising from Company's own + negligence or intentional misconduct." + β†’ This is dangerous: the employee bears the cost of the company's + OWN mistakes. + + STANDARD (indemnification): + "Each party shall indemnify and hold harmless the other party for + losses arising directly from that party's own negligence or + willful misconduct." + β†’ Fair: each side is responsible only for their own actions. + + HIGH RISK (IP ownership): + "Employee hereby assigns to Company all inventions, discoveries, + and works of authorship conceived or reduced to practice at any + time during employment, whether or not related to Company's business + and whether or not made during working hours." + β†’ "At any time" + "whether or not related" = company owns your + weekend side projects. + + STANDARD (IP ownership): + "Employee assigns to Company inventions that relate to Company's + business or are developed using Company resources or during + working hours." + β†’ Scoped to actual work-related output. + +Risk levels: + HIGH πŸ”΄ β€” Potential for significant financial or legal harm; seek + attorney review before signing. + MEDIUM 🟑 β€” Unusual or one-sided term; negotiate if possible. + LOW 🟒 β€” Minor concern; worth noting but unlikely to cause harm. +""" + +import json +import re +from pathlib import Path + +from langchain.schema import HumanMessage + + +# --------------------------------------------------------------------------- +# Prompt loading +# --------------------------------------------------------------------------- + +def _load_risk_prompt() -> str: + """Load the risk analysis prompt from prompts/risk_prompt.txt.""" + prompt_path = Path(__file__).parent.parent / "prompts" / "risk_prompt.txt" + with open(prompt_path, "r", encoding="utf-8") as f: + return f.read() + + +# --------------------------------------------------------------------------- +# Core function +# --------------------------------------------------------------------------- + +def analyze_risks(clauses: list[dict], llm) -> list[dict]: + """ + Analyze extracted clauses for legal and financial risks. + + Parameters + ---------- + clauses : list[dict] β€” output from clause_extractor.extract_clauses() + llm : LLM β€” any LangChain-compatible chat model + + Returns + ------- + list of dicts, each with: + clause_summary : str β€” brief description of the risky clause + risk_level : str β€” "HIGH", "MEDIUM", or "LOW" + risk_type : str β€” category (e.g. "unlimited_liability") + explanation : str β€” why it's risky + what fair looks like + original_text_excerpt: str β€” the specific concerning text + + Returns an empty list if analysis fails or no risks are found. + """ + if not clauses: + return [] + + prompt_template = _load_risk_prompt() + + # Serialize clauses to a readable text block for the prompt + clauses_text = json.dumps(clauses, indent=2) + + prompt = prompt_template.format(clauses_text=clauses_text) + response = llm.invoke([HumanMessage(content=prompt)]) + raw_content = response.content if hasattr(response, "content") else str(response) + + # Strip markdown fences + clean = raw_content.strip() + if clean.startswith("```"): + clean = re.sub(r"^```(?:json)?\s*", "", clean, flags=re.MULTILINE) + clean = re.sub(r"```\s*$", "", clean, flags=re.MULTILINE) + clean = clean.strip() + + try: + risks = json.loads(clean) + if isinstance(risks, dict): + risks = [risks] + return risks + except json.JSONDecodeError: + print(f"[RiskAnalyzer] WARNING: Could not parse LLM response as JSON.\n{raw_content[:300]}") + return [] + + +# --------------------------------------------------------------------------- +# Display formatting +# --------------------------------------------------------------------------- + +# Emoji indicators for risk levels β€” visible at a glance in terminal output +_RISK_EMOJI = { + "HIGH": "πŸ”΄", + "MEDIUM": "🟑", + "LOW": "🟒", +} + + +def format_risk_output(risks: list) -> str: + """ + Format risk analysis results as a human-readable terminal string. + + Parameters + ---------- + risks : list β€” result from analyze_risks() + + Returns + ------- + str β€” multi-line formatted risk report with emoji indicators + """ + if not risks: + return "No risks identified (or risk analysis failed)." + + lines = [] + for i, risk in enumerate(risks, start=1): + level = risk.get("risk_level", "UNKNOWN").upper() + emoji = _RISK_EMOJI.get(level, "βšͺ") + risk_type = risk.get("risk_type", "unknown").replace("_", " ").title() + summary = risk.get("clause_summary", "") + explanation = risk.get("explanation", "") + excerpt = risk.get("original_text_excerpt", "") + + lines.append(f"{emoji} [{i}] {level} RISK β€” {risk_type}") + lines.append(f" Clause : {summary}") + lines.append(f" Why : {explanation}") + if excerpt: + # Truncate long excerpts for readability + if len(excerpt) > 200: + excerpt = excerpt[:200] + "..." + lines.append(f" Text : \"{excerpt}\"") + lines.append("") + + return "\n".join(lines) diff --git a/02-legal-ai-assistant/src/summarizer.py b/02-legal-ai-assistant/src/summarizer.py new file mode 100644 index 0000000..57b1343 --- /dev/null +++ b/02-legal-ai-assistant/src/summarizer.py @@ -0,0 +1,146 @@ +""" +summarizer.py β€” Contract Executive Summary Generator + +Sends contract text to an LLM with a structured prompt and parses the JSON +response into a Python dict for downstream use and display. + +Example transformation: + Before (raw contract language): + "This Agreement shall commence on the Effective Date and shall continue + for a period of one (1) year unless sooner terminated..." + + After (plain-English summary field): + "This is a one-year service agreement between Acme Corp and Beta LLC. + Acme will provide software development services in exchange for monthly + payments. Either party may terminate with 30 days written notice." + +The structured JSON output (parties, contract_type, key_obligations, etc.) +makes it easy to build dashboards, comparison tools, or automated alerts on +top of this module without re-parsing free-form text. +""" + +import json +import os +import re +from pathlib import Path + +from langchain.schema import HumanMessage + + +# --------------------------------------------------------------------------- +# Prompt loading +# --------------------------------------------------------------------------- + +def _load_summary_prompt() -> str: + """Load the summary prompt template from prompts/summary_prompt.txt.""" + prompt_path = Path(__file__).parent.parent / "prompts" / "summary_prompt.txt" + with open(prompt_path, "r", encoding="utf-8") as f: + return f.read() + + +# --------------------------------------------------------------------------- +# Core function +# --------------------------------------------------------------------------- + +def generate_summary(contract_text: str, llm) -> dict: + """ + Generate a structured executive summary of a contract. + + Parameters + ---------- + contract_text : str β€” full text (or a representative excerpt) of the contract + llm : LLM β€” any LangChain-compatible chat model (e.g. ChatOpenAI) + + Returns + ------- + dict with keys: + parties : list[str] + contract_type : str + effective_date : str + duration : str + key_obligations : list[str] + summary : str + + Falls back to {"raw_response": } if JSON parsing fails, so callers + always receive a dict even when the model returns malformed output. + + Note: We cap input at 8000 characters. Most consumer LLMs have a ~4k-token + context window for GPT-3.5 or ~8k for GPT-4. 8000 chars β‰ˆ 2000 tokens, + leaving room for the prompt itself and the response. + """ + prompt_template = _load_summary_prompt() + + # Truncate to avoid exceeding model context limits + truncated_text = contract_text[:8000] + if len(contract_text) > 8000: + truncated_text += "\n\n[... document truncated for summary ...]" + + prompt = prompt_template.format(contract_text=truncated_text) + + # Invoke the LLM β€” works with both .invoke() (newer LangChain) and direct call + response = llm.invoke([HumanMessage(content=prompt)]) + raw_content = response.content if hasattr(response, "content") else str(response) + + # --- Parse JSON response --- + # The prompt instructs the model to return ONLY JSON, but it occasionally + # wraps it in ```json ... ``` markdown fences β€” strip those first. + clean = raw_content.strip() + if clean.startswith("```"): + # Remove opening fence (```json or ```) + clean = re.sub(r"^```(?:json)?\s*", "", clean, flags=re.MULTILINE) + # Remove closing fence + clean = re.sub(r"```\s*$", "", clean, flags=re.MULTILINE) + clean = clean.strip() + + try: + return json.loads(clean) + except json.JSONDecodeError: + # Graceful fallback: return raw text so the caller can still display something + return {"raw_response": raw_content} + + +# --------------------------------------------------------------------------- +# Display formatting +# --------------------------------------------------------------------------- + +def format_summary_output(summary: dict) -> str: + """ + Format a summary dict as a human-readable string for terminal output. + + Parameters + ---------- + summary : dict β€” result from generate_summary() + + Returns + ------- + str β€” multi-line formatted summary + """ + if "raw_response" in summary: + return f"[Raw LLM response β€” JSON parsing failed]\n\n{summary['raw_response']}" + + lines = [] + + contract_type = summary.get("contract_type", "Unknown") + lines.append(f"Contract Type : {contract_type}") + + parties = summary.get("parties", []) + if parties: + lines.append("Parties :") + for p in parties: + lines.append(f" β€’ {p}") + + lines.append(f"Effective Date: {summary.get('effective_date', 'Not specified')}") + lines.append(f"Duration : {summary.get('duration', 'Not specified')}") + + obligations = summary.get("key_obligations", []) + if obligations: + lines.append("Key Obligations:") + for ob in obligations: + lines.append(f" β€’ {ob}") + + plain_summary = summary.get("summary", "") + if plain_summary: + lines.append(f"\nSummary:\n {plain_summary}") + + return "\n".join(lines) + diff --git a/03-research-agent/.env.example b/03-research-agent/.env.example new file mode 100644 index 0000000..c5a27fa --- /dev/null +++ b/03-research-agent/.env.example @@ -0,0 +1,11 @@ +# OpenAI API Key (required) +OPENAI_API_KEY=your_openai_api_key_here + +# Model to use (gpt-4 recommended for research synthesis) +OPENAI_MODEL=gpt-4 + +# Optional: Anthropic Claude (alternative) +# ANTHROPIC_API_KEY=your_anthropic_key_here + +# Papers directory +PAPERS_DIR=data/papers diff --git a/03-research-agent/README.md b/03-research-agent/README.md new file mode 100644 index 0000000..0624998 --- /dev/null +++ b/03-research-agent/README.md @@ -0,0 +1,191 @@ +# 03 β€” Research Agent + +> **"Like a research assistant who can look up papers, take notes, and compare findings β€” rather than just answering one question."** + +## What is an AI Agent? + +A regular LLM call is a single prompt β†’ single response. You hand the model some text and it writes back. That's it. + +An **AI agent** is different: it has access to **tools** β€” functions it can call to look things up, compute things, or take actions β€” and it decides *dynamically* which tools to use based on each new sub-goal. + +Think of the difference between: +- πŸ€– **Simple LLM**: You ask "what does Paper A say about transformers?" and the model guesses from its training data. +- πŸ•΅οΈ **Research Agent**: You ask the same question, and the agent *looks it up*, reads the relevant sections, possibly *compares* them to Paper B, and synthesises an answer with citations. + +## How This Differs from Simple RAG + +| Simple RAG | Research Agent | +|---|---| +| Embed documents β†’ vector DB | Same | +| User query β†’ nearest chunks β†’ LLM answer | **Agent plans which tools to call** | +| Single retrieval step | **Multi-step: search β†’ summarise β†’ compare** | +| No memory between steps | **Observations from each step feed the next** | +| Good for Q&A | Good for synthesis, comparison, gap analysis | + +In simple RAG, the pipeline is fixed: retrieve then answer. In an agent, the LLM itself decides the pipeline at runtime. + +## The ReAct Loop Explained + +**ReAct = Reason + Act**. The agent alternates between thinking and doing: + +``` +Thought : I need to find papers about attention mechanisms. +Action : search_papers +Input : attention mechanism self-attention +Observation: [Result 1] Paper: "Attention Is All You Need" … + +Thought : I found the relevant paper. Now I'll get its full summary. +Action : summarize_paper +Input : Attention Is All You Need +Observation: Title: Attention Is All You Need, Authors: Vaswani et al. … + +Thought : I have enough to answer the question. +Final Answer: The paper "Attention Is All You Need" introduced … +``` + +Each **Observation** is the tool's output, appended to the agent's context. The agent re-reads the growing context at each step to decide what to do next. + +## Architecture + +``` +β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” +β”‚ main.py β”‚ +β”‚ (CLI: --query / --report / --interactive) β”‚ +β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ + β”‚ + β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β–Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” + β”‚ Research Agent β”‚ ← agent.py + β”‚ (ReAct loop + LLM) β”‚ + β””β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ + β”‚ β”‚ + β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β–Όβ”€β”€β” β”Œβ”€β”€β”€β–Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” + β”‚search_tool β”‚ β”‚ summary_tool β”‚ β”‚ compare_tool β”‚ + β”‚(FAISS β”‚ β”‚ (PaperMetadata β”‚ β”‚ (LLM comparisonβ”‚ + β”‚ semantic β”‚ β”‚ lookup) β”‚ β”‚ of two papers)β”‚ + β”‚ search) β”‚ β””β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β””β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”˜ + β””β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”˜ β”‚ β”‚ + β”‚ β”Œβ”€β”€β”€β”€β”€β”€β”€β–Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β–Όβ”€β”€β”€β”€β”€β”€β” + β”‚ β”‚ PaperMetadata objects β”‚ + β”‚ β”‚ (from paper_parser.py) β”‚ + β”‚ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ + β”Œβ”€β”€β”€β”€β”€β–Όβ”€β”€β”€β”€β”€β”€β” + β”‚ FAISS indexβ”‚ ← paper_indexer.py + β”‚ (chunked β”‚ + β”‚ PDFs + β”‚ + β”‚ metadata) β”‚ + β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ + + Gap Analysis (--report): + paper_metadata β†’ gap_analyzer.py β†’ LLM synthesis β†’ report_generator.py β†’ .md file +``` + +## Setup + +```bash +# 1. Clone / navigate to the project +cd 03-research-agent + +# 2. Create and activate a virtual environment +python -m venv venv +source venv/bin/activate # Windows: venv\Scripts\activate + +# 3. Install dependencies +pip install -r requirements.txt + +# 4. Configure environment +cp .env.example .env +# Edit .env and add your OPENAI_API_KEY + +# 5. Add research papers +# Copy your .pdf files into data/papers/ +``` + +## How to Add Papers + +Place any number of **.pdf** files into `data/papers/`. The pipeline will: +1. Extract text and LLM-parse metadata (title, authors, abstract, methodology, findings, limitations). +2. Chunk the full text and embed it into a FAISS vector index. +3. Make both the metadata and the full text available to the agent's tools. + +**Tips:** +- Use papers that are topically related for better gap analysis. +- 3–10 papers is the sweet spot. More than 20 may hit the LLM's context limit during gap analysis. +- Scanned PDFs without OCR will produce empty or garbled text β€” use PDFs with selectable text. + +## Running the Agent + +```bash +# Ask a single question and exit +python main.py --query "What methodologies are used across these papers?" + +# Start an interactive Q&A session +python main.py --interactive + +# Generate a gap analysis report +python main.py --topic "transformer models" --report + +# All options +python main.py --papers-dir data/papers \ + --topic "BERT fine-tuning" \ + --model gpt-4 \ + --report \ + --output reports/bert_gaps.md +``` + +## Sample Queries + +These questions showcase the agent's multi-step reasoning: + +``` +"What methodologies are used across these papers?" +"Which papers agree on X, and which contradict each other?" +"What are the main gaps in this research area?" +"Summarise the paper on [topic] and compare it to [other paper]." +"Which paper has the strongest experimental design?" +"What datasets are most commonly used?" +"Are there any contradictions between the papers' findings?" +``` + +## How to Interpret the Gap Analysis + +The gap analysis report has six sections: + +| Section | What it means | +|---|---| +| **Common Themes** | Topics / findings that appear in multiple papers β€” the consensus view | +| **Contradictions** | Where papers disagree β€” potential areas of ongoing debate | +| **Missing Experiments** | Experiments that logically follow from the existing work but haven't been done | +| **Missing Populations** | Groups, languages, contexts, or demographics not yet studied | +| **Methodological Gaps** | Approaches not used in any paper (e.g., "no longitudinal study exists") | +| **Suggested Next Steps** | Concrete research directions derived from all of the above | + +> ⚠️ **Always verify the output.** LLMs can hallucinate contradictions or invent plausible-sounding but non-existent gaps. Treat the gap analysis as a *first draft* to refine with domain expertise. + +## Limitations + +1. **LLMs can hallucinate citations** β€” the agent might confidently say "Paper X found Y" when it did not. Always check claims against the original PDF. + +2. **Gap analysis may miss domain-specific context** β€” a gap that is obvious to a domain expert ("nobody used technique Z") requires domain knowledge the LLM may not have. + +3. **Works best with 3–10 papers on the same topic** β€” fewer papers means less to synthesise; more papers risks exceeding the context window during gap analysis. + +4. **PDF extraction quality varies** β€” scanned PDFs, multi-column layouts, and heavy use of figures degrade text extraction. The LLM falls back gracefully but metadata may be incomplete. + +5. **The agent may loop or over-call tools** β€” the `max_iterations=8` safety cap prevents infinite loops but may cut off complex multi-paper comparisons. + +## How to Extend + +### Adding a new tool + +1. Create `src/tools/my_tool.py` with a `create_my_tool(…) -> Tool` function. +2. Import and instantiate it in `src/agent.py` inside `create_research_agent`. +3. Add it to the `tools` list passed to `initialize_agent`. + +The agent will automatically start using the new tool based on its description β€” no other changes needed. + +### Ideas for new tools + +- **`cite_tool`** β€” generate a BibTeX entry for a paper from its metadata. +- **`timeline_tool`** β€” order papers chronologically and show how the field evolved. +- **`keyword_tool`** β€” extract and rank keywords across all papers. +- **`arxiv_tool`** β€” search arXiv for papers related to the indexed collection. diff --git a/03-research-agent/data/papers/.gitkeep b/03-research-agent/data/papers/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/03-research-agent/main.py b/03-research-agent/main.py new file mode 100644 index 0000000..4ed8d8b --- /dev/null +++ b/03-research-agent/main.py @@ -0,0 +1,222 @@ +""" +main.py +------- +Entry point for the 03-research-agent pipeline. + +Usage examples +-------------- +# Parse all PDFs and start an interactive research Q&A session +python main.py --papers-dir data/papers --interactive + +# Ask the agent a single question and exit +python main.py --papers-dir data/papers --query "What methodologies are used across these papers?" + +# Generate a full gap analysis report +python main.py --papers-dir data/papers --topic "transformer models" --report + +# Combine: generate a report and also run an interactive session +python main.py --papers-dir data/papers --topic "NLP" --report --interactive + +# Save the report to a specific file +python main.py --papers-dir data/papers --topic "BERT fine-tuning" --report --output reports/bert.md +""" + +import argparse +import os +import sys +from pathlib import Path + +from dotenv import load_dotenv + +# --------------------------------------------------------------------------- +# Load environment variables from .env before any other imports that might +# need OPENAI_API_KEY (e.g., langchain_openai) +# --------------------------------------------------------------------------- +load_dotenv() + + +def _build_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser( + prog="main.py", + description="AI Research Agent β€” analyse a collection of research PDFs.", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=__doc__, + ) + parser.add_argument( + "--papers-dir", + default=os.getenv("PAPERS_DIR", "data/papers"), + metavar="DIR", + help="Directory containing *.pdf files (default: data/papers)", + ) + parser.add_argument( + "--topic", + default="Research Analysis", + metavar="TOPIC", + help="Research topic label used in the report title (default: 'Research Analysis')", + ) + parser.add_argument( + "--model", + default=os.getenv("OPENAI_MODEL", "gpt-4"), + metavar="MODEL", + help="OpenAI model name (default: gpt-4)", + ) + parser.add_argument( + "--query", + default=None, + metavar="QUESTION", + help="Ask the agent a single question and exit.", + ) + parser.add_argument( + "--report", + action="store_true", + help="Run gap analysis and generate a Markdown report.", + ) + parser.add_argument( + "--output", + default=None, + metavar="PATH", + help="Output file path for the Markdown report (only used with --report).", + ) + parser.add_argument( + "--interactive", + action="store_true", + help="Start an interactive Q&A session with the agent.", + ) + return parser + + +def _check_api_key() -> None: + """Exit early with a clear error if the OpenAI key is missing.""" + if not os.getenv("OPENAI_API_KEY"): + print( + "[main] ERROR: OPENAI_API_KEY environment variable is not set.\n" + " Copy .env.example to .env and add your key.", + file=sys.stderr, + ) + sys.exit(1) + + +def main() -> None: + parser = _build_parser() + args = parser.parse_args() + + _check_api_key() + + # ------------------------------------------------------------------ + # Lazy imports so startup is fast when there are argument errors + # ------------------------------------------------------------------ + from langchain_openai import ChatOpenAI + + from src.agent import create_research_agent, run_agent + from src.gap_analyzer import analyze_gaps, format_gap_analysis + from src.paper_indexer import index_papers + from src.paper_parser import parse_all_papers + from src.report_generator import generate_report + + # ------------------------------------------------------------------ + # Step 1: Validate papers directory + # ------------------------------------------------------------------ + papers_dir = Path(args.papers_dir) + if not papers_dir.exists(): + print(f"[main] ERROR: Papers directory '{papers_dir}' does not exist.", file=sys.stderr) + sys.exit(1) + + pdf_count = len(list(papers_dir.glob("*.pdf"))) + if pdf_count == 0: + print( + f"[main] ERROR: No PDF files found in '{papers_dir}'.\n" + " Add research papers as .pdf files and try again.", + file=sys.stderr, + ) + sys.exit(1) + + print(f"[main] Found {pdf_count} PDF file(s) in '{papers_dir}'.") + + # ------------------------------------------------------------------ + # Step 2: Initialise LLM + # ------------------------------------------------------------------ + print(f"[main] Using model: {args.model}") + llm = ChatOpenAI( + model=args.model, + temperature=0, # deterministic output for research tasks + openai_api_key=os.environ["OPENAI_API_KEY"], + ) + + # ------------------------------------------------------------------ + # Step 3: Parse all papers with LLM + # ------------------------------------------------------------------ + print("\n[main] === Step 1/3: Parsing papers ===") + paper_metadata = parse_all_papers(args.papers_dir, llm) + + if not paper_metadata: + print("[main] ERROR: No papers were successfully parsed.", file=sys.stderr) + sys.exit(1) + + print(f"[main] Parsed {len(paper_metadata)} paper(s).") + + # ------------------------------------------------------------------ + # Step 4: Index papers in FAISS + # ------------------------------------------------------------------ + print("\n[main] === Step 2/3: Indexing papers in FAISS ===") + vector_store = index_papers(args.papers_dir) + + # ------------------------------------------------------------------ + # Step 5: Create the research agent + # ------------------------------------------------------------------ + print("\n[main] === Step 3/3: Building research agent ===") + agent = create_research_agent(vector_store, paper_metadata, llm) + print("[main] Agent ready.\n") + + # ------------------------------------------------------------------ + # Step 6a: Generate report (--report) + # ------------------------------------------------------------------ + if args.report: + print("[main] Running gap analysis…") + gaps = analyze_gaps(paper_metadata, llm) + + print(format_gap_analysis(gaps)) + + report = generate_report( + paper_metadata_list=paper_metadata, + gap_analysis=gaps, + topic=args.topic, + output_path=args.output, + ) + print(f"[main] Report generated ({len(report)} characters).") + + # ------------------------------------------------------------------ + # Step 6b: Single query (--query) + # ------------------------------------------------------------------ + if args.query: + run_agent(args.query, agent) + + # ------------------------------------------------------------------ + # Step 6c: Interactive session (--interactive) + # ------------------------------------------------------------------ + if args.interactive: + print("\n[main] Entering interactive mode. Type 'exit' or 'quit' to stop.\n") + while True: + try: + user_input = input("You: ").strip() + except (EOFError, KeyboardInterrupt): + print("\n[main] Exiting.") + break + + if not user_input: + continue + if user_input.lower() in {"exit", "quit", "q"}: + print("[main] Goodbye!") + break + + run_agent(user_input, agent) + + # If no action flag was given, print help + if not args.report and not args.query and not args.interactive: + parser.print_help() + print( + "\n[main] No action specified. Use --query, --report, or --interactive." + ) + + +if __name__ == "__main__": + main() diff --git a/03-research-agent/requirements.txt b/03-research-agent/requirements.txt new file mode 100644 index 0000000..4b9fb18 --- /dev/null +++ b/03-research-agent/requirements.txt @@ -0,0 +1,10 @@ +langchain==0.1.20 +langchain-community==0.0.38 +langchain-openai==0.1.6 +faiss-cpu==1.8.0 +sentence-transformers==2.7.0 +pypdf==4.2.0 +openai==1.30.1 +python-dotenv==1.0.1 +pydantic==2.7.1 +arxiv==2.1.0 diff --git a/03-research-agent/src/__init__.py b/03-research-agent/src/__init__.py new file mode 100644 index 0000000..2c12f6c --- /dev/null +++ b/03-research-agent/src/__init__.py @@ -0,0 +1,2 @@ +# src/__init__.py +# Makes 'src' a Python package so modules can be imported as src.paper_parser, etc. diff --git a/03-research-agent/src/agent.py b/03-research-agent/src/agent.py new file mode 100644 index 0000000..beb7326 --- /dev/null +++ b/03-research-agent/src/agent.py @@ -0,0 +1,145 @@ +""" +src/agent.py +------------ +Wires together the tools and LLM into a LangChain ReAct agent. + +WHAT IS THE REACT LOOP? +------------------------ +ReAct (Reason + Act) is a prompting strategy where the LLM alternates between: + + Thought – the model reasons about what to do next + Action – the model picks a tool and writes an input for it + Observation – the tool runs and its output is appended to the prompt + … repeat until … + Final Answer – the model decides it has enough information + +Example: + Thought : I need to find papers about transformers. I'll search. + Action : search_papers + Action Input: transformer self-attention mechanism + Observation : [Result 1] Paper: "Attention Is All You Need" … + Thought : I found a relevant paper. Now I'll summarize it. + Action : summarize_paper + Action Input: Attention Is All You Need + Observation : Title: Attention Is All You Need … + Final Answer: The paper "Attention Is All You Need" introduced … + +HOW THE AGENT SEES THE TOOLS +------------------------------ +The agent receives a text-formatted list of tool names and descriptions in its +system prompt. It never sees function signatures or source code. This is why +precise tool descriptions are critical: they are the agent's entire API docs. + +WHY verbose=True IS IMPORTANT FOR LEARNING +------------------------------------------- +With verbose=True LangChain prints every Thought / Action / Observation to +stdout. You can watch the agent's reasoning unfold in real time. This is +invaluable for understanding why the agent chose a particular tool, and for +debugging when it makes the wrong choice. + +THE DIFFERENCE BETWEEN AN AGENT AND A SIMPLE LLM CALL +-------------------------------------------------------- +A simple LLM call is a single prompt β†’ single response. The LLM cannot fetch +new information mid-response. An agent can: + - Decide which tool to call based on intermediate results + - Retry with a different query if the first search returns nothing + - Chain multiple tool calls (search β†’ summarize β†’ compare) + - Stop early if the first observation already answers the question + +WHAT "ZERO SHOT" MEANS +------------------------ +ZERO_SHOT_REACT_DESCRIPTION means the agent needs zero examples (shots) in its +prompt. It figures out when and how to use each tool purely from the tool +description. This keeps the prompt short and avoids the need to curate +few-shot examples for every new tool. +""" + +from langchain.agents import AgentExecutor, AgentType, initialize_agent +from langchain_community.vectorstores import FAISS + +from src.tools.compare_tool import create_compare_tool +from src.tools.search_tool import create_search_tool +from src.tools.summary_tool import create_summary_tool + +# System prompt injected as the agent's persona and behavioural guidelines. +# The prefix is prepended to the auto-generated ReAct prompt that lists tools. +_AGENT_PREFIX = """You are an AI research assistant. You have access to a collection of research papers. +Use the available tools to answer questions about the research literature. +Always cite your sources by mentioning which paper a piece of information comes from. +Think step by step about which tools to use.""" + + +def create_research_agent( + vector_store: FAISS, + paper_metadata: list, + llm, +) -> AgentExecutor: + """Build and return a fully configured ReAct research agent. + + Parameters + ---------- + vector_store : FAISS + Populated FAISS index (from paper_indexer.index_papers). + paper_metadata : list[PaperMetadata] + List of parsed paper metadata objects. + llm : + Any LangChain chat model (e.g., ChatOpenAI). + + Returns + ------- + AgentExecutor + The runnable agent. Call agent.run(query) to use it. + """ + # Build a title β†’ metadata dict for the summary and compare tools + paper_metadata_dict = {pm.title: pm for pm in paper_metadata} + + # Instantiate each tool + search_tool = create_search_tool(vector_store) + summary_tool = create_summary_tool(paper_metadata_dict, llm) + compare_tool = create_compare_tool(paper_metadata_dict, llm) + + tools = [search_tool, summary_tool, compare_tool] + + # initialize_agent wraps the LLM + tools in a ReAct prompt loop. + # ZERO_SHOT_REACT_DESCRIPTION: no few-shot examples, tool selection driven + # entirely by the description strings we provided above. + agent = initialize_agent( + tools=tools, + llm=llm, + agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, + verbose=True, # print Thought/Action/Observation to stdout + handle_parsing_errors=True, # recover gracefully from malformed tool calls + agent_kwargs={"prefix": _AGENT_PREFIX}, + max_iterations=8, # safety cap to prevent infinite loops + ) + + return agent + + +def run_agent(query: str, agent: AgentExecutor) -> str: + """Run a single query through the research agent. + + Parameters + ---------- + query : str + The user's question or instruction. + agent : AgentExecutor + The agent built by :func:`create_research_agent`. + + Returns + ------- + str + The agent's final answer. + """ + print(f"\n{'='*60}") + print(f"Query: {query}") + print(f"{'='*60}\n") + + result = agent.run(query) + + print(f"\n{'='*60}") + print("Final Answer:") + print(result) + print(f"{'='*60}\n") + + return result diff --git a/03-research-agent/src/gap_analyzer.py b/03-research-agent/src/gap_analyzer.py new file mode 100644 index 0000000..fe98de7 --- /dev/null +++ b/03-research-agent/src/gap_analyzer.py @@ -0,0 +1,185 @@ +""" +src/gap_analyzer.py +-------------------- +Uses an LLM to synthesise research gaps across a collection of papers. + +THIS IS PROMPTED REASONING, NOT DATABASE LOGIC +------------------------------------------------ +Traditional literature review tools use citation graphs, keyword co-occurrence +matrices, or statistical topic models to find gaps. We use a different +approach: we feed all paper summaries to an LLM and ask it to reason about +what is missing. + +Advantages: + - No need for structured metadata like citation counts or MeSH terms. + - Can catch conceptual gaps ("nobody studied X in context Y") that keyword + matching would miss. + - Works on any research domain without domain-specific pre-processing. + +Disadvantages (see LIMITATIONS below): + - The LLM may hallucinate gaps or themes that are not actually present. + - Subtle contradictions buried in technical detail may be missed. + - The quality of the output is bounded by the quality of the summaries. + +WHY SYNTHESIS REQUIRES READING ALL PAPERS, NOT JUST SEARCHING +--------------------------------------------------------------- +Semantic search retrieves chunks relevant to a specific query. Gap analysis +is a meta-level task: it needs to observe the DISTRIBUTION of topics across +the entire corpus. If only 3 of 8 papers mention "dataset bias" we cannot +detect that gap by searching for "bias" β€” we need to compare absence vs +presence across all papers simultaneously. + +LIMITATIONS +----------- +1. The LLM may invent plausible-sounding but false contradictions. +2. Gaps requiring deep domain expertise (e.g., specific biochemical pathways) + may be missed or mischaracterised. +3. This prompt works best with 3–10 papers on the same topic. With 20+ + papers the concatenated summaries may exceed the context window. +4. Results should always be reviewed by a domain expert before acting on them. +""" + +import json + + +# --------------------------------------------------------------------------- +# Synthesis prompt +# --------------------------------------------------------------------------- + +_GAP_ANALYSIS_PROMPT = """You are analyzing a collection of research papers on a topic. + +Here are summaries of all the papers: +{all_summaries} + +Based on these papers, provide a research gap analysis with: +1. common_themes: What topics/findings appear across multiple papers? +2. contradictions: Where do papers disagree or contradict each other? +3. missing_experiments: What experiments have NOT been done that would be valuable? +4. missing_populations: What groups or contexts haven't been studied? +5. methodological_gaps: What methodological approaches are missing? +6. suggested_next_steps: 3-5 specific research directions worth pursuing + +Respond with JSON only.""" + + +def analyze_gaps(paper_metadata_list: list, llm) -> dict: + """Run a cross-paper synthesis prompt and return structured gap analysis. + + Parameters + ---------- + paper_metadata_list : list[PaperMetadata] + All parsed papers to analyse. + llm : + Any LangChain chat model. + + Returns + ------- + dict with keys: common_themes, contradictions, missing_experiments, + missing_populations, methodological_gaps, suggested_next_steps + """ + if not paper_metadata_list: + return { + "common_themes": [], + "contradictions": [], + "missing_experiments": [], + "missing_populations": [], + "methodological_gaps": [], + "suggested_next_steps": [], + "error": "No papers provided for gap analysis.", + } + + # Build a human-readable summary block for each paper + summary_blocks = [] + for pm in paper_metadata_list: + authors_str = ", ".join(pm.authors) if pm.authors else "Unknown" + findings_str = ( + "\n ".join(f"β€’ {f}" for f in pm.key_findings) + if pm.key_findings + else "(not extracted)" + ) + limitations_str = ( + "\n ".join(f"β€’ {l}" for l in pm.limitations) + if pm.limitations + else "(not extracted)" + ) + block = ( + f"--- Paper: {pm.title} ---\n" + f"Authors : {authors_str}\n" + f"Year : {pm.year or 'Unknown'}\n" + f"Methodology: {pm.methodology or 'Not extracted'}\n" + f"Key Findings:\n {findings_str}\n" + f"Limitations:\n {limitations_str}" + ) + summary_blocks.append(block) + + all_summaries = "\n\n".join(summary_blocks) + prompt = _GAP_ANALYSIS_PROMPT.format(all_summaries=all_summaries) + + print("[gap_analyzer] Running synthesis prompt across all papers…") + response = llm.invoke(prompt) + raw = response.content if hasattr(response, "content") else str(response) + + # Strip markdown code fences if present + raw = raw.strip() + if raw.startswith("```"): + raw = raw.split("```", 2)[1] + if raw.startswith("json"): + raw = raw[4:] + raw = raw.rsplit("```", 1)[0].strip() + + try: + gaps = json.loads(raw) + except json.JSONDecodeError as exc: + print(f"[gap_analyzer] Warning: could not parse JSON response: {exc}") + # Return the raw text under a fallback key so nothing is lost + gaps = { + "common_themes": [], + "contradictions": [], + "missing_experiments": [], + "missing_populations": [], + "methodological_gaps": [], + "suggested_next_steps": [], + "raw_response": raw, + } + + return gaps + + +def format_gap_analysis(gaps: dict) -> str: + """Format a gap analysis dict as a human-readable string for console display. + + Parameters + ---------- + gaps : dict + Output of :func:`analyze_gaps`. + + Returns + ------- + str + """ + def _fmt_list(items) -> str: + if not items: + return " (none identified)" + if isinstance(items, list): + return "\n".join(f" β€’ {item}" for item in items) + return f" {items}" + + sections = [ + ("Common Themes", gaps.get("common_themes", [])), + ("Contradictions", gaps.get("contradictions", [])), + ("Missing Experiments", gaps.get("missing_experiments", [])), + ("Missing Populations", gaps.get("missing_populations", [])), + ("Methodological Gaps", gaps.get("methodological_gaps", [])), + ("Suggested Next Steps", gaps.get("suggested_next_steps", [])), + ] + + lines = ["=" * 60, "RESEARCH GAP ANALYSIS", "=" * 60] + for heading, items in sections: + lines.append(f"\n{heading}:") + lines.append(_fmt_list(items)) + + if "raw_response" in gaps: + lines.append("\n[Raw LLM response β€” JSON parsing failed]") + lines.append(gaps["raw_response"]) + + return "\n".join(lines) diff --git a/03-research-agent/src/paper_indexer.py b/03-research-agent/src/paper_indexer.py new file mode 100644 index 0000000..57d508d --- /dev/null +++ b/03-research-agent/src/paper_indexer.py @@ -0,0 +1,180 @@ +""" +src/paper_indexer.py +-------------------- +Embeds research papers and stores them in a FAISS vector index. + +HOW METADATA TAGGING WORKS IN FAISS (via LangChain) +----------------------------------------------------- +LangChain's FAISS wrapper stores a Python dict alongside each embedded chunk. +When you call `FAISS.from_documents(docs)`, each Document's `.metadata` dict +is persisted verbatim next to its vector. At search time, every returned +Document carries its original metadata, so you can read `doc.metadata["source"]` +to know which paper a chunk came from. + +METADATA FILTERING: "SEARCH ONLY WITHIN PAPER X" +-------------------------------------------------- +FAISS itself does not support SQL-style WHERE clauses β€” it returns the k nearest +vectors globally. We implement per-paper filtering post-hoc: run the query +across the full index, then discard results whose `doc.metadata["source"]` does +not match the requested paper title. This is simple and correct for small +collections (< a few thousand chunks). For larger collections a dedicated +vector DB (Pinecone, Weaviate, Qdrant) with native metadata filters is better. + +WHY INDEX ALL PAPERS TOGETHER? +------------------------------- +A single shared index enables cross-paper queries like "which papers discuss +attention mechanisms?". If each paper had its own index you would have to +query N indexes and merge results manually. The trade-off is that per-paper +filtering requires a post-search step, but that cost is negligible at the +scale of a typical research collection (3-50 papers). +""" + +import os +from pathlib import Path + +from langchain.text_splitter import RecursiveCharacterTextSplitter +from langchain_community.document_loaders import PyPDFLoader +from langchain_community.embeddings import HuggingFaceEmbeddings +from langchain_community.vectorstores import FAISS + +# --------------------------------------------------------------------------- +# Embedding model +# --------------------------------------------------------------------------- +# "all-MiniLM-L6-v2" is a fast, lightweight model (80 MB) that works well for +# semantic similarity on academic text and runs entirely locally (no API key). +_EMBEDDING_MODEL = "all-MiniLM-L6-v2" + +# Chunk parameters: 1 000 chars with 200-char overlap. +# Research paragraphs average ~500-800 chars, so a 1 000-char window usually +# captures a complete idea. Overlap prevents a sentence at a chunk boundary +# from being split across two embeddings. +_CHUNK_SIZE = 1000 +_CHUNK_OVERLAP = 200 + + +def index_papers( + papers_dir: str, + index_path: str = "papers_faiss_index", +) -> FAISS: + """Load all PDFs in *papers_dir*, chunk them, embed them, and build a FAISS index. + + Each chunk's metadata contains: + source – paper title derived from filename (used for filtering) + file_path – absolute path to the source PDF + chunk_id – sequential integer within that paper + + Parameters + ---------- + papers_dir : str + Directory containing *.pdf files. + index_path : str + Directory where the FAISS index will be saved to disk. + + Returns + ------- + FAISS + The populated vector store. + """ + papers_path = Path(papers_dir) + pdf_files = sorted(papers_path.glob("*.pdf")) + + if not pdf_files: + raise FileNotFoundError(f"No PDF files found in '{papers_dir}'.") + + splitter = RecursiveCharacterTextSplitter( + chunk_size=_CHUNK_SIZE, + chunk_overlap=_CHUNK_OVERLAP, + separators=["\n\n", "\n", " ", ""], # prefer paragraph β†’ line β†’ word splits + ) + + all_docs = [] + for pdf in pdf_files: + print(f"[paper_indexer] Loading: {pdf.name}") + loader = PyPDFLoader(str(pdf)) + pages = loader.load() + + # Derive a human-readable source label from the filename + paper_title = pdf.stem.replace("_", " ").replace("-", " ") + + chunks = splitter.split_documents(pages) + for i, chunk in enumerate(chunks): + # Enrich metadata β€” LangChain's PyPDFLoader already adds 'source' + # and 'page'; we add our own fields on top. + chunk.metadata["source"] = paper_title + chunk.metadata["file_path"] = str(pdf) + chunk.metadata["chunk_id"] = i + + all_docs.extend(chunks) + print(f"[paper_indexer] β†’ {len(chunks)} chunk(s)") + + print(f"[paper_indexer] Embedding {len(all_docs)} total chunks…") + embeddings = HuggingFaceEmbeddings(model_name=_EMBEDDING_MODEL) + vector_store = FAISS.from_documents(all_docs, embeddings) + + # Persist to disk so we can reload without re-embedding + vector_store.save_local(index_path) + print(f"[paper_indexer] Index saved to '{index_path}'.") + + return vector_store + + +def load_index(index_path: str = "papers_faiss_index") -> FAISS: + """Load a previously saved FAISS index from disk. + + Parameters + ---------- + index_path : str + Directory where the index was saved by :func:`index_papers`. + + Returns + ------- + FAISS + """ + embeddings = HuggingFaceEmbeddings(model_name=_EMBEDDING_MODEL) + vector_store = FAISS.load_local( + index_path, + embeddings, + allow_dangerous_deserialization=True, # required by LangChain β‰₯ 0.1 + ) + print(f"[paper_indexer] Loaded index from '{index_path}'.") + return vector_store + + +def search_papers( + query: str, + vector_store: FAISS, + k: int = 5, + paper_filter: str = None, +) -> list: + """Semantic search over the FAISS index. + + Parameters + ---------- + query : str + Natural-language search query. + vector_store : FAISS + The populated vector store. + k : int + Number of results to return (before optional filtering). + paper_filter : str or None + If provided, only return chunks whose metadata["source"] contains + this string (case-insensitive). This implements per-paper search. + + Returns + ------- + list[Document] + Matching document chunks, each with .page_content and .metadata. + """ + # Retrieve more candidates when filtering so we still get k results after + # dropping non-matching papers + fetch_k = k * 4 if paper_filter else k + results = vector_store.similarity_search(query, k=fetch_k) + + if paper_filter: + filter_lower = paper_filter.lower() + results = [ + doc for doc in results + if filter_lower in doc.metadata.get("source", "").lower() + ] + + return results[:k] diff --git a/03-research-agent/src/paper_parser.py b/03-research-agent/src/paper_parser.py new file mode 100644 index 0000000..6b3c2ad --- /dev/null +++ b/03-research-agent/src/paper_parser.py @@ -0,0 +1,188 @@ +""" +src/paper_parser.py +------------------- +Parses research PDFs into structured metadata using an LLM. + +WHY STRUCTURED EXTRACTION INSTEAD OF RAW TEXT? +------------------------------------------------ +Storing raw text is simple, but it makes downstream tasks hard: + - Comparing papers requires knowing WHERE the methodology lives. + - Gap analysis needs to see key_findings from every paper side-by-side. + - Fuzzy title search works better when the title is its own field. + +By asking the LLM to fill a fixed schema once (at index time), every later +operation (compare, summarise, gap-analyse) can just read Python attributes +instead of re-searching the raw text. + +WHY ONLY THE FIRST 3 PAGES FOR METADATA? +----------------------------------------- +Research papers place title, authors, and abstract on page 1, sometimes +spilling to page 2. Page 3 occasionally contains the introduction which +gives methodology context. Beyond page 3 we are in body / results / tables +territory β€” the LLM prompt would be dominated by noisy content and would +exceed the context window for no gain. + +HOW MESSY PDF FORMATTING AFFECTS EXTRACTION +--------------------------------------------- +PDFs are layout-first, not text-first. Common problems: + - Multi-column layouts produce garbled word order when extracted linearly. + - Footnotes and headers are interspersed with body text. + - Figures and tables appear as blank space or gibberish characters. + - Hyphenated line-breaks split words across lines. + +We mitigate this by: + 1. Limiting extraction to the first 3 000 characters (header area). + 2. Using an LLM instead of regexes β€” LLMs are robust to mild formatting noise. + 3. Falling back to the filename as title when the LLM cannot parse the text. +""" + +import json +import os +from pathlib import Path +from typing import Optional + +from langchain_community.document_loaders import PyPDFLoader +from pydantic import BaseModel, Field + + +class PaperMetadata(BaseModel): + """Structured representation of a research paper's key metadata. + + Fields are intentionally coarse-grained (e.g., 'methodology' is a + 1-2 sentence description) so the LLM can fill them reliably even when + the PDF formatting is messy. + """ + + title: str = Field(description="Full title of the paper") + authors: list[str] = Field(default_factory=list, description="List of author names") + year: Optional[str] = Field(default=None, description="Publication year if found") + abstract: Optional[str] = Field(default=None, description="Full abstract text") + methodology: Optional[str] = Field( + default=None, + description="1-2 sentence description of the research methodology", + ) + key_findings: list[str] = Field( + default_factory=list, + description="3-5 main findings from the paper", + ) + limitations: list[str] = Field( + default_factory=list, + description="Limitations acknowledged by the authors", + ) + file_path: str = Field(description="Absolute or relative path to the source PDF") + + +# --------------------------------------------------------------------------- +# Extraction prompt +# --------------------------------------------------------------------------- + +_EXTRACTION_PROMPT = """Extract the following from this research paper text: +- title: Full paper title +- authors: List of author names +- year: Publication year (if found) +- abstract: Full abstract text +- methodology: Brief description of research methodology (1-2 sentences) +- key_findings: List of 3-5 main findings +- limitations: List of limitations mentioned by authors + +Paper text (first 3000 chars): +{text} + +Respond with JSON only.""" + + +def parse_paper(file_path: str, llm) -> PaperMetadata: + """Load a single PDF and extract structured metadata using an LLM. + + Steps + ----- + 1. Load all pages with PyPDFLoader. + 2. Concatenate text from the first 3 pages and truncate to 3 000 chars. + 3. Ask the LLM to fill the extraction schema (JSON response). + 4. Parse the JSON into a PaperMetadata object. + 5. On any failure, fall back to filename-derived title with empty fields. + + Parameters + ---------- + file_path : str + Path to the PDF file. + llm : + Any LangChain chat model (e.g., ChatOpenAI). + + Returns + ------- + PaperMetadata + """ + loader = PyPDFLoader(file_path) + pages = loader.load() + + # Combine text from first 3 pages only β€” metadata lives here + excerpt = "\n".join(p.page_content for p in pages[:3])[:3000] + + prompt = _EXTRACTION_PROMPT.format(text=excerpt) + + try: + response = llm.invoke(prompt) + # Handle both string responses and AIMessage objects + raw = response.content if hasattr(response, "content") else str(response) + + # Strip markdown code fences if the LLM wraps JSON in ```json ... ``` + raw = raw.strip() + if raw.startswith("```"): + raw = raw.split("```", 2)[1] + if raw.startswith("json"): + raw = raw[4:] + raw = raw.rsplit("```", 1)[0].strip() + + data = json.loads(raw) + return PaperMetadata( + title=data.get("title", Path(file_path).stem), + authors=data.get("authors", []), + year=str(data.get("year")) if data.get("year") else None, + abstract=data.get("abstract"), + methodology=data.get("methodology"), + key_findings=data.get("key_findings", []), + limitations=data.get("limitations", []), + file_path=file_path, + ) + + except Exception as exc: + # Graceful degradation: use the filename as title, leave everything else blank. + # This means the paper can still be searched even if LLM extraction failed. + print(f"[paper_parser] Warning: could not extract metadata from '{file_path}': {exc}") + return PaperMetadata( + title=Path(file_path).stem, + file_path=file_path, + ) + + +def parse_all_papers(papers_dir: str, llm) -> list[PaperMetadata]: + """Parse every PDF found in *papers_dir* and return a list of PaperMetadata. + + Parameters + ---------- + papers_dir : str + Directory that contains *.pdf files (non-recursive). + llm : + Any LangChain chat model. + + Returns + ------- + list[PaperMetadata] + One entry per successfully located PDF. Empty list if no PDFs found. + """ + papers_path = Path(papers_dir) + pdf_files = sorted(papers_path.glob("*.pdf")) + + if not pdf_files: + print(f"[paper_parser] No PDF files found in '{papers_dir}'.") + return [] + + results: list[PaperMetadata] = [] + for pdf in pdf_files: + print(f"[paper_parser] Parsing: {pdf.name}") + metadata = parse_paper(str(pdf), llm) + results.append(metadata) + + print(f"[paper_parser] Parsed {len(results)} paper(s).") + return results diff --git a/03-research-agent/src/report_generator.py b/03-research-agent/src/report_generator.py new file mode 100644 index 0000000..99bbddc --- /dev/null +++ b/03-research-agent/src/report_generator.py @@ -0,0 +1,201 @@ +""" +src/report_generator.py +------------------------ +Generates a structured Markdown report from parsed paper metadata and gap analysis. + +HOW OUTPUT PARSERS CAN ENFORCE STRUCTURE +----------------------------------------- +Here we build the Markdown manually using Python f-strings. An alternative +approach is to use LangChain's StructuredOutputParser or PydanticOutputParser: + + 1. Define a Pydantic model with all report sections as fields. + 2. Attach the parser's format_instructions to the LLM prompt. + 3. The LLM fills the model; the parser deserialises it. + +This is valuable when the *content* of each section needs to be LLM-generated +(e.g., "write a paragraph summarising the common themes"). For our report the +content comes from already-structured dicts (PaperMetadata, gap analysis dict), +so simple string formatting is cleaner and faster β€” no extra LLM call needed. + +The general lesson: use output parsers when you need the LLM to produce +structured data; use string formatting when you already have structured data +and just need to render it. +""" + +import os +from datetime import datetime +from pathlib import Path + + +def generate_report( + paper_metadata_list: list, + gap_analysis: dict, + topic: str, + output_path: str = None, +) -> str: + """Generate a full Markdown research report and optionally save it to disk. + + Parameters + ---------- + paper_metadata_list : list[PaperMetadata] + All parsed papers. + gap_analysis : dict + Output of gap_analyzer.analyze_gaps(). + topic : str + Human-readable topic label used in the report title. + output_path : str or None + If provided, the report is written to this path. + If None, a timestamped filename is used automatically. + + Returns + ------- + str + The complete Markdown report as a string. + """ + now = datetime.now() + timestamp = now.strftime("%Y-%m-%d %H:%M") + file_timestamp = now.strftime("%Y%m%d_%H%M%S") + + # ------------------------------------------------------------------ + # Helper utilities + # ------------------------------------------------------------------ + + def _list_section(items) -> str: + """Render a list of strings as a Markdown bullet list.""" + if not items: + return "_None identified._\n" + return "\n".join(f"- {item}" for item in items) + "\n" + + def _numbered_section(items) -> str: + """Render a list of strings as a Markdown numbered list.""" + if not items: + return "_None identified._\n" + return "\n".join(f"{i}. {item}" for i, item in enumerate(items, 1)) + "\n" + + # ------------------------------------------------------------------ + # Section: Title & preamble + # ------------------------------------------------------------------ + lines = [ + f"# Research Literature Analysis: {topic}", + "", + f"**Generated:** {timestamp} ", + f"**Papers analysed:** {len(paper_metadata_list)}", + "", + ] + + # ------------------------------------------------------------------ + # Section: Overview + # ------------------------------------------------------------------ + lines += [ + "## Overview", + "", + f"This report analyses **{len(paper_metadata_list)}** research paper(s) on the topic of **{topic}**.", + "", + "### Papers in this collection", + "", + ] + for pm in paper_metadata_list: + authors_str = ", ".join(pm.authors[:3]) if pm.authors else "Unknown" + if len(pm.authors) > 3: + authors_str += " et al." + year_str = f" ({pm.year})" if pm.year else "" + lines.append(f"- **{pm.title}**{year_str} β€” {authors_str}") + lines.append("") + + # ------------------------------------------------------------------ + # Section: Individual Paper Summaries + # ------------------------------------------------------------------ + lines += ["## Individual Paper Summaries", ""] + + for pm in paper_metadata_list: + authors_str = ", ".join(pm.authors) if pm.authors else "Unknown" + lines += [ + f"### {pm.title}", + "", + f"**Authors:** {authors_str} ", + f"**Year:** {pm.year or 'Unknown'} ", + f"**File:** `{Path(pm.file_path).name}`", + "", + ] + if pm.abstract: + lines += ["**Abstract:**", "", pm.abstract, ""] + if pm.methodology: + lines += [f"**Methodology:** {pm.methodology}", ""] + if pm.key_findings: + lines += ["**Key Findings:**", ""] + lines += [f"- {f}" for f in pm.key_findings] + lines.append("") + if pm.limitations: + lines += ["**Limitations:**", ""] + lines += [f"- {l}" for l in pm.limitations] + lines.append("") + lines.append("---") + lines.append("") + + # ------------------------------------------------------------------ + # Section: Cross-Paper Analysis + # ------------------------------------------------------------------ + lines += ["## Cross-Paper Analysis", ""] + + lines += ["### Common Themes", ""] + lines.append(_list_section(gap_analysis.get("common_themes", []))) + + lines += ["### Contradictions Between Papers", ""] + lines.append(_list_section(gap_analysis.get("contradictions", []))) + + # ------------------------------------------------------------------ + # Section: Research Gaps + # ------------------------------------------------------------------ + lines += ["## Research Gaps", ""] + + lines += ["### Missing Experiments", ""] + lines.append(_list_section(gap_analysis.get("missing_experiments", []))) + + lines += ["### Under-Studied Populations or Contexts", ""] + lines.append(_list_section(gap_analysis.get("missing_populations", []))) + + lines += ["### Methodological Gaps", ""] + lines.append(_list_section(gap_analysis.get("methodological_gaps", []))) + + # ------------------------------------------------------------------ + # Section: Suggested Next Steps + # ------------------------------------------------------------------ + lines += ["## Suggested Next Steps", ""] + lines.append(_numbered_section(gap_analysis.get("suggested_next_steps", []))) + + # ------------------------------------------------------------------ + # Section: Paper Index + # ------------------------------------------------------------------ + lines += ["## Paper Index", ""] + lines += ["| # | Title | File |", "|---|-------|------|"] + for i, pm in enumerate(paper_metadata_list, 1): + fname = Path(pm.file_path).name + lines.append(f"| {i} | {pm.title} | `{fname}` |") + lines.append("") + + # ------------------------------------------------------------------ + # Footer + # ------------------------------------------------------------------ + lines += [ + "---", + "", + "_Report generated by the 03-research-agent pipeline. " + "Always verify findings against the original papers β€” " + "LLMs can hallucinate citations and misrepresent content._", + "", + ] + + report = "\n".join(lines) + + # ------------------------------------------------------------------ + # Save to disk + # ------------------------------------------------------------------ + if output_path is None: + output_path = f"research_report_{file_timestamp}.md" + + out = Path(output_path) + out.parent.mkdir(parents=True, exist_ok=True) + out.write_text(report, encoding="utf-8") + print(f"[report_generator] Report saved to '{output_path}'.") + + return report diff --git a/03-research-agent/src/tools/__init__.py b/03-research-agent/src/tools/__init__.py new file mode 100644 index 0000000..ba6e12c --- /dev/null +++ b/03-research-agent/src/tools/__init__.py @@ -0,0 +1,2 @@ +# src/tools/__init__.py +# Makes 'src/tools' a Python package. diff --git a/03-research-agent/src/tools/compare_tool.py b/03-research-agent/src/tools/compare_tool.py new file mode 100644 index 0000000..000dfed --- /dev/null +++ b/03-research-agent/src/tools/compare_tool.py @@ -0,0 +1,133 @@ +""" +src/tools/compare_tool.py +-------------------------- +LangChain Tool that uses an LLM to compare two papers' methodologies and findings. + +WHERE AGENTS SHINE: MULTI-STEP REASONING ACROSS DOCUMENTS +---------------------------------------------------------- +Simple RAG retrieves the nearest chunks to a query and returns them. A +comparison task is inherently multi-step: + 1. Identify paper A and paper B by name. + 2. Retrieve full metadata for each. + 3. Synthesise similarities, differences, and contradictions. + 4. Produce a structured answer. + +An agent can chain these steps autonomously. Without the agent layer you +would need to hard-code this pipeline. With an agent the user can simply ask +"Compare Paper A and Paper B" and the agent decides to call this tool. + +HOW THE REACT AGENT USES THIS TOOL +------------------------------------ +A typical agent trace might look like: + + Thought : The user wants to compare "Transformer" and "BERT". + I should use the compare_papers tool. + Action : compare_papers + Action Input: Transformer vs BERT + Observation: [structured comparison returned by this tool] + Thought : I now have the comparison. I can answer the user. + Final Answer: … + +The agent learns the input format ("A vs B") solely from the tool description β€” +no examples are needed. + +INPUT FORMAT +------------ +The tool expects: " vs <title of paper 2>" +The separator " vs " (with spaces) is chosen because it is unambiguous and +unlikely to appear in a paper title. +""" + +from langchain.tools import Tool + +# Comparison prompt template +_COMPARE_PROMPT = """Compare these two research papers: + +Paper 1: {title1} +Methodology: {methodology1} +Key Findings: {findings1} + +Paper 2: {title2} +Methodology: {methodology2} +Key Findings: {findings2} + +Provide a structured comparison covering: +1. Methodological similarities and differences +2. Agreements in findings +3. Contradictions in findings +4. Which paper's approach is stronger and why""" + + +def create_compare_tool(paper_metadata_dict: dict, llm) -> Tool: + """Build a LangChain Tool that compares two papers using an LLM. + + Parameters + ---------- + paper_metadata_dict : dict + Maps paper title (str) β†’ PaperMetadata object. + llm : + Any LangChain chat model used to generate the comparison narrative. + + Returns + ------- + Tool + """ + + def _find_paper(query: str): + """Case-insensitive substring match against known paper titles.""" + q = query.strip().lower() + for title, meta in paper_metadata_dict.items(): + if q in title.lower(): + return meta + return None + + def _compare(input_str: str) -> str: + """Parse 'Paper A vs Paper B', retrieve metadata, call LLM to compare.""" + # Parse the two titles from the 'X vs Y' format + if " vs " not in input_str: + return ( + "Invalid input format. Please use: 'Paper Title A vs Paper Title B'. " + f"Got: '{input_str}'" + ) + + parts = input_str.split(" vs ", maxsplit=1) + title_a, title_b = parts[0].strip(), parts[1].strip() + + paper_a = _find_paper(title_a) + paper_b = _find_paper(title_b) + + # Report clearly which lookups failed so the agent can retry + if paper_a is None and paper_b is None: + return f"Could not find papers matching '{title_a}' or '{title_b}'." + if paper_a is None: + return f"Could not find a paper matching '{title_a}'." + if paper_b is None: + return f"Could not find a paper matching '{title_b}'." + + # Format findings lists as readable strings for the prompt + def fmt_findings(meta) -> str: + if not meta.key_findings: + return "Not extracted" + return "; ".join(meta.key_findings) + + prompt = _COMPARE_PROMPT.format( + title1=paper_a.title, + methodology1=paper_a.methodology or "Not extracted", + findings1=fmt_findings(paper_a), + title2=paper_b.title, + methodology2=paper_b.methodology or "Not extracted", + findings2=fmt_findings(paper_b), + ) + + response = llm.invoke(prompt) + return response.content if hasattr(response, "content") else str(response) + + return Tool( + name="compare_papers", + description=( + "Compare two research papers' methodologies and findings. " + "Input: two paper titles separated by ' vs ' " + "(e.g., 'Paper A vs Paper B')" + ), + func=_compare, + ) diff --git a/03-research-agent/src/tools/search_tool.py b/03-research-agent/src/tools/search_tool.py new file mode 100644 index 0000000..ab1a985 --- /dev/null +++ b/03-research-agent/src/tools/search_tool.py @@ -0,0 +1,90 @@ +""" +src/tools/search_tool.py +------------------------ +LangChain Tool that lets the agent do semantic search over indexed papers. + +WHAT IS A LANGCHAIN TOOL? +-------------------------- +A Tool is a Python function wrapped in a thin object that carries three things: + 1. name – a short identifier (e.g., "search_papers") + 2. description – a plain-English explanation of WHEN and HOW to use the tool + 3. func – the actual callable that receives a string and returns a string + +The agent never inspects the function's source code. It only reads the +name + description to decide whether to call the tool. + +HOW THE AGENT DECIDES WHEN TO USE THIS TOOL +-------------------------------------------- +During each reasoning step the ReAct agent compares its current sub-goal +(e.g., "I need to find papers about attention") to every tool's description. +If "search_papers" says "Search across all indexed research papers …" that +is an obvious match. Poor descriptions cause the agent to either skip a +useful tool or call the wrong one. + +THE INPUT/OUTPUT CONTRACT +-------------------------- + Input – a plain string (the search query). + Output – a plain string that the agent reads as an observation. + +LangChain enforces this contract: whatever your func returns is converted to +str and injected into the agent's prompt as the "Observation:" line. + +WHY TOOL DESCRIPTIONS MUST BE PRECISE +--------------------------------------- +The agent is stateless β€” it has no memory of tool internals. If the +description says "search papers" without clarifying the expected input format, +the agent might pass a JSON object or a question instead of a keyword query, +producing poor results. Explicit examples in the description (like "Input: a +search query string") dramatically improve reliability. +""" + +from langchain.tools import Tool + +from src.paper_indexer import search_papers + + +def create_search_tool(vector_store) -> Tool: + """Build and return a LangChain Tool that searches the FAISS index. + + Parameters + ---------- + vector_store : FAISS + The populated FAISS vector store built by paper_indexer.index_papers(). + + Returns + ------- + Tool + Ready-to-use LangChain Tool instance. + """ + + def _search(query: str) -> str: + """Internal function called by the agent with a plain query string.""" + docs = search_papers(query, vector_store, k=3) + + if not docs: + return "No relevant passages found for that query." + + parts = [] + for i, doc in enumerate(docs, start=1): + source = doc.metadata.get("source", "Unknown paper") + # page is 0-indexed in PyPDFLoader; add 1 for human readability + page = doc.metadata.get("page", 0) + 1 + snippet = doc.page_content.strip()[:400] # keep response concise + parts.append( + f"[Result {i}]\n" + f" Paper : {source}\n" + f" Page : {page}\n" + f" Text : {snippet}…" + ) + + return "\n\n".join(parts) + + return Tool( + name="search_papers", + description=( + "Search across all indexed research papers. " + "Use this to find relevant information, methodologies, or findings. " + "Input: a search query string." + ), + func=_search, + ) diff --git a/03-research-agent/src/tools/summary_tool.py b/03-research-agent/src/tools/summary_tool.py new file mode 100644 index 0000000..b4e4fe9 --- /dev/null +++ b/03-research-agent/src/tools/summary_tool.py @@ -0,0 +1,96 @@ +""" +src/tools/summary_tool.py +------------------------- +LangChain Tool that returns a structured summary of a single named paper. + +TOOL PATTERN +------------ +This tool is a good example of the lookup pattern: + - Input : a (possibly partial) paper title provided by the agent. + - Process: find the matching PaperMetadata object, format its fields. + - Output : a formatted string the agent can quote in its final answer. + +The agent uses this tool when it already knows the paper's name and wants +detailed information about it, as opposed to the search_papers tool which +is used when the agent is still looking for relevant papers. + +INPUT PARSING: FUZZY TITLE MATCHING +------------------------------------- +We do a case-insensitive substring match: a paper titled +"Attention Is All You Need" will match inputs like "attention", "all you need", +or the full title. This is intentional β€” the agent's input may be an +approximation if it inferred the title from a previous search result. + +If multiple papers match, we return the first one (alphabetical order from +the dict). For a production system you might use fuzzy matching (e.g., +rapidfuzz) or ask the agent to be more specific. +""" + +from langchain.tools import Tool + + +def create_summary_tool(paper_metadata_dict: dict, llm) -> Tool: + """Build a LangChain Tool that summarises a specific paper by title. + + Parameters + ---------- + paper_metadata_dict : dict + Maps paper title (str) β†’ PaperMetadata object. + Built in main.py as {pm.title: pm for pm in paper_metadata_list}. + llm : + Unused here but accepted for API consistency with other tool factories. + + Returns + ------- + Tool + """ + + def _summarize(title_query: str) -> str: + """Find a paper by (partial) title and return a formatted summary.""" + query_lower = title_query.strip().lower() + + # Fuzzy match: find the first paper whose title contains the query + match = None + for title, meta in paper_metadata_dict.items(): + if query_lower in title.lower(): + match = meta + break + + if match is None: + available = ", ".join(paper_metadata_dict.keys()) or "none" + return ( + f"No paper found matching '{title_query}'. " + f"Available papers: {available}" + ) + + # Format the metadata as a readable summary + authors_str = ", ".join(match.authors) if match.authors else "Unknown" + findings_str = ( + "\n".join(f" β€’ {f}" for f in match.key_findings) + if match.key_findings + else " (not extracted)" + ) + limitations_str = ( + "\n".join(f" β€’ {l}" for l in match.limitations) + if match.limitations + else " (not extracted)" + ) + + return ( + f"Title : {match.title}\n" + f"Authors : {authors_str}\n" + f"Year : {match.year or 'Unknown'}\n" + f"Abstract : {match.abstract or 'Not available'}\n\n" + f"Methodology: {match.methodology or 'Not extracted'}\n\n" + f"Key Findings:\n{findings_str}\n\n" + f"Limitations:\n{limitations_str}" + ) + + return Tool( + name="summarize_paper", + description=( + "Get a structured summary of a specific research paper. " + "Input: the paper title (or a unique part of it)." + ), + func=_summarize, + ) diff --git a/04-multimodal-rag/.env.example b/04-multimodal-rag/.env.example new file mode 100644 index 0000000..ac45447 --- /dev/null +++ b/04-multimodal-rag/.env.example @@ -0,0 +1,12 @@ +# OpenAI API Key (required β€” GPT-4V used for image understanding) +OPENAI_API_KEY=your_openai_api_key_here + +# Model for text generation +OPENAI_MODEL=gpt-4 + +# Vision model for image captioning (GPT-4V) +VISION_MODEL=gpt-4-vision-preview + +# Paths for extracted content +IMAGES_OUTPUT_DIR=data/extracted/images +TABLES_OUTPUT_DIR=data/extracted/tables diff --git a/04-multimodal-rag/README.md b/04-multimodal-rag/README.md new file mode 100644 index 0000000..ab66956 --- /dev/null +++ b/04-multimodal-rag/README.md @@ -0,0 +1,235 @@ +# 04 β€” Multimodal RAG + +A Retrieval-Augmented Generation system that understands **text, images, and tables** inside PDF documents. + +> A text-only RAG can't answer *"What does the architecture diagram show?"* β€” but this system can. +> It extracts every content type from the document, builds a dedicated search index per modality, +> routes each query to the right index, and generates an answer that explicitly cites whether the +> information came from a paragraph, a chart, or a data table. + +--- + +## What "Multimodal" Means + +| Query | Text-only RAG | This System | +|---|---|---| +| "Explain the authentication flow" | βœ… Can answer | βœ… Can answer | +| "What does the flowchart in section 3 show?" | ❌ Cannot answer | βœ… Captions image, answers from description | +| "What was Q4 revenue?" | ⚠️ Only if table was also in prose | βœ… Extracts table, generates description | +| "Summarise all key findings" | βœ… Partial | βœ… Draws from text + images + tables | + +--- + +## Architecture + +``` +PDF Document + β”‚ + β–Ό +β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” +β”‚ multimodal_parser β”‚ ── pdfplumber extracts text / tables / images +β””β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ + β”‚ + β”Œβ”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” + β–Ό β–Ό β–Ό +β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” +β”‚ Text β”‚ β”‚ Images β”‚ β”‚ Tables β”‚ +β”‚ Blocks β”‚ β”‚ (PNG files)β”‚ β”‚ (list of lists) β”‚ +β””β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”˜ β””β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”˜ β””β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”˜ + β”‚ β”‚ β”‚ + β”‚ GPT-4V caption LLM description + β”‚ β”‚ β”‚ + β–Ό β–Ό β–Ό +β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” +β”‚ FAISS β”‚ β”‚ FAISS β”‚ β”‚ FAISS β”‚ +β”‚ Text Idx β”‚ β”‚ Image Idx β”‚ β”‚ Table Idx β”‚ +β””β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”˜ β””β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”˜ β””β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ + └─────────────┴─────────────────── β”˜ + β”‚ + β”Œβ”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β” + β”‚ Query Routerβ”‚ ── classifies query β†’ TEXT / IMAGE / TABLE / ALL + β””β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”˜ + β”‚ + β”Œβ”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β” + β”‚Multi-Retrievβ”‚ ── fetches top-k from relevant indexes, merges + β””β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”˜ + β”‚ + β”Œβ”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β” + β”‚ Generator β”‚ ── GPT-4 builds final answer from mixed context + β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ +``` + +--- + +## Model Comparison + +| Modality | Model | Why | +|---|---|---| +| Text | all-MiniLM-L6-v2 | Fast, free, runs locally, strong retrieval quality | +| Images | GPT-4V (`gpt-4-vision-preview`) | Understands visual content β€” charts, diagrams, photos | +| Tables | GPT-3.5 / GPT-4 | Strong at structured data reasoning; converts rows to prose | +| Generation | GPT-4 | Best reasoning across mixed text / image / table context | + +**Alternative (cost-free images):** [LLaVA](https://ollama.com/library/llava) via Ollama runs locally and produces comparable captions without API charges. + +--- + +## Setup + +### 1. Clone and enter the project +```bash +cd 04-multimodal-rag +``` + +### 2. Create and activate a virtual environment +```bash +python -m venv .venv +source .venv/bin/activate # Windows: .venv\Scripts\activate +``` + +### 3. Install dependencies +```bash +pip install -r requirements.txt +``` + +### 4. Configure environment variables +```bash +cp .env.example .env +# Edit .env and set OPENAI_API_KEY +``` + +### 5. Add a PDF document +```bash +cp /path/to/your/document.pdf data/sample_docs/ +``` + +--- + +## Usage + +### Ask a single question +```bash +python main.py --file data/sample_docs/annual_report.pdf \ + --query "What was Q4 revenue?" +``` + +### Skip image captioning during development (saves GPT-4V cost) +```bash +python main.py --file data/sample_docs/annual_report.pdf \ + --query "Summarise the key findings" \ + --skip-images +``` + +### Skip both images and tables (fastest, text-only mode) +```bash +python main.py --file data/sample_docs/report.pdf \ + --query "What is the company's strategy?" \ + --skip-images --skip-tables +``` + +### Interactive Q&A loop +```bash +python main.py --file data/sample_docs/annual_report.pdf --interactive +``` + +### Use a different model +```bash +python main.py --file data/sample_docs/report.pdf \ + --query "Describe the architecture diagram" \ + --model gpt-4o \ + --vision-model gpt-4o +``` + +### Full CLI reference +``` +--file Path to PDF document (required) +--query Question to answer +--model Text generation model (default: gpt-4) +--vision-model Vision model for image captioning (default: gpt-4-vision-preview) +--skip-images Skip GPT-4V image captioning +--skip-tables Skip LLM table description generation +--interactive Interactive Q&A loop after indexing +``` + +--- + +## Cost Considerations + +> ⚠️ **GPT-4V calls cost more. Use `--skip-images` during development.** + +Approximate costs per document (GPT-4V "high" detail, ~1024Γ—1024 images): +- Each image β‰ˆ 765 input tokens β‰ˆ **$0.008–$0.01** at current pricing +- A 50-page document with 20 images β‰ˆ **$0.15–$0.20** in image captioning alone +- Captions are generated once and cached; re-running queries does not re-caption + +**Cost optimisation tips:** +1. `--skip-images` β€” bypass GPT-4V entirely during development +2. Pre-generate captions once, save to JSON, reload on subsequent runs +3. Use LLaVA locally (free) for development, GPT-4V for production +4. Use `gpt-3.5-turbo` for table descriptions (cheaper, still good at structured data) + +--- + +## What This Can vs Cannot Answer + +### CAN answer (multimodal RAG) +- "What does the flowchart in section 3 show?" β†’ image caption search +- "What was Q3 revenue according to the table?" β†’ table description search +- "Describe the network architecture diagram" β†’ image caption search +- "What were the year-over-year growth percentages?" β†’ table search +- "Explain the data pipeline shown in figure 2" β†’ image + text combined + +### CANNOT answer (text-only RAG) +- "What does the flowchart in section 3 show?" β€” no image content indexed +- "What was Q3 revenue?" β€” only if the number also appeared in prose + +--- + +## Comparison with Project 1 (Basic RAG) + +| Feature | Project 1 (Basic RAG) | Project 4 (Multimodal RAG) | +|---|---|---| +| Content types | Text only | Text + Images + Tables | +| Indexes | 1 FAISS index | 3 FAISS indexes | +| Embedding model | all-MiniLM-L6-v2 | all-MiniLM-L6-v2 (same) | +| LLM calls | Generation only | Captioning + Table desc + Classification + Generation | +| Query routing | None (always searches) | Router classifies query β†’ selects relevant index(es) | +| Image understanding | ❌ | βœ… GPT-4V captions | +| Table understanding | ❌ | βœ… LLM-generated descriptions | +| Cost | Low (local embeddings) | Medium–High (GPT-4V for images) | +| Complexity | Low | High | + +**New concepts introduced in this project:** +- Multimodal document parsing (pdfplumber) +- Vision model integration (GPT-4V via base64 image encoding) +- Multiple specialised FAISS indexes (one per modality) +- Query routing / intent classification +- Cross-modality result merging and ranking +- Modality-aware generation prompts + +--- + +## Project Structure + +``` +04-multimodal-rag/ +β”œβ”€β”€ README.md +β”œβ”€β”€ requirements.txt +β”œβ”€β”€ .env.example +β”œβ”€β”€ data/ +β”‚ β”œβ”€β”€ sample_docs/ ← Put your PDF files here +β”‚ └── extracted/ +β”‚ β”œβ”€β”€ images/ ← PNG files extracted from PDFs +β”‚ └── tables/ ← CSV files extracted from PDFs +β”œβ”€β”€ src/ +β”‚ β”œβ”€β”€ multimodal_parser.py ← PDF β†’ text + images + tables +β”‚ β”œβ”€β”€ text_indexer.py ← FAISS index for text chunks +β”‚ β”œβ”€β”€ image_processor.py ← GPT-4V image captioning +β”‚ β”œβ”€β”€ image_indexer.py ← FAISS index for image captions +β”‚ β”œβ”€β”€ table_processor.py ← LLM table β†’ prose description +β”‚ β”œβ”€β”€ table_indexer.py ← FAISS index for table descriptions +β”‚ β”œβ”€β”€ query_router.py ← Classify query β†’ modality(ies) +β”‚ β”œβ”€β”€ multi_retriever.py ← Fetch + merge results from indexes +β”‚ └── generator.py ← Build prompt + call LLM +└── main.py ← CLI entry point +``` diff --git a/04-multimodal-rag/data/extracted/images/.gitkeep b/04-multimodal-rag/data/extracted/images/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/04-multimodal-rag/data/extracted/tables/.gitkeep b/04-multimodal-rag/data/extracted/tables/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/04-multimodal-rag/data/sample_docs/.gitkeep b/04-multimodal-rag/data/sample_docs/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/04-multimodal-rag/main.py b/04-multimodal-rag/main.py new file mode 100644 index 0000000..4d01078 --- /dev/null +++ b/04-multimodal-rag/main.py @@ -0,0 +1,214 @@ +""" +main.py β€” Multimodal RAG Pipeline +----------------------------------- +Orchestrates the full multimodal retrieval-augmented generation pipeline: + 1. Parse PDF β†’ extract text, images, tables + 2. Index each modality in its own FAISS vector store + 3. Route a user query to the relevant index(es) + 4. Retrieve top-k results + 5. Generate a grounded answer + +Usage examples +-------------- +# Full pipeline (index + query) +python main.py --file data/sample_docs/annual_report.pdf --query "What was Q4 revenue?" + +# Skip image captioning to save GPT-4V cost during development +python main.py --file data/sample_docs/annual_report.pdf --query "Summarise the findings" --skip-images + +# Interactive mode β€” ask multiple questions after indexing once +python main.py --file data/sample_docs/annual_report.pdf --interactive +""" + +import argparse +import os +import sys + +from dotenv import load_dotenv +from langchain_openai import ChatOpenAI +from openai import OpenAI + +from src.multimodal_parser import parse_document +from src.text_indexer import index_text_chunks +from src.image_processor import process_all_images +from src.image_indexer import index_image_captions +from src.table_processor import process_all_tables +from src.table_indexer import index_table_descriptions +from src.query_router import classify_query +from src.multi_retriever import retrieve_all, merge_and_rank_results +from src.generator import generate_answer + + +def build_arg_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser( + description="Multimodal RAG: answer questions over text, images, and tables in a PDF." + ) + parser.add_argument( + "--file", + required=True, + help="Path to the PDF document to process.", + ) + parser.add_argument( + "--query", + default=None, + help="Question to answer. Required unless --interactive is set.", + ) + parser.add_argument( + "--model", + default=None, + help="OpenAI model for text generation (default: env OPENAI_MODEL or gpt-4).", + ) + parser.add_argument( + "--vision-model", + default=None, + dest="vision_model", + help="OpenAI vision model for image captioning (default: gpt-4-vision-preview).", + ) + parser.add_argument( + "--skip-images", + action="store_true", + dest="skip_images", + help="Skip GPT-4V image captioning (saves cost/time during development).", + ) + parser.add_argument( + "--skip-tables", + action="store_true", + dest="skip_tables", + help="Skip LLM-based table description generation.", + ) + parser.add_argument( + "--interactive", + action="store_true", + help="After indexing, enter an interactive Q&A loop.", + ) + return parser + + +def answer_query( + query: str, + llm, + text_index, + image_index, + table_index, +) -> str: + """Route β†’ retrieve β†’ generate for a single query.""" + print(f"\n[main] Query: {query}") + + query_types = classify_query(query, llm) + print(f"[main] Router selected modalities: {[qt.value for qt in query_types]}") + + raw_results = retrieve_all( + query=query, + query_types=query_types, + text_index=text_index, + image_index=image_index, + table_index=table_index, + k=3, + ) + + ranked_results = merge_and_rank_results(raw_results) + print(f"[main] Retrieved {len(ranked_results)} result(s) after merge/de-dup.") + + answer = generate_answer(query, ranked_results, llm) + return answer + + +def main() -> None: + load_dotenv() + + parser = build_arg_parser() + args = parser.parse_args() + + # ── Validate arguments ──────────────────────────────────────────────────── + if not args.interactive and args.query is None: + parser.error("--query is required unless --interactive is set.") + + if not os.path.isfile(args.file): + print(f"[main] ERROR: File not found: {args.file}") + sys.exit(1) + + # ── Resolve model names ─────────────────────────────────────────────────── + text_model = args.model or os.getenv("OPENAI_MODEL", "gpt-4") + vision_model = args.vision_model or os.getenv("VISION_MODEL", "gpt-4-vision-preview") + images_dir = os.getenv("IMAGES_OUTPUT_DIR", "data/extracted/images") + tables_dir = os.getenv("TABLES_OUTPUT_DIR", "data/extracted/tables") + + # ── Initialise clients ──────────────────────────────────────────────────── + openai_api_key = os.getenv("OPENAI_API_KEY") + if not openai_api_key: + print("[main] ERROR: OPENAI_API_KEY is not set. Copy .env.example to .env and fill it in.") + sys.exit(1) + + llm = ChatOpenAI(model=text_model, openai_api_key=openai_api_key) + openai_client = OpenAI(api_key=openai_api_key) + + # ── Step 1: Parse document ──────────────────────────────────────────────── + print(f"\n[main] Parsing document: {args.file}") + doc = parse_document(args.file, images_dir=images_dir, tables_dir=tables_dir) + + print( + f"[main] Found {len(doc.text_blocks)} text blocks, " + f"{len(doc.image_paths)} images, " + f"{len(doc.tables)} tables." + ) + + # ── Step 2: Index text ──────────────────────────────────────────────────── + text_index = None + if doc.text_blocks: + print(f"\n[main] Indexing {len(doc.text_blocks)} text blocks …") + text_index = index_text_chunks(doc.text_blocks, index_path="text_faiss_index") + else: + print("[main] No text blocks found β€” skipping text index.") + + # ── Step 3: Caption and index images ────────────────────────────────────── + image_index = None + if not args.skip_images and doc.image_paths: + print(f"\n[main] Captioning {len(doc.image_paths)} image(s) with {vision_model} …") + print(" ⚠️ GPT-4V calls cost more than text models.") + print(" Use --skip-images during development to avoid these charges.") + image_data = process_all_images(doc.image_paths, openai_client, vision_model) + print(f"\n[main] Indexing {len(image_data)} image caption(s) …") + image_index = index_image_captions(image_data, index_path="image_faiss_index") + elif args.skip_images: + print("\n[main] --skip-images set: skipping image captioning and indexing.") + else: + print("\n[main] No images found in document.") + + # ── Step 4: Process and index tables ───────────────────────────────────── + table_index = None + if not args.skip_tables and doc.tables: + print(f"\n[main] Processing {len(doc.tables)} table(s) …") + table_data = process_all_tables(doc.tables, llm, tables_dir=tables_dir) + print(f"[main] Indexing {len(table_data)} table description(s) …") + table_index = index_table_descriptions(table_data, index_path="table_faiss_index") + elif args.skip_tables: + print("\n[main] --skip-tables set: skipping table processing and indexing.") + else: + print("\n[main] No tables found in document.") + + # ── Step 5: Answer query / interactive loop ─────────────────────────────── + print("\n" + "─" * 60) + + if args.interactive: + print("[main] Interactive mode. Type 'quit' or 'exit' to stop.\n") + while True: + try: + query = input("Question: ").strip() + except (EOFError, KeyboardInterrupt): + print("\n[main] Exiting.") + break + if query.lower() in ("quit", "exit", "q"): + print("[main] Exiting.") + break + if not query: + continue + answer = answer_query(query, llm, text_index, image_index, table_index) + print(f"\nAnswer:\n{answer}\n") + print("─" * 60) + else: + answer = answer_query(args.query, llm, text_index, image_index, table_index) + print(f"\nAnswer:\n{answer}\n") + + +if __name__ == "__main__": + main() diff --git a/04-multimodal-rag/requirements.txt b/04-multimodal-rag/requirements.txt new file mode 100644 index 0000000..72c8ada --- /dev/null +++ b/04-multimodal-rag/requirements.txt @@ -0,0 +1,11 @@ +langchain==0.1.20 +langchain-community==0.0.38 +langchain-openai==0.1.6 +faiss-cpu==1.8.0 +sentence-transformers==2.7.0 +openai==1.30.1 +python-dotenv==1.0.1 +unstructured[pdf]==0.13.7 +pdfplumber==0.11.1 +Pillow==10.3.0 +pandas==2.2.2 diff --git a/04-multimodal-rag/src/__init__.py b/04-multimodal-rag/src/__init__.py new file mode 100644 index 0000000..5a953af --- /dev/null +++ b/04-multimodal-rag/src/__init__.py @@ -0,0 +1,2 @@ +# src/__init__.py +# Multimodal RAG package β€” handles text, image, and table modalities diff --git a/04-multimodal-rag/src/generator.py b/04-multimodal-rag/src/generator.py new file mode 100644 index 0000000..dd04e3d --- /dev/null +++ b/04-multimodal-rag/src/generator.py @@ -0,0 +1,98 @@ +""" +generator.py +------------ +Builds a structured prompt from multimodal retrieved results and calls the +LLM to produce the final answer. + +The prompt explicitly labels each piece of context by modality ([TEXT], +[IMAGE DESCRIPTIONS], [TABLE DATA]) so the model can reason about the +*source* of information β€” e.g. "the bar chart (image) shows Q4 was highest, +while the revenue table confirms $1.2M." The model is instructed to +acknowledge which modality informed its answer, which improves transparency +and helps users verify the response against the source document. +""" + + +def generate_answer( + query: str, + retrieved_results: list[dict], + llm, + include_image_refs: bool = True, +) -> str: + """ + Generate a natural-language answer from multimodal retrieved context. + + Parameters + ---------- + query : The user's original question. + retrieved_results : Combined, ranked list from multi_retriever.merge_and_rank_results(). + llm : LangChain LLM / chat model. + include_image_refs: When True, append "See image: <path>" lines for any + image results so the user knows where to look. + + Returns + ------- + Formatted answer string. + """ + # ── Separate results by modality ───────────────────────────────────────── + text_chunks: list[str] = [] + image_captions: list[str] = [] + table_descriptions: list[str] = [] + image_refs: list[str] = [] + + for result in retrieved_results: + modality = result.get("modality", "text") + content = result.get("content", "").strip() + + if modality == "text": + text_chunks.append(content) + elif modality == "image": + image_captions.append(content) + if include_image_refs: + img_path = result.get("metadata", {}).get("image_path", "") + if img_path: + image_refs.append(img_path) + elif modality == "table": + table_descriptions.append(content) + + # ── Build context sections ──────────────────────────────────────────────── + text_section = "\n\n".join(text_chunks) if text_chunks else "No text context available." + image_section = "\n\n".join(image_captions) if image_captions else "No image context available." + table_section = "\n\n".join(table_descriptions) if table_descriptions else "No table context available." + + # ── Assemble prompt ─────────────────────────────────────────────────────── + prompt = f"""\ +Answer the following question based on the provided context from a document. +The context includes text, image descriptions, and table data. + +Context: +[TEXT] +{text_section} + +[IMAGE DESCRIPTIONS] +{image_section} + +[TABLE DATA] +{table_section} + +Question: {query} + +Answer (mention which type of content informed your answer β€” text/image/table):""" + + # ── Call the LLM ────────────────────────────────────────────────────────── + try: + if hasattr(llm, "invoke"): + response = llm.invoke(prompt) + answer = response.content if hasattr(response, "content") else str(response) + else: + answer = llm.predict(prompt) + answer = answer.strip() + except Exception as exc: + answer = f"[generator] LLM call failed: {exc}" + + # ── Append image references if requested ───────────────────────────────── + if include_image_refs and image_refs: + refs_block = "\n".join(f"See image: {path}" for path in image_refs) + answer = f"{answer}\n\n{refs_block}" + + return answer diff --git a/04-multimodal-rag/src/image_indexer.py b/04-multimodal-rag/src/image_indexer.py new file mode 100644 index 0000000..4cb286b --- /dev/null +++ b/04-multimodal-rag/src/image_indexer.py @@ -0,0 +1,124 @@ +""" +image_indexer.py +---------------- +Builds and queries a FAISS vector index over image *captions*. + +Key insight: we are searching text (captions), but returning image references. +The caption is the searchable representation; the metadata carries the file +path so callers can retrieve or display the actual image. This pattern β€” +"index the description, store the reference" β€” is the standard approach for +making non-text assets semantically searchable without specialised multimodal +embedding models. +""" + +from langchain_community.embeddings import HuggingFaceEmbeddings +from langchain_community.vectorstores import FAISS +from langchain.schema import Document + + +_EMBED_MODEL_NAME = "all-MiniLM-L6-v2" + + +def _get_embeddings() -> HuggingFaceEmbeddings: + return HuggingFaceEmbeddings(model_name=_EMBED_MODEL_NAME) + + +def index_image_captions( + image_data: list[dict], + index_path: str = "image_faiss_index", +) -> FAISS: + """ + Embed image captions and save a FAISS index to disk. + + Parameters + ---------- + image_data : List of dicts with keys "image_path" and "caption" + (as returned by image_processor.process_all_images()). + index_path : Directory where FAISS index files are written. + + Returns + ------- + A LangChain FAISS vector store whose documents are captions with + image_path stored in metadata. + """ + if not image_data: + raise ValueError("image_data is empty β€” nothing to index.") + + # page_content is the caption text that will be embedded and searched. + # metadata carries the image_path so we can return the file reference + # when this document is retrieved. + docs = [ + Document( + page_content=item["caption"], + metadata={ + "image_path": item["image_path"], + "image_type": item.get("image_type", "figure"), + "modality": "image", + }, + ) + for item in image_data + ] + + embeddings = _get_embeddings() + vector_store = FAISS.from_documents(docs, embeddings) + vector_store.save_local(index_path) + + print(f"[image_indexer] Indexed {len(docs)} image captions β†’ '{index_path}'") + return vector_store + + +def load_image_index(index_path: str) -> FAISS: + """ + Load a previously saved FAISS image-caption index from disk. + + Parameters + ---------- + index_path : Directory path passed to index_image_captions(). + + Returns + ------- + A LangChain FAISS vector store. + """ + embeddings = _get_embeddings() + vector_store = FAISS.load_local( + index_path, embeddings, allow_dangerous_deserialization=True + ) + print(f"[image_indexer] Loaded image index from '{index_path}'") + return vector_store + + +def search_images( + query: str, + vector_store: FAISS, + k: int = 3, +) -> list[dict]: + """ + Retrieve the top-k image captions most relevant to a query. + + Parameters + ---------- + query : Natural language question or search string. + vector_store : A loaded or freshly-built FAISS image-caption index. + k : Number of results to return. + + Returns + ------- + List of dicts: + { + "caption" : str β€” the generated image description + "image_path" : str β€” path to the original image file + "image_type" : str β€” coarse type (chart, diagram, photo, …) + "score" : float β€” FAISS L2 distance (lower = more similar) + } + """ + raw_results = vector_store.similarity_search_with_score(query, k=k) + + return [ + { + "caption": doc.page_content, + "image_path": doc.metadata.get("image_path", ""), + "image_type": doc.metadata.get("image_type", "figure"), + "score": float(score), + } + for doc, score in raw_results + ] diff --git a/04-multimodal-rag/src/image_processor.py b/04-multimodal-rag/src/image_processor.py new file mode 100644 index 0000000..7b7b46d --- /dev/null +++ b/04-multimodal-rag/src/image_processor.py @@ -0,0 +1,168 @@ +""" +image_processor.py +------------------ +Converts image files into natural-language captions using GPT-4V (or any +compatible OpenAI vision model), making images semantically searchable. + +Why convert images to text captions? +-------------------------------------- +Semantic search engines (FAISS + sentence-transformers) operate in *text* +embedding space. A raw PNG file cannot be compared to a natural-language +query like "architecture diagram of the data pipeline." + +By asking GPT-4V to *describe* an image in detail, we produce a text string +that captures the visual content β€” labels, shapes, data, layout β€” in a form +that a sentence-transformer can embed and a user query can match against. + +How GPT-4V works +----------------- +GPT-4V (gpt-4-vision-preview) is a multimodal large language model that +accepts *both* text and images in the same prompt. Images are supplied as +base64-encoded strings inside a message with role "user". + +The base64 encoding pattern: + 1. Read the image file in binary mode. + 2. Encode with base64.b64encode(raw_bytes).decode("utf-8"). + 3. Pass as {"type": "image_url", "image_url": {"url": "data:image/png;base64,<b64>"}} + inside the messages list. + +Cost consideration ⚠️ +---------------------- +GPT-4V is significantly more expensive than text-only GPT models: + * A 1024Γ—1024 image costs roughly 765 tokens at the "high" detail setting. + * Caption all images once, then **cache** the results to avoid re-captioning + on every run. The main pipeline serialises captions to disk for this reason. + +Alternative: LLaVA +------------------- +LLaVA (Large Language and Vision Assistant) is an open-source vision model +that runs locally with Ollama β€” zero API cost. Swap `caption_image` to call +`ollama.chat(model="llava", ...)` for a cost-free local alternative, at the +expense of some caption quality. +""" + +import base64 +import io + +from PIL import Image + + +def caption_image( + image_path: str, + openai_client, + vision_model: str = "gpt-4-vision-preview", +) -> dict: + """ + Generate a detailed text caption for a single image using GPT-4V. + + Parameters + ---------- + image_path : Path to the image file (PNG, JPEG, etc.). + openai_client : An initialised openai.OpenAI() client instance. + vision_model : OpenAI vision model identifier. + + Returns + ------- + dict with keys: + "image_path" β€” the original path (used as a reference in search results) + "caption" β€” the generated natural-language description + "image_type" β€” coarse type extracted from the caption (e.g. "chart") + """ + # ── Step 1: read and base64-encode the image ───────────────────────────── + with open(image_path, "rb") as f: + raw_bytes = f.read() + + # Normalise to PNG via PIL to ensure a consistent MIME type. + pil_img = Image.open(io.BytesIO(raw_bytes)).convert("RGB") + png_buffer = io.BytesIO() + pil_img.save(png_buffer, format="PNG") + b64_image = base64.b64encode(png_buffer.getvalue()).decode("utf-8") + + # ── Step 2: build the GPT-4V prompt ────────────────────────────────────── + # The data-URI scheme embeds the image directly in the JSON payload. + data_uri = f"data:image/png;base64,{b64_image}" + + prompt_text = ( + "Describe this image in detail for a document search system. " + "Include: what the image shows, any text visible, any data or statistics shown, " + "the type of visualization (chart, diagram, photo, etc.)." + ) + + try: + response = openai_client.chat.completions.create( + model=vision_model, + messages=[ + { + "role": "user", + "content": [ + {"type": "text", "text": prompt_text}, + {"type": "image_url", "image_url": {"url": data_uri}}, + ], + } + ], + max_tokens=512, + ) + caption = response.choices[0].message.content.strip() + + # Derive a coarse image_type by scanning the caption for keywords. + image_type = _infer_image_type(caption) + + except Exception as exc: + # Graceful degradation: if GPT-4V is unavailable (quota, model access, + # or network issue) we return a placeholder so the pipeline keeps running. + # The placeholder still gets indexed; it just won't match queries well. + print(f" [image_processor] GPT-4V unavailable for '{image_path}': {exc}") + caption = f"[Image caption unavailable β€” {image_path}]" + image_type = "unknown" + + return { + "image_path": image_path, + "caption": caption, + "image_type": image_type, + } + + +def process_all_images( + image_paths: list[str], + openai_client, + vision_model: str = "gpt-4-vision-preview", +) -> list[dict]: + """ + Caption every image in the list and return combined results. + + Parameters + ---------- + image_paths : List of file paths returned by the multimodal parser. + openai_client : An initialised openai.OpenAI() client instance. + vision_model : OpenAI vision model identifier. + + Returns + ------- + List of caption dicts (same structure as caption_image() return value). + + Note: captioning is done sequentially to stay within rate limits. + For large document sets, consider batching with a short sleep between calls. + """ + results = [] + for idx, path in enumerate(image_paths, start=1): + print(f" [image_processor] Captioning image {idx}/{len(image_paths)}: {path}") + result = caption_image(path, openai_client, vision_model) + results.append(result) + return results + + +# ── Private helpers ────────────────────────────────────────────────────────── + + +def _infer_image_type(caption: str) -> str: + """Heuristically classify the image type from its caption text.""" + caption_lower = caption.lower() + if any(w in caption_lower for w in ("chart", "bar", "pie", "line graph", "plot")): + return "chart" + if any(w in caption_lower for w in ("diagram", "flowchart", "architecture", "uml")): + return "diagram" + if any(w in caption_lower for w in ("table", "matrix", "grid")): + return "table_image" + if any(w in caption_lower for w in ("photo", "photograph", "picture", "image of")): + return "photo" + return "figure" diff --git a/04-multimodal-rag/src/multi_retriever.py b/04-multimodal-rag/src/multi_retriever.py new file mode 100644 index 0000000..dac6c52 --- /dev/null +++ b/04-multimodal-rag/src/multi_retriever.py @@ -0,0 +1,169 @@ +""" +multi_retriever.py +------------------ +Queries one or more FAISS indexes in parallel based on the query types +returned by the router, then merges the results into a single ranked list. + +The challenge of ranking across modalities +------------------------------------------- +Each FAISS index returns an L2 distance score in the embedding space of +all-MiniLM-L6-v2 (384 dimensions). Because all three indexes use the *same* +embedding model, scores are theoretically comparable β€” but in practice: + + * The distribution of scores differs by modality (short captions tend to + have lower variance than long text chunks). + * A "0.3 score" for a text chunk may not be semantically equivalent to a + "0.3 score" for an image caption. + +Two ranking strategies are discussed here: + + Simple (implemented): interleave results β€” 1 text result, 1 image result, + 1 table result β€” so every modality is represented in the context, regardless + of raw score magnitude. Easy to implement, transparent to the user. + + Complex (alternative): normalise scores per-modality using min-max scaling, + then sort globally. More precise but can still suppress a modality entirely + if its scores are consistently higher (worse) than others. + +We use the simple interleaving approach and let the generator model weight +results contextually via its attention mechanism. + +De-duplication +-------------- +The same text snippet can theoretically appear in multiple indexes (e.g. a +table that was also mentioned verbatim in the text). We de-duplicate on +content string to avoid feeding the same information twice to the generator. +""" + +from langchain_community.vectorstores import FAISS + +from .query_router import QueryType +from .text_indexer import search_text +from .image_indexer import search_images +from .table_indexer import search_tables + + +def retrieve_all( + query: str, + query_types: list[QueryType], + text_index: FAISS | None, + image_index: FAISS | None, + table_index: FAISS | None, + k: int = 3, +) -> list[dict]: + """ + Retrieve top-k results from each relevant index and return a combined list. + + Parameters + ---------- + query : User's natural-language question. + query_types : List of QueryType values from the router. + text_index : Loaded FAISS text index (or None if not built). + image_index : Loaded FAISS image-caption index (or None if not built). + table_index : Loaded FAISS table-description index (or None if not built). + k : Number of results to fetch from each relevant index. + + Returns + ------- + List of result dicts: + { + "content" : str β€” the text content (chunk / caption / description) + "modality" : str β€” "text" | "image" | "table" + "metadata" : dict β€” index-specific metadata (image_path, csv_path, etc.) + "source" : str β€” human-readable source label + "score" : float + } + """ + results: list[dict] = [] + + if QueryType.TEXT in query_types and text_index is not None: + for doc, score in search_text(query, text_index, k=k): + results.append( + { + "content": doc.page_content, + "modality": "text", + "metadata": doc.metadata, + "source": f"text_chunk_{doc.metadata.get('chunk_id', '?')}", + "score": float(score), + } + ) + + if QueryType.IMAGE in query_types and image_index is not None: + for item in search_images(query, image_index, k=k): + results.append( + { + "content": item["caption"], + "modality": "image", + "metadata": { + "image_path": item["image_path"], + "image_type": item["image_type"], + }, + "source": item["image_path"], + "score": item["score"], + } + ) + + if QueryType.TABLE in query_types and table_index is not None: + for item in search_tables(query, table_index, k=k): + results.append( + { + "content": item["description"], + "modality": "table", + "metadata": { + "table_id": item["table_id"], + "csv_path": item["csv_path"], + "page": item["page"], + }, + "source": item["table_id"], + "score": item["score"], + } + ) + + return results + + +def merge_and_rank_results(results: list[dict]) -> list[dict]: + """ + De-duplicate and interleave results across modalities. + + De-duplication is done on content string (exact match). Ranking uses a + simple modality-interleaving strategy: we pick results round-robin from + text β†’ image β†’ table buckets so every modality is represented early in + the context window. + + Parameters + ---------- + results : Combined list from retrieve_all(). + + Returns + ------- + De-duplicated, interleaved list of result dicts. + """ + # De-duplicate on content string. + seen_content: set[str] = set() + unique: list[dict] = [] + for r in results: + if r["content"] not in seen_content: + seen_content.add(r["content"]) + unique.append(r) + + # Separate into modality buckets. + buckets: dict[str, list[dict]] = {"text": [], "image": [], "table": []} + for r in unique: + bucket_key = r["modality"] if r["modality"] in buckets else "text" + buckets[bucket_key].append(r) + + # Sort each bucket by ascending score (lower L2 = more similar). + for bucket in buckets.values(): + bucket.sort(key=lambda x: x["score"]) + + # Interleave: take one from each non-empty bucket in rotation. + merged: list[dict] = [] + order = ["text", "image", "table"] + max_len = max((len(b) for b in buckets.values()), default=0) + for i in range(max_len): + for modality in order: + if i < len(buckets[modality]): + merged.append(buckets[modality][i]) + + return merged diff --git a/04-multimodal-rag/src/multimodal_parser.py b/04-multimodal-rag/src/multimodal_parser.py new file mode 100644 index 0000000..c438816 --- /dev/null +++ b/04-multimodal-rag/src/multimodal_parser.py @@ -0,0 +1,163 @@ +""" +multimodal_parser.py +-------------------- +Parses a PDF document and extracts three distinct modalities: + 1. Text blocks β€” raw text per page, ready for embedding + 2. Images β€” saved as PNG files; need vision model captioning before embedding + 3. Tables β€” extracted as list-of-lists, converted to dict rows for downstream processing + +Why separate modalities before indexing? +----------------------------------------- +Each content type requires a completely different processing pipeline: + + Text β†’ can be chunked and embedded directly with a sentence-transformer. + + Images β†’ embedding raw pixel data is rarely useful for Q&A. Instead we use a + vision model (GPT-4V or LLaVA) to *describe* each image in plain English, + then embed that description. This bridges the "semantic gap" between a + pixel array and a natural-language query. + + Tables β†’ 2-D structured data doesn't embed well as a flat string of cell values. + We convert each table into a short natural-language paragraph + ("Q1 revenue was $1 M, up 12 % year-over-year …") that a sentence- + transformer can compare against a user question. + +Limitations +----------- + * pdfplumber excels at text and table extraction from text-based PDFs. + * Image extraction relies on the PDF's internal XObject stream; quality varies. + Scanned PDFs with no embedded images will yield zero images here. + * Large tables spanning multiple pages may be split; downstream code should + handle partial tables gracefully. +""" + +import os +from dataclasses import dataclass, field +from pathlib import Path + +import pdfplumber +from PIL import Image + + +@dataclass +class ParsedDocument: + """Container for all content extracted from a single PDF.""" + + file_name: str + # One entry per page; each entry is the full text of that page. + text_blocks: list[str] = field(default_factory=list) + # Absolute/relative paths to saved PNG files extracted from the PDF. + image_paths: list[str] = field(default_factory=list) + # Each table is a dict with keys "rows" (list[list]) and "page" (int). + tables: list[dict] = field(default_factory=list) + + +def parse_document( + file_path: str, + images_dir: str = "data/extracted/images", + tables_dir: str = "data/extracted/tables", +) -> ParsedDocument: + """ + Open a PDF and extract text, images, and tables into a ParsedDocument. + + Parameters + ---------- + file_path : Path to the source PDF file. + images_dir : Directory where extracted PNG images are saved. + tables_dir : Directory where extracted tables are saved (CSV, handled downstream). + + Returns + ------- + ParsedDocument with text_blocks, image_paths, and tables populated. + """ + Path(images_dir).mkdir(parents=True, exist_ok=True) + Path(tables_dir).mkdir(parents=True, exist_ok=True) + + file_name = Path(file_path).stem + text_blocks: list[str] = [] + image_paths: list[str] = [] + tables: list[dict] = [] + + with pdfplumber.open(file_path) as pdf: + for page_num, page in enumerate(pdf.pages, start=1): + + # ── 1. TEXT ────────────────────────────────────────────────────────── + # extract_text() returns the full text of the page as a single string. + # We keep one block per page; callers can chunk further if needed. + page_text = page.extract_text() or "" + if page_text.strip(): + text_blocks.append(page_text.strip()) + + # ── 2. TABLES ──────────────────────────────────────────────────────── + # extract_tables() returns a list of tables; each table is a list of + # rows, and each row is a list of cell values (strings or None). + for table_idx, raw_table in enumerate(page.extract_tables()): + # Replace None cells with empty string to avoid downstream errors. + clean_rows = [ + [cell if cell is not None else "" for cell in row] + for row in raw_table + ] + tables.append( + { + "rows": clean_rows, # list[list[str]] + "page": page_num, + "table_index": table_idx, + } + ) + + # ── 3. IMAGES ──────────────────────────────────────────────────────── + # pdfplumber exposes raw image XObjects via page.images. + # Each entry is a dict with keys: "stream" (raw bytes), "x0", "y0", + # "x1", "y1", "width", "height", etc. + # We reconstruct a PIL Image from the raw stream and save as PNG. + for img_idx, img_meta in enumerate(page.images): + try: + raw_stream = img_meta.get("stream") + if raw_stream is None: + continue + + # The stream is a pdfplumber PDFStream object; get its raw data. + raw_data = ( + raw_stream.get_data() + if hasattr(raw_stream, "get_data") + else bytes(raw_stream) + ) + + # Attempt to open as a PIL Image (handles JPEG, PNG, etc.). + import io + try: + pil_img = Image.open(io.BytesIO(raw_data)) + pil_img = pil_img.convert("RGB") # normalise colour mode + except Exception: + # The raw bytes may be raw pixel data rather than an encoded + # image format. Fall back to using width/height from metadata. + width = int(img_meta.get("width", 100)) + height = int(img_meta.get("height", 100)) + pil_img = Image.frombytes("RGB", (width, height), raw_data) + + img_filename = f"{file_name}_page{page_num}_img{img_idx}.png" + img_save_path = os.path.join(images_dir, img_filename) + pil_img.save(img_save_path, format="PNG") + image_paths.append(img_save_path) + + except Exception as exc: + # Non-fatal: log and continue β€” a single bad image shouldn't + # abort extraction of the rest of the document. + print( + f" [parser] Could not extract image {img_idx} " + f"on page {page_num}: {exc}" + ) + + print( + f"[parser] '{file_name}': " + f"{len(text_blocks)} text blocks, " + f"{len(image_paths)} images, " + f"{len(tables)} tables extracted." + ) + + return ParsedDocument( + file_name=file_name, + text_blocks=text_blocks, + image_paths=image_paths, + tables=tables, + ) diff --git a/04-multimodal-rag/src/query_router.py b/04-multimodal-rag/src/query_router.py new file mode 100644 index 0000000..8217fbe --- /dev/null +++ b/04-multimodal-rag/src/query_router.py @@ -0,0 +1,114 @@ +""" +query_router.py +--------------- +Classifies an incoming user query to determine which content modalities +(text, image, table) are most likely to contain the answer, then routes +retrieval to the appropriate FAISS indexes. + +Why routing matters +-------------------- +Without routing every query would hit all three indexes, which: + * Wastes embedding / similarity-search compute. + * Inflates cost when GPT-4V-captioned image indexes are large. + * Dilutes the final context with irrelevant cross-modal results. + +By classifying upfront we retrieve *only* from relevant indexes, reducing +latency and cost while keeping the context focused. + +When to use ALL +---------------- +Complex questions (e.g. "Summarise the findings from section 2") often span +all content types. When the classifier is uncertain it returns ALL, which is +the safe default β€” it is better to over-search than to miss the answer. + +Parsing the LLM output +----------------------- +We ask the LLM to respond with a JSON object `{"types": [...]}` to make +parsing deterministic. If the response cannot be parsed as JSON we fall back +to ALL to maintain correctness at the cost of a broader search. +""" + +import json +import re +from enum import Enum + + +class QueryType(Enum): + TEXT = "TEXT" + IMAGE = "IMAGE" + TABLE = "TABLE" + ALL = "ALL" + + +_CLASSIFICATION_PROMPT = """\ +Classify this query to determine which type of document content would best answer it. + +Query: {query} + +Choose one or more from: +- TEXT: The answer is likely in text paragraphs +- IMAGE: The answer requires looking at a visual/diagram/photo +- TABLE: The answer requires numerical data from a table or chart +- ALL: Search all content types + +Common patterns: +- "show me", "what does X look like", "diagram of" β†’ IMAGE +- "how many", "revenue", "statistics", "percentage", "trend" β†’ TABLE +- "explain", "describe", "what is", "how does" β†’ TEXT +- Complex questions β†’ ALL + +Respond with JSON only: {{"types": ["TEXT", "TABLE"]}} +""" + + +def classify_query(query: str, llm) -> list[QueryType]: + """ + Ask the LLM to classify a user query by relevant content modality. + + Parameters + ---------- + query : The user's natural-language question. + llm : A LangChain LLM / chat model that supports .invoke() or .predict(). + + Returns + ------- + List of QueryType enum values indicating which indexes to search. + Falls back to [QueryType.ALL] on any parsing error. + """ + prompt = _CLASSIFICATION_PROMPT.format(query=query) + + try: + if hasattr(llm, "invoke"): + response = llm.invoke(prompt) + raw = response.content if hasattr(response, "content") else str(response) + else: + raw = llm.predict(prompt) + + # Extract JSON from the response β€” the model may wrap it in markdown fences. + json_match = re.search(r"\{.*?\}", raw, re.DOTALL) + if not json_match: + raise ValueError("No JSON object found in LLM response.") + + parsed = json.loads(json_match.group()) + type_strings: list[str] = parsed.get("types", ["ALL"]) + + query_types = [] + for t in type_strings: + t_upper = t.upper() + if t_upper == "ALL": + # ALL expands to all three specific types. + return [QueryType.TEXT, QueryType.IMAGE, QueryType.TABLE] + try: + query_types.append(QueryType(t_upper)) + except ValueError: + pass # Unknown type string β€” skip. + + if not query_types: + raise ValueError("No valid QueryType values parsed.") + + return query_types + + except Exception as exc: + # Fallback: search everything rather than potentially missing the answer. + print(f" [query_router] Classification failed ({exc}) β€” defaulting to ALL.") + return [QueryType.TEXT, QueryType.IMAGE, QueryType.TABLE] diff --git a/04-multimodal-rag/src/table_indexer.py b/04-multimodal-rag/src/table_indexer.py new file mode 100644 index 0000000..172f898 --- /dev/null +++ b/04-multimodal-rag/src/table_indexer.py @@ -0,0 +1,123 @@ +""" +table_indexer.py +---------------- +Builds and queries a FAISS vector index over natural-language table +descriptions produced by table_processor.py. + +The descriptions are embedded with the same all-MiniLM-L6-v2 model used for +text and image captions, giving us a single, consistent semantic space across +all three modalities. The metadata carries the table_id and csv_path so +callers can retrieve the exact CSV data when needed. +""" + +from langchain_community.embeddings import HuggingFaceEmbeddings +from langchain_community.vectorstores import FAISS +from langchain.schema import Document + + +_EMBED_MODEL_NAME = "all-MiniLM-L6-v2" + + +def _get_embeddings() -> HuggingFaceEmbeddings: + return HuggingFaceEmbeddings(model_name=_EMBED_MODEL_NAME) + + +def index_table_descriptions( + table_data: list[dict], + index_path: str = "table_faiss_index", +) -> FAISS: + """ + Embed table descriptions and save a FAISS index to disk. + + Parameters + ---------- + table_data : List of dicts with keys "table_id", "csv_path", + "description" (as returned by table_processor.process_all_tables()). + index_path : Directory where FAISS index files are written. + + Returns + ------- + A LangChain FAISS vector store whose documents are descriptions with + table_id and csv_path stored in metadata. + """ + if not table_data: + raise ValueError("table_data is empty β€” nothing to index.") + + docs = [ + Document( + page_content=item["description"], + metadata={ + "table_id": item["table_id"], + "csv_path": item["csv_path"], + "page": item.get("page", 0), + "modality": "table", + }, + ) + for item in table_data + ] + + embeddings = _get_embeddings() + vector_store = FAISS.from_documents(docs, embeddings) + vector_store.save_local(index_path) + + print(f"[table_indexer] Indexed {len(docs)} table descriptions β†’ '{index_path}'") + return vector_store + + +def load_table_index(index_path: str) -> FAISS: + """ + Load a previously saved FAISS table-description index from disk. + + Parameters + ---------- + index_path : Directory path passed to index_table_descriptions(). + + Returns + ------- + A LangChain FAISS vector store. + """ + embeddings = _get_embeddings() + vector_store = FAISS.load_local( + index_path, embeddings, allow_dangerous_deserialization=True + ) + print(f"[table_indexer] Loaded table index from '{index_path}'") + return vector_store + + +def search_tables( + query: str, + vector_store: FAISS, + k: int = 3, +) -> list[dict]: + """ + Retrieve the top-k table descriptions most relevant to a query. + + Parameters + ---------- + query : Natural language question or search string. + vector_store : A loaded or freshly-built FAISS table index. + k : Number of results to return. + + Returns + ------- + List of dicts: + { + "description" : str β€” natural-language summary of the table + "table_id" : str β€” unique table identifier + "csv_path" : str β€” path to the raw CSV file + "page" : int β€” source page in the original document + "score" : float β€” FAISS L2 distance (lower = more similar) + } + """ + raw_results = vector_store.similarity_search_with_score(query, k=k) + + return [ + { + "description": doc.page_content, + "table_id": doc.metadata.get("table_id", ""), + "csv_path": doc.metadata.get("csv_path", ""), + "page": doc.metadata.get("page", 0), + "score": float(score), + } + for doc, score in raw_results + ] diff --git a/04-multimodal-rag/src/table_processor.py b/04-multimodal-rag/src/table_processor.py new file mode 100644 index 0000000..2ef01ce --- /dev/null +++ b/04-multimodal-rag/src/table_processor.py @@ -0,0 +1,162 @@ +""" +table_processor.py +------------------ +Converts extracted tables into natural-language descriptions suitable for +semantic search, while also persisting the raw data as CSV files. + +The challenge of searching tabular data semantically +----------------------------------------------------- +Tables are inherently 2-D structured objects. A flat string representation +like "Q1 | 1000000 | Q2 | 1200000" is syntactically correct but semantically +opaque to a sentence-transformer trained on prose. + +Why generate natural-language descriptions? +-------------------------------------------- +Text embedding models (and LLMs used for generation) understand sentences +like "Q1 revenue was $1 M, representing 12 % growth quarter-over-quarter" +far better than a raw CSV row. By asking an LLM to paraphrase a table, we +convert the structured 2-D data into a format that: + 1. Embeds meaningfully with all-MiniLM-L6-v2. + 2. Matches natural-language queries ("What were the Q1 sales figures?"). + 3. Can be injected verbatim into a generation prompt for the final answer. + +Why keep the raw CSV too? +-------------------------- +Natural-language descriptions are lossy β€” they summarise, not enumerate. +For exact queries ("What was the exact revenue in row 7, column 3?") or for +programmatic downstream use (pandas, Excel), the CSV is the ground truth. +We store both and surface whichever is appropriate. +""" + +import csv +import os +from pathlib import Path + + +def table_to_description(table: list[list], llm) -> str: + """ + Use an LLM to convert a raw table (list of rows) into a prose description. + + Parameters + ---------- + table : 2-D list where table[0] is typically the header row and + subsequent rows are data rows. Cell values are strings. + llm : A LangChain chat/LLM object that supports .invoke() or .predict(). + + Returns + ------- + A natural-language string describing the table's content and structure. + """ + if not table: + return "Empty table." + + # Format the table as a plain-text grid so the LLM can parse it easily. + table_str = _format_table_as_text(table) + + prompt = ( + "Convert this table to a natural language description for search purposes. " + "Describe what data the table contains, its structure, and key values.\n\n" + f"Table:\n{table_str}" + ) + + try: + # Support both .invoke() (LangChain β‰₯ 0.1) and .predict() (legacy). + if hasattr(llm, "invoke"): + response = llm.invoke(prompt) + # .invoke() may return a string or an AIMessage depending on the model. + description = response.content if hasattr(response, "content") else str(response) + else: + description = llm.predict(prompt) + return description.strip() + + except Exception as exc: + # Non-fatal fallback: return the raw text representation. + print(f" [table_processor] LLM unavailable for table description: {exc}") + return f"Table data:\n{table_str}" + + +def save_table_as_csv(table: list[list], output_path: str) -> None: + """ + Write a 2-D list to a CSV file. + + Parameters + ---------- + table : 2-D list of cell values. + output_path : Full file path for the output CSV (directory must exist). + """ + Path(output_path).parent.mkdir(parents=True, exist_ok=True) + with open(output_path, "w", newline="", encoding="utf-8") as f: + writer = csv.writer(f) + writer.writerows(table) + + +def process_all_tables( + tables: list[dict], + llm, + tables_dir: str = "data/extracted/tables", +) -> list[dict]: + """ + Process every table extracted by the parser: save as CSV and generate a + natural-language description via the LLM. + + Parameters + ---------- + tables : List of table dicts as returned by multimodal_parser β€” + each has keys "rows" (list[list]) and "page" (int). + llm : LangChain LLM / chat model for description generation. + tables_dir : Directory where CSV files are written. + + Returns + ------- + List of dicts: + { + "table_id" : str β€” unique identifier, e.g. "table_p2_0" + "csv_path" : str β€” path to the saved CSV file + "description" : str β€” LLM-generated natural-language summary + "raw_table" : list[list] β€” original row data + "page" : int β€” source page number + } + """ + Path(tables_dir).mkdir(parents=True, exist_ok=True) + results = [] + + for idx, table_meta in enumerate(tables): + raw_rows = table_meta.get("rows", []) + page = table_meta.get("page", 0) + table_id = f"table_p{page}_{table_meta.get('table_index', idx)}" + + csv_filename = f"{table_id}.csv" + csv_path = os.path.join(tables_dir, csv_filename) + + # Persist raw data. + save_table_as_csv(raw_rows, csv_path) + + # Generate natural-language description. + print( + f" [table_processor] Describing table {idx + 1}/{len(tables)} " + f"(page {page}) …" + ) + description = table_to_description(raw_rows, llm) + + results.append( + { + "table_id": table_id, + "csv_path": csv_path, + "description": description, + "raw_table": raw_rows, + "page": page, + } + ) + + return results + + +# ── Private helpers ────────────────────────────────────────────────────────── + + +def _format_table_as_text(table: list[list]) -> str: + """Render a 2-D list as a plain-text grid with | separators.""" + lines = [] + for row in table: + lines.append(" | ".join(str(cell) for cell in row)) + return "\n".join(lines) diff --git a/04-multimodal-rag/src/text_indexer.py b/04-multimodal-rag/src/text_indexer.py new file mode 100644 index 0000000..88383e6 --- /dev/null +++ b/04-multimodal-rag/src/text_indexer.py @@ -0,0 +1,106 @@ +""" +text_indexer.py +--------------- +Builds and queries a FAISS vector index for plain-text chunks. + +Same embedding approach as Project 1 β€” this is the text modality index. + +We reuse the all-MiniLM-L6-v2 sentence-transformer model because: + * It is fast and runs fully locally (no API calls, no cost). + * Its 384-dimensional embeddings strike a good balance between quality + and memory / speed. + * It has proven strong retrieval performance on diverse Q&A benchmarks. + +The only difference from Project 1 is that this index is *one of three* +indexes in the multimodal pipeline. The query router decides whether to +hit this index, the image index, the table index, or all three. +""" + +from langchain_community.embeddings import HuggingFaceEmbeddings +from langchain_community.vectorstores import FAISS +from langchain.schema import Document + + +# Shared embedding model β€” instantiated once to avoid repeated model loading. +_EMBED_MODEL_NAME = "all-MiniLM-L6-v2" + + +def _get_embeddings() -> HuggingFaceEmbeddings: + """Return a HuggingFaceEmbeddings instance for all-MiniLM-L6-v2.""" + return HuggingFaceEmbeddings(model_name=_EMBED_MODEL_NAME) + + +def index_text_chunks( + text_blocks: list[str], + index_path: str = "text_faiss_index", +) -> FAISS: + """ + Embed a list of text strings and persist them as a FAISS index. + + Parameters + ---------- + text_blocks : Raw text strings (one per page, paragraph, or chunk). + index_path : Directory path where the FAISS index files are saved. + + Returns + ------- + A LangChain FAISS vector store ready for similarity search. + """ + if not text_blocks: + raise ValueError("text_blocks is empty β€” nothing to index.") + + # Wrap each string in a LangChain Document so we can store metadata. + # We record the chunk number so retrieved results can be traced back. + docs = [ + Document(page_content=block, metadata={"chunk_id": i, "modality": "text"}) + for i, block in enumerate(text_blocks) + ] + + embeddings = _get_embeddings() + + # FAISS.from_documents embeds all docs in a single batch and builds + # the index in memory, then we persist it to disk. + vector_store = FAISS.from_documents(docs, embeddings) + vector_store.save_local(index_path) + + print(f"[text_indexer] Indexed {len(docs)} text chunks β†’ '{index_path}'") + return vector_store + + +def load_text_index(index_path: str) -> FAISS: + """ + Load a previously saved FAISS text index from disk. + + Parameters + ---------- + index_path : Directory path that was passed to index_text_chunks(). + + Returns + ------- + A LangChain FAISS vector store. + """ + embeddings = _get_embeddings() + vector_store = FAISS.load_local( + index_path, embeddings, allow_dangerous_deserialization=True + ) + print(f"[text_indexer] Loaded text index from '{index_path}'") + return vector_store + + +def search_text(query: str, vector_store: FAISS, k: int = 3) -> list: + """ + Retrieve the top-k most relevant text chunks for a query. + + Parameters + ---------- + query : Natural language question or search string. + vector_store : A loaded or freshly-built FAISS text index. + k : Number of results to return. + + Returns + ------- + List of (Document, score) tuples ordered by descending similarity. + Lower L2 distance = higher similarity in FAISS. + """ + results = vector_store.similarity_search_with_score(query, k=k) + return results diff --git a/05-agentic-rag-realtime/.env.example b/05-agentic-rag-realtime/.env.example new file mode 100644 index 0000000..1fa0393 --- /dev/null +++ b/05-agentic-rag-realtime/.env.example @@ -0,0 +1,19 @@ +# OpenAI API Key (required) +OPENAI_API_KEY=your_openai_api_key_here + +# Model to use (gpt-4 recommended for tool selection accuracy) +OPENAI_MODEL=gpt-4 + +# Web Search - Tavily API (free tier: 1000 requests/month) +# Sign up at: https://tavily.com +TAVILY_API_KEY=your_tavily_api_key_here + +# OR use SerpAPI instead +# SERPAPI_API_KEY=your_serpapi_key_here + +# OpenWeatherMap API (free tier: 60 calls/min) +# Sign up at: https://openweathermap.org/api +OPENWEATHERMAP_API_KEY=your_openweathermap_key_here + +# Knowledge base path +KNOWLEDGE_BASE_DIR=data/knowledge_base diff --git a/05-agentic-rag-realtime/README.md b/05-agentic-rag-realtime/README.md new file mode 100644 index 0000000..0f3171e --- /dev/null +++ b/05-agentic-rag-realtime/README.md @@ -0,0 +1,297 @@ +# Agentic RAG with Real-Time Tools + +A LangChain agent that decides **which tool to use** for every question β€” searching your internal documents, fetching live stock prices, current weather, Wikipedia articles, or the live web, depending on what the question needs. + +--- + +## Agentic RAG vs Standard RAG + +| Standard RAG | Agentic RAG | +|---|---| +| Always searches FAISS | Decides which tool(s) to use | +| One retrieval step | Multiple steps if needed | +| Only knows stored documents | Can fetch live data | +| Fast (single LLM call) | Slower (multi-step reasoning) | +| Deterministic path | Dynamic, question-driven path | + +**When to use Agentic RAG:** When users ask mixed questions that combine internal knowledge with live data (e.g. "How does today's AAPL price compare to our internal valuation model?"). + +**When to use Standard RAG:** High-volume, low-latency workloads where every question is about static documents. + +--- + +## Architecture + +``` +User Question + β”‚ + β–Ό +β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” +β”‚ LLM Agent β”‚ ← reads tool descriptions to decide what to call +β””β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”˜ + β”‚ ReAct Loop: Reason β†’ Act β†’ Observe β†’ Repeat + β”‚ + β”Œβ”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” + β”‚ Tool Registry β”‚ + β”‚ β”‚ + β”‚ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚ + β”‚ β”‚ search_knowledge β”‚ β”‚ get_stock_data β”‚ β”‚ + β”‚ β”‚ _base β”‚ β”‚ (yfinance) β”‚ β”‚ + β”‚ β”‚ (FAISS index) β”‚ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚ + β”‚ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚ + β”‚ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚ + β”‚ β”‚ web_search β”‚ β”‚ get_weather β”‚ β”‚ + β”‚ β”‚ (Tavily API) β”‚ β”‚ (OpenWeather) β”‚ β”‚ + β”‚ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚ + β”‚ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚ + β”‚ β”‚ search_wikipedia β”‚ β”‚ + β”‚ β”‚ (Wikipedia API) β”‚ β”‚ + β”‚ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚ + β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ + β”‚ + β–Ό + Final Answer + Sources +``` + +--- + +## How the Agent Decides Which Tool to Use + +The agent's LLM reads every tool's `name` and `description` string before responding. Here is the decision process for a typical question: + +**Question:** *"What is AAPL's current price and how does it compare to our internal forecast?"* + +``` +Step 1 β€” REASON: + "This question needs current stock data AND internal documents. + I should call get_stock_data first, then search_knowledge_base." + +Step 2 β€” ACT: get_stock_data("AAPL") +Step 3 β€” OBSERVE: "Stock: AAPL | Price: $182.50 | ..." + +Step 4 β€” REASON: + "Now I have the live price. I need the internal forecast from the KB." + +Step 5 β€” ACT: search_knowledge_base("AAPL valuation forecast") +Step 6 β€” OBSERVE: "Found in knowledge base: 1. Q3 forecast values AAPL at..." + +Step 7 β€” REASON: + "I have both pieces of information. I can now compose a full answer." + +Step 8 β€” FINAL ANSWER (no more tool calls needed) +``` + +The key insight: **the tool description IS the routing logic**. A clear description like *"Use this for questions about internal policies"* routes the agent correctly without any if/else code. + +--- + +## Setup + +### 1. Clone and install + +```bash +cd 05-agentic-rag-realtime +pip install -r requirements.txt +``` + +### 2. Configure API keys + +```bash +cp .env.example .env +# Edit .env with your keys +``` + +### API Key Guide + +| Service | Required? | Free Tier | Sign-up Link | +|---|---|---|---| +| **OpenAI** | βœ… Yes | Pay-per-use | [platform.openai.com](https://platform.openai.com) | +| **Tavily** (web search) | ❌ Optional | 1,000 searches/month | [tavily.com](https://tavily.com) | +| **OpenWeatherMap** | ❌ Optional | 60 calls/min | [openweathermap.org/api](https://openweathermap.org/api) | +| **yfinance** (finance) | βœ… Built-in | Unlimited* | No key needed | +| **Wikipedia** | βœ… Built-in | Unlimited | No key needed | + +*yfinance scrapes Yahoo Finance; data may be delayed 15 minutes. + +### 3. Add documents to the knowledge base (optional) + +```bash +# Drop .pdf or .txt files here: +data/knowledge_base/ +``` + +### 4. Run + +```bash +# Single query +python main.py --query "What is AAPL's current stock price?" + +# Interactive session +python main.py --interactive + +# Without conversation memory +python main.py --interactive --no-memory + +# Hide the reasoning trace +python main.py --query "Weather in Tokyo" --no-verbose +``` + +--- + +## Example Multi-Tool Queries + +### Finance + RAG +**Query:** *"What is AAPL price and how does it compare to our internal valuation?"* + +``` +Agent calls: get_stock_data("AAPL") β†’ search_knowledge_base("AAPL valuation") +``` + +### Weather + RAG +**Query:** *"What's the weather in London and should we proceed per our event guidelines?"* + +``` +Agent calls: get_weather("London") β†’ search_knowledge_base("event guidelines weather policy") +``` + +### Web Search + RAG +**Query:** *"What are latest AI news stories relevant to our strategy?"* + +``` +Agent calls: web_search("latest AI news 2024") β†’ search_knowledge_base("AI strategy") +``` + +### Wikipedia + Finance +**Query:** *"What does Wikipedia say about transformer models and how is NVDA performing?"* + +``` +Agent calls: search_wikipedia("transformer neural network") β†’ get_stock_data("NVDA") +``` + +--- + +## Cost and Rate Limits + +| Tool | Cost | Rate Limit | +|---|---|---| +| `search_knowledge_base` | Free (local FAISS) | Unlimited | +| `get_stock_data` | Free (yfinance) | ~2,000 req/hour* | +| `search_wikipedia` | Free | Unlimited | +| `web_search` | Free tier: 1,000/month | 1 req/sec | +| `get_weather` | Free tier: 60 calls/min | 1,000,000/month | +| OpenAI GPT-4 | ~$0.03/1K tokens | Depends on tier | +| OpenAI GPT-3.5 | ~$0.002/1K tokens | Depends on tier | + +*Yahoo Finance has unofficial rate limits; excessive calls may trigger temporary blocks. + +--- + +## How to Add a Custom Tool + +Adding a new tool requires three steps: write the function, wrap it in a `Tool`, and register it. + +### Step 1 β€” Write the function + +Create `src/tools/my_tool.py`: + +```python +from langchain.tools import Tool + +def my_custom_function(input_str: str) -> str: + # Your logic here β€” always string in, string out + return "Result: ..." + +def create_my_tool() -> Tool: + return Tool( + name="my_tool_name", + func=my_custom_function, + description=( + "What this tool does and when to use it. " + "Input: what to provide (be specific about format)." + ), + ) +``` + +**Rules for a good tool:** +- Function signature: always `(input_str: str) -> str` +- Never raise exceptions β€” catch errors and return a message string +- Description must say WHAT the tool does, WHEN to use it, and WHAT input format it expects + +### Step 2 β€” Register it in `tool_registry.py` + +```python +from src.tools.my_tool import create_my_tool + +def build_tool_registry(vector_store, config): + tools = [...] # existing tools + tools.append(create_my_tool()) + return tools +``` + +### Step 3 β€” Test it + +```bash +python main.py --query "A question that should trigger your new tool" +``` + +With `--verbose` (default) you'll see whether the agent picked your tool and what it returned. + +--- + +## Troubleshooting + +| Problem | Solution | +|---|---| +| `OPENAI_API_KEY not set` | Add key to `.env` file | +| Agent always uses `search_knowledge_base` | Knowledge base is empty β€” agent defaults to it. Add docs to `data/knowledge_base/` | +| `yfinance` returns `None` for price | Market may be closed; try a major ticker like AAPL | +| Web search returns mock message | Add `TAVILY_API_KEY` to `.env` | +| Weather returns mock data | Add `OPENWEATHERMAP_API_KEY` to `.env` | +| Agent uses wrong tool | Check tool descriptions in `src/tools/` β€” make them more specific | +| `FAISS` index error on reload | Delete `data/knowledge_base/.faiss_index/` and re-run | +| Agent loops > 8 times | Increase `max_iterations` in `src/agent.py` or simplify your query | +| `sentence-transformers` slow on first run | It downloads the model (~80 MB) once; subsequent runs are fast | + +### Changing the LLM + +```bash +# Use GPT-3.5 instead of GPT-4 (cheaper, slightly less accurate tool selection) +python main.py --model gpt-3.5-turbo --interactive +``` + +### Viewing the reasoning trace + +The `--verbose` flag (on by default) prints every Thought β†’ Action β†’ Observation cycle. This is the best way to debug unexpected answers: + +``` +Thought: I need current stock data for AAPL. +Action: get_stock_data +Action Input: AAPL +Observation: Stock: AAPL | Price: $182.50 | ... +Thought: I now have the price. Let me check the knowledge base for the internal valuation. +... +``` + +--- + +## Project Structure + +``` +05-agentic-rag-realtime/ +β”œβ”€β”€ main.py # Entry point β€” pipeline + CLI +β”œβ”€β”€ requirements.txt +β”œβ”€β”€ .env.example # Copy to .env and fill in keys +β”œβ”€β”€ data/ +β”‚ └── knowledge_base/ # Drop .pdf and .txt files here +β”œβ”€β”€ src/ +β”‚ β”œβ”€β”€ knowledge_indexer.py # FAISS index builder (reused from Project 1) +β”‚ β”œβ”€β”€ tool_registry.py # Assembles all tools into a list +β”‚ β”œβ”€β”€ agent.py # LangChain agent with ReAct loop +β”‚ β”œβ”€β”€ response_formatter.py # Formats output with sources and trace +β”‚ └── tools/ +β”‚ β”œβ”€β”€ rag_tool.py # Wraps FAISS search as a Tool +β”‚ β”œβ”€β”€ finance_tool.py # yfinance stock data +β”‚ β”œβ”€β”€ weather_tool.py # OpenWeatherMap current weather +β”‚ β”œβ”€β”€ web_search_tool.py # Tavily live web search +β”‚ └── wiki_tool.py # Wikipedia summaries +``` diff --git a/05-agentic-rag-realtime/data/knowledge_base/.gitkeep b/05-agentic-rag-realtime/data/knowledge_base/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/05-agentic-rag-realtime/main.py b/05-agentic-rag-realtime/main.py new file mode 100644 index 0000000..ba2c4af --- /dev/null +++ b/05-agentic-rag-realtime/main.py @@ -0,0 +1,233 @@ +""" +main.py β€” Agentic RAG with Real-Time Tools + +Entry point for the 05-agentic-rag-realtime project. Assembles the full +pipeline: knowledge base indexing β†’ tool registry β†’ agent β†’ interactive Q&A. + +Usage examples: + # Single query + python main.py --query "What is AAPL's current stock price?" + + # Interactive multi-turn session + python main.py --interactive + + # Use a specific knowledge base directory + python main.py --kb-dir /path/to/docs --interactive + + # Disable conversation memory (stateless mode) + python main.py --interactive --no-memory + + # Hide the agent's reasoning trace + python main.py --query "Weather in Tokyo" --no-verbose +""" + +import argparse +import os +import sys + +from dotenv import load_dotenv + +# Load .env before importing project modules (they may read env vars at import time). +load_dotenv() + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _check_api_keys(config: dict) -> None: + """ + Print a startup banner showing which tools are ready / missing API keys. + This helps users quickly see what's available before running queries. + """ + openai_key = config.get("openai_api_key") + tavily_key = config.get("tavily_api_key") + owm_key = config.get("openweathermap_api_key") + + rag_status = "βœ… RAG Tool ready" + finance_status = "βœ… Finance Tool ready (yfinance β€” no key needed)" + wiki_status = "βœ… Wikipedia Tool ready (no key needed)" + web_status = "βœ… Web Search ready" if tavily_key else "❌ Web Search (no TAVILY_API_KEY)" + weather_status = "βœ… Weather Tool ready" if owm_key else "⚠️ Weather Tool (mock mode β€” no OPENWEATHERMAP_API_KEY)" + openai_status = "βœ… OpenAI connected" if openai_key else "❌ OpenAI (no OPENAI_API_KEY β€” required)" + + print("\n" + "=" * 60) + print(" Agentic RAG β€” Tool Availability") + print("=" * 60) + for status in [openai_status, rag_status, finance_status, wiki_status, web_status, weather_status]: + print(f" {status}") + print("=" * 60) + + if not openai_key: + print("\n[ERROR] OPENAI_API_KEY is required. Add it to your .env file.") + sys.exit(1) + + +def _print_example_queries() -> None: + """Print suggested example queries so new users know what to try.""" + print("\nExample queries:") + print(' β€’ "What is the current price of AAPL?"') + print(' β€’ "What\'s the weather in London today?"') + print(' β€’ "Search Wikipedia for transformer neural networks"') + print(' β€’ "What does our internal strategy document say about AI adoption?"') + print(' β€’ "What is AAPL price and how does it compare to our internal valuation?"') + print(' β€’ "What are latest AI news stories relevant to our strategy?"') + print() + + +def _build_config() -> dict: + """Read all configuration from environment variables and return as a dict.""" + return { + "openai_api_key": os.getenv("OPENAI_API_KEY", ""), + "openai_model": os.getenv("OPENAI_MODEL", "gpt-4"), + "tavily_api_key": os.getenv("TAVILY_API_KEY", ""), + "openweathermap_api_key": os.getenv("OPENWEATHERMAP_API_KEY", ""), + "domain_description": "internal company documents and knowledge base", + } + + +def _parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Agentic RAG with real-time tools (finance, weather, web search, Wikipedia).", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=__doc__, + ) + parser.add_argument( + "--kb-dir", + default=os.getenv("KNOWLEDGE_BASE_DIR", "data/knowledge_base"), + help="Directory containing .pdf and .txt files to index (default: data/knowledge_base)", + ) + parser.add_argument( + "--model", + default=None, + help="OpenAI model name to use (overrides OPENAI_MODEL env var, default: gpt-4)", + ) + parser.add_argument( + "--query", + default=None, + help="Run a single query and exit.", + ) + parser.add_argument( + "--interactive", + action="store_true", + help="Start an interactive multi-turn Q&A session.", + ) + parser.add_argument( + "--no-memory", + action="store_true", + help="Disable conversation memory (each query is independent).", + ) + verbose_group = parser.add_mutually_exclusive_group() + verbose_group.add_argument( + "--verbose", + dest="verbose", + action="store_true", + default=True, + help="Show the agent's reasoning trace (default: on).", + ) + verbose_group.add_argument( + "--no-verbose", + dest="verbose", + action="store_false", + help="Hide the agent's reasoning trace.", + ) + return parser.parse_args() + + +# --------------------------------------------------------------------------- +# Main pipeline +# --------------------------------------------------------------------------- + +def main() -> None: + args = _parse_args() + config = _build_config() + + # Allow --model to override the environment variable. + if args.model: + config["openai_model"] = args.model + + # --- Print startup banner --- + _check_api_keys(config) + + # --- Step 1: Index / load the knowledge base --- + print(f"\n[Setup] Indexing knowledge base from '{args.kb_dir}' …") + from src.knowledge_indexer import index_knowledge_base # noqa: PLC0415 + vector_store = index_knowledge_base( + kb_dir=args.kb_dir, + index_path=os.path.join(args.kb_dir, ".faiss_index"), + ) + + # --- Step 2: Build tool registry --- + print("[Setup] Building tool registry …") + from src.tool_registry import build_tool_registry, get_tool_descriptions # noqa: PLC0415 + tools = build_tool_registry(vector_store, config) + print(get_tool_descriptions(tools)) + + # --- Step 3: Instantiate the LLM --- + print(f"\n[Setup] Connecting to OpenAI model '{config['openai_model']}' …") + from langchain_openai import ChatOpenAI # noqa: PLC0415 + llm = ChatOpenAI( + model=config["openai_model"], + openai_api_key=config["openai_api_key"], + temperature=0, # deterministic tool selection + ) + + # --- Step 4: Create agent --- + use_memory = not args.no_memory + print(f"[Setup] Creating agent (memory={'on' if use_memory else 'off'}, verbose={args.verbose}) …") + from src.agent import create_agent, run_agent_query # noqa: PLC0415 + agent = create_agent(tools, llm, memory=use_memory, verbose=args.verbose) + + # --- Step 5: Run query/interactive loop --- + from src.response_formatter import ( # noqa: PLC0415 + format_response, + extract_tools_from_steps, + ) + + def _run_and_display(query: str) -> None: + """Run a single query and print the formatted response.""" + print(f"\n[Query] {query}\n") + try: + result = agent.invoke({"input": query}) + answer = result.get("output", str(result)) + steps = result.get("intermediate_steps", []) + tools_used = extract_tools_from_steps(steps) + except Exception as exc: + answer = f"Agent encountered an error: {exc}" + tools_used = [] + + print("\n" + format_response(answer, tools_used)) + + if args.query: + # Single-shot mode: run one query and exit. + _run_and_display(args.query) + + elif args.interactive: + # Interactive mode: loop until user types "quit" or "exit". + _print_example_queries() + print("Type 'quit' or 'exit' to end the session.\n") + + while True: + try: + user_input = input("You: ").strip() + except (EOFError, KeyboardInterrupt): + print("\nGoodbye!") + break + + if not user_input: + continue + if user_input.lower() in {"quit", "exit", "q"}: + print("Goodbye!") + break + + _run_and_display(user_input) + + else: + # No mode selected β€” show help and example queries. + print("\nNo query mode selected. Use --query or --interactive.") + _print_example_queries() + print("Run with --help for all options.") + + +if __name__ == "__main__": + main() diff --git a/05-agentic-rag-realtime/requirements.txt b/05-agentic-rag-realtime/requirements.txt new file mode 100644 index 0000000..08694ad --- /dev/null +++ b/05-agentic-rag-realtime/requirements.txt @@ -0,0 +1,11 @@ +langchain==0.1.20 +langchain-community==0.0.38 +langchain-openai==0.1.6 +faiss-cpu==1.8.0 +sentence-transformers==2.7.0 +openai==1.30.1 +python-dotenv==1.0.1 +yfinance==0.2.38 +wikipedia==1.4.0 +requests==2.31.0 +tavily-python==0.3.3 diff --git a/05-agentic-rag-realtime/src/__init__.py b/05-agentic-rag-realtime/src/__init__.py new file mode 100644 index 0000000..6b4afd1 --- /dev/null +++ b/05-agentic-rag-realtime/src/__init__.py @@ -0,0 +1,2 @@ +# src/__init__.py +# Makes src/ a Python package so imports like `from src.agent import create_agent` work. diff --git a/05-agentic-rag-realtime/src/agent.py b/05-agentic-rag-realtime/src/agent.py new file mode 100644 index 0000000..dc511a9 --- /dev/null +++ b/05-agentic-rag-realtime/src/agent.py @@ -0,0 +1,148 @@ +""" +src/agent.py + +Assembles the LangChain agent executor that ties together the LLM, all tools, +and optional conversation memory. + +THE ReAct LOOP (Reason + Act): + Every time the agent receives a question it goes through repeated cycles: + 1. REASON β€” "What do I need to answer this? Which tool should I call?" + 2. ACT β€” Calls a tool with a specific input string. + 3. OBSERVE β€” Reads the tool's output. + 4. REPEAT β€” Reasons again with the new information; stops when confident. + + This is fundamentally different from standard RAG which does a single + FAISS search every time regardless of the question type. + +AGENT TYPES: + β€’ OPENAI_FUNCTIONS (default when using GPT-3.5 / GPT-4): + Uses OpenAI's native function-calling API. The LLM is trained to emit + structured JSON for function calls, so tool invocation is very reliable. + Requires an OpenAI model that supports function calling. + + β€’ ZERO_SHOT_REACT_DESCRIPTION (fallback): + Works with ANY LLM (Llama, Mistral, Claude, etc.). + The LLM reasons in plain text using a "Thought/Action/Observation" format. + Less reliable for tool selection but model-agnostic. + +MEMORY: + ConversationBufferWindowMemory(k=5) keeps the last 5 exchanges in context. + k=5 is a pragmatic choice: + β€’ Enough to handle follow-up questions ("And what about MSFT?") + β€’ Small enough not to overflow the context window on long conversations + Disable memory (--no-memory) for stateless single-query use cases. + +VERBOSE MODE: + verbose=True is essential for learning: you see every Thought β†’ Action β†’ + Observation cycle printed to stdout. In production set verbose=False. +""" + +from typing import List, Optional + +from langchain.agents import AgentExecutor, initialize_agent, AgentType +from langchain.memory import ConversationBufferWindowMemory +from langchain.schema import SystemMessage +from langchain.tools import Tool + + +# System prompt injected before every conversation. +# Specific instructions improve tool selection accuracy significantly. +_SYSTEM_PROMPT = """You are a knowledgeable assistant with access to multiple tools. +You can search internal documents, look up live data, and search the web. + +When answering: +1. First consider if you need real-time data (use web_search or get_stock_data) +2. Or if the question is about internal documents (use search_knowledge_base) +3. Or both (use multiple tools) + +Always cite which tools you used and where information came from. +Think step by step before deciding which tools to use.""" + + +def create_agent( + tools: List[Tool], + llm, + memory: bool = True, + verbose: bool = True, +) -> AgentExecutor: + """ + Build and return a LangChain AgentExecutor wired to the provided tools. + + Args: + tools: List of LangChain Tool objects from tool_registry. + llm: An instantiated LangChain LLM (e.g. ChatOpenAI). + memory: If True, adds a sliding-window conversation memory (k=5). + verbose: If True, prints the full reasoning trace to stdout. + + Returns: + A configured AgentExecutor ready to accept queries. + """ + # --- Memory --- + # ConversationBufferWindowMemory keeps only the last k exchanges so the + # context window doesn't grow unboundedly during long conversations. + mem: Optional[ConversationBufferWindowMemory] = None + if memory: + mem = ConversationBufferWindowMemory( + k=5, + memory_key="chat_history", + return_messages=True, + ) + + # --- Determine the best agent type --- + # OPENAI_FUNCTIONS is more reliable for tool selection because it uses + # OpenAI's native function-calling format instead of text-based reasoning. + # We detect whether we're talking to an OpenAI chat model by checking the + # class name β€” this avoids a hard dependency on langchain_openai at this level. + llm_class = type(llm).__name__ + is_openai_chat = "ChatOpenAI" in llm_class or "AzureChatOpenAI" in llm_class + + if is_openai_chat: + agent_type = AgentType.OPENAI_FUNCTIONS + # Inject the system message through agent_kwargs for OPENAI_FUNCTIONS agents. + agent_kwargs = { + "system_message": SystemMessage(content=_SYSTEM_PROMPT), + } + if mem: + agent_kwargs["extra_prompt_messages"] = [] # memory messages prepended automatically + else: + # ZERO_SHOT_REACT_DESCRIPTION works with any LLM via plain-text reasoning. + agent_type = AgentType.ZERO_SHOT_REACT_DESCRIPTION + agent_kwargs = {} + + agent_executor = initialize_agent( + tools=tools, + llm=llm, + agent=agent_type, + memory=mem, + agent_kwargs=agent_kwargs, + verbose=verbose, + # handle_parsing_errors=True prevents the agent from crashing when the + # LLM produces a malformed tool call; it retries with an error message. + handle_parsing_errors=True, + # max_iterations caps runaway loops β€” agent stops after N tool calls. + max_iterations=8, + ) + + return agent_executor + + +def run_agent_query(query: str, agent: AgentExecutor) -> str: + """ + Submit a query to the agent and return the final answer string. + + Wraps the AgentExecutor.invoke() call with error handling so the main + loop doesn't crash on unexpected LLM failures. + + Args: + query: The user's natural-language question. + agent: A configured AgentExecutor from create_agent(). + + Returns: + The agent's final answer as a plain string. + """ + try: + result = agent.invoke({"input": query}) + # AgentExecutor returns a dict; the final answer is under "output". + return result.get("output", str(result)) + except Exception as exc: + return f"Agent encountered an error: {exc}" diff --git a/05-agentic-rag-realtime/src/knowledge_indexer.py b/05-agentic-rag-realtime/src/knowledge_indexer.py new file mode 100644 index 0000000..805200e --- /dev/null +++ b/05-agentic-rag-realtime/src/knowledge_indexer.py @@ -0,0 +1,157 @@ +""" +src/knowledge_indexer.py + +This is the same RAG indexing pattern from Project 1, reused here as one of the +agent's tools. The key difference: in Project 1 this was the ONLY retrieval path; +here it is just one tool the agent may or may not call depending on the question. +""" + +import os +from typing import List + +from langchain_community.vectorstores import FAISS +from langchain_community.embeddings import HuggingFaceEmbeddings +from langchain_community.document_loaders import ( + PyPDFLoader, + TextLoader, + DirectoryLoader, +) +from langchain.text_splitter import RecursiveCharacterTextSplitter + + +# --------------------------------------------------------------------------- +# Embedding model β€” same lightweight model used in Project 1. +# Using a local sentence-transformers model avoids OpenAI embedding API calls +# and keeps costs at zero for the indexing step. +# --------------------------------------------------------------------------- +EMBEDDING_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2" + + +def _get_embeddings() -> HuggingFaceEmbeddings: + """Return a cached HuggingFace embedding model instance.""" + return HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL_NAME) + + +def index_knowledge_base( + kb_dir: str, + index_path: str = "kb_faiss_index", +) -> FAISS: + """ + Load every .pdf and .txt file from kb_dir, chunk the text, build a FAISS + vector index, and persist it to disk. + + If the index already exists on disk it is loaded directly β€” re-indexing is + skipped so the agent starts up fast after the first run. + + Args: + kb_dir: Directory containing source documents (.pdf / .txt). + index_path: Path where the FAISS index folder will be saved. + + Returns: + A ready-to-query FAISS vector store. + """ + # If the index was already built, load it and return immediately. + if os.path.exists(index_path): + print(f"[Indexer] Loading existing FAISS index from '{index_path}'") + return load_index(index_path) + + print(f"[Indexer] Building new FAISS index from '{kb_dir}'") + + # --- Step 1: Load documents using two loaders β€” one for PDFs, one for TXTs --- + documents = [] + + pdf_loader = DirectoryLoader( + kb_dir, + glob="**/*.pdf", + loader_cls=PyPDFLoader, + silent_errors=True, + ) + txt_loader = DirectoryLoader( + kb_dir, + glob="**/*.txt", + loader_cls=TextLoader, + loader_kwargs={"encoding": "utf-8"}, + silent_errors=True, + ) + + for loader in (pdf_loader, txt_loader): + try: + docs = loader.load() + documents.extend(docs) + print(f"[Indexer] Loaded {len(docs)} pages/docs via {loader.__class__.__name__}") + except Exception as exc: + print(f"[Indexer] Warning: loader {loader.__class__.__name__} failed β€” {exc}") + + if not documents: + print(f"[Indexer] No documents found in '{kb_dir}'. Index will be empty.") + # Create a minimal placeholder doc so FAISS doesn't crash on an empty list. + from langchain.schema import Document + documents = [ + Document( + page_content="Knowledge base is empty. Add .pdf or .txt files to the knowledge_base/ folder.", + metadata={"source": "placeholder"}, + ) + ] + + # --- Step 2: Split documents into overlapping chunks --- + # chunk_size=500 tokens keeps chunks small enough for a single context window slot. + # chunk_overlap=50 ensures sentences cut at a boundary don't lose context. + splitter = RecursiveCharacterTextSplitter( + chunk_size=500, + chunk_overlap=50, + ) + chunks = splitter.split_documents(documents) + print(f"[Indexer] Split into {len(chunks)} chunks") + + # --- Step 3: Embed and index --- + embeddings = _get_embeddings() + vector_store = FAISS.from_documents(chunks, embeddings) + + # --- Step 4: Persist to disk --- + vector_store.save_local(index_path) + print(f"[Indexer] FAISS index saved to '{index_path}'") + + return vector_store + + +def load_index(index_path: str) -> FAISS: + """ + Load a previously saved FAISS index from disk. + + Args: + index_path: Path to the directory created by FAISS.save_local(). + + Returns: + Loaded FAISS vector store ready for similarity search. + """ + embeddings = _get_embeddings() + vector_store = FAISS.load_local( + index_path, + embeddings, + allow_dangerous_deserialization=True, # required since LangChain 0.1.x + ) + print(f"[Indexer] Loaded FAISS index from '{index_path}'") + return vector_store + + +def search_knowledge_base( + query: str, + vector_store: FAISS, + k: int = 3, +) -> List[str]: + """ + Perform a similarity search and return the top-k text chunks. + + The agent calls this via rag_tool.py. Returning plain strings (not Documents) + keeps the tool output easy to format and display. + + Args: + query: The natural-language search query. + vector_store: A loaded FAISS index. + k: Number of top results to return. + + Returns: + List of text strings β€” the retrieved chunks, most relevant first. + """ + docs = vector_store.similarity_search(query, k=k) + return [doc.page_content for doc in docs] diff --git a/05-agentic-rag-realtime/src/response_formatter.py b/05-agentic-rag-realtime/src/response_formatter.py new file mode 100644 index 0000000..219e8df --- /dev/null +++ b/05-agentic-rag-realtime/src/response_formatter.py @@ -0,0 +1,138 @@ +""" +src/response_formatter.py + +Formats the agent's raw output into a structured, readable display. + +WHY FORMAT RESPONSES? + β€’ Transparency: users should know whether an answer came from live data + (could change in minutes) or stored documents (could be months old). + β€’ Trust: showing which tools were used lets users verify accuracy. + β€’ Debuggability: the reasoning trace (Thought β†’ Action β†’ Observation) is + an audit trail that reveals HOW the agent reached its conclusion. + +The box-drawing characters used here render well in any Unicode terminal. +""" + +from typing import List, Optional + + +# Width of the output box in characters. +_BOX_WIDTH = 56 + + +def format_response( + answer: str, + tools_used: List[str], + agent_steps: Optional[list] = None, +) -> str: + """ + Render the agent's answer inside a bordered box with a tools-used footer. + + Args: + answer: The final answer string from the agent. + tools_used: List of tool names that were called (e.g. ["get_stock_data"]). + agent_steps: Optional raw intermediate steps from AgentExecutor for the + full reasoning trace footer. + + Returns: + A multi-line formatted string ready to print to stdout. + """ + lines: List[str] = [] + + # ── Answer box ──────────────────────────────────────────────────────────── + lines.append("β•”" + "═" * _BOX_WIDTH + "β•—") + lines.append("β•‘ ANSWER" + " " * (_BOX_WIDTH - 7) + "β•‘") + lines.append("β•š" + "═" * _BOX_WIDTH + "╝") + lines.append(answer) + lines.append("") + + # ── Tools / sources footer ───────────────────────────────────────────── + tools_str = ", ".join(tools_used) if tools_used else "none" + lines.append("β”Œ" + "─" * _BOX_WIDTH + "┐") + + # Truncate tool list if it overflows the box width. + tools_line = f" Tools Used: {tools_str}" + if len(tools_line) > _BOX_WIDTH - 1: + tools_line = tools_line[: _BOX_WIDTH - 4] + "…" + lines.append("β”‚" + tools_line.ljust(_BOX_WIDTH) + "β”‚") + lines.append("β””" + "─" * _BOX_WIDTH + "β”˜") + + return "\n".join(lines) + + +def extract_tools_from_steps(agent_steps: list) -> List[str]: + """ + Parse LangChain AgentExecutor intermediate steps to extract tool names. + + LangChain returns intermediate_steps as a list of (AgentAction, observation) + tuples. Each AgentAction has a `tool` attribute with the tool's name. + + Args: + agent_steps: The `intermediate_steps` value from AgentExecutor output. + + Returns: + Deduplicated list of tool names that were called, in call order. + """ + seen = set() + tools: List[str] = [] + + for step in agent_steps or []: + try: + # Each step is a (AgentAction, str) tuple. + action = step[0] + tool_name = getattr(action, "tool", None) + if tool_name and tool_name not in seen: + tools.append(tool_name) + seen.add(tool_name) + except (IndexError, TypeError): + # Malformed step β€” skip silently. + continue + + return tools + + +def format_agent_trace(agent_steps: list) -> str: + """ + Render the agent's full Thought β†’ Action β†’ Observation trace as text. + + This is the "reasoning audit trail": every decision the agent made is + visible here. Useful for debugging unexpected answers and for teaching + users how the agent works. + + Args: + agent_steps: The `intermediate_steps` list from AgentExecutor output. + + Returns: + A formatted multi-line string showing each reasoning step. + """ + if not agent_steps: + return "(No intermediate steps recorded)" + + lines: List[str] = ["── Agent Reasoning Trace ──"] + + for i, step in enumerate(agent_steps, start=1): + try: + action, observation = step[0], step[1] + tool_name = getattr(action, "tool", "unknown_tool") + tool_input = getattr(action, "tool_input", "") + log = getattr(action, "log", "").strip() + + lines.append(f"\nStep {i}:") + + # The `log` field contains the model's "Thought:" text for ReAct agents. + if log: + # Show only the Thought portion (first line) to keep it concise. + thought_line = log.split("\n")[0] + lines.append(f" Thought : {thought_line}") + + lines.append(f" Action : {tool_name}({tool_input!r})") + # Truncate very long observations for readability. + obs_str = str(observation) + if len(obs_str) > 300: + obs_str = obs_str[:300] + "…" + lines.append(f" Observation: {obs_str}") + + except Exception: + lines.append(f"\nStep {i}: (could not parse step)") + + return "\n".join(lines) diff --git a/05-agentic-rag-realtime/src/tool_registry.py b/05-agentic-rag-realtime/src/tool_registry.py new file mode 100644 index 0000000..2ef8a22 --- /dev/null +++ b/05-agentic-rag-realtime/src/tool_registry.py @@ -0,0 +1,117 @@ +""" +src/tool_registry.py + +Central registry that assembles all agent tools from their factory functions. + +WHY A REGISTRY? + The agent's LLM reads every tool's `name` and `description` when deciding + what to call. By building all tools in one place we can: + β€’ Conditionally include/exclude tools based on available API keys. + β€’ Replace real tools with mock/disabled stubs without touching agent.py. + β€’ Easily add new tools in one location. + +GOOD TOOL DESCRIPTIONS: + Think of tool descriptions like function docstrings aimed at another LLM. + They should answer: + 1. WHAT the tool does. + 2. WHEN to use it (vs. other tools). + 3. WHAT INPUT FORMAT to provide. + Vague descriptions β†’ the agent picks the wrong tool. + Overlapping descriptions β†’ the agent gets confused about which to pick. + +NOTE: + The agent can ONLY use tools in this list β€” it cannot make up tools or call + functions not registered here. If a capability isn't listed, the agent will + either say it can't help or try to approximate with another tool. +""" + +from typing import List + +from langchain.tools import Tool +from langchain_community.vectorstores import FAISS + +from src.tools.rag_tool import create_rag_tool +from src.tools.web_search_tool import create_web_search_tool, create_mock_web_search_tool +from src.tools.finance_tool import create_finance_tool +from src.tools.weather_tool import create_weather_tool, create_mock_weather_tool +from src.tools.wiki_tool import create_wiki_tool + + +def build_tool_registry( + vector_store: FAISS, + config: dict, +) -> List[Tool]: + """ + Instantiate and return the full list of tools available to the agent. + + Each tool factory either creates a real (API-backed) tool or a mock/disabled + stub, depending on whether the relevant API key is present in `config`. + + Args: + vector_store: Loaded FAISS index for the RAG tool. + config: Dictionary with optional keys: + - tavily_api_key (str | None) + - openweathermap_api_key (str | None) + - domain_description (str) β€” what the KB contains + + Returns: + List of LangChain Tool objects, ordered roughly by expected call frequency. + """ + tools: List[Tool] = [] + + # --- 1. RAG / Knowledge Base Tool --- + # Always available β€” uses the local FAISS index, no external API needed. + domain = config.get("domain_description", "internal company documents") + rag = create_rag_tool(vector_store, domain_description=domain) + tools.append(rag) + + # --- 2. Finance Tool --- + # yfinance scrapes Yahoo Finance; no API key required. + finance = create_finance_tool() + tools.append(finance) + + # --- 3. Wikipedia Tool --- + # Free, no API key required. Useful for factual/encyclopaedic queries. + wiki = create_wiki_tool() + tools.append(wiki) + + # --- 4. Web Search Tool --- + # Requires a Tavily API key. Falls back to a disabled mock if not provided. + tavily_key = config.get("tavily_api_key") + if tavily_key: + web = create_web_search_tool(tavily_key) + else: + web = create_mock_web_search_tool() + tools.append(web) + + # --- 5. Weather Tool --- + # Requires an OpenWeatherMap API key. Falls back to mock data if not set. + owm_key = config.get("openweathermap_api_key") + if owm_key: + weather = create_weather_tool(owm_key) + else: + weather = create_mock_weather_tool() + tools.append(weather) + + return tools + + +def get_tool_descriptions(tools: List[Tool]) -> str: + """ + Return a formatted string listing every tool's name and description. + + Useful for displaying the agent's capabilities at startup or for debugging + which tools are available in the current session. + + Args: + tools: List of LangChain Tool objects from build_tool_registry(). + + Returns: + Multi-line string, one tool per line. + """ + lines = ["Available tools:"] + for tool in tools: + # Trim the description to one sentence for the summary display. + first_sentence = tool.description.split(".")[0] + "." + lines.append(f" β€’ {tool.name}: {first_sentence}") + return "\n".join(lines) diff --git a/05-agentic-rag-realtime/src/tools/__init__.py b/05-agentic-rag-realtime/src/tools/__init__.py new file mode 100644 index 0000000..88af137 --- /dev/null +++ b/05-agentic-rag-realtime/src/tools/__init__.py @@ -0,0 +1,2 @@ +# src/tools/__init__.py +# Makes tools/ a sub-package. Individual tool modules are imported explicitly in tool_registry.py. diff --git a/05-agentic-rag-realtime/src/tools/finance_tool.py b/05-agentic-rag-realtime/src/tools/finance_tool.py new file mode 100644 index 0000000..e9f3b35 --- /dev/null +++ b/05-agentic-rag-realtime/src/tools/finance_tool.py @@ -0,0 +1,151 @@ +""" +src/tools/finance_tool.py + +Fetches live stock and financial data via yfinance (Yahoo Finance). + +This is a great example tool because it has: + β€’ A concrete, unambiguous input format: the ticker symbol (AAPL, MSFT, …) + β€’ A concrete, structured output: price, range, P/E, market cap + β€’ No API key required β€” yfinance scrapes Yahoo Finance directly + +LIMITATIONS: + β€’ Data may be delayed up to 15 minutes (Yahoo Finance's standard delay). + β€’ Market cap and P/E are sourced from Yahoo Finance's "info" dict, which + can occasionally be None for smaller or recently-listed companies. + β€’ For company-name inputs (e.g. "Apple") we do a best-effort ticker lookup + using yfinance's search; this may not always resolve correctly. +""" + +from langchain.tools import Tool + + +# Common company-name β†’ ticker fallback mapping for the most-searched names. +# yfinance doesn't have a built-in nameβ†’ticker resolver so we keep a small +# local table for robustness. Users who pass valid tickers bypass this table. +_NAME_TO_TICKER = { + "apple": "AAPL", + "microsoft": "MSFT", + "google": "GOOGL", + "alphabet": "GOOGL", + "amazon": "AMZN", + "meta": "META", + "facebook": "META", + "tesla": "TSLA", + "nvidia": "NVDA", + "netflix": "NFLX", + "berkshire": "BRK-B", + "visa": "V", + "jpmorgan": "JPM", + "walmart": "WMT", +} + + +def _resolve_ticker(input_str: str) -> str: + """ + Best-effort conversion of a user's input to a valid ticker symbol. + + Priority: + 1. If the input looks like a ticker (short, uppercase, no spaces) use it. + 2. Check the local nameβ†’ticker mapping. + 3. Fall back to the input as-is (yfinance will fail if it's wrong). + """ + stripped = input_str.strip() + + # Heuristic: tickers are 1–5 uppercase letters (optionally with a dot/dash) + if len(stripped) <= 6 and stripped.replace("-", "").replace(".", "").isalpha(): + return stripped.upper() + + # Check local lookup table (case-insensitive) + lower = stripped.lower() + for name, ticker in _NAME_TO_TICKER.items(): + if name in lower: + return ticker + + # Last resort: return input uppercased and hope it's a valid ticker + return stripped.upper() + + +def create_finance_tool() -> Tool: + """ + Build a LangChain Tool that returns live stock data from Yahoo Finance. + + No API key is required. yfinance handles all HTTP communication internally. + + Returns: + A configured LangChain Tool for stock data lookups. + """ + + def get_stock_data(input_str: str) -> str: + """ + Fetch key financial metrics for a given ticker or company name. + + The agent passes the ticker or company name as a plain string. + We return a single formatted line so the agent can include it verbatim + in its response without further parsing. + """ + try: + import yfinance as yf # noqa: PLC0415 + + ticker_symbol = _resolve_ticker(input_str) + ticker = yf.Ticker(ticker_symbol) + + # fast_info is lighter-weight than the full .info dict and avoids + # some rate-limiting issues, but has fewer fields. + fast = ticker.fast_info + info = ticker.info # full metadata dict β€” may be slow on first call + + # Safely extract values; Yahoo Finance sometimes returns None. + price = fast.last_price + if price is None: + return ( + f"Could not find stock data for '{input_str}'. " + "Please provide a valid ticker symbol." + ) + + high_52w = fast.year_high + low_52w = fast.year_low + market_cap = fast.market_cap + pe_ratio = info.get("trailingPE") + + # --- Format market cap as a human-readable string --- + def _fmt_cap(cap) -> str: + if cap is None: + return "N/A" + if cap >= 1e12: + return f"${cap / 1e12:.2f}T" + if cap >= 1e9: + return f"${cap / 1e9:.2f}B" + if cap >= 1e6: + return f"${cap / 1e6:.2f}M" + return f"${cap:,.0f}" + + pe_str = f"{pe_ratio:.1f}" if pe_ratio else "N/A" + high_str = f"${high_52w:.2f}" if high_52w else "N/A" + low_str = f"${low_52w:.2f}" if low_52w else "N/A" + + return ( + f"Stock: {ticker_symbol} | " + f"Price: ${price:.2f} | " + f"52W High: {high_str} | " + f"52W Low: {low_str} | " + f"P/E: {pe_str} | " + f"Market Cap: {_fmt_cap(market_cap)}" + ) + + except Exception as exc: + # Catch-all so a yfinance network error doesn't crash the agent loop. + return ( + f"Could not find stock data for '{input_str}'. " + f"Error: {exc}. " + "Please provide a valid ticker symbol (e.g., AAPL, MSFT, GOOGL)." + ) + + return Tool( + name="get_stock_data", + func=get_stock_data, + description=( + "Get current stock/financial data for a publicly traded company. " + "Input: a stock ticker symbol (e.g., AAPL, MSFT, GOOGL) or company name. " + "Returns current price, 52-week range, P/E ratio, and market cap." + ), + ) diff --git a/05-agentic-rag-realtime/src/tools/rag_tool.py b/05-agentic-rag-realtime/src/tools/rag_tool.py new file mode 100644 index 0000000..85a37cc --- /dev/null +++ b/05-agentic-rag-realtime/src/tools/rag_tool.py @@ -0,0 +1,85 @@ +""" +src/tools/rag_tool.py + +Wraps the FAISS knowledge base search as a LangChain Tool so the agent can +call it alongside web search, finance, and weather tools. + +KEY CONCEPT β€” Tool Description is Everything: + The agent's LLM reads the `description` field to decide WHEN to call this + tool. A vague description like "search documents" leads to the agent using + the tool for every question. A specific description that says what the + knowledge base CONTAINS helps the agent make the right routing decision: + internal docs β†’ RAG, live prices β†’ finance_tool, current events β†’ web_search. +""" + +from typing import List + +from langchain.tools import Tool +from langchain_community.vectorstores import FAISS + +from src.knowledge_indexer import search_knowledge_base + + +def create_rag_tool( + vector_store: FAISS, + domain_description: str = "internal company documents", +) -> Tool: + """ + Build a LangChain Tool that searches the FAISS knowledge base. + + Input/Output contract (required by LangChain Tools): + - Input: always a plain string β€” the search query the agent constructs. + - Output: always a plain string β€” the formatted retrieval results. + The agent cannot pass Python objects; everything is serialised to/from text. + + Args: + vector_store: Loaded FAISS index returned by knowledge_indexer. + domain_description: Short phrase describing WHAT is stored in the KB, + e.g. "Q3 financial forecasts and product roadmaps". + This is injected into the tool description so the + LLM knows exactly when to use it. + + Returns: + A configured LangChain Tool ready to add to the agent's tool list. + """ + + def _search(query: str) -> str: + """ + Inner function called by LangChain when the agent invokes this tool. + The agent provides `query` as a plain string. + """ + chunks: List[str] = search_knowledge_base(query, vector_store, k=3) + + if not chunks: + return "No relevant information found in knowledge base." + + # Format each retrieved chunk with a numbered label so the LLM can + # reference specific chunks in its final answer. + lines = ["Found in knowledge base:"] + for i, chunk in enumerate(chunks, start=1): + # Strip excess whitespace from the chunk to keep the context window tidy. + clean = " ".join(chunk.split()) + lines.append(f"{i}. {clean}") + + return "\n".join(lines) + + # --------------------------------------------------------------------------- + # The description is intentionally verbose: + # β€’ "internal policies, product documentation, or stored knowledge" signals + # the agent to use this for anything that would appear in static docs. + # β€’ Mentioning the domain_description further narrows the scope. + # β€’ Ending with "Input: a search query string" sets the input format + # expectation clearly so the agent passes a plain query, not JSON. + # --------------------------------------------------------------------------- + description = ( + f"Search {domain_description} for relevant information. " + "Use this for questions about internal policies, product documentation, " + "or stored knowledge. " + "Input: a search query string." + ) + + return Tool( + name="search_knowledge_base", + func=_search, + description=description, + ) diff --git a/05-agentic-rag-realtime/src/tools/weather_tool.py b/05-agentic-rag-realtime/src/tools/weather_tool.py new file mode 100644 index 0000000..762c38b --- /dev/null +++ b/05-agentic-rag-realtime/src/tools/weather_tool.py @@ -0,0 +1,131 @@ +""" +src/tools/weather_tool.py + +Fetches current weather data from the OpenWeatherMap API. + +HOW THE AGENT PASSES PARAMETERS: + The agent's LLM reads the tool description, extracts the relevant entity + from the user's question (e.g. "London" from "What's the weather in London?"), + and passes it as the `city` string to this tool. The tool's job is simply to + accept that string and return a human-readable result. + +UNITS: + We use metric (Β°C, km/h) by default because it is universally understood. + If your users are in the US you can change `units=metric` to `units=imperial` + in the API URL to receive Β°F and mph. +""" + +import requests +from langchain.tools import Tool + + +def create_weather_tool(openweathermap_api_key: str) -> Tool: + """ + Build a LangChain Tool that returns current weather from OpenWeatherMap. + + Free tier: 60 API calls/minute, no credit card required. + Sign up at https://openweathermap.org/api + + Args: + openweathermap_api_key: A valid OpenWeatherMap API key. + + Returns: + A configured LangChain Tool for weather lookups. + """ + + def get_weather(city: str) -> str: + """ + Call the OpenWeatherMap "current weather" endpoint for a given city. + + The agent passes the city name exactly as it understands it from the + user's question β€” it may be "London", "New York, US", "Paris, FR", etc. + OpenWeatherMap accepts most common city formats. + + Args: + city: City name string provided by the agent. + + Returns: + A single human-readable weather summary string. + """ + city = city.strip() + url = ( + "https://api.openweathermap.org/data/2.5/weather" + f"?q={city}&appid={openweathermap_api_key}&units=metric" + ) + + try: + response = requests.get(url, timeout=10) + + # OpenWeatherMap returns 404 when the city name isn't recognised. + if response.status_code == 404: + return ( + f"Could not find weather for '{city}'. " + "Please check the city name (try adding the country code, " + "e.g. 'Paris, FR')." + ) + + response.raise_for_status() + data = response.json() + + # Extract fields β€” the API always returns these keys on success. + temp = data["main"]["temp"] + feels_like = data["main"]["feels_like"] + humidity = data["main"]["humidity"] + description = data["weather"][0]["description"].capitalize() + wind_speed_ms = data["wind"]["speed"] + # Convert m/s β†’ km/h for a more intuitive display. + wind_kmh = wind_speed_ms * 3.6 + city_name = data.get("name", city) + country = data.get("sys", {}).get("country", "") + location = f"{city_name}, {country}" if country else city_name + + return ( + f"Weather in {location}: " + f"{temp:.1f}Β°C (feels like {feels_like:.1f}Β°C), " + f"{description}, " + f"Humidity: {humidity}%, " + f"Wind: {wind_kmh:.1f} km/h" + ) + + except requests.exceptions.Timeout: + return f"Weather service timed out for '{city}'. Please try again." + except Exception as exc: + return f"Could not retrieve weather for '{city}'. Error: {exc}" + + return Tool( + name="get_weather", + func=get_weather, + description=( + "Get current weather and forecast for any city. " + "Input: city name (e.g., 'London' or 'New York, US'). " + "Returns temperature, weather conditions, humidity, and wind speed." + ), + ) + + +def create_mock_weather_tool() -> Tool: + """ + Return a mock weather tool used when no OpenWeatherMap API key is set. + + The mock returns plausible-looking data so that the full agent pipeline can + be tested without any API keys. The response clearly labels itself as mock + data so it is never confused with real weather information. + """ + + def mock_weather(city: str) -> str: + city = city.strip() + return ( + f"[MOCK DATA β€” configure OPENWEATHERMAP_API_KEY for real weather] " + f"Weather in {city}: 18.0Β°C (feels like 17.0Β°C), " + f"Partly cloudy, Humidity: 65%, Wind: 14.0 km/h" + ) + + return Tool( + name="get_weather", + func=mock_weather, + description=( + "Get current weather and forecast for any city. " + "NOTE: Running in mock mode (no API key configured). " + "Input: city name (e.g., 'London' or 'New York, US')." + ), + ) diff --git a/05-agentic-rag-realtime/src/tools/web_search_tool.py b/05-agentic-rag-realtime/src/tools/web_search_tool.py new file mode 100644 index 0000000..0977b10 --- /dev/null +++ b/05-agentic-rag-realtime/src/tools/web_search_tool.py @@ -0,0 +1,138 @@ +""" +src/tools/web_search_tool.py + +Provides live web search capability via the Tavily API so the agent can answer +questions about current events, recent news, or anything not in the static +knowledge base. + +WHEN TO USE WEB SEARCH vs RAG: + β€’ Current events / breaking news β†’ web_search (RAG docs are static) + β€’ Live prices, exchange rates, scores β†’ dedicated tool (finance/weather) + β€’ Recent product releases, news stories β†’ web_search + β€’ Internal company policies or strategy β†’ search_knowledge_base (private data) + β€’ Historical or stable reference info β†’ search_knowledge_base (faster, no API cost) + +RATE LIMITS: + Tavily free tier allows 1,000 searches/month (~33/day). + Don't call this tool for every question β€” reserve it for questions that + genuinely require real-time or recent information. +""" + +from langchain.tools import Tool + + +def create_web_search_tool(tavily_api_key: str) -> Tool: + """ + Build a LangChain Tool that queries the Tavily web search API. + + Tavily is purpose-built for LLM agents β€” it returns clean text snippets + rather than raw HTML, which keeps the context window efficient. + + Args: + tavily_api_key: A valid Tavily API key from https://tavily.com. + + Returns: + A configured LangChain Tool for live web search. + """ + + def _web_search(query: str) -> str: + """ + Call the Tavily search API and format the top results as plain text. + + The agent provides `query` as whatever it decides to search for. + We return a structured string so the agent can extract the most + relevant details in its reasoning step. + """ + try: + # tavily-python provides a clean SDK around the REST API. + from tavily import TavilyClient # type: ignore + + client = TavilyClient(api_key=tavily_api_key) + + # max_results=3 keeps the context manageable while still giving + # the agent enough breadth to triangulate information. + response = client.search(query, max_results=3) + + results = response.get("results", []) + if not results: + return "Web search returned no results for that query." + + lines = [] + for r in results: + title = r.get("title", "No title") + snippet = r.get("content", r.get("snippet", "No snippet available")) + url = r.get("url", "") + lines.append(f"Title: {title}\nSnippet: {snippet}\nURL: {url}\n---") + + return "\n".join(lines) + + except ImportError: + # Fall back to a direct HTTP call if the SDK isn't installed. + return _tavily_http_fallback(query, tavily_api_key) + + except Exception as exc: + # Graceful degradation: the agent will see this message and can + # note in its response that web search was unavailable. + return f"Web search unavailable: {exc}" + + def _tavily_http_fallback(query: str, api_key: str) -> str: + """Direct HTTP fallback when the tavily-python package is missing.""" + import requests # noqa: PLC0415 + + try: + resp = requests.post( + "https://api.tavily.com/search", + json={"api_key": api_key, "query": query, "max_results": 3}, + timeout=10, + ) + resp.raise_for_status() + data = resp.json() + results = data.get("results", []) + if not results: + return "Web search returned no results." + lines = [] + for r in results: + lines.append( + f"Title: {r.get('title', '')}\n" + f"Snippet: {r.get('content', '')}\n" + f"URL: {r.get('url', '')}\n---" + ) + return "\n".join(lines) + except Exception as exc: + return f"Web search unavailable: {exc}" + + return Tool( + name="web_search", + func=_web_search, + description=( + "Search the live web for current information, news, or recent events. " + "Use this when the question requires up-to-date information not available " + "in static documents. " + "Input: a search query string." + ), + ) + + +def create_mock_web_search_tool() -> Tool: + """ + Fallback tool returned when no Tavily API key is configured. + + The agent will still receive a coherent message explaining why the tool + is unavailable, rather than raising an exception mid-reasoning. + """ + + def _mock_search(query: str) -> str: # noqa: ARG001 + return ( + "Web search is not configured. " + "Please add TAVILY_API_KEY to .env to enable live web search." + ) + + return Tool( + name="web_search", + func=_mock_search, + description=( + "Search the live web for current information, news, or recent events. " + "NOTE: Web search is currently disabled (no API key configured). " + "Input: a search query string." + ), + ) diff --git a/05-agentic-rag-realtime/src/tools/wiki_tool.py b/05-agentic-rag-realtime/src/tools/wiki_tool.py new file mode 100644 index 0000000..a2d07d9 --- /dev/null +++ b/05-agentic-rag-realtime/src/tools/wiki_tool.py @@ -0,0 +1,85 @@ +""" +src/tools/wiki_tool.py + +Provides Wikipedia lookups as a lightweight alternative to full web search. +Wikipedia is ideal for factual, encyclopaedic questions where live web crawling +isn't needed but the answer isn't in the internal knowledge base either. + +Advantages over web_search: + β€’ No API key required. + β€’ Zero cost β€” Wikipedia has no rate limits for this use case. + β€’ Results are well-structured and factual (not SEO-optimised content). + +Use this before falling back to web_search for definitional or historical queries. +""" + +from langchain.tools import Tool + + +def create_wiki_tool() -> Tool: + """ + Build a LangChain Tool that searches Wikipedia for encyclopaedic information. + + Uses the `wikipedia` PyPI package which wraps the Wikipedia REST API. + + Returns: + A configured LangChain Tool for Wikipedia lookups. + """ + + def search_wikipedia(query: str) -> str: + """ + Search Wikipedia for the query and return a short summary. + + We return only the first 500 characters of the summary to keep the + context window usage low. The agent can always call the tool again + with a more specific query if it needs more detail. + + Args: + query: Search term or topic provided by the agent. + + Returns: + A plain-text summary from Wikipedia, or an error message. + """ + try: + import wikipedia # noqa: PLC0415 + + query = query.strip() + + # wikipedia.summary() can raise DisambiguationError when the query + # maps to multiple articles (e.g. "Python"). We handle this by + # picking the first suggestion automatically. + try: + summary = wikipedia.summary(query, sentences=4, auto_suggest=True) + except wikipedia.exceptions.DisambiguationError as e: + # Try the first suggested page instead of failing. + if e.options: + summary = wikipedia.summary(e.options[0], sentences=4) + else: + return f"Wikipedia: '{query}' is ambiguous. Please be more specific." + except wikipedia.exceptions.PageError: + return f"Wikipedia: No article found for '{query}'. Try a different search term." + + # Truncate to keep context window usage predictable. + if len(summary) > 800: + summary = summary[:800] + "…" + + return f"Wikipedia summary for '{query}':\n{summary}" + + except ImportError: + return ( + "Wikipedia tool is unavailable. " + "Install the 'wikipedia' package: pip install wikipedia" + ) + except Exception as exc: + return f"Wikipedia lookup failed for '{query}'. Error: {exc}" + + return Tool( + name="search_wikipedia", + func=search_wikipedia, + description=( + "Search Wikipedia for factual, encyclopaedic information about any topic. " + "Use this for definitions, historical facts, scientific concepts, or " + "general knowledge questions. No API key required. " + "Input: a topic or search term string." + ), + ) diff --git a/README.md b/README.md new file mode 100644 index 0000000..5274624 --- /dev/null +++ b/README.md @@ -0,0 +1,168 @@ +# GenAI Beginner Projects + +A hands-on learning path for developers new to Generative AI. Five self-contained projects that take you from basic RAG to agentic systems with real-time data β€” each building on the previous. + +--- + +## Why These 5 Projects? + +Most GenAI tutorials show you a hello-world demo and call it a day. These projects are different: + +- **Real code**, not toy examples β€” each project solves an actual use case +- **Step-by-step comments** explain *why*, not just *what* +- **Progressive complexity** β€” each project introduces exactly one new concept +- **Works with OpenAI or Ollama** β€” you're not gated by API costs + +--- + +## Prerequisites + +| Requirement | Notes | +|-------------|-------| +| Python 3.10+ | `python --version` to check | +| OpenAI API key | Or run Ollama locally for free | +| Git | For cloning the repo | +| 8 GB RAM minimum | For running local embedding models | + +--- + +## Project Map + +| # | Project | Difficulty | Key New Concept | One-Line Description | +|---|---------|-----------|----------------|----------------------| +| 1 | [RAG From Scratch](./01-rag-from-scratch/) | ⭐⭐ Beginner | Embeddings, vector search | Build a Q&A system over your own documents | +| 2 | [Legal AI Assistant](./02-legal-ai-assistant/) | ⭐⭐⭐ Beginner+ | Domain prompting, structured output | Analyze contracts for risks, clauses, and conflicts | +| 3 | [AI Research Agent](./03-research-agent/) | ⭐⭐⭐ Intermediate | Agents, multi-step reasoning | Synthesize multiple research papers and find gaps | +| 4 | [Multimodal RAG](./04-multimodal-rag/) | ⭐⭐⭐⭐ Intermediate | Vision models, multi-index | RAG that understands text, images, and tables | +| 5 | [Agentic RAG + Real-Time](./05-agentic-rag-realtime/) | ⭐⭐⭐⭐ Intermediate | Tool use, live data APIs | Agent that combines stored docs with live web/financial data | + +--- + +## Learning Path + +Follow the projects in order β€” each one adds exactly one new layer: + +``` +Project 1: RAG From Scratch + ↓ (adds domain-specific prompting) +Project 2: Legal AI Assistant + ↓ (adds agent framework + tools) +Project 3: AI Research Agent + ↓ (adds vision models + multi-index) +Project 4: Multimodal RAG + ↓ (adds live data tools + planning) +Project 5: Agentic RAG + Real-Time +``` + +**Skill progression:** +- After Project 1: You understand how RAG works and can build basic Q&A over documents +- After Project 2: You can write domain-specific prompts and structure LLM output as JSON +- After Project 3: You understand agents and can build multi-step reasoning systems +- After Project 4: You can handle documents with images and tables, not just text +- After Project 5: You can build production-grade agents that combine stored knowledge with live data + +--- + +## Quick Setup + +```bash +# 1. Clone the repo +git clone https://github.com/your-org/genai-beginner-projects.git +cd genai-beginner-projects + +# 2. Pick a project to start with +cd 01-rag-from-scratch + +# 3. Create a virtual environment (recommended β€” keeps dependencies isolated) +python -m venv venv +source venv/bin/activate # Mac/Linux +# venv\Scripts\activate # Windows + +# 4. Install dependencies +pip install -r requirements.txt + +# 5. Set up environment variables +cp .env.example .env +# Edit .env and add your API keys + +# 6. Run the project +python main.py --help +``` + +> **Tip:** Each project has its own `venv` and `requirements.txt`. You don't need to install everything at once. + +--- + +## Glossary + +Plain-English definitions for terms you'll encounter in these projects: + +| Term | Plain-English Definition | +|------|--------------------------| +| **RAG** | Retrieval-Augmented Generation β€” feeding relevant documents to an LLM before asking it a question, so it answers based on your data instead of guessing | +| **Embedding** | A list of numbers (a vector) that represents the meaning of a piece of text. Similar texts have similar vectors. | +| **Vector store** | A database optimized for finding similar vectors quickly. FAISS is a popular local option. | +| **FAISS** | Facebook AI Similarity Search β€” a library that stores vectors and finds the most similar ones very fast | +| **Agent** | An LLM that can take actions (like calling tools) to complete a goal, rather than just answering a single question | +| **Tool** | A function the agent can call β€” like searching the web, getting stock prices, or searching your documents | +| **Chain** | A sequence of LLM calls or operations linked together. LangChain helps you build these. | +| **Prompt template** | A reusable text structure with variables that gets filled in at runtime. Like a form letter for LLMs. | +| **Hallucination** | When an LLM confidently states something false. RAG reduces this by grounding answers in real documents. | +| **Chunk** | A small piece of a larger document (usually 300–1000 characters). Documents are split into chunks for embedding. | +| **Top-k retrieval** | Finding the k most similar chunks to a question. k=3 means: "find the 3 most relevant passages." | +| **ReAct** | Reason + Act β€” an agent pattern where the LLM thinks about what to do, does it, observes the result, and repeats | +| **LangChain** | A Python framework for building LLM applications. Provides building blocks for RAG, agents, chains, and more. | + +--- + +## Using Ollama (Free Local LLMs) + +Don't want to pay for OpenAI? Use Ollama to run LLMs on your own machine: + +```bash +# Install Ollama: https://ollama.com +curl -fsSL https://ollama.com/install.sh | sh + +# Pull a model +ollama pull llama3 + +# In any project, use: +python main.py --model ollama/llama3 +``` + +> **Note:** Local models require ~8 GB RAM for small models. They're slower than OpenAI but completely free. + +--- + +## Repository Structure + +``` +genai-beginner-projects/ +β”‚ +β”œβ”€β”€ README.md ← You are here +β”‚ +β”œβ”€β”€ 01-rag-from-scratch/ ← ⭐⭐ Build RAG from scratch +β”œβ”€β”€ 02-legal-ai-assistant/ ← ⭐⭐⭐ Legal document analysis +β”œβ”€β”€ 03-research-agent/ ← ⭐⭐⭐ AI research synthesis agent +β”œβ”€β”€ 04-multimodal-rag/ ← ⭐⭐⭐⭐ Text + images + tables +└── 05-agentic-rag-realtime/ ← ⭐⭐⭐⭐ Live data + documents +``` + +Each project folder contains: +- `README.md` β€” what the project does, how to run it, what you'll learn +- `requirements.txt` β€” all Python dependencies pinned to specific versions +- `.env.example` β€” copy this to `.env` and fill in your API keys +- `main.py` β€” the entry point to run the project +- `src/` β€” well-commented source files organized by feature + +--- + +## Contributing + +Found a bug? Have an improvement? Open an issue or PR. + +When adding comments or documentation, remember the audience: developers with 3–4 years of experience who are new to GenAI. Explain the "why", not just the "what". + +--- + +*All projects support both OpenAI API and local Ollama models. You don't need to pay for API access to learn from these projects.*