diff --git a/.github/copilot-instructions.md b/.github/copilot-instructions.md new file mode 100644 index 0000000..210a041 --- /dev/null +++ b/.github/copilot-instructions.md @@ -0,0 +1,47 @@ +# Copilot Code Review Instructions + +## Project context + +This is a Streamlit-based image-embedding explorer that runs on HPC GPU +clusters (Ohio Supercomputer Center, SLURM). It has an automatic backend +fallback chain: cuML (GPU) → FAISS (CPU) → scikit-learn (CPU). Optional +GPU dependencies (cuML, CuPy, PyTorch, FAISS-GPU) may or may not be +installed — the app detects them at runtime and degrades gracefully. + +## Review focus + +Prioritise **logic bugs, security issues, and correctness problems** over +style or lint. We run linters separately. A review comment should tell +us something a linter cannot. + +## Patterns to accept (do NOT flag these) + +- **`except (ImportError, Exception): pass` with an inline comment** — + These are intentional graceful-degradation paths for optional GPU + dependencies. If the comment explains the intent, do not suggest adding + logging or replacing the bare pass. + +- **Self-referencing extras in `pyproject.toml`** — e.g. + `gpu = ["emb-explorer[gpu-cu12]"]`. This is a supported pip feature + for aliasing optional-dependency groups. It is not a circular dependency. + +- **`faiss-gpu-cu12` inside a `[gpu-cu13]` extra** — There is no + `faiss-gpu-cu13` package on PyPI. CUDA forward-compatibility means the + cu12 build works on CUDA 13 drivers. If a comment explains this, accept it. + +- **Streamlit `st.rerun(scope="app")`** — The `scope` parameter has been + available since Streamlit 1.33 (2024). `scope="app"` from inside a + `@st.fragment` triggers a full page rerun. This is intentional. + +- **PID-based temp files under `/dev/shm`** — Used for subprocess IPC in + cuML UMAP isolation. The subprocess is short-lived and files are cleaned + up in a `finally` block. This is acceptable for a single-user HPC app. + +## Things worth flagging + +- Version-specifier bugs in `pyproject.toml` (e.g. `<=X.Y.0` excluding + valid patch releases when the real constraint is ` - -

📊 Embed & Explore Images

- - -

🔍 Explore Pre-calculated Embeddings

- + Embed & Explore + Precalculated Embedding Exploration - -

Embedding Interface

- Embedding Clusters -

Embed your images using pre-trained models

- - -

Smart Filtering

- Precalculated Embedding Filters -

Apply filters to pre-calculated embeddings

- + Embedding Interface + Smart Filtering - -

Cluster Summary

- Cluster Summary -

Analyze clustering results and representative images

- - -

Interactive Exploration

- Precalculated Embedding Clusters -

Explore clusters with interactive visualization

- + Cluster Summary + Interactive Exploration - - - - -

Taxonomy Tree Navigation

- Precalculated Embedding Taxon Tree -

Browse hierarchical taxonomy structure

- + + Taxonomy Tree - ## Features -### Embed & Explore Images from Upload - -* **Batch Image Embedding:** - Efficiently embed large collections of images using the pretrained model (e.g., CLIP, BioCLIP) on CPU or GPU (preferably), with customizable batch size and parallelism. -* **Clustering:** - Reduces embedding vectors to 2D using PCA, T-SNE, and UMAP. Performs K-Means clustering and display result using a scatter plot. Explore clusters via interactive scatter plots. Click on data points to preview images and details. -* **Cluster-Based Repartitioning:** - Copy/repartition images into cluster-specific folders with a single click. Generates a summary CSV for downstream use. -* **Clustering Summary:** - Displays cluster sizes, variances, and representative images for each cluster, helping you evaluate clustering quality. - -### Explore Pre-computed Embeddings +**Embed & Explore** - Embed images using pretrained models (CLIP, BioCLIP), cluster with K-Means, visualize with PCA/t-SNE/UMAP, and repartition images by cluster. -* **Parquet File Support:** - Load precomputed embeddings with associated metadata from parquet files. Compatible with various embedding formats and metadata schemas. -* **Advanced Filtering:** - Filter datasets by taxonomic hierarchy, source datasets, and custom metadata fields. Combine multiple filter criteria for precise data selection. -* **Clustering:** - Reduce embedding vectors to 2D using PCA, UMAP, or t-SNE. Perform K-Means clustering and display result using a scatter plot. Explore clusters via interactive scatter plots. Click on points to preview images and explore metadata details. -* **Taxonomy Tree Navigation:** - Browse hierarchical biological classifications with interactive tree view. Expand and collapse taxonomic nodes to explore at different classification levels. +**Precalculated Embeddings** - Load parquet files (or directories of parquets) with precomputed embeddings, apply dynamic cascading filters, and explore clusters with taxonomy tree navigation. See [Data Format](docs/DATA_FORMAT.md) for the expected schema and [Backend Pipeline](docs/BACKEND_PIPELINE.md) for how embeddings flow through clustering and visualization. ## Installation -[uv](https://docs.astral.sh/uv/) is a fast Python package installer and resolver. Install `uv` first if you haven't already: - ```bash -# Install uv (if not already installed) -curl -LsSf https://astral.sh/uv/install.sh | sh -``` - -Then install the project: - -```bash -# Clone the repository git clone https://github.com/Imageomics/emb-explorer.git cd emb-explorer -# Create virtual environment and install dependencies -uv venv -source .venv/bin/activate # On Windows: .venv\Scripts\activate +# Using uv (recommended) +uv venv && source .venv/bin/activate uv pip install -e . ``` -### GPU Support (Optional) +### GPU Acceleration (optional) -For GPU acceleration, you'll need CUDA 12.0+ installed on your system. +A GPU is **not required** — everything works on CPU out of the box. But if you have an NVIDIA GPU with CUDA, clustering and dimensionality reduction (KMeans, t-SNE, UMAP) will be significantly faster via [cuML](https://docs.rapids.ai/api/cuml/stable/). ```bash -# Full GPU support with RAPIDS (cuDF + cuML) -uv pip install -e ".[gpu]" +# CUDA 12.x +uv pip install -e ".[gpu-cu12]" -# Minimal GPU support (PyTorch + FAISS only) -uv pip install -e ".[gpu-minimal]" +# CUDA 13.x +uv pip install -e ".[gpu-cu13]" ``` -### Development - -```bash -# Install with development tools -uv pip install -e ".[dev]" -``` +The app auto-detects GPU availability at runtime and falls back to CPU if anything goes wrong — no configuration needed. You can also manually select backends (cuML, FAISS, sklearn) in the sidebar. ## Usage -### Running the Application +### Standalone Apps ```bash -# Activate virtual environment (if not already activated) -source .venv/bin/activate # On Windows: .venv\Scripts\activate +# Embed & Explore - Interactive image embedding and clustering +streamlit run apps/embed_explore/app.py -# Run the Streamlit app -streamlit run app.py +# Precalculated Embeddings - Explore precomputed embeddings from parquet +streamlit run apps/precalculated/app.py ``` -An example dataset (`example_1k.parquet`) is provided in the `data/` folder for testing the pre-calculated embeddings features. This parquet contains metadata and the [BioCLIP 2](https://imageomics.github.io/bioclip-2/) embeddings for a one thousand-image subset of [TreeOfLife-200M](https://huggingface.co/datasets/imageomics/TreeOfLife-200M). - -### Command Line Tools - -The project also provides command-line utilities: +### Entry Points (after pip install) ```bash -# List all available models -python list_models.py --format table - -# List models in JSON format -python list_models.py --format json --pretty - -# List models as names only -python list_models.py --format names - -# Get help for the list models command -python list_models.py --help +emb-embed-explore # Launch Embed & Explore app +emb-precalculated # Launch Precalculated Embeddings app +list-models # List available embedding models ``` -### Running on Remote Compute Nodes +### Example Data -If running the app on a remote compute node (e.g., HPC cluster), you'll need to set up port forwarding to access the Streamlit interface from your local machine. +An example dataset (`data/example_1k.parquet`) is provided with BioCLIP 2 embeddings for testing. Please see the [data README](data/README.md) for more information about this sample set. -1. **Start the app on the compute node:** - ```bash - # On the remote compute node - streamlit run app.py - ``` - Note the port number (default is 8501) and the compute node hostname. +### Remote HPC Usage -2. **Set up SSH port forwarding from your local machine:** - ```bash - # From your local machine - ssh -N -L 8501::8501 @ - ``` - - **Example:** - ```bash - ssh -N -L 8501:c0828.ten.osc.edu:8501 username@cardinal.osc.edu - ``` - - Replace: - - `` with the actual compute node hostname (e.g., `c0828.ten.osc.edu`) - - `` with your username - - `` with the login node address (e.g., `cardinal.osc.edu`) - -3. **Access the app:** - Open your web browser and navigate to `http://localhost:8501` - -The `-N` flag prevents SSH from executing remote commands, and `-L` sets up the local port forwarding. +```bash +# On compute node +streamlit run apps/precalculated/app.py --server.port 8501 -### Notes on Implementation +# On local machine (port forwarding) +ssh -N -L 8501::8501 @ -More notes on different implementation methods and approaches are available in the [implementation summary doc](docs/implementation_summary.md). +# Access at http://localhost:8501 +``` ## Acknowledgements -* [OpenCLIP](https://github.com/mlfoundations/open_clip) -* [Streamlit](https://streamlit.io/) -* [Altair](https://altair-viz.github.io/) - ---- +[OpenCLIP](https://github.com/mlfoundations/open_clip) | [Streamlit](https://streamlit.io/) | [Altair](https://altair-viz.github.io/) diff --git a/app.py b/app.py deleted file mode 100644 index 65bd728..0000000 --- a/app.py +++ /dev/null @@ -1,106 +0,0 @@ -import streamlit as st - -def main(): - """Main application entry point.""" - st.set_page_config( - layout="wide", - page_title="emb-explorer", - page_icon="🔍" - ) - - # Welcome page content - st.title("🔍 emb-explorer") - st.markdown("**Visual exploration and clustering tool for image datasets and pre-calculated image embeddings**") - - st.markdown("---") - - # Two-column layout to match README structure - col1, col2 = st.columns(2) - - with col1: - st.markdown("### 📊 Embed & Explore Images") - st.markdown("**Upload and process your own image datasets**") - - st.markdown(""" - **🔋 Key Features:** - - **Batch Image Embedding**: Process large image collections using pre-trained models (CLIP, BioCLIP, OpenCLIP) - - **Multi-Model Support**: Choose from various vision-language models optimized for different domains - - **K-Means Analysis**: Clustering with customizable KMeans parameters - - **Interactive Clustering**: Explore data with PCA, t-SNE, and UMAP dimensionality reduction - - **Cluster Repartitioning**: Organize images into cluster-specific folders with one click - - **Summary Statistics**: Analyze cluster quality with size, variance, and representative samples - """) - - - - with col2: - st.markdown("### 📊 Explore Pre-calculated Embeddings") - st.markdown("**Work with existing embeddings and rich metadata**") - - st.markdown(""" - **🔍 Key Features:** - - **Parquet File Support**: Load precomputed embeddings with associated metadata - - **Advanced Filtering**: Filter by custom metadata - - **K-Means Analysis**: Clustering with customizable KMeans parameters - - **Interactive Clustering**: Explore data with PCA and UMAP dimensionality reduction - - **Taxonomy Tree Navigation**: Browse hierarchical taxonomy classifications with interactive tree view - """) - - - st.markdown("---") - - # Getting started section - st.markdown("## 🚀 Getting Started") - - col1, col2 = st.columns(2) - - with col1: - st.markdown(""" - **🎯 Choose Your Workflow:** - - **For New Images** → Use **Clustering** page - - Upload your image folder - - Select embedding model - - Generate embeddings and explore clusters - - **For Existing Data** → Use **Precalculated Embeddings** page - - Load your parquet file - - Apply filters and explore patterns - - Perform targeted clustering analysis - """) - - with col2: - st.markdown(""" - **⚡ Technical Capabilities:** - - - **Models**: CLIP, BioCLIP-2, OpenCLIP variants - - **Acceleration**: CPU and GPU (CUDA) support - - **Formats**: Images (PNG, JPG, etc.), Parquet files - - **Clustering**: K-Means with multiple initialization methods - - **Visualization**: Interactive scatter plots with image preview - - **Export**: CSV summaries, folder organization, filtered datasets - """) - - st.markdown("---") - - # Navigation help - st.markdown("### 📋 Navigation") - st.markdown(""" - Use the **sidebar navigation** to select your workflow: - - **🔍 Clustering**: Process and explore new image datasets - - **📊 Precalculated Embeddings**: Analyze existing embeddings with metadata filtering - - Each page provides step-by-step guidance and real-time feedback for your analysis workflow. - """) - - # Quick tips - with st.expander("💡 Pro Tips"): - st.markdown(""" - - **GPU Acceleration**: Install with `uv pip install -e ".[gpu]"` for faster processing - - **Large Datasets**: Use batch processing and monitor memory usage in the sidebar - - **Custom Filtering**: Combine multiple filter criteria for precise data selection - - **Export Results**: Save cluster summaries and repartitioned images for downstream analysis - """) - -if __name__ == "__main__": - main() diff --git a/apps/__init__.py b/apps/__init__.py new file mode 100644 index 0000000..cb0762f --- /dev/null +++ b/apps/__init__.py @@ -0,0 +1,7 @@ +""" +emb-explorer standalone applications. + +Available apps: +- embed_explore: Interactive image embedding explorer with clustering +- precalculated: Precalculated embeddings explorer with dynamic filters +""" diff --git a/apps/embed_explore/__init__.py b/apps/embed_explore/__init__.py new file mode 100644 index 0000000..b43b9e9 --- /dev/null +++ b/apps/embed_explore/__init__.py @@ -0,0 +1,3 @@ +""" +BYO Images Embed & Explore application. +""" diff --git a/apps/embed_explore/app.py b/apps/embed_explore/app.py new file mode 100644 index 0000000..941af45 --- /dev/null +++ b/apps/embed_explore/app.py @@ -0,0 +1,58 @@ +""" +BYO Images Embed & Explore application. + +This application allows users to bring their own images, generate embeddings, +cluster them, and explore the results visually. +""" + +import streamlit as st + +from apps.embed_explore.components.sidebar import render_clustering_sidebar +from apps.embed_explore.components.image_preview import render_image_preview +from shared.components.summary import render_clustering_summary +from shared.components.visualization import render_scatter_plot + + +def main(): + """CLI entry point — launches the Streamlit server.""" + import sys + import os + from streamlit.web import cli as stcli + + sys.argv = ["streamlit", "run", os.path.abspath(__file__), "--server.headless", "true"] + stcli.main() + + +def app(): + """Streamlit application layout.""" + st.set_page_config( + layout="wide", + page_title="Embed & Explore", + page_icon="🔍" + ) + + st.title("🔍 Embed & Explore") + st.markdown("Generate embeddings from your images, cluster them, and explore the results.") + + # Create the main layout + col_settings, col_plot, col_preview = st.columns([2, 6, 3]) + + with col_settings: + # Render the sidebar with all controls + render_clustering_sidebar() + + with col_plot: + # Render the main scatter plot + render_scatter_plot() + + with col_preview: + # Render the image preview + render_image_preview() + + # Bottom section: Clustering summary + st.markdown("---") + render_clustering_summary() + + +if __name__ == "__main__": + app() diff --git a/apps/embed_explore/components/__init__.py b/apps/embed_explore/components/__init__.py new file mode 100644 index 0000000..30c5e79 --- /dev/null +++ b/apps/embed_explore/components/__init__.py @@ -0,0 +1,14 @@ +""" +UI components for the embed_explore application. +""" + +from apps.embed_explore.components.sidebar import render_clustering_sidebar +from apps.embed_explore.components.visualization import render_scatter_plot, render_image_preview +from apps.embed_explore.components.summary import render_clustering_summary + +__all__ = [ + "render_clustering_sidebar", + "render_scatter_plot", + "render_image_preview", + "render_clustering_summary" +] diff --git a/apps/embed_explore/components/image_preview.py b/apps/embed_explore/components/image_preview.py new file mode 100644 index 0000000..368483b --- /dev/null +++ b/apps/embed_explore/components/image_preview.py @@ -0,0 +1,42 @@ +""" +Image preview component for the embed_explore application. +""" + +import streamlit as st +import os + +from shared.utils.logging_config import get_logger + +logger = get_logger(__name__) + +# Track last displayed image to avoid duplicate logging +_last_displayed_path = None + + +def render_image_preview(): + """Render the image preview panel for local image files.""" + global _last_displayed_path + + valid_paths = st.session_state.get("valid_paths", None) + labels = st.session_state.get("labels", None) + selected_idx = st.session_state.get("selected_image_idx", 0) + + if ( + valid_paths is not None and + labels is not None and + selected_idx is not None and + 0 <= selected_idx < len(valid_paths) + ): + img_path = valid_paths[selected_idx] + cluster = labels[selected_idx] if labels is not None else "?" + + # Log only when image changes + if _last_displayed_path != img_path: + logger.info(f"[Image] Loading local file: {os.path.basename(img_path)} (cluster={cluster})") + _last_displayed_path = img_path + + st.image(img_path, caption=f"Cluster {cluster}: {os.path.basename(img_path)}", width='stretch') + st.markdown(f"**File:** `{os.path.basename(img_path)}`") + st.markdown(f"**Cluster:** `{cluster}`") + else: + st.info("Image preview will appear here after you select a cluster point.") diff --git a/components/clustering/sidebar.py b/apps/embed_explore/components/sidebar.py similarity index 66% rename from components/clustering/sidebar.py rename to apps/embed_explore/components/sidebar.py index c581301..129b2a2 100644 --- a/components/clustering/sidebar.py +++ b/apps/embed_explore/components/sidebar.py @@ -1,51 +1,55 @@ """ -Sidebar components for the clustering page. +Sidebar components for the embed_explore application. """ import streamlit as st import os from typing import Tuple, List, Optional -from services.embedding_service import EmbeddingService -from services.clustering_service import ClusteringService -from services.file_service import FileService -from lib.progress import StreamlitProgressContext -from components.shared.clustering_controls import render_clustering_backend_controls, render_basic_clustering_controls +from shared.services.embedding_service import EmbeddingService +from shared.services.clustering_service import ClusteringService +from shared.services.file_service import FileService +from shared.lib.progress import StreamlitProgressContext +from shared.components.clustering_controls import render_clustering_backend_controls, render_basic_clustering_controls +from shared.utils.backend import check_cuda_available, resolve_backend, is_oom_error +from shared.utils.logging_config import get_logger + +logger = get_logger(__name__) def render_embedding_section() -> Tuple[bool, Optional[str], Optional[str], int, int]: """ Render the embedding section of the sidebar. - + Returns: Tuple of (embed_button_clicked, image_dir, model_name, n_workers, batch_size) """ with st.expander("Embed", expanded=True): image_dir = st.text_input("Image folder path") - + # Get available models dynamically available_models = EmbeddingService.get_model_options() model_name = st.selectbox("Model", available_models) - + col1, col2 = st.columns(2) with col1: n_workers = st.number_input( - "N workers", - min_value=1, - max_value=64, - value=16, + "N workers", + min_value=1, + max_value=64, + value=16, step=1 ) with col2: batch_size = st.number_input( - "Batch size", - min_value=1, - max_value=2048, - value=32, + "Batch size", + min_value=1, + max_value=2048, + value=32, step=1 ) embed_button = st.button("Run Embedding") - + # Handle embedding execution if embed_button and image_dir and os.path.isdir(image_dir): with StreamlitProgressContext(st.empty(), "Embedding complete!") as progress: @@ -54,15 +58,17 @@ def render_embedding_section() -> Tuple[bool, Optional[str], Optional[str], int, image_dir, model_name, batch_size, n_workers, progress_callback=progress ) - + if embeddings.shape[0] == 0: st.error("No valid image embeddings found.") + logger.warning("Embedding generation returned 0 embeddings") st.session_state.embeddings = None st.session_state.valid_paths = None st.session_state.labels = None st.session_state.data = None st.session_state.selected_image_idx = None else: + logger.info(f"Embeddings stored: shape={embeddings.shape}, dtype={embeddings.dtype}") st.success(f"Generated {embeddings.shape[0]} image embeddings.") st.session_state.embeddings = embeddings st.session_state.valid_paths = valid_paths @@ -72,62 +78,117 @@ def render_embedding_section() -> Tuple[bool, Optional[str], Optional[str], int, st.session_state.labels = None st.session_state.data = None st.session_state.selected_image_idx = 0 - + except Exception as e: st.error(f"Error during embedding: {e}") - + logger.exception("Embedding generation failed") + elif embed_button: st.error("Please provide a valid image directory path.") - + return embed_button, image_dir, model_name, n_workers, batch_size def render_clustering_section(n_workers: int = 1) -> Tuple[bool, int, str]: """ Render the clustering section of the sidebar. - + Args: n_workers: Number of workers for parallel processing - + Returns: Tuple of (cluster_button_clicked, n_clusters, reduction_method) """ with st.expander("Cluster", expanded=False): # Basic clustering controls n_clusters, reduction_method = render_basic_clustering_controls() - + # Backend and advanced controls dim_reduction_backend, clustering_backend, n_workers_clustering, seed = render_clustering_backend_controls() - + cluster_button = st.button("Run Clustering", type="primary") - + # Handle clustering execution if cluster_button: embeddings = st.session_state.get("embeddings", None) valid_paths = st.session_state.get("valid_paths", None) - + if embeddings is not None and valid_paths is not None and len(valid_paths) > 1: - try: - with st.spinner("Running clustering..."): - df_plot, labels = ClusteringService.run_clustering( - embeddings, valid_paths, n_clusters, reduction_method, n_workers_clustering, - dim_reduction_backend, clustering_backend, seed - ) - - # Store everything in session state for reruns - st.session_state.data = df_plot - st.session_state.labels = labels - st.session_state.selected_image_idx = 0 # Reset selection - st.success(f"Clustering complete! Found {n_clusters} clusters.") - - except Exception as e: - st.error(f"Error during clustering: {e}") + run_clustering_with_fallback( + embeddings, valid_paths, n_clusters, reduction_method, + n_workers_clustering, dim_reduction_backend, clustering_backend, seed + ) else: st.error("Please run embedding first.") - + return cluster_button, n_clusters, reduction_method +def run_clustering_with_fallback( + embeddings, + valid_paths, + n_clusters: int, + reduction_method: str, + n_workers: int, + dim_reduction_backend: str, + clustering_backend: str, + seed: Optional[int] +): + """ + Run clustering with robust error handling and automatic fallbacks. + + Uses ClusteringService.run_clustering_safe() which transparently + handles GPU errors by falling back to CPU-based sklearn backends. + """ + cuda_available, device_info = check_cuda_available() + actual_dim_backend = resolve_backend(dim_reduction_backend, "reduction") + actual_cluster_backend = resolve_backend(clustering_backend, "clustering") + + logger.info(f"Starting clustering: samples={len(embeddings)}, clusters={n_clusters}, " + f"reduction={reduction_method}, device={device_info}") + logger.info(f"Backends: dim_reduction={actual_dim_backend}, clustering={actual_cluster_backend}") + + try: + with st.spinner(f"Running {reduction_method} + KMeans ({actual_dim_backend}/{actual_cluster_backend})..."): + df_plot, labels = ClusteringService.run_clustering_safe( + embeddings, valid_paths, n_clusters, reduction_method, + n_workers, actual_dim_backend, actual_cluster_backend, seed + ) + + # Store results + st.session_state.data = df_plot + st.session_state.labels = labels + st.session_state.selected_image_idx = 0 + + # Compute and store clustering summary + logger.info("Computing clustering summary statistics...") + summary_df, representatives = ClusteringService.generate_clustering_summary( + embeddings, labels, df_plot + ) + st.session_state.clustering_summary = summary_df + st.session_state.clustering_representatives = representatives + logger.info(f"Clustering summary computed: {len(summary_df)} clusters") + + st.success(f"Clustering complete! Found {n_clusters} clusters.") + + except (RuntimeError, OSError) as e: + if is_oom_error(e): + st.error("**GPU Out of Memory** - Dataset too large for GPU") + st.info("Try: Reduce dataset size, or select 'sklearn' backend") + logger.exception("GPU OOM error during clustering") + else: + st.error(f"Error during clustering: {e}") + logger.exception("Clustering error") + + except MemoryError: + st.error("**System Out of Memory** - Reduce dataset size") + logger.exception("System memory exhausted during clustering") + + except Exception as e: + st.error(f"Error during clustering: {e}") + logger.exception("Unexpected clustering error") + + def render_save_section(): """Render the save operations section of the sidebar.""" # --- Save images from a specific cluster utility --- @@ -135,7 +196,7 @@ def render_save_section(): with st.expander("Save Images from Specific Cluster", expanded=True): df_plot = st.session_state.get("data", None) labels = st.session_state.get("labels", None) - + if df_plot is not None and labels is not None: available_clusters = sorted(df_plot['cluster'].unique(), key=lambda x: int(x)) selected_clusters = st.multiselect( @@ -150,14 +211,14 @@ def render_save_section(): key="save_cluster_dir" ) save_cluster_button = st.button("Save images", key="save_cluster_btn") - + # Handle save execution if save_cluster_button and selected_clusters: cluster_rows = df_plot[df_plot['cluster'].isin(selected_clusters)] max_workers = st.session_state.get("num_threads", 8) - + with StreamlitProgressContext( - save_status_placeholder, + save_status_placeholder, f"Images from cluster(s) {', '.join(map(str, selected_clusters))} saved successfully!" ) as progress: try: @@ -165,39 +226,39 @@ def render_save_section(): cluster_rows, save_dir, max_workers, progress_callback=progress ) st.info(f"Summary CSV saved at {csv_path}") - + except Exception as e: save_status_placeholder.error(f"Error saving images: {e}") - + elif save_cluster_button: save_status_placeholder.warning("Please select at least one cluster.") - + else: st.info("Run clustering first to enable this utility.") - + # --- Repartition expander and status --- repartition_status_placeholder = st.empty() with st.expander("Repartition Images by Cluster", expanded=False): st.markdown("**Target directory for repartitioned images (will be created):**") repartition_dir = st.text_input( - "Directory", + "Directory", value="repartitioned_output", key="repartition_dir" ) max_workers = st.number_input( - "Number of threads (higher = faster, try 8–32)", - min_value=1, - max_value=64, + "Number of threads (higher = faster, try 8-32)", + min_value=1, + max_value=64, value=8, step=1, key="num_threads" ) repartition_button = st.button("Repartition images by cluster", key="repartition_btn") - + # Handle repartition execution if repartition_button: df_plot = st.session_state.get("data", None) - + if df_plot is None or len(df_plot) < 1: repartition_status_placeholder.warning("Please run clustering first before repartitioning images.") else: @@ -210,7 +271,7 @@ def render_save_section(): df_plot, repartition_dir, max_workers, progress_callback=progress ) st.info(f"Summary CSV saved at {csv_path}") - + except Exception as e: repartition_status_placeholder.error(f"Error repartitioning images: {e}") @@ -218,14 +279,14 @@ def render_save_section(): def render_clustering_sidebar(): """Render the complete clustering sidebar with all sections.""" tab_compute, tab_save = st.tabs(["Compute", "Save"]) - + with tab_compute: embed_button, image_dir, model_name, n_workers, batch_size = render_embedding_section() cluster_button, n_clusters, reduction_method = render_clustering_section(n_workers) - + with tab_save: render_save_section() - + return { 'embed_button': embed_button, 'image_dir': image_dir, diff --git a/apps/embed_explore/components/summary.py b/apps/embed_explore/components/summary.py new file mode 100644 index 0000000..86053cb --- /dev/null +++ b/apps/embed_explore/components/summary.py @@ -0,0 +1,10 @@ +""" +Clustering summary components for the embed_explore application. + +This module re-exports from shared for backwards compatibility. +""" + +# Re-export from shared module +from shared.components.summary import render_clustering_summary + +__all__ = ['render_clustering_summary'] diff --git a/apps/embed_explore/components/visualization.py b/apps/embed_explore/components/visualization.py new file mode 100644 index 0000000..50c675c --- /dev/null +++ b/apps/embed_explore/components/visualization.py @@ -0,0 +1,13 @@ +""" +Visualization components for the embed_explore application. + +This module re-exports from shared for backwards compatibility. +""" + +# Re-export scatter plot from shared module +from shared.components.visualization import render_scatter_plot + +# Re-export image preview from local module +from apps.embed_explore.components.image_preview import render_image_preview + +__all__ = ['render_scatter_plot', 'render_image_preview'] diff --git a/apps/precalculated/__init__.py b/apps/precalculated/__init__.py new file mode 100644 index 0000000..9507dba --- /dev/null +++ b/apps/precalculated/__init__.py @@ -0,0 +1,3 @@ +"""Precalculated embeddings explorer standalone application.""" + +__version__ = "0.1.0" diff --git a/apps/precalculated/app.py b/apps/precalculated/app.py new file mode 100644 index 0000000..354efd3 --- /dev/null +++ b/apps/precalculated/app.py @@ -0,0 +1,78 @@ +""" +Precalculated Embeddings Explorer - Standalone Application + +A Streamlit application for exploring precomputed embeddings stored in parquet files. +Features dynamic filter generation based on available columns. +""" + +import streamlit as st + +from apps.precalculated.components.sidebar import ( + render_file_section, + render_dynamic_filters, + render_clustering_section, +) +from apps.precalculated.components.data_preview import render_data_preview +from shared.components.visualization import render_scatter_plot +from shared.components.summary import render_clustering_summary + + +def main(): + """CLI entry point — launches the Streamlit server.""" + import sys + import os + from streamlit.web import cli as stcli + + sys.argv = ["streamlit", "run", os.path.abspath(__file__), "--server.headless", "true"] + stcli.main() + + +def app(): + """Streamlit application layout.""" + st.set_page_config( + layout="wide", + page_title="Precalculated Embeddings Explorer", + page_icon="📊" + ) + + # Initialize session state + if "page_type" not in st.session_state or st.session_state.page_type != "precalculated_app": + # Clear any stale state from other apps + keys_to_clear = ["embeddings", "valid_paths", "last_image_dir", "embedding_complete"] + for key in keys_to_clear: + if key in st.session_state: + del st.session_state[key] + st.session_state.page_type = "precalculated_app" + + # Header + st.title("📊 Precalculated Embeddings Explorer") + st.markdown( + "Load parquet files with embeddings, apply dynamic filters, and cluster for visualization. " + "Filters are automatically generated based on your data columns." + ) + + # Row 1: File loading + render_file_section() + + # Row 2: Dynamic filters + render_dynamic_filters() + + # Row 3: Main content + col_settings, col_plot, col_preview = st.columns([2, 7, 3]) + + with col_settings: + render_clustering_section() + + with col_plot: + render_scatter_plot() + + with col_preview: + render_data_preview() + + # Bottom: Clustering summary + st.markdown("---") + render_clustering_summary(show_taxonomy=True) + + +if __name__ == "__main__": + app() diff --git a/apps/precalculated/components/__init__.py b/apps/precalculated/components/__init__.py new file mode 100644 index 0000000..b49685e --- /dev/null +++ b/apps/precalculated/components/__init__.py @@ -0,0 +1,17 @@ +"""Components for the precalculated embeddings application.""" + +from apps.precalculated.components.sidebar import ( + render_file_section, + render_dynamic_filters, + render_clustering_section, +) +from apps.precalculated.components.data_preview import render_data_preview +from apps.precalculated.components.visualization import render_scatter_plot + +__all__ = [ + "render_file_section", + "render_dynamic_filters", + "render_clustering_section", + "render_data_preview", + "render_scatter_plot", +] diff --git a/apps/precalculated/components/data_preview.py b/apps/precalculated/components/data_preview.py new file mode 100644 index 0000000..06a0d36 --- /dev/null +++ b/apps/precalculated/components/data_preview.py @@ -0,0 +1,188 @@ +""" +Data preview components for the precalculated embeddings application. +Dynamically displays all available metadata fields. +""" + +import streamlit as st +import pandas as pd +import requests +import time +from typing import Optional +from PIL import Image +from io import BytesIO + +from shared.utils.logging_config import get_logger + +logger = get_logger(__name__) + + +@st.cache_data(ttl=300, show_spinner=False) +def _fetch_image_from_url_cached(url: str, timeout: int = 5) -> Optional[bytes]: + """Internal cached function to fetch image bytes.""" + if not url or not isinstance(url, str): + return None + + try: + if not url.startswith(('http://', 'https://')): + return None + + response = requests.get(url, timeout=timeout, stream=True) + response.raise_for_status() + + content_type = response.headers.get('content-type', '').lower() + if not content_type.startswith('image/'): + return None + + return response.content + + except Exception: + return None + + +def fetch_image_from_url(url: str, timeout: int = 5) -> Optional[bytes]: + """ + Fetch an image from a URL with logging. + Uses caching internally but logs the request. + """ + if not url or not isinstance(url, str): + return None + + if not url.startswith(('http://', 'https://')): + logger.warning(f"[Image] Invalid URL scheme: {url[:50]}...") + return None + + logger.info(f"[Image] Fetching: {url[:80]}...") + start_time = time.time() + + result = _fetch_image_from_url_cached(url, timeout) + + elapsed = time.time() - start_time + if result: + logger.info(f"[Image] Loaded: {len(result)/1024:.1f}KB in {elapsed:.3f}s") + else: + logger.warning(f"[Image] Failed to load: {url[:50]}...") + + return result + + +def get_image_from_url(url: str) -> Optional[Image.Image]: + """Get image from URL with caching and logging.""" + image_bytes = fetch_image_from_url(url) + if image_bytes: + try: + image = Image.open(BytesIO(image_bytes)) + logger.info(f"[Image] Opened: {image.size[0]}x{image.size[1]} {image.mode}") + return image + except Exception as e: + logger.error(f"[Image] Failed to open: {e}") + return None + return None + + +def render_data_preview(): + """Render the data preview panel with dynamic field display.""" + df_plot = st.session_state.get("data", None) + labels = st.session_state.get("labels", None) + selected_idx = st.session_state.get("selected_image_idx", None) # Default to None, not 0 + filtered_df = st.session_state.get("filtered_df_for_clustering", None) + + # Validate that selection matches current data version + current_data_version = st.session_state.get("data_version", None) + selection_data_version = st.session_state.get("selection_data_version", None) + selection_valid = ( + selected_idx is not None and + current_data_version is not None and + selection_data_version == current_data_version + ) + + if ( + df_plot is not None and + labels is not None and + selection_valid and + 0 <= selected_idx < len(df_plot) and + filtered_df is not None + ): + # Get the selected record + selected_uuid = df_plot.iloc[selected_idx]['uuid'] + cluster = labels[selected_idx] if labels is not None else "?" + + # Use cluster_name if available + if 'cluster_name' in df_plot.columns: + cluster_display = df_plot.iloc[selected_idx]['cluster_name'] + else: + cluster_display = cluster + + # Find the full record + record = filtered_df[filtered_df['uuid'] == selected_uuid].iloc[0] + + st.markdown("### 📋 Record Details") + + # Try to display image if identifier/url column exists (cached to prevent re-fetch) + image_cols = ['identifier', 'image_url', 'url', 'img_url', 'image'] + for img_col in image_cols: + if img_col in record.index and pd.notna(record[img_col]): + url = record[img_col] + image = get_image_from_url(url) + if image is not None: + st.image(image, width=280) + break + + # Display Cluster and UUID prominently (not in table) + st.markdown(f"**Cluster:** `{cluster_display}`") + st.markdown(f"**UUID:** `{selected_uuid}`") + + # Build metadata table for remaining fields + skip_fields = {'emb', 'embedding', 'embeddings', 'vector', 'idx', 'uuid', 'cluster', 'cluster_name'} + + metadata_rows = [] + for field, value in record.items(): + if field.lower() in skip_fields or field in skip_fields: + continue + if pd.isna(value): + continue + + # Format value + if isinstance(value, float): + display_val = f"{value:.4f}" + elif isinstance(value, (list, tuple)): + display_val = f"[{len(value)} items]" + else: + display_val = str(value) + + metadata_rows.append({"Field": field, "Value": display_val}) + + # Display remaining metadata as table + if metadata_rows: + st.markdown("---") + st.markdown("**📊 Metadata**") + metadata_df = pd.DataFrame(metadata_rows) + st.dataframe( + metadata_df, + hide_index=True, + width="stretch", + column_config={ + "Field": st.column_config.TextColumn("Field", width="small"), + "Value": st.column_config.TextColumn("Value", width="large"), + } + ) + + else: + # Show appropriate message based on state + if df_plot is not None and labels is not None: + st.info("📋 Click a point in the scatter plot to view its details.") + else: + st.info("📋 Run clustering first, then click a point to view details.") + + # Show dataset summary + filtered_df = st.session_state.get("filtered_df", None) + if filtered_df is not None and len(filtered_df) > 0: + st.markdown("### 📈 Dataset Summary") + st.markdown(f"**Records:** {len(filtered_df):,}") + + # Show column stats + column_info = st.session_state.get("column_info", {}) + if column_info: + with st.expander("Column overview"): + for col, info in list(column_info.items())[:10]: + unique = len(info['unique_values']) if info['unique_values'] else "many" + st.caption(f"• **{col}** ({info['type']}): {unique} unique") diff --git a/apps/precalculated/components/sidebar.py b/apps/precalculated/components/sidebar.py new file mode 100644 index 0000000..7c7780c --- /dev/null +++ b/apps/precalculated/components/sidebar.py @@ -0,0 +1,747 @@ +""" +Sidebar components for the precalculated embeddings application. +Features dynamic cascading filter generation based on parquet columns. +""" + +import streamlit as st +import pandas as pd +import pyarrow as pa +import pyarrow.parquet as pq +import pyarrow.compute as pc +import numpy as np +import os +import time +import hashlib +from typing import Dict, Any, Optional, Tuple, List + +from shared.services.clustering_service import ClusteringService +from shared.components.clustering_controls import render_clustering_backend_controls +from shared.utils.backend import check_cuda_available, resolve_backend, is_oom_error +from shared.utils.logging_config import get_logger + +logger = get_logger(__name__) + + +# Technical columns that should never be shown as filters +EXCLUDED_COLUMNS = {'uuid', 'emb', 'embedding', 'embeddings', 'vector'} + + +def get_column_info_dynamic(table: pa.Table) -> Dict[str, Dict[str, Any]]: + """ + Dynamically analyze all columns in a PyArrow table for filtering. + + Args: + table: PyArrow Table to analyze + + Returns: + Dictionary mapping column names to their info (type, unique_values, etc.) + """ + column_info = {} + + for col_name in table.column_names: + # Skip technical/excluded columns + if col_name.lower() in EXCLUDED_COLUMNS: + continue + + col_array = table.column(col_name) + + # Handle null values + non_null_mask = pc.is_valid(col_array) + non_null_count = pc.sum(non_null_mask).as_py() + total_count = len(col_array) + null_count = total_count - non_null_count + + if non_null_count == 0: + col_type = 'empty' + unique_values = [] + value_counts = {} + else: + # Check data type + arrow_type = col_array.type + + if (pa.types.is_integer(arrow_type) or + pa.types.is_floating(arrow_type) or + pa.types.is_decimal(arrow_type)): + col_type = 'numeric' + unique_values = None + value_counts = None + elif pa.types.is_list(arrow_type) or pa.types.is_large_list(arrow_type): + # Skip list/array columns (like embeddings) + continue + else: + # Get unique values for categorical determination + try: + unique_array = pc.unique(col_array) + unique_count = len(unique_array) + + if unique_count <= 100: # Categorical if <= 100 unique values + col_type = 'categorical' + unique_values = sorted([v.as_py() for v in unique_array if v.is_valid]) + + # Get value counts + value_counts_result = pc.value_counts(col_array) + value_counts = {} + for i in range(len(value_counts_result)): + struct = value_counts_result[i].as_py() + if struct['values'] is not None: + value_counts[struct['values']] = struct['counts'] + else: + col_type = 'text' + unique_values = None + value_counts = None + except Exception: + col_type = 'text' + unique_values = None + value_counts = None + + column_info[col_name] = { + 'type': col_type, + 'unique_values': unique_values, + 'value_counts': value_counts, + 'null_count': null_count, + 'total_count': total_count, + 'null_percentage': (null_count / total_count) * 100 if total_count > 0 else 0 + } + + return column_info + + +def get_cascading_options( + table: pa.Table, + target_column: str, + current_filters: Dict[str, Any], + column_info: Dict[str, Dict[str, Any]] +) -> List[str]: + """ + Get available options for a column based on other active filters. + This enables cascading/dependent filter behavior. + + Args: + table: Full PyArrow table + target_column: Column to get options for + current_filters: Currently selected filter values (excluding target_column) + column_info: Column metadata + + Returns: + List of unique values available for the target column given other filters + """ + # Build filters excluding the target column + other_filters = {k: v for k, v in current_filters.items() if k != target_column and v} + + if not other_filters: + # No other filters, return original unique values + info = column_info.get(target_column, {}) + return info.get('unique_values', []) or [] + + # Apply other filters to get subset + filtered_table = apply_filters_arrow(table, other_filters) + + if target_column not in filtered_table.column_names: + return [] + + # Get unique values from filtered subset + try: + col_array = filtered_table.column(target_column) + unique_array = pc.unique(col_array) + return sorted([v.as_py() for v in unique_array if v.is_valid]) + except Exception: + return column_info.get(target_column, {}).get('unique_values', []) or [] + + +def render_file_section() -> Tuple[bool, Optional[str]]: + """ + Render the file loading section. + + Returns: + Tuple of (file_loaded, file_path) + """ + with st.expander("📁 Load Parquet", expanded=True): + file_path = st.text_input( + "Parquet file or directory path", + value=st.session_state.get("parquet_file_path", ""), + help="Path to a parquet file or directory of parquet files containing embeddings and metadata" + ) + + load_button = st.button("Load File", type="primary") + + if load_button and file_path and os.path.exists(file_path): + try: + logger.info(f"Loading parquet file: {file_path}") + with st.spinner("Loading parquet file..."): + table = pq.read_table(file_path) + df = table.to_pandas() + + logger.info(f"Loaded {len(df):,} records, {len(table.column_names)} columns, " + f"schema: {[f'{c.name}({c.type})' for c in table.schema]}") + + # Validate required columns + if 'uuid' not in table.column_names: + st.error("Missing required 'uuid' column") + logger.error("Parquet validation failed: missing 'uuid' column") + return False, file_path + if 'emb' not in table.column_names: + st.error("Missing required 'emb' column") + logger.error("Parquet validation failed: missing 'emb' column") + return False, file_path + + emb_dim = len(df['emb'].iloc[0]) + logger.info(f"Embedding dimension: {emb_dim}") + + # Dynamically analyze all columns + column_info = get_column_info_dynamic(table) + logger.info(f"Column analysis: {len(column_info)} filterable columns " + f"({sum(1 for v in column_info.values() if v['type'] == 'categorical')} categorical, " + f"{sum(1 for v in column_info.values() if v['type'] == 'numeric')} numeric, " + f"{sum(1 for v in column_info.values() if v['type'] == 'text')} text)") + + # Store in session state + st.session_state.parquet_table = table + st.session_state.parquet_df = df + st.session_state.parquet_file_path = file_path + st.session_state.column_info = column_info + + # Reset downstream state + st.session_state.filtered_df = None + st.session_state.embeddings = None + st.session_state.data = None + st.session_state.labels = None + st.session_state.selected_image_idx = None + st.session_state.active_filters = {} + st.session_state.pending_filters = {} + + st.success(f"Loaded {len(df):,} records with {len(column_info)} filterable columns") + st.info(f"Embedding dimension: {emb_dim}") + + return True, file_path + + except Exception as e: + st.error(f"Error loading file: {e}") + logger.exception(f"Failed to load parquet file: {file_path}") + return False, file_path + + elif load_button and file_path: + st.error(f"File not found: {file_path}") + return False, file_path + elif load_button: + st.error("Please provide a file path") + return False, None + + return False, file_path + + +def render_dynamic_filters() -> Dict[str, Any]: + """ + Render dynamically generated cascading filters based on parquet columns. + Filter options update based on other selected filters (AND logic). + + Returns: + Dictionary of applied filters + """ + with st.expander("🔍 Filter Data", expanded=True): + df = st.session_state.get("parquet_df", None) + table = st.session_state.get("parquet_table", None) + column_info = st.session_state.get("column_info", {}) + + if df is None or table is None: + st.info("Load a parquet file first to enable filtering.") + return {} + + st.markdown(f"**Total records:** {len(df):,}") + + # Separate columns by type for better organization + categorical_cols = [(k, v) for k, v in column_info.items() if v['type'] == 'categorical'] + numeric_cols = [(k, v) for k, v in column_info.items() if v['type'] == 'numeric'] + text_cols = [(k, v) for k, v in column_info.items() if v['type'] == 'text'] + + # Sort categorical columns by number of unique values (fewer first) + categorical_cols.sort(key=lambda x: len(x[1].get('unique_values', []) or [])) + + # Let user select which columns to filter on + all_filterable = [col for col, _ in categorical_cols + numeric_cols + text_cols] + + selected_columns = st.multiselect( + "Select columns to filter on", + options=all_filterable, + default=st.session_state.get("selected_filter_columns", []), + help="Choose columns for filtering. Options cascade based on selections (AND logic).", + key="filter_column_selector" + ) + st.session_state.selected_filter_columns = selected_columns + + if not selected_columns: + st.caption("Select columns above to create filters") + + # Show column summary with consistent string types to avoid Arrow errors + with st.expander("📊 Available columns", expanded=False): + col_summary = [] + for col, info in column_info.items(): + unique_count = len(info['unique_values']) if info['unique_values'] else -1 + col_summary.append({ + "Column": col, + "Type": info['type'], + "Unique": str(unique_count) if unique_count >= 0 else "many", + "Null %": f"{info['null_percentage']:.1f}%" + }) + st.dataframe(pd.DataFrame(col_summary), hide_index=True, width="stretch") + + return {} + + st.markdown("---") + st.markdown("**🎯 Cascading Filters** *(AND logic - options update based on selections)*") + + # Initialize pending filters from session state + pending_filters = st.session_state.get("pending_filters", {}) + + # Render filters for selected columns (max 4 per row) + cols_per_row = 4 + for row_start in range(0, len(selected_columns), cols_per_row): + row_cols = selected_columns[row_start:row_start + cols_per_row] + cols = st.columns(len(row_cols)) + + for i, col_name in enumerate(row_cols): + info = column_info.get(col_name, {}) + col_type = info.get('type', 'text') + + with cols[i]: + st.markdown(f"**{col_name}**") + + if col_type == 'categorical': + # Get cascading options based on other filters + available_options = get_cascading_options( + table, col_name, pending_filters, column_info + ) + + # Get current selection, filter to only valid options + current_selection = pending_filters.get(col_name, []) + if isinstance(current_selection, list): + current_selection = [v for v in current_selection if v in available_options] + + selected_values = st.multiselect( + f"Select values", + options=available_options, + default=current_selection, + key=f"filter_{col_name}", + help=f"{len(available_options)} options available" + ) + + # Update pending filters + if selected_values: + pending_filters[col_name] = selected_values + elif col_name in pending_filters: + del pending_filters[col_name] + + elif col_type == 'numeric': + # For numeric, apply other filters first to get valid range + other_filters = {k: v for k, v in pending_filters.items() if k != col_name and v} + if other_filters: + filtered_table = apply_filters_arrow(table, other_filters) + filtered_df = filtered_table.to_pandas() + else: + filtered_df = df + + col_data = filtered_df[col_name].dropna() + if len(col_data) > 0: + min_val, max_val = float(col_data.min()), float(col_data.max()) + if min_val != max_val: + # Get current range or use full range + current_range = pending_filters.get(col_name, {}) + default_min = current_range.get('min', min_val) if isinstance(current_range, dict) else min_val + default_max = current_range.get('max', max_val) if isinstance(current_range, dict) else max_val + + # Clamp to available range + default_min = max(min_val, min(default_min, max_val)) + default_max = min(max_val, max(default_max, min_val)) + + range_values = st.slider( + f"Range", + min_value=min_val, + max_value=max_val, + value=(default_min, default_max), + key=f"filter_{col_name}" + ) + if range_values != (min_val, max_val): + pending_filters[col_name] = {'min': range_values[0], 'max': range_values[1]} + elif col_name in pending_filters: + del pending_filters[col_name] + + elif col_type == 'text': + current_text = pending_filters.get(col_name, "") + if not isinstance(current_text, str): + current_text = "" + + search_text = st.text_input( + f"Search", + value=current_text, + key=f"filter_{col_name}", + help="Case-insensitive contains search" + ) + if search_text.strip(): + pending_filters[col_name] = search_text.strip() + elif col_name in pending_filters: + del pending_filters[col_name] + + # Store pending filters + st.session_state.pending_filters = pending_filters + + st.markdown("---") + + # Show preview of filtered count + if pending_filters: + try: + preview_table = apply_filters_arrow(table, pending_filters) + preview_count = len(preview_table) + st.info(f"📊 Preview: **{preview_count:,}** records match current filters") + except Exception: + logger.debug("Filter preview count failed", exc_info=True) + + # Apply filters button + col1, col2 = st.columns([1, 1]) + with col1: + apply_button = st.button("Apply Filters", type="primary") + with col2: + clear_button = st.button("Clear All") + + if clear_button: + st.session_state.filtered_df = df + st.session_state.active_filters = {} + st.session_state.pending_filters = {} + st.session_state.selected_filter_columns = [] + st.rerun() + + if apply_button: + if pending_filters: + with st.spinner("Applying filters..."): + logger.info(f"Applying filters: {list(pending_filters.keys())}") + filtered_table = apply_filters_arrow(table, pending_filters) + filtered_df = filtered_table.to_pandas() + + logger.info(f"Filter result: {len(df):,} -> {len(filtered_df):,} records " + f"({len(filtered_df)/len(df)*100:.1f}% retained)") + + st.session_state.filtered_df = filtered_df + st.session_state.active_filters = pending_filters.copy() + + # Reset downstream state + st.session_state.embeddings = None + st.session_state.data = None + st.session_state.labels = None + st.session_state.selected_image_idx = None + + st.success(f"Filtered to {len(filtered_df):,} records") + else: + st.session_state.filtered_df = df + st.session_state.active_filters = {} + st.info("No filters applied, using full dataset") + + # Show active filter summary + active_filters = st.session_state.get("active_filters", {}) + if active_filters: + with st.expander("📋 Applied filters", expanded=False): + for col, val in active_filters.items(): + if isinstance(val, list): + st.caption(f"• **{col}**: {', '.join(str(v) for v in val[:3])}{'...' if len(val) > 3 else ''}") + elif isinstance(val, dict): + st.caption(f"• **{col}**: {val['min']:.2f} to {val['max']:.2f}") + else: + st.caption(f"• **{col}**: contains '{val}'") + + return pending_filters + + +def apply_filters_arrow(table: pa.Table, filters: Dict[str, Any]) -> pa.Table: + """ + Apply filters to PyArrow Table with AND logic. + + Args: + table: PyArrow Table to filter + filters: Dictionary of column_name -> filter_value pairs + + Returns: + Filtered PyArrow Table + """ + filter_expressions = [] + + for col, filter_value in filters.items(): + if col not in table.column_names or filter_value is None: + continue + + col_ref = pc.field(col) + + if isinstance(filter_value, dict): + # Numeric range filter + if 'min' in filter_value and filter_value['min'] is not None: + filter_expressions.append(pc.greater_equal(col_ref, filter_value['min'])) + if 'max' in filter_value and filter_value['max'] is not None: + filter_expressions.append(pc.less_equal(col_ref, filter_value['max'])) + elif isinstance(filter_value, list): + # Categorical filter (multiple values) + if len(filter_value) > 0: + filter_expressions.append(pc.is_in(col_ref, pa.array(filter_value))) + elif isinstance(filter_value, str): + # Text filter (case-insensitive literal substring match) + if filter_value.strip(): + filter_expressions.append( + pc.match_substring(pc.utf8_lower(col_ref), filter_value.lower()) + ) + + # Combine all filters with AND + if filter_expressions: + from functools import reduce + try: + combined = reduce(pc.and_kleene, filter_expressions) + return table.filter(combined) + except AttributeError: + # Fallback for older PyArrow + result = table + for expr in filter_expressions: + result = result.filter(expr) + return result + + return table + + +def extract_embeddings_safe(df: pd.DataFrame) -> np.ndarray: + """ + Safely extract embeddings from DataFrame using zero-copy where possible. + + Args: + df: DataFrame with 'emb' column + + Returns: + numpy array of embeddings (float32) + """ + if 'emb' not in df.columns: + raise ValueError("DataFrame does not contain 'emb' column") + + logger.info(f"Extracting embeddings from DataFrame: {len(df)} rows") + + # Use np.stack for efficient conversion + embeddings = np.stack(df['emb'].values) + + if embeddings.ndim != 2: + raise ValueError(f"Embeddings should be 2D, got shape {embeddings.shape}") + + embeddings = embeddings.astype(np.float32) + logger.info(f"Extracted embeddings: shape={embeddings.shape}, dtype={embeddings.dtype}") + + return embeddings + + +def render_clustering_section() -> Tuple[bool, int, str, str, str, int, Optional[int]]: + """ + Render the clustering section with VRAM error handling. + + Returns: + Tuple of (cluster_button_clicked, n_clusters, reduction_method, + dim_reduction_backend, clustering_backend, n_workers, seed) + """ + with st.expander("🎯 Cluster Embeddings", expanded=False): + filtered_df = st.session_state.get("filtered_df", None) + + if filtered_df is None or len(filtered_df) == 0: + st.info("Apply filters first to enable clustering.") + return False, 5, "TSNE", "auto", "auto", 8, None + + st.markdown(f"**Ready to cluster:** {len(filtered_df):,} records") + + # Estimate memory requirements + emb_dim = len(filtered_df['emb'].iloc[0]) + n_samples = len(filtered_df) + est_memory_mb = (n_samples * emb_dim * 4) / (1024 * 1024) # float32 + + if est_memory_mb > 1000: + st.warning(f"⚠️ Large dataset: ~{est_memory_mb:.0f} MB for embeddings. Consider filtering further if GPU memory is limited.") + + # Cluster count options + cluster_method = st.radio( + "Cluster count method:", + ["Specify number", "Use column values"], + horizontal=True + ) + + if cluster_method == "Specify number": + n_clusters = st.slider("Number of clusters", 2, min(100, len(filtered_df)//2), 5) + cluster_column = None + else: + # Get categorical columns for clustering + column_info = st.session_state.get("column_info", {}) + categorical_cols = [k for k, v in column_info.items() if v['type'] == 'categorical'] + + if categorical_cols: + cluster_column = st.selectbox( + "Use unique values from column:", + categorical_cols, + help="Number of clusters = unique values in selected column" + ) + if cluster_column in filtered_df.columns: + n_clusters = filtered_df[cluster_column].nunique() + st.info(f"Using **{n_clusters}** clusters from {cluster_column}") + else: + n_clusters = 5 + else: + st.warning("No categorical columns available") + n_clusters = 5 + cluster_column = None + + reduction_method = st.selectbox( + "Dimensionality Reduction", + ["TSNE", "PCA", "UMAP"], + help="For 2D visualization only. Clustering uses full embeddings." + ) + + # Backend controls + dim_reduction_backend, clustering_backend, n_workers, seed = render_clustering_backend_controls() + + cluster_button = st.button("Run Clustering", type="primary") + + if cluster_button: + run_clustering_with_error_handling( + filtered_df, n_clusters, reduction_method, + dim_reduction_backend, clustering_backend, n_workers, seed, + cluster_column if cluster_method == "Use column values" else None + ) + + return cluster_button, n_clusters, reduction_method, dim_reduction_backend, clustering_backend, n_workers, seed + + +def run_clustering_with_error_handling( + filtered_df: pd.DataFrame, + n_clusters: int, + reduction_method: str, + dim_reduction_backend: str, + clustering_backend: str, + n_workers: int, + seed: Optional[int], + cluster_column: Optional[str] = None +): + """ + Run clustering with comprehensive error handling for VRAM and CUDA issues. + """ + try: + # Check CUDA availability + cuda_available, device_info = check_cuda_available() + + # Resolve auto backends + actual_dim_backend = resolve_backend(dim_reduction_backend, "reduction") + actual_cluster_backend = resolve_backend(clustering_backend, "clustering") + + # Log clustering start + logger.info("=" * 60) + logger.info("CLUSTERING START") + logger.info("=" * 60) + logger.info(f"Device: {device_info} (CUDA: {'Yes' if cuda_available else 'No'})") + logger.info(f"Dim Reduction Backend: {actual_dim_backend} (requested: {dim_reduction_backend})") + logger.info(f"Clustering Backend: {actual_cluster_backend} (requested: {clustering_backend})") + + # Extract embeddings + t_start = time.time() + with st.spinner("Extracting embeddings..."): + embeddings = extract_embeddings_safe(filtered_df) + st.session_state.embeddings = embeddings + t_extract = time.time() - t_start + + n_samples, emb_dim = embeddings.shape + mem_mb = (n_samples * emb_dim * 4) / (1024 * 1024) + + logger.info(f"Records: {n_samples:,} | Embedding dim: {emb_dim}") + logger.info(f"Memory: ~{mem_mb:.1f} MB | Clusters: {n_clusters}") + logger.info(f"Embeddings extracted ({t_extract:.2f}s)") + + # Run clustering with automatic GPU fallback + t_cluster_start = time.time() + with st.spinner(f"Running {reduction_method} + KMeans..."): + df_plot, labels = ClusteringService.run_clustering_safe( + embeddings, + filtered_df['uuid'].tolist(), + n_clusters, + reduction_method, + n_workers, + actual_dim_backend, + actual_cluster_backend, + seed + ) + + t_cluster = time.time() - t_cluster_start + t_total = time.time() - t_start + + # Log clustering completion to console + logger.info(f"{reduction_method} + KMeans completed ({t_cluster:.2f}s)") + logger.info(f"Total time: {t_total:.2f}s") + + # Create enhanced plot dataframe + df_plot = create_cluster_dataframe(filtered_df.reset_index(drop=True), df_plot[['x', 'y']].values, labels) + + # Handle column-based cluster names + if cluster_column and cluster_column in filtered_df.columns: + filtered_reset = filtered_df.reset_index(drop=True) + unique_taxa = sorted(filtered_df[cluster_column].dropna().unique()) + taxon_to_id = {taxon: str(i) for i, taxon in enumerate(unique_taxa)} + + taxonomic_names = [] + numeric_clusters = [] + + for idx in range(len(df_plot)): + taxon_value = filtered_reset.iloc[idx][cluster_column] + if pd.notna(taxon_value) and taxon_value in taxon_to_id: + taxonomic_names.append(str(taxon_value)) + numeric_clusters.append(taxon_to_id[taxon_value]) + else: + taxonomic_names.append("Unknown") + numeric_clusters.append(str(len(unique_taxa))) + + df_plot['cluster'] = numeric_clusters + df_plot['cluster_name'] = taxonomic_names + st.session_state.taxonomic_clustering = {'is_taxonomic': True, 'column': cluster_column} + else: + df_plot['cluster_name'] = df_plot['cluster'].copy() + st.session_state.taxonomic_clustering = {'is_taxonomic': False} + + # Store results with data version tracking + data_hash = hashlib.md5(f"{len(df_plot)}_{n_clusters}_{reduction_method}".encode()).hexdigest()[:8] + + st.session_state.data = df_plot + st.session_state.labels = labels + st.session_state.data_version = data_hash # Track data version for selection validation + st.session_state.selected_image_idx = None # User must click to select (not auto-select) + st.session_state.filtered_df_for_clustering = filtered_df.reset_index(drop=True) + + # Final log with success + logger.info(f"Clustering complete: {n_clusters} clusters found") + logger.info("=" * 60) + + st.success(f"Clustering complete! {n_clusters} clusters found.") + + except (RuntimeError, OSError) as e: + if is_oom_error(e): + st.error("**GPU Out of Memory**") + st.info("Try: Reduce dataset size with more filters, use 'sklearn' backend, or use PCA") + logger.exception("GPU OOM error during clustering") + else: + st.error(f"Error during clustering: {e}") + logger.exception("Clustering error") + + except MemoryError: + st.error("**System Out of Memory** - Reduce dataset size") + logger.exception("System memory exhausted during clustering") + + except Exception as e: + st.error(f"Error: {e}") + logger.exception("Unexpected clustering error") + + +def create_cluster_dataframe(df: pd.DataFrame, embeddings_2d: np.ndarray, labels: np.ndarray) -> pd.DataFrame: + """Create a dataframe for clustering visualization.""" + df_plot = pd.DataFrame({ + "x": embeddings_2d[:, 0], + "y": embeddings_2d[:, 1], + "cluster": labels.astype(str), + "uuid": df['uuid'].values, + "idx": range(len(df)) + }) + + # Add available metadata columns for tooltips + for col in df.columns: + if col not in ['uuid', 'emb', 'embedding', 'embeddings'] and col not in df_plot.columns: + df_plot[col] = df[col].values + + return df_plot diff --git a/apps/precalculated/components/visualization.py b/apps/precalculated/components/visualization.py new file mode 100644 index 0000000..e40ff14 --- /dev/null +++ b/apps/precalculated/components/visualization.py @@ -0,0 +1,10 @@ +""" +Visualization components for the precalculated embeddings application. + +This module re-exports from shared for backwards compatibility. +""" + +# Re-export scatter plot from shared module +from shared.components.visualization import render_scatter_plot + +__all__ = ['render_scatter_plot'] diff --git a/components/clustering/__init__.py b/components/clustering/__init__.py deleted file mode 100644 index 264c68b..0000000 --- a/components/clustering/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -""" -UI components for the clustering. -""" diff --git a/components/clustering/visualization.py b/components/clustering/visualization.py deleted file mode 100644 index 64f7416..0000000 --- a/components/clustering/visualization.py +++ /dev/null @@ -1,101 +0,0 @@ -""" -Visualization components for the clustering page. -""" - -import streamlit as st -import altair as alt -import os -from typing import Optional - - -def render_scatter_plot(): - """Render the main clustering scatter plot.""" - df_plot = st.session_state.get("data", None) - labels = st.session_state.get("labels", None) - selected_idx = st.session_state.get("selected_image_idx", 0) - - if df_plot is not None and len(df_plot) > 1: - point_selector = alt.selection_point(fields=["idx"], name="point_selection") - - # Determine tooltip fields based on available columns - tooltip_fields = [] - - # Use cluster_name for display if available (taxonomic clustering), otherwise use cluster - if 'cluster_name' in df_plot.columns: - tooltip_fields.append('cluster_name:N') - cluster_legend_field = 'cluster_name:N' - cluster_legend_title = "Cluster" - else: - tooltip_fields.append('cluster:N') - cluster_legend_field = 'cluster:N' - cluster_legend_title = "Cluster" - - # Add metadata fields if available (for precalculated embeddings) - metadata_fields = ['scientific_name', 'common_name', 'family', 'genus', 'species', 'uuid'] - for field in metadata_fields: - if field in df_plot.columns: - tooltip_fields.append(field) - - # Add file_name if available (for image clustering) - if 'file_name' in df_plot.columns: - tooltip_fields.append('file_name') - - # Determine title based on data type - if 'uuid' in df_plot.columns: - title = "Embedding Clusters (click a point to view details)" - else: - title = "Image Clusters (click a point to preview image)" - - scatter = ( - alt.Chart(df_plot) - .mark_circle(size=60) - .encode( - x=alt.X('x', scale=alt.Scale(zero=False)), - y=alt.Y('y', scale=alt.Scale(zero=False)), - color=alt.Color('cluster:N', legend=alt.Legend(title=cluster_legend_title)), - tooltip=tooltip_fields, - fillOpacity=alt.condition(point_selector, alt.value(1), alt.value(0.3)) - ) - .add_params(point_selector) - .properties( - width=800, - height=700, - title=title - ) - ) - event = st.altair_chart(scatter, key="alt_chart", on_select="rerun", use_container_width=True) - - # Handle updated event format - if ( - event - and "selection" in event - and "point_selection" in event["selection"] - and event["selection"]["point_selection"] - ): - new_idx = int(event["selection"]["point_selection"][0]["idx"]) - st.session_state["selected_image_idx"] = new_idx - - else: - st.info("Run clustering to see the cluster scatter plot.") - st.session_state['selected_image_idx'] = None - - -def render_image_preview(): - """Render the image preview panel.""" - valid_paths = st.session_state.get("valid_paths", None) - labels = st.session_state.get("labels", None) - selected_idx = st.session_state.get("selected_image_idx", 0) - - if ( - valid_paths is not None and - labels is not None and - selected_idx is not None and - 0 <= selected_idx < len(valid_paths) - ): - img_path = valid_paths[selected_idx] - cluster = labels[selected_idx] if labels is not None else "?" - st.image(img_path, caption=f"Cluster {cluster}: {os.path.basename(img_path)}", width='stretch') - st.markdown(f"**File:** `{os.path.basename(img_path)}`") - st.markdown(f"**Cluster:** `{cluster}`") - else: - st.info("Image preview will appear here after you select a cluster point.") diff --git a/components/precalculated/__init__.py b/components/precalculated/__init__.py deleted file mode 100644 index 09b6600..0000000 --- a/components/precalculated/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -""" -UI components for the precalculated embeddings page. -""" diff --git a/components/precalculated/data_preview.py b/components/precalculated/data_preview.py deleted file mode 100644 index 6449669..0000000 --- a/components/precalculated/data_preview.py +++ /dev/null @@ -1,210 +0,0 @@ -""" -Data preview components for the precalculated embeddings page. -""" - -import streamlit as st -import pandas as pd -import requests -from typing import Optional -from PIL import Image -from io import BytesIO - - -def fetch_image_from_url(url: str, timeout: int = 5) -> Optional[Image.Image]: - """ - Try to fetch an image from a URL. - - Args: - url: The image URL - timeout: Request timeout in seconds - - Returns: - PIL Image object if successful, None otherwise - """ - if not url or not isinstance(url, str): - return None - - try: - # Add common image URL patterns if needed - if not url.startswith(('http://', 'https://')): - return None - - response = requests.get(url, timeout=timeout, stream=True) - response.raise_for_status() - - # Check if content type is an image - content_type = response.headers.get('content-type', '').lower() - if not content_type.startswith('image/'): - return None - - # Try to open as image - image = Image.open(BytesIO(response.content)) - return image - - except Exception: - return None - - -def render_data_preview(): - """Render the data preview panel (replaces image preview).""" - df_plot = st.session_state.get("data", None) - labels = st.session_state.get("labels", None) - selected_idx = st.session_state.get("selected_image_idx", 0) - filtered_df = st.session_state.get("filtered_df_for_clustering", None) - - if ( - df_plot is not None and - labels is not None and - selected_idx is not None and - 0 <= selected_idx < len(df_plot) and - filtered_df is not None - ): - # Get the selected record - selected_idx = st.session_state.get("selected_image_idx", 0) - selected_uuid = df_plot.iloc[selected_idx]['uuid'] - cluster = labels[selected_idx] if labels is not None else "?" - - # Use cluster_name if available (for taxonomic clustering) - if 'cluster_name' in df_plot.columns: - cluster_display = df_plot.iloc[selected_idx]['cluster_name'] - else: - cluster_display = cluster - - # Find the full record in the original filtered dataframe - record = filtered_df[filtered_df['uuid'] == selected_uuid].iloc[0] - - st.markdown(f"### 📋 Record Details") - - # Create tabs for different types of information - tab_overview, tab_details = st.tabs(["🔍 Overview", "📊 Details"]) - - with tab_overview: - # Basic information - st.markdown(f"**Cluster:** `{cluster_display}`") - st.markdown(f"**UUID:** `{selected_uuid}`") - - # Try to fetch and display image if identifier exists - if 'identifier' in record.index and pd.notna(record['identifier']): - identifier_url = record['identifier'] - st.markdown("**Image:**") - - with st.spinner("Fetching image..."): - image = fetch_image_from_url(identifier_url) - - if image is not None: - st.image(image, caption=f"Image from: {identifier_url}", width='stretch') - else: - st.info(f"Could not fetch image from: {identifier_url}") - with st.expander("🔗 Image URL"): - st.code(identifier_url) - - with tab_details: - # Taxonomy section - st.markdown("#### 🧬 Taxonomy") - - # Show scientific and common names first - id_fields = ['scientific_name', 'common_name'] - for field in id_fields: - if field in record.index and pd.notna(record[field]): - st.markdown(f"**{field.replace('_', ' ').title()}:** {record[field]}") - - # Show taxonomic hierarchy - taxonomic_fields = ['kingdom', 'phylum', 'class', 'order', 'family', 'genus', 'species'] - hierarchy_parts = [] - for field in taxonomic_fields: - if field in record.index and pd.notna(record[field]): - hierarchy_parts.append(f"{field.title()}: {record[field]}") - - if hierarchy_parts: - st.markdown("**Taxonomic Hierarchy:**") - hierarchy_text = "\n".join([f"• {part}" for part in hierarchy_parts]) - st.code(hierarchy_text, language="text") - - # Display source information - st.markdown("#### 📊 Source Information") - source_fields = ['source_dataset', 'publisher', 'basisOfRecord', 'img_type'] - for field in source_fields: - if field in record.index and pd.notna(record[field]): - value = record[field] - if len(str(value)) > 50: # Truncate long values - value = str(value)[:47] + "..." - st.markdown(f"**{field.replace('_', ' ').title()}:** {value}") - - # Display additional metadata in an expander - with st.expander("🔍 All Metadata"): - # Create a clean dataframe for display - display_data = [] - for field, value in record.items(): - if field not in ['uuid', 'emb']: # Skip technical fields - display_data.append({ - 'Field': field.replace('_', ' ').title(), - 'Value': value if pd.notna(value) else 'null' - }) - - if display_data: - metadata_df = pd.DataFrame(display_data) - st.dataframe(metadata_df, hide_index=True, width='stretch') - - else: - st.info("📋 Record details will appear here after you select a point in the cluster plot.") - - # Show dataset summary if we have filtered data - filtered_df = st.session_state.get("filtered_df", None) - if filtered_df is not None and len(filtered_df) > 0: - st.markdown("### 📈 Dataset Summary") - st.markdown(f"**Total records:** {len(filtered_df):,}") - - # Show distribution of key fields - summary_fields = ['kingdom', 'family', 'source_dataset', 'img_type'] - for field in summary_fields: - if field in filtered_df.columns: - non_null_count = filtered_df[field].notna().sum() - unique_count = filtered_df[field].nunique() - st.markdown(f"**{field.replace('_', ' ').title()}:** {unique_count} unique values ({non_null_count:,} non-null)") - - -def render_cluster_statistics(): - """Render cluster-level statistics.""" - df_plot = st.session_state.get("data", None) - labels = st.session_state.get("labels", None) - filtered_df = st.session_state.get("filtered_df_for_clustering", None) - - if df_plot is not None and labels is not None and filtered_df is not None: - st.markdown("### 📊 Cluster Statistics") - - # Create cluster summary - cluster_summary = [] - - # Check if we have taxonomic clustering with cluster names - if 'cluster_name' in df_plot.columns: - # Use cluster names for display, but group by cluster ID for consistency - unique_cluster_ids = sorted(df_plot['cluster'].unique(), key=lambda x: int(x)) - - for cluster_id in unique_cluster_ids: - cluster_mask = df_plot['cluster'] == cluster_id - cluster_size = cluster_mask.sum() - cluster_percentage = (cluster_size / len(df_plot)) * 100 - - # Get the cluster name for this cluster ID - cluster_name = df_plot[cluster_mask]['cluster_name'].iloc[0] if cluster_size > 0 else str(cluster_id) - - cluster_summary.append({ - 'Cluster': cluster_name, - 'Size': cluster_size, - 'Percentage': f"{cluster_percentage:.1f}%" - }) - else: - # Standard numeric clustering - for cluster_id in sorted(df_plot['cluster'].unique(), key=int): - cluster_mask = df_plot['cluster'] == cluster_id - cluster_size = cluster_mask.sum() - cluster_percentage = (cluster_size / len(df_plot)) * 100 - - cluster_summary.append({ - 'Cluster': int(cluster_id), - 'Size': cluster_size, - 'Percentage': f"{cluster_percentage:.1f}%" - }) - - summary_df = pd.DataFrame(cluster_summary) - st.dataframe(summary_df, hide_index=True, width='stretch') diff --git a/components/precalculated/sidebar.py b/components/precalculated/sidebar.py deleted file mode 100644 index 1acf955..0000000 --- a/components/precalculated/sidebar.py +++ /dev/null @@ -1,395 +0,0 @@ -""" -Sidebar components for the precalculated embeddings page. -""" - -import streamlit as st -import pandas as pd -import pyarrow as pa -import os -from typing import Dict, Any, Optional, Tuple - -from services.parquet_service import ParquetService -from services.clustering_service import ClusteringService -from components.shared.clustering_controls import render_clustering_backend_controls, render_basic_clustering_controls - - -def render_file_section() -> Tuple[bool, Optional[str]]: - """ - Render the file loading section. - - Returns: - Tuple of (file_loaded, file_path) - """ - with st.expander("📁 Load Parquet File", expanded=True): - file_path = st.text_input( - "Parquet file path", - help="Path to your parquet file containing embeddings and metadata. Large files are loaded efficiently." - ) - - - load_button = st.button("Load File") - - if load_button and file_path and os.path.exists(file_path): - try: - with st.spinner("Loading parquet file..."): - # Use the efficient PyArrow loader - table, df = ParquetService.load_and_filter_efficient(file_path) - - # Validate structure (works with both PyArrow table and pandas DataFrame) - is_valid, issues = ParquetService.validate_parquet_structure(table) - - if not is_valid: - st.error("File validation failed:") - for issue in issues: - st.error(f"• {issue}") - return False, file_path - - # Store both PyArrow table and DataFrame in session state - st.session_state.parquet_table = table # PyArrow table for efficient operations - st.session_state.parquet_df = df # pandas DataFrame for compatibility - st.session_state.parquet_file_path = file_path - st.session_state.column_info = ParquetService.get_column_info(table) # Use PyArrow for analysis - - # Reset downstream state - st.session_state.filtered_df = None - st.session_state.embeddings = None - st.session_state.data = None - st.session_state.labels = None - st.session_state.selected_image_idx = None - - st.success(f"✅ Loaded {len(df):,} records from parquet file") - st.info(f"Embedding dimension: {len(df['emb'].iloc[0])}") - - return True, file_path - - except Exception as e: - st.error(f"Error loading file: {e}") - return False, file_path - - elif load_button and file_path: - st.error(f"File not found: {file_path}") - return False, file_path - elif load_button: - st.error("Please provide a file path") - return False, None - - return False, file_path - - -def render_filter_section() -> Dict[str, Any]: - """ - Render the metadata filtering section. - - Returns: - Dictionary of applied filters - """ - with st.expander("🔍 Filter Data", expanded=True): - df = st.session_state.get("parquet_df", None) - column_info = st.session_state.get("column_info", {}) - - if df is None: - st.info("Load a parquet file first to enable filtering.") - return {} - - st.markdown(f"**Total records:** {len(df):,}") - - filters = {} - - # Define taxonomy columns in order - taxonomy_columns = ['kingdom', 'phylum', 'class', 'order', 'family', 'genus', 'species', 'scientific_name', 'common_name'] - - # Separate taxonomy and other columns - taxonomy_filters = [] - other_filters = [] - - for col, info in column_info.items(): - # Skip technical columns and empty columns - if col in ['source_id', 'identifier', 'resolution_status', 'uuid', 'emb'] or info['type'] == 'empty': - continue - - if col in taxonomy_columns: - taxonomy_filters.append((col, info)) - else: - other_filters.append((col, info)) - - # Sort taxonomy filters by their order in taxonomy_columns - taxonomy_filters.sort(key=lambda x: taxonomy_columns.index(x[0])) - - # Row 1: Taxonomy filters (up to 7 columns) - if taxonomy_filters: - st.markdown("**🌿 Taxonomy Filters**") - cols = st.columns(len(taxonomy_filters)) - - for i, (col, info) in enumerate(taxonomy_filters): - with cols[i]: - st.markdown(f"**{col.title()}**") - - if info['type'] == 'categorical': - selected_values = st.multiselect( - f"Select {col}", - options=info['unique_values'], - key=f"filter_{col}", - help=f"{len(info['unique_values'])} unique values" - ) - if selected_values: - filters[col] = selected_values - elif info['type'] == 'text': - search_text = st.text_input( - f"Search {col}", - key=f"filter_{col}", - help="Case-insensitive search" - ) - if search_text.strip(): - filters[col] = search_text.strip() - - # Rows 2+: Other metadata filters (5-7 per row) - if other_filters: - st.markdown("**📋 Metadata Filters**") - - # Group other filters into rows of 6 - filters_per_row = 6 - for row_start in range(0, len(other_filters), filters_per_row): - row_filters = other_filters[row_start:row_start + filters_per_row] - cols = st.columns(len(row_filters)) - - for i, (col, info) in enumerate(row_filters): - with cols[i]: - st.markdown(f"**{col}**") - - if info['type'] == 'categorical': - selected_values = st.multiselect( - f"Select {col}", - options=info['unique_values'], - key=f"filter_{col}", - help=f"{len(info['unique_values'])} unique values" - ) - if selected_values: - filters[col] = selected_values - - elif info['type'] == 'numeric': - col_data = df[col].dropna() - if len(col_data) > 0: - min_val, max_val = float(col_data.min()), float(col_data.max()) - if min_val != max_val: - range_values = st.slider( - f"{col} range", - min_value=min_val, - max_value=max_val, - value=(min_val, max_val), - key=f"filter_{col}" - ) - if range_values != (min_val, max_val): - filters[col] = {'min': range_values[0], 'max': range_values[1]} - - elif info['type'] == 'text': - search_text = st.text_input( - f"Search {col}", - key=f"filter_{col}", - help="Case-insensitive search" - ) - if search_text.strip(): - filters[col] = search_text.strip() - - # Apply filters button and results - if st.button("Apply Filters", type="primary"): - if filters: - with st.spinner("Applying filters..."): - # Use PyArrow table for efficient filtering - parquet_table = st.session_state.get("parquet_table", None) - - if parquet_table is not None: - # Use efficient PyArrow filtering - filtered_table = ParquetService.apply_filters_arrow(parquet_table, filters) - filtered_df = filtered_table.to_pandas() - else: - # Convert pandas DataFrame to PyArrow table and filter - table = pa.Table.from_pandas(df) - filtered_table = ParquetService.apply_filters_arrow(table, filters) - filtered_df = filtered_table.to_pandas() - - st.session_state.filtered_df = filtered_df - st.session_state.current_filters = filters - - # Reset downstream state - st.session_state.embeddings = None - st.session_state.data = None - st.session_state.labels = None - st.session_state.selected_image_idx = None - - st.success(f"✅ Filtered to {len(filtered_df):,} records") - else: - # No filters applied, use full dataset - st.session_state.filtered_df = df - st.session_state.current_filters = {} - st.info("No filters applied, using full dataset") - - # Show current filter summary - current_filters = st.session_state.get("current_filters", {}) - if current_filters: - st.markdown("**Active filters:**") - for col, filter_val in current_filters.items(): - if isinstance(filter_val, list): - st.caption(f"• {col}: {len(filter_val)} values selected") - elif isinstance(filter_val, dict): - st.caption(f"• {col}: {filter_val['min']} - {filter_val['max']}") - else: - st.caption(f"• {col}: contains '{filter_val}'") - - return filters - - -def render_clustering_section() -> Tuple[bool, int, str, str, str, int, Optional[int]]: - """ - Render the clustering section. - - Returns: - Tuple of (cluster_button_clicked, n_clusters, reduction_method, dim_reduction_backend, clustering_backend, n_workers, seed) - """ - with st.expander("🎯 Cluster Embeddings", expanded=False): - filtered_df = st.session_state.get("filtered_df", None) - - if filtered_df is None or len(filtered_df) == 0: - st.info("Apply filters first to enable clustering.") - return False, 5, "TSNE", "auto", "auto", 8, None - - st.markdown(f"**Ready to cluster:** {len(filtered_df):,} records") - - # Two options for determining number of clusters - cluster_method = st.radio( - "How to determine number of clusters:", - ["Specify number", "Use taxonomic rank"], - horizontal=True - ) - - if cluster_method == "Specify number": - n_clusters = st.slider("Number of clusters", 2, min(100, len(filtered_df)//2), 5) - else: - # Option 2: Cluster by taxonomic rank - taxonomy_columns = ['kingdom', 'phylum', 'class', 'order', 'family', 'genus', 'species'] - selected_rank = st.selectbox( - "Select taxonomic rank:", - taxonomy_columns, - index=4, # Default to 'family' - help="Number of clusters will be determined by unique values in this taxonomic rank" - ) - - # Calculate number of unique values for the selected rank - if selected_rank in filtered_df.columns: - n_clusters = filtered_df[selected_rank].nunique() - st.info(f"Using **{n_clusters}** clusters based on unique {selected_rank} values") - else: - st.warning(f"Column '{selected_rank}' not found in data. Using default of 5 clusters.") - n_clusters = 5 - reduction_method = st.selectbox( - "Dimensionality Reduction (for visualization)", - ["TSNE", "PCA", "UMAP"], - help="Used only for 2D visualization. Clustering is performed on full high-dimensional embeddings for better quality." - ) - - # Backend and advanced controls - dim_reduction_backend, clustering_backend, n_workers, seed = render_clustering_backend_controls() - - cluster_button = st.button("Run Clustering", type="primary") - - if cluster_button: - try: - with st.spinner("Extracting embeddings..."): - embeddings = ParquetService.extract_embeddings(filtered_df) - st.session_state.embeddings = embeddings - - with st.spinner("Running clustering on full embeddings..."): - df_plot, labels = ClusteringService.run_clustering( - embeddings, - filtered_df['uuid'].tolist(), # Use UUIDs as "paths" - n_clusters, - reduction_method, - n_workers, # Pass the workers parameter - dim_reduction_backend, # Explicit dimensionality reduction backend - clustering_backend, # Explicit clustering backend - seed # Random seed - ) - - # Create enhanced plot dataframe with metadata - df_plot = ParquetService.create_cluster_dataframe( - filtered_df.reset_index(drop=True), - df_plot[['x', 'y']].values, - labels - ) - - # If using taxonomic clustering, enhance cluster names while preserving color mapping - if cluster_method == "Use taxonomic rank" and selected_rank in filtered_df.columns: - # Create mapping from cluster numbers to taxonomic names - filtered_df_reset = filtered_df.reset_index(drop=True) - - # Get unique taxonomic values and create consistent mapping - unique_taxa = sorted(filtered_df[selected_rank].dropna().unique()) - taxon_to_cluster_id = {taxon: str(i) for i, taxon in enumerate(unique_taxa)} - - # Create taxonomic cluster names while keeping numeric IDs for coloring - taxonomic_names = [] - numeric_clusters = [] - - for idx in range(len(df_plot)): - taxon_value = filtered_df_reset.iloc[idx][selected_rank] - if pd.notna(taxon_value) and taxon_value in taxon_to_cluster_id: - # Use the taxonomic name as display name - taxonomic_names.append(taxon_value) - # Keep numeric ID for consistent coloring - numeric_clusters.append(taxon_to_cluster_id[taxon_value]) - else: - # Handle missing values - unknown_name = f"Unknown {selected_rank}" - taxonomic_names.append(unknown_name) - # Assign a high numeric ID for unknowns - numeric_clusters.append(str(len(unique_taxa))) - - # Store both versions: display names and numeric IDs - df_plot['cluster'] = numeric_clusters # Keep numeric for consistent coloring - df_plot['cluster_name'] = taxonomic_names # Add taxonomic names for display - - # Store taxonomic clustering metadata - st.session_state.taxonomic_clustering = { - 'is_taxonomic': True, - 'rank': selected_rank, - 'taxon_to_id': taxon_to_cluster_id - } - else: - # Standard numeric clustering - use cluster IDs as names too - df_plot['cluster_name'] = df_plot['cluster'].copy() - st.session_state.taxonomic_clustering = {'is_taxonomic': False} - - # Store results - st.session_state.data = df_plot - st.session_state.labels = labels - st.session_state.selected_image_idx = 0 - st.session_state.filtered_df_for_clustering = filtered_df.reset_index(drop=True) - - st.success(f"✅ Clustering complete! Found {n_clusters} clusters.") - - except Exception as e: - st.error(f"Error during clustering: {e}") - - return cluster_button, n_clusters, reduction_method, dim_reduction_backend, clustering_backend, n_workers, seed - - -def render_precalculated_sidebar(): - """Render the complete precalculated embeddings sidebar.""" - # Load & Filter sections at the top (no tabs) - file_loaded, file_path = render_file_section() - filters = render_filter_section() - - # Clustering section below - cluster_button, n_clusters, reduction_method, dim_reduction_backend, clustering_backend, n_workers, seed = render_clustering_section() - - return { - 'file_loaded': file_loaded, - 'file_path': file_path, - 'filters': filters, - 'cluster_button': cluster_button, - 'n_clusters': n_clusters, - 'reduction_method': reduction_method, - 'dim_reduction_backend': dim_reduction_backend, - 'clustering_backend': clustering_backend, - 'n_workers': n_workers, - 'seed': seed, - } diff --git a/components/shared/__init__.py b/components/shared/__init__.py deleted file mode 100644 index e26e5ae..0000000 --- a/components/shared/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -""" -Shared components package. -""" diff --git a/data/README.md b/data/README.md new file mode 100644 index 0000000..2df38ec --- /dev/null +++ b/data/README.md @@ -0,0 +1,28 @@ +# Example Data + +`example_1k.parquet`: a small sample for trying out the Precalculated Embedding Exploration app. + +It contains [BioCLIP 2](https://huggingface.co/imageomics/bioclip-2) embeddings for 1,030 randomly sampled images from [TreeOfLife-200M](https://huggingface.co/datasets/imageomics/TreeOfLife-200M) (embeddings only, not the images themselves). Taxonomic information and other metadata comes from `catalog.parquet` in the TOL-200M repo. + +## Schema + +``` +uuid: string +emb: list +source_dataset: string +source_id: string +kingdom: string +phylum: string +class: string +order: string +family: string +genus: string +species: string +scientific_name: string +common_name: string +resolution_status: string +publisher: string +basisOfRecord: string +identifier: string +img_type: string +``` diff --git a/docs/BACKEND_PIPELINE.md b/docs/BACKEND_PIPELINE.md new file mode 100644 index 0000000..43c1209 --- /dev/null +++ b/docs/BACKEND_PIPELINE.md @@ -0,0 +1,132 @@ +# Backend Pipeline + +A quick walkthrough of what happens to your embeddings from the moment you click +"Run Clustering" to the scatter plot on screen. + +## The Pipeline at a Glance + +``` +Raw Embeddings (from parquet or model) + │ + ├─ Validate: check for NaN/Inf, cast to float32 + ├─ L2 Normalize: project onto unit hypersphere + │ + ├─► Step 1: KMeans Clustering (high-dimensional) + │ Backend: cuML → FAISS → sklearn + │ + ├─► Step 2: Dimensionality Reduction to 2D + │ Method: PCA / t-SNE / UMAP + │ Backend: cuML → sklearn + │ + └─► Scatter Plot (Altair) + Color = cluster, position = 2D projection +``` + +## Step 0: Embedding Preparation + +Before any computation, every embedding goes through `_prepare_embeddings()`: + +1. **Cast to float32** — GPU backends require it; keeps memory predictable. +2. **NaN/Inf check** — replaces bad values with 0 and logs a warning. +3. **L2 normalization** — divides each vector by its magnitude so every point + sits on the unit hypersphere. This is critical for two reasons: + - Prevents cuML UMAP's NN-descent from crashing with SIGFPE on + large-magnitude vectors (see `investigation/cuml_umap_sigfpe/`). + - Appropriate for contrastive embeddings (CLIP, BioCLIP) whose training + objective is cosine-similarity based — magnitude isn't a learned signal. + +Input norms are logged so you can always verify what came in. + +## Step 1: KMeans Clustering + +Clusters the full high-dimensional embeddings (e.g., 768-d for BioCLIP 2). +Runs *before* dimensionality reduction so clusters are based on the full +feature space, not a lossy 2D projection. + +| Backend | When It's Used | How It Works | +|---------|---------------|--------------| +| **cuML** | GPU available + >500 samples | GPU-accelerated KMeans via RAPIDS. Runs on CuPy arrays. Falls back to sklearn on any error. | +| **FAISS** | No GPU + >500 samples | Facebook's optimized CPU KMeans using L2 index. Fast for medium datasets. Falls back to sklearn on error. | +| **sklearn** | Small datasets or fallback | Standard scikit-learn KMeans. Always works, no special dependencies. | + +**Auto-selection priority:** cuML > FAISS > sklearn. You can override in the sidebar. + +## Step 2: Dimensionality Reduction + +Projects embeddings from high-dimensional space down to 2D for visualization. +This is purely for the scatter plot — clustering uses the full-dimensional data. + +### PCA (Principal Component Analysis) + +The fastest option. Linear projection onto the two directions of maximum variance. +Good for getting a quick overview; doesn't capture nonlinear structure. + +| Backend | Notes | +|---------|-------| +| **cuML** | GPU-accelerated, near-instant even on large datasets | +| **sklearn** | CPU-based, still fast since PCA is O(n) | + +### t-SNE + +Nonlinear method that preserves local neighborhoods. Good at revealing clusters +but slow on large datasets. Perplexity is auto-adjusted based on sample size. + +| Backend | Notes | +|---------|-------| +| **cuML** | GPU-accelerated, handles thousands of samples well | +| **sklearn** | CPU-based, can be slow above ~5k samples | + +### UMAP + +The recommended default. Nonlinear like t-SNE but faster and better at +preserving global structure. Neighbor count is auto-adjusted. + +| Backend | Notes | +|---------|-------| +| **cuML** | Runs in an **isolated subprocess** so a crash doesn't kill the app. The subprocess verifies L2 normalization as a safety net. Falls back to sklearn on failure. | +| **sklearn** | CPU-based `umap-learn`. Slower but numerically stable. | + +**Why the subprocess?** cuML UMAP's NN-descent algorithm can occasionally trigger +a SIGFPE (floating-point exception) that kills the process instantly — no Python +try/except can catch it. The subprocess isolates this risk. + +## Backend Selection + +When you select "auto" (the default), the app picks the fastest available backend: + +| Operation | Auto Logic | +|-----------|-----------| +| KMeans | cuML if GPU + >500 samples, else FAISS if available + >500 samples, else sklearn | +| Dim. Reduction | cuML if GPU + >5000 samples, else sklearn | + +Any GPU error (architecture mismatch, missing libraries, out of memory (OOM)) triggers an +automatic retry with sklearn. OOM errors are surfaced to the user with guidance. + +## Logging + +Every step is logged to `logs/emb_explorer.log` (DEBUG level) and console (INFO): + +- Embedding extraction: shape, dtype +- Preparation: input norms (min/max/mean), non-finite count, L2 normalization +- Backend selection: which backend was chosen and why +- KMeans: cluster count, sample count, elapsed time +- Reduction: method, sample count, elapsed time +- Fallbacks: what failed and what we fell back to +- Visualization: point selection events, density mode changes + +Check the log file for the full picture when debugging. + +## GPU Fallback Chain + +``` +cuML (GPU) + │ error? + ▼ +FAISS (CPU, optimized) ← KMeans only + │ error? + ▼ +sklearn (CPU, always works) +``` + +The app is designed to *always produce a result*. GPU acceleration is a +nice-to-have, never a hard requirement. diff --git a/docs/DATA_FORMAT.md b/docs/DATA_FORMAT.md new file mode 100644 index 0000000..7ac3cdb --- /dev/null +++ b/docs/DATA_FORMAT.md @@ -0,0 +1,78 @@ +# Precalculated Embeddings: Expected Parquet Format + +The precalculated embeddings app loads a `.parquet` file **or a directory of +`.parquet` files** (Hive-partitioned or flat) containing precomputed embedding +vectors alongside arbitrary metadata columns. When a directory is provided, +all parquet files within it are read and concatenated automatically. + +## Column Requirements + +### Must Have + +| Column | Type | Description | +|--------|------|-------------| +| `uuid` | `string` | Unique identifier for each record. Used for filtering, selection, and cross-referencing between views. | +| `emb` | `list` | Precomputed embedding vector. All rows must have the same dimensionality. Used for KMeans clustering and dimensionality reduction (PCA/t-SNE/UMAP). | + +The app validates these two columns on load and will reject files missing either. + +### Good to Have + +These columns unlock additional features but are not required. + +| Column | Type | Feature Enabled | +|--------|------|-----------------| +| `identifier` or `image_url` or `url` or `img_url` or `image` | `string` (URL) | **Image preview** in the detail panel. The app tries these column names in order and displays the first valid HTTP(S) image URL found. | +| `kingdom`, `phylum`, `class`, `order`, `family`, `genus`, `species` | `string` | **Taxonomic tree** summary. Any subset works; missing levels default to "Unknown". At minimum `kingdom` must be present and non-null for a row to appear in the tree. | + +### Optional (Auto-Detected) + +All other columns are automatically analyzed on load: + +- **Categorical** (<=100 unique values): Rendered as multi-select dropdown filters with cascading AND logic. +- **Numeric** (int/float): Rendered as range slider filters. +- **Text** (>100 unique string values): Rendered as case-insensitive substring search filters. +- **List/array columns**: Skipped (assumed to be embeddings or similar). + +These columns also appear in the record detail panel when a scatter plot point is selected. + +### Excluded from Filters + +Columns named `uuid`, `emb`, `embedding`, `embeddings`, or `vector` are +automatically excluded from the filter UI and metadata display. + +## Minimal Example + +```python +import pandas as pd +import numpy as np + +df = pd.DataFrame({ + "uuid": ["a1", "a2", "a3"], + "emb": [np.random.randn(512).tolist() for _ in range(3)], +}) +df.to_parquet("minimal.parquet") +``` + +## Full Example (with taxonomy and images) + +```python +df = pd.DataFrame({ + "uuid": ["a1", "a2", "a3"], + "emb": [np.random.randn(512).tolist() for _ in range(3)], + "identifier": [ + "https://example.com/img1.jpg", + "https://example.com/img2.jpg", + "https://example.com/img3.jpg", + ], + "kingdom": ["Animalia", "Animalia", "Plantae"], + "phylum": ["Chordata", "Chordata", "Magnoliophyta"], + "class": ["Mammalia", "Aves", "Magnoliopsida"], + "order": ["Carnivora", "Passeriformes", "Rosales"], + "family": ["Felidae", "Corvidae", "Rosaceae"], + "genus": ["Panthera", "Corvus", "Rosa"], + "species": ["Panthera leo", "Corvus corax", "Rosa canina"], + "source": ["iNaturalist", "iNaturalist", "GBIF"], # auto-detected as categorical filter +}) +df.to_parquet("full.parquet") +``` diff --git a/lib/__init__.py b/lib/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/list_models.py b/list_models.py deleted file mode 100755 index 7d451c0..0000000 --- a/list_models.py +++ /dev/null @@ -1,69 +0,0 @@ -#!/usr/bin/env python3 -""" -Command-line script to list available models from the emb-explorer utils. -""" - -import json -import argparse -import sys -from pathlib import Path - -# Add the project root to the Python path -project_root = Path(__file__).parent -sys.path.insert(0, str(project_root)) - -from utils.models import list_available_models - - -def main(): - """Main function to list available models.""" - parser = argparse.ArgumentParser( - description="List all available models for the embedding explorer" - ) - parser.add_argument( - "--format", - choices=["json", "table", "names"], - default="json", - help="Output format (default: json)" - ) - parser.add_argument( - "--pretty", - action="store_true", - help="Pretty print JSON output" - ) - - args = parser.parse_args() - - try: - models = list_available_models() - - if args.format == "json": - if args.pretty: - print(json.dumps(models, indent=2)) - else: - print(json.dumps(models)) - - elif args.format == "table": - print(f"{'Model Name':<40} {'Pretrained':<30}") - print("-" * 70) - for model in models: - name = model['name'] - pretrained = model['pretrained'] or "None" - print(f"{name:<40} {pretrained:<30}") - - elif args.format == "names": - for model in models: - name = model['name'] - pretrained = model['pretrained'] - if pretrained: - print(f"{name} ({pretrained})") - else: - print(name) - - except Exception as e: - print(f"Error: {e}", file=sys.stderr) - sys.exit(1) - - -if __name__ == "__main__": - main() diff --git a/pages/01_Clustering.py b/pages/01_Clustering.py deleted file mode 100644 index 6c8577d..0000000 --- a/pages/01_Clustering.py +++ /dev/null @@ -1,54 +0,0 @@ -""" -Clustering page for the embedding explorer. -""" - -import streamlit as st -import os - -from components.clustering.sidebar import render_clustering_sidebar -from components.clustering.visualization import render_scatter_plot, render_image_preview -from components.clustering.summary import render_clustering_summary - - -def main(): - """Main clustering page function.""" - st.set_page_config( - layout="wide", - page_title="Image Clustering", - page_icon="🔍" - ) - - # Clear precalculated embeddings data to prevent carry-over - if "page_type" not in st.session_state or st.session_state.page_type != "clustering": - # Clear precalculated data - precalc_keys = ["parquet_df", "parquet_file_path", "column_info", "filtered_df", - "current_filters", "filtered_df_for_clustering"] - for key in precalc_keys: - if key in st.session_state: - del st.session_state[key] - st.session_state.page_type = "clustering" - - st.title("🔍 Image Clustering") - - # Create the main layout - col_settings, col_plot, col_preview = st.columns([2, 6, 3]) - - with col_settings: - # Render the sidebar with all controls - sidebar_state = render_clustering_sidebar() - - with col_plot: - # Render the main scatter plot - render_scatter_plot() - - with col_preview: - # Render the image preview - render_image_preview() - - # Bottom section: Clustering summary - st.markdown("---") - render_clustering_summary() - - -if __name__ == "__main__": - main() diff --git a/pages/02_Precalculated_Embeddings.py b/pages/02_Precalculated_Embeddings.py deleted file mode 100644 index 0e4dd42..0000000 --- a/pages/02_Precalculated_Embeddings.py +++ /dev/null @@ -1,66 +0,0 @@ -""" -Precalculated Embeddings page for the embedding explorer. -Works with parquet files containing precomputed embeddings and metadata. -""" - -import streamlit as st - -from components.precalculated.sidebar import ( - render_precalculated_sidebar, - render_file_section, - render_filter_section, - render_clustering_section -) -from components.clustering.visualization import render_scatter_plot -from components.precalculated.data_preview import render_data_preview -from components.clustering.summary import render_clustering_summary - - -def main(): - """Main precalculated embeddings page function.""" - st.set_page_config( - layout="wide", - page_title="Precalculated Embeddings", - page_icon="📊" - ) - - # Clear clustering page data to prevent carry-over - if "page_type" not in st.session_state or st.session_state.page_type != "precalculated": - # Clear regular clustering data - clustering_keys = ["embeddings", "valid_paths", "last_image_dir", "embedding_complete"] - for key in clustering_keys: - if key in st.session_state: - del st.session_state[key] - st.session_state.page_type = "precalculated" - - st.title("📊 Precalculated Embeddings") - st.markdown("Load and cluster precomputed embeddings from parquet files with metadata filtering.") - - # Row 1: Load Parquet File section - file_loaded, file_path = render_file_section() - - # Row 2: Filter Data section - filters = render_filter_section() - - # Row 3: Main content layout with clustering controls, plot, and preview - col_settings, col_plot, col_preview = st.columns([2, 7, 3]) - - with col_settings: - # Render only the clustering section in the sidebar - cluster_button, n_clusters, reduction_method, dim_reduction_backend, clustering_backend, n_workers, seed = render_clustering_section() - - with col_plot: - # Render the main scatter plot - render_scatter_plot() - - with col_preview: - # Render the data preview (metadata instead of images) - render_data_preview() - - # Bottom section: Clustering summary with taxonomy tree - st.markdown("---") - render_clustering_summary(show_taxonomy=True) - - -if __name__ == "__main__": - main() diff --git a/pyproject.toml b/pyproject.toml index 08c0f2a..a900b0f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,9 +31,9 @@ classifiers = [ requires-python = ">=3.10" dependencies = [ # Core UI and web framework - "streamlit>=1.40.0", + "streamlit>=1.50.0", # Data processing and numerical computing - "numpy>=1.21.0", + "numpy<2.3", # capped: numba 0.61.x requires numpy <2.3 "pandas>=2.0.0", "pillow>=9.0.0", "pyarrow>=10.0.0", @@ -60,25 +60,26 @@ dev = [ "flake8>=4.0.0", "mypy>=0.950", ] -# GPU acceleration with cuDF and cuML +# GPU acceleration — pick the extra matching your system CUDA version: +# pip install -e ".[gpu-cu12]" # CUDA 12.x (e.g. OSC Pitzer) +# pip install -e ".[gpu-cu13]" # CUDA 13.x +# "gpu" is an alias for gpu-cu12 (most common HPC setup). gpu = [ - # PyTorch for CUDA detection and some models + "emb-explorer[gpu-cu12]", +] +gpu-cu12 = [ "torch>=2.0.0", - # NVIDIA CUDA runtime libraries (required for cuDF/cuML) - "nvidia-cublas-cu12", - "nvidia-cuda-runtime-cu12", - "nvidia-cudnn-cu12", - "nvidia-cufft-cu12", - "nvidia-curand-cu12", - "nvidia-cusolver-cu12", - "nvidia-cusparse-cu12", - # Essential RAPIDS packages - "cudf-cu12==25.6.*", - "cuml-cu12==25.6.*", - # Fast GPU clustering + "cuml-cu12>=25.6", "faiss-gpu-cu12>=1.11.0", + "pynvml>=11.0.0", +] +gpu-cu13 = [ + "torch>=2.0.0", + "cuml-cu13>=25.12", + "faiss-gpu-cu12>=1.11.0", # no cu13 build on PyPI; cu12 works via CUDA backward compat + "pynvml>=11.0.0", ] -# Minimal GPU support (just PyTorch + FAISS GPU) +# Minimal GPU support (just PyTorch + FAISS GPU, no RAPIDS) gpu-minimal = [ "torch>=2.0.0", "faiss-gpu-cu12>=1.11.0", @@ -93,28 +94,25 @@ Repository = "https://github.com/Imageomics/emb-explorer" Issues = "https://github.com/Imageomics/emb-explorer/issues" [project.scripts] -emb-explorer = "app:main" -list-models = "utils.models:list_available_models" +emb-embed-explore = "apps.embed_explore.app:main" +emb-precalculated = "apps.precalculated.app:main" +list-models = "shared.utils.models:print_available_models" [tool.hatch.version] -path = "utils/__init__.py" +path = "shared/__init__.py" [tool.hatch.metadata] allow-direct-references = true [tool.hatch.build.targets.wheel] -packages = ["utils"] +packages = ["shared", "apps"] [tool.hatch.build.targets.sdist] include = [ - "/utils", - "/app.py", - "/list_models.py", - "/setup.sh", + "/shared", + "/apps", "/README.md", "/LICENSE", - "/requirements.txt", - "/data", ] [tool.black] @@ -147,3 +145,4 @@ python_files = ["test_*.py", "*_test.py"] python_classes = ["Test*"] python_functions = ["test_*"] addopts = "-v --tb=short" +markers = ["gpu: requires GPU hardware (deselect with '-m not gpu')"] diff --git a/services/__init__.py b/services/__init__.py deleted file mode 100644 index b309a2e..0000000 --- a/services/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -""" -Server-side business logic for the embedding explorer. -""" diff --git a/services/clustering_service.py b/services/clustering_service.py deleted file mode 100644 index 6bef2fb..0000000 --- a/services/clustering_service.py +++ /dev/null @@ -1,113 +0,0 @@ -""" -Clustering service. -""" - -import numpy as np -import pandas as pd -import os -from typing import Tuple, Dict, List, Any - -from utils.clustering import run_kmeans, reduce_dim - - -class ClusteringService: - """Service for handling clustering workflows""" - - @staticmethod - def run_clustering( - embeddings: np.ndarray, - valid_paths: List[str], - n_clusters: int, - reduction_method: str, - n_workers: int = 1, - dim_reduction_backend: str = "auto", - clustering_backend: str = "auto", - seed: int = None - ) -> Tuple[pd.DataFrame, np.ndarray]: - """ - Run clustering on embeddings. - - Args: - embeddings: Input embeddings - valid_paths: List of image paths - n_clusters: Number of clusters - reduction_method: Dimensionality reduction method - n_workers: Number of workers for reduction - dim_reduction_backend: Backend for dimensionality reduction ("auto", "sklearn", "faiss", "cuml") - clustering_backend: Backend for clustering ("auto", "sklearn", "faiss", "cuml") - seed: Random seed for reproducibility (None for random) - - Returns: - Tuple of (cluster dataframe, cluster labels) - """ - # Step 1: Perform K-means clustering on full high-dimensional embeddings - kmeans, labels = run_kmeans( - embeddings, # Use original high-dimensional embeddings for clustering - int(n_clusters), - seed=seed, - n_workers=n_workers, - backend=clustering_backend - ) - - # Step 2: Reduce dimensionality to 2D for visualization only - reduced = reduce_dim( - embeddings, - reduction_method, - seed=seed, - n_workers=n_workers, - backend=dim_reduction_backend - ) - - df_plot = pd.DataFrame({ - "x": reduced[:, 0], - "y": reduced[:, 1], - "cluster": labels.astype(str), - "image_path": valid_paths, - "file_name": [os.path.basename(p) for p in valid_paths], - "idx": range(len(valid_paths)) - }) - - return df_plot, labels - - @staticmethod - def generate_clustering_summary( - embeddings: np.ndarray, - labels: np.ndarray, - df_plot: pd.DataFrame - ) -> Tuple[pd.DataFrame, Dict[int, List[int]]]: - """ - Generate clustering summary statistics and representative images. - - Args: - embeddings: Original embeddings - labels: Cluster labels - df_plot: Clustering dataframe - - Returns: - Tuple of (summary dataframe, representatives dict) - """ - cluster_ids = np.unique(labels) - summary_data = [] - representatives = {} - - for k in cluster_ids: - idxs = np.where(labels == k)[0] - cluster_embeds = embeddings[idxs] - centroid = cluster_embeds.mean(axis=0) - - # Internal variance - variance = np.mean(np.sum((cluster_embeds - centroid) ** 2, axis=1)) - - # Find 3 closest images - dists = np.sum((cluster_embeds - centroid) ** 2, axis=1) - closest_indices = idxs[np.argsort(dists)[:3]] - representatives[k] = closest_indices - - summary_data.append({ - "Cluster": int(k), - "Count": len(idxs), - "Variance": round(variance, 3), - }) - - summary_df = pd.DataFrame(summary_data) - return summary_df, representatives diff --git a/services/parquet_service.py b/services/parquet_service.py deleted file mode 100644 index 11528da..0000000 --- a/services/parquet_service.py +++ /dev/null @@ -1,372 +0,0 @@ -""" -Service for handling parquet file operations with embeddings and metadata. -""" - -import pyarrow as pa -import pyarrow.parquet as pq -import pyarrow.compute as pc -import pandas as pd # Keep for DataFrame output compatibility -import numpy as np -import streamlit as st -from typing import Dict, List, Tuple, Optional, Any, Union -from pathlib import Path - - -class ParquetService: - """Service for handling parquet file operations with embeddings and metadata""" - - # Define the expected taxonomic columns based on your schema - TAXONOMIC_COLUMNS = [ - 'kingdom', 'phylum', 'class', 'order', 'family', 'genus', 'species' - ] - - METADATA_COLUMNS = [ - 'source_dataset', 'scientific_name', 'common_name', - 'publisher', 'basisOfRecord', 'img_type' - ] + TAXONOMIC_COLUMNS - - @staticmethod - def load_parquet_table(file_path: str) -> pa.Table: - """ - Load a parquet file as PyArrow Table (zero-copy, memory efficient). - - Args: - file_path: Path to the parquet file - - Returns: - PyArrow Table with the parquet data - """ - try: - return pq.read_table(file_path) - except Exception as e: - raise ValueError(f"Error loading parquet file: {e}") - - @staticmethod - def validate_parquet_structure(df: Union[pd.DataFrame, pa.Table]) -> Tuple[bool, List[str]]: - """ - Validate that the parquet file has the expected structure. - - Args: - df: DataFrame or PyArrow Table to validate - - Returns: - Tuple of (is_valid, list_of_issues) - """ - issues = [] - - if isinstance(df, pa.Table): - # PyArrow Table validation - column_names = df.column_names - - # Check for required columns - if 'uuid' not in column_names: - issues.append("Missing required 'uuid' column") - if 'emb' not in column_names: - issues.append("Missing required 'emb' column") - - # Check for null values in critical columns - if 'uuid' in column_names: - uuid_col = df.column('uuid') - null_count = pc.sum(pc.is_null(uuid_col)).as_py() - if null_count > 0: - issues.append("Found null values in 'uuid' column") - - if 'emb' in column_names: - emb_col = df.column('emb') - null_count = pc.sum(pc.is_null(emb_col)).as_py() - if null_count > 0: - issues.append("Found null values in 'emb' column") - - # Check embedding format - try: - # Try to get first non-null embedding to check format - first_emb = None - for i in range(min(len(emb_col), 100)): # Check first 100 rows - if emb_col[i].is_valid: - first_emb = emb_col[i].as_py() - break - - if first_emb is not None: - if not isinstance(first_emb, (list, tuple)): - issues.append("Embedding column 'emb' does not contain arrays") - elif len(first_emb) == 0: - issues.append("Empty embeddings found") - else: - issues.append("No valid embeddings found") - except Exception as e: - issues.append(f"Error parsing embeddings: {e}") - else: - # pandas DataFrame validation (fallback for compatibility) - df = df.to_pandas() if isinstance(df, pa.Table) else df - - # Check for required columns - if 'uuid' not in df.columns: - issues.append("Missing required 'uuid' column") - if 'emb' not in df.columns: - issues.append("Missing required 'emb' column") - - # Check for null values in critical columns - if 'uuid' in df.columns and df['uuid'].isnull().any(): - issues.append("Found null values in 'uuid' column") - if 'emb' in df.columns and df['emb'].isnull().any(): - issues.append("Found null values in 'emb' column") - - # Check embedding format - if 'emb' in df.columns: - try: - # Try to convert first embedding to check format - first_emb = df['emb'].iloc[0] - if not isinstance(first_emb, (list, np.ndarray)): - issues.append("Embedding column 'emb' does not contain arrays") - elif len(first_emb) == 0: - issues.append("Empty embeddings found") - except Exception as e: - issues.append(f"Error parsing embeddings: {e}") - - return len(issues) == 0, issues - - @staticmethod - def extract_embeddings(df: Union[pd.DataFrame, pa.Table]) -> np.ndarray: - """ - Extract embeddings from the DataFrame or PyArrow Table. - - Args: - df: DataFrame or PyArrow Table containing 'emb' column - - Returns: - numpy array of embeddings with shape (n_samples, embedding_dim) - """ - if isinstance(df, pa.Table): - if 'emb' not in df.column_names: - raise ValueError("Table does not contain 'emb' column") - - # Extract embeddings column as PyArrow array - emb_column = df.column('emb') - # Convert to numpy - PyArrow list arrays need special handling - embeddings = emb_column.to_pylist() - embeddings = np.array(embeddings) - else: - # pandas DataFrame fallback - if 'emb' not in df.columns: - raise ValueError("DataFrame does not contain 'emb' column") - embeddings = np.array(df['emb'].tolist()) - - if embeddings.ndim != 2: - raise ValueError(f"Embeddings should be 2D, got shape {embeddings.shape}") - - return embeddings - - @staticmethod - def get_column_info(df: Union[pd.DataFrame, pa.Table]) -> Dict[str, Dict[str, Any]]: - """ - Get information about each column for filtering purposes. - - Args: - df: DataFrame or PyArrow Table to analyze - - Returns: - Dictionary mapping column names to their info (type, unique_values, etc.) - """ - column_info = {} - - # Convert to PyArrow table if pandas DataFrame - if isinstance(df, pd.DataFrame): - df = pa.Table.from_pandas(df) - - # PyArrow Table processing - for col_name in df.column_names: - if col_name in ['uuid', 'emb']: # Skip technical columns - continue - - col_array = df.column(col_name) - - # Handle null values - non_null_mask = pc.is_valid(col_array) - non_null_count = pc.sum(non_null_mask).as_py() - total_count = len(col_array) - null_count = total_count - non_null_count - - if non_null_count == 0: - col_type = 'empty' - unique_values = [] - value_counts = {} - else: - # Check data type - arrow_type = col_array.type - - if (pa.types.is_integer(arrow_type) or - pa.types.is_floating(arrow_type) or - pa.types.is_decimal(arrow_type)): - col_type = 'numeric' - unique_values = None - value_counts = None - else: - # Get unique values for categorical determination - try: - unique_array = pc.unique(col_array) - unique_count = len(unique_array) - - if unique_count <= 50: # Categorical if <= 50 unique values - col_type = 'categorical' - unique_values = sorted([v.as_py() for v in unique_array if v.is_valid]) - - # Get value counts - value_counts_result = pc.value_counts(col_array) - value_counts = {} - for i in range(len(value_counts_result)): - struct = value_counts_result[i].as_py() - if struct['values'] is not None: - value_counts[struct['values']] = struct['counts'] - else: - col_type = 'text' - unique_values = None - value_counts = None - except: - col_type = 'text' - unique_values = None - value_counts = None - - column_info[col_name] = { - 'type': col_type, - 'unique_values': unique_values, - 'value_counts': value_counts, - 'null_count': null_count, - 'total_count': total_count, - 'null_percentage': (null_count / total_count) * 100 if total_count > 0 else 0 - } - - return column_info - - @staticmethod - def apply_filters_arrow(table: pa.Table, filters: Dict[str, Any]) -> pa.Table: - """ - Apply filters to PyArrow Table (more memory efficient). - - Args: - table: PyArrow Table to filter - filters: Dictionary of column_name -> filter_value pairs - - Returns: - Filtered PyArrow Table - """ - filter_expressions = [] - - for col, filter_value in filters.items(): - if col not in table.column_names or filter_value is None: - continue - - col_ref = pc.field(col) - - if isinstance(filter_value, dict): - # Numeric range filter - if 'min' in filter_value and filter_value['min'] is not None: - filter_expressions.append(pc.greater_equal(col_ref, filter_value['min'])) - if 'max' in filter_value and filter_value['max'] is not None: - filter_expressions.append(pc.less_equal(col_ref, filter_value['max'])) - elif isinstance(filter_value, list): - # Categorical filter (multiple values) - if len(filter_value) > 0: - filter_expressions.append(pc.is_in(col_ref, pa.array(filter_value))) - elif isinstance(filter_value, str): - # Text filter (contains) - if filter_value.strip(): - # PyArrow string matching (case insensitive) - pattern = f"*{filter_value.lower()}*" - filter_expressions.append( - pc.match_substring_regex( - pc.utf8_lower(col_ref), - pattern.replace("*", ".*") - ) - ) - - # Combine all filter expressions with AND - if filter_expressions: - if len(filter_expressions) == 1: - combined_filter = filter_expressions[0] - else: - # Combine filters using reduce pattern - from functools import reduce - try: - # Try pc.and_kleene first (newer PyArrow versions) - combined_filter = reduce(lambda a, b: pc.and_kleene(a, b), filter_expressions) - except AttributeError: - # Fallback for older PyArrow versions - apply filters sequentially - filtered_table = table - for expr in filter_expressions: - filtered_table = filtered_table.filter(expr) - return filtered_table - - return table.filter(combined_filter) - - return table - - @staticmethod - def create_cluster_dataframe( - df: pd.DataFrame, - embeddings_2d: np.ndarray, - labels: np.ndarray - ) -> pd.DataFrame: - """ - Create a dataframe for clustering visualization. - - Args: - df: Original dataframe with metadata - embeddings_2d: 2D reduced embeddings - labels: Cluster labels - - Returns: - DataFrame suitable for plotting - """ - df_plot = pd.DataFrame({ - "x": embeddings_2d[:, 0], - "y": embeddings_2d[:, 1], - "cluster": labels.astype(str), - "uuid": df['uuid'].values, - "idx": range(len(df)) - }) - - # Add key metadata columns for tooltips - metadata_cols = ['scientific_name', 'common_name', 'family', 'genus', 'species'] - for col in metadata_cols: - if col in df.columns: - df_plot[col] = df[col].values - - return df_plot - - @staticmethod - def load_and_filter_efficient( - file_path: str, - filters: Optional[Dict[str, Any]] = None, - columns: Optional[List[str]] = None - ) -> Tuple[pa.Table, pd.DataFrame]: - """ - Load parquet file efficiently with PyArrow and apply filters. - Returns both PyArrow table (for efficient operations) and pandas DataFrame (for compatibility). - - Args: - file_path: Path to parquet file - filters: Optional filters to apply - columns: Optional list of columns to select - - Returns: - Tuple of (PyArrow Table, pandas DataFrame) - """ - # Load as PyArrow table - table = ParquetService.load_parquet_table(file_path) - - # Apply column selection if specified - if columns: - # Ensure required columns are included - required_cols = ['uuid', 'emb'] - all_columns = list(set(columns + required_cols)) - available_columns = [col for col in all_columns if col in table.column_names] - table = table.select(available_columns) - - # Apply filters efficiently with PyArrow - if filters: - table = ParquetService.apply_filters_arrow(table, filters) - - # Convert to pandas for compatibility (only the filtered data) - df = table.to_pandas() - - return table, df diff --git a/shared/__init__.py b/shared/__init__.py new file mode 100644 index 0000000..ae13b8e --- /dev/null +++ b/shared/__init__.py @@ -0,0 +1,5 @@ +""" +Shared utilities and services for the emb-explorer applications. +""" + +__version__ = "0.1.0" diff --git a/shared/components/__init__.py b/shared/components/__init__.py new file mode 100644 index 0000000..bc1e04b --- /dev/null +++ b/shared/components/__init__.py @@ -0,0 +1,8 @@ +""" +Shared UI components. + +Import directly from submodules: + + from shared.components.clustering_controls import render_clustering_backend_controls + from shared.components.visualization import render_scatter_plot +""" diff --git a/components/shared/clustering_controls.py b/shared/components/clustering_controls.py similarity index 80% rename from components/shared/clustering_controls.py rename to shared/components/clustering_controls.py index 0aba28a..df971ae 100644 --- a/components/shared/clustering_controls.py +++ b/shared/components/clustering_controls.py @@ -5,41 +5,26 @@ import streamlit as st from typing import Tuple, Optional +from shared.utils.backend import HAS_FAISS_PACKAGE, HAS_CUML_PACKAGE, HAS_CUPY_PACKAGE + def render_clustering_backend_controls(): """ Render clustering backend selection controls. - + Returns: Tuple of (dim_reduction_backend, clustering_backend, n_workers, seed) """ - # Backend availability detection + # Backend availability detection — uses find_spec() flags (instant, no heavy imports) dim_reduction_options = ["auto", "sklearn"] clustering_options = ["auto", "sklearn"] - - has_faiss = False - has_cuml = False - has_cuda = False - - # Check for FAISS (clustering only) - try: - import faiss - has_faiss = True + + if HAS_FAISS_PACKAGE: clustering_options.append("faiss") - except ImportError: - pass - - # Check for cuML + CUDA (both dim reduction and clustering) - try: - import cuml - import cupy as cp - has_cuml = True - if cp.cuda.is_available(): - has_cuda = True - dim_reduction_options.append("cuml") - clustering_options.append("cuml") - except ImportError: - pass + + if HAS_CUML_PACKAGE and HAS_CUPY_PACKAGE: + dim_reduction_options.append("cuml") + clustering_options.append("cuml") # Show backend status use_seed = st.checkbox( diff --git a/components/clustering/summary.py b/shared/components/summary.py similarity index 88% rename from components/clustering/summary.py rename to shared/components/summary.py index fe63d64..36c4406 100644 --- a/components/clustering/summary.py +++ b/shared/components/summary.py @@ -1,12 +1,14 @@ """ -Clustering summary components. +Shared clustering summary components. """ import streamlit as st import os import pandas as pd -from services.clustering_service import ClusteringService -from utils.taxonomy_tree import build_taxonomic_tree, format_tree_string, get_tree_statistics +from shared.utils.taxonomy_tree import build_taxonomic_tree, format_tree_string, get_tree_statistics +from shared.utils.logging_config import get_logger + +logger = get_logger(__name__) def render_taxonomic_tree_summary(): @@ -14,13 +16,13 @@ def render_taxonomic_tree_summary(): df_plot = st.session_state.get("data", None) labels = st.session_state.get("labels", None) filtered_df = st.session_state.get("filtered_df_for_clustering", None) - + if df_plot is not None and labels is not None and filtered_df is not None: - st.markdown("### 🌳 Taxonomic Distribution") - + st.markdown("### Taxonomic Distribution") + # Add controls at the top of the taxonomy section col1, col2, col3 = st.columns([2, 1, 1]) - + with col1: # Get available clusters cluster_options = ["All"] @@ -28,7 +30,7 @@ def render_taxonomic_tree_summary(): # Check if we have taxonomic clustering with cluster names taxonomic_info = st.session_state.get("taxonomic_clustering", {}) is_taxonomic = taxonomic_info.get('is_taxonomic', False) - + if is_taxonomic and 'cluster_name' in df_plot.columns: # Use taxonomic names for display unique_cluster_names = sorted(df_plot["cluster_name"].unique()) @@ -37,7 +39,7 @@ def render_taxonomic_tree_summary(): # Standard numeric clustering unique_clusters = sorted(df_plot["cluster"].unique(), key=lambda x: int(x)) cluster_options.extend([f"Cluster {c}" for c in unique_clusters]) - + selected_cluster = st.selectbox( "Display taxonomy for:", options=cluster_options, @@ -45,7 +47,7 @@ def render_taxonomic_tree_summary(): key="taxonomy_cluster_selector", help="Select a specific cluster to show its taxonomy tree, or 'All' to show the entire dataset" ) - + with col2: min_count = st.number_input( "Minimum count", @@ -56,7 +58,7 @@ def render_taxonomic_tree_summary(): key="taxonomy_min_count", help="Minimum number of records for a taxon to appear in the tree" ) - + with col3: tree_depth = st.slider( "Tree depth", @@ -66,7 +68,7 @@ def render_taxonomic_tree_summary(): key="taxonomy_tree_depth", help="Maximum depth of the taxonomy tree to display" ) - + # Create a stable cache key based on the data characteristics and filter parameters # Use data length and a sample of UUIDs for a stable data identifier data_length = len(filtered_df) @@ -74,21 +76,21 @@ def render_taxonomic_tree_summary(): sample_uuids = filtered_df['uuid'].iloc[:min(10, len(filtered_df))].tolist() data_id = f"{data_length}_{len(sample_uuids)}_{sample_uuids[0] if sample_uuids else 'empty'}" cache_key = f"taxonomy_{data_id}_{selected_cluster}_{min_count}_{tree_depth}" - + # Check if we have cached results and they're still valid # Also ensure critical session state data hasn't changed unexpectedly current_cache_key = st.session_state.get("taxonomy_cache_key") cache_exists = cache_key in st.session_state - + if (not cache_exists or current_cache_key != cache_key): - + # Data or parameters changed, regenerate taxonomy tree with st.spinner("Building taxonomy tree..."): # Filter data based on selected cluster if selected_cluster != "All": taxonomic_info = st.session_state.get("taxonomic_clustering", {}) is_taxonomic = taxonomic_info.get('is_taxonomic', False) - + if is_taxonomic and 'cluster_name' in df_plot.columns: # For taxonomic clustering, filter by cluster_name cluster_mask = df_plot['cluster_name'] == selected_cluster @@ -111,12 +113,12 @@ def render_taxonomic_tree_summary(): else: tree_df = filtered_df display_title = "Taxonomic Tree for All Clusters" - + # Build taxonomic tree for the selected data (only when needed) tree = build_taxonomic_tree(tree_df) stats = get_tree_statistics(tree) tree_string = format_tree_string(tree, max_depth=tree_depth, min_count=min_count) - + # Cache the results st.session_state[cache_key] = { 'tree': tree, @@ -125,10 +127,10 @@ def render_taxonomic_tree_summary(): 'display_title': display_title } st.session_state["taxonomy_cache_key"] = cache_key - + # Use cached results (no regeneration) cached_data = st.session_state[cache_key] - + # Show statistics st.markdown(f"**{cached_data['display_title']}**") col1, col2, col3, col4 = st.columns(4) @@ -140,7 +142,7 @@ def render_taxonomic_tree_summary(): st.metric("Families", cached_data['stats']['families']) with col4: st.metric("Species", cached_data['stats']['species']) - + # Display the tree if cached_data['tree_string']: st.code(cached_data['tree_string'], language="text") @@ -149,24 +151,24 @@ def render_taxonomic_tree_summary(): def render_clustering_summary(show_taxonomy=False): - """Render the clustering summary panel.""" + """Render the clustering summary panel using cached results from clustering action.""" df_plot = st.session_state.get("data", None) labels = st.session_state.get("labels", None) - embeddings = st.session_state.get("embeddings", None) - if df_plot is not None and labels is not None and embeddings is not None: + # Get pre-computed summary from session state (computed when clustering was run) + summary_df = st.session_state.get("clustering_summary", None) + representatives = st.session_state.get("clustering_representatives", None) + + if df_plot is not None and labels is not None: # Check if this is image data or metadata-only data has_images = 'image_path' in df_plot.columns - + if has_images: # For image data, show the full clustering summary st.subheader("Clustering Summary") - - try: - summary_df, representatives = ClusteringService.generate_clustering_summary( - embeddings, labels, df_plot - ) - + + if summary_df is not None and representatives is not None: + logger.debug("Displaying cached clustering summary") st.dataframe(summary_df, hide_index=True, width='stretch') st.markdown("#### Representative Images") @@ -176,21 +178,21 @@ def render_clustering_summary(show_taxonomy=False): img_cols = st.columns(3) for i, img_idx in enumerate(representatives[k]): img_path = df_plot.iloc[img_idx]["image_path"] + logger.debug(f"Displaying representative image: {img_path}") img_cols[i].image( - img_path, - width='stretch', + img_path, + width='stretch', caption=os.path.basename(img_path) ) - - except Exception as e: - st.error(f"Error generating clustering summary: {e}") + else: + st.info("Clustering summary will be computed when you run clustering.") else: # For metadata-only data (precalculated embeddings), show taxonomic tree if requested if show_taxonomy: filtered_df = st.session_state.get("filtered_df_for_clustering", None) - + if filtered_df is not None: render_taxonomic_tree_summary() - + else: st.info("Clustering summary will appear here after clustering.") diff --git a/shared/components/visualization.py b/shared/components/visualization.py new file mode 100644 index 0000000..db9b243 --- /dev/null +++ b/shared/components/visualization.py @@ -0,0 +1,181 @@ +""" +Shared visualization components for scatter plots. +""" + +import streamlit as st +import altair as alt + +from shared.utils.logging_config import get_logger + +logger = get_logger(__name__) + + +def render_scatter_plot(): + """Render the main clustering scatter plot with dynamic tooltips. + + The chart is rendered inside a @st.fragment so that zoom/pan interactions + only rerun the chart itself — the rest of the page (data preview, summary) + stays untouched. A full page rerun is triggered explicitly only when the + user clicks a *different* point. + """ + df_plot = st.session_state.get("data", None) + + if df_plot is not None and len(df_plot) > 1: + _render_chart_fragment(df_plot) + else: + st.info("Run clustering to see the cluster scatter plot.") + st.session_state['selected_image_idx'] = None + + +@st.fragment +def _render_chart_fragment(df_plot): + """Fragment-isolated chart rendering — zoom/pan do NOT rerun the page.""" + # Track previous density mode to detect changes + prev_density_mode = st.session_state.get("_prev_density_mode", None) + + # Plot options in columns for compact layout + opt_col1, opt_col2 = st.columns([2, 1]) + + with opt_col1: + density_mode = st.radio( + "Density visualization", + options=["Off", "Opacity", "Heatmap"], + index=0, + horizontal=True, + key="density_mode", + help="Off: normal view | Opacity: lower opacity to show overlap | Heatmap: 2D binned density (disables selection)" + ) + + # Log density mode change + if prev_density_mode != density_mode: + logger.info(f"[Visualization] Density mode changed: {prev_density_mode} -> {density_mode}") + st.session_state["_prev_density_mode"] = density_mode + + with opt_col2: + if density_mode == "Heatmap": + prev_bins = st.session_state.get("_prev_heatmap_bins", 40) + heatmap_bins = st.slider( + "Grid resolution", + min_value=10, + max_value=80, + value=40, + step=5, + key="heatmap_bins", + help="Number of bins for density grid (higher = finer detail)" + ) + if prev_bins != heatmap_bins: + logger.info(f"[Visualization] Heatmap bins changed: {prev_bins} -> {heatmap_bins}") + st.session_state["_prev_heatmap_bins"] = heatmap_bins + else: + heatmap_bins = 40 # Default, not used + + point_selector = alt.selection_point(fields=["idx"], name="point_selection") + + # Determine tooltip fields based on available columns + tooltip_fields = [] + + # Use cluster_name for display if available (taxonomic clustering), otherwise use cluster + if 'cluster_name' in df_plot.columns: + tooltip_fields.append('cluster_name:N') + cluster_legend_title = "Cluster" + else: + tooltip_fields.append('cluster:N') + cluster_legend_title = "Cluster" + + # Add other metadata columns dynamically (limit to prevent tooltip overflow) + skip_cols = {'x', 'y', 'cluster', 'cluster_name', 'idx', 'emb', 'embedding', 'embeddings', 'vector'} + metadata_cols = [c for c in df_plot.columns if c not in skip_cols][:8] + tooltip_fields.extend(metadata_cols) + + # Determine title based on data type + if 'uuid' in df_plot.columns: + title = "Embedding Clusters (click a point to view details)" + else: + title = "Image Clusters (click a point to preview image)" + + # Set opacity based on density mode + if density_mode == "Opacity": + point_opacity = 0.15 # Low opacity so overlaps show density + elif density_mode == "Heatmap": + point_opacity = 0.5 # Medium opacity when heatmap is behind + else: + point_opacity = 0.7 # Normal opacity + + # Create scatter plot + scatter = ( + alt.Chart(df_plot) + .mark_circle(size=60, opacity=point_opacity) + .encode( + x=alt.X('x:Q', scale=alt.Scale(zero=False)), + y=alt.Y('y:Q', scale=alt.Scale(zero=False)), + color=alt.Color('cluster:N', legend=alt.Legend(title=cluster_legend_title)), + tooltip=tooltip_fields, + fillOpacity=alt.condition(point_selector, alt.value(1), alt.value(0.3)) + ) + .add_params(point_selector) + ) + + if density_mode == "Heatmap": + # Create 2D density heatmap layer with configurable bins + density = ( + alt.Chart(df_plot) + .mark_rect(opacity=0.4) + .encode( + x=alt.X('x:Q', bin=alt.Bin(maxbins=heatmap_bins), scale=alt.Scale(zero=False)), + y=alt.Y('y:Q', bin=alt.Bin(maxbins=heatmap_bins), scale=alt.Scale(zero=False)), + color=alt.Color( + 'count():Q', + scale=alt.Scale(scheme='blues'), + legend=None + ) + ) + ) + # Layer density behind scatter + chart = alt.layer(density, scatter) + else: + chart = scatter + + # Apply common properties and interactivity + title_suffix = " (scroll to zoom, drag to pan)" + if density_mode != "Heatmap": + title_suffix += ", click to select" + + chart = ( + chart + .properties( + width=800, + height=700, + title=title + title_suffix + ) + .interactive() # Enable zoom/pan + ) + + # Log chart render only at DEBUG to avoid noise from zoom/pan reruns + logger.debug(f"[Visualization] Rendering chart: {len(df_plot)} points, density={density_mode}, " + f"bins={heatmap_bins if density_mode == 'Heatmap' else 'N/A'}") + + # Streamlit doesn't support selections on layered charts, so only enable + # selection when not using heatmap mode + if density_mode == "Heatmap": + st.altair_chart(chart, key="alt_chart", width="stretch") + st.caption("Note: Point selection is disabled when heatmap is shown.") + else: + event = st.altair_chart(chart, key="alt_chart", on_select="rerun", width="stretch") + + # Handle point selection — only trigger full page rerun when + # the selected point actually changes (zoom/pan stay fragment-local) + if ( + event + and "selection" in event + and "point_selection" in event["selection"] + and event["selection"]["point_selection"] + ): + new_idx = int(event["selection"]["point_selection"][0]["idx"]) + prev_idx = st.session_state.get("selected_image_idx") + if prev_idx != new_idx: + cluster = df_plot.iloc[new_idx]['cluster'] if 'cluster' in df_plot.columns else '?' + logger.info(f"[Visualization] Point selected: idx={new_idx}, cluster={cluster}") + st.session_state["selected_image_idx"] = new_idx + st.session_state["selection_data_version"] = st.session_state.get("data_version", None) + # Trigger full page rerun so the preview panel updates + st.rerun(scope="app") diff --git a/shared/lib/__init__.py b/shared/lib/__init__.py new file mode 100644 index 0000000..6289b2d --- /dev/null +++ b/shared/lib/__init__.py @@ -0,0 +1,7 @@ +""" +Shared library utilities. + +Import directly from submodules: + + from shared.lib.progress import StreamlitProgressContext +""" diff --git a/lib/progress.py b/shared/lib/progress.py similarity index 100% rename from lib/progress.py rename to shared/lib/progress.py diff --git a/shared/services/__init__.py b/shared/services/__init__.py new file mode 100644 index 0000000..1da3566 --- /dev/null +++ b/shared/services/__init__.py @@ -0,0 +1,8 @@ +""" +Shared services for embedding, clustering, and file operations. + +Import directly from submodules to avoid pulling in heavy dependencies: + + from shared.services.clustering_service import ClusteringService + from shared.services.embedding_service import EmbeddingService +""" diff --git a/shared/services/clustering_service.py b/shared/services/clustering_service.py new file mode 100644 index 0000000..26ce93d --- /dev/null +++ b/shared/services/clustering_service.py @@ -0,0 +1,190 @@ +""" +Clustering service. +""" + +import numpy as np +import pandas as pd +import os +import time +from typing import Tuple, Dict, List, Optional + +from shared.utils.clustering import run_kmeans, reduce_dim +from shared.utils.backend import is_oom_error, is_cuda_arch_error, is_gpu_error +from shared.utils.logging_config import get_logger + +logger = get_logger(__name__) + + +class ClusteringService: + """Service for handling clustering workflows""" + + @staticmethod + def run_clustering( + embeddings: np.ndarray, + valid_paths: List[str], + n_clusters: int, + reduction_method: str, + n_workers: int = 1, + dim_reduction_backend: str = "auto", + clustering_backend: str = "auto", + seed: int = None + ) -> Tuple[pd.DataFrame, np.ndarray]: + """ + Run clustering on embeddings. + + Args: + embeddings: Input embeddings + valid_paths: List of image paths + n_clusters: Number of clusters + reduction_method: Dimensionality reduction method + n_workers: Number of workers for reduction + dim_reduction_backend: Backend for dimensionality reduction ("auto", "sklearn", "faiss", "cuml") + clustering_backend: Backend for clustering ("auto", "sklearn", "faiss", "cuml") + seed: Random seed for reproducibility (None for random) + + Returns: + Tuple of (cluster dataframe, cluster labels) + """ + n_samples, n_features = embeddings.shape + logger.info(f"Starting clustering workflow: samples={n_samples}, features={n_features}, " + f"n_clusters={n_clusters}, reduction={reduction_method}, " + f"dim_backend={dim_reduction_backend}, cluster_backend={clustering_backend}, " + f"seed={seed}") + + total_start = time.time() + + # Step 1: Perform K-means clustering on full high-dimensional embeddings + # (embeddings are L2-normalized inside run_kmeans) + logger.info("Step 1/2: Running KMeans clustering on high-dimensional embeddings") + kmeans, labels = run_kmeans( + embeddings, + int(n_clusters), + seed=seed, + n_workers=n_workers, + backend=clustering_backend + ) + logger.info(f"Step 1/2 complete: {len(np.unique(labels))} clusters assigned") + + # Step 2: Reduce dimensionality to 2D for visualization only + # (embeddings are L2-normalized inside reduce_dim) + logger.info("Step 2/2: Reducing dimensionality to 2D for visualization") + reduced = reduce_dim( + embeddings, + reduction_method, + seed=seed, + n_workers=n_workers, + backend=dim_reduction_backend + ) + logger.info(f"Step 2/2 complete: reduced to shape {reduced.shape}") + + df_plot = pd.DataFrame({ + "x": reduced[:, 0], + "y": reduced[:, 1], + "cluster": labels.astype(str), + "image_path": valid_paths, + "file_name": [os.path.basename(p) for p in valid_paths], + "idx": range(len(valid_paths)) + }) + + total_elapsed = time.time() - total_start + logger.info(f"Clustering workflow completed in {total_elapsed:.2f}s") + + return df_plot, labels + + @staticmethod + def generate_clustering_summary( + embeddings: np.ndarray, + labels: np.ndarray, + df_plot: pd.DataFrame + ) -> Tuple[pd.DataFrame, Dict[int, List[int]]]: + """ + Generate clustering summary statistics and representative images. + + Args: + embeddings: Original embeddings + labels: Cluster labels + df_plot: Clustering dataframe + + Returns: + Tuple of (summary dataframe, representatives dict) + """ + logger.info("Generating clustering summary statistics") + cluster_ids = np.unique(labels) + logger.debug(f"Found {len(cluster_ids)} unique clusters") + summary_data = [] + representatives = {} + + for k in cluster_ids: + idxs = np.where(labels == k)[0] + cluster_embeds = embeddings[idxs] + centroid = cluster_embeds.mean(axis=0) + + # Internal variance + variance = np.mean(np.sum((cluster_embeds - centroid) ** 2, axis=1)) + + # Find 3 closest images + dists = np.sum((cluster_embeds - centroid) ** 2, axis=1) + closest_indices = idxs[np.argsort(dists)[:3]] + representatives[k] = closest_indices + + summary_data.append({ + "Cluster": int(k), + "Count": len(idxs), + "Variance": round(variance, 3), + }) + + summary_df = pd.DataFrame(summary_data) + return summary_df, representatives + + @staticmethod + def run_clustering_safe( + embeddings: np.ndarray, + valid_paths: List[str], + n_clusters: int, + reduction_method: str, + n_workers: int = 1, + dim_reduction_backend: str = "auto", + clustering_backend: str = "auto", + seed: Optional[int] = None + ) -> Tuple[pd.DataFrame, np.ndarray]: + """ + Run clustering with automatic GPU-to-CPU fallback on errors. + + Handles CUDA architecture mismatches, missing NVRTC libraries, and + other GPU errors by transparently retrying with sklearn backends. + + GPU OOM and system MemoryError are re-raised for the caller to + present appropriate UI. + + Args: + embeddings: Input embeddings + valid_paths: List of identifiers (image paths or UUIDs) + n_clusters: Number of clusters + reduction_method: Dimensionality reduction method + n_workers: Number of workers for reduction + dim_reduction_backend: Backend for dimensionality reduction + clustering_backend: Backend for clustering + seed: Random seed for reproducibility + + Returns: + Tuple of (cluster dataframe, cluster labels) + + Raises: + MemoryError: System out of memory (unrecoverable) + RuntimeError: GPU OOM (caller should show user guidance) + """ + try: + return ClusteringService.run_clustering( + embeddings, valid_paths, n_clusters, reduction_method, + n_workers, dim_reduction_backend, clustering_backend, seed + ) + except (RuntimeError, OSError) as e: + if is_oom_error(e): + raise + if is_cuda_arch_error(e) or is_gpu_error(e): + logger.warning(f"GPU error ({e}), falling back to sklearn backends") + return ClusteringService.run_clustering( + embeddings, valid_paths, n_clusters, reduction_method, + n_workers, "sklearn", "sklearn", seed + ) + raise diff --git a/services/embedding_service.py b/shared/services/embedding_service.py similarity index 67% rename from services/embedding_service.py rename to shared/services/embedding_service.py index 987908e..3b82e28 100644 --- a/services/embedding_service.py +++ b/shared/services/embedding_service.py @@ -1,41 +1,45 @@ """ Embedding generation service. + +Heavy libraries (torch, open_clip) are imported lazily inside methods +to avoid slowing down app startup. """ -import torch import numpy as np -import open_clip import streamlit as st +import time from typing import Tuple, List, Optional, Callable -from utils.io import list_image_files -from utils.models import list_available_models -from hpc_inference.datasets.image_folder_dataset import ImageFolderDataset +from shared.utils.io import list_image_files +from shared.utils.models import list_available_models +from shared.utils.logging_config import get_logger + +logger = get_logger(__name__) class EmbeddingService: """Service for handling embedding generation workflows""" - + @staticmethod @st.cache_data def get_model_options() -> List[str]: """Get formatted model options for selectbox.""" models_data = list_available_models() options = [] - + # Add all models from list for model in models_data: name = model['name'] pretrained = model['pretrained'] - + if pretrained is None or pretrained == "": display_name = name else: display_name = f"{name} ({pretrained})" options.append(display_name) - + return options - + @staticmethod def parse_model_selection(selected_model: str) -> Tuple[str, Optional[str]]: """Parse the selected model string to extract model name and pretrained.""" @@ -46,22 +50,30 @@ def parse_model_selection(selected_model: str) -> Tuple[str, Optional[str]]: return name, pretrained else: return selected_model, None - + @staticmethod @st.cache_resource(show_spinner=True) def load_model_unified(selected_model: str, device: str = "cuda"): """Unified model loading function that handles all model types.""" + import torch + import open_clip + model_name, pretrained = EmbeddingService.parse_model_selection(selected_model) - + + logger.info(f"Loading model: {model_name} (pretrained={pretrained}) on device={device}") + start_time = time.time() + model, _, preprocess = open_clip.create_model_and_transforms( model_name, pretrained=pretrained, device=device ) - + model = torch.compile(model.to(device)) + + elapsed = time.time() - start_time + logger.info(f"Model loaded in {elapsed:.2f}s") return model, preprocess - + @staticmethod - @torch.no_grad() def generate_embeddings( image_dir: str, model_name: str, @@ -71,37 +83,46 @@ def generate_embeddings( ) -> Tuple[np.ndarray, List[str]]: """ Generate embeddings for images in a directory. - + Args: image_dir: Path to directory containing images model_name: Name of the model to use batch_size: Batch size for processing n_workers: Number of worker processes progress_callback: Optional callback for progress updates - + Returns: Tuple of (embeddings array, list of valid image paths) """ + import torch + from hpc_inference.datasets.image_folder_dataset import ImageFolderDataset + + logger.info(f"Starting embedding generation: dir={image_dir}, model={model_name}, " + f"batch_size={batch_size}, n_workers={n_workers}") + total_start = time.time() + if progress_callback: progress_callback(0.0, "Listing images...") - + image_paths = list_image_files(image_dir) - + logger.info(f"Found {len(image_paths)} images in {image_dir}") + if progress_callback: progress_callback(0.1, f"Found {len(image_paths)} images. Loading model...") - + torch_device = "cuda" if torch.cuda.is_available() else "cpu" + logger.info(f"Using device: {torch_device}") model, preprocess = EmbeddingService.load_model_unified(model_name, torch_device) - + if progress_callback: progress_callback(0.2, "Creating dataset...") - + # Create dataset & DataLoader dataset = ImageFolderDataset( image_dir=image_dir, preprocess=preprocess, uuid_mode="fullpath", - rank=0, + rank=0, world_size=1, evenly_distribute=True, validate=True @@ -117,18 +138,19 @@ def generate_embeddings( total = len(image_paths) valid_paths = [] embeddings = [] - + processed = 0 - for batch_paths, batch_imgs in dataloader: - batch_imgs = batch_imgs.to(torch_device, non_blocking=True) - batch_embeds = model.encode_image(batch_imgs).cpu().numpy() - embeddings.append(batch_embeds) - valid_paths.extend(batch_paths) - processed += len(batch_paths) - - if progress_callback: - progress = 0.2 + (processed / total) * 0.8 # Use 20% to 100% for actual processing - progress_callback(progress, f"Embedding {processed}/{total}") + with torch.no_grad(): + for batch_paths, batch_imgs in dataloader: + batch_imgs = batch_imgs.to(torch_device, non_blocking=True) + batch_embeds = model.encode_image(batch_imgs).cpu().numpy() + embeddings.append(batch_embeds) + valid_paths.extend(batch_paths) + processed += len(batch_paths) + + if progress_callback: + progress = 0.2 + (processed / total) * 0.8 # Use 20% to 100% for actual processing + progress_callback(progress, f"Embedding {processed}/{total}") # Stack embeddings if available if embeddings: @@ -139,4 +161,8 @@ def generate_embeddings( if progress_callback: progress_callback(1.0, f"Complete! Generated {embeddings.shape[0]} embeddings") + total_elapsed = time.time() - total_start + logger.info(f"Embedding generation completed: {embeddings.shape[0]} embeddings in {total_elapsed:.2f}s " + f"({embeddings.shape[0] / total_elapsed:.1f} images/sec)") + return embeddings, valid_paths diff --git a/services/file_service.py b/shared/services/file_service.py similarity index 85% rename from services/file_service.py rename to shared/services/file_service.py index d037519..2c01cbd 100644 --- a/services/file_service.py +++ b/shared/services/file_service.py @@ -3,16 +3,20 @@ """ import os +import time import pandas as pd import concurrent.futures from typing import List, Dict, Any, Optional, Callable, Tuple -from utils.io import copy_image +from shared.utils.io import copy_image +from shared.utils.logging_config import get_logger + +logger = get_logger(__name__) class FileService: """Service for handling file operations like saving and repartitioning""" - + @staticmethod def save_cluster_images( cluster_rows: pd.DataFrame, @@ -22,46 +26,52 @@ def save_cluster_images( ) -> Tuple[pd.DataFrame, str]: """ Save images from selected clusters. - + Args: cluster_rows: DataFrame containing cluster data to save save_dir: Directory to save images max_workers: Number of worker threads progress_callback: Optional callback for progress updates - + Returns: Tuple of (summary dataframe, csv path) """ + logger.info(f"Saving {len(cluster_rows)} cluster images to {save_dir}") + start_time = time.time() + os.makedirs(save_dir, exist_ok=True) save_rows = [] - + if progress_callback: progress_callback(0.0, "Copying images...") - + with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: futures = [ executor.submit(copy_image, row, save_dir) for idx, row in cluster_rows.iterrows() ] total_files = len(futures) - + for i, future in enumerate(concurrent.futures.as_completed(futures), 1): result = future.result() if result is not None: save_rows.append(result) - + # Progress callback with same logic as before if i % 50 == 0 or i == total_files: if progress_callback: progress = i / total_files progress_callback(progress, f"Copied {i} / {total_files} images") - + save_summary_df = pd.DataFrame(save_rows) csv_path = os.path.join(save_dir, "saved_cluster_summary.csv") save_summary_df.to_csv(csv_path, index=False) - + + elapsed = time.time() - start_time + logger.info(f"Saved {len(save_rows)} images in {elapsed:.2f}s") + return save_summary_df, csv_path - + @staticmethod def repartition_images_by_cluster( df_plot: pd.DataFrame, @@ -71,41 +81,47 @@ def repartition_images_by_cluster( ) -> Tuple[pd.DataFrame, str]: """ Repartition all images by cluster. - + Args: df_plot: DataFrame containing all cluster data repartition_dir: Directory to repartition images max_workers: Number of worker threads progress_callback: Optional callback for progress updates - + Returns: Tuple of (summary dataframe, csv path) """ + logger.info(f"Repartitioning {len(df_plot)} images by cluster to {repartition_dir}") + start_time = time.time() + os.makedirs(repartition_dir, exist_ok=True) repartition_rows = [] - + if progress_callback: progress_callback(0.0, "Starting repartitioning...") - + with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: futures = [ executor.submit(copy_image, row, repartition_dir) for idx, row in df_plot.iterrows() ] total_files = len(futures) - + for i, future in enumerate(concurrent.futures.as_completed(futures), 1): result = future.result() if result is not None: repartition_rows.append(result) - + if i % 100 == 0 or i == total_files: if progress_callback: progress = i / total_files progress_callback(progress, f"Repartitioned {i} / {total_files} images") - + repartition_summary_df = pd.DataFrame(repartition_rows) csv_path = os.path.join(repartition_dir, "cluster_summary.csv") repartition_summary_df.to_csv(csv_path, index=False) - + + elapsed = time.time() - start_time + logger.info(f"Repartitioned {len(repartition_rows)} images in {elapsed:.2f}s") + return repartition_summary_df, csv_path diff --git a/shared/utils/__init__.py b/shared/utils/__init__.py new file mode 100644 index 0000000..b305aa7 --- /dev/null +++ b/shared/utils/__init__.py @@ -0,0 +1,10 @@ +""" +Shared utilities for clustering, IO, models, and taxonomy. + +Modules are imported lazily to avoid pulling in heavy dependencies +(sklearn, umap, faiss, cuml, torch, open_clip) at startup. +Use direct imports instead: + + from shared.utils.clustering import reduce_dim, run_kmeans + from shared.utils.io import list_image_files +""" diff --git a/shared/utils/backend.py b/shared/utils/backend.py new file mode 100644 index 0000000..ed66cad --- /dev/null +++ b/shared/utils/backend.py @@ -0,0 +1,194 @@ +""" +Backend detection and resolution utilities. + +Provides consistent backend selection and CUDA availability checking +across all applications. + +Availability checks use importlib.find_spec() for instant package detection +without importing heavy libraries. Actual imports happen lazily when the +backend is first used. +""" + +import importlib.util +from typing import Tuple, Optional +from shared.utils.logging_config import get_logger + +logger = get_logger(__name__) + +# --- Lightweight availability checks (find_spec, no actual import) ---------- + +# These are safe to call at module-load / render time — they only check +# whether the package is installed, without executing it. + +HAS_FAISS_PACKAGE: bool = importlib.util.find_spec("faiss") is not None +HAS_CUML_PACKAGE: bool = importlib.util.find_spec("cuml") is not None +HAS_CUPY_PACKAGE: bool = importlib.util.find_spec("cupy") is not None +HAS_TORCH_PACKAGE: bool = importlib.util.find_spec("torch") is not None + +# --- Cached runtime checks (perform actual import, cached after first call) - + +# Cache CUDA availability to avoid repeated checks +_cuda_check_cache: Optional[Tuple[bool, str]] = None + + +def check_cuda_available() -> Tuple[bool, str]: + """ + Check if CUDA is available for GPU-accelerated backends. + + Returns: + Tuple of (is_available, device_info_string) + """ + global _cuda_check_cache + + if _cuda_check_cache is not None: + return _cuda_check_cache + + # Try PyTorch first + if HAS_TORCH_PACKAGE: + try: + import torch + if torch.cuda.is_available(): + device_name = torch.cuda.get_device_name(0) + _cuda_check_cache = (True, device_name) + logger.info(f"CUDA available via PyTorch: {device_name}") + return _cuda_check_cache + except ImportError: + pass # PyTorch not installed, try CuPy next + + # Try CuPy + if HAS_CUPY_PACKAGE: + try: + import cupy as cp + if cp.cuda.is_available(): + device = cp.cuda.Device(0) + device_info = f"GPU {device.id}" + _cuda_check_cache = (True, device_info) + logger.info(f"CUDA available via CuPy: {device_info}") + return _cuda_check_cache + except ImportError: + pass # CuPy not installed, fall through to CPU-only + + _cuda_check_cache = (False, "CPU only") + logger.info("CUDA not available, using CPU") + return _cuda_check_cache + + +def check_cuml_available() -> bool: + """Check if cuML is available (actual import, for runtime use).""" + if not HAS_CUML_PACKAGE: + return False + try: + import cuml + return True + except ImportError: + return False + + +def check_faiss_available() -> bool: + """Check if FAISS is available (actual import, for runtime use).""" + if not HAS_FAISS_PACKAGE: + return False + try: + import faiss + return True + except ImportError: + return False + + +def resolve_backend(backend: str, operation: str = "general") -> str: + """ + Resolve 'auto' backend to actual backend based on available hardware. + + Args: + backend: Requested backend ("auto", "sklearn", "cuml", "faiss") + operation: Operation type for logging ("clustering", "reduction", "general") + + Returns: + Resolved backend name + """ + if backend != "auto": + logger.debug(f"Using explicitly requested backend: {backend}") + return backend + + cuda_available, device_info = check_cuda_available() + has_cuml = check_cuml_available() + has_faiss = check_faiss_available() + + if cuda_available and has_cuml: + resolved = "cuml" + logger.info(f"Auto-resolved {operation} backend to cuML (GPU: {device_info})") + elif has_faiss: + resolved = "faiss" + logger.info(f"Auto-resolved {operation} backend to FAISS (CPU)") + else: + resolved = "sklearn" + logger.info(f"Auto-resolved {operation} backend to sklearn (CPU)") + + return resolved + + +def get_backend_info() -> dict: + """ + Get comprehensive backend availability information. + + Returns: + Dictionary with backend availability status + """ + cuda_available, device_info = check_cuda_available() + + return { + "cuda_available": cuda_available, + "device_info": device_info, + "cuml_available": check_cuml_available(), + "faiss_available": check_faiss_available(), + } + + +def is_gpu_error(error: Exception) -> bool: + """ + Check if an exception is a GPU-related error. + + Args: + error: Exception to check + + Returns: + True if error is GPU-related + """ + error_msg = str(error).lower() + gpu_indicators = [ + "out of memory", + "oom", + "cuda", + "gpu", + "nvrtc", + "libnvrtc", + "no kernel image", + "cudaerror", + ] + return any(indicator in error_msg for indicator in gpu_indicators) + + +def is_oom_error(error: Exception) -> bool: + """Check if an exception is an out-of-memory error.""" + error_msg = str(error).lower() + oom_indicators = [ + "out of memory", + "cudaerroroutofmemory", + "oom", + "memory allocation failed", + "cudamalloc failed", + "failed to allocate", + ] + return any(indicator in error_msg for indicator in oom_indicators) + + +def is_cuda_arch_error(error: Exception) -> bool: + """Check if an exception is a CUDA architecture incompatibility error.""" + error_msg = str(error).lower() + arch_indicators = [ + "no kernel image", + "cudaerrornokernel", + "unsupported gpu", + "compute capability", + ] + return any(indicator in error_msg for indicator in arch_indicators) diff --git a/shared/utils/clustering.py b/shared/utils/clustering.py new file mode 100644 index 0000000..f144f4b --- /dev/null +++ b/shared/utils/clustering.py @@ -0,0 +1,497 @@ +from typing import Optional, Tuple +import os +import sys +import subprocess +import tempfile +import time +import numpy as np + +from shared.utils.logging_config import get_logger +from shared.utils.backend import ( + HAS_FAISS_PACKAGE, HAS_CUML_PACKAGE, HAS_CUPY_PACKAGE, + check_cuda_available, check_cuml_available, check_faiss_available, +) + +logger = get_logger(__name__) + +# Legacy module-level flags — now backed by lightweight find_spec() checks +# so importing this module no longer triggers heavy library loads. +# Functions that actually need the libraries import them locally. +HAS_FAISS: bool = HAS_FAISS_PACKAGE +HAS_CUML: bool = HAS_CUML_PACKAGE and HAS_CUPY_PACKAGE +HAS_CUDA: bool = False # resolved lazily via check_cuda_available() + + +def _check_cuda() -> bool: + """Check CUDA availability (cached after first call).""" + global HAS_CUDA + available, _ = check_cuda_available() + HAS_CUDA = available + return available + + +class VRAMExceededError(Exception): + """Raised when GPU VRAM is exceeded during computation.""" + pass + + +class GPUArchitectureError(Exception): + """Raised when GPU architecture is not supported.""" + pass + + +def get_gpu_memory_info() -> Optional[Tuple[int, int]]: + """ + Get GPU memory info (used, total) in MB. + + Returns: + Tuple of (used_mb, total_mb) or None if unavailable. + """ + try: + if HAS_CUML and _check_cuda(): + import cupy as cp + meminfo = cp.cuda.Device().mem_info + free_bytes, total_bytes = meminfo + used_bytes = total_bytes - free_bytes + return (used_bytes // (1024 * 1024), total_bytes // (1024 * 1024)) + except Exception: + pass # GPU memory query via CuPy failed; try PyTorch next + + try: + import torch + if torch.cuda.is_available(): + used = torch.cuda.memory_allocated() // (1024 * 1024) + total = torch.cuda.get_device_properties(0).total_memory // (1024 * 1024) + return (used, total) + except Exception: + pass # GPU memory query via PyTorch failed; return None + + return None + + +def estimate_memory_requirement(n_samples: int, n_features: int, method: str) -> int: + """ + Estimate memory requirement in MB for dimensionality reduction. + + Args: + n_samples: Number of samples + n_features: Number of features + method: Reduction method (PCA, TSNE, UMAP) + + Returns: + Estimated memory in MB + """ + # Base memory for input data (float32) + base_mb = (n_samples * n_features * 4) / (1024 * 1024) + + # Method-specific multipliers (empirical estimates) + if method.upper() == "PCA": + return int(base_mb * 2) # Relatively low overhead + elif method.upper() == "TSNE": + return int(base_mb * 4 + (n_samples * n_samples * 4) / (1024 * 1024)) # Distance matrix + elif method.upper() == "UMAP": + return int(base_mb * 3 + (n_samples * 15 * 4) / (1024 * 1024)) # kNN graph + else: + return int(base_mb * 3) + +def _prepare_embeddings(embeddings: np.ndarray, operation: str) -> np.ndarray: + """Validate, cast to float32, and L2-normalize embeddings. + + L2 normalization projects vectors onto the unit hypersphere (magnitude 1). + This stabilises cuML's NN-descent (prevents SIGFPE from large magnitudes) + and is appropriate for contrastive-model embeddings (e.g. CLIP, BioCLIP) + whose training objective is cosine-similarity based. + + Args: + embeddings: Raw embedding matrix (n_samples, n_features). + operation: Label for log messages (e.g. "reduce_dim", "kmeans"). + + Returns: + L2-normalized float32 embedding matrix. + """ + n_samples, n_features = embeddings.shape + + # Cast to float32 + embeddings = np.ascontiguousarray(embeddings, dtype=np.float32) + + # Check for non-finite values + n_nonfinite = (~np.isfinite(embeddings)).sum() + if n_nonfinite > 0: + logger.warning(f"[{operation}] {n_nonfinite} non-finite values found, replacing with 0") + embeddings = np.nan_to_num(embeddings, nan=0.0, posinf=0.0, neginf=0.0) + + # L2 normalize + norms = np.linalg.norm(embeddings, axis=1, keepdims=True) + n_zero = (norms.ravel() < 1e-10).sum() + if n_zero > 0: + logger.warning(f"[{operation}] {n_zero} near-zero-norm vectors found (will clamp to avoid division by zero)") + embeddings = embeddings / np.maximum(norms, 1e-10) + + logger.info(f"[{operation}] Prepared embeddings: {n_samples} samples, {n_features} features, " + f"dtype=float32, L2-normalized " + f"(input norms: min={norms.min():.2f}, max={norms.max():.2f}, mean={norms.mean():.2f})") + return embeddings + + +def reduce_dim(embeddings: np.ndarray, method: str = "PCA", seed: Optional[int] = None, n_workers: int = 1, backend: str = "auto"): + """ + Reduce the dimensionality of embeddings to 2D using PCA, t-SNE, or UMAP. + + Args: + embeddings (np.ndarray): The input feature embeddings of shape (n_samples, n_features). + method (str, optional): The dimensionality reduction method, "PCA", "TSNE", or "UMAP". Defaults to "PCA". + seed (int, optional): Random seed for reproducibility. Defaults to None (random). + n_workers (int, optional): Number of parallel workers for t-SNE/UMAP. Defaults to 1. + backend (str, optional): Backend to use - "auto", "sklearn", "cuml". Defaults to "auto". + + Returns: + np.ndarray: The 2D reduced embeddings of shape (n_samples, 2). + + Raises: + ValueError: If an unsupported method is provided. + """ + n_samples, n_features = embeddings.shape + logger.info(f"Dimensionality reduction: method={method}, samples={n_samples}, features={n_features}, backend={backend}") + + # Validate, cast, and L2-normalize + embeddings = _prepare_embeddings(embeddings, "reduce_dim") + + # Determine which backend to use + cuda_available = _check_cuda() + use_cuml = False + if backend == "cuml" and HAS_CUML and cuda_available: + use_cuml = True + elif backend == "auto" and HAS_CUML and cuda_available and n_samples > 5000: + # Use cuML automatically for large datasets on GPU + use_cuml = True + + start_time = time.time() + if use_cuml: + logger.info(f"Using cuML backend for {method}") + result = _reduce_dim_cuml(embeddings, method, seed, n_workers) + else: + logger.info(f"Using sklearn backend for {method}") + result = _reduce_dim_sklearn(embeddings, method, seed, n_workers) + + elapsed = time.time() - start_time + logger.info(f"Dimensionality reduction completed in {elapsed:.2f}s") + return result + + +def _reduce_dim_sklearn(embeddings: np.ndarray, method: str, seed: Optional[int], n_workers: int): + """Dimensionality reduction using sklearn/umap backends.""" + from sklearn.decomposition import PCA + from sklearn.manifold import TSNE + + # Use -1 (all available cores) instead of specific values > 1 to avoid + # thread count restrictions on HPC clusters (OMP_NUM_THREADS, SLURM cgroups) + effective_workers = -1 if n_workers > 1 else n_workers + + if method.upper() == "PCA": + reducer = PCA(n_components=2) + elif method.upper() == "TSNE": + # Adjust perplexity to be valid for the sample size + n_samples = embeddings.shape[0] + perplexity = min(30, max(5, n_samples // 3)) # Ensure perplexity is reasonable + + if seed is not None: + reducer = TSNE(n_components=2, perplexity=perplexity, random_state=seed, n_jobs=effective_workers) + else: + reducer = TSNE(n_components=2, perplexity=perplexity, n_jobs=effective_workers) + elif method.upper() == "UMAP": + from umap import UMAP + # Adjust n_neighbors to be valid for the sample size + n_samples = embeddings.shape[0] + n_neighbors = min(15, max(2, n_samples - 1)) + + if seed is not None: + reducer = UMAP(n_components=2, n_neighbors=n_neighbors, random_state=seed, n_jobs=effective_workers) + else: + reducer = UMAP(n_components=2, n_neighbors=n_neighbors, n_jobs=effective_workers) + else: + raise ValueError("Unsupported method. Choose 'PCA', 'TSNE', or 'UMAP'.") + return reducer.fit_transform(embeddings) + + +def _reduce_dim_cuml(embeddings: np.ndarray, method: str, seed: Optional[int], n_workers: int): + """Dimensionality reduction using cuML GPU backends. + + Expects embeddings to already be L2-normalized float32 from _prepare_embeddings(). + """ + try: + import cupy as cp + + if method.upper() == "UMAP": + # cuML UMAP can crash with SIGFPE on certain data distributions + # (NN-descent numerical instability). SIGFPE is a signal, not a + # Python exception, so try/except cannot catch it. Run in an + # isolated subprocess so the main process (Streamlit) survives. + return _run_cuml_umap_subprocess(embeddings, seed) + + # PCA and TSNE are stable — run in-process + embeddings_gpu = cp.asarray(embeddings, dtype=cp.float32) + + if method.upper() == "PCA": + from cuml.decomposition import PCA as cuPCA + reducer = cuPCA(n_components=2) + elif method.upper() == "TSNE": + from cuml.manifold import TSNE as cuTSNE + n_samples = embeddings.shape[0] + perplexity = min(30, max(5, n_samples // 3)) + + if seed is not None: + reducer = cuTSNE(n_components=2, perplexity=perplexity, random_state=seed) + else: + reducer = cuTSNE(n_components=2, perplexity=perplexity) + else: + raise ValueError("Unsupported method. Choose 'PCA', 'TSNE', or 'UMAP'.") + + result_gpu = reducer.fit_transform(embeddings_gpu) + return cp.asnumpy(result_gpu) + + except RuntimeError as e: + error_msg = str(e).lower() + if "no kernel image" in error_msg or "cudaerrornokernel" in error_msg: + logger.warning(f"cuML {method} not supported on this GPU architecture, falling back to sklearn") + else: + logger.warning(f"cuML reduction failed ({e}), falling back to sklearn") + return _reduce_dim_sklearn(embeddings, method, seed, n_workers) + except Exception as e: + logger.warning(f"cuML reduction failed ({e}), falling back to sklearn") + return _reduce_dim_sklearn(embeddings, method, seed, n_workers) + + +# Standalone script executed in a subprocess for cuML UMAP. +# Kept minimal: only imports cuml/cupy/numpy, no project dependencies. +_CUML_UMAP_SCRIPT = """\ +import sys, numpy as np, cupy as cp +from cuml.manifold import UMAP as cuUMAP + +input_path, output_path = sys.argv[1], sys.argv[2] +n_neighbors = int(sys.argv[3]) +seed = int(sys.argv[4]) if sys.argv[4] else None + +embeddings = np.load(input_path) +emb_gpu = cp.asarray(embeddings, dtype=cp.float32) + +# Embeddings arrive L2-normalized from _prepare_embeddings(). +# Verify as a safety net — re-normalize if needed (prevents SIGFPE from NN-descent). +norms = cp.linalg.norm(emb_gpu, axis=1) +if cp.abs(norms.mean() - 1.0) > 0.01: + emb_gpu = emb_gpu / cp.maximum(norms.reshape(-1, 1), 1e-10) + +kw = dict(n_components=2, n_neighbors=n_neighbors) +if seed is not None: + kw["random_state"] = seed +reducer = cuUMAP(**kw) +result = reducer.fit_transform(emb_gpu) +np.save(output_path, cp.asnumpy(result)) +""" + + +def _run_cuml_umap_subprocess(embeddings: np.ndarray, seed: Optional[int]) -> np.ndarray: + """Run cuML UMAP in an isolated subprocess to survive SIGFPE crashes. + + cuML UMAP's NN-descent can trigger a floating-point exception (SIGFPE) on + certain data distributions, which kills the entire process. By running in + a child process, the parent (Streamlit) survives and can fall back to + sklearn UMAP. + """ + n_samples = embeddings.shape[0] + n_neighbors = min(15, max(2, n_samples - 1)) + + # Use /dev/shm for fast IPC when available, else /tmp + shm_dir = "/dev/shm" if os.path.isdir("/dev/shm") else tempfile.gettempdir() + input_path = os.path.join(shm_dir, f"cuml_umap_in_{os.getpid()}.npy") + output_path = os.path.join(shm_dir, f"cuml_umap_out_{os.getpid()}.npy") + + np.save(input_path, embeddings) + seed_arg = str(seed) if seed is not None else "" + + try: + logger.info(f"Running cuML UMAP in subprocess ({n_samples} samples, " + f"n_neighbors={n_neighbors})") + result = subprocess.run( + [sys.executable, "-c", _CUML_UMAP_SCRIPT, + input_path, output_path, str(n_neighbors), seed_arg], + capture_output=True, text=True, timeout=300, + ) + + if result.returncode == 0 and os.path.exists(output_path): + reduced = np.load(output_path) + logger.info("cuML UMAP subprocess completed successfully") + return reduced + + stderr = result.stderr.strip() + raise RuntimeError( + f"cuML UMAP subprocess failed (rc={result.returncode}): " + f"{stderr[-500:] if stderr else 'no stderr'}" + ) + finally: + for path in (input_path, output_path): + try: + os.unlink(path) + except OSError: + pass # Best-effort cleanup of temp IPC files + +def run_kmeans(embeddings: np.ndarray, n_clusters: int, seed: Optional[int] = None, n_workers: int = 1, backend: str = "auto"): + """ + Perform KMeans clustering on the given embeddings. + + Args: + embeddings (np.ndarray): The input feature embeddings of shape (n_samples, n_features). + n_clusters (int): The number of clusters to form. + seed (int, optional): Random seed for reproducibility. Defaults to None (random). + n_workers (int, optional): Number of parallel workers (used by FAISS and cuML if available). + backend (str, optional): Clustering backend - "auto", "sklearn", "faiss", or "cuml". Defaults to "auto". + + Returns: + kmeans (KMeans or custom object): The fitted clustering object. + labels (np.ndarray): Cluster labels for each sample. + """ + n_samples = embeddings.shape[0] + logger.info(f"KMeans clustering: n_clusters={n_clusters}, samples={n_samples}, backend={backend}") + + # Validate, cast, and L2-normalize + embeddings = _prepare_embeddings(embeddings, "kmeans") + + start_time = time.time() + + # Determine which backend to use + cuda_available = _check_cuda() + if backend == "cuml" and HAS_CUML and cuda_available: + logger.info("Using cuML backend for KMeans") + result = _run_kmeans_cuml(embeddings, n_clusters, seed, n_workers) + elif backend == "faiss" and HAS_FAISS: + logger.info("Using FAISS backend for KMeans") + result = _run_kmeans_faiss(embeddings, n_clusters, seed, n_workers) + elif backend == "auto": + # Auto selection priority: cuML > FAISS > sklearn + if HAS_CUML and cuda_available and n_samples > 500: + logger.info("Auto-selected cuML backend for KMeans (GPU available, large dataset)") + result = _run_kmeans_cuml(embeddings, n_clusters, seed, n_workers) + elif HAS_FAISS and n_samples > 500: + logger.info("Auto-selected FAISS backend for KMeans (large dataset)") + result = _run_kmeans_faiss(embeddings, n_clusters, seed, n_workers) + else: + logger.info("Using sklearn backend for KMeans") + result = _run_kmeans_sklearn(embeddings, n_clusters, seed) + else: + logger.info("Using sklearn backend for KMeans") + result = _run_kmeans_sklearn(embeddings, n_clusters, seed) + + elapsed = time.time() - start_time + logger.info(f"KMeans clustering completed in {elapsed:.2f}s") + return result + + +def _run_kmeans_cuml(embeddings: np.ndarray, n_clusters: int, seed: Optional[int] = None, n_workers: int = 1): + """KMeans using cuML GPU backend.""" + try: + import cupy as cp + from cuml.cluster import KMeans as cuKMeans + + # Convert to cupy array for GPU processing + embeddings_gpu = cp.asarray(embeddings, dtype=cp.float32) + + # Create cuML KMeans object + if seed is not None: + kmeans = cuKMeans( + n_clusters=n_clusters, + random_state=seed, + max_iter=300, + init='k-means++', + tol=1e-4 + ) + else: + kmeans = cuKMeans( + n_clusters=n_clusters, + max_iter=300, + init='k-means++', + tol=1e-4 + ) + + # Fit and predict on GPU + labels_gpu = kmeans.fit_predict(embeddings_gpu) + + # Convert results back to numpy + labels = cp.asnumpy(labels_gpu) + centroids = cp.asnumpy(kmeans.cluster_centers_) + + # Create a simple object to mimic sklearn KMeans interface + class cuMLKMeans: + def __init__(self, centroids, labels): + self.cluster_centers_ = centroids + self.labels_ = labels + self.n_clusters = len(centroids) + + return cuMLKMeans(centroids, labels), labels + + except Exception as e: + logger.warning(f"cuML clustering failed ({e}), falling back to sklearn") + return _run_kmeans_sklearn(embeddings, n_clusters, seed) + + +def _run_kmeans_sklearn(embeddings: np.ndarray, n_clusters: int, seed: Optional[int] = None): + """KMeans using scikit-learn backend.""" + from sklearn.cluster import KMeans + if seed is not None: + kmeans = KMeans(n_clusters=n_clusters, random_state=seed) + else: + kmeans = KMeans(n_clusters=n_clusters) + labels = kmeans.fit_predict(embeddings) + return kmeans, labels + + +def _run_kmeans_faiss(embeddings: np.ndarray, n_clusters: int, seed: Optional[int] = None, n_workers: int = 1): + """KMeans using FAISS backend for faster clustering.""" + try: + import faiss + + # Ensure embeddings are float32 and C-contiguous (FAISS requirement) + embeddings = np.ascontiguousarray(embeddings.astype(np.float32)) + + n_samples, d = embeddings.shape + + # Set number of threads for FAISS + if n_workers > 1: + faiss.omp_set_num_threads(n_workers) + + # Create FAISS KMeans object + kmeans = faiss.Clustering(d, n_clusters) + + # Set clustering parameters + kmeans.verbose = False + kmeans.niter = 20 # Number of iterations + kmeans.nredo = 1 # Number of redos + if seed is not None: + kmeans.seed = seed + + # Use L2 distance (equivalent to sklearn's default) + index = faiss.IndexFlatL2(d) + + # Run clustering + kmeans.train(embeddings, index) + + # Get centroids + centroids = faiss.vector_to_array(kmeans.centroids).reshape(n_clusters, d) + + # Assign labels by finding nearest centroid for each point + _, labels = index.search(embeddings, 1) + labels = labels.flatten() + + # Create a simple object to mimic sklearn KMeans interface + class FAISSKMeans: + def __init__(self, centroids, labels): + self.cluster_centers_ = centroids + self.labels_ = labels + self.n_clusters = len(centroids) + + return FAISSKMeans(centroids, labels), labels + + except Exception as e: + # Fallback to sklearn if FAISS fails + logger.warning(f"FAISS clustering failed ({e}), falling back to sklearn") + return _run_kmeans_sklearn(embeddings, n_clusters, seed) + + diff --git a/utils/io.py b/shared/utils/io.py similarity index 100% rename from utils/io.py rename to shared/utils/io.py diff --git a/shared/utils/logging_config.py b/shared/utils/logging_config.py new file mode 100644 index 0000000..b2c186b --- /dev/null +++ b/shared/utils/logging_config.py @@ -0,0 +1,94 @@ +""" +Centralized logging configuration for emb-explorer. + +Usage: + from shared.utils.logging_config import get_logger + logger = get_logger(__name__) + logger.info("Message") +""" + +import logging +import os +import sys +from typing import Optional + + +# Module-level flag to track if logging has been configured +_logging_configured = False + +# Default log directory (relative to working directory) +_LOG_DIR = os.environ.get("EMB_EXPLORER_LOG_DIR", "logs") +_LOG_FILE = "emb_explorer.log" + + +def configure_logging( + level: int = logging.INFO, + log_format: Optional[str] = None, + log_to_file: bool = True, +): + """ + Configure the root logger for the application. + + Args: + level: Logging level (default: INFO) + log_format: Custom log format string (optional) + log_to_file: Whether to also write logs to a file (default: True) + """ + global _logging_configured + + if _logging_configured: + return + + if log_format is None: + log_format = ( + "[%(asctime)s] %(levelname)s " + "[%(name)s.%(funcName)s:%(lineno)d] %(message)s" + ) + + # Configure root logger + root_logger = logging.getLogger() + root_logger.setLevel(level) + + # Remove existing handlers to avoid duplicates + for handler in root_logger.handlers[:]: + root_logger.removeHandler(handler) + + formatter = logging.Formatter(log_format, datefmt="%Y-%m-%d %H:%M:%S") + + # Console handler + console_handler = logging.StreamHandler(sys.stdout) + console_handler.setLevel(level) + console_handler.setFormatter(formatter) + root_logger.addHandler(console_handler) + + # File handler (append mode, rotates implicitly by date via log dir) + if log_to_file: + try: + os.makedirs(_LOG_DIR, exist_ok=True) + file_handler = logging.FileHandler( + os.path.join(_LOG_DIR, _LOG_FILE), mode="a", encoding="utf-8" + ) + file_handler.setLevel(logging.DEBUG) + file_handler.setFormatter(formatter) + root_logger.addHandler(file_handler) + except OSError: + # Non-fatal: skip file logging if directory can't be created + pass + + _logging_configured = True + + +def get_logger(name: str) -> logging.Logger: + """ + Get a logger instance for the given module name. + + Automatically configures logging if not already done. + + Args: + name: Logger name (typically __name__) + + Returns: + Logger instance + """ + configure_logging() + return logging.getLogger(name) diff --git a/utils/models.py b/shared/utils/models.py similarity index 65% rename from utils/models.py rename to shared/utils/models.py index 480ae2f..d958b84 100644 --- a/utils/models.py +++ b/shared/utils/models.py @@ -1,6 +1,3 @@ -import pandas as pd -import open_clip - def list_available_models(): """List all available models.""" @@ -14,6 +11,7 @@ def list_available_models(): ]) # OpenCLIP models + import open_clip openclip_models = open_clip.list_pretrained() for model_name, pretrained in openclip_models: models_data.append({ @@ -22,3 +20,13 @@ def list_available_models(): }) return models_data + + +def print_available_models(): + """CLI entry point: print all available models to stdout.""" + models = list_available_models() + for m in models: + if m["pretrained"]: + print(f"{m['name']} (pretrained: {m['pretrained']})") + else: + print(m["name"]) diff --git a/utils/taxonomy_tree.py b/shared/utils/taxonomy_tree.py similarity index 87% rename from utils/taxonomy_tree.py rename to shared/utils/taxonomy_tree.py index 291ff3d..69c513d 100644 --- a/utils/taxonomy_tree.py +++ b/shared/utils/taxonomy_tree.py @@ -4,61 +4,68 @@ import pandas as pd from typing import Dict, List, Any, Optional -from collections import defaultdict, Counter +from collections import defaultdict def build_taxonomic_tree(df: pd.DataFrame) -> Dict[str, Any]: """ Build a hierarchical taxonomic tree from a dataframe. - + Args: df: DataFrame containing taxonomic columns - + Returns: Nested dictionary representing the taxonomic tree with counts """ taxonomic_levels = ['kingdom', 'phylum', 'class', 'order', 'family', 'genus', 'species'] - + # Filter to only include rows that have at least kingdom df_clean = df[df['kingdom'].notna()].copy() - + tree = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: defaultdict( lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(int))))))) - + + def _val(row, col): + """Get column value, replacing NaN/None/empty with 'Unknown'.""" + v = row.get(col, 'Unknown') + if pd.isna(v) or v == '': + return 'Unknown' + return v + for _, row in df_clean.iterrows(): # Get values for each taxonomic level, using 'Unknown' for nulls - kingdom = row.get('kingdom', 'Unknown') or 'Unknown' - phylum = row.get('phylum', 'Unknown') or 'Unknown' - class_name = row.get('class', 'Unknown') or 'Unknown' - order = row.get('order', 'Unknown') or 'Unknown' - family = row.get('family', 'Unknown') or 'Unknown' - genus = row.get('genus', 'Unknown') or 'Unknown' - species = row.get('species', 'Unknown') or 'Unknown' - + kingdom = _val(row, 'kingdom') + phylum = _val(row, 'phylum') + class_name = _val(row, 'class') + order = _val(row, 'order') + family = _val(row, 'family') + genus = _val(row, 'genus') + species = _val(row, 'species') + # Build the nested structure tree[kingdom][phylum][class_name][order][family][genus][species] += 1 - + return dict(tree) def format_tree_string(tree: Dict[str, Any], max_depth: int = 7, min_count: int = 1) -> str: """ Format the taxonomic tree as a string similar to the 'tree' command output. - + Args: tree: Taxonomic tree dictionary max_depth: Maximum depth to display min_count: Minimum count to include in the tree - + Returns: Formatted tree string """ lines = [] - + def format_level(node, level=0, prefix="", is_last=True, path=""): if level >= max_depth: return - + if isinstance(node, dict): items = list(node.items()) # Sort by count (descending) if we're at the species level @@ -67,7 +74,7 @@ def format_level(node, level=0, prefix="", is_last=True, path=""): else: # Sort by name for higher levels items = sorted(items, key=lambda x: x[0]) - + # Filter by minimum count items = [(k, v) for k, v in items if ( isinstance(v, int) and v >= min_count) or ( @@ -75,10 +82,10 @@ def format_level(node, level=0, prefix="", is_last=True, path=""): get_total_count(subv) >= min_count for subv in v.values() ) )] - + for i, (key, value) in enumerate(items): is_last_item = (i == len(items) - 1) - + # Create the tree characters if level == 0: connector = "" @@ -86,7 +93,7 @@ def format_level(node, level=0, prefix="", is_last=True, path=""): else: connector = "└── " if is_last_item else "├── " new_prefix = prefix + (" " if is_last_item else "│ ") - + # Get count for this node if isinstance(value, int): count = value @@ -94,14 +101,14 @@ def format_level(node, level=0, prefix="", is_last=True, path=""): else: count = get_total_count(value) count_str = f" ({count})" if count > 0 else "" - + # Add the line lines.append(f"{prefix}{connector}{key}{count_str}") - + # Recurse if it's a dictionary if isinstance(value, dict): format_level(value, level + 1, new_prefix, is_last_item, f"{path}/{key}") - + format_level(tree) return "\n".join(lines) @@ -109,10 +116,10 @@ def format_level(node, level=0, prefix="", is_last=True, path=""): def get_total_count(node: Any) -> int: """ Get the total count for a tree node. - + Args: node: Tree node (dict or int) - + Returns: Total count for this node and all children """ @@ -127,10 +134,10 @@ def get_total_count(node: Any) -> int: def get_tree_statistics(tree: Dict[str, Any]) -> Dict[str, int]: """ Get statistics about the taxonomic tree. - + Args: tree: Taxonomic tree dictionary - + Returns: Dictionary with statistics """ @@ -144,7 +151,7 @@ def get_tree_statistics(tree: Dict[str, Any]) -> Dict[str, int]: 'genera': 0, 'species': 0 } - + for kingdom, phyla in tree.items(): stats['phyla'] += len(phyla) for phylum, classes in phyla.items(): @@ -157,5 +164,5 @@ def get_tree_statistics(tree: Dict[str, Any]) -> Dict[str, int]: stats['genera'] += len(genera) for genus, species in genera.items(): stats['species'] += len(species) - + return stats diff --git a/tests/README.md b/tests/README.md new file mode 100644 index 0000000..6f739e8 --- /dev/null +++ b/tests/README.md @@ -0,0 +1,52 @@ +# Test Suite + +Hey! Welcome to the emb-explorer test suite. This doc is for humans *and* AI coding agents (hi Claude) — so it's kept concise and structured. + +## Quick Start + +Once your venv is activated: + +```bash +pytest tests/ -v # all tests +pytest tests/test_backend.py -v # specific file +pytest tests/ -m "not gpu" # skip GPU-marked tests +``` + +> **Heads up:** TSNE/UMAP tests are slow on CPU (~1 min). Everything else is fast. Much quicker on GPU nodes. + +## Test Organization + +| File | What It Covers | +|---|---| +| `test_backend.py` (29) | Error classifiers, backend resolution priority, CUDA cache | +| `test_clustering.py` (23) | L2 normalization, dim reduction, KMeans, GPU fallback (mocked) | +| `test_filters.py` (16) | PyArrow filter logic, column type detection, embedding extraction | +| `test_taxonomy_tree.py` (12) | Tree building, NaN handling, depth/count filtering | +| `test_clustering_service.py` (8) | Clustering summary, `run_clustering_safe()` fallback chain | +| `test_logging_config.py` (5) | Logger naming, handler setup, idempotency | +| `conftest.py` | Shared fixtures (embeddings, paths, PyArrow tables, reset helpers) | + +**98 tests total.** All pass on CPU-only machines — no GPU required. GPU fallback behavior is tested via mocking (`HAS_CUML`, `HAS_CUDA`, `subprocess.run`). The `@pytest.mark.gpu` marker is registered for future tests that exercise real GPU code paths. + +## Running on a SLURM Cluster + +Two batch scripts are provided in `tests/`. Before using them, edit the `#SBATCH` headers to match your cluster (account, partition names, venv path): + +```bash +sbatch tests/run_cpu_tests.sh # CPU partition — runs non-GPU tests +sbatch tests/run_gpu_tests.sh # GPU partition — runs full suite +sbatch tests/run_gpu_tests.sh --gpu # GPU partition — GPU-marked tests only +``` + +The GPU script sets `LD_LIBRARY_PATH` for cuML/CuPy nvidia libs automatically. + +## For AI Agents + +If you're adding new utility functions to `shared/utils/` or `shared/services/`: + +1. **Add tests.** Check if an existing test file covers the module, or create a new one. +2. **Use the fixtures** in `conftest.py` — `sample_embeddings`, `sample_embeddings_small`, `sample_arrow_table`, etc. +3. **Mock GPU code**, don't try to call it. Patch module-level flags like `HAS_CUML` or inject mock objects for `cp` (cupy). +4. **Run `pytest tests/ -v`** after changes to verify nothing broke. +5. The `reset_cuda_cache` and `reset_logging` fixtures exist because those modules use global state — use them when testing `backend.py` or `logging_config.py`. +6. **GPU tests** (future) use `@pytest.mark.gpu`. These only run on GPU nodes — don't expect them to pass on CPU-only nodes. diff --git a/components/__init__.py b/tests/__init__.py similarity index 100% rename from components/__init__.py rename to tests/__init__.py diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..a87411f --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,79 @@ +"""Shared fixtures for emb-explorer test suite.""" + +import logging +import numpy as np +import pandas as pd +import pyarrow as pa +import pytest + + +@pytest.fixture +def sample_embeddings(): + """Reproducible (100, 512) float32 embedding matrix.""" + rng = np.random.RandomState(42) + return rng.randn(100, 512).astype(np.float32) + + +@pytest.fixture +def sample_embeddings_small(): + """Small (10, 32) float32 embedding matrix for fast edge-case tests.""" + rng = np.random.RandomState(42) + return rng.randn(10, 32).astype(np.float32) + + +@pytest.fixture +def sample_paths(): + """Fake image paths matching sample_embeddings (100 items).""" + return [f"/images/img_{i:04d}.jpg" for i in range(100)] + + +@pytest.fixture +def sample_uuids(): + """Fake UUIDs matching sample_embeddings (100 items).""" + return [f"uuid-{i:04d}" for i in range(100)] + + +@pytest.fixture +def sample_labels(): + """Cluster labels for 100 samples across 5 clusters.""" + rng = np.random.RandomState(42) + return rng.randint(0, 5, size=100) + + +@pytest.fixture +def sample_arrow_table(): + """PyArrow table with mixed column types for filter testing.""" + return pa.table({ + "uuid": [f"id-{i}" for i in range(20)], + "species": ["cat", "dog", "cat", "bird", "dog"] * 4, + "family": ["felidae", "canidae", "felidae", "passeridae", "canidae"] * 4, + "weight": [4.5, 25.0, 3.8, 0.03, 30.0] * 4, + "notes": ["healthy", "large breed", "kitten", "sparrow", "retriever"] * 4, + "emb": [[0.1] * 8 for _ in range(20)], + }) + + +@pytest.fixture +def reset_cuda_cache(): + """Reset backend CUDA cache between tests.""" + import shared.utils.backend as backend_mod + original = backend_mod._cuda_check_cache + backend_mod._cuda_check_cache = None + yield + backend_mod._cuda_check_cache = original + + +@pytest.fixture +def reset_logging(): + """Reset logging configuration between tests.""" + import shared.utils.logging_config as log_mod + original = log_mod._logging_configured + log_mod._logging_configured = False + root = logging.getLogger() + old_handlers = root.handlers[:] + root.handlers.clear() + yield + root.handlers.clear() + for h in old_handlers: + root.addHandler(h) + log_mod._logging_configured = original diff --git a/tests/run_cpu_tests.sh b/tests/run_cpu_tests.sh new file mode 100755 index 0000000..f087d68 --- /dev/null +++ b/tests/run_cpu_tests.sh @@ -0,0 +1,36 @@ +#!/bin/bash +#SBATCH --account=PAS2136 +#SBATCH --partition=cpu +#SBATCH --cpus-per-task=4 +#SBATCH --time=00:30:00 +#SBATCH --job-name=emb-tests-cpu +#SBATCH --output=tests/cpu_test_results_%j.log + +# ------------------------------------------------------------------ +# CPU test runner for emb-explorer (OSC Pitzer) +# +# Usage: +# sbatch tests/run_cpu_tests.sh # all non-GPU tests +# sbatch tests/run_cpu_tests.sh tests/test_filters.py # specific file +# ------------------------------------------------------------------ + +set -euo pipefail + +PROJECT_DIR="${SLURM_SUBMIT_DIR:-$(cd "$(dirname "$0")/.." && pwd)}" +cd "$PROJECT_DIR" + +source /fs/scratch/PAS2136/netzissou/venv/emb_explorer_pitzer/bin/activate + +echo "=== CPU Test Run ===" +echo "Node: $(hostname)" +echo "Python: $(python --version)" +echo "Project: $PROJECT_DIR" +echo "====================" + +if [[ -n "${1:-}" ]]; then + echo "Running: pytest $* -m 'not gpu' -v" + pytest "$@" -m "not gpu" -v +else + echo "Running all CPU tests..." + pytest tests/ -m "not gpu" -v +fi diff --git a/tests/run_gpu_tests.sh b/tests/run_gpu_tests.sh new file mode 100755 index 0000000..fec5d62 --- /dev/null +++ b/tests/run_gpu_tests.sh @@ -0,0 +1,43 @@ +#!/bin/bash +#SBATCH --account=PAS2136 +#SBATCH --partition=gpu +#SBATCH --gpus-per-node=1 +#SBATCH --cpus-per-task=4 +#SBATCH --time=00:30:00 +#SBATCH --job-name=emb-tests-gpu +#SBATCH --output=tests/gpu_test_results_%j.log + +# ------------------------------------------------------------------ +# GPU test runner for emb-explorer (OSC Pitzer) +# +# Usage: +# sbatch tests/run_gpu_tests.sh # full suite on GPU node +# sbatch tests/run_gpu_tests.sh --gpu # GPU-marked tests only +# ------------------------------------------------------------------ + +set -euo pipefail + +PROJECT_DIR="${SLURM_SUBMIT_DIR:-$(cd "$(dirname "$0")/.." && pwd)}" +cd "$PROJECT_DIR" + +source /fs/scratch/PAS2136/netzissou/venv/emb_explorer_pitzer/bin/activate + +# cuML/CuPy need nvidia libs on LD_LIBRARY_PATH +NVIDIA_LIBS="$(python -c 'import nvidia.cublas.lib, nvidia.cusolver.lib, nvidia.cusparse.lib; \ + print(nvidia.cublas.lib.__path__[0]); print(nvidia.cusolver.lib.__path__[0]); print(nvidia.cusparse.lib.__path__[0])' 2>/dev/null | tr '\n' ':')" || true +export LD_LIBRARY_PATH="${NVIDIA_LIBS}${LD_LIBRARY_PATH:+:$LD_LIBRARY_PATH}" + +echo "=== GPU Test Run ===" +echo "Node: $(hostname)" +echo "GPU: $(nvidia-smi --query-gpu=name --format=csv,noheader 2>/dev/null || echo 'N/A')" +echo "Python: $(python --version)" +echo "Project: $PROJECT_DIR" +echo "====================" + +if [[ "${1:-}" == "--gpu" ]]; then + echo "Running GPU-marked tests only..." + pytest tests/ -m gpu -v +else + echo "Running full test suite on GPU node..." + pytest tests/ -v +fi diff --git a/tests/test_backend.py b/tests/test_backend.py new file mode 100644 index 0000000..bd4f1b6 --- /dev/null +++ b/tests/test_backend.py @@ -0,0 +1,126 @@ +"""Tests for shared/utils/backend.py.""" + +from unittest.mock import patch + +import pytest + +from shared.utils.backend import ( + is_gpu_error, + is_oom_error, + is_cuda_arch_error, + resolve_backend, + check_cuda_available, +) + + +# --------------------------------------------------------------------------- +# Error classifiers (pure — no mocking needed) +# --------------------------------------------------------------------------- + +class TestIsGpuError: + @pytest.mark.parametrize("msg", [ + "CUDA error: out of memory", + "RuntimeError: no kernel image is available", + "nvrtc compilation failed", + "libnvrtc.so not found", + "GPU memory allocation failed", + "cudaErrorNoKernel", + ]) + def test_gpu_errors_detected(self, msg): + assert is_gpu_error(RuntimeError(msg)) + + @pytest.mark.parametrize("msg", [ + "FileNotFoundError: /tmp/data.npy", + "ValueError: invalid literal", + "Connection refused", + ]) + def test_non_gpu_errors_rejected(self, msg): + assert not is_gpu_error(RuntimeError(msg)) + + def test_case_insensitive(self): + assert is_gpu_error(RuntimeError("CUDA ERROR: device not found")) + + +class TestIsOomError: + @pytest.mark.parametrize("msg", [ + "CUDA out of memory", + "cudaErrorOutOfMemory", + "OOM killer invoked", + "memory allocation failed", + "cudaMalloc failed", + "failed to allocate 1024 bytes", + ]) + def test_oom_errors_detected(self, msg): + assert is_oom_error(RuntimeError(msg)) + + def test_non_oom_rejected(self): + assert not is_oom_error(RuntimeError("invalid argument")) + + +class TestIsCudaArchError: + @pytest.mark.parametrize("msg", [ + "no kernel image is available for execution on the device", + "cudaErrorNoKernel", + "unsupported GPU architecture", + "compute capability 3.5 not supported", + ]) + def test_arch_errors_detected(self, msg): + assert is_cuda_arch_error(RuntimeError(msg)) + + def test_non_arch_rejected(self): + assert not is_cuda_arch_error(RuntimeError("out of memory")) + + +# --------------------------------------------------------------------------- +# resolve_backend (mock check_* functions) +# --------------------------------------------------------------------------- + +class TestResolveBackend: + def test_explicit_backend_passthrough(self): + assert resolve_backend("sklearn") == "sklearn" + assert resolve_backend("cuml") == "cuml" + assert resolve_backend("faiss") == "faiss" + + def test_auto_with_cuda_and_cuml(self): + with patch("shared.utils.backend.check_cuda_available", return_value=(True, "V100")), \ + patch("shared.utils.backend.check_cuml_available", return_value=True), \ + patch("shared.utils.backend.check_faiss_available", return_value=True): + assert resolve_backend("auto") == "cuml" + + def test_auto_without_cuda_with_faiss(self): + with patch("shared.utils.backend.check_cuda_available", return_value=(False, "CPU only")), \ + patch("shared.utils.backend.check_cuml_available", return_value=False), \ + patch("shared.utils.backend.check_faiss_available", return_value=True): + assert resolve_backend("auto") == "faiss" + + def test_auto_cpu_only(self): + with patch("shared.utils.backend.check_cuda_available", return_value=(False, "CPU only")), \ + patch("shared.utils.backend.check_cuml_available", return_value=False), \ + patch("shared.utils.backend.check_faiss_available", return_value=False): + assert resolve_backend("auto") == "sklearn" + + def test_auto_cuda_without_cuml_falls_to_faiss(self): + with patch("shared.utils.backend.check_cuda_available", return_value=(True, "V100")), \ + patch("shared.utils.backend.check_cuml_available", return_value=False), \ + patch("shared.utils.backend.check_faiss_available", return_value=True): + assert resolve_backend("auto") == "faiss" + + +# --------------------------------------------------------------------------- +# check_cuda_available (mock imports, test caching) +# --------------------------------------------------------------------------- + +class TestCheckCudaAvailable: + def test_returns_false_without_gpu(self, reset_cuda_cache): + """On a CPU-only node, should return (False, 'CPU only').""" + with patch("shared.utils.backend.HAS_TORCH_PACKAGE", False), \ + patch("shared.utils.backend.HAS_CUPY_PACKAGE", False): + result = check_cuda_available() + assert result == (False, "CPU only") + + def test_cache_prevents_reimport(self, reset_cuda_cache): + """Second call should return cached value.""" + import shared.utils.backend as backend_mod + backend_mod._cuda_check_cache = (True, "V100-test") + result = check_cuda_available() + assert result == (True, "V100-test") diff --git a/tests/test_clustering.py b/tests/test_clustering.py new file mode 100644 index 0000000..0a2f5e8 --- /dev/null +++ b/tests/test_clustering.py @@ -0,0 +1,194 @@ +"""Tests for shared/utils/clustering.py.""" + +import subprocess +import sys +from unittest.mock import patch, MagicMock + +import numpy as np +import pytest + +from shared.utils.clustering import ( + _prepare_embeddings, + estimate_memory_requirement, + reduce_dim, + run_kmeans, + _reduce_dim_sklearn, + _reduce_dim_cuml, + _run_kmeans_sklearn, + _run_cuml_umap_subprocess, +) + + +# --------------------------------------------------------------------------- +# _prepare_embeddings +# --------------------------------------------------------------------------- + +class TestPrepareEmbeddings: + def test_output_dtype_float32(self, sample_embeddings): + result = _prepare_embeddings(sample_embeddings, "test") + assert result.dtype == np.float32 + + def test_output_l2_normalized(self, sample_embeddings): + result = _prepare_embeddings(sample_embeddings, "test") + norms = np.linalg.norm(result, axis=1) + np.testing.assert_allclose(norms, 1.0, atol=1e-5) + + def test_shape_preserved(self, sample_embeddings): + result = _prepare_embeddings(sample_embeddings, "test") + assert result.shape == sample_embeddings.shape + + def test_nan_replaced(self): + emb = np.array([[1.0, np.nan, 3.0], [4.0, 5.0, 6.0]], dtype=np.float32) + result = _prepare_embeddings(emb, "test") + assert np.all(np.isfinite(result)) + + def test_inf_replaced(self): + emb = np.array([[1.0, np.inf, 3.0], [4.0, -np.inf, 6.0]], dtype=np.float32) + result = _prepare_embeddings(emb, "test") + assert np.all(np.isfinite(result)) + + def test_zero_norm_vector_clamped(self): + emb = np.array([[0.0, 0.0, 0.0], [1.0, 2.0, 3.0]], dtype=np.float32) + result = _prepare_embeddings(emb, "test") + # Zero vector stays near-zero after clamped division, no crash + assert np.all(np.isfinite(result)) + + def test_float64_input_cast(self): + emb = np.random.RandomState(0).randn(5, 10).astype(np.float64) + result = _prepare_embeddings(emb, "test") + assert result.dtype == np.float32 + + +# --------------------------------------------------------------------------- +# estimate_memory_requirement +# --------------------------------------------------------------------------- + +class TestEstimateMemory: + def test_positive_for_all_methods(self): + for method in ("PCA", "TSNE", "UMAP"): + assert estimate_memory_requirement(1000, 512, method) > 0 + + def test_tsne_greater_than_pca(self): + pca = estimate_memory_requirement(1000, 512, "PCA") + tsne = estimate_memory_requirement(1000, 512, "TSNE") + assert tsne > pca + + def test_unknown_method_returns_positive(self): + assert estimate_memory_requirement(1000, 512, "UNKNOWN") > 0 + + +# --------------------------------------------------------------------------- +# reduce_dim — sklearn path +# --------------------------------------------------------------------------- + +class TestReduceDimSklearn: + def test_pca_output_shape(self, sample_embeddings_small): + result = _reduce_dim_sklearn(sample_embeddings_small, "PCA", seed=42, n_workers=1) + assert result.shape == (10, 2) + + def test_tsne_output_shape(self, sample_embeddings_small): + result = _reduce_dim_sklearn(sample_embeddings_small, "TSNE", seed=42, n_workers=1) + assert result.shape == (10, 2) + + def test_umap_output_shape(self, sample_embeddings_small): + result = _reduce_dim_sklearn(sample_embeddings_small, "UMAP", seed=42, n_workers=1) + assert result.shape == (10, 2) + + def test_deterministic_with_seed(self, sample_embeddings_small): + r1 = _reduce_dim_sklearn(sample_embeddings_small, "PCA", seed=42, n_workers=1) + r2 = _reduce_dim_sklearn(sample_embeddings_small, "PCA", seed=42, n_workers=1) + np.testing.assert_array_equal(r1, r2) + + def test_invalid_method_raises(self, sample_embeddings_small): + with pytest.raises(ValueError, match="Unsupported method"): + _reduce_dim_sklearn(sample_embeddings_small, "INVALID", seed=42, n_workers=1) + + +class TestReduceDim: + def test_sklearn_backend(self, sample_embeddings_small): + result = reduce_dim(sample_embeddings_small, "PCA", seed=42, backend="sklearn") + assert result.shape == (10, 2) + + def test_unknown_method_raises(self, sample_embeddings_small): + with pytest.raises(ValueError): + reduce_dim(sample_embeddings_small, "INVALID", seed=42, backend="sklearn") + + +# --------------------------------------------------------------------------- +# run_kmeans — sklearn path +# --------------------------------------------------------------------------- + +class TestRunKmeansSklearn: + def test_returns_labels_and_object(self, sample_embeddings_small): + kmeans, labels = _run_kmeans_sklearn( + sample_embeddings_small.astype(np.float32), n_clusters=3, seed=42 + ) + assert labels.shape == (10,) + assert hasattr(kmeans, "cluster_centers_") + + def test_labels_in_range(self, sample_embeddings_small): + _, labels = _run_kmeans_sklearn( + sample_embeddings_small.astype(np.float32), n_clusters=3, seed=42 + ) + assert set(labels).issubset(set(range(3))) + + def test_deterministic_with_seed(self, sample_embeddings_small): + _, l1 = _run_kmeans_sklearn(sample_embeddings_small.astype(np.float32), 3, seed=42) + _, l2 = _run_kmeans_sklearn(sample_embeddings_small.astype(np.float32), 3, seed=42) + np.testing.assert_array_equal(l1, l2) + + +class TestRunKmeans: + def test_sklearn_backend(self, sample_embeddings_small): + _, labels = run_kmeans(sample_embeddings_small, 3, seed=42, backend="sklearn") + assert labels.shape == (10,) + + def test_auto_backend_small_dataset(self, sample_embeddings_small): + # Small dataset (10 samples) should use sklearn even on auto + _, labels = run_kmeans(sample_embeddings_small, 3, seed=42, backend="auto") + assert labels.shape == (10,) + + +# --------------------------------------------------------------------------- +# GPU fallback (mocked) +# --------------------------------------------------------------------------- + +class TestGPUFallback: + def test_reduce_dim_cuml_fallback(self, sample_embeddings_small): + """When cuML cp.asarray raises RuntimeError, _reduce_dim_cuml falls back to sklearn.""" + # Mock cupy so the cuML code path can execute, then fail on cp.asarray + mock_cp = MagicMock() + mock_cp.asarray.side_effect = RuntimeError("CUDA error: no kernel image") + mock_cp.float32 = np.float32 + + # Patch the 'import cupy as cp' inside _reduce_dim_cuml + with patch.dict(sys.modules, {"cupy": mock_cp}): + emb = sample_embeddings_small.astype(np.float32) + result = _reduce_dim_cuml(emb, "PCA", seed=42, n_workers=1) + assert result.shape == (10, 2) + + def test_umap_subprocess_crash_raises(self, sample_embeddings_small): + """Subprocess returning non-zero should raise RuntimeError.""" + mock_result = MagicMock() + mock_result.returncode = 1 + mock_result.stderr = "Segmentation fault (SIGFPE)" + + with patch("shared.utils.clustering.subprocess.run", return_value=mock_result), \ + patch("shared.utils.clustering.os.path.exists", return_value=False): + with pytest.raises(RuntimeError, match="subprocess failed"): + _run_cuml_umap_subprocess(sample_embeddings_small.astype(np.float32), seed=42) + + def test_umap_subprocess_cleans_temp_files(self, tmp_path, sample_embeddings_small): + """Temp files should be cleaned up even on failure.""" + mock_result = MagicMock() + mock_result.returncode = 1 + mock_result.stderr = "crash" + + with patch("shared.utils.clustering.subprocess.run", return_value=mock_result), \ + patch("shared.utils.clustering.os.path.exists", return_value=False), \ + patch("shared.utils.clustering.os.path.isdir", return_value=True), \ + patch("shared.utils.clustering.os.unlink") as mock_unlink: + with pytest.raises(RuntimeError): + _run_cuml_umap_subprocess(sample_embeddings_small.astype(np.float32), seed=42) + # unlink called for both input and output paths + assert mock_unlink.call_count == 2 diff --git a/tests/test_clustering_service.py b/tests/test_clustering_service.py new file mode 100644 index 0000000..9c3fb40 --- /dev/null +++ b/tests/test_clustering_service.py @@ -0,0 +1,108 @@ +"""Tests for shared/services/clustering_service.py.""" + +from unittest.mock import patch + +import numpy as np +import pandas as pd +import pytest + +from shared.services.clustering_service import ClusteringService + + +# --------------------------------------------------------------------------- +# generate_clustering_summary (pure — no mocking needed) +# --------------------------------------------------------------------------- + +class TestGenerateClusteringSummary: + def _make_inputs(self, n_samples=20, n_features=32, n_clusters=3): + rng = np.random.RandomState(42) + embeddings = rng.randn(n_samples, n_features).astype(np.float32) + labels = rng.randint(0, n_clusters, size=n_samples) + df_plot = pd.DataFrame({ + "x": rng.randn(n_samples), + "y": rng.randn(n_samples), + "cluster": labels.astype(str), + "image_path": [f"/img/{i}.jpg" for i in range(n_samples)], + "idx": range(n_samples), + }) + return embeddings, labels, df_plot + + def test_summary_columns(self): + emb, labels, df = self._make_inputs() + summary, _ = ClusteringService.generate_clustering_summary(emb, labels, df) + assert set(summary.columns) == {"Cluster", "Count", "Variance"} + + def test_counts_sum_to_total(self): + emb, labels, df = self._make_inputs(n_samples=50) + summary, _ = ClusteringService.generate_clustering_summary(emb, labels, df) + assert summary["Count"].sum() == 50 + + def test_representatives_per_cluster(self): + emb, labels, df = self._make_inputs(n_samples=30, n_clusters=3) + _, reps = ClusteringService.generate_clustering_summary(emb, labels, df) + for cluster_id, indices in reps.items(): + cluster_size = (labels == cluster_id).sum() + assert len(indices) <= min(3, cluster_size) + + def test_single_sample_cluster(self): + """Cluster with 1 sample should have variance 0.""" + embeddings = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=np.float32) + labels = np.array([0, 1, 2]) + df = pd.DataFrame({"x": [0, 0, 0], "y": [0, 0, 0], "cluster": ["0", "1", "2"], "idx": [0, 1, 2]}) + summary, reps = ClusteringService.generate_clustering_summary(embeddings, labels, df) + # Each cluster has 1 sample → variance = 0 + assert all(summary["Variance"] == 0.0) + assert all(len(v) == 1 for v in reps.values()) + + +# --------------------------------------------------------------------------- +# run_clustering_safe — fallback chain (mocked) +# --------------------------------------------------------------------------- + +class TestRunClusteringSafe: + def _dummy_args(self): + rng = np.random.RandomState(0) + emb = rng.randn(20, 32).astype(np.float32) + paths = [f"uuid-{i}" for i in range(20)] + return emb, paths, 3, "PCA", 1, "auto", "auto", 42 + + def test_success_passthrough(self): + emb, paths, *rest = self._dummy_args() + # Should succeed via sklearn on CPU + df, labels = ClusteringService.run_clustering_safe(emb, paths, *rest) + assert len(df) == 20 + assert labels.shape == (20,) + + def test_gpu_error_triggers_sklearn_fallback(self): + emb, paths, *rest = self._dummy_args() + call_count = {"n": 0} + + def mock_run_clustering(embeddings, valid_paths, n_clusters, method, + n_workers, dim_backend, cluster_backend, seed): + call_count["n"] += 1 + if call_count["n"] == 1: + raise RuntimeError("CUDA error: no kernel image") + # Second call (fallback) should use sklearn + assert dim_backend == "sklearn" + assert cluster_backend == "sklearn" + return pd.DataFrame({"x": [0]*20, "y": [0]*20, "cluster": ["0"]*20, + "image_path": valid_paths, "file_name": valid_paths, + "idx": range(20)}), np.zeros(20, dtype=int) + + with patch.object(ClusteringService, "run_clustering", side_effect=mock_run_clustering): + df, labels = ClusteringService.run_clustering_safe(emb, paths, *rest) + assert call_count["n"] == 2 + + def test_oom_error_reraised(self): + emb, paths, *rest = self._dummy_args() + with patch.object(ClusteringService, "run_clustering", + side_effect=RuntimeError("CUDA out of memory")): + with pytest.raises(RuntimeError, match="out of memory"): + ClusteringService.run_clustering_safe(emb, paths, *rest) + + def test_non_gpu_error_reraised(self): + emb, paths, *rest = self._dummy_args() + with patch.object(ClusteringService, "run_clustering", + side_effect=RuntimeError("unexpected error")): + with pytest.raises(RuntimeError, match="unexpected error"): + ClusteringService.run_clustering_safe(emb, paths, *rest) diff --git a/tests/test_filters.py b/tests/test_filters.py new file mode 100644 index 0000000..30e6b15 --- /dev/null +++ b/tests/test_filters.py @@ -0,0 +1,163 @@ +"""Tests for filter logic in apps/precalculated/components/sidebar.py. + +These functions are pure data transformations — no Streamlit dependency. +""" + +import numpy as np +import pandas as pd +import pyarrow as pa +import pytest + +from apps.precalculated.components.sidebar import ( + apply_filters_arrow, + get_column_info_dynamic, + extract_embeddings_safe, + create_cluster_dataframe, +) + + +# --------------------------------------------------------------------------- +# apply_filters_arrow +# --------------------------------------------------------------------------- + +class TestApplyFiltersArrow: + def test_categorical_filter(self, sample_arrow_table): + result = apply_filters_arrow(sample_arrow_table, {"species": ["cat"]}) + species_vals = result.column("species").to_pylist() + assert all(v == "cat" for v in species_vals) + + def test_numeric_range_filter(self, sample_arrow_table): + result = apply_filters_arrow(sample_arrow_table, { + "weight": {"min": 1.0, "max": 10.0} + }) + weights = result.column("weight").to_pylist() + assert all(1.0 <= w <= 10.0 for w in weights) + + def test_text_filter(self, sample_arrow_table): + result = apply_filters_arrow(sample_arrow_table, {"notes": "kitten"}) + notes = result.column("notes").to_pylist() + assert all("kitten" in n.lower() for n in notes) + + def test_text_filter_case_insensitive(self, sample_arrow_table): + result = apply_filters_arrow(sample_arrow_table, {"notes": "HEALTHY"}) + assert len(result) > 0 + + def test_multiple_filters_and_logic(self, sample_arrow_table): + result = apply_filters_arrow(sample_arrow_table, { + "species": ["cat"], + "weight": {"min": 3.0, "max": 5.0}, + }) + for i in range(len(result)): + assert result.column("species")[i].as_py() == "cat" + assert 3.0 <= result.column("weight")[i].as_py() <= 5.0 + + def test_empty_filters_returns_original(self, sample_arrow_table): + result = apply_filters_arrow(sample_arrow_table, {}) + assert len(result) == len(sample_arrow_table) + + def test_unknown_column_skipped(self, sample_arrow_table): + result = apply_filters_arrow(sample_arrow_table, {"nonexistent": ["x"]}) + assert len(result) == len(sample_arrow_table) + + def test_empty_list_filter_skipped(self, sample_arrow_table): + result = apply_filters_arrow(sample_arrow_table, {"species": []}) + assert len(result) == len(sample_arrow_table) + + +# --------------------------------------------------------------------------- +# get_column_info_dynamic +# --------------------------------------------------------------------------- + +class TestGetColumnInfoDynamic: + def test_detects_categorical(self, sample_arrow_table): + info = get_column_info_dynamic(sample_arrow_table) + assert info["species"]["type"] == "categorical" + + def test_detects_numeric(self, sample_arrow_table): + info = get_column_info_dynamic(sample_arrow_table) + assert info["weight"]["type"] == "numeric" + + def test_skips_excluded_columns(self, sample_arrow_table): + info = get_column_info_dynamic(sample_arrow_table) + assert "uuid" not in info + assert "emb" not in info + + def test_null_counting(self): + table = pa.table({ + "col": [1, None, 3, None, 5], + }) + info = get_column_info_dynamic(table) + assert info["col"]["null_count"] == 2 + assert info["col"]["null_percentage"] == 40.0 + + def test_high_cardinality_becomes_text(self): + """Columns with >100 unique values should be classified as text.""" + table = pa.table({ + "many_unique": [f"val_{i}" for i in range(150)], + }) + info = get_column_info_dynamic(table) + assert info["many_unique"]["type"] == "text" + + +# --------------------------------------------------------------------------- +# extract_embeddings_safe +# --------------------------------------------------------------------------- + +class TestExtractEmbeddingsSafe: + def test_valid_extraction(self): + emb_data = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]] + df = pd.DataFrame({"emb": emb_data, "id": [1, 2]}) + result = extract_embeddings_safe(df) + assert result.shape == (2, 3) + assert result.dtype == np.float32 + + def test_missing_emb_column_raises(self): + df = pd.DataFrame({"id": [1, 2]}) + with pytest.raises(ValueError, match="emb"): + extract_embeddings_safe(df) + + +# --------------------------------------------------------------------------- +# create_cluster_dataframe +# --------------------------------------------------------------------------- + +class TestCreateClusterDataframe: + def test_required_columns(self): + df = pd.DataFrame({ + "uuid": ["a", "b", "c"], + "emb": [[1, 2], [3, 4], [5, 6]], + "species": ["cat", "dog", "bird"], + }) + emb_2d = np.array([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]]) + labels = np.array([0, 1, 0]) + + result = create_cluster_dataframe(df, emb_2d, labels) + assert "x" in result.columns + assert "y" in result.columns + assert "cluster" in result.columns + assert "uuid" in result.columns + assert "idx" in result.columns + + def test_metadata_columns_copied(self): + df = pd.DataFrame({ + "uuid": ["a", "b"], + "emb": [[1, 2], [3, 4]], + "species": ["cat", "dog"], + }) + emb_2d = np.array([[0.1, 0.2], [0.3, 0.4]]) + labels = np.array([0, 1]) + + result = create_cluster_dataframe(df, emb_2d, labels) + assert "species" in result.columns + + def test_embedding_columns_excluded(self): + df = pd.DataFrame({ + "uuid": ["a", "b"], + "emb": [[1, 2], [3, 4]], + "embedding": [[1, 2], [3, 4]], + }) + emb_2d = np.array([[0.1, 0.2], [0.3, 0.4]]) + labels = np.array([0, 1]) + + result = create_cluster_dataframe(df, emb_2d, labels) + assert "embedding" not in result.columns diff --git a/tests/test_logging_config.py b/tests/test_logging_config.py new file mode 100644 index 0000000..8e969ae --- /dev/null +++ b/tests/test_logging_config.py @@ -0,0 +1,44 @@ +"""Tests for shared/utils/logging_config.py.""" + +import logging +import os + +from shared.utils.logging_config import configure_logging, get_logger + + +class TestGetLogger: + def test_returns_logger_with_correct_name(self, reset_logging): + logger = get_logger("my.module") + assert logger.name == "my.module" + + def test_returns_logger_instance(self, reset_logging): + logger = get_logger("test") + assert isinstance(logger, logging.Logger) + + +class TestConfigureLogging: + def test_adds_console_handler(self, reset_logging): + configure_logging() + root = logging.getLogger() + stream_handlers = [h for h in root.handlers if isinstance(h, logging.StreamHandler) + and not isinstance(h, logging.FileHandler)] + assert len(stream_handlers) == 1 + + def test_idempotent(self, reset_logging): + configure_logging() + handler_count = len(logging.getLogger().handlers) + configure_logging() + assert len(logging.getLogger().handlers) == handler_count + + def test_file_handler_created(self, reset_logging, tmp_path): + import shared.utils.logging_config as log_mod + original_dir = log_mod._LOG_DIR + log_mod._LOG_DIR = str(tmp_path) + try: + configure_logging(log_to_file=True) + root = logging.getLogger() + file_handlers = [h for h in root.handlers if isinstance(h, logging.FileHandler)] + assert len(file_handlers) == 1 + assert os.path.exists(os.path.join(str(tmp_path), "emb_explorer.log")) + finally: + log_mod._LOG_DIR = original_dir diff --git a/tests/test_taxonomy_tree.py b/tests/test_taxonomy_tree.py new file mode 100644 index 0000000..37e56ed --- /dev/null +++ b/tests/test_taxonomy_tree.py @@ -0,0 +1,138 @@ +"""Tests for shared/utils/taxonomy_tree.py.""" + +import numpy as np +import pandas as pd +import pytest + +from shared.utils.taxonomy_tree import ( + build_taxonomic_tree, + format_tree_string, + get_total_count, + get_tree_statistics, +) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_taxonomy_df(rows): + """Create DataFrame from list of (kingdom, phylum, class, order, family, genus, species) tuples.""" + cols = ["kingdom", "phylum", "class", "order", "family", "genus", "species"] + return pd.DataFrame(rows, columns=cols) + + +# --------------------------------------------------------------------------- +# build_taxonomic_tree +# --------------------------------------------------------------------------- + +class TestBuildTaxonomicTree: + def test_basic_nesting(self): + df = _make_taxonomy_df([ + ("Animalia", "Chordata", "Mammalia", "Carnivora", "Felidae", "Felis", "F. catus"), + ("Animalia", "Chordata", "Mammalia", "Carnivora", "Felidae", "Felis", "F. catus"), + ("Animalia", "Chordata", "Aves", "Passeriformes", "Passeridae", "Passer", "P. domesticus"), + ]) + tree = build_taxonomic_tree(df) + assert "Animalia" in tree + assert tree["Animalia"]["Chordata"]["Mammalia"]["Carnivora"]["Felidae"]["Felis"]["F. catus"] == 2 + + def test_nan_kingdom_excluded(self): + df = _make_taxonomy_df([ + (np.nan, "Chordata", "Mammalia", "Carnivora", "Felidae", "Felis", "F. catus"), + ("Animalia", "Chordata", "Aves", "Passeriformes", "Passeridae", "Passer", "P. domesticus"), + ]) + tree = build_taxonomic_tree(df) + assert get_total_count(tree) == 1 + + def test_nan_lower_level_becomes_unknown(self): + df = _make_taxonomy_df([ + ("Animalia", "Chordata", np.nan, np.nan, np.nan, np.nan, np.nan), + ]) + tree = build_taxonomic_tree(df) + assert "Unknown" in tree["Animalia"]["Chordata"] + + def test_empty_dataframe(self): + df = _make_taxonomy_df([]) + tree = build_taxonomic_tree(df) + assert tree == {} + + +# --------------------------------------------------------------------------- +# get_total_count +# --------------------------------------------------------------------------- + +class TestGetTotalCount: + def test_int_leaf(self): + assert get_total_count(5) == 5 + + def test_nested_dict(self): + tree = {"a": {"b": 3, "c": 2}, "d": 1} + assert get_total_count(tree) == 6 + + def test_empty_dict(self): + assert get_total_count({}) == 0 + + def test_non_int_non_dict(self): + assert get_total_count("invalid") == 0 + + +# --------------------------------------------------------------------------- +# format_tree_string +# --------------------------------------------------------------------------- + +class TestFormatTreeString: + def test_max_depth_truncation(self): + df = _make_taxonomy_df([ + ("Animalia", "Chordata", "Mammalia", "Carnivora", "Felidae", "Felis", "F. catus"), + ]) + tree = build_taxonomic_tree(df) + output = format_tree_string(tree, max_depth=2) + # Should show kingdom and phylum but not deeper + assert "Animalia" in output + assert "Chordata" in output + assert "Mammalia" not in output + + def test_min_count_filtering(self): + df = _make_taxonomy_df([ + ("Animalia", "Chordata", "Mammalia", "Carnivora", "Felidae", "Felis", "F. catus"), + ("Animalia", "Chordata", "Mammalia", "Carnivora", "Felidae", "Felis", "F. catus"), + ("Plantae", "Tracheophyta", "Magnoliopsida", "Rosales", "Rosaceae", "Rosa", "R. gallica"), + ]) + tree = build_taxonomic_tree(df) + output = format_tree_string(tree, min_count=2) + assert "Animalia" in output + # Plantae has count 1, should be filtered out + assert "Plantae" not in output + + def test_tree_connector_chars(self): + df = _make_taxonomy_df([ + ("Animalia", "Chordata", "Mammalia", "Carnivora", "Felidae", "Felis", "F. catus"), + ("Animalia", "Chordata", "Aves", "Passeriformes", "Passeridae", "Passer", "P. domesticus"), + ]) + tree = build_taxonomic_tree(df) + output = format_tree_string(tree) + # Should contain tree-drawing characters + assert any(c in output for c in ["├──", "└──"]) + + +# --------------------------------------------------------------------------- +# get_tree_statistics +# --------------------------------------------------------------------------- + +class TestGetTreeStatistics: + def test_counts(self): + df = _make_taxonomy_df([ + ("Animalia", "Chordata", "Mammalia", "Carnivora", "Felidae", "Felis", "F. catus"), + ("Animalia", "Chordata", "Aves", "Passeriformes", "Passeridae", "Passer", "P. domesticus"), + ]) + tree = build_taxonomic_tree(df) + stats = get_tree_statistics(tree) + assert stats["total_records"] == 2 + assert stats["kingdoms"] == 1 + assert stats["species"] == 2 + + def test_empty_tree(self): + stats = get_tree_statistics({}) + assert stats["total_records"] == 0 + assert stats["kingdoms"] == 0 diff --git a/utils/__init__.py b/utils/__init__.py deleted file mode 100644 index f3f8770..0000000 --- a/utils/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -""" -Utilities for the embedding explorer project. -""" - -__version__ = "0.1.0" \ No newline at end of file diff --git a/utils/clustering.py b/utils/clustering.py deleted file mode 100644 index 9e7801e..0000000 --- a/utils/clustering.py +++ /dev/null @@ -1,264 +0,0 @@ -from typing import Optional -import numpy as np -from sklearn.cluster import KMeans -from sklearn.decomposition import PCA -from sklearn.manifold import TSNE -from umap import UMAP - -# Optional FAISS support for faster clustering -try: - import faiss - HAS_FAISS = True -except ImportError: - HAS_FAISS = False - -# Optional cuML support for GPU acceleration -try: - import cuml - from cuml.cluster import KMeans as cuKMeans - from cuml.decomposition import PCA as cuPCA - from cuml.manifold import TSNE as cuTSNE - from cuml.manifold import UMAP as cuUMAP - import cupy as cp - HAS_CUML = True -except ImportError: - HAS_CUML = False - -# Check for CUDA availability -try: - import torch - HAS_CUDA = torch.cuda.is_available() -except ImportError: - try: - import cupy as cp - HAS_CUDA = cp.cuda.is_available() - except ImportError: - HAS_CUDA = False - -def reduce_dim(embeddings: np.ndarray, method: str = "PCA", seed: Optional[int] = None, n_workers: int = 1, backend: str = "auto"): - """ - Reduce the dimensionality of embeddings to 2D using PCA, t-SNE, or UMAP. - - Args: - embeddings (np.ndarray): The input feature embeddings of shape (n_samples, n_features). - method (str, optional): The dimensionality reduction method, "PCA", "TSNE", or "UMAP". Defaults to "PCA". - seed (int, optional): Random seed for reproducibility. Defaults to None (random). - n_workers (int, optional): Number of parallel workers for t-SNE/UMAP. Defaults to 1. - backend (str, optional): Backend to use - "auto", "sklearn", "cuml". Defaults to "auto". - - Returns: - np.ndarray: The 2D reduced embeddings of shape (n_samples, 2). - - Raises: - ValueError: If an unsupported method is provided. - """ - # Determine which backend to use - use_cuml = False - if backend == "cuml" and HAS_CUML and HAS_CUDA: - use_cuml = True - elif backend == "auto" and HAS_CUML and HAS_CUDA and embeddings.shape[0] > 5000: - # Use cuML automatically for large datasets on GPU - use_cuml = True - - if use_cuml: - return _reduce_dim_cuml(embeddings, method, seed, n_workers) - else: - return _reduce_dim_sklearn(embeddings, method, seed, n_workers) - - -def _reduce_dim_sklearn(embeddings: np.ndarray, method: str, seed: Optional[int], n_workers: int): - """Dimensionality reduction using sklearn/umap backends.""" - if method.upper() == "PCA": - reducer = PCA(n_components=2) - elif method.upper() == "TSNE": - # Adjust perplexity to be valid for the sample size - n_samples = embeddings.shape[0] - perplexity = min(30, max(5, n_samples // 3)) # Ensure perplexity is reasonable - - if seed is not None: - reducer = TSNE(n_components=2, perplexity=perplexity, random_state=seed, n_jobs=n_workers) - else: - reducer = TSNE(n_components=2, perplexity=perplexity, n_jobs=n_workers) - elif method.upper() == "UMAP": - if seed is not None: - reducer = UMAP(n_components=2, random_state=seed, n_jobs=n_workers) - else: - reducer = UMAP(n_components=2, n_jobs=n_workers) - else: - raise ValueError("Unsupported method. Choose 'PCA', 'TSNE', or 'UMAP'.") - return reducer.fit_transform(embeddings) - - -def _reduce_dim_cuml(embeddings: np.ndarray, method: str, seed: Optional[int], n_workers: int): - """Dimensionality reduction using cuML GPU backends.""" - try: - # Convert to cupy array for GPU processing - embeddings_gpu = cp.asarray(embeddings, dtype=cp.float32) - - if method.upper() == "PCA": - reducer = cuPCA(n_components=2) - elif method.upper() == "TSNE": - # Adjust perplexity to be valid for the sample size - n_samples = embeddings.shape[0] - perplexity = min(30, max(5, n_samples // 3)) # Ensure perplexity is reasonable - - if seed is not None: - reducer = cuTSNE(n_components=2, perplexity=perplexity, random_state=seed) - else: - reducer = cuTSNE(n_components=2, perplexity=perplexity) - elif method.upper() == "UMAP": - if seed is not None: - reducer = cuUMAP(n_components=2, random_state=seed) - else: - reducer = cuUMAP(n_components=2) - else: - raise ValueError("Unsupported method. Choose 'PCA', 'TSNE', or 'UMAP'.") - - # Fit and transform on GPU - result_gpu = reducer.fit_transform(embeddings_gpu) - - # Convert back to numpy array - return cp.asnumpy(result_gpu) - - except Exception as e: - print(f"cuML reduction failed ({e}), falling back to sklearn") - return _reduce_dim_sklearn(embeddings, method, seed, n_workers) - -def run_kmeans(embeddings: np.ndarray, n_clusters: int, seed: Optional[int] = None, n_workers: int = 1, backend: str = "auto"): - """ - Perform KMeans clustering on the given embeddings. - - Args: - embeddings (np.ndarray): The input feature embeddings of shape (n_samples, n_features). - n_clusters (int): The number of clusters to form. - seed (int, optional): Random seed for reproducibility. Defaults to None (random). - n_workers (int, optional): Number of parallel workers (used by FAISS and cuML if available). - backend (str, optional): Clustering backend - "auto", "sklearn", "faiss", or "cuml". Defaults to "auto". - - Returns: - kmeans (KMeans or custom object): The fitted clustering object. - labels (np.ndarray): Cluster labels for each sample. - """ - # Determine which backend to use - if backend == "cuml" and HAS_CUML and HAS_CUDA: - return _run_kmeans_cuml(embeddings, n_clusters, seed, n_workers) - elif backend == "faiss" and HAS_FAISS: - return _run_kmeans_faiss(embeddings, n_clusters, seed, n_workers) - elif backend == "auto": - # Auto selection priority: cuML > FAISS > sklearn - if HAS_CUML and HAS_CUDA and embeddings.shape[0] > 500: - return _run_kmeans_cuml(embeddings, n_clusters, seed, n_workers) - elif HAS_FAISS and embeddings.shape[0] > 500: - return _run_kmeans_faiss(embeddings, n_clusters, seed, n_workers) - else: - return _run_kmeans_sklearn(embeddings, n_clusters, seed) - else: - return _run_kmeans_sklearn(embeddings, n_clusters, seed) - - -def _run_kmeans_cuml(embeddings: np.ndarray, n_clusters: int, seed: Optional[int] = None, n_workers: int = 1): - """KMeans using cuML GPU backend.""" - try: - # Convert to cupy array for GPU processing - embeddings_gpu = cp.asarray(embeddings, dtype=cp.float32) - - # Create cuML KMeans object - if seed is not None: - kmeans = cuKMeans( - n_clusters=n_clusters, - random_state=seed, - max_iter=300, - init='k-means++', - tol=1e-4 - ) - else: - kmeans = cuKMeans( - n_clusters=n_clusters, - max_iter=300, - init='k-means++', - tol=1e-4 - ) - - # Fit and predict on GPU - labels_gpu = kmeans.fit_predict(embeddings_gpu) - - # Convert results back to numpy - labels = cp.asnumpy(labels_gpu) - centroids = cp.asnumpy(kmeans.cluster_centers_) - - # Create a simple object to mimic sklearn KMeans interface - class cuMLKMeans: - def __init__(self, centroids, labels): - self.cluster_centers_ = centroids - self.labels_ = labels - self.n_clusters = len(centroids) - - return cuMLKMeans(centroids, labels), labels - - except Exception as e: - print(f"cuML clustering failed ({e}), falling back to sklearn") - return _run_kmeans_sklearn(embeddings, n_clusters, seed) - - -def _run_kmeans_sklearn(embeddings: np.ndarray, n_clusters: int, seed: Optional[int] = None): - """KMeans using scikit-learn backend.""" - if seed is not None: - kmeans = KMeans(n_clusters=n_clusters, random_state=seed) - else: - kmeans = KMeans(n_clusters=n_clusters) - labels = kmeans.fit_predict(embeddings) - return kmeans, labels - - -def _run_kmeans_faiss(embeddings: np.ndarray, n_clusters: int, seed: Optional[int] = None, n_workers: int = 1): - """KMeans using FAISS backend for faster clustering.""" - try: - import faiss - - # Ensure embeddings are float32 and C-contiguous (FAISS requirement) - embeddings = np.ascontiguousarray(embeddings.astype(np.float32)) - - n_samples, d = embeddings.shape - - # Set number of threads for FAISS - if n_workers > 1: - faiss.omp_set_num_threads(n_workers) - - # Create FAISS KMeans object - kmeans = faiss.Clustering(d, n_clusters) - - # Set clustering parameters - kmeans.verbose = False - kmeans.niter = 20 # Number of iterations - kmeans.nredo = 1 # Number of redos - if seed is not None: - kmeans.seed = seed - - # Use L2 distance (equivalent to sklearn's default) - index = faiss.IndexFlatL2(d) - - # Run clustering - kmeans.train(embeddings, index) - - # Get centroids - centroids = faiss.vector_to_array(kmeans.centroids).reshape(n_clusters, d) - - # Assign labels by finding nearest centroid for each point - _, labels = index.search(embeddings, 1) - labels = labels.flatten() - - # Create a simple object to mimic sklearn KMeans interface - class FAISSKMeans: - def __init__(self, centroids, labels): - self.cluster_centers_ = centroids - self.labels_ = labels - self.n_clusters = len(centroids) - - return FAISSKMeans(centroids, labels), labels - - except Exception as e: - # Fallback to sklearn if FAISS fails - print(f"FAISS clustering failed ({e}), falling back to sklearn") - return _run_kmeans_sklearn(embeddings, n_clusters, seed) - -