diff --git a/clx/app/migrations/0005_remove_labeltrainsetexample_reason_and_more.py b/clx/app/migrations/0005_remove_labeltrainsetexample_reason_and_more.py new file mode 100644 index 0000000..73b31f7 --- /dev/null +++ b/clx/app/migrations/0005_remove_labeltrainsetexample_reason_and_more.py @@ -0,0 +1,23 @@ +# Generated by Django 5.2.7 on 2026-02-13 21:12 + +import django.db.models.deletion +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('app', '0004_labelfinetune_finetuned_at_and_more'), + ] + + operations = [ + migrations.RemoveField( + model_name='labeltrainsetexample', + name='reason', + ), + migrations.AddField( + model_name='labeltrainsetexample', + name='decision', + field=models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, related_name='trainset_examples', to='app.labeldecision'), + ), + ] diff --git a/clx/app/migrations/0006_remove_labeltrainsetexample_decision_and_more.py b/clx/app/migrations/0006_remove_labeltrainsetexample_decision_and_more.py new file mode 100644 index 0000000..e8dad74 --- /dev/null +++ b/clx/app/migrations/0006_remove_labeltrainsetexample_decision_and_more.py @@ -0,0 +1,22 @@ +# Generated by Django 5.2.7 on 2026-02-16 22:16 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('app', '0005_remove_labeltrainsetexample_reason_and_more'), + ] + + operations = [ + migrations.RemoveField( + model_name='labeltrainsetexample', + name='decision', + ), + migrations.AddField( + model_name='labeltrainsetexample', + name='reason', + field=models.TextField(blank=True, null=True), + ), + ] diff --git a/clx/app/migrations/0007_remove_label_inference_model_and_more.py b/clx/app/migrations/0007_remove_label_inference_model_and_more.py new file mode 100644 index 0000000..e57578e --- /dev/null +++ b/clx/app/migrations/0007_remove_label_inference_model_and_more.py @@ -0,0 +1,48 @@ +# Generated by Django 5.2.7 on 2026-02-17 15:15 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('app', '0006_remove_labeltrainsetexample_decision_and_more'), + ] + + operations = [ + migrations.RemoveField( + model_name='label', + name='inference_model', + ), + migrations.RemoveField( + model_name='label', + name='predictor_data', + ), + migrations.RemoveField( + model_name='label', + name='predictor_updated_at', + ), + migrations.RemoveField( + model_name='label', + name='teacher_model', + ), + migrations.RemoveField( + model_name='label', + name='trainset_examples_per_heuristic_bucket', + ), + migrations.AlterField( + model_name='label', + name='trainset_num_excluded', + field=models.IntegerField(default=50), + ), + migrations.AlterField( + model_name='label', + name='trainset_num_likely', + field=models.IntegerField(default=50), + ), + migrations.AlterField( + model_name='label', + name='trainset_num_neutral', + field=models.IntegerField(default=50), + ), + ] diff --git a/clx/app/migrations/0008_alter_label_trainset_num_decision_neighbors_and_more.py b/clx/app/migrations/0008_alter_label_trainset_num_decision_neighbors_and_more.py new file mode 100644 index 0000000..eecf460 --- /dev/null +++ b/clx/app/migrations/0008_alter_label_trainset_num_decision_neighbors_and_more.py @@ -0,0 +1,33 @@ +# Generated by Django 5.2.7 on 2026-02-18 14:03 + +import django.db.models.deletion +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('app', '0007_remove_label_inference_model_and_more'), + ] + + operations = [ + migrations.AlterField( + model_name='label', + name='trainset_num_decision_neighbors', + field=models.IntegerField(default=20), + ), + migrations.CreateModel( + name='LabelQuerystring', + fields=[ + ('id', models.BigAutoField(primary_key=True, serialize=False)), + ('created_at', models.DateTimeField(auto_now_add=True)), + ('updated_at', models.DateTimeField(auto_now=True)), + ('querystring', models.TextField()), + ('num_examples', models.IntegerField(default=30)), + ('label', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='querystrings', to='app.label')), + ], + options={ + 'abstract': False, + }, + ), + ] diff --git a/clx/app/migrations/0009_alter_labelquerystring_num_examples_and_more.py b/clx/app/migrations/0009_alter_labelquerystring_num_examples_and_more.py new file mode 100644 index 0000000..ac46d73 --- /dev/null +++ b/clx/app/migrations/0009_alter_labelquerystring_num_examples_and_more.py @@ -0,0 +1,22 @@ +# Generated by Django 5.2.7 on 2026-02-18 14:36 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('app', '0008_alter_label_trainset_num_decision_neighbors_and_more'), + ] + + operations = [ + migrations.AlterField( + model_name='labelquerystring', + name='num_examples', + field=models.IntegerField(default=50), + ), + migrations.AlterUniqueTogether( + name='labelquerystring', + unique_together={('label', 'querystring')}, + ), + ] diff --git a/clx/app/migrations/0010_labeldecision_added_to_sample_and_more.py b/clx/app/migrations/0010_labeldecision_added_to_sample_and_more.py new file mode 100644 index 0000000..09b1f56 --- /dev/null +++ b/clx/app/migrations/0010_labeldecision_added_to_sample_and_more.py @@ -0,0 +1,23 @@ +# Generated by Django 5.2.7 on 2026-02-18 15:11 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('app', '0009_alter_labelquerystring_num_examples_and_more'), + ] + + operations = [ + migrations.AddField( + model_name='labeldecision', + name='added_to_sample', + field=models.BooleanField(default=False), + ), + migrations.AddField( + model_name='labelquerystring', + name='added_to_sample', + field=models.BooleanField(default=False), + ), + ] diff --git a/clx/app/models.py b/clx/app/models.py index 1147d29..5eefa1e 100644 --- a/clx/app/models.py +++ b/clx/app/models.py @@ -1,3 +1,6 @@ +import time +from concurrent.futures import ThreadPoolExecutor, as_completed + import lmdb import numpy as np import pandas as pd @@ -8,7 +11,8 @@ from tqdm import tqdm from clx import label2slug -from clx.llm import GEPAPredictor, SingleLabelPredictor, batch_embed, mesh_sort +from clx.llm import batch_embed, mesh_sort +from clx.llm.anno_agent import AnnoAgent from clx.ml import pipeline, training_run from clx.settings import CLX_HOME from clx.utils import pd_save_or_append @@ -84,47 +88,18 @@ class Label(BaseModel): Project, on_delete=models.CASCADE, related_name="labels" ) name = models.CharField(max_length=255) + instructions = models.TextField(null=True, blank=True) # Sample counts num_excluded = models.IntegerField(default=0) num_neutral = models.IntegerField(default=0) num_likely = models.IntegerField(default=0) - # Predictor config - llm_models = [ - ("GPT-5 Mini", "openai/gpt-5-mini"), - ("GPT-5", "openai/gpt-5"), - ("Gemini 2.5 Flash Lite", "gemini/gemini-2.5-flash-lite"), - ("Gemini 2.5 Flash", "gemini/gemini-2.5-flash"), - ("Gemini 2.5 Pro", "gemini/gemini-2.5-pro"), - ("Qwen 235B-A22B", "bedrock/qwen.qwen3-235b-a22b-2507-v1:0"), - ( - "Claude Sonnet 4.5", - "bedrock/us.anthropic.claude-sonnet-4-5-20250929-v1:0", - ), - ] - default_inference_model = "openai/gpt-5-mini" - default_teacher_model = "openai/gpt-5" - instructions = models.TextField(null=True, blank=True) - inference_model = models.CharField( - max_length=255, - choices=llm_models, - default=default_inference_model, - ) - teacher_model = models.CharField( - max_length=255, - choices=llm_models, - default=default_teacher_model, - ) - predictor_data = models.JSONField(null=True, blank=True) - predictor_updated_at = models.DateTimeField(null=True, blank=True) - # Trainset config - trainset_examples_per_heuristic_bucket = models.IntegerField(default=1000) - trainset_num_excluded = models.IntegerField(default=1000) - trainset_num_neutral = models.IntegerField(default=1000) - trainset_num_likely = models.IntegerField(default=1000) - trainset_num_decision_neighbors = models.IntegerField(default=50) + trainset_num_excluded = models.IntegerField(default=50) + trainset_num_neutral = models.IntegerField(default=50) + trainset_num_likely = models.IntegerField(default=50) + trainset_num_decision_neighbors = models.IntegerField(default=20) trainset_updated_at = models.DateTimeField(null=True, blank=True) trainset_predictions_updated_at = models.DateTimeField( null=True, blank=True @@ -173,29 +148,83 @@ def likely_query(self, queryset=None): return queryset.none() return queryset.tags(any=minimal_tag_ids).tags(any=likely_tag_ids) + def get_minimal_fn(self): + minimal_fns = [ + x.heuristic.get_apply_fn() + for x in LabelTag.objects.filter( + label=self, heuristic__is_minimal=True + ) + ] + + def minimal_fn(text): + return any(f(text) for f in minimal_fns) + + return minimal_fn + + def get_likely_fn(self): + likely_fns = [ + x.heuristic.get_apply_fn() + for x in LabelTag.objects.filter( + label=self, heuristic__is_likely=True + ) + ] + + def likely_fn(text): + return any(f(text) for f in likely_fns) + + return likely_fn + def update_counts(self): self.num_excluded = self.excluded_query().count() self.num_likely = self.likely_query().count() self.num_neutral = self.neutral_query().count() self.save() - def sample_trainset(self, ratio=1): - """Sample trainset examples.""" - data = [] + def update_trainset(self): + data = self.load_trainset() + model = self.project.get_search_model() + + # Reset predictions for existing anno disagreements + needs_corrections = data[ + data["anno_value"].notna() + & data["pred"].notna() + & (data["anno_value"] != data["pred"]) + ]["text_hash"].tolist() + if len(needs_corrections): + LabelTrainsetExample.objects.filter( + label=self, text_hash__in=needs_corrections + ).update(pred=None, reason=None) + + new_ids = [] + # Sample decision neighbors model = self.project.get_search_model() for decision in self.decisions.all(): - embedding = ( - model.objects.filter(text_hash=decision.text_hash) - .first() - .embedding.to_list() - ) - decision_examples = model.objects.search( - semantic_sort=embedding, - page_size=int(self.trainset_num_decision_neighbors * ratio), - ) - data += [{"id": x["id"]} for x in decision_examples["data"]] + if not decision.added_to_sample: + embedding = ( + model.objects.filter(text_hash=decision.text_hash) + .first() + .embedding.to_list() + ) + decision_examples = model.objects.search( + semantic_sort=embedding, + page_size=self.trainset_num_decision_neighbors, + ) + new_ids += [x["id"] for x in decision_examples["data"]] + decision.save(added_to_sample=True) + + # Sample on querystring samplers + for querystring in self.querystrings.all(): + if not querystring.added_to_sample: + querystring_examples = model.objects.search( + params={"querystring": querystring.querystring}, + page_size=querystring.num_examples, + sort=["shuffle_sort", "id"], + ) + new_ids += [x["id"] for x in querystring_examples["data"]] + querystring.save(added_to_sample=True) + # Mesh sort helper def apply_mesh_sort(queryset, n_examples): """Select 10x the number of examples and take most diverse 10%""" cluster_ks = [10, 10] @@ -206,51 +235,69 @@ def apply_mesh_sort(queryset, n_examples): np.array(data["embedding"].tolist()), cluster_ks ) data = data.sort_values(by="sort").head(n_examples) - return data[["id"]].to_dict("records") + return data["id"].tolist() - # Sample heuristic buckets - data += apply_mesh_sort( - self.excluded_query(), int(self.trainset_num_excluded * ratio) + # Sample from heuristic buckets + num_excluded = self.trainset_num_excluded - len( + data[data["bucket"] == "excluded"] ) - data += apply_mesh_sort( - self.neutral_query(), int(self.trainset_num_neutral * ratio) + num_neutral = self.trainset_num_neutral - len( + data[data["bucket"] == "neutral"] ) - data += apply_mesh_sort( - self.likely_query(), int(self.trainset_num_likely * ratio) + num_likely = self.trainset_num_likely - len( + data[data["bucket"] == "likely"] ) - data = pd.DataFrame(data).drop_duplicates(subset="id").sample(frac=1) - return data["id"].tolist() - - def update_trainset(self): - self.trainset_examples.all().delete() - model = self.project.get_search_model() - - train_ids = self.sample_trainset(ratio=1) - train_examples = model.objects.filter(id__in=train_ids).values( - "text", "text_hash" + if num_excluded > 0: + new_ids += apply_mesh_sort(self.excluded_query(), num_excluded) + if num_neutral > 0: + new_ids += apply_mesh_sort(self.neutral_query(), num_neutral) + if num_likely > 0: + new_ids += apply_mesh_sort(self.likely_query(), num_likely) + + # Get new examples + cols = ["text", "text_hash"] + new_examples = pd.DataFrame( + model.objects.filter(id__in=new_ids).values(*cols), + columns=cols, ) - train_examples = pd.DataFrame(train_examples) + new_examples = new_examples[ + ~new_examples["text_hash"].isin(data["text_hash"]) + ] + new_examples = new_examples.drop_duplicates(subset="text_hash") + new_examples = new_examples.sample(frac=1) + + # Make train/eval split + split = int(len(new_examples) * 0.8) + train_examples = new_examples.head(split) train_examples["split"] = "train" - - eval_ids = self.sample_trainset(ratio=0.2) - eval_examples = model.objects.filter(id__in=eval_ids).values( - "text", "text_hash" - ) - eval_examples = pd.DataFrame(eval_examples) + eval_examples = new_examples.tail(len(new_examples) - split) eval_examples["split"] = "eval" + new_examples = pd.concat([train_examples, eval_examples]) + + new_examples = pd.concat([train_examples, eval_examples]) - trainset = pd.concat([train_examples, eval_examples]) - trainset = trainset.drop_duplicates(subset="text_hash") - rows = trainset.to_dict("records") + # Add to trainset + rows = new_examples.to_dict("records") LabelTrainsetExample.objects.bulk_create( [LabelTrainsetExample(label_id=self.id, **row) for row in rows], batch_size=1000, ) self.sync_trainset_tags() + self.update_trainset_pred_counts() self.trainset_updated_at = timezone.now() self.save() + def reset_trainset(self): + self.trainset_examples.all().delete() + self.decisions.all().update(added_to_sample=False) + self.querystrings.all().update(added_to_sample=False) + self.sync_trainset_tags() + self.update_trainset_pred_counts() + self.trainset_updated_at = None + self.trainset_predictions_updated_at = None + self.save() + def load_annos(self): project = self.project search_model = project.get_search_model() @@ -279,44 +326,73 @@ def load_annos(self): return annos def load_trainset(self): + trainset_hashes = list( + self.trainset_examples.values_list("text_hash", flat=True) + ) + annos = self.load_annos() + + missing_annos = annos[~annos["text_hash"].isin(trainset_hashes)] + missing_annos = missing_annos.drop_duplicates(subset="text_hash") + if len(missing_annos): + updates = [ + LabelTrainsetExample( + label_id=self.id, + text_hash=row["text_hash"], + text=row["text"], + split="train", + ) + for row in missing_annos.to_dict("records") + ] + LabelTrainsetExample.objects.bulk_create(updates, batch_size=1000) + + cols = ["text_hash", "text", "split", "pred", "reason"] data = pd.DataFrame( - self.trainset_examples.all().values( - "text_hash", "text", "split", "pred", "reason" - ) + self.trainset_examples.all().values(*cols), + columns=cols, ) - annos = self.load_annos() flagged_hashes = annos[annos["value"].isna()]["text_hash"].tolist() annos = annos[~annos["value"].isna()] - annos = annos.rename(columns={"value": "pred"}) - annos["split"] = "train" - data = pd.concat([data, annos]) + annos = annos[["text_hash", "value"]].rename( + columns={"value": "anno_value"} + ) + data = data.merge(annos, on="text_hash", how="left") - if len(data) and "text_hash" in data.columns: - data = data.drop_duplicates(subset="text_hash", keep="last") - data = data[~data["text_hash"].isin(flagged_hashes)] + data["value"] = data["anno_value"].fillna(data["pred"]) + data.loc[data["text_hash"].isin(flagged_hashes), "value"] = None data = data.sample(frac=1, random_state=42) data = data.reset_index(drop=True) - return data - def update_trainset_preds(self, num_threads=128): - predictor = self.predictor - trainset = self.load_trainset() - preds = predictor.predict( - trainset["text"].tolist(), num_threads=num_threads + minimal_fn = self.get_minimal_fn() + likely_fn = self.get_likely_fn() + data["bucket"] = data["text"].apply( + lambda x: "excluded" + if not minimal_fn(x) + else "likely" + if likely_fn(x) + else "neutral" ) - trainset["pred"] = [x.value for x in preds] - trainset["reason"] = [x.reason for x in preds] + return data + + def update_trainset_preds(self, num_threads=32): + data = self.load_trainset() + data = data[data["pred"].isna()] + texts = data["text"].tolist() + preds = self.batch_predict(texts, num_threads=num_threads) + data["pred"] = [x.get("value") for x in preds] + data["reason"] = [x.get("reason") for x in preds] examples = self.trainset_examples.all() examples = {e.text_hash: e for e in examples} - for row in trainset.to_dict("records"): + updates = [] + for row in data.to_dict("records"): if row["text_hash"] in examples: example = examples[row["text_hash"]] example.pred = row["pred"] example.reason = row["reason"] + updates.append(example) LabelTrainsetExample.objects.bulk_update( - list(examples.values()), + updates, fields=["pred", "reason"], batch_size=1000, ) @@ -337,20 +413,40 @@ def update_trainset_pred_counts(self): self.trainset_num_negative_preds = 0 self.save() - def get_new_predictor(self): - return SingleLabelPredictor( - label_name=self.name, - project_instructions=self.project.instructions, - label_instructions=self.instructions, - model=self.inference_model, - ) + def load_predictor(self): + args = { + "label_name": self.name, + "project_instructions": self.project.instructions, + "label_instructions": self.instructions, + "decisions": self.decisions.values("text", "value", "reason"), + } - @property - def predictor(self): - if self.predictor_data is None: - return self.get_new_predictor() - else: - return GEPAPredictor.from_config(self.predictor_data) + def predict_fn(text: str): + for _ in range(3): + try: + agent = AnnoAgent(**args) + anno = agent(text) + return { + "status": "success", + "value": anno.value, + "reason": anno.reason, + } + except Exception as e: + print(f"Error predicting {text}: {e}") + time.sleep(5) + return {"status": "error"} + + return predict_fn + + def batch_predict(self, texts: list[str], num_threads: int = 32): + predictor = self.load_predictor() + with ThreadPoolExecutor(max_workers=num_threads) as executor: + futures = [executor.submit(predictor, text) for text in texts] + for _ in tqdm( + as_completed(futures), total=len(futures), desc="Predicting" + ): + pass + return [future.result() for future in futures] @property def trainset_train_tag(self): @@ -467,23 +563,6 @@ def sync_trainset_pred_tags(self): pos_ids = [] model.bulk_replace_tag(self.trainset_pred_tag, pos_ids) - def fit_predictor(self): - predictor = self.get_new_predictor() - examples = self.decisions.values("text", "value", "reason") - predictor.fit( - examples, - num_threads=8, - reflection_lm={ - "model": self.teacher_model, - "temperature": 1.0, - "max_tokens": 32000, - }, - ) - self.predictor_data = predictor.config - self.predictor_updated_at = timezone.now() - self.save() - print(predictor.last_cost) - def get_finetune_run_name(self, config_name): return f"{self.project_id}__{label2slug(self.name)}__{config_name}" @@ -500,8 +579,8 @@ def prepare_finetune( data = self.load_trainset() data = data.sample(frac=1, random_state=42) data = ( - data[["text_hash", "text", "pred", "split"]] - .rename(columns={"pred": "label"}) + data[["text_hash", "text", "value", "split"]] + .rename(columns={"value": "label"}) .dropna() ) data["label"] = data["label"].apply(lambda x: "yes" if x else "no") @@ -630,10 +709,9 @@ def update_all(self, num_threads=128, predict=False, force=False): Runs the full pipeline in order, but only steps that need updating: 1. Resample trainset (if decisions newer than trainset) - 2. Fit predictor (if trainset newer than predictor) - 3. Run predictions (if predictor newer than predictions) - 4. Train finetunes (if predictions newer than finetunes) - 5. Run global corpus predictions (if predict is True and finetune newer than global predictions) + 2. Run predictions (if trainset newer than predictions) + 3. Train finetunes (if predictions newer than finetunes) + 4. Run global corpus predictions (if predict is True and finetune newer than global predictions) """ missing = [] if not self.heuristics.filter(is_minimal=True).exists(): @@ -672,32 +750,20 @@ def update_all(self, num_threads=128, predict=False, force=False): self.update_trainset() self.refresh_from_db() - # Step 2: Fit predictor if trainset is newer + # Step 2: Run predictions if trainset is newer if force or ( self.trainset_updated_at - and ( - not self.predictor_updated_at - or self.trainset_updated_at > self.predictor_updated_at - ) - ): - print("Step 2: Fitting predictor") - self.fit_predictor() - self.refresh_from_db() - - # Step 3: Run predictions if predictor is newer - if force or ( - self.predictor_updated_at and ( not self.trainset_predictions_updated_at - or self.predictor_updated_at + or self.trainset_updated_at > self.trainset_predictions_updated_at ) ): - print("Step 3: Running predictions") + print("Step 2: Running predictions") self.update_trainset_preds(num_threads=num_threads) self.refresh_from_db() - # Step 4: Train finetunes if predictions are newer + # Step 3: Train finetunes if predictions are newer for config_name in finetune_configs: finetune = self.fintunes.filter(config_name=config_name).first() finetuned_at = finetune.finetuned_at if finetune else None @@ -709,10 +775,10 @@ def update_all(self, num_threads=128, predict=False, force=False): or self.trainset_predictions_updated_at > finetuned_at ) ): - print(f"Step 4: Training finetune: {config_name}") + print(f"Step 3: Training finetune: {config_name}") self.train_finetune(config_name) - # Step 5: Run global corpus predictions if finetune is newer + # Step 4: Run global corpus predictions if finetune is newer if predict: ft = self.fintunes.filter( config_name=self.project.get_search_model().main_finetune_config @@ -727,7 +793,7 @@ def update_all(self, num_threads=128, predict=False, force=False): ) ) ): - print("Step 5: Running global predictions") + print("Step 4: Running global predictions") self.predict_finetune(force=force) print("Update complete!") @@ -770,11 +836,34 @@ class LabelDecision(BaseModel): text = models.TextField(null=True, blank=True) value = models.BooleanField() reason = models.TextField() + added_to_sample = models.BooleanField(default=False) + + def save(self, *args, added_to_sample=False, **kwargs): + self.added_to_sample = added_to_sample + super().save(*args, **kwargs) class Meta: unique_together = ("label", "text_hash") +class LabelQuerystring(BaseModel): + """Model for label querystrings.""" + + label = models.ForeignKey( + Label, on_delete=models.CASCADE, related_name="querystrings" + ) + querystring = models.TextField() + num_examples = models.IntegerField(default=50) + added_to_sample = models.BooleanField(default=False) + + def save(self, *args, added_to_sample=False, **kwargs): + self.added_to_sample = added_to_sample + super().save(*args, **kwargs) + + class Meta: + unique_together = ("label", "querystring") + + class LabelHeuristic(BaseModel): """Model for label heuristics.""" @@ -930,15 +1019,6 @@ class DocketEntry(SearchDocumentModel): project_id = "docket-entry" finetune_configs = { - "underfit": { - "base_model_name": "answerdotai/ModernBERT-base", - "training_args": { - "num_train_epochs": 1, - "learning_rate": 5e-5, - "warmup_ratio": 0.05, - "bf16": True, - }, - }, "main": { "base_model_name": "answerdotai/ModernBERT-base", "training_args": { diff --git a/clx/app/templates/search/heuristics_tab.html b/clx/app/templates/search/heuristics_tab.html index 90d6451..3fc8932 100644 --- a/clx/app/templates/search/heuristics_tab.html +++ b/clx/app/templates/search/heuristics_tab.html @@ -135,5 +135,37 @@

Samplers

+
+ +
\ No newline at end of file diff --git a/clx/app/templates/search/index.html b/clx/app/templates/search/index.html index 209244f..c2aa52c 100644 --- a/clx/app/templates/search/index.html +++ b/clx/app/templates/search/index.html @@ -454,6 +454,25 @@ } }, + async saveQuerystring(labelId, querystring, numExamples) { + await this.apiFetch('labels/save-querystring/', { + method: 'POST', + body: { label_id: labelId, querystring: querystring, num_examples: numExamples }, + jobName: 'Saving Querystring', + }); + await this.getLabels(); + this.activeTab = 'heuristics'; + }, + + async deleteQuerystring(labelId, querystring) { + await this.apiFetch('labels/delete-querystring/', { + method: 'POST', + body: { label_id: labelId, querystring: querystring }, + jobName: 'Deleting Querystring', + }); + await this.getLabels(); + }, + // Tags async getTags() { const data = await this.apiFetch('tags/', { method: 'GET' }); @@ -616,20 +635,6 @@ await this.getLabels(); await this.getTags(); }, - async fitPredictor(labelId) { - if (!labelId) return; - await this.apiFetch('predictor/fit/', { - method: 'POST', - body: { - label_id: labelId, - inference_model: this.predictorConfig.inferenceModel, - teacher_model: this.predictorConfig.teacherModel, - }, - jobName: 'Fitting Predictor', - }); - await this.getLabels(); - await this.getTags(); - }, // Finetunes async getFinetunes(labelId = null) { diff --git a/clx/app/templates/search/predictor_tab.html b/clx/app/templates/search/predictor_tab.html index 2eb1e13..5a722f6 100644 --- a/clx/app/templates/search/predictor_tab.html +++ b/clx/app/templates/search/predictor_tab.html @@ -2,22 +2,12 @@ class="flex-1 min-h-0 overflow-y-auto pb-32 px-4 -mx-4" x-show="activeTab === 'predictor'" x-cloak x-data="{ - getPredictorStatus() { - const latest_decision_updated_at = this.decisions && Object.values(this.decisions).length ? Math.max(...Object.values(this.decisions).map(d => new Date(d.updated_at))) : null; - const predictor_updated_at = new Date(this.labels?.[activeLabelId]?.predictor_updated_at); - if (!latest_decision_updated_at) return 'no_decisions'; - if (!predictor_updated_at) return 'no_predictor'; - if (latest_decision_updated_at > predictor_updated_at) return 'predictor_out_of_date'; - return 'up_to_date'; - }, getTrainsetStatus() { const trainset_updated_at = new Date(this.labels?.[activeLabelId]?.trainset_updated_at); const trainset_predictions_updated_at = new Date(this.labels?.[activeLabelId]?.trainset_predictions_updated_at); - const predictor_updated_at = new Date(this.labels?.[activeLabelId]?.predictor_updated_at); if (!trainset_updated_at) return 'no_trainset'; if (!trainset_predictions_updated_at) return 'no_preds'; if (trainset_updated_at > trainset_predictions_updated_at) return 'preds_out_of_date'; - if (predictor_updated_at > trainset_predictions_updated_at) return 'preds_out_of_date'; return 'up_to_date'; }, isFitting() { @@ -169,61 +159,11 @@

Trainset

Predictor

-
- - - Fitting... - -
- - -
-
- -
-
- - -
-
- - -
-
-
@@ -296,27 +236,6 @@

Predictor

-
-
- -
-

Instructions generated from the DSPy optimization process.

-
-

-                        
-
-
- -
-
diff --git a/clx/app/templates/search/search_panel.html b/clx/app/templates/search/search_panel.html index d40aeb1..5da852f 100644 --- a/clx/app/templates/search/search_panel.html +++ b/clx/app/templates/search/search_panel.html @@ -49,7 +49,13 @@

Search

placeholder="Enter query…" class="w-full p-2 rounded-md bg-white ring-1 ring-gray-300 focus:ring-blue-500" /> -
+
+ Add as sampler /labels/save-querystring/", + views.labels_save_querystring_endpoint, + name="labels-save-querystring-endpoint", + ), + path( + "api/project//labels/delete-querystring/", + views.labels_delete_querystring_endpoint, + name="labels-delete-querystring-endpoint", + ), path( "api/project//tags/", views.tags_endpoint, @@ -97,11 +107,6 @@ views.predictor_update_trainset_preds_endpoint, name="predictor-update-trainset-preds-endpoint", ), - path( - "api/project//predictor/fit/", - views.predictor_fit_endpoint, - name="predictor-fit-endpoint", - ), path( "api/project//finetunes/", views.finetunes_endpoint, diff --git a/clx/app/views.py b/clx/app/views.py index 8e76beb..5881b27 100644 --- a/clx/app/views.py +++ b/clx/app/views.py @@ -12,6 +12,7 @@ LabelDecision, LabelFinetune, LabelHeuristic, + LabelQuerystring, LabelTag, Project, ) @@ -66,17 +67,13 @@ def project_update_instructions_endpoint(request, project_id): @require_GET def labels_endpoint(request, project_id): project = Project.objects.get(id=project_id) - labels_qs = Label.objects.filter(project=project).values( + labels_query = Label.objects.filter(project=project).values( "id", "name", "num_excluded", "num_neutral", "num_likely", "instructions", - "inference_model", - "teacher_model", - "predictor_data", - "predictor_updated_at", "trainset_num_excluded", "trainset_num_neutral", "trainset_num_likely", @@ -86,7 +83,18 @@ def labels_endpoint(request, project_id): "trainset_predictions_updated_at", "trainset_updated_at", ) - labels = {row["id"]: row for row in labels_qs} + labels = {row["id"]: {**row, "querystrings": []} for row in labels_query} + all_qs = ( + LabelQuerystring.objects.filter(label__project=project) + .values( + "label_id", + "querystring", + "num_examples", + ) + .order_by("created_at") + ) + for qs in all_qs: + labels[qs["label_id"]]["querystrings"].append(qs) return JsonResponse({"labels": labels}) @@ -102,6 +110,34 @@ def labels_update_instructions_endpoint(request, project_id): return JsonResponse({"ok": True}) +@csrf_exempt +@require_POST +def labels_save_querystring_endpoint(request, project_id): + payload = {} if request.body is None else json.loads(request.body) + label_id = payload.get("label_id") + querystring = payload.get("querystring") + num_examples = payload.get("num_examples", 50) + label = Label.objects.get(id=label_id, project_id=project_id) + qs, _ = LabelQuerystring.objects.get_or_create( + label=label, querystring=querystring + ) + qs.num_examples = int(num_examples) + qs.save() + return JsonResponse({"ok": True}) + + +@csrf_exempt +@require_POST +def labels_delete_querystring_endpoint(request, project_id): + payload = {} if request.body is None else json.loads(request.body) + label_id = payload.get("label_id") + querystring = payload.get("querystring") + label = Label.objects.get(id=label_id, project_id=project_id) + qs = LabelQuerystring.objects.get(label=label, querystring=querystring) + qs.delete() + return JsonResponse({"ok": True}) + + # Tags Endpoints @require_GET def tags_endpoint(request, project_id): @@ -328,20 +364,6 @@ def predictor_update_trainset_preds_endpoint(request, project_id): return JsonResponse({"ok": True}) -@csrf_exempt -@require_POST -def predictor_fit_endpoint(request, project_id): - payload = {} if request.body is None else json.loads(request.body) - label_id = payload.get("label_id") - assert label_id, "label_id is required" - label = Label.objects.get(id=label_id) - label.inference_model = payload.get("inference_model") - label.teacher_model = payload.get("teacher_model") - label.save() - label.fit_predictor() - return JsonResponse({"ok": True}) - - # Finetunes Endpoints @csrf_exempt @require_POST diff --git a/clx/cli/cleanup.py b/clx/cli/cleanup.py index 3277ec7..d1557e7 100644 --- a/clx/cli/cleanup.py +++ b/clx/cli/cleanup.py @@ -1,14 +1,13 @@ -from concurrent.futures import ThreadPoolExecutor, as_completed - import click from tqdm import tqdm @click.command() +@click.argument("project_id", default=None) +@click.argument("label_name", default=None) @click.option("--update", is_flag=True, help="Update labels") @click.option("--predict", is_flag=True, help="Run global corpus predictions") -@click.option("--label-name", "--label", help="Label name") -def cleanup(update, predict, label_name): +def cleanup(project_id, label_name, update, predict): """Sync app data.""" from clx.models import LabelHeuristic, Project @@ -45,17 +44,9 @@ def cleanup(update, predict, label_name): print(f"Updating heuristic {heuristic.name}...") heuristic.apply() - def update_label(label, predict): - print(f"Updating label {label.name}...") - label.update_all(predict=predict, num_threads=32) - - if update: - with ThreadPoolExecutor(max_workers=4) as executor: - futures = [] - for label in project.labels.all().order_by("name"): - if label_name is None or label.name == label_name: - futures.append( - executor.submit(update_label, label, predict) - ) - for _ in tqdm(as_completed(futures), total=len(futures)): - pass + if update and (project_id is None or project_id == project.id): + for label in project.labels.all().order_by("name"): + if label_name is None or label.name == label_name: + print(f"Updating label {label.name}...") + label.update_all(predict=predict, num_threads=8) + print(f"Label {label.name} updated.") diff --git a/clx/llm/anno_agent.py b/clx/llm/anno_agent.py new file mode 100644 index 0000000..4d29f99 --- /dev/null +++ b/clx/llm/anno_agent.py @@ -0,0 +1,82 @@ +import simplejson as json +from pydantic import BaseModel + +from clx.llm.agent import Agent + +PROJECT_INSTRUCTIONS_TEMPLATE = """ +You are an annotation assistant providing single-label classification +annotations for the following label: {label_name}. + +When annotating you will be provided a text example. You should respond +with a boolean `value` indicating whether the label "{label_name}" applies to +the text, and a brief, one-sentence `reason` explaining how your decision +aligns with the guidelines below. + +Here are some guidelines you should follow when annotating: + +Consider these project-level instructions. These are general, project-wide +instructions that apply to all labels in the project. They may include examples +of labels other than the one that you are annotating, just remember that you are +currently annotating for the label "{label_name}" specifically. + +``` +{project_instructions} +``` +""" + +LABEL_INSTRUCTIONS_TEMPLATE = """ +The user has also provided some label-specific instructions. These should take precedence +over the project-level instructions if they are in conflict. + +``` +{label_instructions} +``` +""" + + +class Annotation(BaseModel): + """An annotation value with a reason.""" + + value: bool + reason: str + + +class AnnoAgent(Agent): + """An annotation agent for single-label classification.""" + + default_model = "gemini/gemini-2.5-flash-lite" + default_completion_args = { + "response_format": Annotation, + } + on_init_args = [ + "label_name", + "label_instructions", + "project_instructions", + "decisions", + ] + + def on_init( + self, label_name, label_instructions, project_instructions, decisions + ): + system_prompt = PROJECT_INSTRUCTIONS_TEMPLATE.format( + label_name=label_name, + project_instructions=project_instructions, + ) + if label_instructions is not None: + system_prompt += "\n\n" + LABEL_INSTRUCTIONS_TEMPLATE.format( + label_name=label_name, + label_instructions=label_instructions, + ) + messages = [{"role": "system", "content": system_prompt}] + for decision in decisions: + messages.append({"role": "user", "content": decision["text"]}) + json_content = Annotation( + value=decision["value"], reason=decision["reason"] + ).model_dump_json() + messages.append({"role": "assistant", "content": json_content}) + self.state["prefix_messages"] = messages + + def __call__(self, text: str) -> Annotation: + self.messages = [*self.state["prefix_messages"]] + response = self.step(messages=[{"role": "user", "content": text}]) + return Annotation(**json.loads(response["content"])) diff --git a/home/projects/docketbert/runs/docketbert-final-large-395M/checkpoints/logs/events.out.tfevents.1768913464.c3c16d89609e.1542.0 b/home/projects/docketbert/runs/docketbert-final-large-395M/checkpoints/logs/events.out.tfevents.1768913464.c3c16d89609e.1542.0 new file mode 100644 index 0000000..6500b26 Binary files /dev/null and b/home/projects/docketbert/runs/docketbert-final-large-395M/checkpoints/logs/events.out.tfevents.1768913464.c3c16d89609e.1542.0 differ diff --git a/home/projects/docketbert/runs/docketbert-final-large-395M/checkpoints/logs/events.out.tfevents.1769091611.90d7d00a7e03.1540.0 b/home/projects/docketbert/runs/docketbert-final-large-395M/checkpoints/logs/events.out.tfevents.1769091611.90d7d00a7e03.1540.0 new file mode 100644 index 0000000..708d493 Binary files /dev/null and b/home/projects/docketbert/runs/docketbert-final-large-395M/checkpoints/logs/events.out.tfevents.1769091611.90d7d00a7e03.1540.0 differ diff --git a/home/projects/docketbert/runs/docketbert-final-large-395M/checkpoints/results.json b/home/projects/docketbert/runs/docketbert-final-large-395M/checkpoints/results.json new file mode 100644 index 0000000..45abc75 --- /dev/null +++ b/home/projects/docketbert/runs/docketbert-final-large-395M/checkpoints/results.json @@ -0,0 +1,9 @@ +{ + "eval_loss": 0.5528995990753174, + "eval_runtime": 541.8924, + "eval_samples_per_second": 184.538, + "eval_steps_per_second": 23.067, + "epoch": 1.0, + "num_train_examples": 40761590, + "num_eval_examples": 100000 +} \ No newline at end of file diff --git a/home/projects/docketbert/runs/docketbert-final-large-395M/config.json b/home/projects/docketbert/runs/docketbert-final-large-395M/config.json new file mode 100644 index 0000000..976be56 --- /dev/null +++ b/home/projects/docketbert/runs/docketbert-final-large-395M/config.json @@ -0,0 +1,32 @@ +{ + "task": "mlm", + "run_name": "docketbert-final-large-395M", + "run_dir_parent": "/workspace/clx/projects/docketbert/runs", + "base_model_name": "answerdotai/ModernBERT-large", + "tokenizer_name": "answerdotai/ModernBERT-base", + "tokenize_args": { + "max_length": 768, + "padding": false + }, + "model_args": {}, + "training_args": { + "learning_rate": 0.0005, + "weight_decay": 0.01, + "num_train_epochs": 1, + "warmup_ratio": 0.05, + "logging_steps": 5, + "save_strategy": "steps", + "save_steps": 1000, + "save_total_limit": 2, + "eval_strategy": "steps", + "eval_steps": 1000, + "prediction_loss_only": true, + "remove_unused_columns": false, + "bf16": true, + "max_steps": 159224, + "per_device_train_batch_size": 8, + "per_device_eval_batch_size": 8, + "gradient_accumulation_steps": 32 + }, + "mlm_probability": 0.3 +} \ No newline at end of file diff --git a/home/projects/docketbert/runs/docketbert-final-sliced-175M/checkpoints/logs/events.out.tfevents.1770755421.df02e7f16835.4837.0 b/home/projects/docketbert/runs/docketbert-final-sliced-175M/checkpoints/logs/events.out.tfevents.1770755421.df02e7f16835.4837.0 new file mode 100644 index 0000000..e26b278 Binary files /dev/null and b/home/projects/docketbert/runs/docketbert-final-sliced-175M/checkpoints/logs/events.out.tfevents.1770755421.df02e7f16835.4837.0 differ diff --git a/home/projects/docketbert/runs/docketbert-final-sliced-175M/config.json b/home/projects/docketbert/runs/docketbert-final-sliced-175M/config.json new file mode 100644 index 0000000..fb5a7b0 --- /dev/null +++ b/home/projects/docketbert/runs/docketbert-final-sliced-175M/config.json @@ -0,0 +1,32 @@ +{ + "task": "mlm", + "run_name": "docketbert-final-sliced-175M", + "run_dir_parent": "/workspace/clx/projects/docketbert/runs", + "base_model_name": "/workspace/clx/projects/docketbert/models/final-sliced-large-ft-interleaved-10l", + "tokenizer_name": "answerdotai/ModernBERT-base", + "tokenize_args": { + "max_length": 768, + "padding": false + }, + "model_args": {}, + "training_args": { + "learning_rate": 0.0005, + "weight_decay": 0.01, + "num_train_epochs": 1, + "warmup_ratio": 0.05, + "logging_steps": 5, + "save_strategy": "steps", + "save_steps": 1000, + "save_total_limit": 2, + "eval_strategy": "steps", + "eval_steps": 1000, + "prediction_loss_only": true, + "remove_unused_columns": false, + "bf16": true, + "max_steps": 159224, + "per_device_train_batch_size": 32, + "per_device_eval_batch_size": 32, + "gradient_accumulation_steps": 8 + }, + "mlm_probability": 0.3 +} \ No newline at end of file diff --git a/home/projects/docketbert/runs/docketbert-scratch-7M/checkpoints/logs/events.out.tfevents.1768838378.a1e1309953bb.1410.0 b/home/projects/docketbert/runs/docketbert-scratch-7M/checkpoints/logs/events.out.tfevents.1768838378.a1e1309953bb.1410.0 new file mode 100644 index 0000000..b03df04 Binary files /dev/null and b/home/projects/docketbert/runs/docketbert-scratch-7M/checkpoints/logs/events.out.tfevents.1768838378.a1e1309953bb.1410.0 differ diff --git a/home/projects/docketbert/runs/docketbert-scratch-7M/config.json b/home/projects/docketbert/runs/docketbert-scratch-7M/config.json index 1da9ed8..efa4ee9 100644 --- a/home/projects/docketbert/runs/docketbert-scratch-7M/config.json +++ b/home/projects/docketbert/runs/docketbert-scratch-7M/config.json @@ -1,7 +1,7 @@ { "task": "mlm", "run_name": "docketbert-scratch-7M", - "run_dir_parent": "/workspace/clx/home/runs", + "run_dir_parent": "/workspace/clx/projects/docketbert/runs", "base_model_name": "answerdotai/ModernBERT-base", "tokenizer_name": "answerdotai/ModernBERT-base", "tokenize_args": {