Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 27 additions & 15 deletions clx/ml/training_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import requests
import simplejson as json
import torch
from datasets import Dataset
from datasets import Dataset, IterableDataset, load_dataset
from tqdm import tqdm
from transformers import (
AutoModel,
Expand Down Expand Up @@ -185,8 +185,8 @@ def pipe(self) -> Pipeline:

def train(
self,
train_data: pd.DataFrame,
eval_data: pd.DataFrame | None = None,
train_data: pd.DataFrame | Path | str,
eval_data: pd.DataFrame | Path | str | None = None,
overwrite: bool = False,
resume_from_checkpoint: str | bool | None = None,
lazy_tokenize: bool = False,
Expand Down Expand Up @@ -223,13 +223,25 @@ def train(
"Please delete it or set `overwrite=True` to overwrite it.",
)

# Validate the input data format
self.validate_data_format(train_data)
if eval_data is not None:
self.validate_data_format(eval_data)

# Prepare the datasets
def prepare_dataset(data: pd.DataFrame) -> Dataset:
def prepare_dataset(
data: pd.DataFrame | Path | str,
) -> Dataset | IterableDataset:
if isinstance(data, str | Path):
dataset = load_dataset(
"csv",
data_files=str(data),
streaming=True,
)["train"]
dataset = dataset.map(self.tokenize, batched=True)
dataset = dataset.select_columns(self.dataset_cols)
dataset = dataset.shuffle(buffer_size=10_000)
count = 0
for chunk in pd.read_csv(data, chunksize=1_000_000):
count += len(chunk)
return dataset, count

self.validate_data_format(data)
dataset = Dataset.from_pandas(data)

if lazy_tokenize:
Expand All @@ -238,11 +250,11 @@ def prepare_dataset(data: pd.DataFrame) -> Dataset:
dataset = dataset.map(self.tokenize, batched=True)
dataset = dataset.select_columns(self.dataset_cols)
dataset.set_format(type="torch")
return dataset
return dataset, len(data)

train_dataset = prepare_dataset(train_data)
eval_dataset = (
None if eval_data is None else prepare_dataset(eval_data)
train_dataset, train_count = prepare_dataset(train_data)
eval_dataset, eval_count = (
(None, 0) if eval_data is None else prepare_dataset(eval_data)
)

if callbacks is None:
Expand Down Expand Up @@ -284,8 +296,8 @@ def prepare_dataset(data: pd.DataFrame) -> Dataset:
# Evaluate the model
if eval_dataset is not None:
eval_results = trainer.evaluate(eval_dataset)
eval_results["num_train_examples"] = len(train_data)
eval_results["num_eval_examples"] = len(eval_data)
eval_results["num_train_examples"] = train_count
eval_results["num_eval_examples"] = eval_count
self.results_path.write_text(json.dumps(eval_results, indent=4))

# Save the model
Expand Down
95 changes: 95 additions & 0 deletions projects/docketbert/prepare_train_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import os
from pathlib import Path

import pandas as pd
import psycopg2
from tqdm import tqdm

from clx import pd_save_or_append
from clx.settings import CLX_HOME

PROJECT_DIR = CLX_HOME / "projects" / "docketbert"

DB_CONFIG = {
"host": os.getenv("DEV_DB_HOST"),
"port": int(os.getenv("DEV_DB_PORT", "5432")),
"dbname": os.getenv("DEV_DB_NAME"),
"user": os.getenv("DEV_DB_USER"),
"password": os.getenv("DEV_DB_PASSWORD"),
}


def pull_dev_data(table_name, nrows=5_000_000, batch_size=1_000_000) -> None:
data_path = Path(PROJECT_DIR / "data" / f"{table_name}_descriptions.csv")
data_path.parent.mkdir(parents=True, exist_ok=True)

conn = psycopg2.connect(**DB_CONFIG)

last_id = None
current_rows = 0
if data_path.exists():
chunks = pd.read_csv(data_path, chunksize=batch_size)
for chunk in chunks:
current_rows += len(chunk)
min_id = chunk["id"].min()
last_id = min_id if last_id is None else min(last_id, min_id)

progress = tqdm(total=nrows, desc="Downloading")
progress.update(current_rows)

try:
while current_rows < nrows:
last_id_condition = (
"" if last_id is None else f"AND id < {last_id}"
)
with conn.cursor() as cur:
cur.execute(f"""
SELECT id, description FROM {table_name}
WHERE description IS NOT NULL
AND description <> ''
{last_id_condition}
ORDER BY id DESC
LIMIT {batch_size}
""")
rows = cur.fetchall()
if not rows:
break
col_names = [desc[0] for desc in cur.description]
data = pd.DataFrame(rows, columns=col_names)
last_id = data["id"].min()
current_rows += len(data)
progress.update(len(data))
pd_save_or_append(data, data_path)
finally:
conn.close()


def consolidate_data():
d1 = pd.read_csv(CLX_HOME / "app_projects" / "docket-entry" / "docs.csv")[
"text"
]
d2 = pd.read_csv(
CLX_HOME / "app_projects" / "docket-entry-short" / "docs.csv"
)["text"]
d3 = pd.read_csv(
PROJECT_DIR / "data" / "search_recapdocument_descriptions.csv",
usecols=["description"],
).rename(columns={"description": "text"})
d4 = pd.read_csv(
PROJECT_DIR / "data" / "search_docketentry_descriptions.csv",
usecols=["description"],
).rename(columns={"description": "text"})
data = pd.concat([d1, d2, d3, d4])
data = data.drop_duplicates("text")
data = data.sample(frac=1)
return data


if __name__ == "__main__":
pull_dev_data("search_docketentry", nrows=40_000_000)
pull_dev_data("search_recapdocument", nrows=20_000_000)
data = consolidate_data()
eval_data = data.tail(100000)
data = data.head(-100000)
data.to_csv(PROJECT_DIR / "data" / "train.csv", index=False)
eval_data.to_csv(PROJECT_DIR / "data" / "eval.csv", index=False)
37 changes: 37 additions & 0 deletions projects/docketbert/pull_runpod_logs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import os
import subprocess

from clx.settings import CLX_HOME

RUNPOD_POD_IP = os.getenv("RUNPOD_POD_IP")
RUNPOD_POD_PORT = os.getenv("RUNPOD_POD_PORT")
RUNPOD_SSH_KEY = os.getenv("RUNPOD_SSH_KEY")


if __name__ == "__main__":
remote = f"root@{RUNPOD_POD_IP}:/workspace/clx/projects/docketbert/runs"
local = CLX_HOME / "projects" / "docketbert"
exclude_patterns = [
"*.safetensors",
"*.pt",
"*.csv",
]

cmd = [
"rsync",
"-avz",
"--progress",
"-e",
f"ssh -i {RUNPOD_SSH_KEY} -p {RUNPOD_POD_PORT}",
]

for pattern in exclude_patterns:
cmd.append(f"--exclude={pattern}")

cmd += [
remote,
str(local),
]

print("Running:", " ".join(cmd))
subprocess.run(cmd, check=True)
49 changes: 39 additions & 10 deletions projects/docketbert/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
from clx.settings import CLX_HOME

PROJECT_DIR = CLX_HOME / "projects" / "docketbert"
EXP_DATA_PATH = CLX_HOME / "app_projects" / "docket-entry" / "docs.csv"
FULL_DATA_TRAIN_PATH = PROJECT_DIR / "data" / "train.csv"
FULL_DATA_EVAL_PATH = PROJECT_DIR / "data" / "eval.csv"


def create_sliced_model(
Expand All @@ -35,6 +38,7 @@ def create_sliced_model(

def get_experiment_config(experiment, batch_size=None):
config = {
"use_full_data": False,
"task": "mlm",
"run_dir_parent": PROJECT_DIR / "runs",
"base_model_name": "answerdotai/ModernBERT-base",
Expand Down Expand Up @@ -234,6 +238,26 @@ def get_experiment_config(experiment, batch_size=None):
"global_attn_every_n_layers": 2,
}
default_batch_size = 8
elif experiment == "final-base-150M":
config["training_args"]["max_steps"] = 40761591 // 256
config["use_full_data"] = True
default_batch_size = 16
elif experiment == "final-large-395M":
config["base_model_name"] = "answerdotai/ModernBERT-large"
config["training_args"]["max_steps"] = 40761591 // 256
config["use_full_data"] = True
default_batch_size = 8
elif experiment == "final-sliced-175M":
base_model_name = (
PROJECT_DIR / "runs" / "docketbert-final-large-395M" / "model"
)
config["base_model_name"] = create_sliced_model(
"final-sliced-large-ft-interleaved-10l",
[0, 3, 6, 9, 12, 15, 18, 21, 24, 27],
base_model_name,
)
config["training_args"]["max_steps"] = 40761591 // 256
config["use_full_data"] = True
else:
raise ValueError(f"Invalid experiment: {experiment}")

Expand Down Expand Up @@ -308,25 +332,30 @@ def train_docketbert(
overwrite, resume, check_params, mem_test, experiment, batch_size, exit
):
"""Train a docket language model."""
from clx.models import DocketEntry

try:
if resume and overwrite:
raise click.UsageError(
"Cannot use --resume and --overwrite together."
)

data = pd.read_csv(
DocketEntry.get_project().cached_documents_path,
usecols=["text"],
nrows=200000 if mem_test else None,
)
data = data.sample(frac=1, random_state=42)
train_data = data.head(-100000)
eval_data = data.tail(100000)

config = get_experiment_config(experiment, batch_size)

use_full_data = config.pop("use_full_data")

if use_full_data:
train_data = FULL_DATA_TRAIN_PATH
eval_data = FULL_DATA_EVAL_PATH
else:
data = pd.read_csv(
EXP_DATA_PATH,
usecols=["text"],
nrows=200000 if mem_test else None,
)
data = data.sample(frac=1, random_state=42)
train_data = data.head(-100000)
eval_data = data.tail(100000)

if mem_test:
config["tokenize_args"]["padding"] = "max_length"

Expand Down
12 changes: 7 additions & 5 deletions runpod_init.sh
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
#!/usr/bin/env bash

git fetch origin
git reset --hard origin/main
pip install -e .
apt-get update && apt-get install -y rsync
git pull
pip install -e '.[dev]'
pip install flash-attn --no-build-isolation
clx config --autoload-env on
export CLX_HOME=/workspace/clx/home
export HF_HOME=/workspace/hf
cat > .env << 'EOF'
CLX_HOME=/workspace/clx
HF_HOME=/workspace/hf
EOF