Skip to content

Feature/fit on ray#165

Open
pradyumna-rfai wants to merge 18 commits intomainfrom
feature/fit-on-ray
Open

Feature/fit on ray#165
pradyumna-rfai wants to merge 18 commits intomainfrom
feature/fit-on-ray

Conversation

@pradyumna-rfai
Copy link
Collaborator

@pradyumna-rfai pradyumna-rfai commented Feb 3, 2026

PR Summary: Unified Fit and Evals

This is a major refactoring PR that unifies the codebase for both fit (training) and evals (inference) modes, eliminating code duplication and enabling shared infrastructure for experiment tracking, interactive control, and metric logging and RF setup.

Changes

Major Changes

1. Unified Database Schema

  • Single experiments table for both fit and evals modes
    • Mode-specific configuration stored in JSON config column
  • Unified interactive_control table for dynamic operations
    • target_type field: 'run' (fit) or 'pipeline' (evals)
    • target_id field: holds run_id or pipeline_id
    • config_data field: holds operation-specific JSON configuration
    • Supports operations: stop, resume, delete, clone, clone_warm
  • Mode-specific tables remain separate:
    • Fit mode: runs, worker_task, controller_progress, worker_progress
    • Evals mode: pipelines, contexts, actor_tasks

2. Unified Experiment Class

  • Single entry point Experiment(name, mode="fit"|"evals") for both modes
  • Mode-specific initialization:
    • _init_fit_mode() - Sets up training infrastructure
    • _init_evals_mode() - Sets up inference infrastructure
  • Shared methods:
    • end() - Clean up resources
    • cancel_current() - Cancel current operation
    • get_log_file_path() - Get experiment logs
  • Mode-specific methods:
    • run_fit() - Execute training (fit mode only)
    • run_evals() - Execute inference (evals mode only)
    • get_results() - Get training metrics (fit mode only)
    • get_runs_info() - Get run information (fit mode only)

3. Unified Metric Logging System

4. Unified Status Enums

5. Setup

  • Unified setup for both fit and evals mode. Removed flags for --init command.
  • Added --clear command to clear all Db, logs and dashboard files.

Testing

  • Ran ChatQA lite notebook for SFT E2E with IC Ops - stop, clone
  • Ran DPO notebook for SFT E2E with IC Ops - stop, clone
  • Ran FIQA RAG notebook for evals E2E with IC Ops - stop, clone

Screenshots

Screenshot 2026-01-31 at 7 06 36 PM Screenshot 2026-01-31 at 7 41 09 PM Screenshot 2026-01-31 at 7 41 36 PM Screenshot 2026-02-02 at 3 11 20 PM Screenshot 2026-02-02 at 3 11 46 PM Screenshot 2026-02-02 at 3 46 11 PM

Note

High Risk
Large refactor that replaces the evals-specific DB/dispatcher with a new unified SQLite schema and new REST dispatcher, which can impact experiment tracking and interactive control across both fit and evals. CLI install/init behavior also changes (including a new destructive clear command), increasing rollout risk if paths/requirements differ across environments.

Overview
This PR replaces the evals-only database/dispatcher stack with a unified DB layer in rapidfireai/db (new tables.sql, DatabaseInterface, and a consolidated RfDb) that tracks experiments plus mode-specific entities (fit runs/tasks/progress and evals pipelines/contexts/actor tasks) and standardizes interactive control via a single interactive_control table.

It also adds a new unified Flask-based dispatcher in rapidfireai/dispatcher exposing run and pipeline IC endpoints (stop/resume/delete/clone) backed by the new DB, and updates Gunicorn wiring to the new module.

Developer-facing updates include: CLI init simplification (drops --evals, switches to unified setup/rapidfireai requirements selection), CUDA-aware package install adjustments, a new rapidfireai clear command that deletes DB/logs/experiments directories, dependency tweaks in pyproject.toml, and notebook/code import path updates (new rapidfireai.db.RfDb, moved InteractiveController, and logger/exception import relocations).

Written by Cursor Bugbot for commit 5d72352. This will update automatically on new commits. Configure here.

Copy link

@cursor cursor bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cursor Bugbot has reviewed your changes and found 4 potential issues.

Autofix Details

Bugbot Autofix prepared fixes for all 4 issues found in the latest run.

  • ✅ Fixed: Duplicate torch/torchvision/torchaudio installation in setup
    • Removed the second unconditional torch/torchvision/torchaudio append block so each package is installed only once per init run.
  • ✅ Fixed: Cannot clear ended_by field via set_run_details
    • Updated set_run_details to accept enum-or-string values for source and ended_by so empty strings are persisted and can clear stale values.
  • ✅ Fixed: Fit controller fetches pipeline IC operations from unified table
    • Changed the fit controller to call get_pending_ic_operations(target_type="run") so it ignores pipeline-targeted IC operations.
  • ✅ Fixed: Duplicate encode_payload/decode_db_payload definitions in rf_db.py
    • Removed local duplicate serializer helpers from rf_db.py and imported the canonical implementations from rapidfireai.utils.serialize.

Create PR

Or push these changes by commenting:

@cursor push 6d1da7f0c6
Preview (6d1da7f0c6)
diff --git a/rapidfireai/cli.py b/rapidfireai/cli.py
--- a/rapidfireai/cli.py
+++ b/rapidfireai/cli.py
@@ -252,9 +252,6 @@
         if get_compute_capability() >= 8.0:
             packages.append({"package": "flash-attn>=2.8.3", "extra_args": ["--upgrade", "--no-build-isolation"]})
         packages.append({"package": "transformers>=4.56.1,<5.0.0", "extra_args": ["--upgrade"]})
-        packages.append({"package": f"torch=={torch_version}", "extra_args": ["--upgrade", "--index-url", f"https://download.pytorch.org/whl/{torch_cuda}"]})
-        packages.append({"package": f"torchvision=={torchvision_version}", "extra_args": ["--upgrade", "--index-url", f"https://download.pytorch.org/whl/{torch_cuda}"]})
-        packages.append({"package": f"torchaudio=={torchaudio_version}", "extra_args": ["--upgrade", "--index-url", f"https://download.pytorch.org/whl/{torch_cuda}"]})
 
         packages.append({"package": "numpy<2.3", "extra_args": ["--upgrade"]})
 

diff --git a/rapidfireai/db/rf_db.py b/rapidfireai/db/rf_db.py
--- a/rapidfireai/db/rf_db.py
+++ b/rapidfireai/db/rf_db.py
@@ -10,6 +10,7 @@
 from typing import Any
 
 from rapidfireai.db.db_interface import DatabaseInterface
+from rapidfireai.utils.serialize import decode_db_payload, encode_payload, extract_pipeline_config_json
 from rapidfireai.utils.constants import (
     ContextStatus,
     ControllerTask,
@@ -26,20 +27,6 @@
 )
 
 
-def encode_payload(payload: object) -> str:
-    """Encode the payload for the database using dill."""
-    import base64
-    import dill
-    return base64.b64encode(dill.dumps(payload)).decode("utf-8")
-
-
-def decode_db_payload(payload: str) -> object:
-    """Decode the payload from the database using dill."""
-    import base64
-    import dill
-    return dill.loads(base64.b64decode(payload))
-
-
 class RfDb:
     """
     Database manager for RapidFire AI experiments.
@@ -787,8 +774,8 @@
         num_epochs_completed: int | None = None,
         chunk_offset: int | None = None,
         error: str | None = None,
-        source: RunSource | None = None,
-        ended_by: RunEndedBy | None = None,
+        source: RunSource | str | None = None,
+        ended_by: RunEndedBy | str | None = None,
         warm_started_from: int | None = None,
         cloned_from: int | None = None,
     ) -> None:
@@ -804,8 +791,8 @@
             "num_epochs_completed": num_epochs_completed,
             "chunk_offset": chunk_offset,
             "error": error,
-            "source": source.value if source else None,
-            "ended_by": ended_by.value if ended_by else None,
+            "source": source.value if isinstance(source, RunSource) else source,
+            "ended_by": ended_by.value if isinstance(ended_by, RunEndedBy) else ended_by,
             "warm_started_from": warm_started_from,
             "cloned_from": cloned_from,
         }
@@ -1068,7 +1055,6 @@
         encoded_config = encode_payload(pipeline_config)
 
         # Extract JSON-serializable data
-        from rapidfireai.utils.serialize import extract_pipeline_config_json
         json_config_dict = extract_pipeline_config_json(pipeline_config)
         json_config_str = json.dumps(json_config_dict) if json_config_dict else "{}"
         flattened_config_str = json.dumps(flattened_config) if flattened_config else "{}"

diff --git a/rapidfireai/fit/backend/controller.py b/rapidfireai/fit/backend/controller.py
--- a/rapidfireai/fit/backend/controller.py
+++ b/rapidfireai/fit/backend/controller.py
@@ -423,7 +423,7 @@
     ) -> tuple[dict[str, Any], list[dict[str, Any]]]:
         """Process the interactive control."""
         # get IC Ops scheduled tasks
-        ic_scheduled_tasks = self.db.get_pending_ic_operations()
+        ic_scheduled_tasks = self.db.get_pending_ic_operations(target_type="run")
 
         # track states for each task(run) and collect clone_modify tasks separately
         run_states = {}
This Bugbot Autofix run was free. To enable autofix for future PRs, go to the Cursor dashboard.

"cloned_from": cloned_from,
}

columns = {k: v for k, v in columns.items() if v is not None}
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cannot clear ended_by field via set_run_details

High Severity

set_run_details uses ended_by.value if ended_by else None which treats empty string "" as falsy, converting it to None. Since None values are then filtered out, the controller's call set_run_details(run_id, status=ONGOING, ended_by="") during run resume never clears the ended_by column. The resumed run retains its old ended_by value (e.g., interactive_control), which is incorrect state.

Additional Locations (1)

Fix in Cursor Fix in Web

"created_at": row[7],
"processed_at": row[8],
})
return operations
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fit controller fetches pipeline IC operations from unified table

High Severity

The fit controller calls get_pending_ic_operations() without passing target_type="run". With the unified interactive_control table, this returns both run-targeted and pipeline-targeted IC operations. When the fit controller processes a pipeline operation, it treats target_id (a pipeline_id) as a run_id and calls self.db.get_run(run_id), which will raise an exception because no run exists with that ID, crashing IC processing.

Additional Locations (1)

Fix in Cursor Fix in Web

"""Decode the payload from the database using dill."""
import base64
import dill
return dill.loads(base64.b64decode(payload))
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Duplicate encode_payload/decode_db_payload definitions in rf_db.py

Low Severity

encode_payload and decode_db_payload are defined at module level in rf_db.py but are identical to the canonical implementations in rapidfireai/utils/serialize.py. The same file already imports from rapidfireai.utils.serialize elsewhere (line 1071 for extract_pipeline_config_json), so these local copies are pure duplication that risks divergent maintenance.

Additional Locations (1)

Fix in Cursor Fix in Web

"""Process the interactive control."""
# get IC Ops scheduled tasks
ic_scheduled_tasks = self.db.get_scheduled_ic_ops_tasks()
ic_scheduled_tasks = self.db.get_pending_ic_operations()
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fit controller processes pipeline IC operations as runs

High Severity

The fit controller calls self.db.get_pending_ic_operations() without passing target_type="run". With the unified interactive_control table now storing both run and pipeline operations, this returns ALL pending operations. Pipeline-targeted IC operations will be processed as run operations — task["target_id"] (a pipeline_id) will be passed to self.db.get_run(), which will either crash with "Run not found" or corrupt state if IDs happen to collide.

Fix in Cursor Fix in Web

The unified db/rf_db.py was missing estimated_runtime, required_workers
columns in the runs table and multi_worker_details in worker_task table.
Also fixed set_run_details missing those params, added set_estimated_runtime
method, and fixed set_experiment_error call in worker_actor.py to pass
experiment_id matching the unified API signature.

Made-with: Cursor
Updated interactive_controller import from fit.utils to fit.backend
and rf_db import from fit.db to db across all tutorial and community
notebooks.

Made-with: Cursor
Renamed MLFlowConfig to MLflowConfig in utils/__init__.py, dispatcher,
and removed duplicate stale MLFlowConfig imports in controller.py.
Fixed automl_utils.py import from fit.utils.exceptions to utils.exceptions
and callbacks.py import from fit.utils.logging to utils.logging.

Made-with: Cursor
Fixed stale imports in conftest.py and test_metric_logger.py to use
the new metrics package paths. Removed tests for DualMetricLogger and
create_metric_logger which no longer exist. Fixed callback import path
from rapidfireai.ml to rapidfireai.fit.ml. Deleted orphaned
fit/utils/serialize.py superseded by utils/serialize.py.

Made-with: Cursor
Copy link

@cursor cursor bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cursor Bugbot has reviewed your changes and found 1 potential issue.

Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.

"source": [
"# Get experiment path\n",
"from rapidfireai.fit.db.rf_db import RfDb\n",
"from rapidfireai.db.rf_db import RfDb\n",
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Notebooks pass string name to integer-typed parameter

High Severity

The notebooks call db.get_experiments_path(my_experiment) where my_experiment is a string (e.g., "pii-masking-gpt2-v1"), but the new RfDb.get_experiments_path method accepts experiment_id: int and queries WHERE experiment_id = ?. SQLite won't match a text string against an integer primary key column, so the query returns no rows and raises Exception("Experiments path not found"). The import path was updated to the new unified DB, but the call site wasn't adapted to the changed method signature.

Additional Locations (2)

Fix in Cursor Fix in Web

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant