diff --git a/cc/CLAUDE.md b/cc/CLAUDE.md new file mode 100644 index 0000000..a51ee85 --- /dev/null +++ b/cc/CLAUDE.md @@ -0,0 +1,820 @@ +# Claude Experiment: Automated Annotation Workflow + +## Overview + +This experiment explores automating the human annotation workflow for the docket-entry classifier project. The goal is to have Claude perform the iterative steps that humans currently do—creating synthetic annotations, grounding decision boundaries, generating training sets, and preparing data for BERT fine-tuning. + +## Project Context + +- **Target Project**: `docket-entry` - classifying docket entries (motions, orders, etc.) +- **Pipeline**: Synthetic annotations → Human grounding → AI annotation of training sets → BERT fine-tuning + +## Note on Dictation + +The user often uses dictation, so be forgiving of typos and odd formatting in spoken instructions. Defer to what's written in the code for canonical names and spellings (e.g., "Claude" may be transcribed as "Cloud"). + +## Important: Import Pattern + +Always import models using the shim pattern: +```python +from clx.models import Label, LabelHeuristic, LabelDecision, Project +``` +NOT `from clx.app.models import ...`. The shim at `clx/models.py` auto-initializes Django. + +--- + +## Step 1: Creating Heuristics for a Label + +### Purpose + +Heuristics partition the corpus into three buckets for efficient annotation: +- **Excluded**: High-confidence negatives (don't meet minimal conditions) +- **Neutral**: Uncertain cases (meet minimal but not likely conditions) +- **Likely**: High-confidence positives (meet both minimal and likely conditions) + +### Two Types of Heuristics + +#### 1. Query String Heuristics + +Simple keyword conditions using a mini-language: + +| Operator | Meaning | Example | +|----------|---------|---------| +| `,` | AND (all must match) | `motion, court` | +| `\|` | OR (any can match) | `motion\|filing` | +| `~` | NOT (negation) | `~denied` | +| `^` | Starts with | `^Summary` | + +**Precedence**: ORs are nested within ANDs. So `a, b|c` means `(a) AND (b OR c)`. + +**All matching is case-insensitive.** + +Examples: +- `complaint` - contains "complaint" anywhere +- `^motion` - starts with "motion" +- `motion, ~denied` - contains "motion" AND does not contain "denied" +- `motion|application|request` - contains any of these terms +- `^motion, court|judge` - starts with "motion" AND contains "court" or "judge" + +#### 2. Custom Function Heuristics + +For complex logic, define a decorated function in `clx/app/custom_heuristics.py`: + +```python +from clx.app.custom_heuristics import custom_heuristic + +def within_first(text, term, n): + """Helper: check if term appears in first n words.""" + first_n = " ".join(text.split()[:n]) + return term in first_n + +@custom_heuristic("docket-entry", "Motion") +def first_3_motion(text, **kwargs): + """Matches if 'motion' appears in the first 3 words.""" + return within_first(text.lower(), "motion", 3) +``` + +The decorator registers the function with: +- `project_id`: Which project this applies to +- `label_name`: Which label this is a heuristic for +- The function receives `text` and must return `True`/`False` + +After adding custom heuristics, sync them: +```python +LabelHeuristic.sync_custom_heuristics() +``` + +### Minimal vs Likely Conditions + +**Minimal Conditions** (`is_minimal=True`): +- Define what MUST be true for a positive example +- Used to exclude obvious negatives +- Should be conservative—avoid false exclusions +- Example: A "Complaint" should contain "complaint" (or common misspellings) + +**Likely Conditions** (`is_likely=True`): +- Define patterns that strongly suggest a positive +- Used to identify easy positive cases +- Can be more aggressive +- Example: Text starting with "Complaint" is very likely a complaint + +### The Three Buckets Logic + +``` +EXCLUDED = does not match ANY minimal heuristic +NEUTRAL = matches at least one minimal BUT no likely heuristics +LIKELY = matches at least one minimal AND at least one likely heuristic +``` + +### Creating and Managing Heuristics + +```python +from clx.models import Label, LabelHeuristic + +# Get the label +label = Label.objects.get(project_id="docket-entry", name="Motion") + +# View existing heuristics +for h in label.heuristics.all(): + print(f"ID: {h.id}") + print(f" Query: {h.querystring or h.custom}") + print(f" is_minimal: {h.is_minimal}, is_likely: {h.is_likely}") + print(f" Matches: {h.num_examples}, Applied: {h.applied_at}") + +# Create a querystring heuristic +heuristic = LabelHeuristic.objects.create( + label=label, + querystring="motion|application|request", + is_minimal=True, +) + +# Apply the heuristic (computes across corpus) +heuristic.apply() + +# Create a likely heuristic +likely_heuristic = LabelHeuristic.objects.create( + label=label, + querystring="^motion", + is_likely=True, +) +likely_heuristic.apply() + +# Check bucket counts +label.refresh_from_db() +print(f"Excluded: {label.num_excluded}") +print(f"Neutral: {label.num_neutral}") +print(f"Likely: {label.num_likely}") +``` + +### Guidelines for Claude + +1. **Start simple** - Don't overthink. Simple keyword matches work well. + +2. **Minimal conditions should be conservative**: + - Ask: "Could there ever be a positive example that doesn't match this?" + - If yes, broaden the condition or add alternatives with `|` + - Include common misspellings, abbreviations, synonyms + +3. **Likely conditions can be aggressive**: + - These just identify easy cases, not all cases + - Prefix matches (`^term`) are often good likely conditions + +4. **Iterate based on counterexamples**: + - If you find a positive example in the "excluded" bucket → expand minimal condition + - If you find obvious positives in "neutral" → add likely conditions + +5. **Multiple heuristics combine with OR**: + - Multiple minimal heuristics: excluded if matches NONE of them + - Multiple likely heuristics: likely if matches ANY of them + +--- + +## Step 2: Create Annotation Decisions + +### Purpose + +Decisions are reason-annotated examples that define decision boundaries. They serve two purposes: +1. Document where we're drawing the line on edge cases +2. Provide training examples for the GEPA predictor optimization + +### What Makes a Good Decision + +- **Keep it minimal**: Humans should be able to review all decisions and understand the labeling policy +- **Include obvious examples**: At least one clear positive example ("This is obviously a complaint") +- **Focus on edge cases**: Where the boundary isn't obvious +- **Short reasons**: 1-2 sentences explaining why + +### Examples of Good Decisions + +For a "Complaint" label: +- **Positive**: "Complaint for Damages" → `True`, "This is clearly a complaint filing" +- **Negative**: "Submission of Complaint as Exhibit" → `False`, "This references a complaint but is not the complaint itself" +- **Negative**: "Response to Complaint" → `False`, "This is a response document, not the complaint" + +For a "Motion" label: +- **Positive**: "Motion for Summary Judgment" → `True`, "Standard motion filing" +- **Positive**: "Application for Extension of Time" → `True`, "Applications that request court action are functionally motions" +- **Negative**: "Opposition to Motion" → `False`, "This opposes a motion but is not itself a motion" + +### Creating Decisions + +```python +from clx.models import Label, LabelDecision +from clx import generate_hash + +label = Label.objects.get(project_id="docket-entry", name="Motion") + +# View existing decisions +for d in label.decisions.all(): + print(f"Value: {d.value}") + print(f"Text: {d.text[:100]}...") + print(f"Reason: {d.reason}") + print() + +# Create a decision from text +text = "Motion for Summary Judgment filed by Defendant" +decision = LabelDecision.objects.create( + label=label, + text_hash=generate_hash(text), + text=text, + value=True, + reason="Standard motion filing requesting summary judgment" +) + +# Or create from a search result (has text_hash already) +# See Search section for how to find examples +``` + +### Guidelines for Claude + +1. **Start with 1-2 obvious decisions** per label +2. **Add edge case decisions as you encounter them** during review +3. **Keep reasons brief but clear** - they'll be used for predictor training +4. **Update decisions if needed** - the same text_hash will update the existing decision + +--- + +## Step 3: Sample the Training Set + +### Purpose + +The training set is a diverse sample of examples used for: +- Running predictor inference +- Training fine-tuned BERT models +- Evaluating model performance + +### How Sampling Works + +The trainset samples from multiple sources to ensure diversity: + +1. **Heuristic buckets**: Random samples from excluded, neutral, and likely buckets +2. **Decision neighbors**: Semantic neighbors of each decision (finds similar edge cases) + +Default configuration (configurable per label): +- `trainset_num_excluded`: 1000 examples from excluded bucket +- `trainset_num_neutral`: 1000 examples from neutral bucket +- `trainset_num_likely`: 1000 examples from likely bucket +- `trainset_num_decision_neighbors`: 50 neighbors per decision + +The sampling uses "mesh sort" to select diverse examples (not just random). + +### Train vs Eval Split + +- **Train split**: Main sample (ratio=1.0) +- **Eval split**: Smaller sample (ratio=0.2) for evaluation + +### Updating the Trainset + +```python +from clx.models import Label + +label = Label.objects.get(project_id="docket-entry", name="Motion") + +# Configure sampling parameters (optional) +label.trainset_num_excluded = 1000 +label.trainset_num_neutral = 1000 +label.trainset_num_likely = 1000 +label.trainset_num_decision_neighbors = 50 +label.save() + +# Sample the trainset +label.update_trainset() + +# Check what was sampled +print(f"Train examples: {label.trainset_examples.filter(split='train').count()}") +print(f"Eval examples: {label.trainset_examples.filter(split='eval').count()}") +``` + +### When to Resample + +Resample the trainset when: +- You add new decisions (to include their neighbors) +- You change heuristics significantly +- You want different sampling parameters + +**Note**: Resampling will require re-running predictions (Step 4), which costs money. + +--- + +## Step 4: Fit and Run Predictor + +### Purpose + +The predictor is a small LLM (GPT-mini, Gemini Flash, etc.) that classifies examples. It uses GEPA (a DSPY optimization algorithm) to generate an optimized classification prompt based on your decisions. + +### Cost Warning + +Running predictions costs money (~$2-3 per full trainset run). Plan your workflow to minimize re-runs: +- Batch multiple decisions before resampling +- Fix as many issues as possible before re-running predictions +- The iteration loop is: decisions → resample → fit predictor → run predictions → review → repeat + +### Fitting the Predictor + +Fitting uses your decisions to optimize a classification prompt: + +```python +from clx.models import Label + +label = Label.objects.get(project_id="docket-entry", name="Motion") + +# Configure models (optional) +label.inference_model = "openai/gpt-5-mini" # For predictions +label.teacher_model = "openai/gpt-5" # For GEPA optimization +label.save() + +# Fit the predictor (uses decisions as training examples) +label.fit_predictor() +# This will print the cost when done +``` + +### Running Predictions + +After fitting, run predictions across the trainset: + +```python +label.update_trainset_preds(num_threads=128) + +# Check prediction counts +label.refresh_from_db() +print(f"Positive predictions: {label.trainset_num_positive_preds}") +print(f"Negative predictions: {label.trainset_num_negative_preds}") +``` + +### Viewing Predictions with Reasons + +The predictor outputs both a value and a reason for each prediction: + +```python +# View trainset examples with predictions +for ex in label.trainset_examples.filter(pred__isnull=False)[:10]: + print(f"Pred: {ex.pred}") + print(f"Text: {ex.text[:100]}...") + print(f"Reason: {ex.reason}") + print() +``` + +--- + +## Step 5: Train Fine-tuned Models + +### Purpose + +Fine-tuned BERT models are the production output. They're fast and cheap to run at scale. We train them on the predictor's outputs. + +### Two Configs + +- **`main`**: Full training (10 epochs) - the production model +- **`underfit`**: Light training (1 epoch) - useful for finding different failure modes + +Training both configs helps identify disagreements between models. + +### Training Process + +Training can be done via CLI or programmatically: + +```bash +# CLI: Train main model (10 epochs) +clx train docket-entry "Motion" main + +# CLI: Train underfit model (1 epoch) +clx train docket-entry "Motion" underfit +``` + +```python +# Programmatic: Train a specific config +label = Label.objects.get(project_id="docket-entry", name="Motion") +label.train_finetune("main") +``` + +The training process: +1. Prepares training data from the trainset +2. Runs training remotely in the cloud +3. Runs predictions on the trainset using the trained model +4. Updates the finetune tags and saves eval results + +### Update All (Recommended) + +The `update_all` method runs the full pipeline, but **only steps that are out of date**: + +```python +label = Label.objects.get(project_id="docket-entry", name="Motion") + +# Run only what's needed based on timestamps +label.update_all() + +# Force run everything regardless of timestamps +label.update_all(force=True) +``` + +This checks timestamps and runs: +1. **Resample trainset** - if decisions are newer than trainset +2. **Fit predictor** - if trainset is newer than predictor +3. **Run predictions** - if predictor is newer than predictions +4. **Train finetunes** - if predictions are newer than finetunes +5. **Run global corpus predictions** - only if `predict=True` and main finetune is newer than global predictions + +```python +# Also run global corpus predictions (step 5) +label.update_all(predict=True) +``` + +### Programmatic Access + +```python +from clx.models import Label + +label = Label.objects.get(project_id="docket-entry", name="Motion") + +# Prepare finetune data (for inspection) +train_data, eval_data, config = label.prepare_finetune("main") + +# Get the trained model pipeline (runs remotely) +pipe = label.get_finetune_run_pipe("main") +predictions = pipe(["some text to classify"], batch_size=16) + +# View finetune results +for ft in label.fintunes.all(): + print(f"Config: {ft.config_name}") + print(f"Results: {ft.eval_results}") +``` + +### Global Corpus Predictions + +After training a finetune, you can run predictions across the **entire corpus** (not just the trainset). This is a separate step because it's more expensive and not always needed during development. + +Global predictions only run for the **main finetune config** (defined on the search model as `main_finetune_config`). + +```python +# Run predictions across entire corpus using the main finetune config +label.predict_finetune() + +# Force restart (clears cache and starts fresh) +label.predict_finetune(force=True) +``` + +The `predict_finetune` method: +1. **Uses the main finetune config** - defined on the search model (e.g., `DocketEntry.main_finetune_config = "main"`) +2. **Is idempotent** - caches progress and picks up where it left off if interrupted +3. **Uses a CSV cache** in the label's data directory (`data_dir/finetune_predictions_cache.csv`) +4. **Updates the global finetune tag** (`ft`) when complete +5. **Sets `predicted_at` timestamp** on the LabelFinetune object +6. **Deletes the cache** after successful completion + +**Tags**: +- `trainset:ft:{config}` - Predictions on trainset only (set by `train_finetune`) +- `ft` - Predictions on entire corpus for the main config (set by `predict_finetune`) + +**Timestamps** on LabelFinetune: +- `finetuned_at` - When the model was last trained +- `predicted_at` - When global corpus predictions were last run + +--- + +## Step 6: Review and Iterate + +### Finding Issues + +Use search to find examples where models might be wrong: + +1. **Review disagreements**: Examples where predictor and fine-tunes disagree +2. **Search by heuristic bucket**: Look in neutral bucket for edge cases +3. **Keyword search**: Find specific patterns +4. **Semantic search**: Find examples similar to a known problem + +### Fast Annotations + +For quick fixes without full decision reasons, use fast annotations: + +```python +from clx.models import Label + +label = Label.objects.get(project_id="docket-entry", name="Motion") +model = label.project.get_search_model() + +# Get an example +example = model.objects.get(id=12345) + +# Set annotation (no reason needed) +example.set_annotation(label, True) # Mark as positive +example.set_annotation(label, False) # Mark as negative +example.set_annotation(label, "flag") # Flag for exclusion from trainset +example.set_annotation(label, None) # Clear annotation +``` + +Flagged examples are excluded from the trainset entirely. + +### The Iteration Loop + +1. **Search** for potential issues (disagreements, specific patterns, etc.) +2. **Review** examples and identify errors +3. **Fix** via decisions (for edge cases needing reasons) or fast annotations (for quick fixes) +4. **Batch fixes** - do as many as possible before re-running +5. **Resample trainset** (`label.update_trainset()`) +6. **Refit predictor** (`label.fit_predictor()`) +7. **Re-run predictions** (`label.update_trainset_preds()`) +8. **Retrain models** (CLI train commands) +9. **Repeat** + +--- + +## Search Reference + +The search system is the primary way to find and review examples. + +### Basic Search + +```python +from clx.models import Project + +project = Project.objects.get(id="docket-entry") +model = project.get_search_model() + +# Simple search - returns dict with 'data' key +results = model.objects.search(page=1, page_size=100) +for item in results["data"]: + print(item["id"], item["text"][:80]) +``` + +### Search Parameters + +All parameters go in a `params` dict: + +```python +results = model.objects.search( + active_label_id=label.id, # Required for most filters + params={ + # Heuristic bucket filter + "heuristic_bucket": "excluded" | "neutral" | "likely", + + # Trainset filter + "trainset_split": "train" | "eval" | "both", + + # Predictor prediction filter + "predictor_value": "true" | "false", + + # Manual annotation filter + "annotation_value": "true" | "false" | "flag" | "any" | "none", + + # Find disagreements between models + "review_disagreements": True, + + # Keyword search (uses query string syntax) + "querystring": "motion, ~denied", + }, + page=1, + page_size=100, +) +``` + +### Semantic Search + +Find examples similar to a query or embedding: + +```python +# Search by text similarity +results = model.objects.search( + semantic_sort="motion for summary judgment", + page_size=50, +) + +# Or use an embedding directly +embedding = [0.1, 0.2, ...] # 96-dim vector +results = model.objects.search(semantic_sort=embedding) +``` + +### Search Result Format + +Each result includes: + +```python +{ + "id": 12345, + "text_hash": "abc123...", + "text": "Full text of the example", + "tags": [1, 5, 12], # Tag IDs + # If in trainset: + "split": "train" | "eval", + "pred": True | False | None, + "reason": "Predictor's reasoning...", +} +``` + +### Count Only + +```python +result = model.objects.search( + active_label_id=label.id, + params={"heuristic_bucket": "neutral"}, + count=True, +) +print(f"Total: {result['total']}") +``` + +### Query String Syntax (Review) + +| Operator | Meaning | Example | +|----------|---------|---------| +| `,` | AND | `motion, court` | +| `\|` | OR | `motion\|filing` | +| `~` | NOT | `~denied` | +| `^` | Starts with | `^Summary` | + +--- + +## Key Files Reference + +| Component | File | Key Lines | +|-----------|------|-----------| +| Models | `clx/app/models.py` | Full file | +| Search | `clx/app/search_utils.py` | `SearchQuerySet.search` | +| Heuristics | `clx/app/models.py` | `LabelHeuristic` class | +| Custom Heuristics | `clx/app/custom_heuristics.py` | Decorator pattern | +| Train CLI | `clx/cli/train.py` | Export/train/import | +| Views | `clx/app/views.py` | All endpoints | +| **Helpers** | `experiment/helpers.py` | Claude Code utilities | + +--- + +## Helper Scripts for Claude Code + +The `experiment/helpers.py` module provides convenient functions for the annotation workflow: + +### Quick Status Check + +```python +from experiment.helpers import print_label_status + +print_label_status("Motion") +``` + +### Searching and Viewing Examples + +```python +from experiment.helpers import ( + search_examples, + print_examples, + disagreements, + neutral_examples, + similar_to, +) + +# Find disagreements between models +examples = disagreements("Motion") +print_examples(examples) + +# Look at edge cases (neutral bucket) +examples = neutral_examples("Motion", page_size=10) +print_examples(examples) + +# Find similar examples +examples = similar_to("Motion", "application for extension of time") +print_examples(examples, show_full_text=True) + +# Complex search +examples = search_examples( + "Motion", + heuristic_bucket="neutral", + querystring="application", + page_size=20, +) +print_examples(examples) +``` + +### Creating Decisions + +```python +from experiment.helpers import ( + create_decision, + create_decision_from_id, + view_decisions, +) + +# View existing decisions +view_decisions("Motion") + +# Create from text +create_decision( + "Motion", + text="Application for Extension of Time", + value=True, + reason="Applications requesting court action are functionally motions" +) + +# Create from search result ID +create_decision_from_id( + "Motion", + example_id=12345, + value=False, + reason="This is a response to a motion, not a motion itself" +) +``` + +### Fast Annotations + +```python +from experiment.helpers import annotate + +annotate("Motion", example_id=12345, value=True) # Positive +annotate("Motion", example_id=12346, value=False) # Negative +annotate("Motion", example_id=12347, value="flag") # Exclude from trainset +``` + +### Creating Heuristics + +```python +from experiment.helpers import create_heuristic + +create_heuristic( + "Motion", + querystring="motion|application|request", + is_minimal=True, + apply=True, # Immediately computes across corpus +) +``` + +--- + +## Scales OKN Integration (docket-entry only) + +For the docket-entry project, we have predictions from Scales OKN—a similar classification project with pre-trained models for many of the same labels. + +### Available Scales Labels + +The following labels have Scales OKN predictions imported: + +| Scales Label | Our Label | +|--------------|-----------| +| summons | Summons | +| waiver | Waiver | +| brief | Brief / Memorandum | +| arrest | Arrest | +| warrant | Warrant | +| verdict | Verdict | +| answer | Answer | +| complaint | Complaint | +| indictment | Indictment | +| information | Information | +| petition | Petition | +| notice | Notice | +| response | Reply / Response | +| minute entry | Minute Entry | +| plea agreement | Plea Agreement | +| judgment | Judgment | +| stipulation | Stipulation | +| motion | Motion | +| order | Order | + +### How Scales Tags Work + +- Each label has a `LabelTag` with `name="scales"` +- Positive Scales predictions (score > 0.5) are tagged +- Absence of tag means Scales predicted negative (or no prediction) + +### Using Scales for Review + +Scales predictions are another source of feedback when reviewing. You can compare: +- Examples where our models predict TRUE but Scales predicts FALSE +- Examples where our models predict FALSE but Scales predicts TRUE + +**Important caveats:** +1. **Scales is not ground truth** - it has errors and may make different annotation decisions +2. **Scope to trainset** - we only compute our predictions on the trainset, so compare within trainset +3. **Check against decisions** - if our models disagree with Scales but are consistent with our documented decisions, that's fine + +### Searching with Scales + +```python +from experiment.helpers import search_examples + +# Find examples where we predict TRUE but Scales predicts FALSE +# (These might be cases Scales missed, or cases we're wrong about) +label = get_label("Motion") +scales_tag = label.labeltag_set.filter(name="scales").first() + +# Search for trainset examples our predictor says TRUE +examples = search_examples( + "Motion", + trainset_split="train", + predictor_value="true", +) + +# Filter to those without scales tag (Scales said FALSE) +# This requires checking tags manually or using raw search +``` + +### When to Use Scales Feedback + +- **After initial model training** - to find potential blind spots +- **When reviewing disagreements** - as an additional signal +- **NOT as automatic corrections** - always review why there's a disagreement + +--- + +## Notes + +- The docket-entry project uses `DocketEntry` as the search model +- Heuristics create `LabelTag` entries attached to documents via PostgreSQL array fields +- The `apply()` step processes documents in batches of 1M for efficiency +- Predictions cost money - batch your changes before re-running +- The main fine-tune model is the production output; underfit helps find different errors diff --git a/cc/helpers.py b/cc/helpers.py new file mode 100644 index 0000000..1a95b77 --- /dev/null +++ b/cc/helpers.py @@ -0,0 +1,497 @@ +""" +Helper functions for Claude Code to interact with the CLX annotation workflow. + +These utilities make it easier to: +- Search and view examples with predictions +- Create decisions and annotations +- Check label status +""" + +from clx import generate_hash +from clx.models import Label, LabelDecision, LabelHeuristic, Project + + +def get_label(label_name: str, project_id: str = "docket-entry") -> Label: + """Get a label by name.""" + return Label.objects.get(project_id=project_id, name=label_name) + + +def get_project(project_id: str = "docket-entry") -> Project: + """Get a project by ID.""" + return Project.objects.get(id=project_id) + + +def label_status(label_name: str, project_id: str = "docket-entry") -> dict: + """Get comprehensive status of a label including warnings and decisions.""" + label = get_label(label_name, project_id) + + # Get all decisions with full info + decisions = [ + { + "id": d.id, + "value": d.value, + "reason": d.reason, + "text": d.text, + "created_at": d.created_at, + "updated_at": d.updated_at, + } + for d in label.decisions.all().order_by("-updated_at") + ] + + # Get finetunes with timestamps + finetunes = [ + { + "config": ft.config_name, + "results": ft.eval_results, + "created_at": ft.created_at, + "updated_at": ft.updated_at, + "finetuned_at": ft.finetuned_at, + "predicted_at": ft.predicted_at, + "is_main": ft.config_name + == label.project.get_search_model().main_finetune_config, + } + for ft in label.fintunes.all() + ] + + # Generate warnings based on timestamp comparisons + warnings = [] + + # Get latest decision timestamp + latest_decision_at = None + if decisions: + latest_decision_at = max(d["updated_at"] for d in decisions) + + # Warning: decisions newer than trainset + if latest_decision_at and label.trainset_updated_at: + if latest_decision_at > label.trainset_updated_at: + warnings.append( + "Decisions updated since last trainset sampling - consider resampling" + ) + elif latest_decision_at and not label.trainset_updated_at: + warnings.append("Trainset has never been sampled") + + # Warning: trainset newer than predictor + if label.trainset_updated_at and label.predictor_updated_at: + if label.trainset_updated_at > label.predictor_updated_at: + warnings.append( + "Trainset updated since last predictor fit - consider refitting" + ) + elif label.trainset_updated_at and not label.predictor_updated_at: + warnings.append("Predictor has never been fit") + + # Warning: predictor newer than predictions + if label.predictor_updated_at and label.trainset_predictions_updated_at: + if label.predictor_updated_at > label.trainset_predictions_updated_at: + warnings.append( + "Predictor updated since last prediction run - consider rerunning predictions" + ) + elif ( + label.predictor_updated_at + and not label.trainset_predictions_updated_at + ): + warnings.append("Predictions have never been run") + + # Warning: predictions newer than finetunes + if label.trainset_predictions_updated_at and finetunes: + for ft in finetunes: + finetuned_at = ft.get("finetuned_at") + if ( + not finetuned_at + or label.trainset_predictions_updated_at > finetuned_at + ): + warnings.append( + f"Predictions updated since '{ft['config']}' finetune - consider retraining" + ) + + # Warning: finetunes newer than global predictions + for ft in finetunes: + finetuned_at = ft.get("finetuned_at") + predicted_at = ft.get("predicted_at") + if finetuned_at and (not predicted_at or finetuned_at > predicted_at): + warnings.append( + f"'{ft['config']}' finetune updated since global predictions - consider running predict_finetune" + ) + + return { + "name": label.name, + "id": label.id, + "warnings": warnings, + "heuristic_buckets": { + "excluded": label.num_excluded, + "neutral": label.num_neutral, + "likely": label.num_likely, + }, + "heuristics": [ + { + "id": h.id, + "query": h.querystring or f"[custom: {h.custom}]", + "is_minimal": h.is_minimal, + "is_likely": h.is_likely, + "num_examples": h.num_examples, + "applied_at": h.applied_at, + } + for h in label.heuristics.all() + ], + "decisions": decisions, + "trainset": { + "train": label.trainset_examples.filter(split="train").count(), + "eval": label.trainset_examples.filter(split="eval").count(), + "updated_at": label.trainset_updated_at, + }, + "predictor": { + "positive_preds": label.trainset_num_positive_preds, + "negative_preds": label.trainset_num_negative_preds, + "updated_at": label.trainset_predictions_updated_at, + "fitted_at": label.predictor_updated_at, + "inference_model": label.inference_model, + "teacher_model": label.teacher_model, + }, + "finetunes": finetunes, + } + + +def print_label_status(label_name: str, project_id: str = "docket-entry"): + """Print a formatted label status report.""" + status = label_status(label_name, project_id) + + print(f"=== Label: {status['name']} (ID: {status['id']}) ===\n") + + # Show warnings prominently at the top + if status["warnings"]: + print("WARNINGS:") + for warning in status["warnings"]: + print(f" ! {warning}") + print() + + print("Heuristic Buckets:") + for bucket, count in status["heuristic_buckets"].items(): + print(f" {bucket}: {count:,}") + + print(f"\nHeuristics ({len(status['heuristics'])}):") + for h in status["heuristics"]: + flags = [] + if h["is_minimal"]: + flags.append("minimal") + if h["is_likely"]: + flags.append("likely") + flag_str = f" [{', '.join(flags)}]" if flags else "" + print(f" {h['query']}{flag_str} → {h['num_examples']:,} matches") + + print(f"\nDecisions ({len(status['decisions'])}):") + for d in status["decisions"]: + value_str = "TRUE" if d["value"] else "FALSE" + text_preview = ( + d["text"][:80] + "..." + if d["text"] and len(d["text"]) > 80 + else d["text"] + ) + print(f" [{value_str}] {text_preview}") + print(f" Reason: {d['reason']}") + + print("\nTrainset:") + print(f" Train: {status['trainset']['train']:,}") + print(f" Eval: {status['trainset']['eval']:,}") + print(f" Updated: {status['trainset']['updated_at']}") + + print("\nPredictor:") + print(f" Positive preds: {status['predictor']['positive_preds']:,}") + print(f" Negative preds: {status['predictor']['negative_preds']:,}") + print(f" Fitted: {status['predictor']['fitted_at']}") + print(f" Predictions updated: {status['predictor']['updated_at']}") + + if status["finetunes"]: + print("\nFinetunes:") + for ft in status["finetunes"]: + print(f" {ft['config']}:") + print(f" Results: {ft['results']}") + print(f" Finetuned at: {ft['finetuned_at']}") + print(f" Global predictions at: {ft['predicted_at']}") + + +def search_examples( + label_name: str, + project_id: str = "docket-entry", + heuristic_bucket: str | None = None, + trainset_split: str | None = None, + predictor_value: str | None = None, + annotation_value: str | None = None, + review_disagreements: bool = False, + querystring: str | None = None, + semantic_sort: str | None = None, + page: int = 1, + page_size: int = 20, +) -> list[dict]: + """ + Search for examples with full context. + + Returns examples with: + - text and metadata + - predictor prediction and reason (if in trainset) + - finetune predictions + - annotation status + """ + label = get_label(label_name, project_id) + model = label.project.get_search_model() + + params = {} + if heuristic_bucket: + params["heuristic_bucket"] = heuristic_bucket + if trainset_split: + params["trainset_split"] = trainset_split + if predictor_value: + params["predictor_value"] = predictor_value + if annotation_value: + params["annotation_value"] = annotation_value + if review_disagreements: + params["review_disagreements"] = True + if querystring: + params["querystring"] = querystring + + search_kwargs = { + "active_label_id": label.id, + "params": params, + "page": page, + "page_size": page_size, + } + if semantic_sort: + search_kwargs["semantic_sort"] = semantic_sort + + results = model.objects.search(**search_kwargs) + + # Enrich with tag information + enriched = [] + for item in results.get("data", []): + tags = item.get("tags", []) + + # Check annotation status + anno_status = None + if label.anno_true_tag.id in tags: + anno_status = "true" + elif label.anno_false_tag.id in tags: + anno_status = "false" + elif label.anno_flag_tag.id in tags: + anno_status = "flag" + + # Check finetune predictions + finetune_preds = {} + for ft in label.fintunes.all(): + ft_tag = label.get_trainset_finetune_tag(ft.config_name) + finetune_preds[ft.config_name] = ft_tag.id in tags + + # Check predictor prediction + predictor_pred = label.trainset_pred_tag.id in tags + + enriched.append( + { + "id": item["id"], + "text_hash": item["text_hash"], + "text": item["text"], + "annotation": anno_status, + "predictor_pred": predictor_pred + if item.get("split") + else None, + "predictor_reason": item.get("reason"), + "finetune_preds": finetune_preds if finetune_preds else None, + "trainset_split": item.get("split"), + } + ) + + return enriched + + +def print_examples( + examples: list[dict], + show_full_text: bool = False, + max_text_len: int = 120, +): + """Print examples in a readable format.""" + for i, ex in enumerate(examples, 1): + print(f"\n{'=' * 60}") + print(f"[{i}] ID: {ex['id']}") + + text = ex["text"] + if not show_full_text and len(text) > max_text_len: + text = text[:max_text_len] + "..." + print(f"Text: {text}") + + # Predictions + preds = [] + if ex.get("predictor_pred") is not None: + preds.append(f"predictor={ex['predictor_pred']}") + if ex.get("finetune_preds"): + for config, pred in ex["finetune_preds"].items(): + preds.append(f"{config}={pred}") + if preds: + print(f"Predictions: {', '.join(preds)}") + + if ex.get("predictor_reason"): + print(f"Reason: {ex['predictor_reason']}") + + if ex.get("annotation"): + print(f"Annotation: {ex['annotation']}") + + if ex.get("trainset_split"): + print(f"Split: {ex['trainset_split']}") + + +def view_decisions(label_name: str, project_id: str = "docket-entry"): + """View all decisions for a label.""" + label = get_label(label_name, project_id) + + print(f"=== Decisions for {label_name} ===\n") + + for d in label.decisions.all().order_by("-updated_at"): + value_str = "TRUE" if d.value else "FALSE" + print(f"[{value_str}] {d.text[:100]}...") + print(f" Reason: {d.reason}") + print() + + +def create_decision( + label_name: str, + text: str, + value: bool, + reason: str, + project_id: str = "docket-entry", +) -> LabelDecision: + """Create or update a decision.""" + label = get_label(label_name, project_id) + text_hash = generate_hash(text) + + decision, created = LabelDecision.objects.update_or_create( + label=label, + text_hash=text_hash, + defaults={ + "text": text, + "value": value, + "reason": reason, + }, + ) + + action = "Created" if created else "Updated" + print(f"{action} decision: {value} - {reason}") + return decision + + +def create_decision_from_id( + label_name: str, + example_id: int, + value: bool, + reason: str, + project_id: str = "docket-entry", +) -> LabelDecision: + """Create a decision from an example ID.""" + label = get_label(label_name, project_id) + model = label.project.get_search_model() + example = model.objects.get(id=example_id) + + return create_decision( + label_name=label_name, + text=example.text, + value=value, + reason=reason, + project_id=project_id, + ) + + +def annotate( + label_name: str, + example_id: int, + value: bool | str | None, + project_id: str = "docket-entry", +): + """ + Set a fast annotation on an example. + + value can be: + - True: positive + - False: negative + - "flag": exclude from trainset + - None: clear annotation + """ + label = get_label(label_name, project_id) + model = label.project.get_search_model() + example = model.objects.get(id=example_id) + example.set_annotation(label, value) + print(f"Set annotation {value} on example {example_id}") + + +def create_heuristic( + label_name: str, + querystring: str, + is_minimal: bool = False, + is_likely: bool = False, + apply: bool = True, + project_id: str = "docket-entry", +) -> LabelHeuristic: + """Create and optionally apply a heuristic.""" + label = get_label(label_name, project_id) + + heuristic = LabelHeuristic.objects.create( + label=label, + querystring=querystring, + is_minimal=is_minimal, + is_likely=is_likely, + ) + + if apply: + print(f"Applying heuristic: {querystring}") + heuristic.apply() + label.refresh_from_db() + print(f"Matches: {heuristic.num_examples:,}") + print( + f"Buckets - Excluded: {label.num_excluded:,}, Neutral: {label.num_neutral:,}, Likely: {label.num_likely:,}" + ) + + return heuristic + + +# Quick aliases for common operations +def disagreements(label_name: str, page_size: int = 20, **kwargs): + """Find examples where models disagree.""" + return search_examples( + label_name, + review_disagreements=True, + page_size=page_size, + **kwargs, + ) + + +def neutral_examples(label_name: str, page_size: int = 20, **kwargs): + """Get examples from the neutral bucket (edge cases).""" + return search_examples( + label_name, + heuristic_bucket="neutral", + page_size=page_size, + **kwargs, + ) + + +def likely_examples(label_name: str, page_size: int = 20, **kwargs): + """Get examples from the likely bucket (probable positives).""" + return search_examples( + label_name, + heuristic_bucket="likely", + page_size=page_size, + **kwargs, + ) + + +def excluded_examples(label_name: str, page_size: int = 20, **kwargs): + """Get examples from the excluded bucket (probable negatives).""" + return search_examples( + label_name, + heuristic_bucket="excluded", + page_size=page_size, + **kwargs, + ) + + +def similar_to(label_name: str, text: str, page_size: int = 20, **kwargs): + """Find examples semantically similar to the given text.""" + return search_examples( + label_name, + semantic_sort=text, + page_size=page_size, + **kwargs, + ) diff --git a/clx/app/migrations/0004_labelfinetune_finetuned_at_and_more.py b/clx/app/migrations/0004_labelfinetune_finetuned_at_and_more.py new file mode 100644 index 0000000..c5e4b19 --- /dev/null +++ b/clx/app/migrations/0004_labelfinetune_finetuned_at_and_more.py @@ -0,0 +1,23 @@ +# Generated by Django 5.2.7 on 2026-01-15 18:50 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('app', '0003_labelfinetune'), + ] + + operations = [ + migrations.AddField( + model_name='labelfinetune', + name='finetuned_at', + field=models.DateTimeField(blank=True, null=True), + ), + migrations.AddField( + model_name='labelfinetune', + name='predicted_at', + field=models.DateTimeField(blank=True, null=True), + ), + ] diff --git a/clx/app/models.py b/clx/app/models.py index 0d37451..94a6b5a 100644 --- a/clx/app/models.py +++ b/clx/app/models.py @@ -9,8 +9,9 @@ from clx import label2slug from clx.llm import GEPAPredictor, SingleLabelPredictor, batch_embed, mesh_sort -from clx.ml import pipeline +from clx.ml import pipeline, training_run from clx.settings import CLX_HOME +from clx.utils import pd_save_or_append from .custom_heuristics import custom_heuristics from .search_utils import BaseModel, SearchDocumentModel @@ -273,6 +274,8 @@ def load_annos(self): flag_annos["value"] = None annos = pd.concat([pos_annos, neg_annos, flag_annos]) + if annos.empty: + return pd.DataFrame(columns=["text_hash", "text", "value"]) return annos def load_trainset(self): @@ -326,8 +329,9 @@ def update_trainset_pred_counts(self): data = self.load_trainset() if len(data) and "pred" in data.columns: data = data.dropna(subset=["pred"]) - self.trainset_num_positive_preds = data["pred"].sum() - self.trainset_num_negative_preds = (~data["pred"]).sum() + preds = data["pred"].astype(bool) + self.trainset_num_positive_preds = preds.sum() + self.trainset_num_negative_preds = (~preds).sum() else: self.trainset_num_positive_preds = 0 self.trainset_num_negative_preds = 0 @@ -379,6 +383,14 @@ def get_trainset_finetune_tag(self, config_name): ) return tag + @property + def finetune_tag(self): + tag, _ = LabelTag.objects.get_or_create( + name="ft", + label=self, + ) + return tag + @property def anno_true_tag(self): tag, _ = LabelTag.objects.get_or_create( @@ -521,6 +533,196 @@ def prepare_finetune( return train_data, eval_data, run_config + def train_finetune(self, config_name): + """Train a finetune model for this label.""" + train_data, eval_data, run_config = self.prepare_finetune(config_name) + + run = training_run(**run_config) + outputs = run.train(train_data, eval_data, overwrite=True, remote=True) + + data = pd.concat([train_data, eval_data]) + + pipe = self.get_finetune_run_pipe(config_name) + data["pred"] = pipe(data["text"].tolist(), batch_size=16) + data = data[data["pred"] == "yes"] + + tag = self.get_trainset_finetune_tag(config_name) + model = self.project.get_search_model() + example_ids = model.objects.filter( + text_hash__in=data["text_hash"].tolist() + ) + example_ids = example_ids.values_list("id", flat=True) + model.bulk_replace_tag(tag.id, example_ids) + + finetune, _ = LabelFinetune.objects.get_or_create( + label=self, config_name=config_name + ) + finetune.eval_results = outputs["results"] + finetune.finetuned_at = timezone.now() + finetune.save() + + return finetune + + def predict_finetune(self, batch_size=16, num_workers=64, force=False): + """Run finetune predictions across the entire corpus.""" + cache_path = self.data_dir / "finetune_predictions_cache.csv" + self.data_dir.mkdir(parents=True, exist_ok=True) + config_name = self.project.get_search_model().main_finetune_config + if config_name is None: + raise ValueError("Set main_finetune_config for this project") + + if force and cache_path.exists(): + cache_path.unlink() + + cached_ids = set() + if cache_path.exists(): + cached_data = pd.read_csv(cache_path) + cached_ids = set(cached_data["id"].unique().tolist()) + + model = self.project.get_search_model() + pipe = self.get_finetune_run_pipe(config_name) + + total_examples = model.objects.count() + outer_batch_size = 1024 * 500 + for batch in tqdm( + model.objects.batch_df("id", "text", batch_size=outer_batch_size), + desc=f"Predicting {config_name}", + total=total_examples // outer_batch_size, + ): + batch = batch[~batch["id"].isin(cached_ids)] + if len(batch) > 0: + batch["value"] = pipe( + batch["text"].tolist(), + batch_size=batch_size, + num_workers=num_workers, + max_length=768, + truncation=True, + ) + batch["value"] = batch["value"].apply(lambda x: x == "yes") + pd_save_or_append(batch[["id", "value"]], cache_path) + + if cache_path.exists(): + all_preds = pd.read_csv(cache_path) + positive_ids = all_preds[all_preds["value"]]["id"].tolist() + tag = self.finetune_tag + model.bulk_replace_tag(tag.id, positive_ids) + finetune = self.fintunes.filter(config_name=config_name).first() + if finetune: + finetune.predicted_at = timezone.now() + finetune.save() + cache_path.unlink() + + print( + f"Predictions complete: {len(positive_ids):,} positive out of {len(all_preds):,} total" + ) + + def update_all(self, num_threads=128, predict=False, force=False): + """Update all components that are out of date based on timestamps. + + Runs the full pipeline in order, but only steps that need updating: + 1. Resample trainset (if decisions newer than trainset) + 2. Fit predictor (if trainset newer than predictor) + 3. Run predictions (if predictor newer than predictions) + 4. Train finetunes (if predictions newer than finetunes) + 5. Run global corpus predictions (if predict is True and finetune newer than global predictions) + """ + missing = [] + if not self.heuristics.filter(is_minimal=True).exists(): + missing.append("at least one minimal heuristic") + if not self.heuristics.filter(is_likely=True).exists(): + missing.append("at least one likely heuristic") + if not self.decisions.filter(value=True).exists(): + missing.append("at least one positive decision") + if not self.decisions.filter(value=False).exists(): + missing.append("at least one negative decision") + + if missing: + print("Cannot run update_all - missing required setup:") + for item in missing: + print(f" - {item}") + return + + model = self.project.get_search_model() + finetune_configs = list(model.finetune_configs.keys()) + + # Get latest decision timestamp + latest_decision = self.decisions.order_by("-updated_at").first() + latest_decision_at = ( + latest_decision.updated_at if latest_decision else None + ) + + # Step 1: Resample trainset if decisions are newer + if force or ( + latest_decision_at + and ( + not self.trainset_updated_at + or latest_decision_at > self.trainset_updated_at + ) + ): + print("Resampling trainset...") + self.update_trainset() + self.refresh_from_db() + + # Step 2: Fit predictor if trainset is newer + if force or ( + self.trainset_updated_at + and ( + not self.predictor_updated_at + or self.trainset_updated_at > self.predictor_updated_at + ) + ): + print("Fitting predictor...") + self.fit_predictor() + self.refresh_from_db() + + # Step 3: Run predictions if predictor is newer + if force or ( + self.predictor_updated_at + and ( + not self.trainset_predictions_updated_at + or self.predictor_updated_at + > self.trainset_predictions_updated_at + ) + ): + print("Running predictions...") + self.update_trainset_preds(num_threads=num_threads) + self.refresh_from_db() + + # Step 4: Train finetunes if predictions are newer + for config_name in finetune_configs: + finetune = self.fintunes.filter(config_name=config_name).first() + finetuned_at = finetune.finetuned_at if finetune else None + + if force or ( + self.trainset_predictions_updated_at + and ( + not finetuned_at + or self.trainset_predictions_updated_at > finetuned_at + ) + ): + print(f"Training finetune: {config_name}...") + self.train_finetune(config_name) + + # Step 5: Run global corpus predictions if finetune is newer + if predict: + ft = self.fintunes.filter( + config_name=self.project.get_search_model().main_finetune_config + ).first() + if ft and ( + force + or ( + ft.finetuned_at + and ( + not ft.predicted_at + or ft.finetuned_at > ft.predicted_at + ) + ) + ): + print(f"Running global predictions: {ft.config_name}...") + self.predict_finetune(force=force) + + print("Update complete!") + class Meta: unique_together = ("project", "name") @@ -710,6 +912,8 @@ class LabelFinetune(BaseModel): ) config_name = models.CharField(max_length=255) eval_results = models.JSONField(null=True, blank=True) + finetuned_at = models.DateTimeField(null=True, blank=True) + predicted_at = models.DateTimeField(null=True, blank=True) class DocketEntry(SearchDocumentModel): @@ -736,6 +940,7 @@ class DocketEntry(SearchDocumentModel): }, }, } + main_finetune_config = "main" id = models.BigIntegerField(primary_key=True) recap_id = models.BigIntegerField(unique=True) diff --git a/clx/app/search_utils.py b/clx/app/search_utils.py index a3de4d9..712ba43 100644 --- a/clx/app/search_utils.py +++ b/clx/app/search_utils.py @@ -351,6 +351,7 @@ class SearchDocumentModel(BaseModel, metaclass=SearchDocumentModelBase): project_id = None finetune_configs = {} + main_finetune_config = None id = models.BigIntegerField(primary_key=True) text = models.TextField() diff --git a/clx/app/templates/search/results_tab.html b/clx/app/templates/search/results_tab.html index ccee5b9..7cec4d4 100644 --- a/clx/app/templates/search/results_tab.html +++ b/clx/app/templates/search/results_tab.html @@ -140,6 +140,15 @@ > +
+ +
diff --git a/clx/cli/cleanup.py b/clx/cli/cleanup.py index cea0169..dc420c4 100644 --- a/clx/cli/cleanup.py +++ b/clx/cli/cleanup.py @@ -3,7 +3,9 @@ @click.command() -def cleanup(): +@click.option("--predict", is_flag=True, help="Run global corpus predictions") +@click.option("--label-name", "--label", help="Label name") +def cleanup(predict, label_name): """Sync app data.""" from clx.models import LabelHeuristic, Project @@ -39,3 +41,8 @@ def cleanup(): ): print(f"Updating heuristic {heuristic.name}...") heuristic.apply() + + for label in project.labels.all(): + if label_name is None or label.name == label_name: + print(f"Updating label {label.name}...") + label.update_all(predict=predict) diff --git a/clx/cli/train.py b/clx/cli/train.py index b5921eb..c208463 100644 --- a/clx/cli/train.py +++ b/clx/cli/train.py @@ -1,7 +1,4 @@ import click -import pandas as pd - -from clx.ml import training_run @click.command() @@ -9,31 +6,8 @@ @click.argument("label_name") @click.argument("config_name") def train(project_id, label_name, config_name): - """Train a label.""" - from clx.models import Label, LabelFinetune, Project - - project = Project.objects.get(id=project_id) - label = Label.objects.get(project=project, name=label_name) - train_data, eval_data, run_config = label.prepare_finetune(config_name) - - run = training_run(**run_config) - outputs = run.train(train_data, eval_data, overwrite=True, remote=True) - - data = pd.concat([train_data, eval_data]) - - pipe = label.get_finetune_run_pipe(config_name) - data["pred"] = pipe(data["text"].tolist(), batch_size=16) - data = data[data["pred"] == "yes"] + """Train a finetune model for a label.""" + from clx.models import Label - tag = label.get_trainset_finetune_tag(config_name) - model = project.get_search_model() - example_ids = model.objects.filter( - text_hash__in=data["text_hash"].tolist() - ) - example_ids = example_ids.values_list("id", flat=True) - model.bulk_replace_tag(tag.id, example_ids) - finetune, _ = LabelFinetune.objects.get_or_create( - label=label, config_name=config_name - ) - finetune.eval_results = outputs["results"] - finetune.save() + label = Label.objects.get(project_id=project_id, name=label_name) + label.train_finetune(config_name) diff --git a/clx/ml/remote_pipeline.py b/clx/ml/remote_pipeline.py index d117c8a..c2e1ca1 100644 --- a/clx/ml/remote_pipeline.py +++ b/clx/ml/remote_pipeline.py @@ -52,7 +52,7 @@ def handle_megabatch(megabatch: list): ) outputs = response.json() except Exception as e: - errors.append(str(e)) + outputs = {"status": "FAILED", "error": str(e)} if outputs["status"] != "COMPLETED": errors.append(outputs) else: