diff --git a/clx/ml/training_run.py b/clx/ml/training_run.py index 8edaae2..1f5efdf 100644 --- a/clx/ml/training_run.py +++ b/clx/ml/training_run.py @@ -11,7 +11,7 @@ import requests import simplejson as json import torch -from datasets import Dataset +from datasets import Dataset, IterableDataset, load_dataset from tqdm import tqdm from transformers import ( AutoModel, @@ -185,8 +185,8 @@ def pipe(self) -> Pipeline: def train( self, - train_data: pd.DataFrame, - eval_data: pd.DataFrame | None = None, + train_data: pd.DataFrame | Path | str, + eval_data: pd.DataFrame | Path | str | None = None, overwrite: bool = False, resume_from_checkpoint: str | bool | None = None, lazy_tokenize: bool = False, @@ -223,13 +223,25 @@ def train( "Please delete it or set `overwrite=True` to overwrite it.", ) - # Validate the input data format - self.validate_data_format(train_data) - if eval_data is not None: - self.validate_data_format(eval_data) - # Prepare the datasets - def prepare_dataset(data: pd.DataFrame) -> Dataset: + def prepare_dataset( + data: pd.DataFrame | Path | str, + ) -> Dataset | IterableDataset: + if isinstance(data, str | Path): + dataset = load_dataset( + "csv", + data_files=str(data), + streaming=True, + )["train"] + dataset = dataset.map(self.tokenize, batched=True) + dataset = dataset.select_columns(self.dataset_cols) + dataset = dataset.shuffle(buffer_size=10_000) + count = 0 + for chunk in pd.read_csv(data, chunksize=1_000_000): + count += len(chunk) + return dataset, count + + self.validate_data_format(data) dataset = Dataset.from_pandas(data) if lazy_tokenize: @@ -238,11 +250,11 @@ def prepare_dataset(data: pd.DataFrame) -> Dataset: dataset = dataset.map(self.tokenize, batched=True) dataset = dataset.select_columns(self.dataset_cols) dataset.set_format(type="torch") - return dataset + return dataset, len(data) - train_dataset = prepare_dataset(train_data) - eval_dataset = ( - None if eval_data is None else prepare_dataset(eval_data) + train_dataset, train_count = prepare_dataset(train_data) + eval_dataset, eval_count = ( + (None, 0) if eval_data is None else prepare_dataset(eval_data) ) if callbacks is None: @@ -284,8 +296,8 @@ def prepare_dataset(data: pd.DataFrame) -> Dataset: # Evaluate the model if eval_dataset is not None: eval_results = trainer.evaluate(eval_dataset) - eval_results["num_train_examples"] = len(train_data) - eval_results["num_eval_examples"] = len(eval_data) + eval_results["num_train_examples"] = train_count + eval_results["num_eval_examples"] = eval_count self.results_path.write_text(json.dumps(eval_results, indent=4)) # Save the model diff --git a/projects/docketbert/prepare_train_data.py b/projects/docketbert/prepare_train_data.py new file mode 100644 index 0000000..daba9f2 --- /dev/null +++ b/projects/docketbert/prepare_train_data.py @@ -0,0 +1,95 @@ +import os +from pathlib import Path + +import pandas as pd +import psycopg2 +from tqdm import tqdm + +from clx import pd_save_or_append +from clx.settings import CLX_HOME + +PROJECT_DIR = CLX_HOME / "projects" / "docketbert" + +DB_CONFIG = { + "host": os.getenv("DEV_DB_HOST"), + "port": int(os.getenv("DEV_DB_PORT", "5432")), + "dbname": os.getenv("DEV_DB_NAME"), + "user": os.getenv("DEV_DB_USER"), + "password": os.getenv("DEV_DB_PASSWORD"), +} + + +def pull_dev_data(table_name, nrows=5_000_000, batch_size=1_000_000) -> None: + data_path = Path(PROJECT_DIR / "data" / f"{table_name}_descriptions.csv") + data_path.parent.mkdir(parents=True, exist_ok=True) + + conn = psycopg2.connect(**DB_CONFIG) + + last_id = None + current_rows = 0 + if data_path.exists(): + chunks = pd.read_csv(data_path, chunksize=batch_size) + for chunk in chunks: + current_rows += len(chunk) + min_id = chunk["id"].min() + last_id = min_id if last_id is None else min(last_id, min_id) + + progress = tqdm(total=nrows, desc="Downloading") + progress.update(current_rows) + + try: + while current_rows < nrows: + last_id_condition = ( + "" if last_id is None else f"AND id < {last_id}" + ) + with conn.cursor() as cur: + cur.execute(f""" + SELECT id, description FROM {table_name} + WHERE description IS NOT NULL + AND description <> '' + {last_id_condition} + ORDER BY id DESC + LIMIT {batch_size} + """) + rows = cur.fetchall() + if not rows: + break + col_names = [desc[0] for desc in cur.description] + data = pd.DataFrame(rows, columns=col_names) + last_id = data["id"].min() + current_rows += len(data) + progress.update(len(data)) + pd_save_or_append(data, data_path) + finally: + conn.close() + + +def consolidate_data(): + d1 = pd.read_csv(CLX_HOME / "app_projects" / "docket-entry" / "docs.csv")[ + "text" + ] + d2 = pd.read_csv( + CLX_HOME / "app_projects" / "docket-entry-short" / "docs.csv" + )["text"] + d3 = pd.read_csv( + PROJECT_DIR / "data" / "search_recapdocument_descriptions.csv", + usecols=["description"], + ).rename(columns={"description": "text"}) + d4 = pd.read_csv( + PROJECT_DIR / "data" / "search_docketentry_descriptions.csv", + usecols=["description"], + ).rename(columns={"description": "text"}) + data = pd.concat([d1, d2, d3, d4]) + data = data.drop_duplicates("text") + data = data.sample(frac=1) + return data + + +if __name__ == "__main__": + pull_dev_data("search_docketentry", nrows=40_000_000) + pull_dev_data("search_recapdocument", nrows=20_000_000) + data = consolidate_data() + eval_data = data.tail(100000) + data = data.head(-100000) + data.to_csv(PROJECT_DIR / "data" / "train.csv", index=False) + eval_data.to_csv(PROJECT_DIR / "data" / "eval.csv", index=False) diff --git a/projects/docketbert/pull_runpod_logs.py b/projects/docketbert/pull_runpod_logs.py new file mode 100644 index 0000000..f701559 --- /dev/null +++ b/projects/docketbert/pull_runpod_logs.py @@ -0,0 +1,37 @@ +import os +import subprocess + +from clx.settings import CLX_HOME + +RUNPOD_POD_IP = os.getenv("RUNPOD_POD_IP") +RUNPOD_POD_PORT = os.getenv("RUNPOD_POD_PORT") +RUNPOD_SSH_KEY = os.getenv("RUNPOD_SSH_KEY") + + +if __name__ == "__main__": + remote = f"root@{RUNPOD_POD_IP}:/workspace/clx/projects/docketbert/runs" + local = CLX_HOME / "projects" / "docketbert" + exclude_patterns = [ + "*.safetensors", + "*.pt", + "*.csv", + ] + + cmd = [ + "rsync", + "-avz", + "--progress", + "-e", + f"ssh -i {RUNPOD_SSH_KEY} -p {RUNPOD_POD_PORT}", + ] + + for pattern in exclude_patterns: + cmd.append(f"--exclude={pattern}") + + cmd += [ + remote, + str(local), + ] + + print("Running:", " ".join(cmd)) + subprocess.run(cmd, check=True) diff --git a/projects/docketbert/train.py b/projects/docketbert/train.py index ae0dcda..3036098 100644 --- a/projects/docketbert/train.py +++ b/projects/docketbert/train.py @@ -9,6 +9,9 @@ from clx.settings import CLX_HOME PROJECT_DIR = CLX_HOME / "projects" / "docketbert" +EXP_DATA_PATH = CLX_HOME / "app_projects" / "docket-entry" / "docs.csv" +FULL_DATA_TRAIN_PATH = PROJECT_DIR / "data" / "train.csv" +FULL_DATA_EVAL_PATH = PROJECT_DIR / "data" / "eval.csv" def create_sliced_model( @@ -35,6 +38,7 @@ def create_sliced_model( def get_experiment_config(experiment, batch_size=None): config = { + "use_full_data": False, "task": "mlm", "run_dir_parent": PROJECT_DIR / "runs", "base_model_name": "answerdotai/ModernBERT-base", @@ -234,6 +238,26 @@ def get_experiment_config(experiment, batch_size=None): "global_attn_every_n_layers": 2, } default_batch_size = 8 + elif experiment == "final-base-150M": + config["training_args"]["max_steps"] = 40761591 // 256 + config["use_full_data"] = True + default_batch_size = 16 + elif experiment == "final-large-395M": + config["base_model_name"] = "answerdotai/ModernBERT-large" + config["training_args"]["max_steps"] = 40761591 // 256 + config["use_full_data"] = True + default_batch_size = 8 + elif experiment == "final-sliced-175M": + base_model_name = ( + PROJECT_DIR / "runs" / "docketbert-final-large-395M" / "model" + ) + config["base_model_name"] = create_sliced_model( + "final-sliced-large-ft-interleaved-10l", + [0, 3, 6, 9, 12, 15, 18, 21, 24, 27], + base_model_name, + ) + config["training_args"]["max_steps"] = 40761591 // 256 + config["use_full_data"] = True else: raise ValueError(f"Invalid experiment: {experiment}") @@ -308,7 +332,6 @@ def train_docketbert( overwrite, resume, check_params, mem_test, experiment, batch_size, exit ): """Train a docket language model.""" - from clx.models import DocketEntry try: if resume and overwrite: @@ -316,17 +339,23 @@ def train_docketbert( "Cannot use --resume and --overwrite together." ) - data = pd.read_csv( - DocketEntry.get_project().cached_documents_path, - usecols=["text"], - nrows=200000 if mem_test else None, - ) - data = data.sample(frac=1, random_state=42) - train_data = data.head(-100000) - eval_data = data.tail(100000) - config = get_experiment_config(experiment, batch_size) + use_full_data = config.pop("use_full_data") + + if use_full_data: + train_data = FULL_DATA_TRAIN_PATH + eval_data = FULL_DATA_EVAL_PATH + else: + data = pd.read_csv( + EXP_DATA_PATH, + usecols=["text"], + nrows=200000 if mem_test else None, + ) + data = data.sample(frac=1, random_state=42) + train_data = data.head(-100000) + eval_data = data.tail(100000) + if mem_test: config["tokenize_args"]["padding"] = "max_length" diff --git a/runpod_init.sh b/runpod_init.sh index d093b31..de1b903 100755 --- a/runpod_init.sh +++ b/runpod_init.sh @@ -1,9 +1,11 @@ #!/usr/bin/env bash -git fetch origin -git reset --hard origin/main -pip install -e . +apt-get update && apt-get install -y rsync +git pull +pip install -e '.[dev]' pip install flash-attn --no-build-isolation clx config --autoload-env on -export CLX_HOME=/workspace/clx/home -export HF_HOME=/workspace/hf +cat > .env << 'EOF' +CLX_HOME=/workspace/clx +HF_HOME=/workspace/hf +EOF