From 8c4b959c71dcf343b3c8b90bbd0288fcb118cfe5 Mon Sep 17 00:00:00 2001 From: raviguptaamd Date: Thu, 5 Feb 2026 07:33:40 +0000 Subject: [PATCH] Updates to slurm launcher --- src/madengine/cli/commands/build.py | 18 + src/madengine/deployment/slurm.py | 344 +++++++++++++++++- src/madengine/execution/container_runner.py | 206 ++++++++++- .../orchestration/build_orchestrator.py | 283 ++++++++++++++ 4 files changed, 843 insertions(+), 8 deletions(-) diff --git a/src/madengine/cli/commands/build.py b/src/madengine/cli/commands/build.py index 99166a47..e7e93a3c 100644 --- a/src/madengine/cli/commands/build.py +++ b/src/madengine/cli/commands/build.py @@ -55,6 +55,20 @@ def build( "--batch-manifest", help="Input batch.json file for batch build mode" ), ] = None, + use_image: Annotated[ + Optional[str], + typer.Option( + "--use-image", + help="Skip Docker build and use pre-built image (e.g., lmsysorg/sglang:v0.5.2rc1-rocm700-mi30x)" + ), + ] = None, + build_on_compute: Annotated[ + bool, + typer.Option( + "--build-on-compute", + help="Build Docker images on SLURM compute node instead of login node" + ), + ] = False, additional_context: Annotated[ str, typer.Option( @@ -183,6 +197,8 @@ def build( verbose=verbose, _separate_phases=True, batch_build_metadata=batch_build_metadata if batch_build_metadata else None, + use_image=use_image, + build_on_compute=build_on_compute, ) # Initialize orchestrator in build-only mode @@ -203,6 +219,8 @@ def build( clean_cache=clean_docker_cache, manifest_output=manifest_output, batch_build_metadata=batch_build_metadata, + use_image=use_image, + build_on_compute=build_on_compute, ) # Load build summary for display diff --git a/src/madengine/deployment/slurm.py b/src/madengine/deployment/slurm.py index 577be47f..7ac6d6c4 100644 --- a/src/madengine/deployment/slurm.py +++ b/src/madengine/deployment/slurm.py @@ -249,6 +249,7 @@ def __init__(self, config: DeploymentConfig): self.gpus_per_node = self.slurm_config.get("gpus_per_node", 8) self.time_limit = self.slurm_config.get("time", "24:00:00") self.output_dir = Path(self.slurm_config.get("output_dir", "./slurm_output")) + self.reservation = self.slurm_config.get("reservation", None) # Setup Jinja2 template engine template_dir = Path(__file__).parent / "templates" / "slurm" @@ -261,6 +262,115 @@ def __init__(self, config: DeploymentConfig): # Generated script path self.script_path = None + # ========== OPTION 2: Detect existing SLURM allocation ========== + # If SLURM_JOB_ID exists, we're inside an salloc allocation + self.inside_allocation = os.environ.get("SLURM_JOB_ID") is not None + self.existing_job_id = os.environ.get("SLURM_JOB_ID", "") + self.allocation_nodes = self._get_allocation_node_count() + + if self.inside_allocation: + self.console.print( + f"[cyan]✓ Detected existing SLURM allocation: Job {self.existing_job_id}[/cyan]" + ) + self.console.print( + f" Allocation has {self.allocation_nodes} nodes available" + ) + + def _get_allocation_node_count(self) -> int: + """ + Get number of nodes in current SLURM allocation. + + Note: SLURM_NNODES reflects the current job step, not the full allocation. + We query the job directly using scontrol to get the actual node count. + """ + if not self.inside_allocation: + return 0 + + job_id = self.existing_job_id + + # Query the actual job's node count using scontrol (most accurate) + try: + result = subprocess.run( + ["scontrol", "show", "job", job_id], + capture_output=True, + text=True, + timeout=10, + ) + if result.returncode == 0: + # Parse NumNodes=X from output + for line in result.stdout.split("\n"): + if "NumNodes=" in line: + # Format: "NumNodes=3 NumCPUs=..." + for part in line.split(): + if part.startswith("NumNodes="): + try: + return int(part.split("=")[1]) + except (ValueError, IndexError): + pass + except Exception: + pass + + # Fallback: Try SLURM_JOB_NUM_NODES (full job node count, if set) + job_num_nodes = os.environ.get("SLURM_JOB_NUM_NODES") + if job_num_nodes: + try: + return int(job_num_nodes) + except ValueError: + pass + + # Fallback: SLURM_NNODES (may be step-specific, not full allocation) + nnodes = os.environ.get("SLURM_NNODES") + if nnodes: + try: + return int(nnodes) + except ValueError: + pass + + # Last resort: count nodes in SLURM_NODELIST + nodelist = os.environ.get("SLURM_NODELIST") + if nodelist: + try: + result = subprocess.run( + ["scontrol", "show", "hostname", nodelist], + capture_output=True, + text=True, + timeout=10, + ) + if result.returncode == 0: + return len(result.stdout.strip().split("\n")) + except Exception: + pass + + return 0 + + def _validate_allocation_nodes(self) -> tuple[bool, str]: + """ + Validate that existing allocation has enough nodes for the job. + + Returns: + Tuple of (is_valid, error_message) + """ + if not self.inside_allocation: + return True, "" + + requested_nodes = self.nodes + available_nodes = self.allocation_nodes + + if available_nodes < requested_nodes: + return False, ( + f"Insufficient nodes in current allocation. " + f"Requested: {requested_nodes}, Available: {available_nodes}. " + f"Either reduce nodes in config or use a larger allocation." + ) + + if available_nodes > requested_nodes: + self.console.print( + f"[yellow]⚠ Note: Using {requested_nodes} of {available_nodes} " + f"available nodes in allocation[/yellow]" + ) + + return True, "" + def validate(self) -> bool: """Validate SLURM commands are available locally.""" # Check required SLURM CLI tools @@ -362,11 +472,6 @@ def prepare(self) -> bool: """Generate sbatch script from template.""" # Validate environment BEFORE generating job scripts self.console.print("\n[bold]Validating submission environment...[/bold]") - if not self._validate_cli_availability(): - self.console.print( - "\n[yellow]⚠ Tip: Compute nodes inherit your submission environment[/yellow]" - ) - return False try: self.output_dir.mkdir(parents=True, exist_ok=True) @@ -379,6 +484,23 @@ def prepare(self) -> bool: model_key = model_keys[0] model_info = self.manifest["built_models"][model_key] + # Check if this is a baremetal launcher (sglang-disagg, vllm-disagg) + launcher_type = self.distributed_config.get("launcher", "torchrun") + launcher_normalized = launcher_type.lower().replace("_", "-") + + if launcher_normalized in ["sglang-disagg", "vllm-disagg"]: + # For disagg launchers, generate simple wrapper script + # that runs the model's .slurm script directly on baremetal + self.console.print(f"[cyan]Detected baremetal launcher: {launcher_type}[/cyan]") + return self._prepare_baremetal_script(model_info) + + # Standard flow: validate madengine availability for complex job template + if not self._validate_cli_availability(): + self.console.print( + "\n[yellow]⚠ Tip: Compute nodes inherit your submission environment[/yellow]" + ) + return False + # Prepare template context context = self._prepare_template_context(model_info) @@ -400,6 +522,115 @@ def prepare(self) -> bool: self.console.print(f"[red]✗ Failed to generate script: {e}[/red]") return False + def _prepare_baremetal_script(self, model_info: Dict) -> bool: + """ + Generate a simple wrapper script for baremetal launchers (sglang-disagg, vllm-disagg). + + These launchers run the model's .slurm script directly on baremetal, + which then manages Docker containers via srun. No madengine wrapper needed. + """ + # Get the model's script path + model_script = model_info.get("scripts", "") + if not model_script: + self.console.print("[red]✗ No scripts defined in model_info[/red]") + return False + + # Get manifest directory (where the model script is relative to) + manifest_dir = Path(self.config.manifest_file).parent.absolute() + model_script_path = manifest_dir / model_script + + if not model_script_path.exists(): + self.console.print(f"[red]✗ Model script not found: {model_script_path}[/red]") + return False + + # Get environment variables + env_vars = {} + + # From model_info.env_vars + if "env_vars" in model_info: + env_vars.update(model_info["env_vars"]) + + # From additional_context.env_vars + if "env_vars" in self.config.additional_context: + env_vars.update(self.config.additional_context["env_vars"]) + + # From distributed config + sglang_disagg_config = self.distributed_config.get("sglang_disagg", {}) + if sglang_disagg_config: + env_vars["xP"] = str(sglang_disagg_config.get("prefill_nodes", 1)) + env_vars["yD"] = str(sglang_disagg_config.get("decode_nodes", 1)) + + # Get model args + model_args = model_info.get("args", "") + + # Generate simple wrapper script + # IMPORTANT: SBATCH directives MUST be at the top, right after #!/bin/bash + script_lines = [ + "#!/bin/bash", + f"#SBATCH --job-name=madengine-{model_info['name']}", + f"#SBATCH --output={self.output_dir}/madengine-{model_info['name']}_%j.out", + f"#SBATCH --error={self.output_dir}/madengine-{model_info['name']}_%j.err", + f"#SBATCH --partition={self.partition}", + f"#SBATCH --nodes={self.nodes}", + f"#SBATCH --ntasks={self.nodes}", + f"#SBATCH --gpus-per-node={self.gpus_per_node}", + f"#SBATCH --time={self.time_limit}", + "#SBATCH --exclusive", + ] + + # Add reservation if specified + if self.reservation: + script_lines.append(f"#SBATCH --reservation={self.reservation}") + + script_lines.extend([ + "", + f"# Baremetal launcher script for {model_info['name']}", + f"# Generated by madengine for sglang-disagg", + "", + "set -e", + "", + "# Environment variables", + ]) + + for key, value in env_vars.items(): + script_lines.append(f"export {key}=\"{value}\"") + + script_lines.append("") + script_lines.extend([ + "echo '=========================================='", + "echo 'Baremetal Launcher - SGLang Disaggregated'", + "echo '=========================================='", + f"echo 'Model: {model_info['name']}'", + f"echo 'Script: {model_script_path}'", + "echo 'SLURM_JOB_ID:' $SLURM_JOB_ID", + "echo 'SLURM_NNODES:' $SLURM_NNODES", + "echo 'SLURM_NODELIST:' $SLURM_NODELIST", + "echo ''", + "", + "# Change to script directory", + f"cd {model_script_path.parent}", + "", + "# Run the model script directly on baremetal", + f"echo 'Executing: bash {model_script_path.name} {model_args}'", + f"bash {model_script_path.name} {model_args}", + "", + "echo ''", + "echo 'Script completed.'", + ]) + + script_content = "\n".join(script_lines) + + # Save script + self.script_path = self.output_dir / f"madengine_{model_info['name']}.sh" + self.script_path.write_text(script_content) + self.script_path.chmod(0o755) + + self.console.print(f"[green]✓ Generated baremetal script: {self.script_path}[/green]") + self.console.print(f" Model script: {model_script_path}") + self.console.print(f" Environment: {len(env_vars)} variables") + + return True + def _prepare_template_context(self, model_info: Dict) -> Dict[str, Any]: """Prepare context for Jinja2 template rendering.""" # Use hierarchical GPU resolution: runtime > deployment > model > default @@ -835,7 +1066,12 @@ def _generate_basic_env_command( # Model script should handle launcher invocation''' def deploy(self) -> DeploymentResult: - """Submit sbatch script to SLURM scheduler (locally).""" + """ + Deploy to SLURM - either via sbatch (new job) or bash (existing allocation). + + If SLURM_JOB_ID is set (inside salloc), runs script directly with bash. + Otherwise, submits a new job via sbatch. + """ if not self.script_path or not self.script_path.exists(): return DeploymentResult( status=DeploymentStatus.FAILED, @@ -843,6 +1079,85 @@ def deploy(self) -> DeploymentResult: message="Script not generated. Run prepare() first.", ) + # ========== BRANCH: Inside allocation vs new job ========== + if self.inside_allocation: + return self._run_inside_existing_allocation() + else: + return self._submit_new_job() + + def _run_inside_existing_allocation(self) -> DeploymentResult: + """ + Run script directly inside existing salloc allocation using bash. + + The script will use the nodes already allocated to the current job. + SLURM environment variables (SLURM_NODELIST, etc.) are inherited. + """ + # Validate node count before running + is_valid, error_msg = self._validate_allocation_nodes() + if not is_valid: + return DeploymentResult( + status=DeploymentStatus.FAILED, + deployment_id=self.existing_job_id, + message=error_msg, + ) + + self.console.print( + f"\n[bold cyan]Running inside existing SLURM allocation[/bold cyan]" + ) + self.console.print(f" Job ID: {self.existing_job_id}") + self.console.print(f" Using {self.nodes} of {self.allocation_nodes} allocated nodes") + self.console.print(f" GPUs per node: {self.gpus_per_node}") + self.console.print(f" Script: {self.script_path}") + self.console.print(f"\n[dim]Executing: bash {self.script_path}[/dim]\n") + + try: + # Run script directly with bash (synchronous, blocks until done) + # Don't capture output - let it stream directly to console + result = subprocess.run( + ["bash", str(self.script_path)], + timeout=self.config.timeout if self.config.timeout > 0 else None, + ) + + if result.returncode == 0: + self.console.print( + f"\n[green]✓ Script completed successfully in allocation {self.existing_job_id}[/green]" + ) + return DeploymentResult( + status=DeploymentStatus.SUCCESS, + deployment_id=self.existing_job_id, + message=f"Completed inside existing allocation {self.existing_job_id}", + logs_path=str(self.output_dir), + ) + else: + self.console.print( + f"\n[red]✗ Script failed with exit code {result.returncode}[/red]" + ) + return DeploymentResult( + status=DeploymentStatus.FAILED, + deployment_id=self.existing_job_id, + message=f"Script failed with exit code {result.returncode}", + logs_path=str(self.output_dir), + ) + + except subprocess.TimeoutExpired: + self.console.print( + f"\n[red]✗ Script timed out after {self.config.timeout}s[/red]" + ) + return DeploymentResult( + status=DeploymentStatus.FAILED, + deployment_id=self.existing_job_id, + message=f"Script timed out after {self.config.timeout}s", + ) + except Exception as e: + self.console.print(f"\n[red]✗ Execution error: {e}[/red]") + return DeploymentResult( + status=DeploymentStatus.FAILED, + deployment_id=self.existing_job_id, + message=f"Execution error: {str(e)}", + ) + + def _submit_new_job(self) -> DeploymentResult: + """Submit new SLURM job via sbatch (original behavior).""" # ==================== PREFLIGHT NODE SELECTION ==================== # For multi-node jobs with Ray/vLLM, check for clean nodes first # to avoid OOM errors from stale processes @@ -927,6 +1242,15 @@ def deploy(self) -> DeploymentResult: def monitor(self, deployment_id: str) -> DeploymentResult: """Check SLURM job status (locally).""" + # If we ran inside an existing allocation, script already completed synchronously + # No need to poll - just return success (deploy() already handled the result) + if self.inside_allocation: + return DeploymentResult( + status=DeploymentStatus.SUCCESS, + deployment_id=deployment_id, + message=f"Completed (ran inside existing allocation {deployment_id})", + ) + try: # Query job status using squeue (runs locally) result = subprocess.run( @@ -1242,6 +1566,14 @@ def collect_results(self, deployment_id: str) -> Dict[str, Any]: def cleanup(self, deployment_id: str) -> bool: """Cancel SLURM job if still running (locally).""" + # CRITICAL: Never cancel an existing allocation we're running inside! + # The user's salloc session should not be terminated by madengine + if self.inside_allocation: + self.console.print( + f"[dim]Skipping cleanup - running inside existing allocation {deployment_id}[/dim]" + ) + return True + try: subprocess.run( ["scancel", deployment_id], capture_output=True, timeout=10 diff --git a/src/madengine/execution/container_runner.py b/src/madengine/execution/container_runner.py index ba011e81..1df9dd29 100644 --- a/src/madengine/execution/container_runner.py +++ b/src/madengine/execution/container_runner.py @@ -13,6 +13,16 @@ import typing import warnings import re +import subprocess + +# Launchers that should run directly on baremetal (not inside Docker) +# These launchers manage their own Docker containers via SLURM srun commands +BAREMETAL_LAUNCHERS = [ + "sglang-disagg", + "sglang_disagg", + "vllm-disagg", + "vllm_disagg", +] from rich.console import Console as RichConsole from contextlib import redirect_stdout, redirect_stderr from madengine.core.console import Console @@ -580,6 +590,151 @@ def apply_tools( else: print(f" Note: Command '{cmd}' already added by another tool, skipping duplicate.") + def _run_on_baremetal( + self, + model_info: typing.Dict, + build_info: typing.Dict, + log_file_path: str, + timeout: int, + run_results: typing.Dict, + pre_encapsulate_post_scripts: typing.Dict, + run_env: typing.Dict, + ) -> typing.Dict: + """ + Run script directly on baremetal (not inside Docker). + + Used for launchers like sglang-disagg that manage their own Docker containers + via SLURM srun commands. The script is executed directly on the node. + + Args: + model_info: Model configuration from manifest + build_info: Build information from manifest + log_file_path: Path to log file + timeout: Execution timeout in seconds + run_results: Dictionary to store run results + pre_encapsulate_post_scripts: Pre/post script configuration + run_env: Environment variables for the script + + Returns: + Dictionary with run results + """ + import shutil + + self.rich_console.print(f"[dim]{'='*80}[/dim]") + + # Prepare script path + scripts_arg = model_info["scripts"] + + # Get the current working directory (might be temp workspace) + cwd = os.getcwd() + print(f"📂 Current directory: {cwd}") + + if scripts_arg.endswith(".sh") or scripts_arg.endswith(".slurm"): + script_path = scripts_arg + script_name = os.path.basename(scripts_arg) + elif scripts_arg.endswith(".py"): + script_path = scripts_arg + script_name = os.path.basename(scripts_arg) + else: + # Directory specified - look for run.sh + script_path = os.path.join(scripts_arg, "run.sh") + script_name = "run.sh" + + # If script path is relative, make it absolute from cwd + if not os.path.isabs(script_path): + script_path = os.path.join(cwd, script_path) + + # Check script exists + if not os.path.exists(script_path): + print(f"⚠️ Script not found at: {script_path}") + # Try alternative locations + alt_path = os.path.join(cwd, os.path.basename(scripts_arg)) + if os.path.exists(alt_path): + script_path = alt_path + print(f"✓ Found at alternative location: {script_path}") + else: + raise FileNotFoundError(f"Script not found: {script_path}") + + script_dir = os.path.dirname(script_path) or cwd + print(f"📜 Script: {script_path}") + print(f"📁 Working directory: {script_dir}") + + # Prepare model arguments + model_args = self.context.ctx.get("model_args", model_info.get("args", "")) + print(f"📝 Arguments: {model_args}") + + # Build command + if script_path.endswith(".py"): + cmd = f"python3 {script_path} {model_args}" + else: + cmd = f"bash {script_path} {model_args}" + + print(f"🔧 Command: {cmd}") + + # Prepare environment + env = os.environ.copy() + env.update(run_env) + + # Add model-specific env vars from model_info + if "env_vars" in model_info and model_info["env_vars"]: + for key, value in model_info["env_vars"].items(): + env[key] = str(value) + print(f" ENV: {key}={value}") + + # Add env vars from additional_context + if self.additional_context and "env_vars" in self.additional_context: + for key, value in self.additional_context["env_vars"].items(): + env[key] = str(value) + + # Run script with logging + test_start_time = time.time() + self.rich_console.print("\n[bold blue]Running script on baremetal...[/bold blue]") + + try: + with open(log_file_path, mode="w", buffering=1) as outlog: + with redirect_stdout( + PythonicTee(outlog, self.live_output) + ), redirect_stderr(PythonicTee(outlog, self.live_output)): + print(f"⏰ Setting timeout to {timeout} seconds.") + print(f"🚀 Executing: {cmd}") + print(f"📂 Working directory: {script_dir}") + print(f"{'='*80}") + + result = subprocess.run( + cmd, + shell=True, + cwd=script_dir, + env=env, + timeout=timeout if timeout > 0 else None, + ) + + run_results["test_duration"] = time.time() - test_start_time + print(f"\n{'='*80}") + print(f"⏱️ Test Duration: {run_results['test_duration']:.2f} seconds") + + if result.returncode == 0: + run_results["status"] = "SUCCESS" + self.rich_console.print("[bold green]✓ Script completed successfully[/bold green]") + else: + run_results["status"] = "FAILURE" + run_results["status_detail"] = f"Exit code {result.returncode}" + self.rich_console.print(f"[bold red]✗ Script failed with exit code {result.returncode}[/bold red]") + raise subprocess.CalledProcessError(result.returncode, cmd) + + except subprocess.TimeoutExpired: + run_results["status"] = "FAILURE" + run_results["status_detail"] = f"Timeout after {timeout}s" + run_results["test_duration"] = time.time() - test_start_time + self.rich_console.print(f"[bold red]✗ Script timed out after {timeout}s[/bold red]") + raise + except Exception as e: + run_results["status"] = "FAILURE" + run_results["status_detail"] = str(e) + run_results["test_duration"] = time.time() - test_start_time + raise + + return run_results + def run_pre_post_script( self, model_docker: Docker, model_dir: str, pre_post: typing.List ) -> None: @@ -813,6 +968,15 @@ def run_container( if merged_count > 0: print(f"ℹ️ Merged {merged_count} environment variables from additional_context") + # Merge env_vars from model_info (models.json) into docker_env_vars + if "env_vars" in model_info and model_info["env_vars"]: + model_env_count = 0 + for key, value in model_info["env_vars"].items(): + self.context.ctx["docker_env_vars"][key] = str(value) + model_env_count += 1 + if model_env_count > 0: + print(f"ℹ️ Merged {model_env_count} environment variables from model_info (models.json)") + if "data" in model_info and model_info["data"] != "" and self.data: mount_datapaths = self.data.get_mountpaths(model_info["data"]) model_dataenv = self.data.get_env(model_info["data"]) @@ -874,6 +1038,44 @@ def run_container( print(f"Docker options: {docker_options}") + # ========== CHECK FOR BAREMETAL LAUNCHERS ========== + # Launchers like sglang-disagg run scripts directly on baremetal, + # not inside Docker. The script itself manages Docker containers via srun. + launcher = "" + + # Debug: Print all sources + print(f"🔍 Baremetal check - looking for launcher...") + print(f" MAD_LAUNCHER_TYPE env: {os.environ.get('MAD_LAUNCHER_TYPE', '')}") + if self.additional_context: + distributed_config = self.additional_context.get("distributed", {}) + launcher = distributed_config.get("launcher", "") + print(f" additional_context.distributed.launcher: {launcher or ''}") + if not launcher and model_info.get("distributed"): + launcher = model_info["distributed"].get("launcher", "") + print(f" model_info.distributed.launcher: {launcher or ''}") + if not launcher: + launcher = os.environ.get("MAD_LAUNCHER_TYPE", "") + print(f" Fallback to MAD_LAUNCHER_TYPE: {launcher or ''}") + + print(f" Final launcher detected: {launcher or ''}") + + # Normalize launcher name (replace underscores with hyphens) + launcher_normalized = launcher.lower().replace("_", "-") if launcher else "" + + if launcher_normalized and launcher_normalized in [l.lower().replace("_", "-") for l in BAREMETAL_LAUNCHERS]: + self.rich_console.print(f"\n[bold cyan]🖥️ Running on BAREMETAL (launcher: {launcher})[/bold cyan]") + self.rich_console.print(f"[dim]Script will manage its own Docker containers via SLURM[/dim]") + return self._run_on_baremetal( + model_info=model_info, + build_info=build_info, + log_file_path=log_file_path, + timeout=timeout, + run_results=run_results, + pre_encapsulate_post_scripts=pre_encapsulate_post_scripts, + run_env=run_env, + ) + # ========== END BAREMETAL CHECK ========== + self.rich_console.print(f"\n[bold blue]🏃 Starting Docker container execution...[/bold blue]") print(f"🏷️ Image: {docker_image}") print(f"📦 Container: {container_name}") @@ -992,8 +1194,8 @@ def run_container( # Prepare script execution scripts_arg = model_info["scripts"] - if scripts_arg.endswith(".sh"): - # Shell script specified directly + if scripts_arg.endswith(".sh") or scripts_arg.endswith(".slurm"): + # Shell script specified directly (.sh or .slurm for SLURM batch scripts) dir_path = os.path.dirname(scripts_arg) script_name = "bash " + os.path.basename(scripts_arg) elif scripts_arg.endswith(".py"): diff --git a/src/madengine/orchestration/build_orchestrator.py b/src/madengine/orchestration/build_orchestrator.py index 49ee76c2..7be37c6f 100644 --- a/src/madengine/orchestration/build_orchestrator.py +++ b/src/madengine/orchestration/build_orchestrator.py @@ -178,6 +178,8 @@ def execute( clean_cache: bool = False, manifest_output: str = "build_manifest.json", batch_build_metadata: Optional[Dict] = None, + use_image: Optional[str] = None, + build_on_compute: bool = False, ) -> str: """ Execute build workflow. @@ -187,6 +189,8 @@ def execute( clean_cache: Whether to use --no-cache for Docker builds manifest_output: Output file for build manifest batch_build_metadata: Optional batch build metadata + use_image: Pre-built Docker image to use (skip Docker build) + build_on_compute: Build on SLURM compute node instead of login node Returns: Path to generated build_manifest.json @@ -195,6 +199,21 @@ def execute( DiscoveryError: If model discovery fails BuildError: If Docker build fails """ + # Handle pre-built image mode + if use_image: + return self._execute_with_prebuilt_image( + use_image=use_image, + manifest_output=manifest_output, + ) + + # Handle build-on-compute mode + if build_on_compute: + return self._execute_build_on_compute( + registry=registry, + clean_cache=clean_cache, + manifest_output=manifest_output, + batch_build_metadata=batch_build_metadata, + ) self.rich_console.print(f"\n[dim]{'=' * 60}[/dim]") self.rich_console.print("[bold blue]🔨 BUILD PHASE[/bold blue]") self.rich_console.print("[yellow](Build-only mode - no GPU detection)[/yellow]") @@ -418,3 +437,267 @@ def _save_deployment_config(self, manifest_file: str): # Non-fatal - just warn self.rich_console.print(f"[yellow]Warning: Could not save deployment config: {e}[/yellow]") + def _execute_with_prebuilt_image( + self, + use_image: str, + manifest_output: str = "build_manifest.json", + ) -> str: + """ + Generate manifest for a pre-built Docker image (skip Docker build). + + This is useful when using external images like: + - lmsysorg/sglang:v0.5.2rc1-rocm700-mi30x + - nvcr.io/nvidia/pytorch:24.01-py3 + + Args: + use_image: Pre-built Docker image name + manifest_output: Output file for build manifest + + Returns: + Path to generated build_manifest.json + """ + self.rich_console.print(f"\n[dim]{'=' * 60}[/dim]") + self.rich_console.print("[bold blue]🔨 BUILD PHASE (Pre-built Image Mode)[/bold blue]") + self.rich_console.print(f"[cyan]Using pre-built image: {use_image}[/cyan]") + self.rich_console.print(f"[dim]{'=' * 60}[/dim]\n") + + try: + # Step 1: Discover models + self.rich_console.print("[bold cyan]🔍 Discovering models...[/bold cyan]") + discover_models = DiscoverModels(args=self.args) + models = discover_models.run() + + if not models: + raise DiscoveryError( + "No models discovered", + context=create_error_context( + operation="discover_models", + component="BuildOrchestrator", + ), + suggestions=[ + "Check if models.json exists", + "Verify --tags parameter is correct", + ], + ) + + self.rich_console.print(f"[green]✓ Found {len(models)} models[/green]\n") + + # Step 2: Generate manifest with pre-built image + self.rich_console.print("[bold cyan]📄 Generating manifest for pre-built image...[/bold cyan]") + + manifest = { + "built_images": { + use_image: { + "image_name": use_image, + "dockerfile": "", + "build_time": 0, + "prebuilt": True, + } + }, + "built_models": {}, + "context": self.context.ctx if hasattr(self.context, 'ctx') else {}, + "credentials_required": [], + "summary": { + "successful_builds": [], + "failed_builds": [], + "total_build_time": 0, + "successful_pushes": [], + "failed_pushes": [], + }, + } + + # Add each discovered model with the pre-built image + for model in models: + model_name = model.get("name", "unknown") + manifest["built_models"][model_name] = { + "name": model_name, + "image": use_image, + "dockerfile": model.get("dockerfile", ""), + "scripts": model.get("scripts", ""), + "data": model.get("data", ""), + "n_gpus": model.get("n_gpus", "8"), + "owner": model.get("owner", ""), + "training_precision": model.get("training_precision", ""), + "multiple_results": model.get("multiple_results", ""), + "tags": model.get("tags", []), + "timeout": model.get("timeout", -1), + "args": model.get("args", ""), + "slurm": model.get("slurm", {}), + "distributed": model.get("distributed", {}), + "env_vars": model.get("env_vars", {}), + "prebuilt": True, + } + manifest["summary"]["successful_builds"].append(model_name) + + # Save manifest + with open(manifest_output, "w") as f: + json.dump(manifest, f, indent=2) + + # Save deployment config + self._save_deployment_config(manifest_output) + + self.rich_console.print(f"[green]✓ Generated manifest: {manifest_output}[/green]") + self.rich_console.print(f" Pre-built image: {use_image}") + self.rich_console.print(f" Models: {len(models)}") + self.rich_console.print(f"[dim]{'=' * 60}[/dim]\n") + + return manifest_output + + except (DiscoveryError, BuildError): + raise + except Exception as e: + raise BuildError( + f"Failed to generate manifest for pre-built image: {e}", + context=create_error_context( + operation="prebuilt_manifest", + component="BuildOrchestrator", + ), + ) from e + + def _execute_build_on_compute( + self, + registry: Optional[str] = None, + clean_cache: bool = False, + manifest_output: str = "build_manifest.json", + batch_build_metadata: Optional[Dict] = None, + ) -> str: + """ + Execute Docker build on a SLURM compute node instead of login node. + + This submits a SLURM job that runs the Docker build on a compute node, + which is useful when: + - Login node has limited disk space + - Login node shouldn't run heavy workloads + - Compute nodes have faster storage/network + + Args: + registry: Optional registry to push images to + clean_cache: Whether to use --no-cache for Docker builds + manifest_output: Output file for build manifest + batch_build_metadata: Optional batch build metadata + + Returns: + Path to generated build_manifest.json + """ + import subprocess + import os + + self.rich_console.print(f"\n[dim]{'=' * 60}[/dim]") + self.rich_console.print("[bold blue]🔨 BUILD PHASE (Compute Node Mode)[/bold blue]") + self.rich_console.print("[cyan]Building on SLURM compute node...[/cyan]") + self.rich_console.print(f"[dim]{'=' * 60}[/dim]\n") + + # Check if we're inside an existing allocation + inside_allocation = os.environ.get("SLURM_JOB_ID") is not None + existing_job_id = os.environ.get("SLURM_JOB_ID", "") + + # Get SLURM config from additional_context + slurm_config = self.additional_context.get("slurm", {}) + partition = slurm_config.get("partition", "gpu") + reservation = slurm_config.get("reservation", "") + time_limit = slurm_config.get("time", "02:00:00") + + # Build the madengine build command (without --build-on-compute to avoid recursion) + tags = getattr(self.args, 'tags', []) + tags_str = " ".join([f"-t {tag}" for tag in tags]) if tags else "" + + additional_context_str = "" + if self.additional_context: + # Serialize additional context for the compute node + import json + ctx_json = json.dumps(self.additional_context) + additional_context_str = f"--additional-context '{ctx_json}'" + + build_cmd = f"madengine build {tags_str} {additional_context_str} --manifest-output {manifest_output}" + if registry: + build_cmd += f" --registry {registry}" + if clean_cache: + build_cmd += " --clean-docker-cache" + + if inside_allocation: + # Run build on compute node via srun + self.rich_console.print(f"[cyan]Running build via srun (inside allocation {existing_job_id})...[/cyan]") + cmd = ["srun", "-N1", "--ntasks=1", "bash", "-c", build_cmd] + else: + # Generate and submit build script + self.rich_console.print("[cyan]Submitting build job via sbatch...[/cyan]") + + build_script_content = f"""#!/bin/bash +#SBATCH --job-name=madengine-build +#SBATCH --partition={partition} +#SBATCH --nodes=1 +#SBATCH --ntasks=1 +#SBATCH --time={time_limit} +{f'#SBATCH --reservation={reservation}' if reservation else ''} +#SBATCH --output=madengine_build_%j.out +#SBATCH --error=madengine_build_%j.err + +echo "=== Building on compute node: $(hostname) ===" +echo "Job ID: $SLURM_JOB_ID" +echo "Build command: {build_cmd}" +echo "" + +# Activate virtual environment if available +if [ -f "venv/bin/activate" ]; then + source venv/bin/activate +fi + +# Run the build +{build_cmd} + +echo "" +echo "=== Build completed ===" +""" + build_script_path = Path("madengine_build_job.sh") + build_script_path.write_text(build_script_content) + build_script_path.chmod(0o755) + + self.rich_console.print(f" Build script: {build_script_path}") + cmd = ["sbatch", "--wait", str(build_script_path)] + + # Execute the build + self.rich_console.print(f" Command: {' '.join(cmd)}") + self.rich_console.print("") + + try: + result = subprocess.run( + cmd, + capture_output=False, # Let output flow to console + text=True, + ) + + if result.returncode != 0: + raise BuildError( + f"Build on compute node failed with exit code {result.returncode}", + context=create_error_context( + operation="build_on_compute", + component="BuildOrchestrator", + ), + suggestions=[ + "Check the build log files (madengine_build_*.out/err)", + "Verify SLURM partition and reservation settings", + "Ensure Docker is available on compute nodes", + ], + ) + + self.rich_console.print(f"[green]✓ Build completed on compute node[/green]") + self.rich_console.print(f"[green]✓ Manifest: {manifest_output}[/green]") + return manifest_output + + except subprocess.TimeoutExpired: + raise BuildError( + "Build on compute node timed out", + context=create_error_context( + operation="build_on_compute", + component="BuildOrchestrator", + ), + ) + except Exception as e: + raise BuildError( + f"Failed to build on compute node: {e}", + context=create_error_context( + operation="build_on_compute", + component="BuildOrchestrator", + ), + ) from e +