diff --git a/CLAUDE.md b/CLAUDE.md index f3aeb08..665e39c 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -72,6 +72,53 @@ def compute(datasetId, apiUrl, token, params): Interface types: `number`, `text`, `select`, `checkbox`, `channel`, `channelCheckboxes`, `tags`, `layer`, `notes` +### Interface Parameter Data Types (What `params['workerInterface']` Returns) + +Each interface type returns a specific data type in `params['workerInterface']['FieldName']`: + +| Interface Type | Returns | Example Value | +|----------------|---------|---------------| +| `number` | `int` or `float` | `32`, `0.5` | +| `text` | `str` | `"1-3, 5-8"`, `""` | +| `select` | `str` | `"sam2.1_hiera_small.pt"` | +| `checkbox` | `bool` | `True`, `False` | +| `channel` | `int` | `0` | +| `channelCheckboxes` | `dict` of `str` → `bool` | `{"0": True, "1": False, "2": True}` | +| `tags` | **`list` of `str`** | `["DAPI blob"]`, `["cell", "nucleus"]` | +| `layer` | `str` | `"layer_id"` | + +**Common pitfall with `tags`**: The `tags` type returns a **plain list of strings**, NOT a dict. Do not call `.get('tags')` on the result. + +```python +# CORRECT - tags returns a list directly: +training_tags = params['workerInterface'].get('Training Tag', []) +# training_tags = ["DAPI blob"] + +# WRONG - will crash with AttributeError: 'list' object has no attribute 'get': +training_tags = params['workerInterface'].get('Training Tag', {}).get('tags', []) +``` + +**Note**: `params['tags']` (the top-level output tags for the worker, NOT a workerInterface field) is also a plain list of strings (e.g., `["DAPI blob"]`). Meanwhile, `params['tags']` used in property workers via `workerClient.get_annotation_list_by_shape()` uses `{'tags': [...], 'exclusive': bool}` — these are two different things. + +**Validating tags** (recommended pattern from cellpose_train, piscis): +```python +tags = workerInterface.get('My Tag Field', []) +if not tags or len(tags) == 0: + sendError("No tag selected", "Please select at least one tag.") + return +``` + +**Using tags to filter annotations**: +```python +# Pass the list directly to annotation_tools +filtered = annotation_tools.get_annotations_with_tags( + annotation_list, tags, exclusive=False) + +# Or with Girder API (must JSON-serialize) +annotations = annotationClient.getAnnotationsByDatasetId( + datasetId, shape='polygon', tags=json.dumps(tags)) +``` + ### Key APIs **annotation_client** (installed from NimbusImage repo): diff --git a/build_machine_learning_workers.sh b/build_machine_learning_workers.sh index 732b51a..b4c0069 100755 --- a/build_machine_learning_workers.sh +++ b/build_machine_learning_workers.sh @@ -39,6 +39,11 @@ docker build . -f ./workers/annotations/sam2_automatic_mask_generator/Dockerfile # Command for M1: # docker build . -f ./workers/annotations/sam2_automatic_mask_generator/Dockerfile_M1 -t annotations/sam2_automatic_mask_generator:latest $NO_CACHE +echo "Building SAM2 few-shot segmentation worker" +docker build . -f ./workers/annotations/sam2_fewshot_segmentation/Dockerfile -t annotations/sam2_fewshot_segmentation:latest $NO_CACHE +# Command for M1: +# docker build . -f ./workers/annotations/sam2_fewshot_segmentation/Dockerfile_M1 -t annotations/sam2_fewshot_segmentation:latest $NO_CACHE + echo "Building SAM2 propagate worker" docker build . -f ./workers/annotations/sam2_propagate/$DOCKERFILE -t annotations/sam2_propagate_worker:latest $NO_CACHE diff --git a/workers/annotations/sam2_fewshot_segmentation/Dockerfile b/workers/annotations/sam2_fewshot_segmentation/Dockerfile new file mode 100644 index 0000000..f3d057a --- /dev/null +++ b/workers/annotations/sam2_fewshot_segmentation/Dockerfile @@ -0,0 +1,70 @@ +FROM nvidia/cuda:12.1.0-cudnn8-devel-ubuntu22.04 as base +LABEL isUPennContrastWorker=True +LABEL com.nvidia.volumes.needed="nvidia_driver" + +ENV NVIDIA_VISIBLE_DEVICES all +ENV NVIDIA_DRIVER_CAPABILITIES compute,utility + +RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -qy tzdata && \ + apt-get install -qy software-properties-common python3-software-properties && \ + apt-get update && apt-get install -qy \ + build-essential \ + wget \ + python3 \ + r-base \ + libffi-dev \ + libssl-dev \ + libjpeg-dev \ + zlib1g-dev \ + r-base \ + git \ + libpython3-dev && \ + apt-get clean && rm -rf /var/lib/apt/lists/* + +ENV PATH="/root/miniforge3/bin:$PATH" +ARG PATH="/root/miniforge3/bin:$PATH" + +RUN wget \ + https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-Linux-x86_64.sh \ + && mkdir /root/.conda \ + && bash Miniforge3-Linux-x86_64.sh -b \ + && rm -f Miniforge3-Linux-x86_64.sh + +FROM base as build + +COPY ./workers/annotations/sam2_fewshot_segmentation/environment.yml / +RUN conda env create --file /environment.yml +SHELL ["conda", "run", "-n", "worker", "/bin/bash", "-c"] + +RUN pip install rtree shapely + +RUN git clone https://github.com/arjunrajlaboratory/NimbusImage/ + +RUN pip install -r /NimbusImage/devops/girder/annotation_client/requirements.txt +RUN pip install -e /NimbusImage/devops/girder/annotation_client/ + +RUN mkdir -p /code +RUN git clone https://github.com/facebookresearch/sam2.git /code/sam2 +RUN pip install -e /code/sam2 + +# Change directory to sam2/checkpoints +WORKDIR /code/sam2/checkpoints +# Download the checkpoints into the checkpoints directory +RUN ./download_ckpts.sh +# Change back to the root directory +WORKDIR / + +COPY ./workers/annotations/sam2_fewshot_segmentation/entrypoint.py / + +COPY ./annotation_utilities /annotation_utilities +RUN pip install /annotation_utilities + +LABEL isUPennContrastWorker="" \ + isAnnotationWorker="" \ + interfaceName="SAM2 few-shot segmentation" \ + interfaceCategory="SAM2" \ + description="Uses SAM2 features for few-shot segmentation based on training annotations" \ + annotationShape="polygon" \ + defaultToolName="SAM2 few-shot" + +ENTRYPOINT ["conda", "run", "--no-capture-output", "-n", "worker", "python", "/entrypoint.py"] diff --git a/workers/annotations/sam2_fewshot_segmentation/Dockerfile_M1 b/workers/annotations/sam2_fewshot_segmentation/Dockerfile_M1 new file mode 100644 index 0000000..5e70586 --- /dev/null +++ b/workers/annotations/sam2_fewshot_segmentation/Dockerfile_M1 @@ -0,0 +1,73 @@ +FROM nvidia/cuda:11.8.0-cudnn8-devel-ubuntu22.04 as base +LABEL isUPennContrastWorker=True +LABEL com.nvidia.volumes.needed="nvidia_driver" + +ENV NVIDIA_VISIBLE_DEVICES all +ENV NVIDIA_DRIVER_CAPABILITIES compute,utility + +RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -qy tzdata && \ + apt-get install -qy software-properties-common python3-software-properties && \ + apt-get update && apt-get install -qy \ + build-essential \ + wget \ + python3 \ + r-base \ + libffi-dev \ + libssl-dev \ + libjpeg-dev \ + zlib1g-dev \ + r-base \ + git \ + libpython3-dev && \ + apt-get clean && rm -rf /var/lib/apt/lists/* + +# The below is for the M1 Macs and should be changed for other architectures +ENV PATH="/root/miniconda3/bin:$PATH" +ARG PATH="/root/miniconda3/bin:$PATH" + +RUN wget \ + https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-aarch64.sh \ + && mkdir /root/.conda \ + && bash Miniconda3-latest-Linux-aarch64.sh -b \ + && rm -f Miniconda3-latest-Linux-aarch64.sh +# END M1 Mac specific + + +FROM base as build + +COPY ./workers/annotations/sam2_fewshot_segmentation/environment.yml / +RUN conda env create --file /environment.yml +SHELL ["conda", "run", "-n", "worker", "/bin/bash", "-c"] + +RUN pip install rtree shapely + +RUN git clone https://github.com/arjunrajlaboratory/NimbusImage/ + +RUN pip install -r /NimbusImage/devops/girder/annotation_client/requirements.txt +RUN pip install -e /NimbusImage/devops/girder/annotation_client/ + +RUN mkdir -p /code +RUN git clone https://github.com/facebookresearch/sam2.git /code/sam2 +RUN pip install -e /code/sam2 + +# Change directory to sam2/checkpoints +WORKDIR /code/sam2/checkpoints +# Download the checkpoints into the checkpoints directory +RUN ./download_ckpts.sh +# Change back to the root directory +WORKDIR / + +COPY ./workers/annotations/sam2_fewshot_segmentation/entrypoint.py / + +COPY ./annotation_utilities /annotation_utilities +RUN pip install /annotation_utilities + +LABEL isUPennContrastWorker="" \ + isAnnotationWorker="" \ + interfaceName="SAM2 few-shot segmentation" \ + interfaceCategory="SAM2" \ + description="Uses SAM2 features for few-shot segmentation based on training annotations" \ + annotationShape="polygon" \ + defaultToolName="SAM2 few-shot" + +ENTRYPOINT ["conda", "run", "--no-capture-output", "-n", "worker", "python", "/entrypoint.py"] diff --git a/workers/annotations/sam2_fewshot_segmentation/SAM2_FEWSHOT.md b/workers/annotations/sam2_fewshot_segmentation/SAM2_FEWSHOT.md new file mode 100644 index 0000000..5030298 --- /dev/null +++ b/workers/annotations/sam2_fewshot_segmentation/SAM2_FEWSHOT.md @@ -0,0 +1,143 @@ +# SAM2 Few-Shot Segmentation Worker + +## Overview + +This worker segments objects in microscopy images using few-shot learning with SAM2. Users annotate a small number of training examples (5-20 objects) with a specific tag, and the worker finds similar objects across the dataset using SAM2's frozen image encoder features. No model training is required. + +## How It Works + +### Phase 1: Training Feature Extraction + +For each polygon annotation matching the user-specified Training Tag: + +1. Load the merged multi-channel image at the annotation's location +2. Convert the annotation polygon to a binary mask +3. Crop the image around the object with context padding (object occupies ~20% of crop area by default) +4. Encode the crop through SAM2's image encoder via `SAM2ImagePredictor.set_image()` +5. Extract the `image_embed` feature map (shape: `1, 256, 64, 64`) +6. Pool the feature map using mask-weighted averaging to produce a 256-dimensional feature vector +7. Average all training feature vectors into a single L2-normalized prototype + +### Phase 2: Inference + +For each image frame in the batch: + +1. Run `SAM2AutomaticMaskGenerator` to generate all candidate masks +2. For each candidate mask: + - Apply the same crop-encode-pool pipeline as training + - Compute cosine similarity between the candidate's feature vector and the training prototype + - Keep the mask if similarity >= threshold +3. Convert passing masks to polygon annotations via `find_contours` + `polygons_to_annotations` +4. Upload all annotations to the server + +## Interface Parameters + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| Training Tag | tags | (required) | Tag identifying training annotation examples | +| Batch XY | text | current | XY positions to process (e.g., "1-3, 5-8") | +| Batch Z | text | current | Z slices to process | +| Batch Time | text | current | Time points to process | +| Model | select | sam2.1_hiera_small.pt | SAM2 checkpoint to use | +| Similarity Threshold | number | 0.5 | Minimum cosine similarity to keep a mask (0.0-1.0) | +| Target Occupancy | number | 0.20 | Fraction of crop area the object should occupy (0.05-0.80) | +| Points per side | number | 32 | Grid density for SAM2 mask generation (16-128) | +| Min Mask Area | number | 100 | Minimum mask area in pixels to consider | +| Max Mask Area | number | 0 | Maximum mask area in pixels (0 = no limit) | +| Smoothing | number | 0.3 | Polygon simplification tolerance | + +## Key Design Decisions + +### Context Padding (Target Occupancy) + +SAM2 was trained on images where objects occupy a reasonable fraction of the frame. Tight crops around objects would be out-of-distribution. The `Target Occupancy` parameter controls how much of the crop the object fills: + +- `crop_side = sqrt(object_area / target_occupancy)` +- Default 0.20 means the object occupies ~20% of the crop area +- The same occupancy is used for both training and inference to ensure consistent feature extraction + +### Mask-Weighted Feature Pooling + +Since we have binary masks for both training annotations and candidate masks, we use them to focus the feature pooling on the actual object pixels rather than background: + +``` +feature_vector = (features * mask).sum(dim=[2,3]) / mask.sum() +``` + +The mask is bilinearly resized from the crop resolution to the feature map resolution (64x64). + +### SAM2ImagePredictor for Encoding + +We use `SAM2ImagePredictor.set_image()` rather than calling `forward_image` directly. This ensures proper handling of: +- Image transforms (resize to 1024x1024, normalization) +- `no_mem_embed` addition (SAM2's learned "no memory" token) +- Consistent feature extraction matching SAM2's internal pipeline + +The `image_embed` from `predictor._features["image_embed"]` gives a `(1, 256, 64, 64)` feature map -- the lowest-resolution, highest-semantic features from SAM2's FPN neck. + +## Tuning Guide + +### Similarity Threshold + +- **Too many false positives**: Increase threshold (try 0.6-0.8) +- **Too few detections (missing objects)**: Decrease threshold (try 0.3-0.4) +- **Start at 0.5** and adjust based on results + +### Target Occupancy + +- **Objects are very small in the image**: Try 0.10-0.15 (more context) +- **Objects are large in the image**: Try 0.30-0.40 (less context) +- **Default 0.20** works well for most microscopy objects + +### Points per side + +- **More masks needed (small objects)**: Increase to 48-64 +- **Faster processing**: Decrease to 16-24 +- **Default 32** balances coverage and speed + +### Min/Max Mask Area + +- Use training annotation areas as a guide +- Set Min to ~50% of smallest training annotation area +- Set Max to ~200% of largest training annotation area +- Set Max to 0 to disable upper limit + +## Performance Characteristics + +- **GPU required**: SAM2 encoder needs CUDA +- **Memory**: ~4GB VRAM for SAM2 small model +- **Speed**: Most time is spent encoding candidate masks individually (one forward pass per candidate). With 32 points per side, expect ~50-200 candidate masks per image. +- **Data efficiency**: Works with 5-20 training examples + +## Possible Future Improvements + +- **Multiple prototypes**: Keep all training vectors instead of averaging, use max similarity (helps when training examples show multiple morphologies) +- **Full-image encoding**: Encode each image once and pool from the full feature map instead of cropping each candidate (faster but lower feature quality for small objects) +- **Negative examples**: Allow users to tag "not this" examples to reduce false positives +- **Size/shape priors**: Learn area distribution from training and filter candidates by size +- **Adaptive thresholding**: Use relative ranking (e.g., top 25%) instead of fixed threshold + +## TODO / Future Work + +- [ ] **Tiled image support**: Large microscopy images should be processed in tiles (like cellposesam's deeptile approach) rather than loading the entire image at once. This would reduce memory usage and allow processing of arbitrarily large images. +- [ ] **Multiple prototypes**: Keep all training feature vectors instead of averaging into a single prototype. Use max similarity or k-NN voting at inference. This would help when training examples show significant morphological variation. +- [ ] **Full-image encoding optimization**: Encode each inference image once and pool from the full feature map for each candidate mask, instead of cropping and re-encoding per candidate. Much faster but may reduce feature quality for small objects. +- [ ] **Negative examples**: Add a "Negative Tag" interface field so users can tag objects they do NOT want to match. Subtract negative similarity from positive similarity to reduce false positives. +- [ ] **Size/shape priors**: Learn area and aspect ratio distributions from training annotations and use them as an additional filter (e.g., reject candidates whose area is >2 std from training mean). +- [ ] **Adaptive thresholding**: Instead of a fixed similarity threshold, use relative ranking (e.g., keep top N% of candidates) or Otsu-style automatic thresholding on the similarity distribution. +- [ ] **Multi-scale feature extraction**: Extract features at multiple occupancy levels (e.g., 0.15, 0.25, 0.40) and concatenate for a richer feature vector. Helps when objects vary significantly in size. +- [ ] **Batch encoding**: Group multiple candidate crops into a batch tensor and encode them in a single forward pass through SAM2 for better GPU utilization. +- [ ] **Cache training prototype**: If the same training tag is used repeatedly, cache the prototype to avoid re-computing features on every run. +- [ ] **Similarity score as property**: Expose the similarity score as an annotation property so users can sort/filter results by confidence. +- [ ] **Support for point annotations as training**: Allow users to provide point prompts (not just polygon masks) as training examples, using SAM2's prompt-based segmentation to generate masks from points first. + +## Files + +| File | Purpose | +|------|---------| +| `entrypoint.py` | Worker logic: interface definition, feature extraction, inference pipeline | +| `Dockerfile` | x86_64 production build (CUDA 12.1, SAM2 checkpoints) | +| `Dockerfile_M1` | arm64/M1 Mac build (CUDA 11.8) | +| `environment.yml` | Conda environment specification | +| `tests/test_sam2_fewshot.py` | Unit tests for helper functions | +| `tests/Dockerfile_Test` | Test Docker image | diff --git a/workers/annotations/sam2_fewshot_segmentation/entrypoint.py b/workers/annotations/sam2_fewshot_segmentation/entrypoint.py new file mode 100644 index 0000000..0908c73 --- /dev/null +++ b/workers/annotations/sam2_fewshot_segmentation/entrypoint.py @@ -0,0 +1,488 @@ +import argparse +import json +import sys +import os + +from itertools import product + +import annotation_client.annotations as annotations_client +import annotation_client.workers as workers +import annotation_client.tiles as tiles + +import annotation_utilities.annotation_tools as annotation_tools +import annotation_utilities.batch_argument_parser as batch_argument_parser + +import numpy as np +from shapely.geometry import Polygon +from skimage.measure import find_contours + +import torch +import torch.nn.functional as F +from sam2.build_sam import build_sam2 +from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator +from sam2.sam2_image_predictor import SAM2ImagePredictor + +from annotation_client.utils import sendProgress, sendError + + +def interface(image, apiUrl, token): + client = workers.UPennContrastWorkerPreviewClient(apiUrl=apiUrl, token=token) + + models = [f for f in os.listdir('/code/sam2/checkpoints') if f.endswith('.pt')] + default_model = 'sam2.1_hiera_small.pt' if 'sam2.1_hiera_small.pt' in models else models[0] if models else None + + interface = { + 'Training Tag': { + 'type': 'tags', + 'displayOrder': 0, + }, + 'Batch XY': { + 'type': 'text', + 'displayOrder': 1, + }, + 'Batch Z': { + 'type': 'text', + 'displayOrder': 2, + }, + 'Batch Time': { + 'type': 'text', + 'displayOrder': 3, + }, + 'Model': { + 'type': 'select', + 'items': models, + 'default': default_model, + 'displayOrder': 4, + }, + 'Similarity Threshold': { + 'type': 'number', + 'min': 0.0, + 'max': 1.0, + 'default': 0.5, + 'displayOrder': 5, + }, + 'Target Occupancy': { + 'type': 'number', + 'min': 0.05, + 'max': 0.80, + 'default': 0.20, + 'displayOrder': 6, + }, + 'Points per side': { + 'type': 'number', + 'min': 16, + 'max': 128, + 'default': 32, + 'displayOrder': 7, + }, + 'Min Mask Area': { + 'type': 'number', + 'min': 0, + 'max': 100000, + 'default': 100, + 'displayOrder': 8, + }, + 'Max Mask Area': { + 'type': 'number', + 'min': 0, + 'max': 10000000, + 'default': 0, + 'displayOrder': 9, + }, + 'Smoothing': { + 'type': 'number', + 'min': 0, + 'max': 3, + 'default': 0.3, + 'displayOrder': 10, + }, + } + client.setWorkerImageInterface(image, interface) + + +def extract_crop_with_context(image, mask, target_occupancy=0.20): + """Extract a crop of the image where the masked object occupies roughly + target_occupancy fraction of the crop area. + + Args: + image: numpy array (H, W, C) or (H, W) + mask: binary numpy array (H, W) + target_occupancy: desired fraction of crop area occupied by object + + Returns: + crop_image: numpy array resized/cropped region + crop_mask: binary numpy array of same spatial size as crop_image + """ + ys, xs = np.where(mask > 0) + if len(ys) == 0: + return image, mask + + y_min, y_max = ys.min(), ys.max() + x_min, x_max = xs.min(), xs.max() + obj_h = y_max - y_min + 1 + obj_w = x_max - x_min + 1 + + obj_area = mask.sum() + if obj_area == 0: + return image, mask + + # Determine crop size so that object occupies target_occupancy of area + crop_area = obj_area / target_occupancy + crop_side = int(np.sqrt(crop_area)) + # Ensure crop is at least as large as the object bounding box + crop_side = max(crop_side, obj_h, obj_w) + + # Center the crop on the object center + cy = (y_min + y_max) / 2.0 + cx = (x_min + x_max) / 2.0 + + h, w = image.shape[:2] + + half = crop_side / 2.0 + top = int(max(0, cy - half)) + left = int(max(0, cx - half)) + bottom = int(min(h, top + crop_side)) + right = int(min(w, left + crop_side)) + + # Adjust if we hit boundaries + if bottom - top < crop_side and top > 0: + top = max(0, bottom - crop_side) + if right - left < crop_side and left > 0: + left = max(0, right - crop_side) + + crop_image = image[top:bottom, left:right] + crop_mask = mask[top:bottom, left:right] + + return crop_image, crop_mask + + +def encode_image_with_sam2(predictor, image_np): + """Encode an image crop using SAM2's image encoder via SAM2ImagePredictor. + + Uses set_image() which handles transforms, backbone encoding, and + no_mem_embed addition consistently with SAM2's internal pipeline. + + Args: + predictor: SAM2ImagePredictor instance + image_np: numpy array (H, W, 3) uint8 + + Returns: + features: tensor of shape [1, 256, 64, 64] (image_embed) + """ + predictor.set_image(image_np) + # image_embed is the lowest-resolution, highest-semantic feature map + # Shape: (1, 256, 64, 64) for 1024x1024 input + return predictor._features["image_embed"] + + +def pool_features_with_mask(features, mask_np, feat_h, feat_w): + """Pool feature map using a binary mask via weighted averaging. + + Args: + features: tensor (1, C, feat_h, feat_w) + mask_np: binary numpy array (crop_h, crop_w) + feat_h: feature map height + feat_w: feature map width + + Returns: + feature_vector: tensor of shape (C,) + """ + # Resize mask to feature map dimensions + mask_tensor = torch.from_numpy(mask_np.astype(np.float32)).unsqueeze(0).unsqueeze(0) + mask_resized = F.interpolate(mask_tensor, size=(feat_h, feat_w), mode='bilinear', align_corners=False) + mask_resized = mask_resized.to(features.device) + + # Weighted pooling + mask_sum = mask_resized.sum() + if mask_sum > 0: + weighted = (features * mask_resized).sum(dim=[2, 3]) / mask_sum + else: + weighted = features.mean(dim=[2, 3]) + + return weighted.squeeze(0) # (C,) + + +def ensure_rgb(image): + """Ensure image is (H, W, 3) uint8 RGB.""" + if image.ndim == 2: + image = np.stack([image, image, image], axis=-1) + elif image.ndim == 3 and image.shape[2] == 1: + image = np.repeat(image, 3, axis=2) + elif image.ndim == 3 and image.shape[2] == 4: + image = image[:, :, :3] + + if image.dtype == np.float32 or image.dtype == np.float64: + if image.max() <= 1.0 and image.min() >= 0.0: + image = (image * 255).astype(np.uint8) + else: + image = np.clip(image, 0, 255).astype(np.uint8) + elif image.dtype == np.uint16: + image = (image / 256).astype(np.uint8) + elif image.dtype != np.uint8: + image = image.astype(np.uint8) + + return image + + +def annotation_to_mask(annotation, image_shape): + """Convert a polygon annotation to a binary mask. + + Args: + annotation: annotation dict with 'coordinates' list of {'x': ..., 'y': ...} + image_shape: (H, W) of the target image + + Returns: + mask: binary numpy array (H, W) + """ + from skimage.draw import polygon as draw_polygon + + coords = annotation['coordinates'] + # Annotation coordinates: 'x' and 'y' in image pixel space + rows = np.array([c['y'] for c in coords]) + cols = np.array([c['x'] for c in coords]) + + mask = np.zeros(image_shape[:2], dtype=np.uint8) + rr, cc = draw_polygon(rows, cols, shape=image_shape[:2]) + mask[rr, cc] = 1 + return mask + + +def compute(datasetId, apiUrl, token, params): + annotationClient = annotations_client.UPennContrastAnnotationClient(apiUrl=apiUrl, token=token) + workerClient = workers.UPennContrastWorkerClient(datasetId, apiUrl, token, params) + tileClient = tiles.UPennContrastDataset(apiUrl=apiUrl, token=token, datasetId=datasetId) + + # Parse parameters + model_name = params['workerInterface']['Model'] + similarity_threshold = float(params['workerInterface']['Similarity Threshold']) + target_occupancy = float(params['workerInterface']['Target Occupancy']) + points_per_side = int(params['workerInterface']['Points per side']) + min_mask_area = int(params['workerInterface']['Min Mask Area']) + max_mask_area = int(params['workerInterface']['Max Mask Area']) + smoothing = float(params['workerInterface']['Smoothing']) + + batch_xy = params['workerInterface'].get('Batch XY', '') + batch_z = params['workerInterface'].get('Batch Z', '') + batch_time = params['workerInterface'].get('Batch Time', '') + + batch_xy = batch_argument_parser.process_range_list(batch_xy, convert_one_to_zero_index=True) + batch_z = batch_argument_parser.process_range_list(batch_z, convert_one_to_zero_index=True) + batch_time = batch_argument_parser.process_range_list(batch_time, convert_one_to_zero_index=True) + + tile = params['tile'] + channel = params['channel'] + tags = params['tags'] + + if batch_xy is None: + batch_xy = [tile['XY']] + if batch_z is None: + batch_z = [tile['Z']] + if batch_time is None: + batch_time = [tile['Time']] + + # Parse training tag - 'type': 'tags' returns a list of strings directly + training_tags = params['workerInterface'].get('Training Tag', []) + if not training_tags or len(training_tags) == 0: + sendError("No training tag selected", + "Please select a tag that identifies your training annotations.") + return + + # ── SAM2 model setup ── + sendProgress(0.0, "Loading model", "Initializing SAM2...") + torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__() + if torch.cuda.get_device_properties(0).major >= 8: + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + + checkpoint_path = f"/code/sam2/checkpoints/{model_name}" + model_to_cfg = { + 'sam2.1_hiera_base_plus.pt': 'sam2.1_hiera_b+.yaml', + 'sam2.1_hiera_large.pt': 'sam2.1_hiera_l.yaml', + 'sam2.1_hiera_small.pt': 'sam2.1_hiera_s.yaml', + 'sam2.1_hiera_tiny.pt': 'sam2.1_hiera_t.yaml', + } + model_cfg = f"configs/sam2.1/{model_to_cfg[model_name]}" + sam2_model = build_sam2(model_cfg, checkpoint_path, device='cuda', apply_postprocessing=False) + predictor = SAM2ImagePredictor(sam2_model) + + # ── Phase 1: Extract training prototype ── + sendProgress(0.05, "Extracting training features", "Fetching training annotations...") + + # Fetch all polygon annotations from the dataset + all_annotations = annotationClient.getAnnotationsByDatasetId(datasetId, shape='polygon') + training_annotations = annotation_tools.get_annotations_with_tags( + all_annotations, training_tags, exclusive=False + ) + + if len(training_annotations) == 0: + sendError("No training annotations found", f"No polygon annotations found with tags: {training_tags}") + return + + print(f"Found {len(training_annotations)} training annotations") + + feature_vectors = [] + for idx, annotation in enumerate(training_annotations): + loc = annotation['location'] + ann_xy = loc.get('XY', 0) + ann_z = loc.get('Z', 0) + ann_time = loc.get('Time', 0) + + # Get the merged image at the annotation's location + images = annotation_tools.get_images_for_all_channels(tileClient, datasetId, ann_xy, ann_z, ann_time) + layers = annotation_tools.get_layers(tileClient.client, datasetId) + merged_image = annotation_tools.process_and_merge_channels(images, layers) + merged_image = ensure_rgb(merged_image) + + # Convert annotation to mask + mask = annotation_to_mask(annotation, merged_image.shape) + + if mask.sum() == 0: + print(f"Warning: training annotation {idx} produced empty mask, skipping") + continue + + # Extract crop with context padding + crop_image, crop_mask = extract_crop_with_context(merged_image, mask, target_occupancy) + crop_image = ensure_rgb(crop_image) + + # Encode with SAM2 + features = encode_image_with_sam2(predictor, crop_image) + feat_h, feat_w = features.shape[2], features.shape[3] + + # Pool features with mask + feature_vec = pool_features_with_mask(features, crop_mask, feat_h, feat_w) + feature_vectors.append(feature_vec) + + sendProgress(0.05 + 0.15 * (idx + 1) / len(training_annotations), + "Extracting training features", + f"Processed {idx + 1}/{len(training_annotations)} training examples") + + if len(feature_vectors) == 0: + sendError("No valid training features", "All training annotations produced empty masks") + return + + # Create prototype by averaging feature vectors + training_prototype = torch.stack(feature_vectors).mean(dim=0) + training_prototype = F.normalize(training_prototype.unsqueeze(0), dim=1).squeeze(0) + + print(f"Training prototype shape: {training_prototype.shape}") + + # Optionally learn size statistics from training annotations + training_areas = [] + for annotation in training_annotations: + coords = annotation['coordinates'] + rows = [c['y'] for c in coords] + cols = [c['x'] for c in coords] + poly = Polygon(zip(cols, rows)) + if poly.is_valid: + training_areas.append(poly.area) + + mean_area = np.mean(training_areas) if training_areas else None + std_area = np.std(training_areas) if training_areas else None + print(f"Training area stats: mean={mean_area}, std={std_area}") + + # ── Phase 2: Inference ── + mask_generator = SAM2AutomaticMaskGenerator( + sam2_model, + points_per_side=points_per_side, + pred_iou_thresh=0.88, + stability_score_thresh=0.95, + min_mask_region_area=min_mask_area, + ) + + batches = list(product(batch_xy, batch_z, batch_time)) + total_batches = len(batches) + new_annotations = [] + + for i, batch in enumerate(batches): + XY, Z, Time = batch + + sendProgress(0.2 + 0.7 * i / total_batches, + "Segmenting", + f"Processing frame {i + 1}/{total_batches}") + + # Get merged image for this batch + images = annotation_tools.get_images_for_all_channels(tileClient, datasetId, XY, Z, Time) + layers = annotation_tools.get_layers(tileClient.client, datasetId) + merged_image = annotation_tools.process_and_merge_channels(images, layers) + merged_image_rgb = ensure_rgb(merged_image) + + # Generate candidate masks with SAM2 + candidate_masks = mask_generator.generate(merged_image_rgb.astype(np.float32)) + print(f"Frame {i + 1}: generated {len(candidate_masks)} candidate masks") + + # Filter candidates by similarity to training prototype + filtered_polygons = [] + for mask_data in candidate_masks: + mask = mask_data['segmentation'] + area = mask.sum() + + # Area filtering + if min_mask_area > 0 and area < min_mask_area: + continue + if max_mask_area > 0 and area > max_mask_area: + continue + + # Extract crop with context, encode, and compare + crop_image, crop_mask = extract_crop_with_context( + merged_image_rgb, mask, target_occupancy + ) + crop_image = ensure_rgb(crop_image) + + if crop_mask.sum() == 0: + continue + + features = encode_image_with_sam2(predictor, crop_image) + feat_h, feat_w = features.shape[2], features.shape[3] + feature_vec = pool_features_with_mask(features, crop_mask, feat_h, feat_w) + + # Compute cosine similarity + feature_vec_norm = F.normalize(feature_vec.unsqueeze(0), dim=1) + similarity = F.cosine_similarity( + feature_vec_norm, + training_prototype.unsqueeze(0) + ).item() + + if similarity >= similarity_threshold: + # Convert mask to polygon + contours = find_contours(mask, 0.5) + if len(contours) == 0: + continue + polygon = Polygon(contours[0]).simplify(smoothing, preserve_topology=True) + if polygon.is_valid and not polygon.is_empty: + filtered_polygons.append(polygon) + + print(f"Frame {i + 1}: {len(filtered_polygons)} masks passed similarity filter") + + # Convert polygons to annotations + temp_annotations = annotation_tools.polygons_to_annotations( + filtered_polygons, datasetId, XY=XY, Time=Time, Z=Z, tags=tags, channel=channel + ) + new_annotations.extend(temp_annotations) + + sendProgress(0.9, "Uploading annotations", f"Sending {len(new_annotations)} annotations to server") + annotationClient.createMultipleAnnotations(new_annotations) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser( + description='SAM2 Few-Shot Segmentation') + + parser.add_argument('--datasetId', type=str, required=False, action='store') + parser.add_argument('--apiUrl', type=str, required=True, action='store') + parser.add_argument('--token', type=str, required=True, action='store') + parser.add_argument('--request', type=str, required=True, action='store') + parser.add_argument('--parameters', type=str, + required=True, action='store') + + args = parser.parse_args(sys.argv[1:]) + + params = json.loads(args.parameters) + datasetId = args.datasetId + apiUrl = args.apiUrl + token = args.token + + match args.request: + case 'compute': + compute(datasetId, apiUrl, token, params) + case 'interface': + interface(params['image'], apiUrl, token) diff --git a/workers/annotations/sam2_fewshot_segmentation/environment.yml b/workers/annotations/sam2_fewshot_segmentation/environment.yml new file mode 100644 index 0000000..39e0881 --- /dev/null +++ b/workers/annotations/sam2_fewshot_segmentation/environment.yml @@ -0,0 +1,18 @@ +name: worker +channels: + - conda-forge + - defaults +dependencies: + - python=3.10 + - pip + - imageio + - rasterio + - shapely + - pillow + - opencv + - matplotlib + - scikit-image + - pip: + - pycocotools + - onnxruntime + - onnx diff --git a/workers/annotations/sam2_fewshot_segmentation/tests/Dockerfile_Test b/workers/annotations/sam2_fewshot_segmentation/tests/Dockerfile_Test new file mode 100644 index 0000000..f04c20a --- /dev/null +++ b/workers/annotations/sam2_fewshot_segmentation/tests/Dockerfile_Test @@ -0,0 +1,14 @@ +# Use the existing sam2_fewshot_segmentation worker as the base +FROM annotations/sam2_fewshot_segmentation:latest AS test + +# Install test dependencies +SHELL ["conda", "run", "-n", "worker", "/bin/bash", "-c"] +RUN pip install pytest pytest-mock + +# Copy test files +RUN mkdir -p /tests +COPY ./workers/annotations/sam2_fewshot_segmentation/tests/*.py /tests +WORKDIR /tests + +# Command to run tests +ENTRYPOINT ["conda", "run", "--no-capture-output", "-n", "worker", "python3", "-m", "pytest", "-v"] diff --git a/workers/annotations/sam2_fewshot_segmentation/tests/__init__.py b/workers/annotations/sam2_fewshot_segmentation/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/workers/annotations/sam2_fewshot_segmentation/tests/test_sam2_fewshot.py b/workers/annotations/sam2_fewshot_segmentation/tests/test_sam2_fewshot.py new file mode 100644 index 0000000..00bcc41 --- /dev/null +++ b/workers/annotations/sam2_fewshot_segmentation/tests/test_sam2_fewshot.py @@ -0,0 +1,278 @@ +import pytest +import numpy as np +from unittest.mock import patch, MagicMock + +from entrypoint import ( + extract_crop_with_context, + pool_features_with_mask, + ensure_rgb, + annotation_to_mask, + interface, +) + + +class TestExtractCropWithContext: + """Tests for the context-aware crop extraction.""" + + def test_basic_crop_centered(self): + """Test that crop is centered on the object.""" + image = np.zeros((200, 200, 3), dtype=np.uint8) + mask = np.zeros((200, 200), dtype=np.uint8) + # Place a 20x20 object in the center + mask[90:110, 90:110] = 1 + image[90:110, 90:110] = 128 + + crop_image, crop_mask = extract_crop_with_context(image, mask, target_occupancy=0.20) + + # Object should still be present in the crop + assert crop_mask.sum() > 0 + # Crop should be larger than the object itself + assert crop_image.shape[0] >= 20 + assert crop_image.shape[1] >= 20 + + def test_small_object_gets_more_context(self): + """A small object should get a proportionally larger crop.""" + image = np.zeros((500, 500, 3), dtype=np.uint8) + mask = np.zeros((500, 500), dtype=np.uint8) + # Small 10x10 object + mask[245:255, 245:255] = 1 + + crop_image, crop_mask = extract_crop_with_context(image, mask, target_occupancy=0.20) + + # With 100 pixels of object area and 0.20 occupancy, + # crop area should be ~500, so side ~22 + obj_area = 100 + expected_crop_area = obj_area / 0.20 + expected_side = int(np.sqrt(expected_crop_area)) + assert crop_image.shape[0] >= expected_side - 2 # Allow small margin + + def test_object_at_edge(self): + """Object near image edge should still produce valid crop.""" + image = np.zeros((100, 100, 3), dtype=np.uint8) + mask = np.zeros((100, 100), dtype=np.uint8) + # Object at top-left corner + mask[0:10, 0:10] = 1 + + crop_image, crop_mask = extract_crop_with_context(image, mask, target_occupancy=0.20) + + # Should not crash and mask should be preserved + assert crop_mask.sum() > 0 + assert crop_image.shape[0] > 0 + assert crop_image.shape[1] > 0 + + def test_empty_mask_returns_original(self): + """Empty mask should return the original image and mask.""" + image = np.zeros((100, 100, 3), dtype=np.uint8) + mask = np.zeros((100, 100), dtype=np.uint8) + + crop_image, crop_mask = extract_crop_with_context(image, mask, target_occupancy=0.20) + + assert np.array_equal(crop_image, image) + assert np.array_equal(crop_mask, mask) + + def test_large_object_respects_bounding_box(self): + """Crop should be at least as large as the object bounding box.""" + image = np.zeros((200, 200, 3), dtype=np.uint8) + mask = np.zeros((200, 200), dtype=np.uint8) + # Large 80x80 object + mask[60:140, 60:140] = 1 + + crop_image, crop_mask = extract_crop_with_context(image, mask, target_occupancy=0.20) + + # Crop must encompass the full object + assert crop_image.shape[0] >= 80 + assert crop_image.shape[1] >= 80 + # And the mask pixels should all be within the crop + assert crop_mask.sum() == mask.sum() + + +class TestPoolFeaturesWithMask: + """Tests for the weighted feature pooling.""" + + def test_basic_pooling(self): + """Pooling with full mask should equal global average.""" + import torch + + C, H, W = 32, 8, 8 + features = torch.ones(1, C, H, W) + mask = np.ones((H, W), dtype=np.float32) + + result = pool_features_with_mask(features, mask, H, W) + + assert result.shape == (C,) + # With all-ones features and all-ones mask, result should be all ones + assert torch.allclose(result, torch.ones(C), atol=1e-3) + + def test_masked_region_pooling(self): + """Pooling should focus on masked region.""" + import torch + + C, H, W = 16, 8, 8 + features = torch.zeros(1, C, H, W) + # Set top-left quadrant to 1.0 + features[:, :, :4, :4] = 1.0 + + # Mask only the top-left quadrant + mask = np.zeros((H, W), dtype=np.float32) + mask[:4, :4] = 1.0 + + result = pool_features_with_mask(features, mask, H, W) + + # Should be close to 1.0 since we're pooling from the region with value 1 + assert torch.allclose(result, torch.ones(C), atol=0.2) + + def test_empty_mask_fallback(self): + """Empty mask should fall back to global average pooling.""" + import torch + + C, H, W = 16, 8, 8 + features = torch.ones(1, C, H, W) * 3.0 + mask = np.zeros((H, W), dtype=np.float32) + + result = pool_features_with_mask(features, mask, H, W) + + # Should fall back to mean pooling + assert result.shape == (C,) + assert torch.allclose(result, torch.ones(C) * 3.0, atol=1e-3) + + def test_mask_upscaling(self): + """Test that mask is properly resized to match feature dimensions.""" + import torch + + C = 16 + feat_h, feat_w = 8, 8 + features = torch.ones(1, C, feat_h, feat_w) + + # Mask at different resolution than features + mask = np.ones((32, 32), dtype=np.float32) + + result = pool_features_with_mask(features, mask, feat_h, feat_w) + + assert result.shape == (C,) + assert torch.allclose(result, torch.ones(C), atol=1e-3) + + +class TestEnsureRgb: + """Tests for image format normalization.""" + + def test_grayscale_to_rgb(self): + image = np.zeros((100, 100), dtype=np.uint8) + result = ensure_rgb(image) + assert result.shape == (100, 100, 3) + assert result.dtype == np.uint8 + + def test_single_channel_to_rgb(self): + image = np.zeros((100, 100, 1), dtype=np.uint8) + result = ensure_rgb(image) + assert result.shape == (100, 100, 3) + + def test_rgba_to_rgb(self): + image = np.zeros((100, 100, 4), dtype=np.uint8) + result = ensure_rgb(image) + assert result.shape == (100, 100, 3) + + def test_float_0_1_to_uint8(self): + image = np.ones((100, 100, 3), dtype=np.float32) * 0.5 + result = ensure_rgb(image) + assert result.dtype == np.uint8 + assert result.max() == 127 or result.max() == 128 # rounding + + def test_float_0_255_to_uint8(self): + image = np.ones((100, 100, 3), dtype=np.float32) * 200.0 + result = ensure_rgb(image) + assert result.dtype == np.uint8 + assert result.max() == 200 + + def test_uint16_to_uint8(self): + image = np.ones((100, 100, 3), dtype=np.uint16) * 512 + result = ensure_rgb(image) + assert result.dtype == np.uint8 + assert result.max() == 2 # 512 / 256 = 2 + + def test_rgb_uint8_passthrough(self): + image = np.ones((100, 100, 3), dtype=np.uint8) * 42 + result = ensure_rgb(image) + assert result.dtype == np.uint8 + assert np.array_equal(result, image) + + +class TestAnnotationToMask: + """Tests for converting polygon annotations to binary masks.""" + + def test_square_annotation(self): + annotation = { + 'coordinates': [ + {'x': 10, 'y': 10}, + {'x': 10, 'y': 20}, + {'x': 20, 'y': 20}, + {'x': 20, 'y': 10}, + ] + } + mask = annotation_to_mask(annotation, (30, 30)) + assert mask.shape == (30, 30) + assert mask.sum() > 0 + # Center of the square should be 1 + assert mask[15, 15] == 1 + # Outside should be 0 + assert mask[0, 0] == 0 + + def test_mask_matches_image_shape(self): + annotation = { + 'coordinates': [ + {'x': 5, 'y': 5}, + {'x': 5, 'y': 15}, + {'x': 15, 'y': 15}, + {'x': 15, 'y': 5}, + ] + } + mask = annotation_to_mask(annotation, (100, 200)) + assert mask.shape == (100, 200) + + +class TestInterface: + """Test the interface function.""" + + @patch('annotation_client.workers.UPennContrastWorkerPreviewClient') + def test_interface_sets_all_fields(self, mock_client_class): + mock_client = mock_client_class.return_value + + # Mock the checkpoint directory + with patch('os.listdir', return_value=['sam2.1_hiera_small.pt', 'sam2.1_hiera_large.pt']): + interface('test_image', 'http://test-api', 'test-token') + + mock_client.setWorkerImageInterface.assert_called_once() + interface_data = mock_client.setWorkerImageInterface.call_args[0][1] + + # Verify all expected fields are present + expected_fields = [ + 'Training Tag', 'Batch XY', 'Batch Z', 'Batch Time', + 'Model', 'Similarity Threshold', 'Target Occupancy', + 'Points per side', 'Min Mask Area', 'Max Mask Area', 'Smoothing', + ] + for field in expected_fields: + assert field in interface_data, f"Missing interface field: {field}" + + # Verify types + assert interface_data['Training Tag']['type'] == 'tags' + assert interface_data['Model']['type'] == 'select' + assert interface_data['Similarity Threshold']['type'] == 'number' + assert interface_data['Target Occupancy']['type'] == 'number' + assert interface_data['Points per side']['type'] == 'number' + assert interface_data['Smoothing']['type'] == 'number' + + # Verify defaults + assert interface_data['Similarity Threshold']['default'] == 0.5 + assert interface_data['Target Occupancy']['default'] == 0.20 + assert interface_data['Points per side']['default'] == 32 + assert interface_data['Model']['default'] == 'sam2.1_hiera_small.pt' + + @patch('annotation_client.workers.UPennContrastWorkerPreviewClient') + def test_interface_with_no_models(self, mock_client_class): + mock_client = mock_client_class.return_value + + with patch('os.listdir', return_value=[]): + interface('test_image', 'http://test-api', 'test-token') + + interface_data = mock_client.setWorkerImageInterface.call_args[0][1] + assert interface_data['Model']['default'] is None + assert interface_data['Model']['items'] == []