Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
820 changes: 820 additions & 0 deletions cc/CLAUDE.md

Large diffs are not rendered by default.

497 changes: 497 additions & 0 deletions cc/helpers.py

Large diffs are not rendered by default.

23 changes: 23 additions & 0 deletions clx/app/migrations/0004_labelfinetune_finetuned_at_and_more.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Generated by Django 5.2.7 on 2026-01-15 18:50

from django.db import migrations, models


class Migration(migrations.Migration):

dependencies = [
('app', '0003_labelfinetune'),
]

operations = [
migrations.AddField(
model_name='labelfinetune',
name='finetuned_at',
field=models.DateTimeField(blank=True, null=True),
),
migrations.AddField(
model_name='labelfinetune',
name='predicted_at',
field=models.DateTimeField(blank=True, null=True),
),
]
211 changes: 208 additions & 3 deletions clx/app/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@

from clx import label2slug
from clx.llm import GEPAPredictor, SingleLabelPredictor, batch_embed, mesh_sort
from clx.ml import pipeline
from clx.ml import pipeline, training_run
from clx.settings import CLX_HOME
from clx.utils import pd_save_or_append

from .custom_heuristics import custom_heuristics
from .search_utils import BaseModel, SearchDocumentModel
Expand Down Expand Up @@ -273,6 +274,8 @@ def load_annos(self):
flag_annos["value"] = None

annos = pd.concat([pos_annos, neg_annos, flag_annos])
if annos.empty:
return pd.DataFrame(columns=["text_hash", "text", "value"])
return annos

def load_trainset(self):
Expand Down Expand Up @@ -326,8 +329,9 @@ def update_trainset_pred_counts(self):
data = self.load_trainset()
if len(data) and "pred" in data.columns:
data = data.dropna(subset=["pred"])
self.trainset_num_positive_preds = data["pred"].sum()
self.trainset_num_negative_preds = (~data["pred"]).sum()
preds = data["pred"].astype(bool)
self.trainset_num_positive_preds = preds.sum()
self.trainset_num_negative_preds = (~preds).sum()
else:
self.trainset_num_positive_preds = 0
self.trainset_num_negative_preds = 0
Expand Down Expand Up @@ -379,6 +383,14 @@ def get_trainset_finetune_tag(self, config_name):
)
return tag

@property
def finetune_tag(self):
tag, _ = LabelTag.objects.get_or_create(
name="ft",
label=self,
)
return tag

@property
def anno_true_tag(self):
tag, _ = LabelTag.objects.get_or_create(
Expand Down Expand Up @@ -521,6 +533,196 @@ def prepare_finetune(

return train_data, eval_data, run_config

def train_finetune(self, config_name):
"""Train a finetune model for this label."""
train_data, eval_data, run_config = self.prepare_finetune(config_name)

run = training_run(**run_config)
outputs = run.train(train_data, eval_data, overwrite=True, remote=True)

data = pd.concat([train_data, eval_data])

pipe = self.get_finetune_run_pipe(config_name)
data["pred"] = pipe(data["text"].tolist(), batch_size=16)
data = data[data["pred"] == "yes"]

tag = self.get_trainset_finetune_tag(config_name)
model = self.project.get_search_model()
example_ids = model.objects.filter(
text_hash__in=data["text_hash"].tolist()
)
example_ids = example_ids.values_list("id", flat=True)
model.bulk_replace_tag(tag.id, example_ids)

finetune, _ = LabelFinetune.objects.get_or_create(
label=self, config_name=config_name
)
finetune.eval_results = outputs["results"]
finetune.finetuned_at = timezone.now()
finetune.save()

return finetune

def predict_finetune(self, batch_size=16, num_workers=64, force=False):
"""Run finetune predictions across the entire corpus."""
cache_path = self.data_dir / "finetune_predictions_cache.csv"
self.data_dir.mkdir(parents=True, exist_ok=True)
config_name = self.project.get_search_model().main_finetune_config
if config_name is None:
raise ValueError("Set main_finetune_config for this project")

if force and cache_path.exists():
cache_path.unlink()

cached_ids = set()
if cache_path.exists():
cached_data = pd.read_csv(cache_path)
cached_ids = set(cached_data["id"].unique().tolist())

model = self.project.get_search_model()
pipe = self.get_finetune_run_pipe(config_name)

total_examples = model.objects.count()
outer_batch_size = 1024 * 500
for batch in tqdm(
model.objects.batch_df("id", "text", batch_size=outer_batch_size),
desc=f"Predicting {config_name}",
total=total_examples // outer_batch_size,
):
batch = batch[~batch["id"].isin(cached_ids)]
if len(batch) > 0:
batch["value"] = pipe(
batch["text"].tolist(),
batch_size=batch_size,
num_workers=num_workers,
max_length=768,
truncation=True,
)
batch["value"] = batch["value"].apply(lambda x: x == "yes")
pd_save_or_append(batch[["id", "value"]], cache_path)

if cache_path.exists():
all_preds = pd.read_csv(cache_path)
positive_ids = all_preds[all_preds["value"]]["id"].tolist()
tag = self.finetune_tag
model.bulk_replace_tag(tag.id, positive_ids)
finetune = self.fintunes.filter(config_name=config_name).first()
if finetune:
finetune.predicted_at = timezone.now()
finetune.save()
cache_path.unlink()

print(
f"Predictions complete: {len(positive_ids):,} positive out of {len(all_preds):,} total"
)

def update_all(self, num_threads=128, predict=False, force=False):
"""Update all components that are out of date based on timestamps.

Runs the full pipeline in order, but only steps that need updating:
1. Resample trainset (if decisions newer than trainset)
2. Fit predictor (if trainset newer than predictor)
3. Run predictions (if predictor newer than predictions)
4. Train finetunes (if predictions newer than finetunes)
5. Run global corpus predictions (if predict is True and finetune newer than global predictions)
"""
missing = []
if not self.heuristics.filter(is_minimal=True).exists():
missing.append("at least one minimal heuristic")
if not self.heuristics.filter(is_likely=True).exists():
missing.append("at least one likely heuristic")
if not self.decisions.filter(value=True).exists():
missing.append("at least one positive decision")
if not self.decisions.filter(value=False).exists():
missing.append("at least one negative decision")

if missing:
print("Cannot run update_all - missing required setup:")
for item in missing:
print(f" - {item}")
return

model = self.project.get_search_model()
finetune_configs = list(model.finetune_configs.keys())

# Get latest decision timestamp
latest_decision = self.decisions.order_by("-updated_at").first()
latest_decision_at = (
latest_decision.updated_at if latest_decision else None
)

# Step 1: Resample trainset if decisions are newer
if force or (
latest_decision_at
and (
not self.trainset_updated_at
or latest_decision_at > self.trainset_updated_at
)
):
print("Resampling trainset...")
self.update_trainset()
self.refresh_from_db()

# Step 2: Fit predictor if trainset is newer
if force or (
self.trainset_updated_at
and (
not self.predictor_updated_at
or self.trainset_updated_at > self.predictor_updated_at
)
):
print("Fitting predictor...")
self.fit_predictor()
self.refresh_from_db()

# Step 3: Run predictions if predictor is newer
if force or (
self.predictor_updated_at
and (
not self.trainset_predictions_updated_at
or self.predictor_updated_at
> self.trainset_predictions_updated_at
)
):
print("Running predictions...")
self.update_trainset_preds(num_threads=num_threads)
self.refresh_from_db()

# Step 4: Train finetunes if predictions are newer
for config_name in finetune_configs:
finetune = self.fintunes.filter(config_name=config_name).first()
finetuned_at = finetune.finetuned_at if finetune else None

if force or (
self.trainset_predictions_updated_at
and (
not finetuned_at
or self.trainset_predictions_updated_at > finetuned_at
)
):
print(f"Training finetune: {config_name}...")
self.train_finetune(config_name)

# Step 5: Run global corpus predictions if finetune is newer
if predict:
ft = self.fintunes.filter(
config_name=self.project.get_search_model().main_finetune_config
).first()
if ft and (
force
or (
ft.finetuned_at
and (
not ft.predicted_at
or ft.finetuned_at > ft.predicted_at
)
)
):
print(f"Running global predictions: {ft.config_name}...")
self.predict_finetune(force=force)

print("Update complete!")

class Meta:
unique_together = ("project", "name")

Expand Down Expand Up @@ -710,6 +912,8 @@ class LabelFinetune(BaseModel):
)
config_name = models.CharField(max_length=255)
eval_results = models.JSONField(null=True, blank=True)
finetuned_at = models.DateTimeField(null=True, blank=True)
predicted_at = models.DateTimeField(null=True, blank=True)


class DocketEntry(SearchDocumentModel):
Expand All @@ -736,6 +940,7 @@ class DocketEntry(SearchDocumentModel):
},
},
}
main_finetune_config = "main"

id = models.BigIntegerField(primary_key=True)
recap_id = models.BigIntegerField(unique=True)
Expand Down
1 change: 1 addition & 0 deletions clx/app/search_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,7 @@ class SearchDocumentModel(BaseModel, metaclass=SearchDocumentModelBase):

project_id = None
finetune_configs = {}
main_finetune_config = None

id = models.BigIntegerField(primary_key=True)
text = models.TextField()
Expand Down
9 changes: 9 additions & 0 deletions clx/app/templates/search/results_tab.html
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,15 @@
></span>
</template>
</div>
<div class="flex flex-wrap gap-2">
<template x-for="tag_id in (result.tags || []).filter(tid => tid.startsWith('ft:'))" :key="tag_id">
<span
class="inline-flex items-center rounded-full bg-red-100 px-2 py-1 text-xs text-gray-700"
x-text="labels?.[tags?.[tag_id]?.label_id]?.name"
x-show="labels?.[tags?.[tag_id]?.label_id]?.name"
></span>
</template>
</div>
</div>
</template>
</div>
Expand Down
9 changes: 8 additions & 1 deletion clx/cli/cleanup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@


@click.command()
def cleanup():
@click.option("--predict", is_flag=True, help="Run global corpus predictions")
@click.option("--label-name", "--label", help="Label name")
def cleanup(predict, label_name):
"""Sync app data."""
from clx.models import LabelHeuristic, Project

Expand Down Expand Up @@ -39,3 +41,8 @@ def cleanup():
):
print(f"Updating heuristic {heuristic.name}...")
heuristic.apply()

for label in project.labels.all():
if label_name is None or label.name == label_name:
print(f"Updating label {label.name}...")
label.update_all(predict=predict)
34 changes: 4 additions & 30 deletions clx/cli/train.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,13 @@
import click
import pandas as pd

from clx.ml import training_run


@click.command()
@click.argument("project_id")
@click.argument("label_name")
@click.argument("config_name")
def train(project_id, label_name, config_name):
"""Train a label."""
from clx.models import Label, LabelFinetune, Project

project = Project.objects.get(id=project_id)
label = Label.objects.get(project=project, name=label_name)
train_data, eval_data, run_config = label.prepare_finetune(config_name)

run = training_run(**run_config)
outputs = run.train(train_data, eval_data, overwrite=True, remote=True)

data = pd.concat([train_data, eval_data])

pipe = label.get_finetune_run_pipe(config_name)
data["pred"] = pipe(data["text"].tolist(), batch_size=16)
data = data[data["pred"] == "yes"]
"""Train a finetune model for a label."""
from clx.models import Label

tag = label.get_trainset_finetune_tag(config_name)
model = project.get_search_model()
example_ids = model.objects.filter(
text_hash__in=data["text_hash"].tolist()
)
example_ids = example_ids.values_list("id", flat=True)
model.bulk_replace_tag(tag.id, example_ids)
finetune, _ = LabelFinetune.objects.get_or_create(
label=label, config_name=config_name
)
finetune.eval_results = outputs["results"]
finetune.save()
label = Label.objects.get(project_id=project_id, name=label_name)
label.train_finetune(config_name)
2 changes: 1 addition & 1 deletion clx/ml/remote_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def handle_megabatch(megabatch: list):
)
outputs = response.json()
except Exception as e:
errors.append(str(e))
outputs = {"status": "FAILED", "error": str(e)}
if outputs["status"] != "COMPLETED":
errors.append(outputs)
else:
Expand Down