diff --git a/.DS_Store b/.DS_Store new file mode 100644 index 0000000..99cac44 Binary files /dev/null and b/.DS_Store differ diff --git a/.gitignore b/.gitignore index 39c7f30..5bff079 100644 --- a/.gitignore +++ b/.gitignore @@ -165,4 +165,4 @@ nodedatabase.db data/* upload_all.py eval/* -logs/* \ No newline at end of file +logs/*.DS_Store diff --git a/README.md b/README.md index ec5e0c1..418c9e9 100644 --- a/README.md +++ b/README.md @@ -1,21 +1,11 @@ # MMIF Graph Visualizer -This repository hosts the code for the Graph Visualizer, a collection-level visualizer for [MMIF](https://mmif.clams.ai/) files which renders MMIF files as nodes in a D3 force-directed graph. +This repository uses the Gemma3 model from Ollama to summarize transcripts in MMIF (https://mmif.clams.ai/) files. -![screenshot](https://github.com/haydenmccormick/graph-visualizer/assets/74222796/a32f5379-e463-4af9-8dc9-d78206f79aa2) -## Quick Start +![screenshot](https://github.com/haydenmccormick/graph-visualizer/assets/74222796/a32f5379-e463-4af9-8dc9-d78206f79aa2) -Currently, you can run the server in two ways: -1. Manually, with Python: - * Install requirements: `pip install -r requirements.txt` - * Unzip `data/topic_newshour.zip` in the `data` directory - * Run `python app.py` to start the server. It will be accessible at `localhost:5555` - * Run the mmif visualizer in parallel for access to visualization. **The MMIF visualizer should be exposed to port 5000** -2. Using Docker/Podman -* docker-compose up will spin up the Graph Visualizer and the MMIF visualizer, and connect them via a network. -* **WARNING**: Because the project contains a significant amount of modeling requirements and networking, building the container may take a while, and on my hardware has consistently crashed before completing. I have not been able to debug this -- running the files locally using your own distribution of Python is likely the most efficient and accessible way to start the service. ## Directory Structure @@ -50,8 +40,8 @@ This project is heavily centered around client-side Javascript code, with Python - date.py [Date scraping] - get_descriptions.py [Description scraping from AAPB API] - ner.py [Spacy named entity extraction] - - summarize.py [Abstractive summarization using BART] - - topic_model.py [Topic modelling using BERTopic] + - summarize.py [Abstractive summarization using Gemma3] + - topic_model.py [Topic modelling using Gemma3] - preprocessing/preprocess.py [functions for building description dataset] - templates - index.html @@ -61,6 +51,47 @@ This project is heavily centered around client-side Javascript code, with Python - tmp [Directory for storing intermediate MMIF files before they are passed to the visualizer] +# Running the models: +1. Summarizer.py: +* Features: +- Support for two summarization methods: +1. Transformer-based summarization using BART +2. LLM-based summarization using Gemma3 via Ollama ( + +* Automatic handling of long transcripts by chunking and hierarchical summarization +* Support for MMIF formatted files and raw transcript text files +* Configurable summary length + +#Installation +Prerequisites: +1.Python 3 +2. CUDA-compatible GPU recommended for transformer model (but will work on CPU) + +#Setup + +1. Clone this repository: + git clone https://github.com/clamsproject/graph-visualizer +cd transcript-summarizer + +2. Install the required dependencies + +3.If using the LLM method, install and set up Ollama: +- Download and install Ollama +- Start the Ollama service: +ollama serve +- Pull the Gemma3 model: +ollama pull gemma3 + + +#Usage +The script can be run from the command line with the following arguments: +bashpython3 summarize.py [--llm | --transformer] input_file.json + +Command-line Options +--llm: Use the LLM-based summarization method (requires Ollama with Gemma3) +--transformer: Use the transformer-based summarization method (using BART) +input_file: Path to the input file (MMIF JSON or raw transcript) + ## Visualizations diff --git a/modeling/comparison.py b/modeling/comparison.py new file mode 100644 index 0000000..f605410 --- /dev/null +++ b/modeling/comparison.py @@ -0,0 +1,324 @@ +"""Simplified Transcript Summarizer Comparison Tool + +This script creates a focused visual comparison between two summarizer systems. + +Usage: + comparison.py --summaries1 + --summaries2 + [--output_dir ] +""" + +import os +import argparse +import json +import pandas as pd +import numpy as np +import matplotlib.pyplot as plt +from pathlib import Path +from tqdm import tqdm +from rouge_score import rouge_scorer +import nltk +from nltk.tokenize import word_tokenize, sent_tokenize +from sklearn.feature_extraction.text import TfidfVectorizer +from sklearn.metrics.pairwise import cosine_similarity +import re + +# Ensure NLTK data is downloaded +try: + nltk.data.find('tokenizers/punkt') +except LookupError: + nltk.download('punkt') + +def load_summaries(dir_path): + """Load summaries from a directory with robust error handling""" + summaries = {} + try: + files = list(Path(dir_path).glob("*.txt")) + if not files: + print(f"Warning: No .txt files found in {dir_path}") + return summaries + + for file_path in tqdm(files, desc=f"Loading from {dir_path}"): + file_id = file_path.stem + try: + with open(file_path, 'r', encoding='utf-8') as f: + text = f.read().strip() + # Extract just the summary if it follows a pattern like "Summary: [content]" + summary_match = re.search(r"Summary:\s*(.*?)(?:\n\n|\Z)", text, re.DOTALL) + if summary_match: + summaries[file_id] = summary_match.group(1).strip() + else: + summaries[file_id] = text + except Exception as e: + print(f"Error reading file {file_path}: {e}") + except Exception as e: + print(f"Error accessing directory {dir_path}: {e}") + + return summaries + +def create_comparison_radar_chart(system1_summaries, system2_summaries, output_path): + """Create a single radar chart comparing the two systems""" + # Initialize metrics + rouge_scorer_obj = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True) + + # Find common file IDs + common_ids = set(system1_summaries.keys()) & set(system2_summaries.keys()) + print(f"Found {len(common_ids)} common transcripts between both systems.") + + if len(common_ids) == 0: + print("Error: No common files found between the two systems.") + return None + + # Calculate metrics for each file + metrics = { + "word_count": {"System 1": [], "System 2": []}, + "sentence_count": {"System 1": [], "System 2": []}, + "vocabulary_richness": {"System 1": [], "System 2": []}, + "rouge1_f": [], + "rouge2_f": [], + "rougeL_f": [], + "semantic_similarity": [] + } + + for file_id in tqdm(common_ids, desc="Calculating metrics"): + summary1 = system1_summaries[file_id] + summary2 = system2_summaries[file_id] + + # Skip empty summaries + if not summary1 or not summary2: + print(f"Warning: Empty summary found for {file_id}, skipping.") + continue + + # Calculate basic metrics + metrics["word_count"]["System 1"].append(len(summary1.split())) + metrics["word_count"]["System 2"].append(len(summary2.split())) + + metrics["sentence_count"]["System 1"].append(len(sent_tokenize(summary1))) + metrics["sentence_count"]["System 2"].append(len(sent_tokenize(summary2))) + + # Vocabulary richness + tokens1 = word_tokenize(summary1.lower()) + tokens2 = word_tokenize(summary2.lower()) + + if tokens1: + metrics["vocabulary_richness"]["System 1"].append(len(set(tokens1)) / len(tokens1)) + else: + metrics["vocabulary_richness"]["System 1"].append(0) + + if tokens2: + metrics["vocabulary_richness"]["System 2"].append(len(set(tokens2)) / len(tokens2)) + else: + metrics["vocabulary_richness"]["System 2"].append(0) + + # ROUGE scores + try: + rouge_scores = rouge_scorer_obj.score(summary1, summary2) + metrics["rouge1_f"].append(rouge_scores["rouge1"].fmeasure) + metrics["rouge2_f"].append(rouge_scores["rouge2"].fmeasure) + metrics["rougeL_f"].append(rouge_scores["rougeL"].fmeasure) + except Exception as e: + print(f"Error calculating ROUGE scores for {file_id}: {e}") + metrics["rouge1_f"].append(0) + metrics["rouge2_f"].append(0) + metrics["rougeL_f"].append(0) + + # Semantic similarity + try: + vectorizer = TfidfVectorizer() + tfidf_matrix = vectorizer.fit_transform([summary1, summary2]) + metrics["semantic_similarity"].append(cosine_similarity(tfidf_matrix[0:1], tfidf_matrix[1:2])[0][0]) + except Exception as e: + print(f"Error calculating semantic similarity for {file_id}: {e}") + metrics["semantic_similarity"].append(0) + + # Check if we have any valid metrics + if not metrics["word_count"]["System 1"]: + print("Error: No valid metrics could be calculated after filtering empty summaries.") + return None + + # Calculate means for each metric + summary = { + "word_count": { + "System 1": np.mean(metrics["word_count"]["System 1"]), + "System 2": np.mean(metrics["word_count"]["System 2"]) + }, + "sentence_count": { + "System 1": np.mean(metrics["sentence_count"]["System 1"]), + "System 2": np.mean(metrics["sentence_count"]["System 2"]) + }, + "vocabulary_richness": { + "System 1": np.mean(metrics["vocabulary_richness"]["System 1"]), + "System 2": np.mean(metrics["vocabulary_richness"]["System 2"]) + }, + "rouge1_f": np.mean(metrics["rouge1_f"]), + "rouge2_f": np.mean(metrics["rouge2_f"]), + "rougeL_f": np.mean(metrics["rougeL_f"]), + "semantic_similarity": np.mean(metrics["semantic_similarity"]) + } + + # Create radar chart + categories = [ + "Word Count", + "Sentence Count", + "Vocabulary Richness", + "ROUGE-1", + "ROUGE-2", + "ROUGE-L", + "Semantic Similarity" + ] + + # Normalize values for radar chart with reasonable caps + max_word_count = max(summary["word_count"]["System 1"], summary["word_count"]["System 2"]) + max_word_count = min(max_word_count, 500) # Cap at reasonable maximum + + max_sentence_count = max(summary["sentence_count"]["System 1"], summary["sentence_count"]["System 2"]) + max_sentence_count = min(max_sentence_count, 30) # Cap at reasonable maximum + + # Get values for System 1 + system1_values = [ + min(summary["word_count"]["System 1"] / max_word_count, 1.0), + min(summary["sentence_count"]["System 1"] / max_sentence_count, 1.0), + summary["vocabulary_richness"]["System 1"], + summary["rouge1_f"], + summary["rouge2_f"], + summary["rougeL_f"], + summary["semantic_similarity"] + ] + + # Get values for System 2 + system2_values = [ + min(summary["word_count"]["System 2"] / max_word_count, 1.0), + min(summary["sentence_count"]["System 2"] / max_sentence_count, 1.0), + summary["vocabulary_richness"]["System 2"], + summary["rouge1_f"], + summary["rouge2_f"], + summary["rougeL_f"], + summary["semantic_similarity"] + ] + + # Ensure values are in range [0,1] for radar chart + system1_values = [min(max(0, v), 1) for v in system1_values] + system2_values = [min(max(0, v), 1) for v in system2_values] + + # Number of variables + N = len(categories) + + # Create angles for each metric + angles = [n / float(N) * 2 * np.pi for n in range(N)] + angles += angles[:1] # Close the loop + + # Add values for the loop closure + system1_values += system1_values[:1] + system2_values += system2_values[:1] + + # Create radar chart + plt.figure(figsize=(14, 12)) + ax = plt.subplot(111, polar=True) + + # Plot System 1 + ax.plot(angles, system1_values, 'o-', linewidth=2, label="System 1", color="#3498db") + ax.fill(angles, system1_values, alpha=0.25, color="#3498db") + + # Plot System 2 + ax.plot(angles, system2_values, 'o-', linewidth=2, label="System 2", color="#e74c3c") + ax.fill(angles, system2_values, alpha=0.25, color="#e74c3c") + + # Set labels and formatting + plt.xticks(angles[:-1], categories, size=14) + + # Improve label positioning to avoid overlap + for label, angle in zip(ax.get_xticklabels(), angles[:-1]): + if angle < np.pi/2 or angle > 3*np.pi/2: + label.set_horizontalalignment('left') + else: + label.set_horizontalalignment('right') + + ax.set_title("Transcript Summarizer Comparison", size=20, pad=20) + + # Add axis labels with actual values + for i, angle in enumerate(angles[:-1]): + if i == 0: # Word count + ax.text(angle, 1.1, f"Max: {int(max_word_count)} words", + ha='center', va='center', size=10) + elif i == 1: # Sentence count + ax.text(angle, 1.1, f"Max: {int(max_sentence_count)} sentences", + ha='center', va='center', size=10) + + # Add legend with metrics + legend = plt.legend(loc="upper right", bbox_to_anchor=(0.1, 0.1)) + + # Add a text box with key statistics + textstr = '\n'.join(( + f"Word Count: System 1 = {int(summary['word_count']['System 1'])}, System 2 = {int(summary['word_count']['System 2'])}", + f"Sentence Count: System 1 = {summary['sentence_count']['System 1']:.1f}, System 2 = {summary['sentence_count']['System 2']:.1f}", + f"Vocabulary Richness: System 1 = {summary['vocabulary_richness']['System 1']:.3f}, System 2 = {summary['vocabulary_richness']['System 2']:.3f}", + f"Semantic Similarity: {summary['semantic_similarity']:.3f}", + f"ROUGE-1: {summary['rouge1_f']:.3f}", + f"ROUGE-2: {summary['rouge2_f']:.3f}", + f"ROUGE-L: {summary['rougeL_f']:.3f}" + )) + + # Create a text box at the bottom + plt.figtext(0.5, 0.01, textstr, ha="center", fontsize=12, + bbox={"facecolor":"white", "alpha":0.8, "pad":5, "boxstyle":"round,pad=0.5"}) + + plt.tight_layout() + + # Save the visualization + try: + os.makedirs(os.path.dirname(output_path), exist_ok=True) + plt.savefig(output_path, dpi=300, bbox_inches='tight') + print(f"Radar chart saved to {output_path}") + except Exception as e: + print(f"Error saving radar chart: {e}") + finally: + plt.close() + + return summary + +def main(): + parser = argparse.ArgumentParser(description="Compare two summarizer systems with a single visualization") + parser.add_argument("--summaries1", required=True, help="Directory containing summaries from system 1") + parser.add_argument("--summaries2", required=True, help="Directory containing summaries from system 2") + parser.add_argument("--output_dir", default="./comparison_results", help="Directory to save comparison results") + + args = parser.parse_args() + + # Create output directory + output_dir = Path(args.output_dir) + output_dir.mkdir(exist_ok=True, parents=True) + + # Load summaries + print("Loading summaries...") + system1_summaries = load_summaries(args.summaries1) + system2_summaries = load_summaries(args.summaries2) + + if not system1_summaries: + print(f"Error: No valid summaries found in {args.summaries1}") + return + + if not system2_summaries: + print(f"Error: No valid summaries found in {args.summaries2}") + return + + # Create the comparison chart + print("Creating comparison visualization...") + summary = create_comparison_radar_chart( + system1_summaries, + system2_summaries, + output_dir / "summarizer_comparison_radar.png" + ) + + if summary: + # Save the summary data + try: + with open(output_dir / "comparison_summary.json", "w") as f: + json.dump(summary, f, indent=2) + print(f"Comparison complete! Results saved to {output_dir}") + except Exception as e: + print(f"Error saving summary data: {e}") + else: + print("Comparison failed. Please check the error messages above.") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/modeling/summarize.py b/modeling/summarize.py index 2deac8d..3d86e08 100644 --- a/modeling/summarize.py +++ b/modeling/summarize.py @@ -1,32 +1,74 @@ from mmif import Mmif, View, AnnotationTypes, DocumentTypes -from transformers import pipeline from tqdm import tqdm +import json import torch import math import pandas as pd from sklearn.model_selection import train_test_split -from transformers import TrainingArguments, Trainer import nltk from nltk.corpus import stopwords from nltk.cluster.util import cosine_distance from nltk.tokenize import sent_tokenize import numpy as np import networkx as nx -from summarizer import Summarizer from torch.utils.data import Dataset -from transformers import BartTokenizer, BartForConditionalGeneration +import requests +import json +import sys +import os +import argparse + tqdm.pandas() -# TEST_DOCUMENT = "../mmif_files/whisper3.mmif" -MAX_LEN = 1024 -# Summarizer needs at least 4GB of VRAM -device = torch.device("cuda:0" if torch.cuda.is_available() and torch.cuda.mem_get_info()[1] > 4000000000 - else "cpu") -# device = torch.device("cpu") -print(f"Using {device}") -summarizer = pipeline( - "summarization", model="facebook/bart-large-cnn", device=device) +MAX_LEN = 1024 # Summarizer needs at least 4GB + +device = torch.device("cuda:0" if torch.cuda.is_available() and torch.cuda.mem_get_info()[1] > 4000000000 else "cpu") +print(f"Using device: {device}") + +# Transformer-based Summarizer +def load_transformer(): + from transformers import pipeline + return pipeline("summarization", model="facebook/bart-large-cnn", device=device) + +def generate_abstractive_summary(asr_text: str, summarizer, max_len=150): + min_len = 30 if max_len > 30 else int(max_len/2) + return summarizer(asr_text, max_length=max_len, min_length=min_len, do_sample=False)[0]['summary_text'] + +def summarize_transformer(asr_text: str): + summarizer = load_transformer() + if len(asr_text) > MAX_LEN: + chunks = [asr_text[i:i+MAX_LEN] for i in range(0, len(asr_text), MAX_LEN)] + summaries = [generate_abstractive_summary(chunk, summarizer, max_len=int(math.floor(MAX_LEN/len(chunks)))) for chunk in tqdm(chunks)] + asr_text = " ".join(summaries) + return generate_abstractive_summary(asr_text, summarizer), asr_text + +# LLM Summarizer +def generate_llm_summary(text, max_len=150): + prompt = f"""Summarize the transcript in about {max_len} words. + +{text} + +Summary:""" + + response = requests.post('http://localhost:11434/api/generate', + json={ + 'model': 'gemma3', + 'prompt': prompt, + 'stream': False, + 'options': { + 'temperature': 0.1, + 'top_p': 0.9 + } + }) + + if response.status_code == 200: + result = response.json() + return result['response'].strip() + else: + print(f"Error calling Ollama API: {response.status_code}") + return "Error generating summary." + def url2posix(path): @@ -53,93 +95,128 @@ def get_asr_views(mmif: Mmif): def get_asr_text(asr_view: View): for annotation in asr_view.annotations: if annotation.at_type.shortname == "TextDocument": - return annotation.properties.get("text").value + text = annotation.properties.get("text") + return text if isinstance(text, str) else text.value + -def summarize_from_text(asr_text: View): +def summarize_from_text(asr_text: str): if len(asr_text) > MAX_LEN: - chunks = [asr_text[i:i+MAX_LEN] - for i in range(0, len(asr_text), MAX_LEN)] - summaries = [generate_abstractive_summary(chunk, max_len=int( - math.floor(MAX_LEN/len(chunks)))) for chunk in tqdm(chunks)] - asr_text = " ".join(summaries) - return generate_abstractive_summary(asr_text), asr_text + chunks = [asr_text[i:i+MAX_LEN] for i in range(0, len(asr_text), MAX_LEN)] + chunk_summaries = [] + + for chunk in tqdm(chunks): + chunk_summary = generate_llm_summary(chunk, max_len=int(math.floor(150/len(chunks)))) + chunk_summaries.append(chunk_summary) + + if len(chunk_summaries) > 1: + combined_summaries = " ".join(chunk_summaries) + final_summary = generate_llm_summary(combined_summaries, max_len=150) + return final_summary, combined_summaries + else: + return chunk_summaries[0], asr_text + else: + summary = generate_llm_summary(asr_text, max_len=150) + return summary, asr_text -def generate_abstractive_summary(asr_text: str, max_len=150): - min_len = 30 if max_len > 30 else int(max_len/2) - return summarizer(asr_text, max_length=max_len, min_length=min_len, do_sample=False)[0]['summary_text'] -def summarize_file(mmif: Mmif): + +def summarize_file(mmif: Mmif, method: str): gold_transcript = get_transcript(mmif) if gold_transcript: asr_text = gold_transcript else: asr_views = get_asr_views(mmif) + if not asr_views: + return "No ASR views found in the MMIF file", "", "" asr_text = get_asr_text(asr_views[0]) - summary, long_summary = summarize_from_text(asr_text) - return summary, long_summary, asr_text - - -def fine_tune(): - df = pd.read_csv("../data/descriptions.csv") - print("Performing extractive summarization") - extractive_model = Summarizer() - df["description"] = df["description"].progress_apply(lambda x: extractive_model(x)) - - train, test = train_test_split(df) - print("Tokenizing summaries...") - train_summaries = summarizer.tokenizer(train["description"].tolist()) - test_summaries = summarizer.tokenizer(test["description"].tolist()) - print("Tokenizing transcripts...") - train_transcripts = summarizer.tokenizer(train["transcript"].tolist()) - test_transcripts = summarizer.tokenizer(test["transcript"].tolist()) - - class AAPBDataset(torch.utils.data.Dataset): - def __init__(self, transcripts, summaries): - self.transcripts = transcripts - self.summaries = summaries - - def __getitem__(self, idx): - return {"input_ids": torch.tensor(self.transcripts.input_ids[idx], dtype=torch.long), - "attention_mask": torch.tensor(self.transcripts.attention_mask[idx], dtype=torch.long), - "decoder_input_ids": torch.tensor(self.summaries.input_ids[idx], dtype=torch.long), - "decoder_attention_mask": torch.tensor(self.summaries.attention_mask[idx], dtype=torch.long)} - - def __len__(self): - return len(self.transcripts.input_ids) + if not asr_text: + return "No text found in ASR view", "", "" + + if method == "llm": + summary, long_summary = summarize_from_text(asr_text) + return summary, long_summary, asr_text + elif method == "transformer": + summary, long_summary = summarize_transformer(asr_text) + return summary, long_summary, asr_text + else: + raise ValueError("Invalid summarization method") + + +def process_dataset_for_examples(): + """ + Instead of traditional fine-tuning, prepare examples for few-shot learning + """ + try: + df = pd.read_csv("../data/descriptions.csv") + print("Processing dataset to create few-shot examples") + + sample_df = df.sample(n=5) + examples = [] + + for _, row in sample_df.iterrows(): + example = { + "transcript": row["transcript"], + "summary": row["description"] + } + examples.append(example) - train_encodings = AAPBDataset(train_transcripts, train_summaries) - test_encodings = AAPBDataset(test_transcripts, test_summaries) - - training_args = TrainingArguments( - output_dir="../models/aapb_summarizer", - num_train_epochs=3, - per_device_train_batch_size=2, - per_device_eval_batch_size=2, - warmup_steps=500, - weight_decay=0.01, - logging_dir="../logs", - logging_steps=10, - save_steps=10, - eval_steps=10, - evaluation_strategy="steps", - report_to="tensorboard", - run_name="aapb_summarizer" - ) - - trainer = Trainer( - model=summarizer.model, - args=training_args, - train_dataset=train_encodings, - eval_dataset=test_encodings - ) - - print("Training...") - trainer.train() - trainer.save_model("../models/aapb_summarizer") + # Save examples for later use with the LLM + with open("../data/few_shot_examples.json", "w") as f: + json.dump(examples, f) + + print(f"Saved {len(examples)} examples for few-shot learning") + return examples + except Exception as e: + print(f"Error processing dataset: {e}") + return [] + + +def is_mmif(content): + try: + obj = json.loads(content) + return "@type" in obj and "MMIF" in obj["@type"] + except json.JSONDecodeError: + return False if __name__ == "__main__": - fine_tune() + parser = argparse.ArgumentParser(description="Summarize MMIF transcript using --llm or --transformer.") + parser.add_argument("--llm", action="store_true", help="Use LLM (Gemma3) summarizer") + parser.add_argument("--transformer", action="store_true", help="Use Transformer (BART) summarizer") + parser.add_argument("input_file", type=str, help="Path to MMIF JSON file") + args = parser.parse_args() + + if not os.path.exists(args.input_file): + print(f"Error: Input file '{args.input_file}' not found.") + sys.exit(1) + + if args.llm: + method = "llm" + elif args.transformer: + method = "transformer" + else: + print("You must specify either --llm or --transformer") + sys.exit(1) + + with open(args.input_file, "r") as f: + content = f.read() + + print("Generating summary...") + + if is_mmif(content): + mmif = Mmif(content) + summary, long_summary, asr_text = summarize_file(mmif, method) + else: + + if method == "llm": + summary, long_summary = summarize_from_text(content) + elif method == "transformer": + summary, long_summary = summarize_transformer(content) + else: + raise ValueError("Invalid summarization method") + asr_text = content + + print("\nSummary:\n", summary) \ No newline at end of file diff --git a/modeling/topic_model.py b/modeling/topic_model.py index 3d85310..2f65164 100644 --- a/modeling/topic_model.py +++ b/modeling/topic_model.py @@ -1,29 +1,28 @@ -from bertopic import BERTopic import pandas as pd from nltk.corpus import stopwords from nltk.tokenize import word_tokenize, sent_tokenize import nltk -from bertopic.representation import KeyBERTInspired -from hdbscan import HDBSCAN import os import json from tqdm import tqdm -from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer -from bertopic.vectorizers import ClassTfidfTransformer -from bertopic.representation import KeyBERTInspired -from bertopic.representation import MaximalMarginalRelevance +import numpy as np +import requests +from collections import Counter +import math +from sklearn.metrics.pairwise import cosine_similarity +import re from scipy.special import softmax import torch -from bertopic import BERTopic -from bertopic.representation import MaximalMarginalRelevance -from .ner import get_entities # If you have less than 4GB of VRAM, your computer will have a bad time running BERTopic device = torch.device("cuda:0" if torch.cuda.is_available() and torch.cuda.mem_get_info()[1] > 4000000000 else "cpu") +print(f"Using {device}") +# Constants +MAX_LEN = 1024 # try: # topic_model = BERTopic.load(os.path.abspath(os.path.join(os.path.dirname(__file__), "../data/base_topic"))) # print("Loaded pretrained topic model.") @@ -38,153 +37,261 @@ def preprocess(text, entities): entities = [entity.lower() for entity in entities] # Split multi-word entities entities = [entity for expression in entities for entity in expression.split()] - return " ".join([word for word in word_tokenize(text) \ - if word.lower() not in entities \ - and word.lower() not in stopwords.words("english")]) - - -def get_topics(docs, entities=[], zeroshot_topics=[]): - print("Getting topics...") - # Flatten list and split into sliding window n-grams. - # This is so the topic model has more information to train on. This works - # in this case, since the long summaries generally contain many different - # topics that vary from sentence to sentence. - flattened_docs = "".join([doc for sublist in docs for doc in sublist]) + tokens = word_tokenize(text) + stop_words = set(stopwords.words("english")) + + # Filter out stop words and entities + filtered_tokens = [word for word in tokens + if word.lower() not in entities + and word.lower() not in stop_words] + + return " ".join(filtered_tokens) + +def generate_llm_topics(text, num_topics=5, max_words_per_topic=5): + """Generate topics using Ollama's LLM""" + if not text or len(text.strip()) == 0: + print("Warning: Empty text provided to topic generator") + return [] + + prompt = f"""Extract {num_topics} distinct topics from the following text. +For each topic, provide a short descriptive name (2-3 words max) and up to {max_words_per_topic} keywords. +Format your response as a valid JSON like this: +{{ + "topics": [ + {{"name": "topic1_name", "keywords": ["keyword1", "keyword2", "keyword3"]}}, + {{"name": "topic2_name", "keywords": ["keyword1", "keyword2", "keyword3"]}} + ] +}} + +Here is the text: +{text} +""" + + # Make API call to Ollama + try: + response = requests.post(OLLAMA_API_URL, + json={ + 'model': LLM_MODEL, + 'prompt': prompt, + 'stream': False, + 'options': { + 'temperature': 0.1 + } + }, timeout=30) + + if response.status_code == 200: + result = response.json() + # Try to parse JSON from the response + json_str = result['response'] + + # Extract JSON if it's wrapped in code blocks + json_match = re.search(r'```(?:json)?\s*([\s\S]*?)\s*```', json_str) + if json_match: + json_str = json_match.group(1) + + # Find the JSON object within text + json_pattern = r'({[\s\S]*})' + json_match = re.search(json_pattern, json_str) + if json_match: + json_str = json_match.group(1) + + try: + topics_data = json.loads(json_str) + # Handle both direct topics or nested "topics" key + if "topics" in topics_data: + return topics_data["topics"] + else: + # If the model returned a list directly + if isinstance(topics_data, list): + return topics_data + # Handle unexpected structure + print("Warning: Unexpected JSON structure in LLM response") + return [] + except json.JSONDecodeError as e: + print(f"Failed to parse JSON from LLM response: {e}") + print(f"Raw response: {json_str}") + return [] + else: + print(f"Error calling Ollama API: {response.status_code}") + return [] + except requests.RequestException as e: + print(f"Request failed: {e}") + return [] + except Exception as e: + print(f"Error processing LLM response: {e}") + return [] + + +def chunk_text(text, max_len=MAX_LEN): + """Split text into manageable chunks for LLM processing""" + if not text: + return [] + + sentences = sent_tokenize(text) + chunks = [] + current_chunk = [] + current_length = 0 + + for sentence in sentences: + sentence_length = len(sentence) + if current_length + sentence_length <= max_len: + current_chunk.append(sentence) + current_length += sentence_length + else: + if current_chunk: + chunks.append(" ".join(current_chunk)) + current_chunk = [sentence] + current_length = sentence_length + + if current_chunk: + chunks.append(" ".join(current_chunk)) + + return chunks + + +def process_documents_for_topics(docs, num_topics=10, entities=None): + """Process all documents to extract topics""" + entities = entities or [] + print("Processing documents for topic modeling...") + + if not docs or all(not doc for doc in docs): + print("Warning: No document content to process") + return {}, [] + + # Flatten list and preprocess + flattened_docs = " ".join([doc for sublist in docs for doc in sublist if doc]) flattened_docs = preprocess(flattened_docs, entities) - sentences = sent_tokenize(flattened_docs) - three_sent_ngrams = [" ".join(sentences[i:i+3]) for i in range(len(sentences)-2)] - print(f"Training on {len(three_sent_ngrams)} n-grams.") - # We want the zero-shot topics to show up in the graph view, even if very few - # documents are classified under those topics. To avoid having to lower the - # zeroshot_min_similarity threshold too much while still keeping zeroshot - # topics, we add temp "dummy" documents guranteed to be classified under the - # zeroshot topics during training (in practice this seems to work very well). - if zeroshot_topics: - three_sent_ngrams += [(topic + " ") * 10 for topic in zeroshot_topics] - - has_zeroshot = len(zeroshot_topics) > 0 - topic_model, _ = train_topic_model(docs=three_sent_ngrams, zeroshot_topics=zeroshot_topics) - probs, _ = topic_model.approximate_distribution(docs, use_embedding_model=has_zeroshot) - - # Normalize for better visualization - print("Normalizing...") - probs = (probs - probs.min(axis=0)) / (probs.max(axis=0) - probs.min(axis=0)) - print("Removing NaNs...") - probs[probs != probs] = 0 - topic_info = topic_model.get_topic_info() - topic_names = {topic: name for topic, name in zip(topic_info["Topic"], topic_info["Name"])} - return topic_names, probs.tolist() - - -def train_topic_model(docs, zeroshot_topics = []): - """ - Train zero-shot topic model. If no zero-shot topics are specified, this is a standard topic model. - """ - # max_df to filter out extremely common words, and min_df to filter out rare but influential words - # like names - # vectorizer_model = TfidfVectorizer(stop_words='english') - vectorizer_model = TfidfVectorizer(max_df=0.99, min_df=0.6, stop_words='english') - ctfidf_model = ClassTfidfTransformer(reduce_frequent_words=True) - representation_model = MaximalMarginalRelevance(diversity=0.015) - - hdbscan_model = HDBSCAN(min_cluster_size=10, metric='euclidean', - cluster_selection_method='eom', prediction_data=True, min_samples=5) - - # Enable low_memory if you run out of memory - topic_model = BERTopic( - # vectorizer_model=vectorizer_model, - # ctfidf_model=ctfidf_model, - # hdbscan_model=hdbscan_model, - zeroshot_topic_list=zeroshot_topics if zeroshot_topics else None, - zeroshot_min_similarity=.65 if zeroshot_topics else None, - # low_memory=True - ) - - print("Training topic model...") - topic_model.fit(docs) - - print("Trained topic model.") - - print(topic_model.get_topic_info()["Name"]) - - return topic_model, docs - - -def grid_search_topic_model(zeroshot_topics=[]): - """ - Perform a grid search to optimize the topic model hyperparameters. - Metric is the square of coherence and sparseness. - """ - from sklearn.model_selection import ParameterGrid - tqdm.pandas() - data = pd.read_csv(os.path.join(os.path.dirname(__file__), "../data/transcripts.csv")) - data = data.dropna() - print("removing speaker names...") - data["transcript"] = data["transcript"].progress_apply(preprocess) - - # Define the parameter grid - param_grid = { - 'vectorizer_model': [TfidfVectorizer(stop_words="english", max_df=0.98), CountVectorizer(stop_words="english", max_df=0.98), None], - 'ctfidf_model': [ClassTfidfTransformer(reduce_frequent_words=True), None], - 'representation_model': [MaximalMarginalRelevance(diversity=0.015), MaximalMarginalRelevance(diversity=0.02), KeyBERTInspired()], - 'n_gram_range': [(1, 1), (1, 2)], - 'min_topic_size': [10, 15, 20, 25], - } - - # Generate all combinations of parameters - param_combinations = list(ParameterGrid(param_grid)) - - best_metric = -float('inf') - best_topic_model = None - - for params in param_combinations: - print(params) - topic_model = BERTopic(**params) - topic_model.fit(data['transcript']) - coherence = get_coherence(topic_model, data['transcript'][:1000] if len(data) > 1000 else data['transcript']) - sparseness = topic_model.approximate_distribution(data['transcript'])[0] - sparseness = (sparseness == 0).sum() + + # Split into manageable chunks + chunks = chunk_text(flattened_docs) + print(f"Processing {len(chunks)} text chunks") + + # Extract topics from each chunk + all_topics = [] + for chunk in tqdm(chunks): + chunk_topics = generate_llm_topics(chunk, num_topics=min(5, num_topics)) + all_topics.extend(chunk_topics) + + # Merge similar topics + merged_topics = merge_similar_topics(all_topics, threshold=0.7) + + # Select top topics by frequency + top_topics = select_top_topics(merged_topics, num_topics) + + # Create topic dictionary with IDs + topic_names = {i: topic["name"] for i, topic in enumerate(top_topics)} + + # Calculate document-topic distribution + topic_distributions = calculate_topic_distributions(docs, top_topics) + + return topic_names, topic_distributions + + +def merge_similar_topics(topics, threshold=0.7): + """Merge similar topics based on keyword overlap""" + if not topics: + return [] - metric = (coherence*sparseness) - print(metric) - print("----------") - if metric > best_metric: - best_metric = metric - best_topic_model = topic_model - - best_topic_model.save(os.path.join(os.path.dirname(__file__), "../data/best_topic_newshour")) - return best_topic_model, data - - -if __name__ == "__main__": - # from eval.topic import get_coherence - - # # best_model, data = grid_search_topic_model() - # # print(best_model.get_topic_info()["Name"]) - - # data = pd.read_csv(os.path.join(os.path.dirname(__file__), "../data/transcripts.csv")) - # data = data.dropna() - - # topic_model, data = train_topic_model() - # coherence = get_coherence(topic_model, data['transcript'][:1000] if len(data) > 1000 else data['transcript']) - # print(topic_model.get_topic_info()["Name"]) + merged = [] + + for topic in topics: + if not isinstance(topic, dict) or "keywords" not in topic or "name" not in topic: + print(f"Warning: Invalid topic format: {topic}") + continue + + name = topic["name"] + keywords = set(topic["keywords"]) + + # Check if this topic should be merged with an existing one + merged_with_existing = False + for existing in merged: + existing_keywords = set(existing["keywords"]) + # Calculate Jaccard similarity + overlap = len(keywords.intersection(existing_keywords)) + union_size = len(keywords.union(existing_keywords)) + + if union_size > 0: # Avoid division by zero + similarity = overlap / union_size + + if similarity >= threshold: + # Merge keywords + existing["keywords"] = list(existing_keywords.union(keywords)) + # Keep frequency count + existing["count"] = existing.get("count", 1) + 1 + merged_with_existing = True + break + + if not merged_with_existing: + topic["count"] = 1 + merged.append(topic) + + return merged - # print(f"Topic Model Coherence: {coherence}") - from datasets import load_dataset +def select_top_topics(topics, num_topics): + """Select top topics based on frequency""" + if not topics: + return [] + + # Sort by count + sorted_topics = sorted(topics, key=lambda x: x.get("count", 0), reverse=True) + return sorted_topics[:num_topics] - dataset = load_dataset("CShorten/ML-ArXiv-Papers")["train"] - docs = dataset["abstract"][:5_000] - # We define a number of topics that we know are in the documents - zeroshot_topic_list = ["Clustering", "Topic Modeling", "Large Language Models"] +def calculate_topic_distributions(docs, topics): + """Calculate document-topic distribution using keyword presence""" + if not topics: + return [[] for _ in docs] + + # Extract all keywords + all_keywords = {} + for i, topic in enumerate(topics): + for keyword in topic.get("keywords", []): + if keyword in all_keywords: + all_keywords[keyword].append(i) + else: + all_keywords[keyword] = [i] + + # Calculate distribution for each document + distributions = [] + + for doc_list in docs: + doc_text = " ".join([d for d in doc_list if d]).lower() + topic_counts = [0] * len(topics) + + # Count keyword occurrences + for keyword, topic_ids in all_keywords.items(): + count = doc_text.count(keyword.lower()) + if count > 0: + for topic_id in topic_ids: + topic_counts[topic_id] += count + + # Normalize to create a probability distribution + total = sum(topic_counts) + if total > 0: + distribution = [count/total for count in topic_counts] + else: + distribution = [1.0/len(topics)] * len(topics) # Uniform if no matches + + distributions.append(distribution) + + return distributions + +def calculate_zeroshot_topic_distributions(docs, topics): + """Calculate topic distribution using zero-shot classification with LLM""" + if not topics: + return [[] for _ in docs] + + distributions = [] + + for doc_list in tqdm(docs, desc="Calculating zero-shot distributions"): + doc_text = " ".join([d for d in doc_list if d]) + + if len(doc_text) > MAX_LEN: + doc_text = doc_text[:MAX_LEN] + - topic_model = BERTopic( - embedding_model="thenlper/gte-small", - min_topic_size=15, - zeroshot_topic_list=zeroshot_topic_list, - zeroshot_min_similarity=.85, - representation_model=KeyBERTInspired() - ) - topic_model.fit(docs) - probs, _ = topic_model.approximate_distribution(docs, use_embedding_model=True) \ No newline at end of file + distribution = generate_llm_topic_distribution(doc_text, topics) + distributions.append(distribution) + + return distributions