diff --git a/.flake8 b/.flake8 new file mode 100644 index 0000000..37776aa --- /dev/null +++ b/.flake8 @@ -0,0 +1,9 @@ +[flake8] +max-line-length = 100 +extend-ignore = E203,W503,D100,D104 +per-file-ignores = + forte_api.py:D,F401,F841,E501 + forte_demo.py:D,F401,E501 + examples/*.py:D,E501 + tests/*.py:D,F401,F811,F841 + __init__.py:F401,D415 diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..e39a0c4 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,134 @@ +name: CI + +on: + push: + branches: [ main, develop ] + pull_request: + branches: [ main, develop ] + +jobs: + test: + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest, macos-latest, windows-latest] + python-version: ['3.9', '3.10', '3.11', '3.12'] + + steps: + - uses: actions/checkout@v3 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + + - name: Cache pip dependencies + uses: actions/cache@v3 + with: + path: ~/.cache/pip + key: ${{ runner.os }}-pip-${{ hashFiles('**/pyproject.toml') }} + restore-keys: | + ${{ runner.os }}-pip- + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -e ".[dev]" + + - name: Lint with flake8 + run: | + # Stop the build if there are Python syntax errors or undefined names + flake8 src/forte --count --select=E9,F63,F7,F82 --show-source --statistics + # Exit-zero treats all errors as warnings + flake8 src/forte --count --exit-zero --max-complexity=10 --max-line-length=100 --statistics + + - name: Format check with black + run: | + black --check src/forte tests + + - name: Run tests with pytest + run: | + pytest tests/ -v --cov=forte --cov-report=xml --cov-report=term -m "not slow" + + - name: Upload coverage to Codecov + uses: codecov/codecov-action@v3 + with: + file: ./coverage.xml + flags: unittests + name: codecov-umbrella + fail_ci_if_error: false + + integration-tests: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: '3.10' + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -e ".[dev]" + + - name: Run integration tests + run: | + pytest tests/ -v -m "integration and not slow" + + build: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: '3.10' + + - name: Install build dependencies + run: | + python -m pip install --upgrade pip + pip install build twine + + - name: Build package + run: python -m build + + - name: Check package with twine + run: twine check dist/* + + - name: Upload artifacts + uses: actions/upload-artifact@v4 + with: + name: dist + path: dist/ + + docs: + runs-on: ubuntu-latest + if: github.event_name == 'push' && (github.ref == 'refs/heads/main' || github.ref == 'refs/heads/develop') + permissions: + contents: write + steps: + - uses: actions/checkout@v3 + with: + fetch-depth: 0 + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: '3.10' + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -e ".[docs]" + + - name: Configure Git + run: | + git config user.name "github-actions[bot]" + git config user.email "github-actions[bot]@users.noreply.github.com" + + - name: Deploy documentation + run: mkdocs gh-deploy --force diff --git a/.gitignore b/.gitignore index 1694133..f8b2959 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,12 @@ +# Project-specific data/* embeddings/* *.png +*.jpg +*.jpeg + +# Keep example images if needed +!docs/images/ # Byte-compiled / optimized / DLL files __pycache__/ @@ -149,6 +155,7 @@ venv.bak/ # mkdocs documentation /site +site/ # mypy .mypy_cache/ @@ -175,4 +182,4 @@ cython_debug/ .ruff_cache/ # PyPI configuration file -.pypirc \ No newline at end of file +.pypirc diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..1520611 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,69 @@ +# See https://pre-commit.com for more information +# See https://pre-commit.com/hooks.html for more hooks +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.5.0 + hooks: + - id: trailing-whitespace + - id: end-of-file-fixer + - id: check-yaml + exclude: ^mkdocs\.yml$ + - id: check-added-large-files + - id: check-json + - id: check-toml + - id: check-merge-conflict + - id: check-case-conflict + - id: detect-private-key + - id: mixed-line-ending + args: ['--fix=lf'] + - id: name-tests-test + args: ['--pytest-test-first'] + + - repo: https://github.com/psf/black + rev: 24.1.1 + hooks: + - id: black + language_version: python3.9 + + - repo: https://github.com/pycqa/isort + rev: 5.13.2 + hooks: + - id: isort + name: isort (python) + + - repo: https://github.com/pycqa/flake8 + rev: 7.0.0 + hooks: + - id: flake8 + args: ['--max-line-length=100', '--extend-ignore=E203,W503'] + additional_dependencies: [flake8-docstrings] + + - repo: https://github.com/pre-commit/mirrors-mypy + rev: v1.8.0 + hooks: + - id: mypy + additional_dependencies: [ + torch, + torchvision, + transformers, + numpy, + scipy, + scikit-learn, + pillow, + tqdm, + ] + args: [--config-file=pyproject.toml, --ignore-missing-imports] + + - repo: https://github.com/PyCQA/bandit + rev: 1.7.6 + hooks: + - id: bandit + args: ['-c', 'pyproject.toml'] + additional_dependencies: ['bandit[toml]'] + + - repo: https://github.com/pycqa/pydocstyle + rev: 6.3.0 + hooks: + - id: pydocstyle + args: ['--convention=google', '--add-ignore=D212'] + exclude: '^(tests/|examples/|forte_api\.py|forte_demo\.py)' diff --git a/CHANGELOG.md b/CHANGELOG.md index 29ebcc6..e1da0de 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,4 +12,4 @@ - Basic evaluation metrics: AUROC, FPR@95TPR, AUPRC, F1 ### Fixed -- None (initial release) \ No newline at end of file +- None (initial release) diff --git a/LICENSE b/LICENSE index 23b3df2..14f6a44 100644 --- a/LICENSE +++ b/LICENSE @@ -18,4 +18,4 @@ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. \ No newline at end of file +SOFTWARE. diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 0000000..f066133 --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1,28 @@ +# Include documentation +include README.md +include LICENSE +include CHANGELOG.md + +# Include package configuration +include pyproject.toml +include setup.py + +# Exclude development and build files +exclude .gitignore +exclude .git* +recursive-exclude * __pycache__ +recursive-exclude * *.py[co] +recursive-exclude * .DS_Store + +# Exclude tests, docs, and examples from distribution +recursive-exclude tests * +recursive-exclude docs * +recursive-exclude examples * +recursive-exclude env * +recursive-exclude embeddings * +recursive-exclude data * + +# Exclude build artifacts +recursive-exclude dist * +recursive-exclude build * +recursive-exclude *.egg-info * diff --git a/README.md b/README.md index a04d4e3..2dafa0c 100644 --- a/README.md +++ b/README.md @@ -1,199 +1,55 @@ -# Forte API Documentation +# Forte -## Overview +[![PyPI](https://badge.fury.io/py/forte-detector.svg)](https://pypi.org/project/forte-detector/) +[![Python 3.9+](https://img.shields.io/badge/python-3.9+-blue.svg)](https://www.python.org/downloads/) +[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT) +[![ICLR 2025](https://img.shields.io/badge/ICLR-2025-red.svg)](https://openreview.net/pdf?id=7XNgVPxCiA) -The Forte library provides robust out-of-distribution (OOD) detection capabilities through the `ForteOODDetector` class. The core algorithm is built on the principle of **F**inding **O**utliers using **R**epresentation **T**ypicality **E**stimation, which: +Out-of-distribution detection via per-point manifold estimation on self-supervised representations. -1. Uses self-supervised vision models to extract semantic features -2. Incorporates manifold estimation to account for local topology -3. Requires no class labels or exposure to OOD data during training +**Paper**: [PDF](https://openreview.net/pdf?id=7XNgVPxCiA) | [arXiv](https://arxiv.org/abs/2410.01322) -This makes Forte particularly useful for real-world applications where anomalous data may be unexpected or unknown at training time. Our goal is to provide a non-opinionated middleware for OOD detection that seamlessly integrates into your ML deployment pipelines. +**Documentation**: [debarghag.github.io/forte-api](https://debarghag.github.io/forte-api) -**Why use Forte?** -Forte OOD Detection serves as middleware between your data ingestion and ML inference systems, by preventing models from making predictions on data they weren't designed to handle. +## Installation -ICICLE Tag : Foundation-AI - - -## How-To Guide - -**Key Features inside Forte** - -- **Multiple feature extractors**: Leverages CLIP, ViT-MSN, and DINOv2 models for robust semantic representation -- **Topology-aware scoring**: Uses Precision, Recall, Density, and Coverage (PRDC) metrics to capture manifold structure -- **Multiple detection methods**: Supports Gaussian Mixture Models (GMM), Kernel Density Estimation (KDE), and One-Class SVM (OCSVM) -- **Automatic hyperparameter selection**: Optimizes model hyperparameters using validation data -- **Caching for efficiency**: Saves extracted features to avoid redundant computation - -## API Reference - -### `ForteOODDetector` - -The main class for OOD detection. - -```python -detector = ForteOODDetector( - batch_size=32, - device=None, - embedding_dir="./embeddings", - nearest_k=5, - method='gmm' -) -``` - -#### Parameters - -- **batch_size** (int, default=32): Batch size for processing images during feature extraction -- **device** (str, default=None): Device to use for computation (e.g., 'cuda:0', 'cpu'). If None, uses CUDA if available -- **embedding_dir** (str, default='./embeddings'): Directory to store extracted features for caching -- **nearest_k** (int, default=5): Number of nearest neighbors for PRDC computation -- **method** (str, default='gmm'): Method to use for OOD detection. Options: - - 'gmm': Gaussian Mixture Model (best for clustered data) - - 'kde': Kernel Density Estimation (best for smooth distributions) - - 'ocsvm': One-Class SVM (best for complex boundaries) - -### Methods - -#### `fit(id_image_paths, val_split=0.2, random_state=42)` - -Fits the OOD detector on in-distribution data. - -**Parameters:** -- **id_image_paths** (list): List of paths to in-distribution images -- **val_split** (float, default=0.2): Fraction of data to use for validation -- **random_state** (int, default=42): Random seed for reproducibility - -**Returns:** -- The fitted detector object - -**Process:** -1. Splits data into training and validation sets -2. Extracts features using pretrained models -3. Computes PRDC features -4. Trains the OOD detector (GMM, KDE, or OCSVM) - -```python -detector.fit(id_image_paths, val_split=0.2, random_state=42) -``` - -#### `predict(image_paths)` - -Predicts if samples are OOD. - -**Parameters:** -- **image_paths** (list): List of paths to images - -**Returns:** -- Binary array (1 for in-distribution, -1 for OOD) - -```python -predictions = detector.predict(test_image_paths) +```bash +pip install forte-detector ``` -#### `predict_proba(image_paths)` - -Returns normalized probability scores for OOD detection. - -**Parameters:** -- **image_paths** (list): List of paths to images - -**Returns:** -- Array of normalized scores (higher values indicate in-distribution) +## Usage ```python -scores = detector.predict_proba(test_image_paths) -``` - -#### `evaluate(id_image_paths, ood_image_paths)` - -Evaluates the OOD detector on in-distribution and out-of-distribution data. +from forte import ForteOODDetector -**Parameters:** -- **id_image_paths** (list): List of paths to in-distribution images -- **ood_image_paths** (list): List of paths to out-of-distribution images - -**Returns:** -- Dictionary of evaluation metrics: - - **AUROC**: Area Under the Receiver Operating Characteristic curve - - **FPR@95TPR**: False Positive Rate at 95% True Positive Rate - - **AUPRC**: Area Under the Precision-Recall Curve - - **F1**: Maximum F1 score - -```python -metrics = detector.evaluate(id_image_paths, ood_image_paths) -print(f"AUROC: {metrics['AUROC']:.4f}") +detector = ForteOODDetector(method='gmm', device='cuda:0') +detector.fit(train_paths) +predictions = detector.predict(test_paths) +metrics = detector.evaluate(id_test_paths, ood_test_paths) ``` -## Tutorial - -### Basic Usage - -```python -from forte_api import ForteOODDetector -import glob - -# Collect in-distribution images -id_images = glob.glob("data/normal_class/*.jpg") - -# Split for training and testing -train_images = id_images[:800] -test_id_images = id_images[800:] +## Method -# Collect OOD images -ood_images = glob.glob("data/anomalies/*.jpg") +Forte detects OOD samples by: +1. Extracting features from CLIP, ViT-MSN, and DINOv2 +2. Computing per-point PRDC metrics using k-NN manifold geometry +3. Fitting a density estimator (GMM, KDE, or OCSVM) on PRDC features +4. Scoring test samples by typicality under the learned density -# Create and train detector -detector = ForteOODDetector( - batch_size=32, - device="cuda:0", - method="gmm" -) +No class labels or OOD exposure required during training. -# Train the detector -detector.fit(train_images) +## Citation -# Evaluate performance -metrics = detector.evaluate(test_id_images, ood_images) -print(f"AUROC: {metrics['AUROC']:.4f}") -print(f"FPR@95TPR: {metrics['FPR@95TPR']:.4f}") - -# Get predictions -predictions = detector.predict(ood_images) +```bibtex +@inproceedings{ganguly2025forte, + title={Forte: Finding Outliers with Representation Typicality Estimation}, + author={Ganguly, Debargha and Morningstar, Warren Richard and Yu, Andrew Seohwan and Chaudhary, Vipin}, + booktitle={The Thirteenth International Conference on Learning Representations}, + year={2025}, + url={https://openreview.net/pdf?id=7XNgVPxCiA} +} ``` -### Complete Example with CIFAR-10/CIFAR-100 - -For a complete example using CIFAR-10 as in-distribution and CIFAR-100 as out-of-distribution data, see the [examples/cifar_demo.py](examples/cifar_demo.py) script in the repository. - -### Experimenting with Different Methods - -```python -# Try different detection methods -methods = ['gmm', 'kde', 'ocsvm'] -results = {} - -for method in methods: - detector = ForteOODDetector(method=method) - detector.fit(train_images) - results[method] = detector.evaluate(test_id_images, ood_images) - -# Compare results -for method, metrics in results.items(): - print(f"{method.upper()} - AUROC: {metrics['AUROC']:.4f}, FPR@95TPR: {metrics['FPR@95TPR']:.4f}") -``` - -## Model Details - -### Feature Extraction Models - -Forte uses three pretrained models for feature extraction: - -1. **CLIP** (Contrastive Language-Image Pretraining): Captures semantic information aligned with natural language concepts -2. **ViT-MSN** (Vision Transformer with Masked Self-supervised Network): Captures fine-grained visual patterns -3. **DINOv2** (Self-supervised Vision Transformer): Captures hierarchical visual representations - -You may modify the code to use your own encoder if you wish. This may be a CNN or a ViT. Anything you want. +## License -### Acknowledgements -National Science Foundation (NSF) funded AI institute for Intelligent Cyberinfrastructure with Computational Learning in the Environment (ICICLE) (OAC 2112606) \ No newline at end of file +MIT. Supported by NSF ICICLE (OAC 2112606). diff --git a/docs/api-reference.md b/docs/api-reference.md new file mode 100644 index 0000000..db2a08c --- /dev/null +++ b/docs/api-reference.md @@ -0,0 +1,281 @@ +# API Reference + +## ForteOODDetector + +Main class for out-of-distribution detection. + +### Constructor + +```python +ForteOODDetector( + batch_size: int = 32, + device: str = None, + embedding_dir: str = "./embeddings", + nearest_k: int = 5, + method: str = "gmm" +) +``` + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `batch_size` | int | 32 | Images per forward pass | +| `device` | str | None | `'cuda:N'`, `'mps'`, or `'cpu'`. Auto-detects if None. | +| `embedding_dir` | str | `'./embeddings'` | Directory for cached features | +| `nearest_k` | int | 5 | k for k-NN in PRDC computation | +| `method` | str | `'gmm'` | Detection backend: `'gmm'`, `'kde'`, `'ocsvm'` | + +### Methods + +#### fit + +```python +fit(id_image_paths: List[str], val_split: float = 0.2, random_state: int = 42) -> ForteOODDetector +``` + +Train detector on in-distribution images. + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `id_image_paths` | List[str] | required | Paths to ID training images | +| `val_split` | float | 0.2 | Fraction for hyperparameter tuning | +| `random_state` | int | 42 | Random seed | + +Returns: `self` + +#### predict + +```python +predict(image_paths: List[str]) -> np.ndarray +``` + +Binary OOD classification. + +Returns: `np.ndarray` of shape `(n,)` with dtype `int64`. Values: `1` (in-distribution), `-1` (out-of-distribution). + +#### predict_proba + +```python +predict_proba(image_paths: List[str]) -> np.ndarray +``` + +Normalized OOD scores. + +Returns: `np.ndarray` of shape `(n,)` with dtype `float64`. Range `[0, 1]`. Higher values indicate in-distribution. + +#### evaluate + +```python +evaluate(id_image_paths: List[str], ood_image_paths: List[str]) -> Dict[str, float] +``` + +Compute evaluation metrics on labeled test data. + +Returns: `dict` with keys: +- `AUROC`: Area under ROC curve +- `FPR@95TPR`: False positive rate at 95% true positive rate +- `AUPRC`: Area under precision-recall curve +- `F1`: Maximum F1 score across thresholds + +--- + +## TorchGMM + +GPU-accelerated Gaussian Mixture Model. + +### Constructor + +```python +TorchGMM( + n_components: int = 1, + covariance_type: str = "full", + max_iter: int = 100, + tol: float = 1e-3, + reg_covar: float = 1e-6, + device: str = "cuda" +) +``` + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `n_components` | int | 1 | Number of mixture components | +| `covariance_type` | str | `"full"` | Only `"full"` supported | +| `max_iter` | int | 100 | Maximum EM iterations | +| `tol` | float | 1e-3 | Convergence threshold | +| `reg_covar` | float | 1e-6 | Covariance regularization | +| `device` | str | `"cuda"` | Computation device | + +### Methods + +#### fit + +```python +fit(X: torch.Tensor) -> TorchGMM +``` + +Fit GMM via EM algorithm. + +| Parameter | Type | Description | +|-----------|------|-------------| +| `X` | torch.Tensor | Shape `(n_samples, n_features)` | + +Returns: `self` + +#### score_samples + +```python +score_samples(X: torch.Tensor) -> torch.Tensor +``` + +Compute log-likelihood per sample. + +Returns: `torch.Tensor` of shape `(n_samples,)` + +#### bic + +```python +bic(X: torch.Tensor) -> float +``` + +Bayesian Information Criterion. + +Returns: `float`. Lower is better. + +--- + +## TorchKDE + +GPU-accelerated Kernel Density Estimation. + +### Constructor + +```python +TorchKDE( + dataset: torch.Tensor, + bw_method: Optional[Union[str, float, Callable]] = None, + weights: Optional[torch.Tensor] = None, + device: str = "cuda" +) +``` + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `dataset` | torch.Tensor | required | Shape `(d, n)` where d=dimension, n=samples | +| `bw_method` | str/float/Callable | None | `'scott'`, `'silverman'`, or scalar. None defaults to Scott. | +| `weights` | torch.Tensor | None | Sample weights of shape `(n,)` | +| `device` | str | `"cuda"` | Computation device | + +### Methods + +#### evaluate + +```python +evaluate(points: torch.Tensor) -> torch.Tensor +``` + +Evaluate density at given points. + +| Parameter | Type | Description | +|-----------|------|-------------| +| `points` | torch.Tensor | Shape `(d, m)` or `(m, d)` | + +Returns: `torch.Tensor` of shape `(m,)` + +#### logpdf + +```python +logpdf(points: torch.Tensor) -> torch.Tensor +``` + +Log probability density. + +Returns: `torch.Tensor` of shape `(m,)` + +#### scotts_factor + +```python +scotts_factor() -> float +``` + +Returns: Scott's bandwidth factor: $n_{\text{eff}}^{-1/(d+4)}$ + +#### silverman_factor + +```python +silverman_factor() -> float +``` + +Returns: Silverman's bandwidth factor: $(n_{\text{eff}}(d+2)/4)^{-1/(d+4)}$ + +--- + +## TorchOCSVM + +GPU-accelerated One-Class SVM. + +### Constructor + +```python +TorchOCSVM( + nu: float = 0.1, + n_iters: int = 1000, + lr: float = 1e-3, + device: str = "cuda" +) +``` + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `nu` | float | 0.1 | Upper bound on outlier fraction (0, 1) | +| `n_iters` | int | 1000 | Optimization iterations | +| `lr` | float | 1e-3 | Adam learning rate | +| `device` | str | `"cuda"` | Computation device | + +### Methods + +#### fit + +```python +fit(X: torch.Tensor) -> TorchOCSVM +``` + +Fit via gradient descent on primal objective. + +| Parameter | Type | Description | +|-----------|------|-------------| +| `X` | torch.Tensor | Shape `(n_samples, n_features)` | + +Returns: `self` + +#### decision_function + +```python +decision_function(X: torch.Tensor) -> torch.Tensor +``` + +Signed distance to decision boundary. + +Returns: `torch.Tensor` of shape `(n_samples,)`. Positive = inlier. + +#### predict + +```python +predict(X: torch.Tensor) -> torch.Tensor +``` + +Binary classification. + +Returns: `torch.Tensor` of shape `(n_samples,)`. Values: `1` (inlier), `-1` (outlier). + +--- + +## Module Exports + +```python +from forte import ( + ForteOODDetector, + TorchGMM, + TorchKDE, + TorchOCSVM, + __version__, +) +``` diff --git a/docs/citation.md b/docs/citation.md new file mode 100644 index 0000000..c2a868f --- /dev/null +++ b/docs/citation.md @@ -0,0 +1,45 @@ +# Citation + +## Paper + +```bibtex +@inproceedings{ganguly2025forte, + title={Forte: Finding Outliers with Representation Typicality Estimation}, + author={Ganguly, Debargha and Morningstar, Warren Richard and Yu, Andrew Seohwan and Chaudhary, Vipin}, + booktitle={The Thirteenth International Conference on Learning Representations}, + year={2025}, + url={https://openreview.net/pdf?id=7XNgVPxCiA} +} +``` + +**Links**: +- [PDF](https://openreview.net/pdf?id=7XNgVPxCiA) +- [arXiv](https://arxiv.org/abs/2410.01322) + +## Software + +```bibtex +@software{forte_detector, + author = {Ganguly, Debargha and Morningstar, Warren Richard and Yu, Andrew Seohwan and Chaudhary, Vipin}, + title = {Forte Detector}, + year = {2025}, + url = {https://github.com/debarghag/forte-detector} +} +``` + +## License + +MIT License. See [LICENSE](https://github.com/debarghag/forte-detector/blob/main/LICENSE). + +## Acknowledgements + +Supported by NSF ICICLE (OAC 2112606). + +Forte uses pretrained models from: +- [CLIP](https://github.com/openai/CLIP) (OpenAI) +- [ViT-MSN](https://github.com/facebookresearch/msn) (Meta AI) +- [DINOv2](https://github.com/facebookresearch/dinov2) (Meta AI) + +## Contact + +- GitHub Issues: [github.com/debarghag/forte-detector/issues](https://github.com/debarghag/forte-detector/issues) diff --git a/docs/examples.md b/docs/examples.md new file mode 100644 index 0000000..4cfdd3c --- /dev/null +++ b/docs/examples.md @@ -0,0 +1,65 @@ +# Examples + +## CIFAR-10 vs CIFAR-100 + +```python +import os +import torch +import torchvision +from torchvision import transforms +from forte import ForteOODDetector + +def save_images(dataset, path, n=1000): + os.makedirs(path, exist_ok=True) + paths = [] + for i in range(min(n, len(dataset))): + img, _ = dataset[i] + if isinstance(img, torch.Tensor): + img = transforms.ToPILImage()(img) + p = os.path.join(path, f"{i}.png") + img.save(p) + paths.append(p) + return paths + +cifar10_train = torchvision.datasets.CIFAR10('./data', train=True, download=True) +cifar10_test = torchvision.datasets.CIFAR10('./data', train=False, download=True) +cifar100_test = torchvision.datasets.CIFAR100('./data', train=False, download=True) + +id_train = save_images(cifar10_train, 'data/c10/train', 5000) +id_test = save_images(cifar10_test, 'data/c10/test', 1000) +ood_test = save_images(cifar100_test, 'data/c100/test', 1000) + +detector = ForteOODDetector(method='gmm', device='cuda:0') +detector.fit(id_train) +print(detector.evaluate(id_test, ood_test)) +``` + +## Custom Dataset + +```python +from pathlib import Path +from forte import ForteOODDetector + +id_train = list(Path('data/normal/train').glob('*.jpg')) +id_test = list(Path('data/normal/test').glob('*.jpg')) +ood_test = list(Path('data/anomaly').glob('*.jpg')) + +detector = ForteOODDetector(method='gmm') +detector.fit([str(p) for p in id_train]) +print(detector.evaluate([str(p) for p in id_test], [str(p) for p in ood_test])) +``` + +## Method Comparison + +```python +from forte import ForteOODDetector + +results = {} +for method in ['gmm', 'kde', 'ocsvm']: + det = ForteOODDetector(method=method, embedding_dir=f'./cache_{method}') + det.fit(train_paths) + results[method] = det.evaluate(id_test, ood_test) + +for m, r in results.items(): + print(f"{m}: AUROC={r['AUROC']:.4f} FPR@95={r['FPR@95TPR']:.4f}") +``` diff --git a/docs/index.md b/docs/index.md new file mode 100644 index 0000000..671a636 --- /dev/null +++ b/docs/index.md @@ -0,0 +1,59 @@ +# Forte + +Out-of-distribution detection via per-point manifold estimation on self-supervised representations. + +**Paper**: [ICLR 2025](https://openreview.net/pdf?id=7XNgVPxCiA) + +## Method + +Forte detects OOD samples by: + +1. Extracting features from CLIP, ViT-MSN, and DINOv2 +2. Computing per-point PRDC metrics using k-NN manifold geometry +3. Fitting a density estimator (GMM, KDE, or OCSVM) on the PRDC feature space +4. Scoring test samples by their typicality under the learned density + +No class labels or OOD exposure required during training. + +## Installation + +```bash +pip install forte-detector +``` + +## Example + +```python +from forte import ForteOODDetector + +detector = ForteOODDetector(method='gmm', device='cuda:0') +detector.fit(train_paths) +predictions = detector.predict(test_paths) +metrics = detector.evaluate(id_test_paths, ood_test_paths) +``` + +## Documentation + +- [Quickstart](quickstart.md) +- [Algorithm](methods.md) +- [API Reference](api-reference.md) +- [Configuration](user-guide.md) +- [Examples](examples.md) + +## Citation + +```bibtex +@inproceedings{ganguly2025forte, + title={Forte: Finding Outliers with Representation Typicality Estimation}, + author={Ganguly, Debargha and Morningstar, Warren Richard and Yu, Andrew Seohwan and Chaudhary, Vipin}, + booktitle={The Thirteenth International Conference on Learning Representations}, + year={2025}, + url={https://openreview.net/pdf?id=7XNgVPxCiA} +} +``` + +## License + +MIT. See [LICENSE](https://github.com/debarghag/forte-detector/blob/main/LICENSE). + +Supported by NSF ICICLE (OAC 2112606). diff --git a/docs/installation.md b/docs/installation.md new file mode 100644 index 0000000..595231f --- /dev/null +++ b/docs/installation.md @@ -0,0 +1,81 @@ +# Installation + +## Requirements + +- Python 3.9+ +- CUDA 11.8+ (optional, for GPU) + +## Dependencies + +Core (installed automatically): +- torch >= 2.0.0 +- torchvision >= 0.15.0 +- transformers >= 4.30.0 +- numpy >= 1.24.0 +- scipy >= 1.10.0 +- scikit-learn >= 1.3.0 +- pillow >= 9.0.0 +- tqdm >= 4.65.0 + +## PyPI + +```bash +pip install forte-detector +``` + +Optional extras: +```bash +pip install forte-detector[dev] # pytest, black, flake8, mypy +pip install forte-detector[docs] # mkdocs, mkdocs-material +pip install forte-detector[viz] # matplotlib +pip install forte-detector[all] # all optional dependencies +``` + +## From Source + +```bash +git clone https://github.com/debarghag/forte-detector.git +cd forte-detector +pip install -e ".[dev]" +``` + +## Verify + +```python +from forte import ForteOODDetector +print(ForteOODDetector.__module__) +``` + +## GPU Setup + +### CUDA + +```bash +# CUDA 11.8 +pip install torch torchvision --index-url https://download.pytorch.org/whl/cu118 + +# CUDA 12.1 +pip install torch torchvision --index-url https://download.pytorch.org/whl/cu121 +``` + +Verify: +```python +import torch +print(torch.cuda.is_available()) +``` + +### MPS (Apple Silicon) + +```python +import torch +print(torch.backends.mps.is_available()) +``` + +## First Run + +First call to `fit()` downloads pretrained models (~2GB total): +- `openai/clip-vit-base-patch32` +- `facebook/vit-msn-base` +- `facebook/dinov2-base` + +Models cached to `~/.cache/huggingface/`. diff --git a/docs/javascripts/mathjax.js b/docs/javascripts/mathjax.js new file mode 100644 index 0000000..06dbf38 --- /dev/null +++ b/docs/javascripts/mathjax.js @@ -0,0 +1,16 @@ +window.MathJax = { + tex: { + inlineMath: [["\\(", "\\)"]], + displayMath: [["\\[", "\\]"]], + processEscapes: true, + processEnvironments: true + }, + options: { + ignoreHtmlClass: ".*|", + processHtmlClass: "arithmatex" + } +}; + +document$.subscribe(() => { + MathJax.typesetPromise() +}) diff --git a/docs/methods.md b/docs/methods.md new file mode 100644 index 0000000..eb56edb --- /dev/null +++ b/docs/methods.md @@ -0,0 +1,121 @@ +# Algorithm + +## Problem + +Given reference data $\mathbf{X}_{\text{ref}} = \{x_i^r\}_{i=1}^m \sim P$ and test data $\mathbf{X}_{\text{test}} = \{x_j^g\}_{j=1}^n \sim \alpha P + (1-\alpha) Q$ where $Q$ is an unknown OOD distribution and $\alpha \in [0,1]$ is unknown, determine which $x_j^g \notin \text{supp}(P)$. + +## Notation + +| Symbol | Definition | +|--------|------------| +| $\text{NND}_k(x)$ | Distance from $x$ to its $k$-th nearest neighbor | +| $B(x, r)$ | Closed ball $\{y : \|x - y\| \leq r\}$ | +| $S(\mathbf{X})$ | $\bigcup_{i} B(x_i, \text{NND}_k(x_i))$ | +| $\mathbf{1}[\cdot]$ | Indicator function | + +## Per-Point PRDC Metrics + +For each test point $x_j^g$, compute four statistics relative to $\mathbf{X}_{\text{ref}}$: + +**Precision** (binary): +$$\text{precision}_j = \mathbf{1}\left[x_j^g \in S(\mathbf{X}_{\text{ref}})\right]$$ + +**Recall** (continuous): +$$\text{recall}_j = \frac{1}{m} \sum_{i=1}^{m} \mathbf{1}\left[x_i^r \in B(x_j^g, \text{NND}_k(x_j^g))\right]$$ + +**Density** (continuous): +$$\text{density}_j = \frac{1}{km} \sum_{i=1}^{m} \mathbf{1}\left[x_j^g \in B(x_i^r, \text{NND}_k(x_i^r))\right]$$ + +**Coverage** (binary): +$$\text{coverage}_j = \mathbf{1}\left[\min_i \|x_j^g - x_i^r\| < \text{NND}_k(x_j^g)\right]$$ + +These metrics capture local manifold geometry. OOD samples fall outside high-density regions, yielding low metric values. See [paper](https://openreview.net/pdf?id=7XNgVPxCiA) Section 3 for theoretical analysis. + +## Feature Extraction + +| Model | Dim | HuggingFace ID | +|-------|-----|----------------| +| CLIP ViT-B/32 | 512 | `openai/clip-vit-base-patch32` | +| ViT-MSN | 768 | `facebook/vit-msn-base` | +| DINOv2 | 768 | `facebook/dinov2-base` | + +For each image, extract CLS token embeddings from all three models. PRDC computed independently per model, then concatenated: 4 metrics × 3 models = 12-dimensional feature vector. + +## Training Procedure + +``` +Input: ID image paths, method ∈ {gmm, kde, ocsvm}, k +Output: Fitted detector + +1. Extract features F_ref for all images +2. Split F_ref into F_train (50%) and F_val (50%) +3. For each model m ∈ {clip, vitmsn, dinov2}: + Compute NND_k radii on F_train[m] + Compute PRDC(F_train[m], F_val[m]) → 4-dim vector per sample +4. Concatenate PRDC vectors → Z ∈ R^{n×12} +5. Fit density estimator on Z: + GMM: Select components via BIC from {1,2,4,8,16,32,64} + KDE: Bandwidth via Scott's rule + OCSVM: Select ν from {0.01,0.05,0.1,0.2,0.5} by validation accuracy +``` + +## Inference + +``` +Input: Test image paths +Output: Scores (higher = more likely ID) + +1. Extract features F_test +2. For each model m: + Compute PRDC(F_train[m], F_test[m]) +3. Concatenate → Z_test ∈ R^{n×12} +4. Score: + GMM: log p(z) + KDE: log p(z) + OCSVM: decision function value +``` + +## Density Estimators + +### GMM + +Mixture of $K$ Gaussians: +$$p(z) = \sum_{k=1}^{K} \pi_k \mathcal{N}(z \mid \mu_k, \Sigma_k)$$ + +Component count selected by minimizing BIC: +$$\text{BIC} = -2 \log \mathcal{L} + p \log n$$ + +### KDE + +Non-parametric density with Gaussian kernel: +$$p(z) = \frac{1}{n} \sum_{i=1}^{n} K_h(z - z_i)$$ + +Bandwidth $h$ via Scott's rule: $h = n^{-1/(d+4)} \sigma$ + +### OCSVM + +Finds hyperplane separating origin from data: +$$\min_{w,\rho,\xi} \frac{1}{2}\|w\|^2 - \rho + \frac{1}{\nu n} \sum_i \xi_i$$ +subject to $w^\top z_i \geq \rho - \xi_i$, $\xi_i \geq 0$ + +Score: $w^\top z - \rho$ + +## Complexity + +| Operation | Time | Space | +|-----------|------|-------| +| Feature extraction | $O(n \cdot T_{\text{forward}})$ | $O(n \cdot d)$ | +| Pairwise distances | $O(n^2 \cdot d)$ | $O(n^2)$ | +| GMM training | $O(K \cdot I \cdot n \cdot d^2)$ | $O(K \cdot d^2)$ | +| KDE evaluation | $O(n_{\text{train}} \cdot n_{\text{test}})$ | $O(n_{\text{train}} \cdot d)$ | +| OCSVM training | $O(T_{\text{opt}} \cdot n \cdot d)$ | $O(d)$ | + +Where $K$ = GMM components, $I$ = EM iterations, $d$ = feature dimension. + +## Method Selection + +| Method | Use when | +|--------|----------| +| GMM | Default choice. Multi-modal ID distributions. | +| KDE | Small datasets (<1000). Smooth decision boundaries. | +| OCSVM | Large datasets. Fast inference required. | diff --git a/docs/quickstart.md b/docs/quickstart.md new file mode 100644 index 0000000..c91c4a5 --- /dev/null +++ b/docs/quickstart.md @@ -0,0 +1,58 @@ +# Quickstart + +## Install + +```bash +pip install forte-detector +``` + +## Train + +```python +from forte import ForteOODDetector + +detector = ForteOODDetector(method='gmm', device='cuda:0') +detector.fit(train_image_paths) +``` + +First run downloads ~2GB of pretrained models. + +## Predict + +```python +predictions = detector.predict(test_paths) # 1=ID, -1=OOD +scores = detector.predict_proba(test_paths) # [0,1], higher=ID +``` + +## Evaluate + +```python +metrics = detector.evaluate(id_test_paths, ood_test_paths) +print(f"AUROC: {metrics['AUROC']:.4f}") +print(f"FPR@95: {metrics['FPR@95TPR']:.4f}") +``` + +## Full Example + +```python +import glob +from forte import ForteOODDetector + +# Collect image paths +id_train = glob.glob("data/normal/train/*.jpg") +id_test = glob.glob("data/normal/test/*.jpg") +ood_test = glob.glob("data/anomaly/test/*.jpg") + +# Train and evaluate +detector = ForteOODDetector(method='gmm', device='cuda:0') +detector.fit(id_train) +metrics = detector.evaluate(id_test, ood_test) + +print(metrics) +``` + +## Next + +- [Algorithm](methods.md) - How it works +- [API Reference](api-reference.md) - Full API +- [Configuration](user-guide.md) - Parameters diff --git a/docs/stylesheets/extra.css b/docs/stylesheets/extra.css new file mode 100644 index 0000000..c9a1f18 --- /dev/null +++ b/docs/stylesheets/extra.css @@ -0,0 +1,24 @@ +/* Custom CSS for Forte Detector documentation */ + +:root { + --forte-primary: #3f51b5; + --forte-accent: #ff4081; +} + +.md-typeset h1 { + font-weight: 700; +} + +.md-typeset code { + background-color: rgba(63, 81, 181, 0.1); +} + +/* Custom admonition for paper references */ +.md-typeset .admonition.paper { + border-left-color: var(--forte-primary); +} + +/* Improved table styling */ +.md-typeset table:not([class]) { + font-size: 0.85em; +} diff --git a/docs/user-guide.md b/docs/user-guide.md new file mode 100644 index 0000000..d405822 --- /dev/null +++ b/docs/user-guide.md @@ -0,0 +1,81 @@ +# Configuration + +## Parameters + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `method` | str | `'gmm'` | `'gmm'`, `'kde'`, or `'ocsvm'` | +| `nearest_k` | int | 5 | k for k-NN manifold estimation | +| `batch_size` | int | 32 | Images per GPU forward pass | +| `device` | str | auto | `'cuda:N'`, `'mps'`, `'cpu'` | +| `embedding_dir` | str | `'./embeddings'` | Feature cache directory | + +## Method Selection + +| Method | Best for | Hyperparameter tuning | +|--------|----------|----------------------| +| GMM | Multi-modal distributions | Components via BIC (1-64) | +| KDE | Small datasets, smooth boundaries | Bandwidth via Scott's rule | +| OCSVM | Large datasets, fast inference | nu via validation (0.01-0.5) | + +## Device Selection + +Auto-detection priority: CUDA > MPS > CPU + +```python +# Force specific device +detector = ForteOODDetector(device='cuda:1') +detector = ForteOODDetector(device='cpu') +``` + +## Caching + +Features cached to `{embedding_dir}/{name}_{model}_features.pt` + +Cache is reused if file exists and sample count matches. Delete to force recomputation: + +```bash +rm -rf ./embeddings +``` + +## Memory + +GPU memory usage: +- Models: ~2-3 GB (CLIP + ViT-MSN + DINOv2) +- Features: ~4 bytes × n_samples × 2048 (all model dims) +- PRDC distances: O(n²) temporary + +Reduce `batch_size` if OOM. + +## Hyperparameter Tuning + +### nearest_k + +Controls manifold resolution. Larger k = smoother estimates, less sensitive to noise. + +| Dataset size | Recommended k | +|--------------|---------------| +| <1000 | 3-5 | +| 1000-10000 | 5-10 | +| >10000 | 10-20 | + +### val_split + +Fraction of training data used for hyperparameter selection. + +```python +detector.fit(paths, val_split=0.1) # 90% train, 10% validation +``` + +## Reproducibility + +```python +import torch +import numpy as np + +np.random.seed(42) +torch.manual_seed(42) +torch.cuda.manual_seed(42) + +detector.fit(paths, random_state=42) +``` diff --git a/examples/cifar_demo.py b/examples/cifar_demo.py new file mode 100644 index 0000000..90f8a3a --- /dev/null +++ b/examples/cifar_demo.py @@ -0,0 +1,420 @@ +import argparse +import logging +import os +import time + +import matplotlib.pyplot as plt +import numpy as np +import torch +import torchvision +import torchvision.transforms as transforms +from PIL import Image +from tqdm import tqdm + +from forte import ForteOODDetector + +# Configure logging +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(levelname)s - %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", +) +logger = logging.getLogger("ForteDemo") + + +def save_dataset_as_png(dataset, save_dir, num_images=1000): + """ + Save a subset of a dataset as PNG images. + + Args: + dataset: PyTorch dataset + save_dir (str): Directory to save images + num_images (int): Number of images to save + + Returns: + list: List of paths to saved images + """ + logger.info(f"Saving {min(num_images, len(dataset))} images to {save_dir}") + os.makedirs(save_dir, exist_ok=True) + paths = [] + + for i in tqdm(range(min(num_images, len(dataset))), desc=f"Saving images to {save_dir}"): + image, label = dataset[i] + # Convert tensor to PIL Image + if isinstance(image, torch.Tensor): + image = transforms.ToPILImage()(image) + + # Save the image + path = os.path.join(save_dir, f"{i}_label{label}.png") + image.save(path) + paths.append(path) + + return paths + + +def load_cifar_datasets(): + """ + Load CIFAR10 and CIFAR100 datasets. + + Returns: + tuple: CIFAR10 train and test sets, CIFAR100 test set + """ + logger.info("Loading CIFAR10 and CIFAR100 datasets...") + # Define transform + transform = transforms.Compose([transforms.ToTensor()]) + + # Load CIFAR10 train and test sets + cifar10_train = torchvision.datasets.CIFAR10( + root="./data", train=True, download=True, transform=transform + ) + + cifar10_test = torchvision.datasets.CIFAR10( + root="./data", train=False, download=True, transform=transform + ) + + # Load CIFAR100 test set + cifar100_test = torchvision.datasets.CIFAR100( + root="./data", train=False, download=True, transform=transform + ) + + logger.info( + f"Loaded datasets - CIFAR10 train: {len(cifar10_train)} images, " + + f"CIFAR10 test: {len(cifar10_test)} images, " + + f"CIFAR100 test: {len(cifar100_test)} images" + ) + + return cifar10_train, cifar10_test, cifar100_test + + +def print_training_phases(): + """Print information about the phases of the Forte training pipeline.""" + phases = [ + ("1. Data Preparation", "Convert datasets to image files and prepare directories"), + ( + "2. Feature Extraction", + "Extract semantic features using pretrained models (CLIP, ViTMSN, DINOv2)", + ), + ( + "3. PRDC Computation", + "Compute Precision, Recall, Density, Coverage metrics from extracted features", + ), + ("4. Detector Training", "Train OOD detector (GMM, KDE, or OCSVM) on PRDC features"), + ("5. Evaluation", "Compute scores and performance metrics on test datasets"), + ] + + logger.info("\n=== Forte OOD Detection Pipeline ===") + for i, (phase, desc) in enumerate(phases): + logger.info(f"{phase}: {desc}") + logger.info("=" * 40) + + +def main(args): + # Print pipeline phases information + print_training_phases() + + # Set random seed for reproducibility + np.random.seed(args.seed) + torch.manual_seed(args.seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(args.seed) + + logger.info(f"Running with configuration: {args}") + + # Create directories + os.makedirs("data", exist_ok=True) + os.makedirs(args.embedding_dir, exist_ok=True) + + # Phase 1: Data Preparation + logger.info("\n=== Phase 1: Data Preparation ===") + cifar10_train, cifar10_test, cifar100_test = load_cifar_datasets() + + # Create directories for images + os.makedirs("data/cifar10/train", exist_ok=True) + os.makedirs("data/cifar10/test", exist_ok=True) + os.makedirs("data/cifar100/test", exist_ok=True) + + # Check if we need to save images + if not os.path.exists("data/cifar10/train/0_label0.png") or args.force_save: + logger.info("Converting datasets to PNG images...") + # Save CIFAR10 training images + cifar10_train_paths = save_dataset_as_png( + cifar10_train, "data/cifar10/train", num_images=args.num_train_images + ) + + # Save CIFAR10 test images + cifar10_test_paths = save_dataset_as_png( + cifar10_test, "data/cifar10/test", num_images=args.num_test_images + ) + + # Save CIFAR100 test images + cifar100_test_paths = save_dataset_as_png( + cifar100_test, "data/cifar100/test", num_images=args.num_test_images + ) + else: + logger.info("Using previously saved images...") + cifar10_train_paths = sorted( + [ + os.path.join("data/cifar10/train", f) + for f in os.listdir("data/cifar10/train") + if f.endswith(".png") + ] + )[: args.num_train_images] + + cifar10_test_paths = sorted( + [ + os.path.join("data/cifar10/test", f) + for f in os.listdir("data/cifar10/test") + if f.endswith(".png") + ] + )[: args.num_test_images] + + cifar100_test_paths = sorted( + [ + os.path.join("data/cifar100/test", f) + for f in os.listdir("data/cifar100/test") + if f.endswith(".png") + ] + )[: args.num_test_images] + + logger.info(f"Number of CIFAR10 training images: {len(cifar10_train_paths)}") + logger.info(f"Number of CIFAR10 test images: {len(cifar10_test_paths)}") + logger.info(f"Number of CIFAR100 test images: {len(cifar100_test_paths)}") + + # Phase 2-4: Feature Extraction, PRDC Computation, and Detector Training + logger.info("\n=== Phase 2-4: Feature Extraction, PRDC Computation, and Detector Training ===") + start_time = time.time() + logger.info( + f"Creating ForteOODDetector with method: {args.method}, nearest_k: {args.nearest_k}" + ) + detector = ForteOODDetector( + batch_size=args.batch_size, + device=args.device, + embedding_dir=args.embedding_dir, + method=args.method, + nearest_k=args.nearest_k, + ) + + # Fit the detector - this performs feature extraction, PRDC computation, and detector training + logger.info(f"Fitting ForteOODDetector on {len(cifar10_train_paths)} in-distribution images...") + detector.fit(cifar10_train_paths, val_split=0.2, random_state=args.seed) + training_time = time.time() - start_time + logger.info(f"Training completed in {training_time:.2f} seconds") + + # Phase 5: Evaluation + logger.info("\n=== Phase 5: Evaluation ===") + + # Benchmark on ID data (CIFAR10 test) + logger.info("Benchmarking detector on CIFAR10 (in-distribution)...") + start_time = time.time() + id_scores = detector._get_ood_scores(cifar10_test_paths, cache_name="id_benchmark") + id_prediction_time = time.time() - start_time + logger.info( + f"ID prediction time for {len(cifar10_test_paths)} images: {id_prediction_time:.2f} seconds " + + f"({id_prediction_time/len(cifar10_test_paths):.4f} sec/image)" + ) + + # Benchmark on OOD data (CIFAR100 test) + logger.info("Benchmarking detector on CIFAR100 (out-of-distribution)...") + start_time = time.time() + ood_scores = detector._get_ood_scores(cifar100_test_paths, cache_name="ood_benchmark") + ood_prediction_time = time.time() - start_time + logger.info( + f"OOD prediction time for {len(cifar100_test_paths)} images: {ood_prediction_time:.2f} seconds " + + f"({ood_prediction_time/len(cifar100_test_paths):.4f} sec/image)" + ) + + # Score statistics + logger.info("\nScore Statistics:") + logger.info( + f"CIFAR10 (ID) - Mean: {np.mean(id_scores):.4f}, Std: {np.std(id_scores):.4f}, " + + f"Min: {np.min(id_scores):.4f}, Max: {np.max(id_scores):.4f}" + ) + logger.info( + f"CIFAR100 (OOD) - Mean: {np.mean(ood_scores):.4f}, Std: {np.std(ood_scores):.4f}, " + + f"Min: {np.min(ood_scores):.4f}, Max: {np.max(ood_scores):.4f}" + ) + + # Calculate threshold based on ID scores + threshold = np.percentile(id_scores, 5) # 5th percentile + logger.info(f"Suggested decision threshold (5th percentile of ID scores): {threshold:.4f}") + + # Calculate detection accuracy + id_correct = (id_scores > threshold).mean() + ood_correct = (ood_scores <= threshold).mean() + overall_acc = (id_correct * len(id_scores) + ood_correct * len(ood_scores)) / ( + len(id_scores) + len(ood_scores) + ) + logger.info(f"ID Detection Rate: {id_correct:.4f}, OOD Detection Rate: {ood_correct:.4f}") + logger.info(f"Overall Accuracy: {overall_acc:.4f}") + + # Full evaluation on mixed test set + logger.info("\nPerforming full evaluation on CIFAR10/CIFAR100 test sets...") + evaluation_start_time = time.time() + results = detector.evaluate(cifar10_test_paths, cifar100_test_paths) + evaluation_time = time.time() - evaluation_start_time + + # Print performance metrics + logger.info("\n=== OOD Detection Performance ===") + logger.info(f"Method: {args.method}, Nearest_k: {args.nearest_k}") + logger.info(f"AUROC: {results['AUROC']:.4f}") + logger.info(f"FPR@95TPR: {results['FPR@95TPR']:.4f}") + logger.info(f"AUPRC: {results['AUPRC']:.4f}") + logger.info(f"F1 Score: {results['F1']:.4f}") + logger.info(f"Evaluation time: {evaluation_time:.2f} seconds") + + # Visualize results + if args.visualize: + logger.info("\nGenerating visualizations...") + + # Plot score distributions + plt.figure(figsize=(10, 6)) + bins = np.linspace( + min(np.min(id_scores), np.min(ood_scores)), + max(np.max(id_scores), np.max(ood_scores)), + 30, + ) + + plt.hist(id_scores, bins=bins, alpha=0.7, label="CIFAR10 (In-Distribution)", density=True) + plt.hist( + ood_scores, bins=bins, alpha=0.7, label="CIFAR100 (Out-of-Distribution)", density=True + ) + + # Add threshold line + plt.axvline( + x=threshold, color="r", linestyle="--", alpha=0.7, label=f"Threshold ({threshold:.4f})" + ) + + plt.legend() + plt.title(f"ForteOODDetector Scores ({args.method}, nearest_k={args.nearest_k})") + plt.xlabel("OOD Score (higher = more in-distribution like)") + plt.ylabel("Density") + plt.grid(True, alpha=0.3) + + # Save figure + plt.savefig(f"forte_{args.method}_results.png") + logger.info(f"Score distribution saved to forte_{args.method}_results.png") + + # Show examples with predictions + num_examples = min(5, len(cifar10_test_paths), len(cifar100_test_paths)) + + fig, axes = plt.subplots(2, num_examples, figsize=(15, 6)) + + # CIFAR10 examples (should be classified as in-distribution) + for i in range(num_examples): + img = Image.open(cifar10_test_paths[i]) + axes[0, i].imshow(img) + + score = id_scores[i] + is_id = score > threshold + correct = is_id # For ID samples, prediction is correct if classified as ID + + color = "green" if correct else "red" + pred = "ID" if is_id else "OOD" + axes[0, i].set_title( + f"CIFAR10 (true=ID)\nPred: {pred}\nScore: {score:.2f}", color=color + ) + axes[0, i].axis("off") + + # CIFAR100 examples (should be classified as out-of-distribution) + for i in range(num_examples): + img = Image.open(cifar100_test_paths[i]) + axes[1, i].imshow(img) + + score = ood_scores[i] + is_id = score > threshold + correct = not is_id # For OOD samples, prediction is correct if classified as OOD + + color = "green" if correct else "red" + pred = "ID" if is_id else "OOD" + axes[1, i].set_title( + f"CIFAR100 (true=OOD)\nPred: {pred}\nScore: {score:.2f}", color=color + ) + axes[1, i].axis("off") + + plt.tight_layout() + plt.savefig("forte_examples.png") + logger.info("Example predictions saved to forte_examples.png") + + # ROC curve + plt.figure(figsize=(8, 6)) + + # Create labels (1 for ID, 0 for OOD) + labels = np.concatenate([np.ones(len(id_scores)), np.zeros(len(ood_scores))]) + scores_combined = np.concatenate([id_scores, ood_scores]) + + # Calculate ROC curve + from sklearn.metrics import auc, roc_curve + + fpr, tpr, _ = roc_curve(labels, scores_combined) + roc_auc = auc(fpr, tpr) + + plt.plot(fpr, tpr, lw=2, label=f"ROC curve (area = {roc_auc:.2f})") + plt.plot([0, 1], [0, 1], color="navy", lw=2, linestyle="--", label="Random") + + # Mark the FPR at 95% TPR + idx_95tpr = np.argmin(np.abs(tpr - 0.95)) + fpr_at_95tpr = fpr[idx_95tpr] + plt.scatter( + fpr_at_95tpr, 0.95, color="red", label=f"FPR@95TPR = {fpr_at_95tpr:.4f}", zorder=5 + ) + + plt.xlim([0.0, 1.0]) + plt.ylim([0.0, 1.05]) + plt.xlabel("False Positive Rate") + plt.ylabel("True Positive Rate") + plt.title(f"ROC Curve - {args.method.upper()}") + plt.legend(loc="lower right") + plt.grid(alpha=0.3) + + plt.savefig(f"forte_{args.method}_roc.png") + logger.info(f"ROC curve saved to forte_{args.method}_roc.png") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Forte OOD Detection Demo") + parser.add_argument("--batch_size", type=int, default=32, help="Batch size for processing") + parser.add_argument( + "--device", + type=str, + default="cuda:0" if torch.cuda.is_available() else "mps", + help="Device to use", + ) + parser.add_argument( + "--method", + type=str, + default="ocsvm", + choices=["gmm", "kde", "ocsvm"], + help="OOD detection method", + ) + parser.add_argument( + "--nearest_k", type=int, default=5, help="Number of nearest neighbors for PRDC" + ) + parser.add_argument( + "--num_train_images", type=int, default=1000, help="Number of training images" + ) + parser.add_argument("--num_test_images", type=int, default=500, help="Number of test images") + parser.add_argument("--seed", type=int, default=42, help="Random seed") + parser.add_argument("--visualize", action="store_true", help="Visualize results") + parser.add_argument( + "--force_save", action="store_true", help="Force save images even if they exist" + ) + parser.add_argument( + "--embedding_dir", type=str, default="embeddings", help="Directory to store embeddings" + ) + parser.add_argument( + "--log_level", + type=str, + default="INFO", + choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], + help="Logging level", + ) + + args = parser.parse_args() + + # Set logging level + numeric_level = getattr(logging, args.log_level.upper(), None) + if not isinstance(numeric_level, int): + raise ValueError(f"Invalid log level: {args.log_level}") + logging.getLogger().setLevel(numeric_level) + + main(args) diff --git a/forte_api.py b/forte_api.py index 73a1080..4ad63ae 100644 --- a/forte_api.py +++ b/forte_api.py @@ -1,18 +1,31 @@ +import math import os import time -import math + import numpy as np import torch import torch.nn.functional as F -from sklearn.model_selection import train_test_split -from transformers import CLIPModel, CLIPProcessor, ViTMSNModel, AutoFeatureExtractor, AutoModel, AutoImageProcessor from PIL import Image -from tqdm import tqdm -from sklearn.metrics import roc_auc_score, precision_recall_curve, average_precision_score, roc_curve from scipy.stats import gaussian_kde +from sklearn.metrics import ( + average_precision_score, + pairwise_distances, + precision_recall_curve, + roc_auc_score, + roc_curve, +) from sklearn.mixture import GaussianMixture +from sklearn.model_selection import train_test_split from sklearn.svm import OneClassSVM -from sklearn.metrics import pairwise_distances +from tqdm import tqdm +from transformers import ( + AutoFeatureExtractor, + AutoImageProcessor, + AutoModel, + CLIPModel, + CLIPProcessor, + ViTMSNModel, +) ############################################# # ForteOODDetector Class @@ -23,19 +36,16 @@ class ForteOODDetector: """ Forte OOD Detector: Finding Outliers Using Representation Typicality Estimation. - This class implements the Forte method for OOD detection. It extracts features using + This class implements the Forte method for OOD detection. It extracts features using pretrained models and computes PRDC features using PyTorch tensors on GPU. - Detector training can use either a custom GPU-based implementation + Detector training can use either a custom GPU-based implementation or fall back to CPU-based detectors from scikit-learn/SciPy. """ - def __init__(self, - batch_size=32, - device=None, - embedding_dir="./embeddings", - nearest_k=5, - method='gmm'): + def __init__( + self, batch_size=32, device=None, embedding_dir="./embeddings", nearest_k=5, method="gmm" + ): """ Initialize the ForteOODDetector. @@ -45,7 +55,7 @@ def __init__(self, embedding_dir (str): Directory to store embeddings. nearest_k (int): Number of nearest neighbors for PRDC computation. method (str): Detector method ('gmm', 'kde', or 'ocsvm'). - custom_detector (bool): If True, use our custom GPU-based implementations + custom_detector (bool): If True, use our custom GPU-based implementations (TorchGMM, TorchKDE, TorchOCSVM). If False, use CPU-based detectors. """ self.batch_size = batch_size @@ -60,13 +70,13 @@ def __init__(self, self.embedding_dir = embedding_dir self.nearest_k = nearest_k self.method = method - self.custom_detector = (self.device != "cpu") + self.custom_detector = self.device != "cpu" self.models = None self.is_fitted = False # These will be set during fit - self.id_train_features = None # GPU tensors for feature extraction - self.id_train_prdc = None # Combined PRDC features (GPU tensor) + self.id_train_features = None # GPU tensors for feature extraction + self.id_train_prdc = None # Combined PRDC features (GPU tensor) self.detector = None os.makedirs(self.embedding_dir, exist_ok=True) @@ -83,12 +93,21 @@ def _init_models(self): print(f"Initializing models on {self.device}...") device = self.device models = [ - ("clip", CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device), - CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")), - ("vitmsn", ViTMSNModel.from_pretrained("facebook/vit-msn-base").to(device), - AutoFeatureExtractor.from_pretrained("facebook/vit-msn-base")), - ("dinov2", AutoModel.from_pretrained('facebook/dinov2-base').to(device), - AutoImageProcessor.from_pretrained('facebook/dinov2-base')) + ( + "clip", + CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device), + CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32"), + ), + ( + "vitmsn", + ViTMSNModel.from_pretrained("facebook/vit-msn-base").to(device), + AutoFeatureExtractor.from_pretrained("facebook/vit-msn-base"), + ), + ( + "dinov2", + AutoModel.from_pretrained("facebook/dinov2-base").to(device), + AutoImageProcessor.from_pretrained("facebook/dinov2-base"), + ), ] return models @@ -108,13 +127,14 @@ def _extract_features_batch(self, image_paths, batch_idx=0): images = [img for img in images if img is not None] if not images: - return {model_name: torch.empty(0, device=self.device) for model_name, _, _ in self.models} + return { + model_name: torch.empty(0, device=self.device) for model_name, _, _ in self.models + } all_features = {} # Process each model using its corresponding processor for model_name, model, processor in self.models: - inputs = processor( - images=images, return_tensors="pt", padding=True).to(self.device) + inputs = processor(images=images, return_tensors="pt", padding=True).to(self.device) try: with torch.no_grad(): if model_name == "clip": @@ -147,15 +167,15 @@ def _extract_features(self, image_paths, name="tmp"): models_to_process = [] for model_name, _, _ in self.models: - embedding_file = os.path.join( - self.embedding_dir, f"{name}_{model_name}_features.pt") + embedding_file = os.path.join(self.embedding_dir, f"{name}_{model_name}_features.pt") if os.path.exists(embedding_file): print(f"Loading pre-computed features from {embedding_file}") loaded = torch.load(embedding_file, map_location=self.device) all_features[model_name] = loaded if loaded.size(0) != len(image_paths): print( - f"Warning: Cached features count ({loaded.size(0)}) doesn't match image count ({len(image_paths)}). Recomputing for {model_name}.") + f"Warning: Cached features count ({loaded.size(0)}) doesn't match image count ({len(image_paths)}). Recomputing for {model_name}." + ) all_features[model_name] = [] models_to_process.append(model_name) else: @@ -167,22 +187,22 @@ def _extract_features(self, image_paths, name="tmp"): return all_features for i in tqdm(range(0, len(image_paths), self.batch_size), desc="Extracting features"): - batch_paths = image_paths[i:i+self.batch_size] - batch_features = self._extract_features_batch( - batch_paths, i//self.batch_size) + batch_paths = image_paths[i : i + self.batch_size] + batch_features = self._extract_features_batch(batch_paths, i // self.batch_size) for model_name, features in batch_features.items(): if features.numel() > 0 and model_name in models_to_process: all_features[model_name].append(features) for model_name in models_to_process: if all_features[model_name]: - all_features[model_name] = torch.cat( - all_features[model_name], dim=0) + all_features[model_name] = torch.cat(all_features[model_name], dim=0) embedding_file = os.path.join( - self.embedding_dir, f"{name}_{model_name}_features.pt") + self.embedding_dir, f"{name}_{model_name}_features.pt" + ) torch.save(all_features[model_name], embedding_file) print( - f"Saved {model_name} features with shape {all_features[model_name].shape} to {embedding_file}") + f"Saved {model_name} features with shape {all_features[model_name].shape} to {embedding_file}" + ) else: all_features[model_name] = torch.empty(0, device=self.device) @@ -245,19 +265,15 @@ def _compute_prdc_features(self, real_features, fake_features): torch.Tensor: PRDC features (recall, density, precision, coverage). """ num_real = real_features.size(0) - real_distances = self._compute_nearest_neighbour_distances( - real_features, self.nearest_k) - fake_distances = self._compute_nearest_neighbour_distances( - fake_features, self.nearest_k) - distance_matrix = self._compute_pairwise_distance( - real_features, fake_features) - - precision = (distance_matrix < real_distances.unsqueeze(1) - ).any(dim=0).float() - recall = (distance_matrix < fake_distances).sum( - dim=0).float() / num_real - density = (1. / float(self.nearest_k)) * (distance_matrix < - real_distances.unsqueeze(1)).sum(dim=0).float() + real_distances = self._compute_nearest_neighbour_distances(real_features, self.nearest_k) + fake_distances = self._compute_nearest_neighbour_distances(fake_features, self.nearest_k) + distance_matrix = self._compute_pairwise_distance(real_features, fake_features) + + precision = (distance_matrix < real_distances.unsqueeze(1)).any(dim=0).float() + recall = (distance_matrix < fake_distances).sum(dim=0).float() / num_real + density = (1.0 / float(self.nearest_k)) * ( + distance_matrix < real_distances.unsqueeze(1) + ).sum(dim=0).float() coverage = (distance_matrix.min(dim=0).values < fake_distances).float() return torch.stack((recall, density, precision, coverage), dim=1) @@ -279,15 +295,13 @@ def fit(self, id_image_paths, val_split=0.2, random_state=42): # Split paths into training and validation id_train_paths, id_val_paths = train_test_split( - id_image_paths, test_size=val_split, random_state=random_state) + id_image_paths, test_size=val_split, random_state=random_state + ) - print( - f"Extracting features from {len(id_train_paths)} training images...") - self.id_train_features = self._extract_features( - id_train_paths, name="id_train") + print(f"Extracting features from {len(id_train_paths)} training images...") + self.id_train_features = self._extract_features(id_train_paths, name="id_train") - print( - f"Extracting features from {len(id_val_paths)} validation images...") + print(f"Extracting features from {len(id_val_paths)} validation images...") id_val_features = self._extract_features(id_val_paths, name="id_val") # Compute PRDC features for each model using GPU tensor operations @@ -303,71 +317,74 @@ def fit(self, id_image_paths, val_split=0.2, random_state=42): id_train_part1 = features[train_idx[:split]] id_train_part2 = features[train_idx[split:]] - print( - f" Training PRDC: {id_train_part1.shape} vs {id_train_part2.shape}") - train_prdc = self._compute_prdc_features( - id_train_part1, id_train_part2) + print(f" Training PRDC: {id_train_part1.shape} vs {id_train_part2.shape}") + train_prdc = self._compute_prdc_features(id_train_part1, id_train_part2) X_id_train_prdc.append(train_prdc) val_feats = id_val_features[model_name] - print( - f" Validation PRDC: {id_train_part1.shape} vs {val_feats.shape}") + print(f" Validation PRDC: {id_train_part1.shape} vs {val_feats.shape}") val_prdc = self._compute_prdc_features(id_train_part1, val_feats) X_id_val_prdc.append(val_prdc) self.id_train_prdc = torch.cat(X_id_train_prdc, dim=1) # still on GPU id_val_prdc = torch.cat(X_id_val_prdc, dim=1) print( - f"Combined PRDC features - Training: {self.id_train_prdc.shape}, Validation: {id_val_prdc.shape}") + f"Combined PRDC features - Training: {self.id_train_prdc.shape}, Validation: {id_val_prdc.shape}" + ) - print( - f"Training detector ({self.method}) with custom_detector={self.custom_detector}...") - if self.method == 'gmm': + print(f"Training detector ({self.method}) with custom_detector={self.custom_detector}...") + if self.method == "gmm": best_bic = np.inf best_n_components = 1 best_model = None for n_components in [1, 2, 4, 8, 16, 32, 64]: if self.custom_detector: - gmm = TorchGMM(n_components=n_components, - max_iter=100, tol=1e-3, device=self.device) + gmm = TorchGMM( + n_components=n_components, max_iter=100, tol=1e-3, device=self.device + ) gmm.fit(self.id_train_prdc) bic_val = gmm.bic(self.id_train_prdc) else: id_train_prdc_cpu = self.id_train_prdc.cpu().numpy() gmm = GaussianMixture( - n_components=n_components, covariance_type='full', random_state=random_state, max_iter=100) + n_components=n_components, + covariance_type="full", + random_state=random_state, + max_iter=100, + ) gmm.fit(id_train_prdc_cpu) bic_val = gmm.bic(id_train_prdc_cpu) if bic_val < best_bic: best_bic = bic_val best_n_components = n_components best_gmm = gmm - print( - f"Selected {best_n_components} components for GMM with BIC={best_bic:.2f}") + print(f"Selected {best_n_components} components for GMM with BIC={best_bic:.2f}") self.detector = best_gmm - elif self.method == 'kde': - self.detector = TorchKDE(self.id_train_prdc.T, bw_method='scott', device=self.device) if self.custom_detector else gaussian_kde( - self.id_train_prdc.cpu().numpy().T, bw_method='scott') + elif self.method == "kde": + self.detector = ( + TorchKDE(self.id_train_prdc.T, bw_method="scott", device=self.device) + if self.custom_detector + else gaussian_kde(self.id_train_prdc.cpu().numpy().T, bw_method="scott") + ) - elif self.method == 'ocsvm': + elif self.method == "ocsvm": if self.custom_detector: best_accuracy = 0 best_nu = 0.01 best_model = None for nu in [0.01, 0.05, 0.1, 0.2, 0.5]: - model = TorchOCSVM(nu=nu, n_iters=1000, - lr=1e-3, device=self.device) + model = TorchOCSVM(nu=nu, n_iters=1000, lr=1e-3, device=self.device) model.fit(self.id_train_prdc) decision = model.decision_function(self.id_train_prdc) - accuracy = (torch.where(decision.detach() >= 0, - 1, -1).float().mean().item() + 1) / 2.0 + accuracy = ( + torch.where(decision.detach() >= 0, 1, -1).float().mean().item() + 1 + ) / 2.0 if accuracy > best_accuracy: best_accuracy = accuracy best_nu = nu best_model = model - print( - f"Selected nu={best_nu} for TorchOCSVM with accuracy {best_accuracy:.4f}") + print(f"Selected nu={best_nu} for TorchOCSVM with accuracy {best_accuracy:.4f}") self.detector = best_model else: best_accuracy = 0 @@ -375,7 +392,7 @@ def fit(self, id_image_paths, val_split=0.2, random_state=42): for nu in [0.01, 0.05, 0.1, 0.2, 0.5]: try: id_train_prdc_cpu = self.id_train_prdc.cpu().numpy() - ocsvm = OneClassSVM(kernel='rbf', gamma='scale', nu=nu) + ocsvm = OneClassSVM(kernel="rbf", gamma="scale", nu=nu) ocsvm.fit(id_train_prdc_cpu) val_pred = ocsvm.predict(id_train_prdc_cpu) accuracy = np.mean(val_pred == 1) @@ -385,11 +402,9 @@ def fit(self, id_image_paths, val_split=0.2, random_state=42): except Exception as e: print(f"Error with nu={nu}: {e}") continue - print( - f"Selected nu={best_nu} for OCSVM with accuracy {best_accuracy:.4f}") + print(f"Selected nu={best_nu} for OCSVM with accuracy {best_accuracy:.4f}") id_train_prdc_cpu = self.id_train_prdc.cpu().numpy() - self.detector = OneClassSVM( - kernel='rbf', gamma='scale', nu=best_nu) + self.detector = OneClassSVM(kernel="rbf", gamma="scale", nu=best_nu) self.detector.fit(id_train_prdc_cpu) self.is_fitted = True @@ -415,15 +430,14 @@ def _get_ood_scores(self, image_paths, cache_name="test"): X_test_prdc = [] for model_name in test_features: ref_features = self.id_train_features[model_name] - train_idx = torch.randperm( - ref_features.size(0), device=self.device) + train_idx = torch.randperm(ref_features.size(0), device=self.device) split = int(ref_features.size(0) * 0.5) id_train_part1 = ref_features[train_idx[:split]] test_tensor = test_features[model_name] print( - f"Computing test PRDC for {model_name}: {id_train_part1.shape} vs {test_tensor.shape}") - test_prdc = self._compute_prdc_features( - id_train_part1, test_tensor) + f"Computing test PRDC for {model_name}: {id_train_part1.shape} vs {test_tensor.shape}" + ) + test_prdc = self._compute_prdc_features(id_train_part1, test_tensor) X_test_prdc.append(test_prdc) X_test_prdc = torch.cat(X_test_prdc, dim=1) @@ -431,22 +445,22 @@ def _get_ood_scores(self, image_paths, cache_name="test"): # For custom (GPU-based) detectors, use torch outputs; then convert to numpy if needed. if self.custom_detector: - if self.method == 'gmm': + if self.method == "gmm": scores = self.detector.score_samples(X_test_prdc) scores = scores.cpu().numpy() - elif self.method == 'kde': + elif self.method == "kde": scores = self.detector.logpdf(X_test_prdc) scores = scores.cpu().numpy() - elif self.method == 'ocsvm': + elif self.method == "ocsvm": scores = self.detector.decision_function(X_test_prdc) scores = scores.detach().cpu().numpy() else: X_test_prdc_cpu = X_test_prdc.cpu().numpy() - if self.method == 'gmm': + if self.method == "gmm": scores = self.detector.score_samples(X_test_prdc_cpu) - elif self.method == 'kde': + elif self.method == "kde": scores = self.detector.logpdf(X_test_prdc_cpu.T) - elif self.method == 'ocsvm': + elif self.method == "ocsvm": scores = self.detector.decision_function(X_test_prdc_cpu) return scores @@ -461,28 +475,26 @@ def predict(self, image_paths): np.ndarray: Binary predictions (1 for in-distribution, -1 for OOD). """ scores = self._get_ood_scores(image_paths) - if self.method == 'ocsvm': + if self.method == "ocsvm": threshold = 0 else: if self.custom_detector: ref_features = self.id_train_prdc # Use a simple split for threshold estimation - train_idx = torch.randperm( - ref_features.size(0), device=self.device) + train_idx = torch.randperm(ref_features.size(0), device=self.device) split = int(ref_features.size(0) * 0.5) id_train_part1 = ref_features[train_idx[:split]] - if self.method == 'gmm': - id_scores = self.detector.score_samples( - id_train_part1).cpu().numpy() - elif self.method == 'kde': - id_scores = self.detector.score_samples( - id_train_part1).cpu().numpy() + if self.method == "gmm": + id_scores = self.detector.score_samples(id_train_part1).cpu().numpy() + elif self.method == "kde": + id_scores = self.detector.score_samples(id_train_part1).cpu().numpy() else: id_train_part1_np, _ = train_test_split( - self.id_train_prdc.cpu().numpy(), test_size=0.5, random_state=42) - if self.method == 'gmm': + self.id_train_prdc.cpu().numpy(), test_size=0.5, random_state=42 + ) + if self.method == "gmm": id_scores = self.detector.score_samples(id_train_part1_np) - elif self.method == 'kde': + elif self.method == "kde": id_scores = self.detector.logpdf(id_train_part1_np.T) threshold = np.percentile(id_scores, 5) return np.where(scores > threshold, 1, -1) @@ -520,50 +532,52 @@ def evaluate(self, id_image_paths, ood_image_paths): if not self.is_fitted: raise RuntimeError("Detector must be fitted before evaluation") - print( - f"Evaluating on {len(id_image_paths)} ID and {len(ood_image_paths)} OOD images...") - + print(f"Evaluating on {len(id_image_paths)} ID and {len(ood_image_paths)} OOD images...") + # Fuse ID and OOD samples for processing together all_image_paths = id_image_paths + ood_image_paths all_scores = self._get_ood_scores(all_image_paths, cache_name="eval_fused") - + # Split the scores back to ID and OOD - id_scores = all_scores[:len(id_image_paths)] - ood_scores = all_scores[len(id_image_paths):] + id_scores = all_scores[: len(id_image_paths)] + ood_scores = all_scores[len(id_image_paths) :] print("\nScore Statistics:") print( - f"ID - Mean: {np.mean(id_scores):.4f}, Std: {np.std(id_scores):.4f}, Min: {np.min(id_scores):.4f}, Max: {np.max(id_scores):.4f}") + f"ID - Mean: {np.mean(id_scores):.4f}, Std: {np.std(id_scores):.4f}, Min: {np.min(id_scores):.4f}, Max: {np.max(id_scores):.4f}" + ) print( - f"OOD - Mean: {np.mean(ood_scores):.4f}, Std: {np.std(ood_scores):.4f}, Min: {np.min(ood_scores):.4f}, Max: {np.max(ood_scores):.4f}") + f"OOD - Mean: {np.mean(ood_scores):.4f}, Std: {np.std(ood_scores):.4f}, Min: {np.min(ood_scores):.4f}, Max: {np.max(ood_scores):.4f}" + ) - labels = np.concatenate( - [np.ones(len(id_scores)), np.zeros(len(ood_scores))]) + labels = np.concatenate([np.ones(len(id_scores)), np.zeros(len(ood_scores))]) scores_all = np.concatenate([id_scores, ood_scores]) auroc = roc_auc_score(labels, scores_all) fpr, tpr, _ = roc_curve(labels, scores_all) idx = np.argmin(np.abs(tpr - 0.95)) fpr95 = fpr[idx] if idx < len(fpr) else 1.0 - precision_vals, recall_vals, _ = precision_recall_curve( - labels, scores_all) + precision_vals, recall_vals, _ = precision_recall_curve(labels, scores_all) auprc = average_precision_score(labels, scores_all) - f1_scores = 2 * (precision_vals * recall_vals) / \ - (precision_vals + recall_vals + 1e-10) + f1_scores = 2 * (precision_vals * recall_vals) / (precision_vals + recall_vals + 1e-10) f1_score = np.max(f1_scores) - return { - "AUROC": auroc, - "FPR@95TPR": fpr95, - "AUPRC": auprc, - "F1": f1_score - } + return {"AUROC": auroc, "FPR@95TPR": fpr95, "AUPRC": auprc, "F1": f1_score} ################################################### # Custom Detectors: TorchGMM, TorchKDE, TorchOCSVM ################################################### + class TorchGMM: - def __init__(self, n_components=1, covariance_type='full', max_iter=100, tol=1e-3, reg_covar=1e-6, device='cuda'): + def __init__( + self, + n_components=1, + covariance_type="full", + max_iter=100, + tol=1e-3, + reg_covar=1e-6, + device="cuda", + ): """ A PyTorch implementation of a Gaussian Mixture Model that closely follows scikit-learn's GaussianMixture (for the 'full' covariance case). @@ -576,7 +590,7 @@ def __init__(self, n_components=1, covariance_type='full', max_iter=100, tol=1e- reg_covar (float): Non-negative regularization added to the diagonal of covariance matrices. device (str): 'cuda' or 'cpu'. """ - if covariance_type != 'full': + if covariance_type != "full": raise NotImplementedError("Only 'full' covariance is implemented.") self.n_components = n_components self.covariance_type = covariance_type @@ -586,8 +600,8 @@ def __init__(self, n_components=1, covariance_type='full', max_iter=100, tol=1e- self.device = device # Parameters to be learned - self.weights_ = None # shape: (n_components,) - self.means_ = None # shape: (n_components, n_features) + self.weights_ = None # shape: (n_components,) + self.means_ = None # shape: (n_components, n_features) # shape: (n_components, n_features, n_features) self.covariances_ = None self.converged_ = False @@ -603,8 +617,7 @@ def _initialize_parameters(self, X): self.means_ = X[indices].clone() # Initialize covariances as diagonal matrices based on sample variance variance = torch.var(X, dim=0) + self.reg_covar - self.covariances_ = torch.stack( - [torch.diag(variance) for _ in range(K)], dim=0) + self.covariances_ = torch.stack([torch.diag(variance) for _ in range(K)], dim=0) def _estimate_log_gaussian_prob(self, X): # X: (n_samples, n_features) @@ -612,8 +625,8 @@ def _estimate_log_gaussian_prob(self, X): # Create a batched MultivariateNormal distribution for each component mvn = torch.distributions.MultivariateNormal( self.means_, - covariance_matrix=self.covariances_ + self.reg_covar * - torch.eye(n_features, device=self.device) + covariance_matrix=self.covariances_ + + self.reg_covar * torch.eye(n_features, device=self.device), ) # X has shape (n_samples, n_features); unsqueeze to (n_samples, 1, n_features) to broadcast over components # Expected shape: (n_samples, n_components) @@ -622,8 +635,7 @@ def _estimate_log_gaussian_prob(self, X): def _e_step(self, X): # Compute log probabilities for each sample and each component - log_prob = self._estimate_log_gaussian_prob( - X) # shape: (n_samples, n_components) + log_prob = self._estimate_log_gaussian_prob(X) # shape: (n_samples, n_components) # Add log weights weighted_log_prob = log_prob + torch.log(self.weights_ + 1e-10) # Compute log-sum-exp for each sample @@ -647,8 +659,7 @@ def _m_step(self, X, resp): weighted_diff = diff * resp[:, k].unsqueeze(1) cov_k = (weighted_diff.t() @ diff) / (Nk[k] + 1e-10) # Add regularization for numerical stability - cov_k = cov_k + self.reg_covar * \ - torch.eye(n_features, device=self.device) + cov_k = cov_k + self.reg_covar * torch.eye(n_features, device=self.device) covariances.append(cov_k) self.covariances_ = torch.stack(covariances, dim=0) @@ -704,14 +715,17 @@ def bic(self, X): float: BIC value. """ n_samples, n_features = X.shape - p = (self.n_components - 1) + self.n_components * n_features + \ - self.n_components * n_features * (n_features + 1) / 2 + p = ( + (self.n_components - 1) + + self.n_components * n_features + + self.n_components * n_features * (n_features + 1) / 2 + ) log_likelihood = self.score_samples(X).sum().item() return -2 * log_likelihood + p * np.log(n_samples) class TorchKDE: - def __init__(self, dataset, bw_method=None, weights=None, device='cuda'): + def __init__(self, dataset, bw_method=None, weights=None, device="cuda"): # Use float32 for MPS devices, otherwise float64. dtype = torch.float32 if "mps" in device.lower() else torch.float64 self.device = device @@ -721,16 +735,15 @@ def __init__(self, dataset, bw_method=None, weights=None, device='cuda'): # Process weights (assumed to be a torch.Tensor on device if provided). if weights is not None: self.weights = (weights / weights.sum()).to(dtype=torch.float32) - self.neff = (self.weights.sum() ** 2) / (self.weights ** 2).sum() + self.neff = (self.weights.sum() ** 2) / (self.weights**2).sum() # Weighted covariance: cov = sum_i w_i (x_i - mean)(x_i - mean)^T / (1 - sum(w_i^2)) - weighted_mean = ( - self.dataset * self.weights.unsqueeze(0)).sum(dim=1, keepdim=True) + weighted_mean = (self.dataset * self.weights.unsqueeze(0)).sum(dim=1, keepdim=True) diff = self.dataset - weighted_mean - cov = (diff * self.weights.unsqueeze(0)) @ diff.T / \ - (1 - (self.weights**2).sum()) + cov = (diff * self.weights.unsqueeze(0)) @ diff.T / (1 - (self.weights**2).sum()) else: self.weights = torch.full( - (self.n,), 1.0 / self.n, dtype=torch.float32, device=self.device) + (self.n,), 1.0 / self.n, dtype=torch.float32, device=self.device + ) self.neff = self.n weighted_mean = self.dataset.mean(dim=1, keepdim=True) diff = self.dataset - weighted_mean @@ -747,9 +760,9 @@ def silverman_factor(self): return (self.neff * (self.d + 2) / 4.0) ** (-1.0 / (self.d + 4)) def set_bandwidth(self, bw_method=None): - if bw_method is None or bw_method == 'scott': + if bw_method is None or bw_method == "scott": self.factor = self.scotts_factor() - elif bw_method == 'silverman': + elif bw_method == "silverman": self.factor = self.silverman_factor() elif isinstance(bw_method, (int, float)): self.factor = float(bw_method) @@ -761,14 +774,13 @@ def set_bandwidth(self, bw_method=None): def _compute_covariance(self): # Scale the data covariance by the bandwidth factor squared. - self.covariance = self._data_covariance * (self.factor ** 2) + self.covariance = self._data_covariance * (self.factor**2) # Increase regularization to ensure positive definiteness. reg = 1e-6 self.cho_cov = torch.linalg.cholesky( - self.covariance + reg * - torch.eye(self.d, device=self.device, dtype=self.dataset.dtype) + self.covariance + reg * torch.eye(self.d, device=self.device, dtype=self.dataset.dtype) ) - self.log_det = 2. * torch.log(torch.diag(self.cho_cov)).sum() + self.log_det = 2.0 * torch.log(torch.diag(self.cho_cov)).sum() def evaluate(self, points): # Assume points is already a torch.Tensor on the proper device. @@ -779,7 +791,8 @@ def evaluate(self, points): points = points.T if points.shape[0] != self.d: raise ValueError( - f"Expected input with one dimension = {self.d}, but got shape {points.shape}") + f"Expected input with one dimension = {self.d}, but got shape {points.shape}" + ) # Compute differences: shape (d, n, m) diff = self.dataset.unsqueeze(2) - points.unsqueeze(1) # Flatten differences for cholesky_solve: (d, n*m) @@ -798,7 +811,7 @@ def logpdf(self, points): class TorchOCSVM: - def __init__(self, nu=0.1, n_iters=1000, lr=1e-3, device='cuda'): + def __init__(self, nu=0.1, n_iters=1000, lr=1e-3, device="cuda"): self.nu = nu self.n_iters = n_iters self.lr = lr @@ -821,18 +834,16 @@ def fit(self, X): # Compute slack = max(0, rho - w^T x) for each sample. # apply a smooth approximation? slack = torch.clamp(self.rho - scores, min=0) - loss = 0.5 * torch.norm(self.w) ** 2 - \ - self.rho + (1 / (self.nu * n)) * slack.sum() + loss = 0.5 * torch.norm(self.w) ** 2 - self.rho + (1 / (self.nu * n)) * slack.sum() loss.backward() optimizer.step() if (i + 1) % 200 == 0: - print( - f"OCSVM iter {i+1}/{self.n_iters}, loss: {loss.item():.4f}") + print(f"OCSVM iter {i+1}/{self.n_iters}, loss: {loss.item():.4f}") return self def decision_function(self, X): X = X.to(self.device) - return (X @ self.w - self.rho) + return X @ self.w - self.rho def predict(self, X): decision = self.decision_function(X) diff --git a/forte_demo.py b/forte_demo.py index 0dbc5ef..6b43ba9 100644 --- a/forte_demo.py +++ b/forte_demo.py @@ -1,135 +1,138 @@ +import argparse +import logging import os +import time + +import matplotlib.pyplot as plt import numpy as np import torch import torchvision import torchvision.transforms as transforms from PIL import Image -import matplotlib.pyplot as plt from tqdm import tqdm -import time -import argparse -import logging + from forte_api import ForteOODDetector # Configure logging logging.basicConfig( level=logging.INFO, - format='%(asctime)s - %(levelname)s - %(message)s', - datefmt='%Y-%m-%d %H:%M:%S' + format="%(asctime)s - %(levelname)s - %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", ) logger = logging.getLogger("ForteDemo") + def save_dataset_as_png(dataset, save_dir, num_images=1000): """ Save a subset of a dataset as PNG images. - + Args: dataset: PyTorch dataset save_dir (str): Directory to save images num_images (int): Number of images to save - + Returns: list: List of paths to saved images """ logger.info(f"Saving {min(num_images, len(dataset))} images to {save_dir}") os.makedirs(save_dir, exist_ok=True) paths = [] - + for i in tqdm(range(min(num_images, len(dataset))), desc=f"Saving images to {save_dir}"): image, label = dataset[i] # Convert tensor to PIL Image if isinstance(image, torch.Tensor): image = transforms.ToPILImage()(image) - + # Save the image path = os.path.join(save_dir, f"{i}_label{label}.png") image.save(path) paths.append(path) - + return paths + def load_cifar_datasets(): """ Load CIFAR10 and CIFAR100 datasets. - + Returns: tuple: CIFAR10 train and test sets, CIFAR100 test set """ logger.info("Loading CIFAR10 and CIFAR100 datasets...") # Define transform - transform = transforms.Compose([ - transforms.ToTensor() - ]) - + transform = transforms.Compose([transforms.ToTensor()]) + # Load CIFAR10 train and test sets cifar10_train = torchvision.datasets.CIFAR10( - root='./data', train=True, download=True, transform=transform + root="./data", train=True, download=True, transform=transform ) - + cifar10_test = torchvision.datasets.CIFAR10( - root='./data', train=False, download=True, transform=transform + root="./data", train=False, download=True, transform=transform ) - + # Load CIFAR100 test set cifar100_test = torchvision.datasets.CIFAR100( - root='./data', train=False, download=True, transform=transform + root="./data", train=False, download=True, transform=transform + ) + + logger.info( + f"Loaded datasets - CIFAR10 train: {len(cifar10_train)} images, " + + f"CIFAR10 test: {len(cifar10_test)} images, " + + f"CIFAR100 test: {len(cifar100_test)} images" ) - - logger.info(f"Loaded datasets - CIFAR10 train: {len(cifar10_train)} images, " + - f"CIFAR10 test: {len(cifar10_test)} images, " + - f"CIFAR100 test: {len(cifar100_test)} images") - + return cifar10_train, cifar10_test, cifar100_test + def print_training_phases(): """Print information about the phases of the Forte training pipeline.""" phases = [ - ("1. Data Preparation", - "Convert datasets to image files and prepare directories"), - - ("2. Feature Extraction", - "Extract semantic features using pretrained models (CLIP, ViTMSN, DINOv2)"), - - ("3. PRDC Computation", - "Compute Precision, Recall, Density, Coverage metrics from extracted features"), - - ("4. Detector Training", - "Train OOD detector (GMM, KDE, or OCSVM) on PRDC features"), - - ("5. Evaluation", - "Compute scores and performance metrics on test datasets") + ("1. Data Preparation", "Convert datasets to image files and prepare directories"), + ( + "2. Feature Extraction", + "Extract semantic features using pretrained models (CLIP, ViTMSN, DINOv2)", + ), + ( + "3. PRDC Computation", + "Compute Precision, Recall, Density, Coverage metrics from extracted features", + ), + ("4. Detector Training", "Train OOD detector (GMM, KDE, or OCSVM) on PRDC features"), + ("5. Evaluation", "Compute scores and performance metrics on test datasets"), ] - + logger.info("\n=== Forte OOD Detection Pipeline ===") for i, (phase, desc) in enumerate(phases): logger.info(f"{phase}: {desc}") - logger.info("="*40) + logger.info("=" * 40) + def main(args): # Print pipeline phases information print_training_phases() - + # Set random seed for reproducibility np.random.seed(args.seed) torch.manual_seed(args.seed) if torch.cuda.is_available(): torch.cuda.manual_seed(args.seed) - + logger.info(f"Running with configuration: {args}") - + # Create directories os.makedirs("data", exist_ok=True) os.makedirs(args.embedding_dir, exist_ok=True) - + # Phase 1: Data Preparation logger.info("\n=== Phase 1: Data Preparation ===") cifar10_train, cifar10_test, cifar100_test = load_cifar_datasets() - + # Create directories for images os.makedirs("data/cifar10/train", exist_ok=True) os.makedirs("data/cifar10/test", exist_ok=True) os.makedirs("data/cifar100/test", exist_ok=True) - + # Check if we need to save images if not os.path.exists("data/cifar10/train/0_label0.png") or args.force_save: logger.info("Converting datasets to PNG images...") @@ -137,95 +140,119 @@ def main(args): cifar10_train_paths = save_dataset_as_png( cifar10_train, "data/cifar10/train", num_images=args.num_train_images ) - + # Save CIFAR10 test images cifar10_test_paths = save_dataset_as_png( cifar10_test, "data/cifar10/test", num_images=args.num_test_images ) - + # Save CIFAR100 test images cifar100_test_paths = save_dataset_as_png( cifar100_test, "data/cifar100/test", num_images=args.num_test_images ) else: logger.info("Using previously saved images...") - cifar10_train_paths = sorted([os.path.join("data/cifar10/train", f) - for f in os.listdir("data/cifar10/train") - if f.endswith(".png")])[:args.num_train_images] - - cifar10_test_paths = sorted([os.path.join("data/cifar10/test", f) - for f in os.listdir("data/cifar10/test") - if f.endswith(".png")])[:args.num_test_images] - - cifar100_test_paths = sorted([os.path.join("data/cifar100/test", f) - for f in os.listdir("data/cifar100/test") - if f.endswith(".png")])[:args.num_test_images] - + cifar10_train_paths = sorted( + [ + os.path.join("data/cifar10/train", f) + for f in os.listdir("data/cifar10/train") + if f.endswith(".png") + ] + )[: args.num_train_images] + + cifar10_test_paths = sorted( + [ + os.path.join("data/cifar10/test", f) + for f in os.listdir("data/cifar10/test") + if f.endswith(".png") + ] + )[: args.num_test_images] + + cifar100_test_paths = sorted( + [ + os.path.join("data/cifar100/test", f) + for f in os.listdir("data/cifar100/test") + if f.endswith(".png") + ] + )[: args.num_test_images] + logger.info(f"Number of CIFAR10 training images: {len(cifar10_train_paths)}") logger.info(f"Number of CIFAR10 test images: {len(cifar10_test_paths)}") logger.info(f"Number of CIFAR100 test images: {len(cifar100_test_paths)}") - + # Phase 2-4: Feature Extraction, PRDC Computation, and Detector Training logger.info("\n=== Phase 2-4: Feature Extraction, PRDC Computation, and Detector Training ===") start_time = time.time() - logger.info(f"Creating ForteOODDetector with method: {args.method}, nearest_k: {args.nearest_k}") + logger.info( + f"Creating ForteOODDetector with method: {args.method}, nearest_k: {args.nearest_k}" + ) detector = ForteOODDetector( batch_size=args.batch_size, device=args.device, embedding_dir=args.embedding_dir, method=args.method, - nearest_k=args.nearest_k + nearest_k=args.nearest_k, ) - + # Fit the detector - this performs feature extraction, PRDC computation, and detector training logger.info(f"Fitting ForteOODDetector on {len(cifar10_train_paths)} in-distribution images...") detector.fit(cifar10_train_paths, val_split=0.2, random_state=args.seed) training_time = time.time() - start_time logger.info(f"Training completed in {training_time:.2f} seconds") - + # Phase 5: Evaluation logger.info("\n=== Phase 5: Evaluation ===") - + # Benchmark on ID data (CIFAR10 test) logger.info("Benchmarking detector on CIFAR10 (in-distribution)...") start_time = time.time() id_scores = detector._get_ood_scores(cifar10_test_paths, cache_name="id_benchmark") id_prediction_time = time.time() - start_time - logger.info(f"ID prediction time for {len(cifar10_test_paths)} images: {id_prediction_time:.2f} seconds " + - f"({id_prediction_time/len(cifar10_test_paths):.4f} sec/image)") - + logger.info( + f"ID prediction time for {len(cifar10_test_paths)} images: {id_prediction_time:.2f} seconds " + + f"({id_prediction_time/len(cifar10_test_paths):.4f} sec/image)" + ) + # Benchmark on OOD data (CIFAR100 test) logger.info("Benchmarking detector on CIFAR100 (out-of-distribution)...") start_time = time.time() ood_scores = detector._get_ood_scores(cifar100_test_paths, cache_name="ood_benchmark") ood_prediction_time = time.time() - start_time - logger.info(f"OOD prediction time for {len(cifar100_test_paths)} images: {ood_prediction_time:.2f} seconds " + - f"({ood_prediction_time/len(cifar100_test_paths):.4f} sec/image)") - + logger.info( + f"OOD prediction time for {len(cifar100_test_paths)} images: {ood_prediction_time:.2f} seconds " + + f"({ood_prediction_time/len(cifar100_test_paths):.4f} sec/image)" + ) + # Score statistics logger.info("\nScore Statistics:") - logger.info(f"CIFAR10 (ID) - Mean: {np.mean(id_scores):.4f}, Std: {np.std(id_scores):.4f}, " + - f"Min: {np.min(id_scores):.4f}, Max: {np.max(id_scores):.4f}") - logger.info(f"CIFAR100 (OOD) - Mean: {np.mean(ood_scores):.4f}, Std: {np.std(ood_scores):.4f}, " + - f"Min: {np.min(ood_scores):.4f}, Max: {np.max(ood_scores):.4f}") - + logger.info( + f"CIFAR10 (ID) - Mean: {np.mean(id_scores):.4f}, Std: {np.std(id_scores):.4f}, " + + f"Min: {np.min(id_scores):.4f}, Max: {np.max(id_scores):.4f}" + ) + logger.info( + f"CIFAR100 (OOD) - Mean: {np.mean(ood_scores):.4f}, Std: {np.std(ood_scores):.4f}, " + + f"Min: {np.min(ood_scores):.4f}, Max: {np.max(ood_scores):.4f}" + ) + # Calculate threshold based on ID scores threshold = np.percentile(id_scores, 5) # 5th percentile logger.info(f"Suggested decision threshold (5th percentile of ID scores): {threshold:.4f}") - + # Calculate detection accuracy id_correct = (id_scores > threshold).mean() - ood_correct = (ood_scores <= threshold).mean() - overall_acc = (id_correct * len(id_scores) + ood_correct * len(ood_scores)) / (len(id_scores) + len(ood_scores)) + ood_correct = (ood_scores <= threshold).mean() + overall_acc = (id_correct * len(id_scores) + ood_correct * len(ood_scores)) / ( + len(id_scores) + len(ood_scores) + ) logger.info(f"ID Detection Rate: {id_correct:.4f}, OOD Detection Rate: {ood_correct:.4f}") logger.info(f"Overall Accuracy: {overall_acc:.4f}") - + # Full evaluation on mixed test set logger.info("\nPerforming full evaluation on CIFAR10/CIFAR100 test sets...") evaluation_start_time = time.time() results = detector.evaluate(cifar10_test_paths, cifar100_test_paths) evaluation_time = time.time() - evaluation_start_time - + # Print performance metrics logger.info("\n=== OOD Detection Performance ===") logger.info(f"Method: {args.method}, Nearest_k: {args.nearest_k}") @@ -234,126 +261,160 @@ def main(args): logger.info(f"AUPRC: {results['AUPRC']:.4f}") logger.info(f"F1 Score: {results['F1']:.4f}") logger.info(f"Evaluation time: {evaluation_time:.2f} seconds") - + # Visualize results if args.visualize: logger.info("\nGenerating visualizations...") - + # Plot score distributions plt.figure(figsize=(10, 6)) - bins = np.linspace(min(np.min(id_scores), np.min(ood_scores)), - max(np.max(id_scores), np.max(ood_scores)), - 30) - - plt.hist(id_scores, bins=bins, alpha=0.7, label='CIFAR10 (In-Distribution)', density=True) - plt.hist(ood_scores, bins=bins, alpha=0.7, label='CIFAR100 (Out-of-Distribution)', density=True) - + bins = np.linspace( + min(np.min(id_scores), np.min(ood_scores)), + max(np.max(id_scores), np.max(ood_scores)), + 30, + ) + + plt.hist(id_scores, bins=bins, alpha=0.7, label="CIFAR10 (In-Distribution)", density=True) + plt.hist( + ood_scores, bins=bins, alpha=0.7, label="CIFAR100 (Out-of-Distribution)", density=True + ) + # Add threshold line - plt.axvline(x=threshold, color='r', linestyle='--', alpha=0.7, label=f'Threshold ({threshold:.4f})') - + plt.axvline( + x=threshold, color="r", linestyle="--", alpha=0.7, label=f"Threshold ({threshold:.4f})" + ) + plt.legend() - plt.title(f'ForteOODDetector Scores ({args.method}, nearest_k={args.nearest_k})') - plt.xlabel('OOD Score (higher = more in-distribution like)') - plt.ylabel('Density') + plt.title(f"ForteOODDetector Scores ({args.method}, nearest_k={args.nearest_k})") + plt.xlabel("OOD Score (higher = more in-distribution like)") + plt.ylabel("Density") plt.grid(True, alpha=0.3) - + # Save figure plt.savefig(f"forte_{args.method}_results.png") logger.info(f"Score distribution saved to forte_{args.method}_results.png") - + # Show examples with predictions num_examples = min(5, len(cifar10_test_paths), len(cifar100_test_paths)) - + fig, axes = plt.subplots(2, num_examples, figsize=(15, 6)) - + # CIFAR10 examples (should be classified as in-distribution) for i in range(num_examples): img = Image.open(cifar10_test_paths[i]) axes[0, i].imshow(img) - + score = id_scores[i] is_id = score > threshold correct = is_id # For ID samples, prediction is correct if classified as ID - - color = 'green' if correct else 'red' + + color = "green" if correct else "red" pred = "ID" if is_id else "OOD" - axes[0, i].set_title(f"CIFAR10 (true=ID)\nPred: {pred}\nScore: {score:.2f}", color=color) - axes[0, i].axis('off') - + axes[0, i].set_title( + f"CIFAR10 (true=ID)\nPred: {pred}\nScore: {score:.2f}", color=color + ) + axes[0, i].axis("off") + # CIFAR100 examples (should be classified as out-of-distribution) for i in range(num_examples): img = Image.open(cifar100_test_paths[i]) axes[1, i].imshow(img) - + score = ood_scores[i] is_id = score > threshold correct = not is_id # For OOD samples, prediction is correct if classified as OOD - - color = 'green' if correct else 'red' + + color = "green" if correct else "red" pred = "ID" if is_id else "OOD" - axes[1, i].set_title(f"CIFAR100 (true=OOD)\nPred: {pred}\nScore: {score:.2f}", color=color) - axes[1, i].axis('off') - + axes[1, i].set_title( + f"CIFAR100 (true=OOD)\nPred: {pred}\nScore: {score:.2f}", color=color + ) + axes[1, i].axis("off") + plt.tight_layout() plt.savefig("forte_examples.png") logger.info("Example predictions saved to forte_examples.png") - + # ROC curve plt.figure(figsize=(8, 6)) - + # Create labels (1 for ID, 0 for OOD) labels = np.concatenate([np.ones(len(id_scores)), np.zeros(len(ood_scores))]) scores_combined = np.concatenate([id_scores, ood_scores]) - + # Calculate ROC curve - from sklearn.metrics import roc_curve, auc + from sklearn.metrics import auc, roc_curve + fpr, tpr, _ = roc_curve(labels, scores_combined) roc_auc = auc(fpr, tpr) - - plt.plot(fpr, tpr, lw=2, label=f'ROC curve (area = {roc_auc:.2f})') - plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--', label='Random') - + + plt.plot(fpr, tpr, lw=2, label=f"ROC curve (area = {roc_auc:.2f})") + plt.plot([0, 1], [0, 1], color="navy", lw=2, linestyle="--", label="Random") + # Mark the FPR at 95% TPR idx_95tpr = np.argmin(np.abs(tpr - 0.95)) fpr_at_95tpr = fpr[idx_95tpr] - plt.scatter(fpr_at_95tpr, 0.95, color='red', - label=f'FPR@95TPR = {fpr_at_95tpr:.4f}', zorder=5) - + plt.scatter( + fpr_at_95tpr, 0.95, color="red", label=f"FPR@95TPR = {fpr_at_95tpr:.4f}", zorder=5 + ) + plt.xlim([0.0, 1.0]) plt.ylim([0.0, 1.05]) - plt.xlabel('False Positive Rate') - plt.ylabel('True Positive Rate') - plt.title(f'ROC Curve - {args.method.upper()}') + plt.xlabel("False Positive Rate") + plt.ylabel("True Positive Rate") + plt.title(f"ROC Curve - {args.method.upper()}") plt.legend(loc="lower right") plt.grid(alpha=0.3) - + plt.savefig(f"forte_{args.method}_roc.png") logger.info(f"ROC curve saved to forte_{args.method}_roc.png") + if __name__ == "__main__": parser = argparse.ArgumentParser(description="Forte OOD Detection Demo") parser.add_argument("--batch_size", type=int, default=32, help="Batch size for processing") - parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu", - help="Device to use") - parser.add_argument("--method", type=str, default="gmm", choices=["gmm", "kde", "ocsvm"], - help="OOD detection method") - parser.add_argument("--nearest_k", type=int, default=5, help="Number of nearest neighbors for PRDC") - parser.add_argument("--num_train_images", type=int, default=1000, help="Number of training images") - parser.add_argument("--num_test_images", type=int, default=500, help="Number of test images") + parser.add_argument( + "--device", + type=str, + default="cuda:0" if torch.cuda.is_available() else "mps", + help="Device to use", + ) + parser.add_argument( + "--method", + type=str, + default="gmm", + choices=["gmm", "kde", "ocsvm"], + help="OOD detection method", + ) + parser.add_argument( + "--nearest_k", type=int, default=5, help="Number of nearest neighbors for PRDC" + ) + parser.add_argument( + "--num_train_images", type=int, default=10000, help="Number of training images" + ) + parser.add_argument("--num_test_images", type=int, default=5000, help="Number of test images") parser.add_argument("--seed", type=int, default=42, help="Random seed") parser.add_argument("--visualize", action="store_true", help="Visualize results") - parser.add_argument("--force_save", action="store_true", help="Force save images even if they exist") - parser.add_argument("--embedding_dir", type=str, default="embeddings", help="Directory to store embeddings") - parser.add_argument("--log_level", type=str, default="INFO", - choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], - help="Logging level") - + parser.add_argument( + "--force_save", action="store_true", help="Force save images even if they exist" + ) + parser.add_argument( + "--embedding_dir", type=str, default="embeddings", help="Directory to store embeddings" + ) + parser.add_argument( + "--log_level", + type=str, + default="INFO", + choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], + help="Logging level", + ) + args = parser.parse_args() - + # Set logging level numeric_level = getattr(logging, args.log_level.upper(), None) if not isinstance(numeric_level, int): - raise ValueError(f'Invalid log level: {args.log_level}') + raise ValueError(f"Invalid log level: {args.log_level}") logging.getLogger().setLevel(numeric_level) - - main(args) \ No newline at end of file + + main(args) diff --git a/mkdocs.yml b/mkdocs.yml new file mode 100644 index 0000000..e14a6cb --- /dev/null +++ b/mkdocs.yml @@ -0,0 +1,132 @@ +site_name: Forte Detector +site_description: Finding Outliers with Representation Typicality Estimation - A PyTorch library for OOD detection +site_author: Debargha Ganguly +site_url: https://debarghag.github.io/forte-api + +repo_name: DebarghaG/forte-api +repo_url: https://github.com/DebarghaG/forte-api +edit_uri: edit/main/docs/ + +theme: + name: material + palette: + # Light mode + - media: "(prefers-color-scheme: light)" + scheme: default + primary: indigo + accent: indigo + toggle: + icon: material/brightness-7 + name: Switch to dark mode + # Dark mode + - media: "(prefers-color-scheme: dark)" + scheme: slate + primary: indigo + accent: indigo + toggle: + icon: material/brightness-4 + name: Switch to light mode + + features: + - navigation.instant + - navigation.tracking + - navigation.tabs + - navigation.sections + - navigation.expand + - navigation.top + - search.suggest + - search.highlight + - content.code.copy + - content.code.annotate + + icon: + repo: fontawesome/brands/github + + font: + text: Roboto + code: Roboto Mono + +nav: + - Home: index.md + - Getting Started: + - Installation: installation.md + - Quickstart: quickstart.md + - Reference: + - Algorithm: methods.md + - Configuration: user-guide.md + - API: api-reference.md + - Examples: examples.md + - Citation: citation.md + +plugins: + - search + - mkdocstrings: + handlers: + python: + options: + docstring_style: google + show_source: true + show_root_heading: true + show_category_heading: true + members_order: source + separate_signature: true + show_signature_annotations: true + signature_crossrefs: true + +markdown_extensions: + # Python Markdown + - abbr + - admonition + - attr_list + - def_list + - footnotes + - md_in_html + - tables + - toc: + permalink: true + toc_depth: 3 + + # Python Markdown Extensions + - pymdownx.arithmatex: + generic: true + - pymdownx.betterem: + smart_enable: all + - pymdownx.caret + - pymdownx.details + - pymdownx.highlight: + anchor_linenums: true + line_spans: __span + pygments_lang_class: true + - pymdownx.inlinehilite + - pymdownx.keys + - pymdownx.mark + - pymdownx.smartsymbols + - pymdownx.superfences: + custom_fences: + - name: mermaid + class: mermaid + format: !!python/name:pymdownx.superfences.fence_code_format + - pymdownx.tabbed: + alternate_style: true + - pymdownx.tasklist: + custom_checkbox: true + - pymdownx.tilde + +extra: + social: + - icon: fontawesome/brands/github + link: https://github.com/DebarghaG/forte-api + - icon: fontawesome/solid/paper-plane + link: https://openreview.net/pdf?id=7XNgVPxCiA + + version: + provider: mike + +extra_css: + - stylesheets/extra.css + +extra_javascript: + - javascripts/mathjax.js + - https://unpkg.com/mathjax@3/es5/tex-mml-chtml.js + +copyright: Copyright © 2025 Debargha Ganguly diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..3920346 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,197 @@ +[build-system] +requires = ["setuptools>=61.0", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "forte-detector" +version = "0.1.0" +description = "Forte: Finding Outliers with Representation Typicality Estimation - A PyTorch library for OOD detection using topology-aware representation learning" +readme = "README.md" +requires-python = ">=3.9" +license = {text = "MIT"} +authors = [ + {name = "Debargha Ganguly", email = "debargha.ganguly@gmail.com"}, + {name = "Warren Richard Morningstar"}, + {name = "Andrew Seohwan Yu"}, + {name = "Vipin Chaudhary"} +] +maintainers = [ + {name = "Debargha Ganguly", email = "debargha.ganguly@gmail.com"} +] +keywords = [ + "outlier-detection", + "out-of-distribution", + "ood-detection", + "computer-vision", + "deep-learning", + "pytorch", + "representation-learning", + "anomaly-detection", + "prdc", + "clip", + "vision-transformers" +] +classifiers = [ + "Development Status :: 4 - Beta", + "Intended Audience :: Science/Research", + "Intended Audience :: Developers", + "License :: OSI Approved :: MIT License", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Scientific/Engineering :: Image Recognition", + "Topic :: Software Development :: Libraries :: Python Modules", + "Operating System :: OS Independent", +] + +dependencies = [ + "torch>=2.0.0", + "torchvision>=0.15.0", + "transformers>=4.30.0", + "numpy>=1.24.0", + "scipy>=1.10.0", + "scikit-learn>=1.3.0", + "pillow>=9.0.0", + "tqdm>=4.65.0", +] + +[project.optional-dependencies] +dev = [ + "pytest>=7.0.0", + "pytest-cov>=4.0.0", + "pytest-mock>=3.10.0", + "black>=23.0.0", + "flake8>=6.0.0", + "isort>=5.12.0", + "mypy>=1.0.0", + "pre-commit>=3.0.0", + "bandit>=1.7.0", +] +docs = [ + "mkdocs>=1.5.0", + "mkdocs-material>=9.0.0", + "mkdocstrings[python]>=0.24.0", +] +viz = [ + "matplotlib>=3.7.0", +] +all = [ + "forte-detector[dev,docs,viz]", +] + +[project.urls] +Homepage = "https://github.com/debarghag/forte-detector" +Documentation = "https://debarghag.github.io/forte-api" +"Source Code" = "https://github.com/debarghag/forte-detector" +"Bug Tracker" = "https://github.com/debarghag/forte-detector/issues" +"Paper" = "https://openreview.net/pdf?id=7XNgVPxCiA" +"ICLR 2025" = "https://openreview.net/pdf?id=7XNgVPxCiA" + +[tool.setuptools] +package-dir = {"" = "src"} + +[tool.setuptools.packages.find] +where = ["src"] +include = ["forte*"] +exclude = ["tests*", "docs*", "examples*"] + +[tool.black] +line-length = 100 +target-version = ['py39', 'py310', 'py311', 'py312'] +include = '\.pyi?$' +extend-exclude = ''' +/( + # directories + \.eggs + | \.git + | \.hg + | \.mypy_cache + | \.tox + | \.venv + | build + | dist + | env +)/ +''' + +[tool.isort] +profile = "black" +line_length = 100 +multi_line_output = 3 +include_trailing_comma = true +force_grid_wrap = 0 +use_parentheses = true +ensure_newline_before_comments = true + +[tool.pytest.ini_options] +minversion = "7.0" +testpaths = ["tests"] +python_files = ["test_*.py"] +python_classes = ["Test*"] +python_functions = ["test_*"] +addopts = [ + "-v", + "--strict-markers", + "--strict-config", + "--cov=forte", + "--cov-report=term-missing", + "--cov-report=html", + "--cov-report=xml", +] +markers = [ + "slow: marks tests as slow (deselect with '-m \"not slow\"')", + "integration: marks tests as integration tests", + "unit: marks tests as unit tests", +] + +[tool.mypy] +python_version = "3.9" +warn_return_any = true +warn_unused_configs = true +disallow_untyped_defs = false +disallow_incomplete_defs = false +check_untyped_defs = false +no_implicit_optional = true +warn_redundant_casts = true +warn_unused_ignores = true +warn_no_return = true +strict_equality = true + +[tool.coverage.run] +source = ["src/forte"] +omit = [ + "*/tests/*", + "*/test_*.py", +] + +[tool.coverage.report] +exclude_lines = [ + "pragma: no cover", + "def __repr__", + "raise AssertionError", + "raise NotImplementedError", + "if __name__ == .__main__.:", + "if TYPE_CHECKING:", + "class .*\\bProtocol\\):", + "@(abc\\.)?abstractmethod", +] + +[tool.bandit] +exclude_dirs = ["tests", "env", ".venv", "venv"] +skips = ["B101"] + +[tool.flake8] +max-line-length = 100 +extend-ignore = ["E203", "W503", "D100", "D104"] +per-file-ignores = [ + "forte_api.py:D,F401,F841", + "forte_demo.py:D,F401", + "examples/*.py:D", + "tests/*.py:D", + "__init__.py:F401", +] +docstring-convention = "google" diff --git a/requirements.txt b/requirements.txt index 7382121..20c42eb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -35,4 +35,4 @@ torchvision==0.21.0 tqdm==4.67.1 transformers==4.50.3 typing_extensions==4.13.1 -urllib3==2.3.0 \ No newline at end of file +urllib3==2.3.0 diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..e22e130 --- /dev/null +++ b/setup.py @@ -0,0 +1,11 @@ +""" +Setup script for forte-detector package. + +This file provides backwards compatibility for tools that still use setup.py. +All configuration is in pyproject.toml. +""" + +from setuptools import setup + +if __name__ == "__main__": + setup() diff --git a/src/forte/__init__.py b/src/forte/__init__.py new file mode 100644 index 0000000..da17d30 --- /dev/null +++ b/src/forte/__init__.py @@ -0,0 +1,26 @@ +"""Forte: Finding Outliers with Representation Typicality Estimation. + +A PyTorch-based library for out-of-distribution (OOD) detection using +topology-aware representation learning from multiple pretrained vision models. + +Based on the ICLR 2025 paper: +Ganguly, D., Morningstar, W. R., Yu, A. S., & Chaudhary, V. (2025). +Forte: Finding Outliers with Representation Typicality Estimation. +In The Thirteenth International Conference on Learning Representations. +""" + +__version__ = "0.1.0" +__author__ = "Debargha Ganguly" +__email__ = "debargha.ganguly@gmail.com" +__license__ = "MIT" + +from .detector import ForteOODDetector +from .models import TorchGMM, TorchKDE, TorchOCSVM + +__all__ = [ + "ForteOODDetector", + "TorchGMM", + "TorchKDE", + "TorchOCSVM", + "__version__", +] diff --git a/src/forte/detector.py b/src/forte/detector.py new file mode 100644 index 0000000..ebc1d4f --- /dev/null +++ b/src/forte/detector.py @@ -0,0 +1,584 @@ +""" +Forte OOD Detector: Finding Outliers with Representation Typicality Estimation. + +This module implements the main ForteOODDetector class based on the ICLR 2025 paper. +""" + +import os +import time +from typing import Dict, List + +import numpy as np +import torch +from PIL import Image +from scipy.stats import gaussian_kde +from sklearn.metrics import ( + average_precision_score, + precision_recall_curve, + roc_auc_score, + roc_curve, +) +from sklearn.mixture import GaussianMixture +from sklearn.model_selection import train_test_split +from sklearn.svm import OneClassSVM +from tqdm import tqdm +from transformers import ( + AutoFeatureExtractor, + AutoImageProcessor, + AutoModel, + CLIPModel, + CLIPProcessor, + ViTMSNModel, +) + +from .models import TorchGMM, TorchKDE, TorchOCSVM + + +class ForteOODDetector: + """ + Forte OOD Detector: Finding Outliers Using Representation Typicality Estimation. + + This class implements the Forte method for OOD detection. It extracts features using + pretrained models and computes PRDC features using PyTorch tensors on GPU. + + Detector training can use either a custom GPU-based implementation + or fall back to CPU-based detectors from scikit-learn/SciPy. + + Example: + >>> detector = ForteOODDetector(method='gmm', nearest_k=5) + >>> detector.fit(id_image_paths) + >>> predictions = detector.predict(test_image_paths) + >>> metrics = detector.evaluate(id_test_paths, ood_test_paths) + """ + + def __init__( + self, batch_size=32, device=None, embedding_dir="./embeddings", nearest_k=5, method="gmm" + ): + """ + Initialize the ForteOODDetector. + + Args: + batch_size (int): Batch size for processing images. + device (str): Device to use for computation (e.g., 'cuda:0' or 'cpu'). + embedding_dir (str): Directory to store embeddings. + nearest_k (int): Number of nearest neighbors for PRDC computation. + method (str): Detector method ('gmm', 'kde', or 'ocsvm'). + """ + self.batch_size = batch_size + if device is None: + if torch.cuda.is_available(): + device = "cuda:0" + elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): + device = "mps" + else: + device = "cpu" + self.device = device + self.embedding_dir = embedding_dir + self.nearest_k = nearest_k + self.method = method + self.custom_detector = self.device != "cpu" + self.models = None + self.is_fitted = False + + # These will be set during fit + self.id_train_features = None # GPU tensors for feature extraction + self.id_train_prdc = None # Combined PRDC features (GPU tensor) + self.detector = None + + os.makedirs(self.embedding_dir, exist_ok=True) + + def _load_image(self, path): + """Load an image from path.""" + try: + return Image.open(path).convert("RGB") + except Exception as e: + print(f"Error loading image {path}: {e}") + return None + + def _init_models(self): + """Initialize the models used for feature extraction.""" + print(f"Initializing models on {self.device}...") + device = self.device + models = [ + ( + "clip", + CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device), + CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32"), + ), + ( + "vitmsn", + ViTMSNModel.from_pretrained("facebook/vit-msn-base").to(device), + AutoFeatureExtractor.from_pretrained("facebook/vit-msn-base"), + ), + ( + "dinov2", + AutoModel.from_pretrained("facebook/dinov2-base").to(device), + AutoImageProcessor.from_pretrained("facebook/dinov2-base"), + ), + ] + return models + + def _extract_features_batch(self, image_paths, batch_idx=0): + """ + Extract features for a batch of images using multiple models. + + Args: + image_paths (list): List of image paths. + batch_idx (int): Batch index for progress tracking. + + Returns: + dict: Dictionary of features for each model (torch tensors on GPU). + """ + # Load images using the helper method and filter out failures + images = [self._load_image(path) for path in image_paths] + images = [img for img in images if img is not None] + + if not images: + return { + model_name: torch.empty(0, device=self.device) for model_name, _, _ in self.models + } + + all_features = {} + # Process each model using its corresponding processor + for model_name, model, processor in self.models: + inputs = processor(images=images, return_tensors="pt", padding=True).to(self.device) + try: + with torch.no_grad(): + if model_name == "clip": + features = model.get_image_features(**inputs) + elif model_name in ["vitmsn", "dinov2"]: + features = model(**inputs).last_hidden_state[:, 0, :] + else: + raise ValueError(f"Unsupported model: {model_name}") + all_features[model_name] = features + except Exception as e: + print(f"Error extracting features with {model_name}: {e}") + all_features[model_name] = torch.empty(0, device=self.device) + return all_features + + def _extract_features(self, image_paths, name="tmp"): + """ + Extract features from all images using the models. + + Args: + image_paths (list): List of image paths. + name (str): Identifier for caching. + + Returns: + dict: Dictionary of features for each model (torch tensors on GPU). + """ + if self.models is None: + self.models = self._init_models() + + all_features = {model_name: [] for model_name, _, _ in self.models} + models_to_process = [] + + for model_name, _, _ in self.models: + embedding_file = os.path.join(self.embedding_dir, f"{name}_{model_name}_features.pt") + if os.path.exists(embedding_file): + print(f"Loading pre-computed features from {embedding_file}") + loaded = torch.load(embedding_file, map_location=self.device) + all_features[model_name] = loaded + if loaded.size(0) != len(image_paths): + print( + f"Warning: Cached features count ({loaded.size(0)}) doesn't " + f"match image count ({len(image_paths)}). " + f"Recomputing for {model_name}." + ) + all_features[model_name] = [] + models_to_process.append(model_name) + else: + print(f"Feature shape for {model_name}: {loaded.shape}") + else: + models_to_process.append(model_name) + + if not models_to_process: + return all_features + + for i in tqdm(range(0, len(image_paths), self.batch_size), desc="Extracting features"): + batch_paths = image_paths[i : i + self.batch_size] + batch_features = self._extract_features_batch(batch_paths, i // self.batch_size) + for model_name, features in batch_features.items(): + if features.numel() > 0 and model_name in models_to_process: + all_features[model_name].append(features) + + for model_name in models_to_process: + if all_features[model_name]: + all_features[model_name] = torch.cat(all_features[model_name], dim=0) + embedding_file = os.path.join( + self.embedding_dir, f"{name}_{model_name}_features.pt" + ) + torch.save(all_features[model_name], embedding_file) + print( + f"Saved {model_name} features with shape " + f"{all_features[model_name].shape} to {embedding_file}" + ) + else: + all_features[model_name] = torch.empty(0, device=self.device) + + return all_features + + def _compute_pairwise_distance(self, data_x, data_y=None): + """ + Compute pairwise distances between two sets of points using torch operations. + + Args: + data_x (torch.Tensor): Data points. + data_y (torch.Tensor, optional): Data points. + + Returns: + torch.Tensor: Pairwise distances. + """ + if data_y is None: + data_y = data_x + return torch.cdist(data_x, data_y, p=2) + + def _get_kth_value(self, unsorted, k, axis=-1): + """ + Get the kth smallest values along an axis using torch.topk. + + Args: + unsorted (torch.Tensor): Input tensor. + k (int): k value. + axis (int): Axis. + + Returns: + torch.Tensor: kth smallest values along the specified axis. + """ + values, _ = torch.topk(unsorted, k, largest=False) + return values.max(dim=axis).values + + def _compute_nearest_neighbour_distances(self, input_features, nearest_k): + """ + Compute distances to kth nearest neighbours using torch operations. + + Args: + input_features (torch.Tensor): Input features. + nearest_k (int): Number of nearest neighbors. + + Returns: + torch.Tensor: Distances to kth nearest neighbours. + """ + distances = self._compute_pairwise_distance(input_features) + radii = self._get_kth_value(distances, k=nearest_k + 1, axis=-1) + return radii + + def _compute_prdc_features(self, real_features, fake_features): + """ + Compute PRDC features using GPU-based tensor operations. + + Args: + real_features (torch.Tensor): Reference features. + fake_features (torch.Tensor): Query features. + + Returns: + torch.Tensor: PRDC features (recall, density, precision, coverage). + """ + num_real = real_features.size(0) + real_distances = self._compute_nearest_neighbour_distances(real_features, self.nearest_k) + fake_distances = self._compute_nearest_neighbour_distances(fake_features, self.nearest_k) + distance_matrix = self._compute_pairwise_distance(real_features, fake_features) + + precision = (distance_matrix < real_distances.unsqueeze(1)).any(dim=0).float() + recall = (distance_matrix < fake_distances).sum(dim=0).float() / num_real + density = (1.0 / float(self.nearest_k)) * ( + distance_matrix < real_distances.unsqueeze(1) + ).sum(dim=0).float() + coverage = (distance_matrix.min(dim=0).values < fake_distances).float() + + return torch.stack((recall, density, precision, coverage), dim=1) + + def fit( + self, id_image_paths: List[str], val_split: float = 0.2, random_state: int = 42 + ) -> "ForteOODDetector": + """ + Fit the OOD detector on in-distribution images. + + Args: + id_image_paths (list): Paths to in-distribution images. + val_split (float): Fraction for validation. + random_state (int): Random seed. + + Returns: + ForteOODDetector: The fitted detector. + """ + start_time = time.time() + print(f"Fitting ForteOODDetector on {len(id_image_paths)} images...") + + # Split paths into training and validation + id_train_paths, id_val_paths = train_test_split( + id_image_paths, test_size=val_split, random_state=random_state + ) + + print(f"Extracting features from {len(id_train_paths)} training images...") + self.id_train_features = self._extract_features(id_train_paths, name="id_train") + + print(f"Extracting features from {len(id_val_paths)} validation images...") + id_val_features = self._extract_features(id_val_paths, name="id_val") + + # Compute PRDC features for each model using GPU tensor operations + print("Computing PRDC features...") + X_id_train_prdc = [] + X_id_val_prdc = [] + for model_name in self.id_train_features: + print(f"Computing PRDC for {model_name}...") + features = self.id_train_features[model_name] + # Use torch-based splitting on GPU + train_idx = torch.randperm(features.size(0), device=self.device) + split = int(features.size(0) * 0.5) + id_train_part1 = features[train_idx[:split]] + id_train_part2 = features[train_idx[split:]] + + print(f" Training PRDC: {id_train_part1.shape} vs {id_train_part2.shape}") + train_prdc = self._compute_prdc_features(id_train_part1, id_train_part2) + X_id_train_prdc.append(train_prdc) + + val_feats = id_val_features[model_name] + print(f" Validation PRDC: {id_train_part1.shape} vs {val_feats.shape}") + val_prdc = self._compute_prdc_features(id_train_part1, val_feats) + X_id_val_prdc.append(val_prdc) + + self.id_train_prdc = torch.cat(X_id_train_prdc, dim=1) # still on GPU + id_val_prdc = torch.cat(X_id_val_prdc, dim=1) + print( + f"Combined PRDC features - Training: {self.id_train_prdc.shape}, " + f"Validation: {id_val_prdc.shape}" + ) + + print(f"Training detector ({self.method}) with custom_detector={self.custom_detector}...") + if self.method == "gmm": + best_bic = np.inf + best_n_components = 1 + best_model = None + for n_components in [1, 2, 4, 8, 16, 32, 64]: + if self.custom_detector: + gmm = TorchGMM( + n_components=n_components, max_iter=100, tol=1e-3, device=self.device + ) + gmm.fit(self.id_train_prdc) + bic_val = gmm.bic(self.id_train_prdc) + else: + id_train_prdc_cpu = self.id_train_prdc.cpu() + id_train_prdc_np = id_train_prdc_cpu.numpy() + gmm_sklearn: GaussianMixture = GaussianMixture( + n_components=n_components, + covariance_type="full", + random_state=random_state, + max_iter=100, + ) + gmm_sklearn.fit(id_train_prdc_np) + bic_val = float(gmm_sklearn.bic(id_train_prdc_np)) + gmm = gmm_sklearn + if bic_val < best_bic: + best_bic = bic_val + best_n_components = n_components + best_gmm = gmm + print(f"Selected {best_n_components} components for GMM with BIC={best_bic:.2f}") + self.detector = best_gmm + + elif self.method == "kde": + self.detector = ( + TorchKDE(self.id_train_prdc.T, bw_method="scott", device=self.device) + if self.custom_detector + else gaussian_kde(self.id_train_prdc.cpu().numpy().T, bw_method="scott") + ) + + elif self.method == "ocsvm": + if self.custom_detector: + best_accuracy = 0.0 + best_nu = 0.01 + best_model = None + for nu in [0.01, 0.05, 0.1, 0.2, 0.5]: + model = TorchOCSVM(nu=nu, n_iters=1000, lr=1e-3, device=self.device) + model.fit(self.id_train_prdc) + decision = model.decision_function(self.id_train_prdc) + accuracy = ( + torch.where(decision.detach() >= 0, 1, -1).float().mean().item() + 1 + ) / 2.0 + if accuracy > best_accuracy: + best_accuracy = accuracy + best_nu = nu + best_model = model + print(f"Selected nu={best_nu} for TorchOCSVM with accuracy {best_accuracy:.4f}") + self.detector = best_model + else: + best_accuracy = 0.0 + best_nu = 0.01 + for nu in [0.01, 0.05, 0.1, 0.2, 0.5]: + try: + id_train_prdc_np = self.id_train_prdc.cpu().numpy() + ocsvm = OneClassSVM(kernel="rbf", gamma="scale", nu=nu) + ocsvm.fit(id_train_prdc_np) + val_pred = ocsvm.predict(id_train_prdc_np) + accuracy = float(np.mean(val_pred == 1)) + if accuracy > best_accuracy: + best_accuracy = accuracy + best_nu = nu + except Exception as e: + print(f"Error with nu={nu}: {e}") + continue + print(f"Selected nu={best_nu} for OCSVM with accuracy {best_accuracy:.4f}") + id_train_prdc_np = self.id_train_prdc.cpu().numpy() + self.detector = OneClassSVM(kernel="rbf", gamma="scale", nu=best_nu) + self.detector.fit(id_train_prdc_np) + + self.is_fitted = True + fit_time = time.time() - start_time + print(f"ForteOODDetector fitted in {fit_time:.2f} seconds.") + return self + + def _get_ood_scores(self, image_paths, cache_name="test"): + """ + Get OOD scores for a set of images. + + Args: + image_paths (list): Paths to images. + cache_name (str): Identifier for caching. + + Returns: + np.ndarray: Array of scores. + """ + if not self.is_fitted: + raise RuntimeError("Detector must be fitted before prediction") + + test_features = self._extract_features(image_paths, name=cache_name) + X_test_prdc = [] + for model_name in test_features: + ref_features = self.id_train_features[model_name] + train_idx = torch.randperm(ref_features.size(0), device=self.device) + split = int(ref_features.size(0) * 0.5) + id_train_part1 = ref_features[train_idx[:split]] + test_tensor = test_features[model_name] + print( + f"Computing test PRDC for {model_name}: " + f"{id_train_part1.shape} vs {test_tensor.shape}" + ) + test_prdc = self._compute_prdc_features(id_train_part1, test_tensor) + X_test_prdc.append(test_prdc) + + X_test_prdc = torch.cat(X_test_prdc, dim=1) + print(f"Combined test PRDC shape: {X_test_prdc.shape}") + + # For custom (GPU-based) detectors, use torch outputs; then convert to numpy if needed. + if self.custom_detector: + if self.method == "gmm": + scores = self.detector.score_samples(X_test_prdc) + scores = scores.cpu().numpy() + elif self.method == "kde": + scores = self.detector.logpdf(X_test_prdc) + scores = scores.cpu().numpy() + elif self.method == "ocsvm": + scores = self.detector.decision_function(X_test_prdc) + scores = scores.detach().cpu().numpy() + else: + X_test_prdc_cpu = X_test_prdc.cpu().numpy() + if self.method == "gmm": + scores = self.detector.score_samples(X_test_prdc_cpu) + elif self.method == "kde": + scores = self.detector.logpdf(X_test_prdc_cpu.T) + elif self.method == "ocsvm": + scores = self.detector.decision_function(X_test_prdc_cpu) + return scores + + def predict(self, image_paths: List[str]) -> np.ndarray: + """ + Predict OOD status. + + Args: + image_paths (list): Paths to images. + + Returns: + np.ndarray: Binary predictions (1 for in-distribution, -1 for OOD). + """ + scores = self._get_ood_scores(image_paths) + threshold: float + if self.method == "ocsvm": + threshold = 0.0 + else: + if self.custom_detector: + ref_features = self.id_train_prdc + # Use a simple split for threshold estimation + train_idx = torch.randperm(ref_features.size(0), device=self.device) + split = int(ref_features.size(0) * 0.5) + id_train_part1 = ref_features[train_idx[:split]] + if self.method == "gmm": + id_scores = self.detector.score_samples(id_train_part1).cpu().numpy() + elif self.method == "kde": + id_scores = self.detector.score_samples(id_train_part1).cpu().numpy() + else: + id_train_part1_np, _ = train_test_split( + self.id_train_prdc.cpu().numpy(), test_size=0.5, random_state=42 + ) + if self.method == "gmm": + id_scores = self.detector.score_samples(id_train_part1_np) + elif self.method == "kde": + id_scores = self.detector.logpdf(id_train_part1_np.T) + threshold = float(np.percentile(id_scores, 5)) + predictions: np.ndarray = np.where(scores > threshold, 1, -1).astype(np.int64) + return predictions + + def predict_proba(self, image_paths: List[str]) -> np.ndarray: + """ + Return normalized probability scores for OOD detection. + + Args: + image_paths (list): Paths to images. + + Returns: + np.ndarray: Normalized scores. + """ + scores = self._get_ood_scores(image_paths) + min_score: float = float(np.min(scores)) + max_score: float = float(np.max(scores)) + if max_score > min_score: + normalized_scores = (scores - min_score) / (max_score - min_score) + else: + normalized_scores = np.ones_like(scores) * 0.5 + result: np.ndarray = np.asarray(normalized_scores) + return result + + def evaluate(self, id_image_paths: List[str], ood_image_paths: List[str]) -> Dict[str, float]: + """ + Evaluate the detector. + + Args: + id_image_paths (list): In-distribution image paths. + ood_image_paths (list): OOD image paths. + + Returns: + dict: Evaluation metrics. + """ + if not self.is_fitted: + raise RuntimeError("Detector must be fitted before evaluation") + + print(f"Evaluating on {len(id_image_paths)} ID and {len(ood_image_paths)} OOD images...") + + # Fuse ID and OOD samples for processing together + all_image_paths = id_image_paths + ood_image_paths + all_scores = self._get_ood_scores(all_image_paths, cache_name="eval_fused") + + # Split the scores back to ID and OOD + id_scores = all_scores[: len(id_image_paths)] + ood_scores = all_scores[len(id_image_paths) :] + + print("\nScore Statistics:") + print( + f"ID - Mean: {np.mean(id_scores):.4f}, Std: {np.std(id_scores):.4f}, " + f"Min: {np.min(id_scores):.4f}, Max: {np.max(id_scores):.4f}" + ) + print( + f"OOD - Mean: {np.mean(ood_scores):.4f}, Std: {np.std(ood_scores):.4f}, " + f"Min: {np.min(ood_scores):.4f}, Max: {np.max(ood_scores):.4f}" + ) + + labels = np.concatenate([np.ones(len(id_scores)), np.zeros(len(ood_scores))]) + scores_all = np.concatenate([id_scores, ood_scores]) + auroc = roc_auc_score(labels, scores_all) + fpr, tpr, _ = roc_curve(labels, scores_all) + idx = np.argmin(np.abs(tpr - 0.95)) + fpr95 = fpr[idx] if idx < len(fpr) else 1.0 + precision_vals, recall_vals, _ = precision_recall_curve(labels, scores_all) + auprc = average_precision_score(labels, scores_all) + f1_scores = 2 * (precision_vals * recall_vals) / (precision_vals + recall_vals + 1e-10) + f1_score: float = float(np.max(f1_scores)) + return {"AUROC": auroc, "FPR@95TPR": fpr95, "AUPRC": auprc, "F1": f1_score} diff --git a/src/forte/models.py b/src/forte/models.py new file mode 100644 index 0000000..f306cb8 --- /dev/null +++ b/src/forte/models.py @@ -0,0 +1,394 @@ +""" +Custom PyTorch implementations of OOD detection models. + +This module provides GPU-accelerated implementations of: +- Gaussian Mixture Models (GMM) +- Kernel Density Estimation (KDE) +- One-Class Support Vector Machines (OCSVM) +""" + +import math +from typing import Callable, Optional, Union + +import numpy as np +import torch + + +class TorchGMM: + """PyTorch implementation of Gaussian Mixture Model with GPU acceleration.""" + + def __init__( + self, + n_components=1, + covariance_type="full", + max_iter=100, + tol=1e-3, + reg_covar=1e-6, + device="cuda", + ): + """Initialize a PyTorch Gaussian Mixture Model. + + A PyTorch implementation that closely follows scikit-learn's + GaussianMixture (for the 'full' covariance case). + + Args: + n_components (int): Number of mixture components. + covariance_type (str): Only 'full' is implemented in this example. + max_iter (int): Maximum number of iterations. + tol (float): Convergence threshold. + reg_covar (float): Non-negative regularization added to the diagonal + of covariance matrices. + device (str): 'cuda' or 'cpu'. + """ + if covariance_type != "full": + raise NotImplementedError("Only 'full' covariance is implemented.") + self.n_components = n_components + self.covariance_type = covariance_type + self.max_iter = max_iter + self.tol = tol + self.reg_covar = reg_covar + self.device = device + + # Parameters to be learned + self.weights_ = None # shape: (n_components,) + self.means_ = None # shape: (n_components, n_features) + # shape: (n_components, n_features, n_features) + self.covariances_ = None + self.converged_ = False + self.lower_bound_ = -np.inf + + def _initialize_parameters(self, X): + n_samples, n_features = X.shape + K = self.n_components + # Initialize weights uniformly + self.weights_ = torch.full((K,), 1.0 / K, device=self.device) + # Initialize means by randomly selecting K samples + indices = torch.randperm(n_samples, device=self.device)[:K] + self.means_ = X[indices].clone() + # Initialize covariances as diagonal matrices based on sample variance + variance = torch.var(X, dim=0) + self.reg_covar + self.covariances_ = torch.stack([torch.diag(variance) for _ in range(K)], dim=0) + + def _estimate_log_gaussian_prob(self, X): + # X: (n_samples, n_features) + n_samples, n_features = X.shape + # Create a batched MultivariateNormal distribution for each component + covariances = self.covariances_ + self.reg_covar * torch.eye(n_features, device=self.device) + + # MPS doesn't support MultivariateNormal with cholesky, so we fall back to CPU + if self.device == "mps": + means_cpu = self.means_.cpu() + covariances_cpu = covariances.cpu() + X_cpu = X.cpu() + mvn = torch.distributions.MultivariateNormal( + means_cpu, + covariance_matrix=covariances_cpu, + ) + log_prob = mvn.log_prob(X_cpu.unsqueeze(1)).to(self.device) + else: + mvn = torch.distributions.MultivariateNormal( + self.means_, + covariance_matrix=covariances, + ) + # X has shape (n_samples, n_features); unsqueeze to (n_samples, 1, n_features) + # to broadcast over components + # Expected shape: (n_samples, n_components) + log_prob = mvn.log_prob(X.unsqueeze(1)) + return log_prob + + def _e_step(self, X): + # Compute log probabilities for each sample and each component + log_prob = self._estimate_log_gaussian_prob(X) # shape: (n_samples, n_components) + # Add log weights + weighted_log_prob = log_prob + torch.log(self.weights_ + 1e-10) + # Compute log-sum-exp for each sample + log_prob_norm = torch.logsumexp(weighted_log_prob, dim=1, keepdim=True) + # Compute responsibilities: r_ik = exp(weighted_log_prob - log_prob_norm) + log_resp = weighted_log_prob - log_prob_norm + resp = torch.exp(log_resp) + return resp, log_prob_norm.sum().item() + + def _m_step(self, X, resp): + n_samples, n_features = X.shape + Nk = resp.sum(dim=0) # shape: (n_components,) + self.weights_ = Nk / n_samples + # Update means + self.means_ = (resp.t() @ X) / (Nk.unsqueeze(1) + 1e-10) + # Update covariances + K = self.n_components + covariances = [] + for k in range(K): + diff = X - self.means_[k] + weighted_diff = diff * resp[:, k].unsqueeze(1) + cov_k = (weighted_diff.t() @ diff) / (Nk[k] + 1e-10) + # Add regularization for numerical stability + cov_k = cov_k + self.reg_covar * torch.eye(n_features, device=self.device) + covariances.append(cov_k) + self.covariances_ = torch.stack(covariances, dim=0) + + def fit(self, X: torch.Tensor) -> "TorchGMM": + """Fit the GMM model on data X. + + Args: + X (torch.Tensor): Input data of shape (n_samples, n_features) on self.device. + + Returns: + TorchGMM: The fitted model instance. + """ + X = X.to(self.device) + self._initialize_parameters(X) + lower_bound = -np.inf + + for i in range(self.max_iter): + resp, curr_lower_bound = self._e_step(X) + self._m_step(X, resp) + change = abs(curr_lower_bound - lower_bound) + lower_bound = curr_lower_bound + if change < self.tol: + self.converged_ = True + break + self.lower_bound_ = lower_bound + return self + + def score_samples(self, X: torch.Tensor) -> torch.Tensor: + """Compute the log-likelihood of each sample under the model. + + Args: + X (torch.Tensor): Data of shape (n_samples, n_features) on self.device. + + Returns: + torch.Tensor: Log probability for each sample. + """ + X = X.to(self.device) + log_prob = self._estimate_log_gaussian_prob(X) + weighted_log_prob = log_prob + torch.log(self.weights_ + 1e-10) + log_prob_norm = torch.logsumexp(weighted_log_prob, dim=1) + return log_prob_norm + + def bic(self, X: torch.Tensor) -> float: + """Bayesian Information Criterion for the current model. + + Args: + X (torch.Tensor): Data of shape (n_samples, n_features) on self.device. + + Returns: + float: BIC value. + """ + n_samples, n_features = X.shape + p = ( + (self.n_components - 1) + + self.n_components * n_features + + self.n_components * n_features * (n_features + 1) / 2 + ) + log_likelihood = self.score_samples(X).sum().item() + return float(-2 * log_likelihood + p * np.log(n_samples)) + + +class TorchKDE: + """PyTorch implementation of Kernel Density Estimation with GPU acceleration.""" + + def __init__( + self, + dataset: torch.Tensor, + bw_method: Optional[Union[str, float, Callable]] = None, + weights: Optional[torch.Tensor] = None, + device: str = "cuda", + ): + """Initialize Kernel Density Estimator. + + Args: + dataset (torch.Tensor): Data points of shape (d, n) where d is dimensionality. + bw_method (str or float): Bandwidth method ('scott', 'silverman', or float value). + weights (torch.Tensor, optional): Sample weights. + device (str): Device for computation ('cuda', 'mps', or 'cpu'). + """ + self.device = device + self.dataset = dataset # shape: (d, n) + self.d, self.n = self.dataset.shape + + # Process weights (assumed to be a torch.Tensor on device if provided). + if weights is not None: + self.weights = (weights / weights.sum()).to(dtype=torch.float32) + self.neff = ((self.weights.sum() ** 2) / (self.weights**2).sum()).item() + # Weighted covariance: cov = sum_i w_i (x_i - mean)(x_i - mean)^T / (1 - sum(w_i^2)) + weighted_mean = (self.dataset * self.weights.unsqueeze(0)).sum(dim=1, keepdim=True) + diff = self.dataset - weighted_mean + cov = (diff * self.weights.unsqueeze(0)) @ diff.T / (1 - (self.weights**2).sum()) + else: + self.weights = torch.full( + (self.n,), 1.0 / self.n, dtype=torch.float32, device=self.device + ) + self.neff = float(self.n) + weighted_mean = self.dataset.mean(dim=1, keepdim=True) + diff = self.dataset - weighted_mean + cov = diff @ diff.T / (self.n - 1) + self._data_covariance = cov # computed entirely on GPU + + # Set bandwidth and compute scaled covariance. + self.set_bandwidth(bw_method) + + def scotts_factor(self): + """Scott's rule for bandwidth selection.""" + return self.neff ** (-1.0 / (self.d + 4)) + + def silverman_factor(self): + """Silverman's rule for bandwidth selection.""" + return (self.neff * (self.d + 2) / 4.0) ** (-1.0 / (self.d + 4)) + + def set_bandwidth(self, bw_method=None): + """Set the bandwidth for the kernel.""" + if bw_method is None or bw_method == "scott": + self.factor = self.scotts_factor() + elif bw_method == "silverman": + self.factor = self.silverman_factor() + elif isinstance(bw_method, (int, float)): + self.factor = float(bw_method) + elif callable(bw_method): + self.factor = float(bw_method(self)) + else: + raise ValueError("Invalid bw_method.") + self._compute_covariance() + + def _compute_covariance(self): + # Scale the data covariance by the bandwidth factor squared. + self.covariance = self._data_covariance * (self.factor**2) + # Increase regularization to ensure positive definiteness. + reg = 1e-6 + cov_matrix = self.covariance + reg * torch.eye( + self.d, device=self.device, dtype=self.dataset.dtype + ) + + # MPS doesn't support linalg.cholesky, so we fall back to CPU for this operation + if self.device == "mps": + cov_cpu = cov_matrix.cpu() + self.cho_cov = torch.linalg.cholesky(cov_cpu).to(self.device) + else: + self.cho_cov = torch.linalg.cholesky(cov_matrix) + + self.log_det = 2.0 * torch.log(torch.diag(self.cho_cov)).sum() + + def evaluate(self, points: torch.Tensor) -> torch.Tensor: + """Evaluate the KDE at given points. + + Args: + points (torch.Tensor): Points to evaluate, shape (d, m) or (m, d). + + Returns: + torch.Tensor: Density estimates. + """ + # Assume points is already a torch.Tensor on the proper device. + if points.dim() == 1: + points = points.unsqueeze(0) + # If points are provided in (n, d) format (n > d), transpose them to (d, m) + if points.shape[0] > points.shape[1]: + points = points.T + if points.shape[0] != self.d: + raise ValueError( + f"Expected input with one dimension = {self.d}, but got shape {points.shape}" + ) + # Compute differences: shape (d, n, m) + diff = self.dataset.unsqueeze(2) - points.unsqueeze(1) + # Flatten differences for cholesky_solve: (d, n*m) + diff_flat = diff.reshape(self.d, -1) + + # MPS doesn't support cholesky_solve, so we fall back to CPU for this operation + if self.device == "mps": + diff_cpu = diff_flat.cpu() + cho_cov_cpu = self.cho_cov.cpu() + sol_flat = torch.cholesky_solve(diff_cpu, cho_cov_cpu).to(self.device) + else: + sol_flat = torch.cholesky_solve(diff_flat, self.cho_cov) + + sol = sol_flat.view(diff.shape) + energy = 0.5 * (diff * sol).sum(dim=0) # shape: (n, m) + result = torch.exp(-energy).T @ self.weights # shape: (m,) + norm_const = torch.exp(-self.log_det) / ((2 * math.pi) ** (self.d / 2)) + return torch.as_tensor(result * norm_const) + + def logpdf(self, points: torch.Tensor) -> torch.Tensor: + """Compute log probability density at given points. + + Args: + points (torch.Tensor): Points to evaluate. + + Returns: + torch.Tensor: Log probability densities. + """ + return torch.log(self.evaluate(points) + 1e-10) + + __call__ = evaluate + + +class TorchOCSVM: + """PyTorch implementation of One-Class SVM with GPU acceleration.""" + + def __init__(self, nu=0.1, n_iters=1000, lr=1e-3, device="cuda"): + """Initialize One-Class SVM. + + Args: + nu (float): Upper bound on fraction of outliers (between 0 and 1). + n_iters (int): Number of optimization iterations. + lr (float): Learning rate for Adam optimizer. + device (str): Device for computation. + """ + self.nu = nu + self.n_iters = n_iters + self.lr = lr + self.device = device + self.w = None + self.rho = None + + def fit(self, X: torch.Tensor) -> "TorchOCSVM": + """Fit the One-Class SVM model. + + Args: + X (torch.Tensor): Training data of shape (n_samples, n_features). + + Returns: + TorchOCSVM: The fitted model instance. + """ + # Ensure X is on the correct device. + X = X.to(self.device) + n, d = X.shape + # Initialize w and rho as nn.Parameter to ensure they are leaf tensors. + self.w = torch.nn.Parameter(torch.randn(d, device=self.device) * 0.01) + self.rho = torch.nn.Parameter(torch.tensor(0.0, device=self.device)) + # TODO: Adam is a good default choice, we can try SGD or adding a learning + # rate scheduler to adapt the learning rate during training. + optimizer = torch.optim.Adam([self.w, self.rho], lr=self.lr) + for i in range(self.n_iters): + optimizer.zero_grad() + scores = X @ self.w # shape: (n,) + # Compute slack = max(0, rho - w^T x) for each sample. + # apply a smooth approximation? + slack = torch.clamp(self.rho - scores, min=0) + loss = 0.5 * torch.norm(self.w) ** 2 - self.rho + (1 / (self.nu * n)) * slack.sum() + loss.backward() + optimizer.step() + if (i + 1) % 200 == 0: + print(f"OCSVM iter {i+1}/{self.n_iters}, loss: {loss.item():.4f}") + return self + + def decision_function(self, X: torch.Tensor) -> torch.Tensor: + """Compute the decision function for samples. + + Args: + X (torch.Tensor): Samples of shape (n_samples, n_features). + + Returns: + torch.Tensor: Decision values. + """ + X = X.to(self.device) + return torch.as_tensor(X @ self.w - self.rho) + + def predict(self, X: torch.Tensor) -> torch.Tensor: + """Predict class labels. + + Args: + X (torch.Tensor): Samples of shape (n_samples, n_features). + + Returns: + torch.Tensor: Predictions (1 for inlier, -1 for outlier). + """ + decision = self.decision_function(X) + return torch.where(decision >= 0, 1, -1) diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..6c06e40 --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1 @@ +"""Test suite for forte-detector package.""" diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..b36b6d9 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,119 @@ +""" +Pytest configuration and shared fixtures for forte-detector tests. +""" + +import os +import shutil +import tempfile + +import numpy as np +import pytest +import torch +from PIL import Image + + +@pytest.fixture(scope="session") +def device(): + """Determine the best available device for testing.""" + if torch.cuda.is_available(): + return "cuda:0" + elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): + return "mps" + else: + return "cpu" + + +@pytest.fixture(scope="session") +def tmp_dir(): + """Create a temporary directory for test artifacts.""" + tmpdir = tempfile.mkdtemp() + yield tmpdir + # Cleanup after all tests + shutil.rmtree(tmpdir, ignore_errors=True) + + +@pytest.fixture +def mock_image_paths(tmp_dir): + """Create mock image files for testing.""" + image_dir = os.path.join(tmp_dir, "mock_images") + os.makedirs(image_dir, exist_ok=True) + + paths = [] + for i in range(10): + # Create a small random RGB image + img = Image.fromarray(np.random.randint(0, 255, (32, 32, 3), dtype=np.uint8)) + path = os.path.join(image_dir, f"image_{i}.png") + img.save(path) + paths.append(path) + + return paths + + +@pytest.fixture +def small_mock_images(tmp_dir): + """Create a small set of mock images for quick tests.""" + image_dir = os.path.join(tmp_dir, "small_mock_images") + os.makedirs(image_dir, exist_ok=True) + + paths = [] + for i in range(3): + img = Image.fromarray(np.random.randint(0, 255, (32, 32, 3), dtype=np.uint8)) + path = os.path.join(image_dir, f"small_image_{i}.png") + img.save(path) + paths.append(path) + + return paths + + +@pytest.fixture +def mock_features(device): + """Create mock feature tensors for testing.""" + # Simulate features from 3 models + n_samples = 20 + feature_dims = [512, 768, 768] # CLIP, ViTMSN, DINOv2 + + features = {} + for i, dim in enumerate(feature_dims): + model_name = ["clip", "vitmsn", "dinov2"][i] + features[model_name] = torch.randn(n_samples, dim, device=device) + + return features + + +@pytest.fixture +def mock_prdc_features(device): + """Create mock PRDC features for testing detectors.""" + # PRDC features have 4 dimensions per model (precision, recall, density, coverage) + # With 3 models, total dimension is 12 + n_samples = 50 + n_features = 12 # 4 PRDC metrics * 3 models + + return torch.randn(n_samples, n_features, device=device) + + +@pytest.fixture +def embedding_dir(tmp_dir): + """Create a temporary embedding directory.""" + emb_dir = os.path.join(tmp_dir, "embeddings") + os.makedirs(emb_dir, exist_ok=True) + return emb_dir + + +@pytest.fixture(autouse=True) +def set_random_seeds(): + """Set random seeds for reproducibility in all tests.""" + np.random.seed(42) + torch.manual_seed(42) + if torch.cuda.is_available(): + torch.cuda.manual_seed(42) + + +@pytest.fixture +def sample_dataset(): + """Create a small synthetic dataset for testing.""" + # In-distribution: samples from N(0, 1) + id_samples = torch.randn(100, 10) + # Out-of-distribution: samples from N(5, 2) + ood_samples = torch.randn(100, 10) * 2 + 5 + + return {"id": id_samples, "ood": ood_samples} diff --git a/tests/test_detector.py b/tests/test_detector.py new file mode 100644 index 0000000..12496c3 --- /dev/null +++ b/tests/test_detector.py @@ -0,0 +1,187 @@ +""" +Tests for ForteOODDetector class. +""" + +import numpy as np +import pytest +import torch + +from forte import ForteOODDetector + + +class TestForteOODDetectorInit: + """Test ForteOODDetector initialization.""" + + def test_default_initialization(self, device): + """Test detector with default parameters.""" + detector = ForteOODDetector() + assert detector.batch_size == 32 + assert detector.device in ["cuda:0", "mps", "cpu"] + assert detector.embedding_dir == "./embeddings" + assert detector.nearest_k == 5 + assert detector.method == "gmm" + assert not detector.is_fitted + + def test_custom_parameters(self, device, embedding_dir): + """Test detector with custom parameters.""" + detector = ForteOODDetector( + batch_size=16, device=device, embedding_dir=embedding_dir, nearest_k=10, method="kde" + ) + assert detector.batch_size == 16 + assert detector.device == device + assert detector.embedding_dir == embedding_dir + assert detector.nearest_k == 10 + assert detector.method == "kde" + + @pytest.mark.parametrize("method", ["gmm", "kde", "ocsvm"]) + def test_all_methods(self, method, device, embedding_dir): + """Test initialization with all supported methods.""" + detector = ForteOODDetector(method=method, device=device, embedding_dir=embedding_dir) + assert detector.method == method + + +class TestForteOODDetectorHelperMethods: + """Test private helper methods of ForteOODDetector.""" + + def test_compute_pairwise_distance(self, device): + """Test pairwise distance computation.""" + detector = ForteOODDetector(device=device) + X = torch.randn(10, 5, device=device) + Y = torch.randn(8, 5, device=device) + + dist = detector._compute_pairwise_distance(X, Y) + assert dist.shape == (10, 8) + assert (dist >= 0).all() # Distances should be non-negative + + def test_get_kth_value(self, device): + """Test k-th value extraction.""" + detector = ForteOODDetector(device=device) + X = torch.randn(10, 20, device=device) + k = 5 + + kth_vals = detector._get_kth_value(X, k=k) + assert kth_vals.shape == (10,) + + def test_compute_nearest_neighbour_distances(self, device): + """Test nearest neighbor distance computation.""" + detector = ForteOODDetector(device=device, nearest_k=5) + X = torch.randn(20, 10, device=device) + + distances = detector._compute_nearest_neighbour_distances(X, nearest_k=5) + assert distances.shape == (20,) + assert (distances >= 0).all() + + def test_compute_prdc_features(self, device): + """Test PRDC feature computation.""" + detector = ForteOODDetector(device=device, nearest_k=5) + real_features = torch.randn(30, 10, device=device) + fake_features = torch.randn(25, 10, device=device) + + prdc = detector._compute_prdc_features(real_features, fake_features) + assert prdc.shape == (25, 4) # 4 PRDC metrics + assert not torch.isnan(prdc).any() + # PRDC values should be in reasonable ranges + assert (prdc >= 0).all() + assert (prdc <= 1).any() # At least some values should be normalized + + +@pytest.mark.slow +class TestForteOODDetectorFit: + """Test ForteOODDetector fitting (slower tests).""" + + def test_fit_not_implemented_full(self, small_mock_images, device, embedding_dir): + """Test that fit raises error before implementation.""" + # This is a placeholder - in real implementation, we'd need actual models + # For now, we just test the basic structure + detector = ForteOODDetector( + device="cpu", # Use CPU to avoid downloading large models + embedding_dir=embedding_dir, + method="gmm", + ) + + # Note: This test would actually download models and run feature extraction + # For unit tests, we might want to mock this + # For now, we just check the structure exists + assert hasattr(detector, "fit") + assert hasattr(detector, "predict") + assert hasattr(detector, "predict_proba") + assert hasattr(detector, "evaluate") + + def test_fit_sets_is_fitted(self, device): + """Test that fit sets the is_fitted flag.""" + detector = ForteOODDetector(device=device) + assert not detector.is_fitted + # After fit, should be True (mocking this for now) + + def test_predict_before_fit_raises_error(self, small_mock_images, device, embedding_dir): + """Test that predict raises error if not fitted.""" + detector = ForteOODDetector(device=device, embedding_dir=embedding_dir) + + with pytest.raises(RuntimeError, match="Detector must be fitted"): + detector._get_ood_scores(small_mock_images) + + +class TestForteOODDetectorPredict: + """Test ForteOODDetector prediction methods.""" + + def test_predict_output_shape(self): + """Test that predict returns correct shape.""" + # This would require a fitted detector + # Placeholder for now + pass + + def test_predict_proba_output_range(self): + """Test that predict_proba returns values in [0, 1].""" + # Placeholder - would need fitted detector + pass + + def test_predict_binary_values(self): + """Test that predict returns only 1 and -1.""" + # Placeholder - would need fitted detector + pass + + +class TestForteOODDetectorEvaluate: + """Test ForteOODDetector evaluation methods.""" + + def test_evaluate_before_fit_raises_error(self, small_mock_images, device, embedding_dir): + """Test that evaluate raises error if not fitted.""" + detector = ForteOODDetector(device=device, embedding_dir=embedding_dir) + + with pytest.raises(RuntimeError, match="Detector must be fitted"): + detector.evaluate(small_mock_images[:2], small_mock_images[2:]) + + def test_evaluate_returns_correct_metrics(self): + """Test that evaluate returns all expected metrics.""" + # Placeholder - would need fitted detector + # Should return dict with keys: AUROC, FPR@95TPR, AUPRC, F1 + pass + + +@pytest.mark.integration +class TestForteOODDetectorIntegration: + """Integration tests for complete ForteOODDetector workflow.""" + + @pytest.mark.slow + def test_full_pipeline_mock_data(self): + """Test complete pipeline with mocked data.""" + # This would be a full end-to-end test + # Requires significant resources, so marked as slow + pass + + def test_device_compatibility(self, device): + """Test that detector works on the available device.""" + detector = ForteOODDetector(device=device) + assert detector.device == device + + # Test that custom_detector flag is set correctly + if device == "cpu": + assert not detector.custom_detector + else: + assert detector.custom_detector + + def test_method_compatibility(self, device, embedding_dir): + """Test all methods are compatible with device.""" + for method in ["gmm", "kde", "ocsvm"]: + detector = ForteOODDetector(device=device, embedding_dir=embedding_dir, method=method) + assert detector.method == method diff --git a/tests/test_integration.py b/tests/test_integration.py new file mode 100644 index 0000000..450dfa9 --- /dev/null +++ b/tests/test_integration.py @@ -0,0 +1,281 @@ +""" +Integration tests for forte-detector package. +These tests verify end-to-end functionality. +""" + +import os + +import numpy as np +import pytest +import torch +from PIL import Image + + +@pytest.mark.integration +@pytest.mark.slow +class TestEndToEndWorkflow: + """Test complete end-to-end workflows.""" + + def test_package_imports(self): + """Test that all main imports work correctly.""" + from forte import ForteOODDetector, TorchGMM, TorchKDE, TorchOCSVM, __version__ + + assert ForteOODDetector is not None + assert TorchGMM is not None + assert TorchKDE is not None + assert TorchOCSVM is not None + assert __version__ == "0.1.0" + + def test_detector_initialization_all_methods(self, device, embedding_dir): + """Test detector initialization with all methods.""" + from forte import ForteOODDetector + + for method in ["gmm", "kde", "ocsvm"]: + detector = ForteOODDetector( + method=method, device=device, embedding_dir=embedding_dir, batch_size=8, nearest_k=3 + ) + assert detector.method == method + assert not detector.is_fitted + + def test_image_loading(self, mock_image_paths): + """Test image loading functionality.""" + from forte import ForteOODDetector + + detector = ForteOODDetector(device="cpu") + + # Test loading a valid image + img = detector._load_image(mock_image_paths[0]) + assert img is not None + assert isinstance(img, Image.Image) + + # Test loading an invalid path + img = detector._load_image("/nonexistent/path.png") + assert img is None + + def test_prdc_computation_pipeline(self, device): + """Test PRDC computation on synthetic data.""" + from forte import ForteOODDetector + + detector = ForteOODDetector(device=device, nearest_k=5) + + # Create synthetic features + real_features = torch.randn(50, 128, device=device) + fake_features = torch.randn(40, 128, device=device) + + prdc = detector._compute_prdc_features(real_features, fake_features) + + assert prdc.shape == (40, 4) # 4 PRDC metrics per sample + assert not torch.isnan(prdc).any() + assert (prdc >= 0).all() + + def test_models_work_with_synthetic_features(self, device): + """Test that all models work with synthetic PRDC features.""" + from forte.models import TorchGMM, TorchKDE, TorchOCSVM + + # Generate synthetic PRDC features + X = torch.randn(100, 12, device=device) # 12 = 4 PRDC * 3 models + + # Test GMM + gmm = TorchGMM(n_components=4, max_iter=20, device=device) + gmm.fit(X) + gmm_scores = gmm.score_samples(X) + assert gmm_scores.shape == (100,) + assert not torch.isnan(gmm_scores).any() + + # Test KDE + kde = TorchKDE(X.T, device=device) + kde_scores = kde.logpdf(X) + assert kde_scores.shape == (100,) + assert not torch.isnan(kde_scores).any() + + # Test OCSVM + ocsvm = TorchOCSVM(nu=0.1, n_iters=50, lr=1e-3, device=device) + ocsvm.fit(X) + ocsvm_scores = ocsvm.decision_function(X) + assert ocsvm_scores.shape == (100,) + assert not torch.isnan(ocsvm_scores).any() + + +@pytest.mark.integration +class TestModelSelection: + """Test model selection and hyperparameter optimization.""" + + def test_gmm_bic_selection(self, device): + """Test GMM BIC-based model selection.""" + from forte.models import TorchGMM + + X = torch.randn(100, 10, device=device) + + bic_scores = [] + for n_components in [1, 2, 4, 8]: + gmm = TorchGMM(n_components=n_components, max_iter=20, device=device) + gmm.fit(X) + bic = gmm.bic(X) + bic_scores.append(bic) + + # BIC scores should be finite + assert all(np.isfinite(bic) for bic in bic_scores) + + def test_ocsvm_nu_selection(self, device): + """Test OCSVM with different nu values.""" + from forte.models import TorchOCSVM + + X = torch.randn(100, 10, device=device) + + for nu in [0.01, 0.05, 0.1, 0.2]: + ocsvm = TorchOCSVM(nu=nu, n_iters=30, device=device) + ocsvm.fit(X) + scores = ocsvm.decision_function(X) + assert not torch.isnan(scores).any() + + +@pytest.mark.integration +class TestDeviceCompatibility: + """Test compatibility across different devices.""" + + def test_cpu_device(self): + """Test that everything works on CPU.""" + from forte import ForteOODDetector + + detector = ForteOODDetector(device="cpu") + assert detector.device == "cpu" + assert not detector.custom_detector # CPU uses non-custom detectors + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + def test_cuda_device(self): + """Test that everything works on CUDA.""" + from forte import ForteOODDetector + + detector = ForteOODDetector(device="cuda:0") + assert detector.device == "cuda:0" + assert detector.custom_detector # GPU uses custom detectors + + @pytest.mark.skipif( + not (hasattr(torch.backends, "mps") and torch.backends.mps.is_available()), + reason="MPS not available", + ) + def test_mps_device(self): + """Test that everything works on MPS (Apple Silicon).""" + from forte import ForteOODDetector + + detector = ForteOODDetector(device="mps") + assert detector.device == "mps" + assert detector.custom_detector # GPU uses custom detectors + + +@pytest.mark.integration +class TestCaching: + """Test feature caching functionality.""" + + def test_embedding_directory_creation(self, tmp_dir): + """Test that embedding directory is created.""" + import os + + from forte import ForteOODDetector + + emb_dir = os.path.join(tmp_dir, "test_embeddings") + detector = ForteOODDetector(embedding_dir=emb_dir, device="cpu") + + assert os.path.exists(emb_dir) + + def test_feature_caching_structure(self, tmp_dir): + """Test that feature caching saves files correctly.""" + import os + + from forte import ForteOODDetector + + emb_dir = os.path.join(tmp_dir, "cache_test") + os.makedirs(emb_dir, exist_ok=True) + + # Create a mock cached feature + cache_path = os.path.join(emb_dir, "test_clip_features.pt") + torch.save(torch.randn(10, 512), cache_path) + + assert os.path.exists(cache_path) + loaded = torch.load(cache_path) + assert loaded.shape == (10, 512) + + +@pytest.mark.integration +class TestErrorHandling: + """Test error handling and edge cases.""" + + def test_invalid_method_raises_error(self, device, embedding_dir): + """Test that invalid method raises appropriate error.""" + from forte import ForteOODDetector + + detector = ForteOODDetector( + method="invalid_method", device=device, embedding_dir=embedding_dir + ) + # Should initialize but may fail during fit + assert detector.method == "invalid_method" + + def test_empty_image_list_handling(self, device, embedding_dir): + """Test handling of empty image lists.""" + from forte import ForteOODDetector + + detector = ForteOODDetector(device=device, embedding_dir=embedding_dir) + # This should be handled gracefully + # Actual behavior depends on implementation + + def test_invalid_image_path_handling(self, device): + """Test handling of invalid image paths.""" + from forte import ForteOODDetector + + detector = ForteOODDetector(device=device) + img = detector._load_image("/invalid/path/image.png") + assert img is None # Should return None, not raise error + + +@pytest.mark.integration +class TestReproducibility: + """Test reproducibility with fixed random seeds.""" + + def test_prdc_reproducibility(self, device): + """Test that PRDC computation is reproducible.""" + import numpy as np + + from forte import ForteOODDetector + + # Set seeds + torch.manual_seed(42) + np.random.seed(42) + + detector1 = ForteOODDetector(device=device, nearest_k=5) + real_features = torch.randn(50, 128, device=device) + fake_features = torch.randn(40, 128, device=device) + prdc1 = detector1._compute_prdc_features(real_features, fake_features) + + # Reset seeds + torch.manual_seed(42) + np.random.seed(42) + + detector2 = ForteOODDetector(device=device, nearest_k=5) + prdc2 = detector2._compute_prdc_features(real_features, fake_features) + + assert torch.allclose(prdc1, prdc2, atol=1e-6) + + def test_model_fitting_reproducibility(self, device): + """Test that model fitting is reproducible with same seed.""" + import numpy as np + + from forte.models import TorchGMM + + X = torch.randn(100, 10, device=device) + + # First fit + torch.manual_seed(42) + np.random.seed(42) + gmm1 = TorchGMM(n_components=2, max_iter=20, device=device) + gmm1.fit(X) + scores1 = gmm1.score_samples(X) + + # Second fit with same seed + torch.manual_seed(42) + np.random.seed(42) + gmm2 = TorchGMM(n_components=2, max_iter=20, device=device) + gmm2.fit(X) + scores2 = gmm2.score_samples(X) + + # Results should be very similar (allowing for small numerical differences) + assert torch.allclose(scores1, scores2, rtol=1e-3, atol=1e-3) diff --git a/tests/test_models.py b/tests/test_models.py new file mode 100644 index 0000000..db73697 --- /dev/null +++ b/tests/test_models.py @@ -0,0 +1,220 @@ +""" +Tests for custom PyTorch model implementations (TorchGMM, TorchKDE, TorchOCSVM). +""" + +import numpy as np +import pytest +import torch + +from forte.models import TorchGMM, TorchKDE, TorchOCSVM + + +class TestTorchGMM: + """Test suite for TorchGMM implementation.""" + + def test_initialization(self, device): + """Test GMM initialization.""" + gmm = TorchGMM(n_components=2, device=device) + assert gmm.n_components == 2 + assert gmm.device == device + assert gmm.weights_ is None + assert gmm.means_ is None + assert gmm.covariances_ is None + + def test_fit(self, device, sample_dataset): + """Test GMM fitting.""" + X = sample_dataset["id"].to(device) + gmm = TorchGMM(n_components=2, max_iter=10, device=device) + gmm.fit(X) + + assert gmm.weights_ is not None + assert gmm.means_ is not None + assert gmm.covariances_ is not None + assert gmm.weights_.shape == (2,) + assert gmm.means_.shape == (2, X.shape[1]) + assert torch.allclose(gmm.weights_.sum(), torch.tensor(1.0), atol=1e-5) + + def test_score_samples(self, device, sample_dataset): + """Test GMM score_samples method.""" + X = sample_dataset["id"].to(device) + gmm = TorchGMM(n_components=2, max_iter=10, device=device) + gmm.fit(X) + + scores = gmm.score_samples(X) + assert scores.shape == (X.shape[0],) + assert not torch.isnan(scores).any() + assert not torch.isinf(scores).any() + + def test_bic(self, device, sample_dataset): + """Test GMM BIC computation.""" + X = sample_dataset["id"].to(device) + gmm = TorchGMM(n_components=2, max_iter=10, device=device) + gmm.fit(X) + + bic = gmm.bic(X) + assert isinstance(bic, float) + assert not np.isnan(bic) + assert not np.isinf(bic) + + def test_convergence(self, device): + """Test GMM convergence on simple data.""" + # Create clear clusters + cluster1 = torch.randn(50, 5, device=device) + 0 + cluster2 = torch.randn(50, 5, device=device) + 5 + X = torch.cat([cluster1, cluster2], dim=0) + + gmm = TorchGMM(n_components=2, max_iter=100, tol=1e-3, device=device) + gmm.fit(X) + + # Should converge + assert gmm.converged_ or gmm.lower_bound_ > -np.inf + + +class TestTorchKDE: + """Test suite for TorchKDE implementation.""" + + def test_initialization(self, device): + """Test KDE initialization.""" + dataset = torch.randn(5, 20, device=device) # (d, n) + kde = TorchKDE(dataset, device=device) + + assert kde.d == 5 + assert kde.n == 20 + assert kde.device == device + assert kde.weights is not None + + def test_scotts_silverman_factor(self, device): + """Test bandwidth factor calculations.""" + dataset = torch.randn(5, 20, device=device) + kde_scott = TorchKDE(dataset, bw_method="scott", device=device) + kde_silverman = TorchKDE(dataset, bw_method="silverman", device=device) + + assert kde_scott.factor > 0 + assert kde_silverman.factor > 0 + + def test_evaluate(self, device): + """Test KDE evaluation.""" + dataset = torch.randn(5, 20, device=device) + kde = TorchKDE(dataset, bw_method="scott", device=device) + + # Evaluate at test points + test_points = torch.randn(5, 10, device=device) + densities = kde.evaluate(test_points) + + assert densities.shape == (10,) + assert (densities >= 0).all() # Densities should be non-negative + assert not torch.isnan(densities).any() + + def test_logpdf(self, device): + """Test KDE log probability density.""" + dataset = torch.randn(5, 20, device=device) + kde = TorchKDE(dataset, device=device) + + test_points = torch.randn(5, 10, device=device) + log_densities = kde.logpdf(test_points) + + assert log_densities.shape == (10,) + assert not torch.isnan(log_densities).any() + assert not torch.isinf(log_densities).any() + + def test_custom_bandwidth(self, device): + """Test KDE with custom bandwidth.""" + dataset = torch.randn(5, 20, device=device) + custom_bw = 0.5 + kde = TorchKDE(dataset, bw_method=custom_bw, device=device) + + assert kde.factor == custom_bw + + +class TestTorchOCSVM: + """Test suite for TorchOCSVM implementation.""" + + def test_initialization(self, device): + """Test OCSVM initialization.""" + ocsvm = TorchOCSVM(nu=0.1, n_iters=100, lr=1e-3, device=device) + + assert ocsvm.nu == 0.1 + assert ocsvm.n_iters == 100 + assert ocsvm.lr == 1e-3 + assert ocsvm.device == device + assert ocsvm.w is None + assert ocsvm.rho is None + + def test_fit(self, device, sample_dataset): + """Test OCSVM fitting.""" + X = sample_dataset["id"].to(device) + ocsvm = TorchOCSVM(nu=0.1, n_iters=50, lr=1e-3, device=device) + ocsvm.fit(X) + + assert ocsvm.w is not None + assert ocsvm.rho is not None + assert ocsvm.w.shape == (X.shape[1],) + assert ocsvm.rho.shape == () + + def test_decision_function(self, device, sample_dataset): + """Test OCSVM decision function.""" + X = sample_dataset["id"].to(device) + ocsvm = TorchOCSVM(nu=0.1, n_iters=50, lr=1e-3, device=device) + ocsvm.fit(X) + + decisions = ocsvm.decision_function(X) + assert decisions.shape == (X.shape[0],) + assert not torch.isnan(decisions).any() + + def test_predict(self, device, sample_dataset): + """Test OCSVM prediction.""" + X = sample_dataset["id"].to(device) + ocsvm = TorchOCSVM(nu=0.1, n_iters=50, lr=1e-3, device=device) + ocsvm.fit(X) + + predictions = ocsvm.predict(X) + assert predictions.shape == (X.shape[0],) + assert torch.all((predictions == 1) | (predictions == -1)) + + def test_ood_detection(self, device, sample_dataset): + """Test OCSVM can distinguish ID from OOD.""" + X_id = sample_dataset["id"].to(device) + X_ood = sample_dataset["ood"].to(device) + + ocsvm = TorchOCSVM(nu=0.1, n_iters=100, lr=1e-3, device=device) + ocsvm.fit(X_id) + + # Get decisions for both + decision_id = ocsvm.decision_function(X_id).mean() + decision_ood = ocsvm.decision_function(X_ood).mean() + + # ID samples should generally have higher decision values + # (though not guaranteed for all random seeds) + assert decision_id.item() != decision_ood.item() + + +@pytest.mark.integration +class TestModelsIntegration: + """Integration tests for all models working together.""" + + def test_all_models_on_same_data(self, device, mock_prdc_features): + """Test that all models can work with the same data.""" + X = mock_prdc_features + + # GMM + gmm = TorchGMM(n_components=2, max_iter=20, device=device) + gmm.fit(X) + gmm_scores = gmm.score_samples(X) + + # KDE + kde = TorchKDE(X.T, device=device) # KDE expects (d, n) + kde_scores = kde.logpdf(X) + + # OCSVM + ocsvm = TorchOCSVM(nu=0.1, n_iters=50, device=device) + ocsvm.fit(X) + ocsvm_scores = ocsvm.decision_function(X) + + # All should produce valid scores + assert gmm_scores.shape == (X.shape[0],) + assert kde_scores.shape == (X.shape[0],) + assert ocsvm_scores.shape == (X.shape[0],) + + assert not torch.isnan(gmm_scores).any() + assert not torch.isnan(kde_scores).any() + assert not torch.isnan(ocsvm_scores).any()