From c94c55118b2662d9b01798d413f65083c43ef06a Mon Sep 17 00:00:00 2001 From: Stephen Shao Date: Mon, 23 Jun 2025 22:49:06 -0400 Subject: [PATCH 01/17] Implemented runner of ssh for multinodes usage --- runners/ssh/README.md | 199 ++++++++++++ runners/ssh/config.ini.example | 37 +++ runners/ssh/example.sh | 39 +++ runners/ssh/quick-start.sh | 43 +++ runners/ssh/run.py | 575 +++++++++++++++++++++++++++++++++ runners/ssh/test_runner.py | 270 ++++++++++++++++ 6 files changed, 1163 insertions(+) create mode 100644 runners/ssh/README.md create mode 100644 runners/ssh/config.ini.example create mode 100644 runners/ssh/example.sh create mode 100644 runners/ssh/quick-start.sh create mode 100644 runners/ssh/run.py create mode 100644 runners/ssh/test_runner.py diff --git a/runners/ssh/README.md b/runners/ssh/README.md new file mode 100644 index 00000000..90d83d6b --- /dev/null +++ b/runners/ssh/README.md @@ -0,0 +1,199 @@ +# SSH Multi-Node Runner for MAD Engine + +This SSH runner automates the execution of PyTorch Megatron-LM training across multiple nodes using SSH connections. + +## Features + +- ✅ Automated SSH connection management +- ✅ Parallel execution across multiple nodes +- ✅ Real-time output streaming from all nodes +- ✅ Robust error handling and connectivity checking +- ✅ Support for both SSH key and password authentication +- ✅ Configurable network interfaces (NCCL/GLOO) +- ✅ Shared filesystem support + +## Prerequisites + +1. **Python Dependencies**: + ```bash + pip install -r requirements.txt + ``` + + Or use the quick-start script: + ```bash + bash quick-start.sh + ``` + +2. **SSH Access**: Ensure you have SSH access to all target nodes with either: + - SSH key-based authentication (recommended) + - Password-based authentication + +3. **Shared Filesystem**: All nodes should have access to a shared filesystem for data (e.g., NFS mount) + +4. **MAD Engine**: Ensure `madengine` is installed and accessible on all target nodes + +## Usage + +### Basic Usage with SSH Key + +```bash +python run.py --model pyt_megatron_lm_train_llama2_7b \ + --nodes 192.168.1.1,192.168.1.2 \ + --master-addr 192.168.0.1 \ + --ssh-user ubuntu \ + --ssh-key ~/.ssh/id_rsa \ + --shared-data-path /nfs/data +``` + +### Usage with Password Authentication + +```bash +python run.py --model pyt_megatron_lm_train_llama2_7b \ + --nodes node1.cluster.com,node2.cluster.com \ + --ssh-user root \ + --ssh-password mypassword \ + --shared-data-path /shared/data +``` + +### Advanced Configuration + +```bash +python run.py --model pyt_megatron_lm_train_llama2_7b \ + --nodes 192.168.1.10,192.168.1.11,192.168.1.12 \ + --master-addr 192.168.1.10 \ + --master-port 5000 \ + --ssh-user mluser \ + --ssh-key /home/user/.ssh/cluster_key \ + --shared-data-path /mnt/nfs/datasets \ + --nccl-interface eth0 \ + --gloo-interface eth0 \ + --timeout 7200 \ + --additional-args "--some-extra-flag" +``` + +## Command Line Arguments + +### Required Arguments + +- `--model`: Model tag to run (e.g., `pyt_megatron_lm_train_llama2_7b`) +- `--nodes`: Comma-separated list of node hostnames/IPs +- `--ssh-user`: SSH username for all nodes + +### Authentication (one required) + +- `--ssh-password`: SSH password for all nodes +- `--ssh-key`: Path to SSH private key file + +### Optional Arguments + +- `--master-addr`: Master node address (defaults to first node) +- `--master-port`: Master node port (default: 4000) +- `--shared-data-path`: Path to shared data filesystem (default: /nfs/data) +- `--nccl-interface`: NCCL socket interface (default: ens14np0) +- `--gloo-interface`: GLOO socket interface (default: ens14np0) +- `--timeout`: Execution timeout in seconds (default: 3600) +- `--madengine-path`: Path to madengine executable (default: madengine) +- `--additional-args`: Additional arguments to pass to madengine + +## How It Works + +1. **Connectivity Check**: Verifies SSH connectivity to all nodes +2. **Command Generation**: Builds appropriate `madengine` commands for each node with correct `NODE_RANK` +3. **Parallel Execution**: Executes commands on all nodes simultaneously using threading +4. **Output Streaming**: Streams real-time output from all nodes with node identification +5. **Result Aggregation**: Collects and reports results from all nodes + +## Example Output + +``` +🌐 Starting multi-node training on 2 nodes +📋 Model: pyt_megatron_lm_train_llama2_7b +🏠 Master: 192.168.0.1:4000 +📁 Shared data: /nfs/data +🔗 Nodes: 192.168.1.1, 192.168.1.2 + +🔍 Checking connectivity to all nodes... +✓ 192.168.1.1 is reachable +✓ 192.168.1.2 is reachable +✅ All nodes are reachable + +🚀 Executing on 192.168.1.1 (rank 0): madengine run --tags pyt_megatron_lm_train_llama2_7b ... +🚀 Executing on 192.168.1.2 (rank 1): madengine run --tags pyt_megatron_lm_train_llama2_7b ... + +[192.168.1.1:0] Starting training... +[192.168.1.2:1] Starting training... +... +✅ 192.168.1.1 completed successfully +✅ 192.168.1.2 completed successfully + +📊 Training Results: +✅ Successful nodes: 2/2 +🎉 Multi-node training completed successfully! +``` + +## Network Configuration + +For optimal performance, ensure: + +1. **Network Interface**: Use the correct network interface names for `--nccl-interface` and `--gloo-interface` + ```bash + # Check available interfaces on your nodes + ssh user@node "ip addr show" + ``` + +2. **Firewall**: Ensure the master port is open between nodes + ```bash + # Example: Open port 4000 on Ubuntu/Debian + sudo ufw allow 4000 + ``` + +3. **Shared Storage**: Verify shared filesystem is mounted on all nodes + ```bash + # Check if NFS mount is available + ssh user@node "ls -la /nfs/data" + ``` + +## Troubleshooting + +### SSH Connection Issues + +- Verify SSH key permissions: `chmod 600 ~/.ssh/id_rsa` +- Test manual SSH connection: `ssh -i ~/.ssh/id_rsa user@node` +- Check SSH agent: `ssh-add ~/.ssh/id_rsa` + +### Network Communication Issues + +- Verify nodes can reach each other on the master port +- Check firewall settings +- Ensure correct network interface names + +### MAD Engine Issues + +- Verify madengine is installed on all nodes: `ssh user@node "which madengine"` +- Check shared data path exists: `ssh user@node "ls -la /nfs/data"` +- Review madengine logs for specific errors + +## Integration with MAD Engine + +This SSH runner integrates seamlessly with the MAD Engine multi-node framework: + +- Automatically configures `multi_node_args` for each node +- Sets appropriate `NODE_RANK` for each node (0, 1, 2, ...) +- Configures `NNODES` based on the number of nodes provided +- Uses `torchrun` as the distributed runner +- Handles network interface configuration for NCCL and GLOO + +The generated command for each node follows this pattern: + +```bash +madengine run --tags pyt_megatron_lm_train_llama2_7b \ + --additional-context "{'multi_node_args': { + 'RUNNER': 'torchrun', 'MASTER_ADDR': '192.168.0.1', + 'MASTER_PORT': '4000', + 'NNODES': '2', + 'NODE_RANK': '0', # Different for each node + 'NCCL_SOCKET_IFNAME': 'ens14np0', + 'GLOO_SOCKET_IFNAME': 'ens14np0' + }}" \ + --force-mirror-local /nfs/data +``` diff --git a/runners/ssh/config.ini.example b/runners/ssh/config.ini.example new file mode 100644 index 00000000..7d589566 --- /dev/null +++ b/runners/ssh/config.ini.example @@ -0,0 +1,37 @@ +# Configuration for SSH multi-node runner + +[cluster] +# Comma-separated list of node hostnames or IPs +nodes = 192.168.1.1,192.168.1.2 + +# Master node configuration +master_addr = 192.168.0.1 +master_port = 4000 + +[ssh] +# SSH authentication +user = username # Replace with your SSH username +# Use either key_file OR password (key_file is recommended) +key_file = ~/.ssh/id_rsa +# password = your_password_here + +[training] +# Model to train +model = pyt_megatron_lm_train_llama2_7b + +# Shared filesystem path where data is located +shared_data_path = /nfs/data + +# Network interfaces for distributed communication +nccl_interface = ens14np0 +gloo_interface = ens14np0 + +# Execution timeout in seconds (2 hours) +timeout = 7200 + +[madengine] +# Path to madengine executable (if not in PATH) +# madengine_path = /opt/madengine/bin/madengine + +# Additional arguments to pass to madengine +# additional_args = --live-output diff --git a/runners/ssh/example.sh b/runners/ssh/example.sh new file mode 100644 index 00000000..51f130f5 --- /dev/null +++ b/runners/ssh/example.sh @@ -0,0 +1,39 @@ +#!/bin/bash +# Example script showing how to use the SSH multi-node runner + +# Configuration +MODEL="pyt_megatron_lm_train_llama2_7b" +NODES="192.168.1.1,192.168.1.2" +MASTER_ADDR="192.168.0.1" +MASTER_PORT="4000" +SSH_USER="username" # Replace with your SSH username +SSH_KEY="~/.ssh/id_rsa" +SHARED_DATA="/nfs/data" +NCCL_INTERFACE="ens14np0" +GLOO_INTERFACE="ens14np0" + +echo "🚀 Starting multi-node training with SSH runner" +echo "📋 Model: $MODEL" +echo "🔗 Nodes: $NODES" +echo "🏠 Master: $MASTER_ADDR:$MASTER_PORT" + +# Install requirements if not already installed +if ! python -c "import paramiko" 2>/dev/null; then + echo "📦 Installing required packages..." + pip install -r requirements.txt +fi + +# Run the SSH multi-node runner +python run.py \ + --model "$MODEL" \ + --nodes "$NODES" \ + --master-addr "$MASTER_ADDR" \ + --master-port "$MASTER_PORT" \ + --ssh-user "$SSH_USER" \ + --ssh-key "$SSH_KEY" \ + --shared-data-path "$SHARED_DATA" \ + --nccl-interface "$NCCL_INTERFACE" \ + --gloo-interface "$GLOO_INTERFACE" \ + --timeout 7200 + +echo "✅ Multi-node training completed!" diff --git a/runners/ssh/quick-start.sh b/runners/ssh/quick-start.sh new file mode 100644 index 00000000..ae012161 --- /dev/null +++ b/runners/ssh/quick-start.sh @@ -0,0 +1,43 @@ +#!/bin/bash +# Quick start script for SSH Multi-Node Runner + +set -e + +echo "🚀 SSH Multi-Node Runner for MAD Engine" +echo "========================================" +echo "" + +# Check if Python is available +if ! command -v python &> /dev/null; then + echo "❌ Python is not installed or not in PATH" + exit 1 +fi + +echo "✅ Python is available" + +# Check if paramiko is installed +if ! python -c "import paramiko" 2>/dev/null; then + echo "📦 Installing paramiko..." + pip install paramiko +else + echo "✅ paramiko is already installed" +fi + +echo "" +echo "🎯 Quick Start Examples:" +echo "" +echo "1. SSH Key Authentication:" +echo " python run.py --model pyt_megatron_lm_train_llama2_7b \\" +echo " --nodes 192.168.1.1,192.168.1.2 \\" +echo " --master-addr 192.168.0.1 \\" +echo " --ssh-user ubuntu \\" +echo " --ssh-key ~/.ssh/id_rsa" +echo "" +echo "2. Configuration File:" +echo " python run.py --config config.ini" +echo "" +echo "3. Run Tests:" +echo " python test_runner.py" +echo "" +echo "📖 For detailed documentation, see README.md" +echo "✨ Ready to run multi-node training!" diff --git a/runners/ssh/run.py b/runners/ssh/run.py new file mode 100644 index 00000000..73cb4421 --- /dev/null +++ b/runners/ssh/run.py @@ -0,0 +1,575 @@ +#!/usr/bin/env python3 +"""SSH Multi-Node Runner for MAD Engine + +This script orchestrates distributed training across multiple nodes using SSH. +It automatically configures and executes madengine commands on remote nodes +for PyTorch Megatron-LM training workloads. + +Example Usage: + python run.py --model pyt_megatron_lm_train_llama2_7b \ + --nodes 192.168.1.1,192.168.1.2 \ + --master-addr 192.168.0.1 \ + --master-port 4000 \ + --ssh-user username \ + --ssh-key /path/to/ssh/key \ + --shared-data-path /nfs/data + +Copyright (c) Advanced Micro Devices, Inc. All rights reserved. +""" + +import argparse +import json +import os +import sys +import time +import threading +import socket +import configparser +from typing import List, Dict, Optional, Tuple +from concurrent.futures import ThreadPoolExecutor, as_completed + +# Third-party imports +try: + import paramiko +except ImportError: + print("Error: paramiko is required but not installed.") + print("Please install it with: pip install paramiko") + sys.exit(1) + +# Add madengine to path if needed +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..', 'src')) + +from madengine.utils.ssh_to_db import print_ssh_out + + +class SSHMultiNodeRunner: + """SSH-based multi-node runner for distributed training.""" + + def __init__(self, args: argparse.Namespace): + """Initialize the SSH multi-node runner. + + Args: + args: Command line arguments containing configuration + """ + self.args = args + self.nodes = [node.strip() for node in args.nodes.split(',')] + self.master_addr = args.master_addr or self.nodes[0] + self.master_port = str(args.master_port) + self.ssh_user = args.ssh_user + self.ssh_password = getattr(args, 'ssh_password', None) + self.ssh_key = getattr(args, 'ssh_key', None) + self.model_tag = args.model + self.shared_data_path = getattr(args, 'shared_data_path', '/nfs/data') + self.nccl_interface = getattr(args, 'nccl_interface', 'ens14np0') + self.gloo_interface = getattr(args, 'gloo_interface', 'ens14np0') + self.timeout = getattr(args, 'timeout', 3600) # 1 hour default + self.madengine_path = getattr(args, 'madengine_path', 'madengine') + self.additional_args = getattr(args, 'additional_args', '') + + # Validate configuration + self._validate_config() + + def _validate_config(self) -> None: + """Validate the configuration parameters.""" + if not self.nodes: + raise ValueError("At least one node must be specified") + + if not self.ssh_user: + raise ValueError("SSH username must be specified") + + if not self.ssh_password and not self.ssh_key: + raise ValueError("Either SSH password or SSH key must be specified") + + if not self.model_tag: + raise ValueError("Model tag must be specified") + + # Validate SSH key file exists if specified + if self.ssh_key and not os.path.exists(self.ssh_key): + raise FileNotFoundError(f"SSH key file not found: {self.ssh_key}") + + def _create_ssh_client(self) -> paramiko.SSHClient: + """Create and configure SSH client. + + Returns: + Configured SSH client + """ + ssh_client = paramiko.SSHClient() + ssh_client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + ssh_client.load_system_host_keys() + return ssh_client + + def _connect_ssh(self, hostname: str) -> paramiko.SSHClient: + """Connect to a remote host via SSH. + + Args: + hostname: The hostname or IP to connect to + + Returns: + Connected SSH client + + Raises: + Exception: If connection fails + """ + ssh_client = self._create_ssh_client() + + try: + if self.ssh_key: + ssh_client.connect( + hostname=hostname, + username=self.ssh_user, + key_filename=self.ssh_key, + timeout=30 + ) + else: + ssh_client.connect( + hostname=hostname, + username=self.ssh_user, + password=self.ssh_password, + timeout=30 + ) + + print(f"✓ Successfully connected to {hostname}") + return ssh_client + + except paramiko.ssh_exception.AuthenticationException as e: + raise Exception(f"Authentication failed for {hostname}: {e}") + except paramiko.ssh_exception.SSHException as e: + raise Exception(f"SSH error for {hostname}: {e}") + except socket.error as e: + raise Exception(f"Socket error for {hostname}: {e}") + + def _build_madengine_command(self, node_rank: int) -> str: + """Build the madengine command for a specific node. + + Args: + node_rank: The rank of this node (0-based) + + Returns: + Complete madengine command string + """ + multi_node_args = { + 'RUNNER': 'torchrun', + 'MASTER_ADDR': self.master_addr, + 'MASTER_PORT': self.master_port, + 'NNODES': str(len(self.nodes)), + 'NODE_RANK': str(node_rank), + 'NCCL_SOCKET_IFNAME': self.nccl_interface, + 'GLOO_SOCKET_IFNAME': self.gloo_interface + } + + # Build the additional context string + additional_context = f"'{{'multi_node_args': {json.dumps(multi_node_args)}}}'" + + # Build the complete command + cmd_parts = [ + self.madengine_path, + 'run', + '--tags', self.model_tag, + '--additional-context', additional_context + ] + + # Add shared data path if specified + if self.shared_data_path: + cmd_parts.extend(['--force-mirror-local', self.shared_data_path]) + + # Add any additional arguments + if self.additional_args: + cmd_parts.append(self.additional_args) + + return ' '.join(cmd_parts) + + def _execute_on_node(self, hostname: str, node_rank: int) -> Tuple[str, bool, str]: + """Execute madengine command on a single node. + + Args: + hostname: The hostname/IP of the node + node_rank: The rank of this node + + Returns: + Tuple of (hostname, success, output/error) + """ + try: + ssh_client = self._connect_ssh(hostname) + + # Build and execute the madengine command + command = self._build_madengine_command(node_rank) + print(f"🚀 Executing on {hostname} (rank {node_rank}): {command}") + + # Execute the command + stdin, stdout, stderr = ssh_client.exec_command( + command, + timeout=self.timeout + ) + + # Read output in real-time + output_lines = [] + error_lines = [] + + # Read stdout + for line in stdout: + line = line.strip() + if line: + print(f"[{hostname}:{node_rank}] {line}") + output_lines.append(line) + + # Read stderr + for line in stderr: + line = line.strip() + if line: + print(f"[{hostname}:{node_rank}] ERROR: {line}") + error_lines.append(line) + + # Get exit code + exit_code = stdout.channel.recv_exit_status() + + ssh_client.close() + + if exit_code == 0: + return hostname, True, '\n'.join(output_lines) + else: + return hostname, False, '\n'.join(error_lines) + + except Exception as e: + return hostname, False, str(e) + + def _check_node_connectivity(self) -> List[str]: + """Check connectivity to all nodes. + + Returns: + List of nodes that are reachable + """ + reachable_nodes = [] + + print("🔍 Checking connectivity to all nodes...") + + for hostname in self.nodes: + try: + ssh_client = self._connect_ssh(hostname) + + # Test basic command execution + stdin, stdout, stderr = ssh_client.exec_command('echo "connectivity_test"') + output = stdout.read().decode().strip() + + if output == "connectivity_test": + reachable_nodes.append(hostname) + print(f"✓ {hostname} is reachable") + else: + print(f"✗ {hostname} failed connectivity test") + + ssh_client.close() + + except Exception as e: + print(f"✗ {hostname} is not reachable: {e}") + + return reachable_nodes + + def _wait_for_master_ready(self, master_host: str) -> bool: + """Wait for master node to be ready to accept connections. + + Args: + master_host: The master node hostname/IP + + Returns: + True if master is ready, False if timeout + """ + print(f"⏳ Waiting for master node {master_host}:{self.master_port} to be ready...") + + max_wait_time = 60 # seconds + start_time = time.time() + + while time.time() - start_time < max_wait_time: + try: + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.settimeout(5) + result = sock.connect_ex((master_host, int(self.master_port))) + sock.close() + + if result == 0: + print(f"✓ Master node is ready") + return True + + except Exception: + pass + + time.sleep(2) + + print(f"✗ Master node did not become ready within {max_wait_time} seconds") + return False + + def run(self) -> bool: + """Run distributed training across all nodes. + + Returns: + True if all nodes completed successfully, False otherwise + """ + print(f"🌐 Starting multi-node training on {len(self.nodes)} nodes") + print(f"📋 Model: {self.model_tag}") + print(f"🏠 Master: {self.master_addr}:{self.master_port}") + print(f"📁 Shared data: {self.shared_data_path}") + print(f"🔗 Nodes: {', '.join(self.nodes)}") + + # Check connectivity to all nodes + reachable_nodes = self._check_node_connectivity() + + if len(reachable_nodes) != len(self.nodes): + unreachable = set(self.nodes) - set(reachable_nodes) + print(f"❌ Some nodes are unreachable: {', '.join(unreachable)}") + return False + + print("✅ All nodes are reachable") + + # Execute on all nodes concurrently + results = [] + + with ThreadPoolExecutor(max_workers=len(self.nodes)) as executor: + # Submit jobs for all nodes + futures = [] + for i, hostname in enumerate(self.nodes): + future = executor.submit(self._execute_on_node, hostname, i) + futures.append(future) + + # Collect results as they complete + for future in as_completed(futures): + hostname, success, output = future.result() + results.append((hostname, success, output)) + + if success: + print(f"✅ {hostname} completed successfully") + else: + print(f"❌ {hostname} failed: {output}") + + # Check overall success + successful_nodes = [r[0] for r in results if r[1]] + failed_nodes = [r[0] for r in results if not r[1]] + + print(f"\n📊 Training Results:") + print(f"✅ Successful nodes: {len(successful_nodes)}/{len(self.nodes)}") + + if failed_nodes: + print(f"❌ Failed nodes: {', '.join(failed_nodes)}") + return False + + print("🎉 Multi-node training completed successfully!") + return True + + +def load_config_file(config_path: str) -> Dict: + """Load configuration from INI file. + + Args: + config_path: Path to configuration file + + Returns: + Dictionary with configuration values + """ + if not os.path.exists(config_path): + raise FileNotFoundError(f"Configuration file not found: {config_path}") + + config = configparser.ConfigParser() + config.read(config_path) + + # Convert to dictionary with flattened keys + config_dict = {} + for section in config.sections(): + for key, value in config[section].items(): + config_dict[f"{section}_{key}"] = value + + return config_dict + + +def merge_config_and_args(config_dict: Dict, args: argparse.Namespace) -> argparse.Namespace: + """Merge configuration file values with command line arguments. + Command line arguments take precedence over config file values. + + Args: + config_dict: Configuration dictionary from file + args: Command line arguments + + Returns: + Updated arguments with config file values applied + """ + # Mapping of config keys to argument attributes + config_to_arg_map = { + 'cluster_nodes': 'nodes', + 'cluster_master_addr': 'master_addr', + 'cluster_master_port': 'master_port', + 'ssh_user': 'ssh_user', + 'ssh_key_file': 'ssh_key', + 'ssh_password': 'ssh_password', + 'training_model': 'model', + 'training_shared_data_path': 'shared_data_path', + 'training_nccl_interface': 'nccl_interface', + 'training_gloo_interface': 'gloo_interface', + 'training_timeout': 'timeout', + 'madengine_madengine_path': 'madengine_path', + 'madengine_additional_args': 'additional_args' + } + + # Apply config values only if argument was not provided + for config_key, arg_attr in config_to_arg_map.items(): + if config_key in config_dict and not getattr(args, arg_attr, None): + value = config_dict[config_key] + # Convert numeric values + if arg_attr in ['master_port', 'timeout']: + value = int(value) + setattr(args, arg_attr, value) + + return args + + +def parse_args() -> argparse.Namespace: + """Parse command line arguments. + + Returns: + Parsed arguments + """ + parser = argparse.ArgumentParser( + description="SSH Multi-Node Runner for MAD Engine", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + # Run with SSH key authentication + python run.py --model pyt_megatron_lm_train_llama2_7b \\ + --nodes 192.168.1.1,192.168.1.2 \\ + --master-addr 192.168.0.1 \\ + --ssh-user ubuntu \\ + --ssh-key ~/.ssh/id_rsa + + # Run with password authentication + python run.py --model pyt_megatron_lm_train_llama2_7b \\ + --nodes node1,node2,node3 \\ + --ssh-user root \\ + --ssh-password mypassword \\ + --shared-data-path /shared/data + + # Run with configuration file + python run.py --config config.ini + """ + ) + + # Configuration file option + parser.add_argument( + '--config', + help='Path to configuration file (INI format)' + ) + + # Required arguments (unless provided in config) + parser.add_argument( + '--model', + help='Model tag to run (e.g., pyt_megatron_lm_train_llama2_7b)' + ) + + parser.add_argument( + '--nodes', + help='Comma-separated list of node hostnames/IPs' + ) + + parser.add_argument( + '--ssh-user', + help='SSH username for all nodes' + ) + + # SSH authentication (one required unless in config) + parser.add_argument( + '--ssh-password', + help='SSH password for all nodes' + ) + parser.add_argument( + '--ssh-key', + help='Path to SSH private key file' + ) + + # Optional arguments + parser.add_argument( + '--master-addr', + help='Master node address (defaults to first node)' + ) + + parser.add_argument( + '--master-port', + type=int, + default=4000, + help='Master node port (default: 4000)' + ) + + parser.add_argument( + '--shared-data-path', + default='/nfs/data', + help='Path to shared data filesystem (default: /nfs/data)' + ) + + parser.add_argument( + '--nccl-interface', + default='ens14np0', + help='NCCL socket interface (default: ens14np0)' + ) + + parser.add_argument( + '--gloo-interface', + default='ens14np0', + help='GLOO socket interface (default: ens14np0)' + ) + + parser.add_argument( + '--timeout', + type=int, + default=3600, + help='Execution timeout in seconds (default: 3600)' + ) + + parser.add_argument( + '--madengine-path', + default='madengine', + help='Path to madengine executable (default: madengine)' + ) + + parser.add_argument( + '--additional-args', + help='Additional arguments to pass to madengine' + ) + + # Parse arguments + args = parser.parse_args() + + # Load configuration file if provided + if args.config: + config_dict = load_config_file(args.config) + args = merge_config_and_args(config_dict, args) + + # Validate required arguments after config loading + if not args.model: + parser.error("--model is required (can be provided via config file)") + if not args.nodes: + parser.error("--nodes is required (can be provided via config file)") + if not args.ssh_user: + parser.error("--ssh-user is required (can be provided via config file)") + if not args.ssh_password and not args.ssh_key: + parser.error("Either --ssh-password or --ssh-key is required (can be provided via config file)") + + return args + + +def main(): + """Main entry point.""" + try: + args = parse_args() + runner = SSHMultiNodeRunner(args) + success = runner.run() + + if success: + print("🎯 All nodes completed successfully!") + sys.exit(0) + else: + print("💥 Some nodes failed!") + sys.exit(1) + + except KeyboardInterrupt: + print("\n🛑 Interrupted by user") + sys.exit(1) + except Exception as e: + print(f"💥 Error: {e}") + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/runners/ssh/test_runner.py b/runners/ssh/test_runner.py new file mode 100644 index 00000000..9822f9dc --- /dev/null +++ b/runners/ssh/test_runner.py @@ -0,0 +1,270 @@ +#!/usr/bin/env python3 +"""Test script for SSH Multi-Node Runner + +This script provides basic unit tests and validation for the SSH runner. +""" + +import os +import sys +import tempfile +import unittest +from unittest.mock import patch, MagicMock + +# Add the current directory to the Python path +sys.path.insert(0, os.path.dirname(__file__)) + +try: + from run import SSHMultiNodeRunner, load_config_file, merge_config_and_args, parse_args +except ImportError: + print("Error: Could not import run module. Make sure run.py is in the same directory.") + sys.exit(1) + + +class TestSSHMultiNodeRunner(unittest.TestCase): + """Test cases for SSH Multi-Node Runner.""" + + def setUp(self): + """Set up test fixtures.""" + self.mock_args = MagicMock() + self.mock_args.model = 'pyt_megatron_lm_train_llama2_7b' + self.mock_args.nodes = '10.0.0.1,10.0.0.2' + self.mock_args.master_addr = '10.0.0.1' + self.mock_args.master_port = 4000 + self.mock_args.ssh_user = 'testuser' + self.mock_args.ssh_password = 'testpass' + self.mock_args.ssh_key = None + self.mock_args.shared_data_path = '/nfs/data' + self.mock_args.nccl_interface = 'eth0' + self.mock_args.gloo_interface = 'eth0' + self.mock_args.timeout = 3600 + self.mock_args.madengine_path = 'madengine' + self.mock_args.additional_args = '' + + def test_initialization(self): + """Test runner initialization.""" + runner = SSHMultiNodeRunner(self.mock_args) + + self.assertEqual(runner.nodes, ['10.0.0.1', '10.0.0.2']) + self.assertEqual(runner.master_addr, '10.0.0.1') + self.assertEqual(runner.master_port, '4000') + self.assertEqual(runner.ssh_user, 'testuser') + self.assertEqual(runner.model_tag, 'pyt_megatron_lm_train_llama2_7b') + + def test_command_generation(self): + """Test madengine command generation.""" + runner = SSHMultiNodeRunner(self.mock_args) + + # Test command for node rank 0 + cmd_0 = runner._build_madengine_command(0) + self.assertIn('madengine run', cmd_0) + self.assertIn('pyt_megatron_lm_train_llama2_7b', cmd_0) + self.assertIn('"NODE_RANK": "0"', cmd_0) + self.assertIn('"NNODES": "2"', cmd_0) + self.assertIn('--force-mirror-local /nfs/data', cmd_0) + + # Test command for node rank 1 + cmd_1 = runner._build_madengine_command(1) + self.assertIn('"NODE_RANK": "1"', cmd_1) + + def test_validation_errors(self): + """Test configuration validation errors.""" + # Test missing nodes + self.mock_args.nodes = '' + with self.assertRaises(ValueError): + SSHMultiNodeRunner(self.mock_args) + + # Reset nodes + self.mock_args.nodes = '10.0.0.1,10.0.0.2' + + # Test missing SSH user + self.mock_args.ssh_user = '' + with self.assertRaises(ValueError): + SSHMultiNodeRunner(self.mock_args) + + # Reset SSH user + self.mock_args.ssh_user = 'testuser' + + # Test missing authentication + self.mock_args.ssh_password = None + self.mock_args.ssh_key = None + with self.assertRaises(ValueError): + SSHMultiNodeRunner(self.mock_args) + + def test_config_file_loading(self): + """Test configuration file loading.""" + # Create a temporary config file + config_content = """ +[cluster] +nodes = node1,node2,node3 +master_addr = node1 +master_port = 5000 + +[ssh] +user = testuser +key_file = /path/to/key + +[training] +model = test_model +shared_data_path = /shared/data +""" + + with tempfile.NamedTemporaryFile(mode='w', suffix='.ini', delete=False) as f: + f.write(config_content) + temp_config_path = f.name + + try: + config_dict = load_config_file(temp_config_path) + + self.assertEqual(config_dict['cluster_nodes'], 'node1,node2,node3') + self.assertEqual(config_dict['cluster_master_addr'], 'node1') + self.assertEqual(config_dict['cluster_master_port'], '5000') + self.assertEqual(config_dict['ssh_user'], 'testuser') + self.assertEqual(config_dict['ssh_key_file'], '/path/to/key') + self.assertEqual(config_dict['training_model'], 'test_model') + + finally: + os.unlink(temp_config_path) + + def test_config_file_not_found(self): + """Test error handling for missing config file.""" + with self.assertRaises(FileNotFoundError): + load_config_file('/nonexistent/config.ini') + + +class MockSSHIntegrationTest(unittest.TestCase): + """Integration tests with mocked SSH connections.""" + + def setUp(self): + """Set up test fixtures.""" + self.mock_args = MagicMock() + self.mock_args.model = 'pyt_megatron_lm_train_llama2_7b' + self.mock_args.nodes = '10.0.0.1,10.0.0.2' + self.mock_args.master_addr = '10.0.0.1' + self.mock_args.master_port = 4000 + self.mock_args.ssh_user = 'testuser' + self.mock_args.ssh_password = 'testpass' + self.mock_args.ssh_key = None + self.mock_args.shared_data_path = '/nfs/data' + self.mock_args.nccl_interface = 'eth0' + self.mock_args.gloo_interface = 'eth0' + self.mock_args.timeout = 3600 + self.mock_args.madengine_path = 'madengine' + self.mock_args.additional_args = '' + + @patch('run.paramiko.SSHClient') + def test_connectivity_check(self, mock_ssh_client_class): + """Test node connectivity checking.""" + # Mock SSH client + mock_ssh_client = MagicMock() + mock_ssh_client_class.return_value = mock_ssh_client + + # Mock successful connection and command execution + mock_stdout = MagicMock() + mock_stdout.read.return_value = b'connectivity_test' + mock_ssh_client.exec_command.return_value = (None, mock_stdout, None) + + runner = SSHMultiNodeRunner(self.mock_args) + reachable_nodes = runner._check_node_connectivity() + + self.assertEqual(len(reachable_nodes), 2) + self.assertIn('10.0.0.1', reachable_nodes) + self.assertIn('10.0.0.2', reachable_nodes) + + @patch('run.paramiko.SSHClient') + def test_command_execution(self, mock_ssh_client_class): + """Test command execution on nodes.""" + # Mock SSH client + mock_ssh_client = MagicMock() + mock_ssh_client_class.return_value = mock_ssh_client + + # Mock successful command execution + mock_stdout = MagicMock() + mock_stdout.__iter__ = lambda self: iter(['Training started...', 'Training completed!']) + mock_stdout.channel.recv_exit_status.return_value = 0 + + mock_stderr = MagicMock() + mock_stderr.__iter__ = lambda self: iter([]) + + mock_ssh_client.exec_command.return_value = (None, mock_stdout, mock_stderr) + + runner = SSHMultiNodeRunner(self.mock_args) + hostname, success, output = runner._execute_on_node('10.0.0.1', 0) + + self.assertEqual(hostname, '10.0.0.1') + self.assertTrue(success) + self.assertIn('Training started...', output) + + +def run_tests(): + """Run all tests.""" + print("🧪 Running SSH Multi-Node Runner Tests...") + + # Create test suite + loader = unittest.TestLoader() + suite = unittest.TestSuite() + + # Add test cases + suite.addTests(loader.loadTestsFromTestCase(TestSSHMultiNodeRunner)) + suite.addTests(loader.loadTestsFromTestCase(MockSSHIntegrationTest)) + + # Run tests + runner = unittest.TextTestRunner(verbosity=2) + result = runner.run(suite) + + # Return success status + return result.wasSuccessful() + + +def validate_environment(): + """Validate that the environment is set up correctly.""" + print("🔍 Validating environment...") + + issues = [] + + # Check if paramiko is available + try: + import paramiko + print("✅ paramiko is available") + except ImportError: + issues.append("❌ paramiko is not installed. Run: pip install paramiko") + + # Check if run.py exists + run_py_path = os.path.join(os.path.dirname(__file__), 'run.py') + if os.path.exists(run_py_path): + print("✅ run.py found") + else: + issues.append("❌ run.py not found in current directory") + + # Check if requirements.txt exists + req_path = os.path.join(os.path.dirname(__file__), 'requirements.txt') + if os.path.exists(req_path): + print("✅ requirements.txt found") + else: + issues.append("❌ requirements.txt not found") + + if issues: + print("\n🚨 Issues found:") + for issue in issues: + print(f" {issue}") + return False + else: + print("\n✅ Environment validation passed!") + return True + + +if __name__ == "__main__": + print("SSH Multi-Node Runner Test Suite") + print("=" * 40) + + # Validate environment first + if not validate_environment(): + print("\n💥 Environment validation failed. Please fix the issues above.") + sys.exit(1) + + # Run tests + if run_tests(): + print("\n🎉 All tests passed!") + sys.exit(0) + else: + print("\n💥 Some tests failed!") + sys.exit(1) From 7e723bbd99ad643dc78b3e784f69d88d4e8ad4f6 Mon Sep 17 00:00:00 2001 From: Stephen Shao Date: Wed, 25 Jun 2025 11:26:48 -0400 Subject: [PATCH 02/17] Fix the flow and corret the preparation at remote node --- runners/ssh/run.py | 150 ++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 147 insertions(+), 3 deletions(-) diff --git a/runners/ssh/run.py b/runners/ssh/run.py index 73cb4421..7e7d43c4 100644 --- a/runners/ssh/run.py +++ b/runners/ssh/run.py @@ -193,11 +193,14 @@ def _execute_on_node(self, hostname: str, node_rank: int) -> Tuple[str, bool, st # Build and execute the madengine command command = self._build_madengine_command(node_rank) - print(f"🚀 Executing on {hostname} (rank {node_rank}): {command}") + + # Change to DeepLearningModels directory and execute the command + full_command = f"cd DeepLearningModels && {command}" + print(f"🚀 Executing on {hostname} (rank {node_rank}): {full_command}") # Execute the command stdin, stdout, stderr = ssh_client.exec_command( - command, + full_command, timeout=self.timeout ) @@ -227,7 +230,8 @@ def _execute_on_node(self, hostname: str, node_rank: int) -> Tuple[str, bool, st if exit_code == 0: return hostname, True, '\n'.join(output_lines) else: - return hostname, False, '\n'.join(error_lines) + error_output = '\n'.join(error_lines) if error_lines else "Command failed with no error output" + return hostname, False, f"Exit code {exit_code}: {error_output}" except Exception as e: return hostname, False, str(e) @@ -296,6 +300,141 @@ def _wait_for_master_ready(self, master_host: str) -> bool: print(f"✗ Master node did not become ready within {max_wait_time} seconds") return False + def _validate_remote_node_prerequisites(self, hostname: str) -> Tuple[bool, str]: + """Validate that remote node has required prerequisites. + + Args: + hostname: The hostname/IP of the node to validate + + Returns: + Tuple of (success, error_message) + """ + try: + ssh_client = self._connect_ssh(hostname) + + # Check if DeepLearningModels folder exists + print(f"🔍 Checking DeepLearningModels folder on {hostname}...") + stdin, stdout, stderr = ssh_client.exec_command('test -d DeepLearningModels && echo "exists" || echo "missing"') + dl_models_result = stdout.read().decode().strip() + + if dl_models_result != "exists": + ssh_client.close() + return False, f"DeepLearningModels folder not found on {hostname}. Please set up the remote node with DeepLearningModels directory." + + print(f"✓ DeepLearningModels folder found on {hostname}") + + # Check if madengine is installed and accessible + print(f"🔍 Checking madengine installation on {hostname}...") + stdin, stdout, stderr = ssh_client.exec_command(f'which {self.madengine_path} > /dev/null 2>&1 && echo "found" || echo "missing"') + madengine_result = stdout.read().decode().strip() + + if madengine_result != "found": + # Try alternative check - see if madengine can be executed + stdin, stdout, stderr = ssh_client.exec_command(f'{self.madengine_path} --help > /dev/null 2>&1 && echo "found" || echo "missing"') + madengine_alt_result = stdout.read().decode().strip() + + if madengine_alt_result != "found": + ssh_client.close() + return False, f"madengine not found or not accessible on {hostname}. Please install madengine on the remote node or ensure it's in the PATH." + + print(f"✓ madengine installation found on {hostname}") + + # Check if we can access the DeepLearningModels directory + print(f"🔍 Checking access to DeepLearningModels directory on {hostname}...") + stdin, stdout, stderr = ssh_client.exec_command('cd DeepLearningModels && pwd') + cd_result = stdout.read().decode().strip() + cd_error = stderr.read().decode().strip() + + if cd_error or not cd_result.endswith('DeepLearningModels'): + ssh_client.close() + return False, f"Cannot access DeepLearningModels directory on {hostname}. Error: {cd_error or 'Unknown error'}" + + print(f"✓ DeepLearningModels directory is accessible on {hostname}") + + # Check if shared data path exists (if specified and not the default) + if self.shared_data_path and self.shared_data_path != '/nfs/data': + print(f"🔍 Checking shared data path on {hostname}...") + stdin, stdout, stderr = ssh_client.exec_command(f'test -d "{self.shared_data_path}" && echo "exists" || echo "missing"') + shared_data_result = stdout.read().decode().strip() + + if shared_data_result != "exists": + ssh_client.close() + return False, f"Shared data path '{self.shared_data_path}' not found on {hostname}. Please ensure the shared filesystem is mounted." + + print(f"✓ Shared data path '{self.shared_data_path}' found on {hostname}") + + print(f"✓ All checks passed for {hostname}") + + ssh_client.close() + return True, "" + + except Exception as e: + return False, f"Error validating prerequisites on {hostname}: {str(e)}" + + def _check_all_prerequisites(self) -> bool: + """Check prerequisites on all nodes. + + Returns: + True if all nodes meet prerequisites, False otherwise + """ + print("🔧 Validating prerequisites on all nodes...") + + failed_nodes = [] + + for hostname in self.nodes: + success, error_msg = self._validate_remote_node_prerequisites(hostname) + if not success: + print(f"❌ {hostname}: {error_msg}") + failed_nodes.append((hostname, error_msg)) + else: + print(f"✅ {hostname}: All prerequisites met") + + if failed_nodes: + print(f"\n💥 Prerequisites check failed for {len(failed_nodes)} node(s):") + for hostname, error_msg in failed_nodes: + print(f" • {hostname}: {error_msg}") + + self._print_setup_instructions() + return False + + print("✅ All nodes meet the prerequisites") + return True + + def _print_setup_instructions(self) -> None: + """Print setup instructions for remote nodes.""" + print("\n" + "="*60) + print("🔧 REMOTE NODE SETUP INSTRUCTIONS") + print("="*60) + print("\nTo prepare your remote nodes for multi-node training:") + print("\n1. DeepLearningModels Directory:") + print(" • Create or ensure the DeepLearningModels directory exists in the user's home directory") + print(" • Command: mkdir -p ~/DeepLearningModels") + print(" • This directory should contain your model configurations and training scripts") + + print("\n2. MAD Engine Installation:") + print(" • Install madengine on each remote node") + print(" • Command: pip install madengine") + print(" • Or ensure madengine is in the PATH and executable") + print(" • Verify with: madengine --help") + + if self.shared_data_path and self.shared_data_path != '/nfs/data': + print(f"\n3. Shared Data Path:") + print(f" • Ensure the shared data path '{self.shared_data_path}' exists and is accessible") + print(f" • This should be a shared filesystem (NFS, GPFS, etc.) mounted on all nodes") + print(f" • All nodes should have read/write access to this path") + + print("\n4. SSH Access:") + print(" • Ensure SSH key-based or password authentication is configured") + print(" • Test SSH access manually before running this script") + + print("\n5. Network Configuration:") + print(f" • Ensure nodes can communicate on the specified interfaces:") + print(f" • NCCL interface: {self.nccl_interface}") + print(f" • GLOO interface: {self.gloo_interface}") + print(f" • Master node {self.master_addr} should be accessible on port {self.master_port}") + + print("\n" + "="*60) + def run(self) -> bool: """Run distributed training across all nodes. @@ -318,6 +457,11 @@ def run(self) -> bool: print("✅ All nodes are reachable") + # Validate prerequisites on all nodes + if not self._check_all_prerequisites(): + return False + return False + # Execute on all nodes concurrently results = [] From cdb1308e6ddc07cca22990f1ec7dd1479b9212e8 Mon Sep 17 00:00:00 2001 From: Stephen Shao Date: Wed, 25 Jun 2025 11:46:28 -0400 Subject: [PATCH 03/17] Fixed all the sissues in tests, now all tests passed --- runners/ssh/__init__.py | 0 runners/ssh/requirements.txt | 2 + runners/ssh/run.py | 6 +-- runners/ssh/test_runner.py | 100 ++++++++++++++++++++++++++++++----- 4 files changed, 91 insertions(+), 17 deletions(-) create mode 100644 runners/ssh/__init__.py create mode 100644 runners/ssh/requirements.txt diff --git a/runners/ssh/__init__.py b/runners/ssh/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/runners/ssh/requirements.txt b/runners/ssh/requirements.txt new file mode 100644 index 00000000..b199d422 --- /dev/null +++ b/runners/ssh/requirements.txt @@ -0,0 +1,2 @@ +# SSH Multi-Node Runner Requirements +paramiko>=2.9.0 diff --git a/runners/ssh/run.py b/runners/ssh/run.py index 7e7d43c4..8d553606 100644 --- a/runners/ssh/run.py +++ b/runners/ssh/run.py @@ -39,8 +39,6 @@ # Add madengine to path if needed sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..', 'src')) -from madengine.utils.ssh_to_db import print_ssh_out - class SSHMultiNodeRunner: """SSH-based multi-node runner for distributed training.""" @@ -52,7 +50,7 @@ def __init__(self, args: argparse.Namespace): args: Command line arguments containing configuration """ self.args = args - self.nodes = [node.strip() for node in args.nodes.split(',')] + self.nodes = [node.strip() for node in args.nodes.split(',') if node.strip()] self.master_addr = args.master_addr or self.nodes[0] self.master_port = str(args.master_port) self.ssh_user = args.ssh_user @@ -71,7 +69,7 @@ def __init__(self, args: argparse.Namespace): def _validate_config(self) -> None: """Validate the configuration parameters.""" - if not self.nodes: + if not self.nodes or not any(node.strip() for node in self.nodes): raise ValueError("At least one node must be specified") if not self.ssh_user: diff --git a/runners/ssh/test_runner.py b/runners/ssh/test_runner.py index 9822f9dc..320f2ea2 100644 --- a/runners/ssh/test_runner.py +++ b/runners/ssh/test_runner.py @@ -69,26 +69,29 @@ def test_command_generation(self): def test_validation_errors(self): """Test configuration validation errors.""" # Test missing nodes - self.mock_args.nodes = '' + args_copy = MagicMock() + for attr in ['model', 'nodes', 'master_addr', 'master_port', 'ssh_user', 'ssh_password', 'ssh_key']: + setattr(args_copy, attr, getattr(self.mock_args, attr)) + args_copy.nodes = '' with self.assertRaises(ValueError): - SSHMultiNodeRunner(self.mock_args) - - # Reset nodes - self.mock_args.nodes = '10.0.0.1,10.0.0.2' + SSHMultiNodeRunner(args_copy) # Test missing SSH user - self.mock_args.ssh_user = '' + args_copy = MagicMock() + for attr in ['model', 'nodes', 'master_addr', 'master_port', 'ssh_user', 'ssh_password', 'ssh_key']: + setattr(args_copy, attr, getattr(self.mock_args, attr)) + args_copy.ssh_user = '' with self.assertRaises(ValueError): - SSHMultiNodeRunner(self.mock_args) - - # Reset SSH user - self.mock_args.ssh_user = 'testuser' + SSHMultiNodeRunner(args_copy) # Test missing authentication - self.mock_args.ssh_password = None - self.mock_args.ssh_key = None + args_copy = MagicMock() + for attr in ['model', 'nodes', 'master_addr', 'master_port', 'ssh_user', 'ssh_password', 'ssh_key']: + setattr(args_copy, attr, getattr(self.mock_args, attr)) + delattr(args_copy, 'ssh_password') + delattr(args_copy, 'ssh_key') with self.assertRaises(ValueError): - SSHMultiNodeRunner(self.mock_args) + SSHMultiNodeRunner(args_copy) def test_config_file_loading(self): """Test configuration file loading.""" @@ -193,6 +196,77 @@ def test_command_execution(self, mock_ssh_client_class): self.assertEqual(hostname, '10.0.0.1') self.assertTrue(success) self.assertIn('Training started...', output) + + @patch('run.paramiko.SSHClient') + def test_prerequisites_validation_success(self, mock_ssh_client_class): + """Test successful prerequisites validation.""" + # Mock SSH client + mock_ssh_client = MagicMock() + mock_ssh_client_class.return_value = mock_ssh_client + + # Mock successful prerequisites checks + def mock_exec_command(command): + mock_stdout = MagicMock() + mock_stderr = MagicMock() + if 'test -d DeepLearningModels' in command: + mock_stdout.read.return_value = b'exists' + elif 'which madengine' in command: + mock_stdout.read.return_value = b'found' + elif 'cd DeepLearningModels' in command: + mock_stdout.read.return_value = b'/home/user/DeepLearningModels' + mock_stderr.read.return_value = b'' + return (None, mock_stdout, mock_stderr) + + mock_ssh_client.exec_command.side_effect = mock_exec_command + + runner = SSHMultiNodeRunner(self.mock_args) + success, error_msg = runner._validate_remote_node_prerequisites('10.0.0.1') + + self.assertTrue(success) + self.assertEqual(error_msg, "") + + @patch('run.paramiko.SSHClient') + def test_prerequisites_validation_missing_deeplearning_models(self, mock_ssh_client_class): + """Test prerequisites validation with missing DeepLearningModels folder.""" + # Mock SSH client + mock_ssh_client = MagicMock() + mock_ssh_client_class.return_value = mock_ssh_client + + # Mock missing DeepLearningModels folder + mock_stdout = MagicMock() + mock_stdout.read.return_value = b'missing' + mock_ssh_client.exec_command.return_value = (None, mock_stdout, None) + + runner = SSHMultiNodeRunner(self.mock_args) + success, error_msg = runner._validate_remote_node_prerequisites('10.0.0.1') + + self.assertFalse(success) + self.assertIn('DeepLearningModels folder not found', error_msg) + + @patch('run.paramiko.SSHClient') + def test_prerequisites_validation_missing_madengine(self, mock_ssh_client_class): + """Test prerequisites validation with missing madengine.""" + # Mock SSH client + mock_ssh_client = MagicMock() + mock_ssh_client_class.return_value = mock_ssh_client + + # Mock DeepLearningModels exists but madengine missing + def mock_exec_command(command): + mock_stdout = MagicMock() + mock_stderr = MagicMock() + if 'test -d DeepLearningModels' in command: + mock_stdout.read.return_value = b'exists' + elif 'which madengine' in command or 'madengine --help' in command: + mock_stdout.read.return_value = b'missing' + return (None, mock_stdout, mock_stderr) + + mock_ssh_client.exec_command.side_effect = mock_exec_command + + runner = SSHMultiNodeRunner(self.mock_args) + success, error_msg = runner._validate_remote_node_prerequisites('10.0.0.1') + + self.assertFalse(success) + self.assertIn('madengine not found', error_msg) def run_tests(): From 4b374e38064e8b149f10d0baf288e9ca94909e53 Mon Sep 17 00:00:00 2001 From: Stephen Shao Date: Wed, 25 Jun 2025 12:05:12 -0400 Subject: [PATCH 04/17] Fix the error of format coding --- runners/ssh/test_runner.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/runners/ssh/test_runner.py b/runners/ssh/test_runner.py index 320f2ea2..ad85e679 100644 --- a/runners/ssh/test_runner.py +++ b/runners/ssh/test_runner.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 +# -*- coding: utf-8 -*- """Test script for SSH Multi-Node Runner This script provides basic unit tests and validation for the SSH runner. @@ -319,7 +320,7 @@ def validate_environment(): if issues: print("\n🚨 Issues found:") for issue in issues: - print(f" {issue}") + print(" {}".format(issue)) return False else: print("\n✅ Environment validation passed!") From ae04ac9a89877a7964506e535ac695eff46c3fb5 Mon Sep 17 00:00:00 2001 From: Stephen Shao Date: Wed, 25 Jun 2025 14:42:34 -0400 Subject: [PATCH 05/17] Refined the structure of ssh runner for multinodes and cleanup code --- runners/ssh/__init__.py | 46 +++ runners/ssh/config.ini.example | 29 +- runners/ssh/config_manager.py | 294 +++++++++++++ runners/ssh/monitoring.py | 340 +++++++++++++++ runners/ssh/requirements.txt | 4 +- runners/ssh/run.py | 660 +++++++++++------------------- runners/ssh/ssh_client_manager.py | 161 ++++++++ runners/ssh/test_runner.py | 483 +++++++++++++--------- runners/ssh/utils.py | 219 ++++++++++ 9 files changed, 1610 insertions(+), 626 deletions(-) create mode 100644 runners/ssh/config_manager.py create mode 100644 runners/ssh/monitoring.py create mode 100644 runners/ssh/ssh_client_manager.py create mode 100644 runners/ssh/utils.py diff --git a/runners/ssh/__init__.py b/runners/ssh/__init__.py index e69de29b..4de2b246 100644 --- a/runners/ssh/__init__.py +++ b/runners/ssh/__init__.py @@ -0,0 +1,46 @@ +"""SSH Multi-Node Runner for MAD Engine + +This package provides SSH-based multi-node distributed training capabilities +for the MAD Engine framework. + +Main Components: +- SSHMultiNodeRunner: Main orchestration class +- SSHClientManager: Robust SSH connection management +- MultiNodeConfig: Configuration management +- Configuration validation and setup instructions +- Utilities: Common helper functions + +Example Usage: + from runners.ssh import SSHMultiNodeRunner, MultiNodeConfig + + config = MultiNodeConfig.from_config_file('config.ini') + runner = SSHMultiNodeRunner(config) + success = runner.run() + +Copyright (c) Advanced Micro Devices, Inc. All rights reserved. +""" + +from .config_manager import ( + MultiNodeConfig, + SSHConfig, + ClusterConfig, + TrainingConfig, + MadEngineConfig +) +from .ssh_client_manager import SSHClientManager +from .run import SSHMultiNodeRunner +from . import utils + +__version__ = "1.0.0" +__author__ = "Advanced Micro Devices, Inc." + +__all__ = [ + 'SSHMultiNodeRunner', + 'SSHClientManager', + 'MultiNodeConfig', + 'SSHConfig', + 'ClusterConfig', + 'TrainingConfig', + 'MadEngineConfig', + 'utils' +] \ No newline at end of file diff --git a/runners/ssh/config.ini.example b/runners/ssh/config.ini.example index 7d589566..2e0d7bad 100644 --- a/runners/ssh/config.ini.example +++ b/runners/ssh/config.ini.example @@ -1,20 +1,28 @@ # Configuration for SSH multi-node runner +# Copy this file to config.ini and customize for your environment [cluster] # Comma-separated list of node hostnames or IPs -nodes = 192.168.1.1,192.168.1.2 +nodes = 192.168.1.1,192.168.1.2,192.168.1.3 -# Master node configuration -master_addr = 192.168.0.1 +# Master node configuration (defaults to first node if not specified) +master_addr = 192.168.1.1 master_port = 4000 [ssh] -# SSH authentication -user = username # Replace with your SSH username -# Use either key_file OR password (key_file is recommended) +# SSH authentication - use either key_file OR password (key_file is recommended) +user = username + +# SSH key-based authentication (recommended) key_file = ~/.ssh/id_rsa + +# Password-based authentication (less secure, comment out if using key_file) # password = your_password_here +# SSH connection settings +timeout = 30 +max_retries = 3 + [training] # Model to train model = pyt_megatron_lm_train_llama2_7b @@ -29,9 +37,12 @@ gloo_interface = ens14np0 # Execution timeout in seconds (2 hours) timeout = 7200 +# Additional arguments to pass to madengine (optional) +# additional_args = --live-output --some-other-flag + [madengine] # Path to madengine executable (if not in PATH) -# madengine_path = /opt/madengine/bin/madengine +path = madengine -# Additional arguments to pass to madengine -# additional_args = --live-output +# Working directory on remote nodes +working_directory = DeepLearningModels diff --git a/runners/ssh/config_manager.py b/runners/ssh/config_manager.py new file mode 100644 index 00000000..98323235 --- /dev/null +++ b/runners/ssh/config_manager.py @@ -0,0 +1,294 @@ +#!/usr/bin/env python3 +"""Configuration Management for SSH Multi-Node Runner + +This module provides configuration validation and management for the SSH runner. + +Copyright (c) Advanced Micro Devices, Inc. All rights reserved. +""" + +import configparser +import os +from dataclasses import dataclass, field +from typing import Dict, List, Optional, Any +from pathlib import Path + + +@dataclass +class SSHConfig: + """SSH configuration settings.""" + user: str + password: Optional[str] = None + key_file: Optional[str] = None + timeout: int = 30 + max_retries: int = 3 + + def __post_init__(self): + """Validate SSH configuration after initialization.""" + if not self.password and not self.key_file: + raise ValueError("Either SSH password or key file must be specified") + + if self.key_file and not os.path.exists(self.key_file): + raise FileNotFoundError(f"SSH key file not found: {self.key_file}") + + +@dataclass +class ClusterConfig: + """Cluster configuration settings.""" + nodes: List[str] + master_addr: Optional[str] = None + master_port: int = 4000 + + def __post_init__(self): + """Validate cluster configuration after initialization.""" + if not self.nodes: + raise ValueError("At least one node must be specified") + + # Clean up node list + self.nodes = [node.strip() for node in self.nodes if node.strip()] + + if not self.nodes: + raise ValueError("At least one valid node must be specified") + + # Set master address to first node if not specified + if not self.master_addr: + self.master_addr = self.nodes[0] + + +@dataclass +class TrainingConfig: + """Training configuration settings.""" + model: str + shared_data_path: str = '/nfs/data' + nccl_interface: str = 'ens14np0' + gloo_interface: str = 'ens14np0' + timeout: int = 3600 + additional_args: str = '' + + def __post_init__(self): + """Validate training configuration after initialization.""" + if not self.model: + raise ValueError("Model must be specified") + + +@dataclass +class MadEngineConfig: + """MAD Engine specific configuration.""" + path: str = 'madengine' + working_directory: str = 'DeepLearningModels' + + +@dataclass +class MultiNodeConfig: + """Complete multi-node runner configuration.""" + ssh: SSHConfig + cluster: ClusterConfig + training: TrainingConfig + madengine: MadEngineConfig = field(default_factory=MadEngineConfig) + + @classmethod + def from_args(cls, args) -> 'MultiNodeConfig': + """Create configuration from command line arguments. + + Args: + args: Parsed command line arguments + + Returns: + Complete configuration object + """ + # Parse nodes list + nodes = [node.strip() for node in args.nodes.split(',') if node.strip()] + + ssh_config = SSHConfig( + user=args.ssh_user, + password=getattr(args, 'ssh_password', None), + key_file=getattr(args, 'ssh_key', None), + timeout=getattr(args, 'ssh_timeout', 30), + max_retries=getattr(args, 'ssh_max_retries', 3) + ) + + cluster_config = ClusterConfig( + nodes=nodes, + master_addr=getattr(args, 'master_addr', None), + master_port=getattr(args, 'master_port', 4000) + ) + + training_config = TrainingConfig( + model=args.model, + shared_data_path=getattr(args, 'shared_data_path', '/nfs/data'), + nccl_interface=getattr(args, 'nccl_interface', 'ens14np0'), + gloo_interface=getattr(args, 'gloo_interface', 'ens14np0'), + timeout=getattr(args, 'timeout', 3600), + additional_args=getattr(args, 'additional_args', '') + ) + + madengine_config = MadEngineConfig( + path=getattr(args, 'madengine_path', 'madengine'), + working_directory=getattr(args, 'working_directory', 'DeepLearningModels') + ) + + return cls( + ssh=ssh_config, + cluster=cluster_config, + training=training_config, + madengine=madengine_config + ) + + @classmethod + def from_config_file(cls, config_path: str) -> 'MultiNodeConfig': + """Create configuration from INI file. + + Args: + config_path: Path to configuration file + + Returns: + Complete configuration object + + Raises: + FileNotFoundError: If config file doesn't exist + configparser.Error: If config file is malformed + """ + if not os.path.exists(config_path): + raise FileNotFoundError(f"Configuration file not found: {config_path}") + + config = configparser.ConfigParser() + config.read(config_path) + + # Parse SSH configuration + ssh_section = config['ssh'] if 'ssh' in config else {} + ssh_config = SSHConfig( + user=ssh_section.get('user'), + password=ssh_section.get('password'), + key_file=ssh_section.get('key_file'), + timeout=int(ssh_section.get('timeout', 30)), + max_retries=int(ssh_section.get('max_retries', 3)) + ) + + # Parse cluster configuration + cluster_section = config['cluster'] if 'cluster' in config else {} + nodes_str = cluster_section.get('nodes', '') + nodes = [node.strip() for node in nodes_str.split(',') if node.strip()] + + cluster_config = ClusterConfig( + nodes=nodes, + master_addr=cluster_section.get('master_addr'), + master_port=int(cluster_section.get('master_port', 4000)) + ) + + # Parse training configuration + training_section = config['training'] if 'training' in config else {} + training_config = TrainingConfig( + model=training_section.get('model'), + shared_data_path=training_section.get('shared_data_path', '/nfs/data'), + nccl_interface=training_section.get('nccl_interface', 'ens14np0'), + gloo_interface=training_section.get('gloo_interface', 'ens14np0'), + timeout=int(training_section.get('timeout', 3600)), + additional_args=training_section.get('additional_args', '') + ) + + # Parse madengine configuration + madengine_section = config['madengine'] if 'madengine' in config else {} + madengine_config = MadEngineConfig( + path=madengine_section.get('path', 'madengine'), + working_directory=madengine_section.get('working_directory', 'DeepLearningModels') + ) + + return cls( + ssh=ssh_config, + cluster=cluster_config, + training=training_config, + madengine=madengine_config + ) + + def validate(self) -> None: + """Validate the complete configuration. + + Raises: + ValueError: If configuration is invalid + """ + # Configurations are validated in their __post_init__ methods + pass + + def to_dict(self) -> Dict[str, Any]: + """Convert configuration to dictionary. + + Returns: + Configuration as dictionary + """ + return { + 'ssh': { + 'user': self.ssh.user, + 'has_password': bool(self.ssh.password), + 'has_key_file': bool(self.ssh.key_file), + 'timeout': self.ssh.timeout, + 'max_retries': self.ssh.max_retries + }, + 'cluster': { + 'nodes': self.cluster.nodes, + 'master_addr': self.cluster.master_addr, + 'master_port': self.cluster.master_port + }, + 'training': { + 'model': self.training.model, + 'shared_data_path': self.training.shared_data_path, + 'nccl_interface': self.training.nccl_interface, + 'gloo_interface': self.training.gloo_interface, + 'timeout': self.training.timeout, + 'additional_args': self.training.additional_args + }, + 'madengine': { + 'path': self.madengine.path, + 'working_directory': self.madengine.working_directory + } + } + + +def merge_config_file_with_args(config_path: str, args) -> 'MultiNodeConfig': + """Merge configuration file with command line arguments. + + Command line arguments take precedence over config file values. + + Args: + config_path: Path to configuration file + args: Command line arguments + + Returns: + Merged configuration + """ + # Start with config file + config = MultiNodeConfig.from_config_file(config_path) + + # Override with command line arguments if provided + if hasattr(args, 'nodes') and args.nodes: + nodes = [node.strip() for node in args.nodes.split(',') if node.strip()] + config.cluster.nodes = nodes + + if hasattr(args, 'master_addr') and args.master_addr: + config.cluster.master_addr = args.master_addr + + if hasattr(args, 'master_port') and args.master_port: + config.cluster.master_port = args.master_port + + if hasattr(args, 'ssh_user') and args.ssh_user: + config.ssh.user = args.ssh_user + + if hasattr(args, 'ssh_password') and args.ssh_password: + config.ssh.password = args.ssh_password + + if hasattr(args, 'ssh_key') and args.ssh_key: + config.ssh.key_file = args.ssh_key + + if hasattr(args, 'model') and args.model: + config.training.model = args.model + + if hasattr(args, 'shared_data_path') and args.shared_data_path: + config.training.shared_data_path = args.shared_data_path + + if hasattr(args, 'timeout') and args.timeout: + config.training.timeout = args.timeout + + if hasattr(args, 'madengine_path') and args.madengine_path: + config.madengine.path = args.madengine_path + + # Re-validate after merging + config.validate() + return config diff --git a/runners/ssh/monitoring.py b/runners/ssh/monitoring.py new file mode 100644 index 00000000..d6323798 --- /dev/null +++ b/runners/ssh/monitoring.py @@ -0,0 +1,340 @@ +#!/usr/bin/env python3 +"""Monitoring and logging utilities for SSH Multi-Node Runner + +This module provides enhanced monitoring and logging capabilities. + +Copyright (c) Advanced Micro Devices, Inc. All rights reserved. +""" + +import json +import logging +import time +from datetime import datetime +from pathlib import Path +from typing import Dict, List, Any, Optional +from dataclasses import dataclass, asdict + + +@dataclass +class NodeExecutionResult: + """Result of execution on a single node.""" + hostname: str + node_rank: int + success: bool + start_time: float + end_time: float + output: str + error_message: Optional[str] = None + + @property + def duration(self) -> float: + """Get execution duration in seconds.""" + return self.end_time - self.start_time + + @property + def duration_formatted(self) -> str: + """Get formatted duration string.""" + from .utils import format_duration + return format_duration(self.duration) + + +@dataclass +class TrainingSession: + """Complete training session information.""" + session_id: str + model: str + nodes: List[str] + master_addr: str + master_port: int + start_time: float + end_time: Optional[float] = None + results: List[NodeExecutionResult] = None + + def __post_init__(self): + if self.results is None: + self.results = [] + + @property + def duration(self) -> Optional[float]: + """Get total session duration in seconds.""" + if self.end_time is None: + return None + return self.end_time - self.start_time + + @property + def success_rate(self) -> float: + """Get success rate as percentage.""" + if not self.results: + return 0.0 + successful = sum(1 for r in self.results if r.success) + return (successful / len(self.results)) * 100.0 + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary for serialization.""" + data = asdict(self) + data['duration'] = self.duration + data['success_rate'] = self.success_rate + return data + + +class SessionLogger: + """Logger for training sessions with structured output.""" + + def __init__(self, log_dir: str = "logs", session_id: Optional[str] = None): + """Initialize session logger. + + Args: + log_dir: Directory to store log files + session_id: Unique session identifier (auto-generated if None) + """ + self.log_dir = Path(log_dir) + self.log_dir.mkdir(exist_ok=True) + + if session_id is None: + session_id = f"session_{datetime.now().strftime('%Y%m%d_%H%M%S')}" + + self.session_id = session_id + self.session_file = self.log_dir / f"{session_id}.json" + self.log_file = self.log_dir / f"{session_id}.log" + + # Setup file logger + self.logger = logging.getLogger(f"session.{session_id}") + self.logger.setLevel(logging.DEBUG) + + # Create file handler + file_handler = logging.FileHandler(str(self.log_file)) + file_handler.setLevel(logging.DEBUG) + + # Create formatter + formatter = logging.Formatter( + '%(asctime)s - %(name)s - %(levelname)s - %(message)s' + ) + file_handler.setFormatter(formatter) + + self.logger.addHandler(file_handler) + + self.session: Optional[TrainingSession] = None + + def start_session(self, model: str, nodes: List[str], master_addr: str, master_port: int) -> None: + """Start a new training session. + + Args: + model: Model being trained + nodes: List of node hostnames + master_addr: Master node address + master_port: Master node port + """ + self.session = TrainingSession( + session_id=self.session_id, + model=model, + nodes=nodes.copy(), + master_addr=master_addr, + master_port=master_port, + start_time=time.time() + ) + + self.logger.info(f"Starting training session {self.session_id}") + self.logger.info(f"Model: {model}") + self.logger.info(f"Nodes: {', '.join(nodes)}") + self.logger.info(f"Master: {master_addr}:{master_port}") + + self._save_session() + + def log_node_start(self, hostname: str, node_rank: int, command: str) -> None: + """Log the start of execution on a node. + + Args: + hostname: Node hostname + node_rank: Node rank + command: Command being executed + """ + self.logger.info(f"Node {hostname} (rank {node_rank}) starting: {command}") + + def log_node_output(self, hostname: str, node_rank: int, line: str, is_error: bool = False) -> None: + """Log output from a node. + + Args: + hostname: Node hostname + node_rank: Node rank + line: Output line + is_error: Whether this is error output + """ + level = logging.ERROR if is_error else logging.INFO + prefix = "ERROR" if is_error else "OUTPUT" + self.logger.log(level, f"[{hostname}:{node_rank}] {prefix}: {line}") + + def log_node_result(self, result: NodeExecutionResult) -> None: + """Log the result of execution on a node. + + Args: + result: Node execution result + """ + if self.session is None: + raise RuntimeError("No active session") + + self.session.results.append(result) + + status = "SUCCESS" if result.success else "FAILED" + self.logger.info( + f"Node {result.hostname} (rank {result.node_rank}) {status} " + f"in {result.duration_formatted}" + ) + + if not result.success and result.error_message: + self.logger.error(f"Node {result.hostname} error: {result.error_message}") + + self._save_session() + + def end_session(self, success: bool) -> None: + """End the training session. + + Args: + success: Whether the overall session was successful + """ + if self.session is None: + raise RuntimeError("No active session") + + self.session.end_time = time.time() + + status = "SUCCESS" if success else "FAILED" + duration = self.session.duration + from .utils import format_duration + duration_str = format_duration(duration) if duration else "unknown" + + self.logger.info(f"Training session {status} in {duration_str}") + self.logger.info(f"Success rate: {self.session.success_rate:.1f}%") + + self._save_session() + + def _save_session(self) -> None: + """Save session data to JSON file.""" + if self.session is None: + return + + try: + with open(self.session_file, 'w') as f: + json.dump(self.session.to_dict(), f, indent=2, default=str) + except Exception as e: + self.logger.error(f"Failed to save session data: {e}") + + def get_session_summary(self) -> Dict[str, Any]: + """Get session summary. + + Returns: + Dictionary containing session summary + """ + if self.session is None: + return {} + + return { + 'session_id': self.session.session_id, + 'model': self.session.model, + 'total_nodes': len(self.session.nodes), + 'completed_nodes': len(self.session.results), + 'successful_nodes': sum(1 for r in self.session.results if r.success), + 'failed_nodes': sum(1 for r in self.session.results if not r.success), + 'success_rate': self.session.success_rate, + 'duration': self.session.duration, + 'status': 'completed' if self.session.end_time else 'running' + } + + +class ProgressMonitor: + """Monitor training progress across nodes.""" + + def __init__(self, total_nodes: int): + """Initialize progress monitor. + + Args: + total_nodes: Total number of nodes + """ + self.total_nodes = total_nodes + self.completed_nodes = 0 + self.successful_nodes = 0 + self.failed_nodes = 0 + self.start_time = time.time() + + def update(self, success: bool) -> None: + """Update progress with a completed node. + + Args: + success: Whether the node completed successfully + """ + self.completed_nodes += 1 + if success: + self.successful_nodes += 1 + else: + self.failed_nodes += 1 + + def get_progress(self) -> Dict[str, Any]: + """Get current progress information. + + Returns: + Dictionary containing progress information + """ + elapsed_time = time.time() - self.start_time + completion_rate = self.completed_nodes / self.total_nodes if self.total_nodes > 0 else 0 + + # Estimate remaining time + if completion_rate > 0: + estimated_total_time = elapsed_time / completion_rate + estimated_remaining_time = estimated_total_time - elapsed_time + else: + estimated_remaining_time = None + + return { + 'total_nodes': self.total_nodes, + 'completed_nodes': self.completed_nodes, + 'successful_nodes': self.successful_nodes, + 'failed_nodes': self.failed_nodes, + 'completion_rate': completion_rate, + 'elapsed_time': elapsed_time, + 'estimated_remaining_time': estimated_remaining_time, + 'success_rate': self.successful_nodes / self.completed_nodes if self.completed_nodes > 0 else 0 + } + + def print_progress(self) -> None: + """Print current progress to console.""" + progress = self.get_progress() + + from .utils import format_duration + elapsed = format_duration(progress['elapsed_time']) + + if progress['estimated_remaining_time']: + remaining = format_duration(progress['estimated_remaining_time']) + time_info = f"Elapsed: {elapsed}, Remaining: ~{remaining}" + else: + time_info = f"Elapsed: {elapsed}" + + print(f"Progress: {progress['completed_nodes']}/{progress['total_nodes']} " + f"({progress['completion_rate']*100:.1f}%) - " + f"Success: {progress['successful_nodes']}, Failed: {progress['failed_nodes']} - " + f"{time_info}") + + +def load_session_history(log_dir: str = "logs") -> List[Dict[str, Any]]: + """Load session history from log directory. + + Args: + log_dir: Directory containing log files + + Returns: + List of session summaries + """ + log_path = Path(log_dir) + if not log_path.exists(): + return [] + + sessions = [] + for json_file in log_path.glob("session_*.json"): + try: + with open(json_file, 'r') as f: + session_data = json.load(f) + sessions.append(session_data) + except Exception: + # Skip corrupted files + continue + + # Sort by start time (most recent first) + sessions.sort(key=lambda x: x.get('start_time', 0), reverse=True) + return sessions diff --git a/runners/ssh/requirements.txt b/runners/ssh/requirements.txt index b199d422..0e7b042e 100644 --- a/runners/ssh/requirements.txt +++ b/runners/ssh/requirements.txt @@ -1,2 +1,4 @@ # SSH Multi-Node Runner Requirements -paramiko>=2.9.0 + +# Core SSH functionality +paramiko>=2.9.0,<4.0.0 diff --git a/runners/ssh/run.py b/runners/ssh/run.py index 8d553606..5b9271d9 100644 --- a/runners/ssh/run.py +++ b/runners/ssh/run.py @@ -19,122 +19,58 @@ import argparse import json +import logging import os import sys import time -import threading -import socket -import configparser -from typing import List, Dict, Optional, Tuple from concurrent.futures import ThreadPoolExecutor, as_completed +from typing import List, Tuple, Optional -# Third-party imports +# Local imports try: - import paramiko + from config_manager import MultiNodeConfig, merge_config_file_with_args + from ssh_client_manager import SSHClientManager except ImportError: - print("Error: paramiko is required but not installed.") - print("Please install it with: pip install paramiko") - sys.exit(1) + # Fallback for direct execution + sys.path.insert(0, os.path.dirname(__file__)) + from config_manager import MultiNodeConfig, merge_config_file_with_args + from ssh_client_manager import SSHClientManager -# Add madengine to path if needed -sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..', 'src')) +# Setup logging +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' +) + + +class ValidationError(Exception): + """Exception raised when validation fails.""" + pass class SSHMultiNodeRunner: """SSH-based multi-node runner for distributed training.""" - def __init__(self, args: argparse.Namespace): + def __init__(self, config: MultiNodeConfig): """Initialize the SSH multi-node runner. Args: - args: Command line arguments containing configuration - """ - self.args = args - self.nodes = [node.strip() for node in args.nodes.split(',') if node.strip()] - self.master_addr = args.master_addr or self.nodes[0] - self.master_port = str(args.master_port) - self.ssh_user = args.ssh_user - self.ssh_password = getattr(args, 'ssh_password', None) - self.ssh_key = getattr(args, 'ssh_key', None) - self.model_tag = args.model - self.shared_data_path = getattr(args, 'shared_data_path', '/nfs/data') - self.nccl_interface = getattr(args, 'nccl_interface', 'ens14np0') - self.gloo_interface = getattr(args, 'gloo_interface', 'ens14np0') - self.timeout = getattr(args, 'timeout', 3600) # 1 hour default - self.madengine_path = getattr(args, 'madengine_path', 'madengine') - self.additional_args = getattr(args, 'additional_args', '') - - # Validate configuration - self._validate_config() - - def _validate_config(self) -> None: - """Validate the configuration parameters.""" - if not self.nodes or not any(node.strip() for node in self.nodes): - raise ValueError("At least one node must be specified") - - if not self.ssh_user: - raise ValueError("SSH username must be specified") - - if not self.ssh_password and not self.ssh_key: - raise ValueError("Either SSH password or SSH key must be specified") - - if not self.model_tag: - raise ValueError("Model tag must be specified") - - # Validate SSH key file exists if specified - if self.ssh_key and not os.path.exists(self.ssh_key): - raise FileNotFoundError(f"SSH key file not found: {self.ssh_key}") - - def _create_ssh_client(self) -> paramiko.SSHClient: - """Create and configure SSH client. - - Returns: - Configured SSH client - """ - ssh_client = paramiko.SSHClient() - ssh_client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) - ssh_client.load_system_host_keys() - return ssh_client - - def _connect_ssh(self, hostname: str) -> paramiko.SSHClient: - """Connect to a remote host via SSH. - - Args: - hostname: The hostname or IP to connect to - - Returns: - Connected SSH client - - Raises: - Exception: If connection fails + config: Complete configuration object """ - ssh_client = self._create_ssh_client() - - try: - if self.ssh_key: - ssh_client.connect( - hostname=hostname, - username=self.ssh_user, - key_filename=self.ssh_key, - timeout=30 - ) - else: - ssh_client.connect( - hostname=hostname, - username=self.ssh_user, - password=self.ssh_password, - timeout=30 - ) - - print(f"✓ Successfully connected to {hostname}") - return ssh_client - - except paramiko.ssh_exception.AuthenticationException as e: - raise Exception(f"Authentication failed for {hostname}: {e}") - except paramiko.ssh_exception.SSHException as e: - raise Exception(f"SSH error for {hostname}: {e}") - except socket.error as e: - raise Exception(f"Socket error for {hostname}: {e}") + self.config = config + self.logger = logging.getLogger(self.__class__.__name__) + + # Create SSH client managers for each node + self.ssh_managers = {} + for node in config.cluster.nodes: + self.ssh_managers[node] = SSHClientManager( + hostname=node, + username=config.ssh.user, + password=config.ssh.password, + key_filename=config.ssh.key_file, + timeout=config.ssh.timeout, + max_retries=config.ssh.max_retries + ) def _build_madengine_command(self, node_rank: int) -> str: """Build the madengine command for a specific node. @@ -147,93 +83,35 @@ def _build_madengine_command(self, node_rank: int) -> str: """ multi_node_args = { 'RUNNER': 'torchrun', - 'MASTER_ADDR': self.master_addr, - 'MASTER_PORT': self.master_port, - 'NNODES': str(len(self.nodes)), + 'MASTER_ADDR': self.config.cluster.master_addr, + 'MASTER_PORT': str(self.config.cluster.master_port), + 'NNODES': str(len(self.config.cluster.nodes)), 'NODE_RANK': str(node_rank), - 'NCCL_SOCKET_IFNAME': self.nccl_interface, - 'GLOO_SOCKET_IFNAME': self.gloo_interface + 'NCCL_SOCKET_IFNAME': self.config.training.nccl_interface, + 'GLOO_SOCKET_IFNAME': self.config.training.gloo_interface } # Build the additional context string - additional_context = f"'{{'multi_node_args': {json.dumps(multi_node_args)}}}'" + additional_context = f"'{json.dumps({'multi_node_args': multi_node_args})}'" # Build the complete command cmd_parts = [ - self.madengine_path, + self.config.madengine.path, 'run', - '--tags', self.model_tag, + '--tags', self.config.training.model, '--additional-context', additional_context ] # Add shared data path if specified - if self.shared_data_path: - cmd_parts.extend(['--force-mirror-local', self.shared_data_path]) + if self.config.training.shared_data_path: + cmd_parts.extend(['--force-mirror-local', self.config.training.shared_data_path]) # Add any additional arguments - if self.additional_args: - cmd_parts.append(self.additional_args) + if self.config.training.additional_args: + cmd_parts.append(self.config.training.additional_args) return ' '.join(cmd_parts) - def _execute_on_node(self, hostname: str, node_rank: int) -> Tuple[str, bool, str]: - """Execute madengine command on a single node. - - Args: - hostname: The hostname/IP of the node - node_rank: The rank of this node - - Returns: - Tuple of (hostname, success, output/error) - """ - try: - ssh_client = self._connect_ssh(hostname) - - # Build and execute the madengine command - command = self._build_madengine_command(node_rank) - - # Change to DeepLearningModels directory and execute the command - full_command = f"cd DeepLearningModels && {command}" - print(f"🚀 Executing on {hostname} (rank {node_rank}): {full_command}") - - # Execute the command - stdin, stdout, stderr = ssh_client.exec_command( - full_command, - timeout=self.timeout - ) - - # Read output in real-time - output_lines = [] - error_lines = [] - - # Read stdout - for line in stdout: - line = line.strip() - if line: - print(f"[{hostname}:{node_rank}] {line}") - output_lines.append(line) - - # Read stderr - for line in stderr: - line = line.strip() - if line: - print(f"[{hostname}:{node_rank}] ERROR: {line}") - error_lines.append(line) - - # Get exit code - exit_code = stdout.channel.recv_exit_status() - - ssh_client.close() - - if exit_code == 0: - return hostname, True, '\n'.join(output_lines) - else: - error_output = '\n'.join(error_lines) if error_lines else "Command failed with no error output" - return hostname, False, f"Exit code {exit_code}: {error_output}" - - except Exception as e: - return hostname, False, str(e) - def _check_node_connectivity(self) -> List[str]: """Check connectivity to all nodes. @@ -242,63 +120,19 @@ def _check_node_connectivity(self) -> List[str]: """ reachable_nodes = [] - print("🔍 Checking connectivity to all nodes...") + self.logger.info("Checking connectivity to all nodes...") - for hostname in self.nodes: - try: - ssh_client = self._connect_ssh(hostname) - - # Test basic command execution - stdin, stdout, stderr = ssh_client.exec_command('echo "connectivity_test"') - output = stdout.read().decode().strip() - - if output == "connectivity_test": - reachable_nodes.append(hostname) - print(f"✓ {hostname} is reachable") - else: - print(f"✗ {hostname} failed connectivity test") - - ssh_client.close() - - except Exception as e: - print(f"✗ {hostname} is not reachable: {e}") + for hostname in self.config.cluster.nodes: + ssh_manager = self.ssh_managers[hostname] + if ssh_manager.test_connectivity(): + reachable_nodes.append(hostname) + self.logger.info(f"✓ {hostname} is reachable") + else: + self.logger.error(f"✗ {hostname} is not reachable") return reachable_nodes - def _wait_for_master_ready(self, master_host: str) -> bool: - """Wait for master node to be ready to accept connections. - - Args: - master_host: The master node hostname/IP - - Returns: - True if master is ready, False if timeout - """ - print(f"⏳ Waiting for master node {master_host}:{self.master_port} to be ready...") - - max_wait_time = 60 # seconds - start_time = time.time() - - while time.time() - start_time < max_wait_time: - try: - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - sock.settimeout(5) - result = sock.connect_ex((master_host, int(self.master_port))) - sock.close() - - if result == 0: - print(f"✓ Master node is ready") - return True - - except Exception: - pass - - time.sleep(2) - - print(f"✗ Master node did not become ready within {max_wait_time} seconds") - return False - - def _validate_remote_node_prerequisites(self, hostname: str) -> Tuple[bool, str]: + def _validate_remote_prerequisites(self, hostname: str) -> Tuple[bool, str]: """Validate that remote node has required prerequisites. Args: @@ -307,131 +141,175 @@ def _validate_remote_node_prerequisites(self, hostname: str) -> Tuple[bool, str] Returns: Tuple of (success, error_message) """ + ssh_manager = self.ssh_managers[hostname] + try: - ssh_client = self._connect_ssh(hostname) - - # Check if DeepLearningModels folder exists - print(f"🔍 Checking DeepLearningModels folder on {hostname}...") - stdin, stdout, stderr = ssh_client.exec_command('test -d DeepLearningModels && echo "exists" || echo "missing"') - dl_models_result = stdout.read().decode().strip() - - if dl_models_result != "exists": - ssh_client.close() - return False, f"DeepLearningModels folder not found on {hostname}. Please set up the remote node with DeepLearningModels directory." + # Check if working directory exists + self.logger.debug(f"Checking {self.config.madengine.working_directory} folder on {hostname}...") + exit_code, stdout, stderr = ssh_manager.execute_command( + f'test -d {self.config.madengine.working_directory} && echo "exists" || echo "missing"' + ) - print(f"✓ DeepLearningModels folder found on {hostname}") + if stdout != "exists": + return False, f"{self.config.madengine.working_directory} folder not found on {hostname}" - # Check if madengine is installed and accessible - print(f"🔍 Checking madengine installation on {hostname}...") - stdin, stdout, stderr = ssh_client.exec_command(f'which {self.madengine_path} > /dev/null 2>&1 && echo "found" || echo "missing"') - madengine_result = stdout.read().decode().strip() + # Check if madengine is accessible + self.logger.debug(f"Checking madengine installation on {hostname}...") + exit_code, stdout, stderr = ssh_manager.execute_command( + f'which {self.config.madengine.path} > /dev/null 2>&1 && echo "found" || echo "missing"' + ) - if madengine_result != "found": - # Try alternative check - see if madengine can be executed - stdin, stdout, stderr = ssh_client.exec_command(f'{self.madengine_path} --help > /dev/null 2>&1 && echo "found" || echo "missing"') - madengine_alt_result = stdout.read().decode().strip() + if stdout != "found": + # Try alternative check + exit_code, stdout, stderr = ssh_manager.execute_command( + f'{self.config.madengine.path} --help > /dev/null 2>&1 && echo "found" || echo "missing"' + ) - if madengine_alt_result != "found": - ssh_client.close() - return False, f"madengine not found or not accessible on {hostname}. Please install madengine on the remote node or ensure it's in the PATH." + if stdout != "found": + return False, f"madengine not found or not accessible on {hostname}" - print(f"✓ madengine installation found on {hostname}") - - # Check if we can access the DeepLearningModels directory - print(f"🔍 Checking access to DeepLearningModels directory on {hostname}...") - stdin, stdout, stderr = ssh_client.exec_command('cd DeepLearningModels && pwd') - cd_result = stdout.read().decode().strip() - cd_error = stderr.read().decode().strip() - - if cd_error or not cd_result.endswith('DeepLearningModels'): - ssh_client.close() - return False, f"Cannot access DeepLearningModels directory on {hostname}. Error: {cd_error or 'Unknown error'}" + # Check if we can access the working directory + self.logger.debug(f"Checking access to {self.config.madengine.working_directory} directory on {hostname}...") + exit_code, stdout, stderr = ssh_manager.execute_command( + f'cd {self.config.madengine.working_directory} && pwd' + ) - print(f"✓ DeepLearningModels directory is accessible on {hostname}") + if stderr or not stdout.endswith(self.config.madengine.working_directory): + return False, f"Cannot access {self.config.madengine.working_directory} directory on {hostname}" - # Check if shared data path exists (if specified and not the default) - if self.shared_data_path and self.shared_data_path != '/nfs/data': - print(f"🔍 Checking shared data path on {hostname}...") - stdin, stdout, stderr = ssh_client.exec_command(f'test -d "{self.shared_data_path}" && echo "exists" || echo "missing"') - shared_data_result = stdout.read().decode().strip() - - if shared_data_result != "exists": - ssh_client.close() - return False, f"Shared data path '{self.shared_data_path}' not found on {hostname}. Please ensure the shared filesystem is mounted." + # Check shared data path if specified and not default + if self.config.training.shared_data_path != '/nfs/data': + self.logger.debug(f"Checking shared data path on {hostname}...") + exit_code, stdout, stderr = ssh_manager.execute_command( + f'test -d "{self.config.training.shared_data_path}" && echo "exists" || echo "missing"' + ) - print(f"✓ Shared data path '{self.shared_data_path}' found on {hostname}") + if stdout != "exists": + return False, f"Shared data path '{self.config.training.shared_data_path}' not found on {hostname}" - print(f"✓ All checks passed for {hostname}") - - ssh_client.close() return True, "" except Exception as e: return False, f"Error validating prerequisites on {hostname}: {str(e)}" - + def _check_all_prerequisites(self) -> bool: """Check prerequisites on all nodes. Returns: True if all nodes meet prerequisites, False otherwise """ - print("🔧 Validating prerequisites on all nodes...") + self.logger.info("Validating prerequisites on all nodes...") failed_nodes = [] - for hostname in self.nodes: - success, error_msg = self._validate_remote_node_prerequisites(hostname) + for hostname in self.config.cluster.nodes: + success, error_msg = self._validate_remote_prerequisites(hostname) if not success: - print(f"❌ {hostname}: {error_msg}") + self.logger.error(f"❌ {hostname}: {error_msg}") failed_nodes.append((hostname, error_msg)) else: - print(f"✅ {hostname}: All prerequisites met") + self.logger.info(f"✅ {hostname}: All prerequisites met") if failed_nodes: - print(f"\n💥 Prerequisites check failed for {len(failed_nodes)} node(s):") - for hostname, error_msg in failed_nodes: - print(f" • {hostname}: {error_msg}") - + self.logger.error(f"Prerequisites check failed for {len(failed_nodes)} node(s)") self._print_setup_instructions() return False - print("✅ All nodes meet the prerequisites") + self.logger.info("All nodes meet the prerequisites") return True def _print_setup_instructions(self) -> None: """Print setup instructions for remote nodes.""" - print("\n" + "="*60) - print("🔧 REMOTE NODE SETUP INSTRUCTIONS") - print("="*60) - print("\nTo prepare your remote nodes for multi-node training:") - print("\n1. DeepLearningModels Directory:") - print(" • Create or ensure the DeepLearningModels directory exists in the user's home directory") - print(" • Command: mkdir -p ~/DeepLearningModels") - print(" • This directory should contain your model configurations and training scripts") - - print("\n2. MAD Engine Installation:") - print(" • Install madengine on each remote node") - print(" • Command: pip install madengine") - print(" • Or ensure madengine is in the PATH and executable") - print(" • Verify with: madengine --help") - - if self.shared_data_path and self.shared_data_path != '/nfs/data': - print(f"\n3. Shared Data Path:") - print(f" • Ensure the shared data path '{self.shared_data_path}' exists and is accessible") - print(f" • This should be a shared filesystem (NFS, GPFS, etc.) mounted on all nodes") - print(f" • All nodes should have read/write access to this path") - - print("\n4. SSH Access:") - print(" • Ensure SSH key-based or password authentication is configured") - print(" • Test SSH access manually before running this script") - - print("\n5. Network Configuration:") - print(f" • Ensure nodes can communicate on the specified interfaces:") - print(f" • NCCL interface: {self.nccl_interface}") - print(f" • GLOO interface: {self.gloo_interface}") - print(f" • Master node {self.master_addr} should be accessible on port {self.master_port}") - - print("\n" + "="*60) + instructions = f""" +{"="*60} +🔧 REMOTE NODE SETUP INSTRUCTIONS +{"="*60} + +To prepare your remote nodes for multi-node training: + +1. {self.config.madengine.working_directory} Directory: + • Create or ensure the {self.config.madengine.working_directory} directory exists + • Command: mkdir -p ~/{self.config.madengine.working_directory} + +2. MAD Engine Installation: + • Install madengine on each remote node + • Command: pip install madengine + • Verify with: {self.config.madengine.path} --help + +3. Shared Data Path: + • Ensure the shared data path '{self.config.training.shared_data_path}' exists + • This should be a shared filesystem mounted on all nodes + +4. SSH Access: + • Ensure SSH key-based or password authentication is configured + • Test SSH access manually before running this script + +5. Network Configuration: + • Ensure nodes can communicate on the specified interfaces + • NCCL interface: {self.config.training.nccl_interface} + • GLOO interface: {self.config.training.gloo_interface} + • Master node {self.config.cluster.master_addr} should be accessible on port {self.config.cluster.master_port} + +{"="*60} + """ + print(instructions) + + def _execute_on_node(self, hostname: str, node_rank: int) -> Tuple[str, bool, str]: + """Execute madengine command on a single node. + + Args: + hostname: The hostname/IP of the node + node_rank: The rank of this node + + Returns: + Tuple of (hostname, success, output/error) + """ + ssh_manager = self.ssh_managers[hostname] + + try: + # Build and execute the madengine command + command = self._build_madengine_command(node_rank) + + # Change to working directory and execute the command + full_command = f"cd {self.config.madengine.working_directory} && {command}" + self.logger.info(f"🚀 Executing on {hostname} (rank {node_rank}): {full_command}") + + # Execute the command with streaming output + with ssh_manager.get_client() as client: + stdin, stdout, stderr = client.exec_command( + full_command, + timeout=self.config.training.timeout + ) + + # Read output in real-time + output_lines = [] + error_lines = [] + + # Read stdout + for line in stdout: + line = line.strip() + if line: + self.logger.info(f"[{hostname}:{node_rank}] {line}") + output_lines.append(line) + + # Read stderr + for line in stderr: + line = line.strip() + if line: + self.logger.warning(f"[{hostname}:{node_rank}] ERROR: {line}") + error_lines.append(line) + + # Get exit code + exit_code = stdout.channel.recv_exit_status() + + if exit_code == 0: + return hostname, True, '\n'.join(output_lines) + else: + error_output = '\n'.join(error_lines) if error_lines else "Command failed with no error output" + return hostname, False, f"Exit code {exit_code}: {error_output}" + + except Exception as e: + return hostname, False, str(e) def run(self) -> bool: """Run distributed training across all nodes. @@ -439,34 +317,33 @@ def run(self) -> bool: Returns: True if all nodes completed successfully, False otherwise """ - print(f"🌐 Starting multi-node training on {len(self.nodes)} nodes") - print(f"📋 Model: {self.model_tag}") - print(f"🏠 Master: {self.master_addr}:{self.master_port}") - print(f"📁 Shared data: {self.shared_data_path}") - print(f"🔗 Nodes: {', '.join(self.nodes)}") + self.logger.info(f"Starting multi-node training on {len(self.config.cluster.nodes)} nodes") + self.logger.info(f"Model: {self.config.training.model}") + self.logger.info(f"Master: {self.config.cluster.master_addr}:{self.config.cluster.master_port}") + self.logger.info(f"Shared data: {self.config.training.shared_data_path}") + self.logger.info(f"Nodes: {', '.join(self.config.cluster.nodes)}") # Check connectivity to all nodes reachable_nodes = self._check_node_connectivity() - if len(reachable_nodes) != len(self.nodes): - unreachable = set(self.nodes) - set(reachable_nodes) - print(f"❌ Some nodes are unreachable: {', '.join(unreachable)}") + if len(reachable_nodes) != len(self.config.cluster.nodes): + unreachable = set(self.config.cluster.nodes) - set(reachable_nodes) + self.logger.error(f"Some nodes are unreachable: {', '.join(unreachable)}") return False - print("✅ All nodes are reachable") + self.logger.info("All nodes are reachable") # Validate prerequisites on all nodes if not self._check_all_prerequisites(): return False - return False # Execute on all nodes concurrently results = [] - with ThreadPoolExecutor(max_workers=len(self.nodes)) as executor: + with ThreadPoolExecutor(max_workers=len(self.config.cluster.nodes)) as executor: # Submit jobs for all nodes futures = [] - for i, hostname in enumerate(self.nodes): + for i, hostname in enumerate(self.config.cluster.nodes): future = executor.submit(self._execute_on_node, hostname, i) futures.append(future) @@ -476,89 +353,25 @@ def run(self) -> bool: results.append((hostname, success, output)) if success: - print(f"✅ {hostname} completed successfully") + self.logger.info(f"✅ {hostname} completed successfully") else: - print(f"❌ {hostname} failed: {output}") + self.logger.error(f"❌ {hostname} failed: {output}") # Check overall success successful_nodes = [r[0] for r in results if r[1]] failed_nodes = [r[0] for r in results if not r[1]] - print(f"\n📊 Training Results:") - print(f"✅ Successful nodes: {len(successful_nodes)}/{len(self.nodes)}") + self.logger.info(f"Training Results:") + self.logger.info(f"✅ Successful nodes: {len(successful_nodes)}/{len(self.config.cluster.nodes)}") if failed_nodes: - print(f"❌ Failed nodes: {', '.join(failed_nodes)}") + self.logger.error(f"❌ Failed nodes: {', '.join(failed_nodes)}") return False - print("🎉 Multi-node training completed successfully!") + self.logger.info("🎉 Multi-node training completed successfully!") return True -def load_config_file(config_path: str) -> Dict: - """Load configuration from INI file. - - Args: - config_path: Path to configuration file - - Returns: - Dictionary with configuration values - """ - if not os.path.exists(config_path): - raise FileNotFoundError(f"Configuration file not found: {config_path}") - - config = configparser.ConfigParser() - config.read(config_path) - - # Convert to dictionary with flattened keys - config_dict = {} - for section in config.sections(): - for key, value in config[section].items(): - config_dict[f"{section}_{key}"] = value - - return config_dict - - -def merge_config_and_args(config_dict: Dict, args: argparse.Namespace) -> argparse.Namespace: - """Merge configuration file values with command line arguments. - Command line arguments take precedence over config file values. - - Args: - config_dict: Configuration dictionary from file - args: Command line arguments - - Returns: - Updated arguments with config file values applied - """ - # Mapping of config keys to argument attributes - config_to_arg_map = { - 'cluster_nodes': 'nodes', - 'cluster_master_addr': 'master_addr', - 'cluster_master_port': 'master_port', - 'ssh_user': 'ssh_user', - 'ssh_key_file': 'ssh_key', - 'ssh_password': 'ssh_password', - 'training_model': 'model', - 'training_shared_data_path': 'shared_data_path', - 'training_nccl_interface': 'nccl_interface', - 'training_gloo_interface': 'gloo_interface', - 'training_timeout': 'timeout', - 'madengine_madengine_path': 'madengine_path', - 'madengine_additional_args': 'additional_args' - } - - # Apply config values only if argument was not provided - for config_key, arg_attr in config_to_arg_map.items(): - if config_key in config_dict and not getattr(args, arg_attr, None): - value = config_dict[config_key] - # Convert numeric values - if arg_attr in ['master_port', 'timeout']: - value = int(value) - setattr(args, arg_attr, value) - - return args - - def parse_args() -> argparse.Namespace: """Parse command line arguments. @@ -670,32 +483,55 @@ def parse_args() -> argparse.Namespace: help='Additional arguments to pass to madengine' ) - # Parse arguments - args = parser.parse_args() + parser.add_argument( + '--working-directory', + default='DeepLearningModels', + help='Working directory on remote nodes (default: DeepLearningModels)' + ) - # Load configuration file if provided - if args.config: - config_dict = load_config_file(args.config) - args = merge_config_and_args(config_dict, args) + parser.add_argument( + '--ssh-timeout', + type=int, + default=30, + help='SSH connection timeout in seconds (default: 30)' + ) - # Validate required arguments after config loading - if not args.model: - parser.error("--model is required (can be provided via config file)") - if not args.nodes: - parser.error("--nodes is required (can be provided via config file)") - if not args.ssh_user: - parser.error("--ssh-user is required (can be provided via config file)") - if not args.ssh_password and not args.ssh_key: - parser.error("Either --ssh-password or --ssh-key is required (can be provided via config file)") + parser.add_argument( + '--ssh-max-retries', + type=int, + default=3, + help='Maximum SSH connection retry attempts (default: 3)' + ) - return args + parser.add_argument( + '--log-level', + default='INFO', + choices=['DEBUG', 'INFO', 'WARNING', 'ERROR'], + help='Logging level (default: INFO)' + ) + + return parser.parse_args() def main(): """Main entry point.""" try: args = parse_args() - runner = SSHMultiNodeRunner(args) + + # Set logging level + logging.getLogger().setLevel(getattr(logging, args.log_level)) + + # Create configuration + if args.config: + config = merge_config_file_with_args(args.config, args) + else: + config = MultiNodeConfig.from_args(args) + + # Validate configuration + config.validate() + + # Create and run the runner + runner = SSHMultiNodeRunner(config) success = runner.run() if success: @@ -709,7 +545,7 @@ def main(): print("\n🛑 Interrupted by user") sys.exit(1) except Exception as e: - print(f"💥 Error: {e}") + logging.error(f"Error: {e}") sys.exit(1) diff --git a/runners/ssh/ssh_client_manager.py b/runners/ssh/ssh_client_manager.py new file mode 100644 index 00000000..38929b29 --- /dev/null +++ b/runners/ssh/ssh_client_manager.py @@ -0,0 +1,161 @@ +#!/usr/bin/env python3 +"""SSH Client Manager for MAD Engine Multi-Node Runner + +This module provides a robust SSH client management class with connection pooling, +error handling, and retry mechanisms. + +Copyright (c) Advanced Micro Devices, Inc. All rights reserved. +""" + +import logging +import socket +import time +from contextlib import contextmanager +from typing import Optional, Tuple, Iterator + +try: + import paramiko +except ImportError: + raise ImportError("paramiko is required but not installed. Please install it with: pip install paramiko") + + +class SSHClientManager: + """Manages SSH connections with robust error handling and retry mechanisms.""" + + def __init__(self, hostname: str, username: str, password: Optional[str] = None, + key_filename: Optional[str] = None, timeout: int = 30, max_retries: int = 3): + """Initialize SSH client manager. + + Args: + hostname: Target hostname or IP address + username: SSH username + password: SSH password (if using password auth) + key_filename: Path to SSH private key (if using key auth) + timeout: Connection timeout in seconds + max_retries: Maximum number of connection retries + """ + self.hostname = hostname + self.username = username + self.password = password + self.key_filename = key_filename + self.timeout = timeout + self.max_retries = max_retries + self.logger = logging.getLogger(f"{__name__}.{hostname}") + + if not password and not key_filename: + raise ValueError("Either password or key_filename must be provided") + + def _create_client(self) -> paramiko.SSHClient: + """Create and configure a new SSH client. + + Returns: + Configured SSH client + """ + client = paramiko.SSHClient() + client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + client.load_system_host_keys() + return client + + def _connect_with_retry(self, client: paramiko.SSHClient) -> None: + """Connect SSH client with retry mechanism. + + Args: + client: SSH client to connect + + Raises: + ConnectionError: If all connection attempts fail + """ + last_exception = None + + for attempt in range(self.max_retries): + try: + if self.key_filename: + client.connect( + hostname=self.hostname, + username=self.username, + key_filename=self.key_filename, + timeout=self.timeout + ) + else: + client.connect( + hostname=self.hostname, + username=self.username, + password=self.password, + timeout=self.timeout + ) + + self.logger.debug(f"Successfully connected to {self.hostname} on attempt {attempt + 1}") + return + + except (paramiko.ssh_exception.AuthenticationException, + paramiko.ssh_exception.SSHException, + socket.error) as e: + last_exception = e + self.logger.warning(f"Connection attempt {attempt + 1} failed: {e}") + + if attempt < self.max_retries - 1: + wait_time = 2 ** attempt # Exponential backoff + self.logger.info(f"Retrying in {wait_time} seconds...") + time.sleep(wait_time) + + raise ConnectionError(f"Failed to connect to {self.hostname} after {self.max_retries} attempts. Last error: {last_exception}") + + @contextmanager + def get_client(self) -> Iterator[paramiko.SSHClient]: + """Get an SSH client with automatic cleanup. + + Yields: + Connected SSH client + + Raises: + ConnectionError: If connection fails + """ + client = self._create_client() + try: + self._connect_with_retry(client) + yield client + finally: + try: + client.close() + except Exception as e: + self.logger.warning(f"Error closing SSH connection: {e}") + + def execute_command(self, command: str, timeout: Optional[int] = None) -> Tuple[int, str, str]: + """Execute a command on the remote host. + + Args: + command: Command to execute + timeout: Command timeout (uses default if None) + + Returns: + Tuple of (exit_code, stdout, stderr) + + Raises: + ConnectionError: If SSH connection fails + TimeoutError: If command times out + """ + with self.get_client() as client: + try: + stdin, stdout, stderr = client.exec_command(command, timeout=timeout or self.timeout) + + exit_code = stdout.channel.recv_exit_status() + stdout_text = stdout.read().decode('utf-8', errors='replace').strip() + stderr_text = stderr.read().decode('utf-8', errors='replace').strip() + + return exit_code, stdout_text, stderr_text + + except socket.timeout: + raise TimeoutError(f"Command timed out after {timeout or self.timeout} seconds: {command}") + + def test_connectivity(self) -> bool: + """Test connectivity to the remote host. + + Returns: + True if connection successful, False otherwise + """ + try: + exit_code, stdout, stderr = self.execute_command('echo "connectivity_test"') + return exit_code == 0 and stdout.strip() == "connectivity_test" + except Exception as e: + self.logger.error(f"Connectivity test failed: {e}") + return False diff --git a/runners/ssh/test_runner.py b/runners/ssh/test_runner.py index ad85e679..2fed0daa 100644 --- a/runners/ssh/test_runner.py +++ b/runners/ssh/test_runner.py @@ -2,101 +2,117 @@ # -*- coding: utf-8 -*- """Test script for SSH Multi-Node Runner -This script provides basic unit tests and validation for the SSH runner. +This script provides comprehensive unit tests and validation for the SSH runner. """ import os import sys import tempfile import unittest -from unittest.mock import patch, MagicMock +from unittest.mock import patch, MagicMock, call +from pathlib import Path # Add the current directory to the Python path sys.path.insert(0, os.path.dirname(__file__)) try: - from run import SSHMultiNodeRunner, load_config_file, merge_config_and_args, parse_args -except ImportError: - print("Error: Could not import run module. Make sure run.py is in the same directory.") + from config_manager import MultiNodeConfig, SSHConfig, ClusterConfig, TrainingConfig + from ssh_client_manager import SSHClientManager + from run import SSHMultiNodeRunner +except ImportError as e: + print(f"Error: Could not import modules: {e}") + print("Make sure all modules are in the same directory.") sys.exit(1) -class TestSSHMultiNodeRunner(unittest.TestCase): - """Test cases for SSH Multi-Node Runner.""" - - def setUp(self): - """Set up test fixtures.""" - self.mock_args = MagicMock() - self.mock_args.model = 'pyt_megatron_lm_train_llama2_7b' - self.mock_args.nodes = '10.0.0.1,10.0.0.2' - self.mock_args.master_addr = '10.0.0.1' - self.mock_args.master_port = 4000 - self.mock_args.ssh_user = 'testuser' - self.mock_args.ssh_password = 'testpass' - self.mock_args.ssh_key = None - self.mock_args.shared_data_path = '/nfs/data' - self.mock_args.nccl_interface = 'eth0' - self.mock_args.gloo_interface = 'eth0' - self.mock_args.timeout = 3600 - self.mock_args.madengine_path = 'madengine' - self.mock_args.additional_args = '' +class TestConfigManager(unittest.TestCase): + """Test cases for configuration management.""" - def test_initialization(self): - """Test runner initialization.""" - runner = SSHMultiNodeRunner(self.mock_args) + def test_ssh_config_validation(self): + """Test SSH configuration validation.""" + # Valid configuration with password + config = SSHConfig(user="testuser", password="testpass") + self.assertEqual(config.user, "testuser") + self.assertEqual(config.password, "testpass") - self.assertEqual(runner.nodes, ['10.0.0.1', '10.0.0.2']) - self.assertEqual(runner.master_addr, '10.0.0.1') - self.assertEqual(runner.master_port, '4000') - self.assertEqual(runner.ssh_user, 'testuser') - self.assertEqual(runner.model_tag, 'pyt_megatron_lm_train_llama2_7b') - - def test_command_generation(self): - """Test madengine command generation.""" - runner = SSHMultiNodeRunner(self.mock_args) + # Valid configuration with key file (mock file existence) + with patch('os.path.exists', return_value=True): + config = SSHConfig(user="testuser", key_file="/path/to/key") + self.assertEqual(config.key_file, "/path/to/key") - # Test command for node rank 0 - cmd_0 = runner._build_madengine_command(0) - self.assertIn('madengine run', cmd_0) - self.assertIn('pyt_megatron_lm_train_llama2_7b', cmd_0) - self.assertIn('"NODE_RANK": "0"', cmd_0) - self.assertIn('"NNODES": "2"', cmd_0) - self.assertIn('--force-mirror-local /nfs/data', cmd_0) + # Invalid configuration - no auth method + with self.assertRaises(ValueError): + SSHConfig(user="testuser") - # Test command for node rank 1 - cmd_1 = runner._build_madengine_command(1) - self.assertIn('"NODE_RANK": "1"', cmd_1) + # Invalid configuration - non-existent key file + with patch('os.path.exists', return_value=False): + with self.assertRaises(FileNotFoundError): + SSHConfig(user="testuser", key_file="/nonexistent/key") - def test_validation_errors(self): - """Test configuration validation errors.""" - # Test missing nodes - args_copy = MagicMock() - for attr in ['model', 'nodes', 'master_addr', 'master_port', 'ssh_user', 'ssh_password', 'ssh_key']: - setattr(args_copy, attr, getattr(self.mock_args, attr)) - args_copy.nodes = '' + def test_cluster_config_validation(self): + """Test cluster configuration validation.""" + # Valid configuration + config = ClusterConfig(nodes=["node1", "node2"]) + self.assertEqual(config.nodes, ["node1", "node2"]) + self.assertEqual(config.master_addr, "node1") # Should default to first node + + # Valid configuration with master_addr specified + config = ClusterConfig(nodes=["node1", "node2"], master_addr="node1") + self.assertEqual(config.master_addr, "node1") + + # Invalid configuration - empty nodes with self.assertRaises(ValueError): - SSHMultiNodeRunner(args_copy) + ClusterConfig(nodes=[]) - # Test missing SSH user - args_copy = MagicMock() - for attr in ['model', 'nodes', 'master_addr', 'master_port', 'ssh_user', 'ssh_password', 'ssh_key']: - setattr(args_copy, attr, getattr(self.mock_args, attr)) - args_copy.ssh_user = '' + # Invalid configuration - whitespace-only nodes with self.assertRaises(ValueError): - SSHMultiNodeRunner(args_copy) - - # Test missing authentication - args_copy = MagicMock() - for attr in ['model', 'nodes', 'master_addr', 'master_port', 'ssh_user', 'ssh_password', 'ssh_key']: - setattr(args_copy, attr, getattr(self.mock_args, attr)) - delattr(args_copy, 'ssh_password') - delattr(args_copy, 'ssh_key') + ClusterConfig(nodes=["", " ", "\t"]) + + def test_training_config_validation(self): + """Test training configuration validation.""" + # Valid configuration + config = TrainingConfig(model="test_model") + self.assertEqual(config.model, "test_model") + self.assertEqual(config.shared_data_path, "/nfs/data") # Default + + # Invalid configuration - empty model with self.assertRaises(ValueError): - SSHMultiNodeRunner(args_copy) + TrainingConfig(model="") + + def test_config_from_args(self): + """Test configuration creation from arguments.""" + mock_args = MagicMock() + mock_args.model = 'test_model' + mock_args.nodes = 'node1,node2,node3' + mock_args.ssh_user = 'testuser' + mock_args.ssh_password = 'testpass' + mock_args.ssh_key = None + + # Set default values for optional attributes + for attr, default in [ + ('master_addr', None), + ('master_port', 4000), + ('shared_data_path', '/nfs/data'), + ('nccl_interface', 'ens14np0'), + ('gloo_interface', 'ens14np0'), + ('timeout', 3600), + ('additional_args', ''), + ('madengine_path', 'madengine'), + ('working_directory', 'DeepLearningModels'), + ('ssh_timeout', 30), + ('ssh_max_retries', 3) + ]: + setattr(mock_args, attr, getattr(mock_args, attr, default)) + + config = MultiNodeConfig.from_args(mock_args) + + self.assertEqual(config.training.model, 'test_model') + self.assertEqual(config.cluster.nodes, ['node1', 'node2', 'node3']) + self.assertEqual(config.ssh.user, 'testuser') + self.assertEqual(config.ssh.password, 'testpass') - def test_config_file_loading(self): - """Test configuration file loading.""" - # Create a temporary config file + def test_config_from_file(self): + """Test configuration loading from file.""" config_content = """ [cluster] nodes = node1,node2,node3 @@ -105,7 +121,7 @@ def test_config_file_loading(self): [ssh] user = testuser -key_file = /path/to/key +password = testpass [training] model = test_model @@ -117,157 +133,217 @@ def test_config_file_loading(self): temp_config_path = f.name try: - config_dict = load_config_file(temp_config_path) + config = MultiNodeConfig.from_config_file(temp_config_path) - self.assertEqual(config_dict['cluster_nodes'], 'node1,node2,node3') - self.assertEqual(config_dict['cluster_master_addr'], 'node1') - self.assertEqual(config_dict['cluster_master_port'], '5000') - self.assertEqual(config_dict['ssh_user'], 'testuser') - self.assertEqual(config_dict['ssh_key_file'], '/path/to/key') - self.assertEqual(config_dict['training_model'], 'test_model') + self.assertEqual(config.cluster.nodes, ['node1', 'node2', 'node3']) + self.assertEqual(config.cluster.master_addr, 'node1') + self.assertEqual(config.cluster.master_port, 5000) + self.assertEqual(config.ssh.user, 'testuser') + self.assertEqual(config.ssh.password, 'testpass') + self.assertEqual(config.training.model, 'test_model') + self.assertEqual(config.training.shared_data_path, '/shared/data') finally: os.unlink(temp_config_path) + + +class TestSSHClientManager(unittest.TestCase): + """Test cases for SSH client management.""" + + @patch('paramiko.SSHClient') + def test_connectivity_test_success(self, mock_ssh_client_class): + """Test successful connectivity test.""" + # Mock SSH client + mock_client = MagicMock() + mock_ssh_client_class.return_value = mock_client + + # Mock successful execution + mock_client.exec_command.return_value = (None, MagicMock(), MagicMock()) + mock_client.exec_command.return_value[1].channel.recv_exit_status.return_value = 0 + mock_client.exec_command.return_value[1].read.return_value = b'connectivity_test' + mock_client.exec_command.return_value[2].read.return_value = b'' + + ssh_manager = SSHClientManager( + hostname="testhost", + username="testuser", + password="testpass" + ) + + result = ssh_manager.test_connectivity() + self.assertTrue(result) + + @patch('paramiko.SSHClient') + def test_connectivity_test_failure(self, mock_ssh_client_class): + """Test failed connectivity test.""" + # Mock SSH client that raises exception + mock_client = MagicMock() + mock_ssh_client_class.return_value = mock_client + mock_client.connect.side_effect = Exception("Connection failed") + + ssh_manager = SSHClientManager( + hostname="testhost", + username="testuser", + password="testpass" + ) + + result = ssh_manager.test_connectivity() + self.assertFalse(result) - def test_config_file_not_found(self): - """Test error handling for missing config file.""" - with self.assertRaises(FileNotFoundError): - load_config_file('/nonexistent/config.ini') + def test_invalid_authentication(self): + """Test invalid authentication configuration.""" + with self.assertRaises(ValueError): + SSHClientManager( + hostname="testhost", + username="testuser" + # No password or key_filename + ) -class MockSSHIntegrationTest(unittest.TestCase): - """Integration tests with mocked SSH connections.""" +class TestSSHMultiNodeRunner(unittest.TestCase): + """Test cases for SSH Multi-Node Runner.""" def setUp(self): """Set up test fixtures.""" - self.mock_args = MagicMock() - self.mock_args.model = 'pyt_megatron_lm_train_llama2_7b' - self.mock_args.nodes = '10.0.0.1,10.0.0.2' - self.mock_args.master_addr = '10.0.0.1' - self.mock_args.master_port = 4000 - self.mock_args.ssh_user = 'testuser' - self.mock_args.ssh_password = 'testpass' - self.mock_args.ssh_key = None - self.mock_args.shared_data_path = '/nfs/data' - self.mock_args.nccl_interface = 'eth0' - self.mock_args.gloo_interface = 'eth0' - self.mock_args.timeout = 3600 - self.mock_args.madengine_path = 'madengine' - self.mock_args.additional_args = '' + self.config = MultiNodeConfig( + ssh=SSHConfig(user="testuser", password="testpass"), + cluster=ClusterConfig(nodes=["node1", "node2"]), + training=TrainingConfig(model="test_model") + ) - @patch('run.paramiko.SSHClient') - def test_connectivity_check(self, mock_ssh_client_class): - """Test node connectivity checking.""" - # Mock SSH client - mock_ssh_client = MagicMock() - mock_ssh_client_class.return_value = mock_ssh_client + def test_initialization(self): + """Test runner initialization.""" + runner = SSHMultiNodeRunner(self.config) - # Mock successful connection and command execution - mock_stdout = MagicMock() - mock_stdout.read.return_value = b'connectivity_test' - mock_ssh_client.exec_command.return_value = (None, mock_stdout, None) + self.assertEqual(len(runner.ssh_managers), 2) + self.assertIn("node1", runner.ssh_managers) + self.assertIn("node2", runner.ssh_managers) + + def test_command_generation(self): + """Test madengine command generation.""" + runner = SSHMultiNodeRunner(self.config) - runner = SSHMultiNodeRunner(self.mock_args) - reachable_nodes = runner._check_node_connectivity() + # Test command for node rank 0 + cmd_0 = runner._build_madengine_command(0) + self.assertIn('madengine run', cmd_0) + self.assertIn('test_model', cmd_0) + self.assertIn('"NODE_RANK": "0"', cmd_0) + self.assertIn('"NNODES": "2"', cmd_0) - self.assertEqual(len(reachable_nodes), 2) - self.assertIn('10.0.0.1', reachable_nodes) - self.assertIn('10.0.0.2', reachable_nodes) + # Test command for node rank 1 + cmd_1 = runner._build_madengine_command(1) + self.assertIn('"NODE_RANK": "1"', cmd_1) - @patch('run.paramiko.SSHClient') - def test_command_execution(self, mock_ssh_client_class): - """Test command execution on nodes.""" - # Mock SSH client - mock_ssh_client = MagicMock() - mock_ssh_client_class.return_value = mock_ssh_client - - # Mock successful command execution - mock_stdout = MagicMock() - mock_stdout.__iter__ = lambda self: iter(['Training started...', 'Training completed!']) - mock_stdout.channel.recv_exit_status.return_value = 0 + @patch.object(SSHClientManager, 'test_connectivity') + def test_connectivity_check_success(self, mock_connectivity): + """Test successful connectivity checking.""" + mock_connectivity.return_value = True - mock_stderr = MagicMock() - mock_stderr.__iter__ = lambda self: iter([]) + runner = SSHMultiNodeRunner(self.config) + reachable_nodes = runner._check_node_connectivity() - mock_ssh_client.exec_command.return_value = (None, mock_stdout, mock_stderr) + self.assertEqual(len(reachable_nodes), 2) + self.assertIn("node1", reachable_nodes) + self.assertIn("node2", reachable_nodes) + + @patch.object(SSHClientManager, 'test_connectivity') + def test_connectivity_check_partial_failure(self, mock_connectivity): + """Test connectivity checking with partial failures.""" + # Mock node1 success, node2 failure + mock_connectivity.side_effect = [True, False] - runner = SSHMultiNodeRunner(self.mock_args) - hostname, success, output = runner._execute_on_node('10.0.0.1', 0) + runner = SSHMultiNodeRunner(self.config) + reachable_nodes = runner._check_node_connectivity() - self.assertEqual(hostname, '10.0.0.1') - self.assertTrue(success) - self.assertIn('Training started...', output) + self.assertEqual(len(reachable_nodes), 1) + self.assertIn("node1", reachable_nodes) + self.assertNotIn("node2", reachable_nodes) - @patch('run.paramiko.SSHClient') - def test_prerequisites_validation_success(self, mock_ssh_client_class): + @patch.object(SSHClientManager, 'execute_command') + def test_prerequisites_validation_success(self, mock_execute): """Test successful prerequisites validation.""" - # Mock SSH client - mock_ssh_client = MagicMock() - mock_ssh_client_class.return_value = mock_ssh_client - - # Mock successful prerequisites checks - def mock_exec_command(command): - mock_stdout = MagicMock() - mock_stderr = MagicMock() - if 'test -d DeepLearningModels' in command: - mock_stdout.read.return_value = b'exists' - elif 'which madengine' in command: - mock_stdout.read.return_value = b'found' - elif 'cd DeepLearningModels' in command: - mock_stdout.read.return_value = b'/home/user/DeepLearningModels' - mock_stderr.read.return_value = b'' - return (None, mock_stdout, mock_stderr) - - mock_ssh_client.exec_command.side_effect = mock_exec_command - - runner = SSHMultiNodeRunner(self.mock_args) - success, error_msg = runner._validate_remote_node_prerequisites('10.0.0.1') + # Mock successful responses for all checks + mock_execute.side_effect = [ + (0, "exists", ""), # Working directory exists + (0, "found", ""), # madengine found + (0, "/home/user/DeepLearningModels", "") # Directory accessible + ] + + runner = SSHMultiNodeRunner(self.config) + success, error_msg = runner._validate_remote_prerequisites("node1") self.assertTrue(success) self.assertEqual(error_msg, "") - @patch('run.paramiko.SSHClient') - def test_prerequisites_validation_missing_deeplearning_models(self, mock_ssh_client_class): - """Test prerequisites validation with missing DeepLearningModels folder.""" - # Mock SSH client - mock_ssh_client = MagicMock() - mock_ssh_client_class.return_value = mock_ssh_client - - # Mock missing DeepLearningModels folder - mock_stdout = MagicMock() - mock_stdout.read.return_value = b'missing' - mock_ssh_client.exec_command.return_value = (None, mock_stdout, None) + @patch.object(SSHClientManager, 'execute_command') + def test_prerequisites_validation_missing_directory(self, mock_execute): + """Test prerequisites validation with missing directory.""" + mock_execute.return_value = (0, "missing", "") - runner = SSHMultiNodeRunner(self.mock_args) - success, error_msg = runner._validate_remote_node_prerequisites('10.0.0.1') + runner = SSHMultiNodeRunner(self.config) + success, error_msg = runner._validate_remote_prerequisites("node1") self.assertFalse(success) - self.assertIn('DeepLearningModels folder not found', error_msg) + self.assertIn("DeepLearningModels folder not found", error_msg) + + +class TestIntegration(unittest.TestCase): + """Integration tests with mocked components.""" - @patch('run.paramiko.SSHClient') - def test_prerequisites_validation_missing_madengine(self, mock_ssh_client_class): - """Test prerequisites validation with missing madengine.""" - # Mock SSH client - mock_ssh_client = MagicMock() - mock_ssh_client_class.return_value = mock_ssh_client + def setUp(self): + """Set up test fixtures.""" + self.config = MultiNodeConfig( + ssh=SSHConfig(user="testuser", password="testpass"), + cluster=ClusterConfig(nodes=["node1", "node2"]), + training=TrainingConfig(model="test_model") + ) + + @patch.object(SSHMultiNodeRunner, '_execute_on_node') + @patch.object(SSHMultiNodeRunner, '_check_all_prerequisites') + @patch.object(SSHMultiNodeRunner, '_check_node_connectivity') + def test_successful_run(self, mock_connectivity, mock_prerequisites, mock_execute): + """Test successful end-to-end run.""" + # Mock all checks and execution as successful + mock_connectivity.return_value = ["node1", "node2"] + mock_prerequisites.return_value = True + mock_execute.side_effect = [ + ("node1", True, "Training completed"), + ("node2", True, "Training completed") + ] - # Mock DeepLearningModels exists but madengine missing - def mock_exec_command(command): - mock_stdout = MagicMock() - mock_stderr = MagicMock() - if 'test -d DeepLearningModels' in command: - mock_stdout.read.return_value = b'exists' - elif 'which madengine' in command or 'madengine --help' in command: - mock_stdout.read.return_value = b'missing' - return (None, mock_stdout, mock_stderr) + runner = SSHMultiNodeRunner(self.config) + result = runner.run() - mock_ssh_client.exec_command.side_effect = mock_exec_command + self.assertTrue(result) + self.assertEqual(mock_execute.call_count, 2) + + @patch.object(SSHMultiNodeRunner, '_execute_on_node') + @patch.object(SSHMultiNodeRunner, '_check_all_prerequisites') + @patch.object(SSHMultiNodeRunner, '_check_node_connectivity') + def test_failed_connectivity(self, mock_connectivity, mock_prerequisites, mock_execute): + """Test run with connectivity failure.""" + # Mock partial connectivity failure + mock_connectivity.return_value = ["node1"] # node2 unreachable - runner = SSHMultiNodeRunner(self.mock_args) - success, error_msg = runner._validate_remote_node_prerequisites('10.0.0.1') + runner = SSHMultiNodeRunner(self.config) + result = runner.run() - self.assertFalse(success) - self.assertIn('madengine not found', error_msg) + self.assertFalse(result) + mock_prerequisites.assert_not_called() + mock_execute.assert_not_called() + + @patch.object(SSHMultiNodeRunner, '_execute_on_node') + @patch.object(SSHMultiNodeRunner, '_check_all_prerequisites') + @patch.object(SSHMultiNodeRunner, '_check_node_connectivity') + def test_failed_prerequisites(self, mock_connectivity, mock_prerequisites, mock_execute): + """Test run with prerequisites failure.""" + mock_connectivity.return_value = ["node1", "node2"] + mock_prerequisites.return_value = False + + runner = SSHMultiNodeRunner(self.config) + result = runner.run() + + self.assertFalse(result) + mock_execute.assert_not_called() def run_tests(): @@ -279,8 +355,10 @@ def run_tests(): suite = unittest.TestSuite() # Add test cases + suite.addTests(loader.loadTestsFromTestCase(TestConfigManager)) + suite.addTests(loader.loadTestsFromTestCase(TestSSHClientManager)) suite.addTests(loader.loadTestsFromTestCase(TestSSHMultiNodeRunner)) - suite.addTests(loader.loadTestsFromTestCase(MockSSHIntegrationTest)) + suite.addTests(loader.loadTestsFromTestCase(TestIntegration)) # Run tests runner = unittest.TextTestRunner(verbosity=2) @@ -303,24 +381,21 @@ def validate_environment(): except ImportError: issues.append("❌ paramiko is not installed. Run: pip install paramiko") - # Check if run.py exists - run_py_path = os.path.join(os.path.dirname(__file__), 'run.py') - if os.path.exists(run_py_path): - print("✅ run.py found") - else: - issues.append("❌ run.py not found in current directory") + # Check if required files exist + required_files = ['run.py', 'config_manager.py', 'ssh_client_manager.py', 'requirements.txt'] + current_dir = os.path.dirname(__file__) - # Check if requirements.txt exists - req_path = os.path.join(os.path.dirname(__file__), 'requirements.txt') - if os.path.exists(req_path): - print("✅ requirements.txt found") - else: - issues.append("❌ requirements.txt not found") + for filename in required_files: + file_path = os.path.join(current_dir, filename) + if os.path.exists(file_path): + print(f"✅ {filename} found") + else: + issues.append(f"❌ {filename} not found in current directory") if issues: print("\n🚨 Issues found:") for issue in issues: - print(" {}".format(issue)) + print(f" {issue}") return False else: print("\n✅ Environment validation passed!") diff --git a/runners/ssh/utils.py b/runners/ssh/utils.py new file mode 100644 index 00000000..1c39755a --- /dev/null +++ b/runners/ssh/utils.py @@ -0,0 +1,219 @@ +#!/usr/bin/env python3 +"""Utility functions for SSH Multi-Node Runner + +This module provides common utility functions used across the SSH runner. + +Copyright (c) Advanced Micro Devices, Inc. All rights reserved. +""" + +import logging +import socket +import time +from typing import Dict, Any, Optional + + +def setup_logging(level: str = 'INFO', format_string: Optional[str] = None) -> logging.Logger: + """Setup logging configuration. + + Args: + level: Logging level (DEBUG, INFO, WARNING, ERROR) + format_string: Custom format string for log messages + + Returns: + Configured logger instance + """ + if format_string is None: + format_string = '%(asctime)s - %(name)s - %(levelname)s - %(message)s' + + logging.basicConfig( + level=getattr(logging, level.upper()), + format=format_string, + datefmt='%Y-%m-%d %H:%M:%S' + ) + + return logging.getLogger(__name__) + + +def validate_network_connectivity(hostname: str, port: int, timeout: int = 5) -> bool: + """Test network connectivity to a host and port. + + Args: + hostname: Target hostname or IP address + port: Target port number + timeout: Connection timeout in seconds + + Returns: + True if connection successful, False otherwise + """ + try: + with socket.create_connection((hostname, port), timeout=timeout): + return True + except (socket.error, socket.timeout): + return False + + +def wait_for_port_ready(hostname: str, port: int, max_wait_time: int = 60, check_interval: int = 2) -> bool: + """Wait for a port to become available on a host. + + Args: + hostname: Target hostname or IP address + port: Target port number + max_wait_time: Maximum time to wait in seconds + check_interval: Time between checks in seconds + + Returns: + True if port becomes available, False if timeout + """ + start_time = time.time() + + while time.time() - start_time < max_wait_time: + if validate_network_connectivity(hostname, port): + return True + time.sleep(check_interval) + + return False + + +def format_duration(seconds: float) -> str: + """Format duration in seconds to human-readable string. + + Args: + seconds: Duration in seconds + + Returns: + Formatted duration string (e.g., "2h 30m 15s") + """ + if seconds < 60: + return f"{seconds:.1f}s" + + minutes = int(seconds // 60) + remaining_seconds = int(seconds % 60) + + if minutes < 60: + return f"{minutes}m {remaining_seconds}s" + + hours = int(minutes // 60) + remaining_minutes = int(minutes % 60) + + if hours < 24: + return f"{hours}h {remaining_minutes}m {remaining_seconds}s" + + days = int(hours // 24) + remaining_hours = int(hours % 24) + + return f"{days}d {remaining_hours}h {remaining_minutes}m {remaining_seconds}s" + + +def sanitize_hostname(hostname: str) -> str: + """Sanitize hostname for safe use in file names and logs. + + Args: + hostname: Raw hostname or IP address + + Returns: + Sanitized hostname safe for file system use + """ + # Replace common problematic characters + safe_hostname = hostname.replace(':', '_').replace('/', '_').replace('\\', '_') + # Remove any remaining problematic characters + safe_hostname = ''.join(c for c in safe_hostname if c.isalnum() or c in '-_.') + return safe_hostname + + +def create_node_summary(nodes: list, successful: list, failed: list) -> Dict[str, Any]: + """Create a summary of node execution results. + + Args: + nodes: List of all nodes + successful: List of successful nodes + failed: List of failed nodes + + Returns: + Dictionary containing execution summary + """ + return { + 'total_nodes': len(nodes), + 'successful_nodes': len(successful), + 'failed_nodes': len(failed), + 'success_rate': len(successful) / len(nodes) if nodes else 0.0, + 'successful_node_list': successful, + 'failed_node_list': failed, + 'all_successful': len(failed) == 0 + } + + +def validate_madengine_command(command: str) -> bool: + """Validate that a madengine command looks reasonable. + + Args: + command: Command string to validate + + Returns: + True if command appears valid, False otherwise + """ + required_parts = ['madengine', 'run', '--tags'] + return all(part in command for part in required_parts) + + +def escape_shell_argument(argument: str) -> str: + """Escape shell argument for safe execution. + + Args: + argument: Argument to escape + + Returns: + Escaped argument safe for shell execution + """ + # Simple escaping - wrap in single quotes and escape any single quotes + return f"'{argument.replace(chr(39), chr(39) + chr(92) + chr(39) + chr(39))}'" + + +def parse_node_list(nodes_string: str) -> list: + """Parse a comma-separated list of nodes. + + Args: + nodes_string: Comma-separated string of node names/IPs + + Returns: + List of cleaned node names + """ + if not nodes_string: + return [] + + nodes = [node.strip() for node in nodes_string.split(',') if node.strip()] + return nodes + + +def get_network_interfaces() -> Dict[str, str]: + """Get available network interfaces on the local machine. + + Returns: + Dictionary mapping interface names to IP addresses + """ + interfaces = {} + + try: + import socket + # Get hostname + hostname = socket.gethostname() + # Get local IP + local_ip = socket.gethostbyname(hostname) + interfaces['local'] = local_ip + + # Try to get more detailed interface information + try: + import psutil + for interface, addresses in psutil.net_if_addrs().items(): + for address in addresses: + if address.family == socket.AF_INET: + interfaces[interface] = address.address + break + except ImportError: + # psutil not available, use basic detection + pass + + except Exception: + # Fallback to basic detection + pass + + return interfaces From 2851a36b29ef060f66a26c5014880c23aea6a6e6 Mon Sep 17 00:00:00 2001 From: Stephen Shao Date: Wed, 2 Jul 2025 14:18:40 -0400 Subject: [PATCH 06/17] Fixed the issue running multinodes --- .../ssh/{config.ini.example => config.ini} | 6 +-- runners/ssh/example.sh | 40 ++++++++++++++++--- runners/ssh/quick-start.sh | 24 +++++++---- 3 files changed, 53 insertions(+), 17 deletions(-) rename runners/ssh/{config.ini.example => config.ini} (92%) diff --git a/runners/ssh/config.ini.example b/runners/ssh/config.ini similarity index 92% rename from runners/ssh/config.ini.example rename to runners/ssh/config.ini index 2e0d7bad..c2e826c5 100644 --- a/runners/ssh/config.ini.example +++ b/runners/ssh/config.ini @@ -3,10 +3,10 @@ [cluster] # Comma-separated list of node hostnames or IPs -nodes = 192.168.1.1,192.168.1.2,192.168.1.3 +nodes = 10.194.128.135,10.194.129.35 # Master node configuration (defaults to first node if not specified) -master_addr = 192.168.1.1 +master_addr = 10.227.23.63 master_port = 4000 [ssh] @@ -14,7 +14,7 @@ master_port = 4000 user = username # SSH key-based authentication (recommended) -key_file = ~/.ssh/id_rsa +key_file = ~/.ssh/id_ed25519 # Password-based authentication (less secure, comment out if using key_file) # password = your_password_here diff --git a/runners/ssh/example.sh b/runners/ssh/example.sh index 51f130f5..828321cd 100644 --- a/runners/ssh/example.sh +++ b/runners/ssh/example.sh @@ -3,11 +3,11 @@ # Configuration MODEL="pyt_megatron_lm_train_llama2_7b" -NODES="192.168.1.1,192.168.1.2" -MASTER_ADDR="192.168.0.1" +NODES="10.194.128.135,10.194.129.35" +MASTER_ADDR="10.227.23.63" MASTER_PORT="4000" SSH_USER="username" # Replace with your SSH username -SSH_KEY="~/.ssh/id_rsa" +SSH_KEY="$HOME/.ssh/id_ed25519" SHARED_DATA="/nfs/data" NCCL_INTERFACE="ens14np0" GLOO_INTERFACE="ens14np0" @@ -17,14 +17,42 @@ echo "📋 Model: $MODEL" echo "🔗 Nodes: $NODES" echo "🏠 Master: $MASTER_ADDR:$MASTER_PORT" +# Validate SSH key exists +if [ ! -f "$SSH_KEY" ]; then + echo "❌ SSH key file not found: $SSH_KEY" + echo "💡 Available SSH keys in ~/.ssh/:" + ls -la ~/.ssh/*.pub 2>/dev/null || echo " No SSH keys found" + echo "" + echo "🔧 To generate a new SSH key, run:" + echo " ssh-keygen -t ed25519 -f ~/.ssh/id_ed25519" + echo " ssh-copy-id -i ~/.ssh/id_ed25519.pub $SSH_USER@" + exit 1 +fi + +echo "✅ SSH key validated: $SSH_KEY" + +# Detect Python command +PYTHON_CMD="" +if command -v python3 &> /dev/null; then + PYTHON_CMD="python3" +elif command -v python &> /dev/null; then + PYTHON_CMD="python" +elif [ -n "$VIRTUAL_ENV" ] && [ -x "$VIRTUAL_ENV/bin/python" ]; then + PYTHON_CMD="$VIRTUAL_ENV/bin/python" +else + echo "❌ Python is not installed or not in PATH" + echo "💡 If you're using a virtual environment, make sure it's activated" + exit 1 +fi + # Install requirements if not already installed -if ! python -c "import paramiko" 2>/dev/null; then +if ! $PYTHON_CMD -c "import paramiko" 2>/dev/null; then echo "📦 Installing required packages..." - pip install -r requirements.txt + $PYTHON_CMD -m pip install -r requirements.txt fi # Run the SSH multi-node runner -python run.py \ +$PYTHON_CMD run.py \ --model "$MODEL" \ --nodes "$NODES" \ --master-addr "$MASTER_ADDR" \ diff --git a/runners/ssh/quick-start.sh b/runners/ssh/quick-start.sh index ae012161..7f8ab95f 100644 --- a/runners/ssh/quick-start.sh +++ b/runners/ssh/quick-start.sh @@ -7,18 +7,26 @@ echo "🚀 SSH Multi-Node Runner for MAD Engine" echo "========================================" echo "" -# Check if Python is available -if ! command -v python &> /dev/null; then +# Check if Python is available (try python3, python, or VIRTUAL_ENV python) +PYTHON_CMD="" +if command -v python3 &> /dev/null; then + PYTHON_CMD="python3" +elif command -v python &> /dev/null; then + PYTHON_CMD="python" +elif [ -n "$VIRTUAL_ENV" ] && [ -x "$VIRTUAL_ENV/bin/python" ]; then + PYTHON_CMD="$VIRTUAL_ENV/bin/python" +else echo "❌ Python is not installed or not in PATH" + echo "💡 If you're using a virtual environment, make sure it's activated" exit 1 fi -echo "✅ Python is available" +echo "✅ Python is available ($PYTHON_CMD)" # Check if paramiko is installed -if ! python -c "import paramiko" 2>/dev/null; then +if ! $PYTHON_CMD -c "import paramiko" 2>/dev/null; then echo "📦 Installing paramiko..." - pip install paramiko + $PYTHON_CMD -m pip install paramiko else echo "✅ paramiko is already installed" fi @@ -27,17 +35,17 @@ echo "" echo "🎯 Quick Start Examples:" echo "" echo "1. SSH Key Authentication:" -echo " python run.py --model pyt_megatron_lm_train_llama2_7b \\" +echo " $PYTHON_CMD run.py --model pyt_megatron_lm_train_llama2_7b \\" echo " --nodes 192.168.1.1,192.168.1.2 \\" echo " --master-addr 192.168.0.1 \\" echo " --ssh-user ubuntu \\" echo " --ssh-key ~/.ssh/id_rsa" echo "" echo "2. Configuration File:" -echo " python run.py --config config.ini" +echo " $PYTHON_CMD run.py --config config.ini" echo "" echo "3. Run Tests:" -echo " python test_runner.py" +echo " $PYTHON_CMD test_runner.py" echo "" echo "📖 For detailed documentation, see README.md" echo "✨ Ready to run multi-node training!" From 5b7c877d60f5eee4e55e04d4ccbb0819e7203ada Mon Sep 17 00:00:00 2001 From: Stephen Shao Date: Wed, 2 Jul 2025 14:20:01 -0400 Subject: [PATCH 07/17] Updated README.md --- runners/ssh/README.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/runners/ssh/README.md b/runners/ssh/README.md index 90d83d6b..78b3b42d 100644 --- a/runners/ssh/README.md +++ b/runners/ssh/README.md @@ -188,7 +188,8 @@ The generated command for each node follows this pattern: ```bash madengine run --tags pyt_megatron_lm_train_llama2_7b \ --additional-context "{'multi_node_args': { - 'RUNNER': 'torchrun', 'MASTER_ADDR': '192.168.0.1', + 'RUNNER': 'torchrun', + 'MASTER_ADDR': '192.168.0.1', 'MASTER_PORT': '4000', 'NNODES': '2', 'NODE_RANK': '0', # Different for each node From 7224d275d36d2603c93a0b677a103d77c297d43f Mon Sep 17 00:00:00 2001 From: Stephen Shao Date: Thu, 3 Jul 2025 12:08:04 -0500 Subject: [PATCH 08/17] Masked the nodes IP --- runners/ssh/config.ini | 2 +- runners/ssh/example.sh | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/runners/ssh/config.ini b/runners/ssh/config.ini index c2e826c5..7a5a3071 100644 --- a/runners/ssh/config.ini +++ b/runners/ssh/config.ini @@ -3,7 +3,7 @@ [cluster] # Comma-separated list of node hostnames or IPs -nodes = 10.194.128.135,10.194.129.35 +nodes = 192.168.0.1,192.168.0.2 # Master node configuration (defaults to first node if not specified) master_addr = 10.227.23.63 diff --git a/runners/ssh/example.sh b/runners/ssh/example.sh index 828321cd..bd2f30e3 100644 --- a/runners/ssh/example.sh +++ b/runners/ssh/example.sh @@ -3,7 +3,7 @@ # Configuration MODEL="pyt_megatron_lm_train_llama2_7b" -NODES="10.194.128.135,10.194.129.35" +NODES="192.168.0.1,192.168.0.2" MASTER_ADDR="10.227.23.63" MASTER_PORT="4000" SSH_USER="username" # Replace with your SSH username From 48d7e370b9d1b13271df3b807224ec94f35b5a61 Mon Sep 17 00:00:00 2001 From: Stephen Shao Date: Wed, 16 Jul 2025 10:22:45 -0400 Subject: [PATCH 09/17] Added logic to handle worknode setup --- runners/ssh/run.py | 33 ++++++++++++++------------------- 1 file changed, 14 insertions(+), 19 deletions(-) diff --git a/runners/ssh/run.py b/runners/ssh/run.py index 5b9271d9..9d80e2e2 100644 --- a/runners/ssh/run.py +++ b/runners/ssh/run.py @@ -151,7 +151,10 @@ def _validate_remote_prerequisites(self, hostname: str) -> Tuple[bool, str]: ) if stdout != "exists": - return False, f"{self.config.madengine.working_directory} folder not found on {hostname}" + return False, ( + f"{self.config.madengine.working_directory} folder not found on {hostname}. " + f"Please run: git clone https://github.com/ROCm/MAD.git {self.config.madengine.working_directory}" + ) # Check if madengine is accessible self.logger.debug(f"Checking madengine installation on {hostname}...") @@ -267,47 +270,39 @@ def _execute_on_node(self, hostname: str, node_rank: int) -> Tuple[str, bool, st ssh_manager = self.ssh_managers[hostname] try: - # Build and execute the madengine command + # Build madengine command command = self._build_madengine_command(node_rank) - - # Change to working directory and execute the command - full_command = f"cd {self.config.madengine.working_directory} && {command}" + # Compose setup and run commands + setup_commands = [ + f"cd {self.config.madengine.working_directory}", + "pip install -r requirements.txt" + ] + full_command = f"{' && '.join(setup_commands)} && {command}" self.logger.info(f"🚀 Executing on {hostname} (rank {node_rank}): {full_command}") - # Execute the command with streaming output with ssh_manager.get_client() as client: stdin, stdout, stderr = client.exec_command( - full_command, + full_command, timeout=self.config.training.timeout ) - - # Read output in real-time output_lines = [] error_lines = [] - - # Read stdout for line in stdout: line = line.strip() if line: self.logger.info(f"[{hostname}:{node_rank}] {line}") output_lines.append(line) - - # Read stderr for line in stderr: line = line.strip() if line: self.logger.warning(f"[{hostname}:{node_rank}] ERROR: {line}") error_lines.append(line) - - # Get exit code exit_code = stdout.channel.recv_exit_status() - if exit_code == 0: return hostname, True, '\n'.join(output_lines) else: error_output = '\n'.join(error_lines) if error_lines else "Command failed with no error output" return hostname, False, f"Exit code {exit_code}: {error_output}" - except Exception as e: return hostname, False, str(e) @@ -485,8 +480,8 @@ def parse_args() -> argparse.Namespace: parser.add_argument( '--working-directory', - default='DeepLearningModels', - help='Working directory on remote nodes (default: DeepLearningModels)' + default='MAD', + help='Working directory on remote nodes (default: MAD)' ) parser.add_argument( From 91d6e34203c55fe1f57275a2d24b882fe7fff10f Mon Sep 17 00:00:00 2001 From: Stephen Shao Date: Wed, 16 Jul 2025 10:50:24 -0400 Subject: [PATCH 10/17] Debug the ssh runner --- runners/ssh/run.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/runners/ssh/run.py b/runners/ssh/run.py index 9d80e2e2..2c9443ac 100644 --- a/runners/ssh/run.py +++ b/runners/ssh/run.py @@ -153,7 +153,7 @@ def _validate_remote_prerequisites(self, hostname: str) -> Tuple[bool, str]: if stdout != "exists": return False, ( f"{self.config.madengine.working_directory} folder not found on {hostname}. " - f"Please run: git clone https://github.com/ROCm/MAD.git {self.config.madengine.working_directory}" + f"Please run: git clone https://github.com/ROCm/DeepLearningModels.git {self.config.madengine.working_directory}" ) # Check if madengine is accessible @@ -480,8 +480,8 @@ def parse_args() -> argparse.Namespace: parser.add_argument( '--working-directory', - default='MAD', - help='Working directory on remote nodes (default: MAD)' + default='DeepLearningModels', + help='Working directory on remote nodes (default: DeepLearningModels)' ) parser.add_argument( From 479277a2562f5fbeaf8e4c0abd4bfec3fd9b1398 Mon Sep 17 00:00:00 2001 From: Stephen Shao Date: Wed, 16 Jul 2025 11:12:13 -0400 Subject: [PATCH 11/17] Updated cwd to MAD --- runners/ssh/config.ini | 2 +- runners/ssh/config_manager.py | 6 +++--- runners/ssh/run.py | 10 +++++----- runners/ssh/test_runner.py | 6 +++--- 4 files changed, 12 insertions(+), 12 deletions(-) diff --git a/runners/ssh/config.ini b/runners/ssh/config.ini index 7a5a3071..07f9b5ec 100644 --- a/runners/ssh/config.ini +++ b/runners/ssh/config.ini @@ -45,4 +45,4 @@ timeout = 7200 path = madengine # Working directory on remote nodes -working_directory = DeepLearningModels +working_directory = MAD diff --git a/runners/ssh/config_manager.py b/runners/ssh/config_manager.py index 98323235..f3135fd3 100644 --- a/runners/ssh/config_manager.py +++ b/runners/ssh/config_manager.py @@ -74,7 +74,7 @@ def __post_init__(self): class MadEngineConfig: """MAD Engine specific configuration.""" path: str = 'madengine' - working_directory: str = 'DeepLearningModels' + working_directory: str = 'MAD' @dataclass @@ -123,7 +123,7 @@ def from_args(cls, args) -> 'MultiNodeConfig': madengine_config = MadEngineConfig( path=getattr(args, 'madengine_path', 'madengine'), - working_directory=getattr(args, 'working_directory', 'DeepLearningModels') + working_directory=getattr(args, 'working_directory', 'MAD') ) return cls( @@ -189,7 +189,7 @@ def from_config_file(cls, config_path: str) -> 'MultiNodeConfig': madengine_section = config['madengine'] if 'madengine' in config else {} madengine_config = MadEngineConfig( path=madengine_section.get('path', 'madengine'), - working_directory=madengine_section.get('working_directory', 'DeepLearningModels') + working_directory=madengine_section.get('working_directory', 'MAD') ) return cls( diff --git a/runners/ssh/run.py b/runners/ssh/run.py index 2c9443ac..bac71b18 100644 --- a/runners/ssh/run.py +++ b/runners/ssh/run.py @@ -153,7 +153,7 @@ def _validate_remote_prerequisites(self, hostname: str) -> Tuple[bool, str]: if stdout != "exists": return False, ( f"{self.config.madengine.working_directory} folder not found on {hostname}. " - f"Please run: git clone https://github.com/ROCm/DeepLearningModels.git {self.config.madengine.working_directory}" + f"Please run: git clone https://github.com/ROCm/MAD.git {self.config.madengine.working_directory}" ) # Check if madengine is accessible @@ -236,7 +236,7 @@ def _print_setup_instructions(self) -> None: 2. MAD Engine Installation: • Install madengine on each remote node - • Command: pip install madengine + • Command: pip install git+https://github.com/ROCm/madengine.git@main • Verify with: {self.config.madengine.path} --help 3. Shared Data Path: @@ -275,7 +275,7 @@ def _execute_on_node(self, hostname: str, node_rank: int) -> Tuple[str, bool, st # Compose setup and run commands setup_commands = [ f"cd {self.config.madengine.working_directory}", - "pip install -r requirements.txt" + "pip install git+https://github.com/ROCm/madengine.git@main ] full_command = f"{' && '.join(setup_commands)} && {command}" self.logger.info(f"🚀 Executing on {hostname} (rank {node_rank}): {full_command}") @@ -480,8 +480,8 @@ def parse_args() -> argparse.Namespace: parser.add_argument( '--working-directory', - default='DeepLearningModels', - help='Working directory on remote nodes (default: DeepLearningModels)' + default='MAD', + help='Working directory on remote nodes (default: MAD)' ) parser.add_argument( diff --git a/runners/ssh/test_runner.py b/runners/ssh/test_runner.py index 2fed0daa..48a3361a 100644 --- a/runners/ssh/test_runner.py +++ b/runners/ssh/test_runner.py @@ -98,7 +98,7 @@ def test_config_from_args(self): ('timeout', 3600), ('additional_args', ''), ('madengine_path', 'madengine'), - ('working_directory', 'DeepLearningModels'), + ('working_directory', 'MAD'), ('ssh_timeout', 30), ('ssh_max_retries', 3) ]: @@ -265,7 +265,7 @@ def test_prerequisites_validation_success(self, mock_execute): mock_execute.side_effect = [ (0, "exists", ""), # Working directory exists (0, "found", ""), # madengine found - (0, "/home/user/DeepLearningModels", "") # Directory accessible + (0, "/home/user/MAD", "") # Directory accessible ] runner = SSHMultiNodeRunner(self.config) @@ -283,7 +283,7 @@ def test_prerequisites_validation_missing_directory(self, mock_execute): success, error_msg = runner._validate_remote_prerequisites("node1") self.assertFalse(success) - self.assertIn("DeepLearningModels folder not found", error_msg) + self.assertIn("MAD folder not found", error_msg) class TestIntegration(unittest.TestCase): From 3bdaed8bc383abf9a596a22628c81b18c819c10f Mon Sep 17 00:00:00 2001 From: Stephen Shao Date: Wed, 16 Jul 2025 12:55:09 -0400 Subject: [PATCH 12/17] Debug the madengine check --- runners/ssh/run.py | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/runners/ssh/run.py b/runners/ssh/run.py index bac71b18..959a7a30 100644 --- a/runners/ssh/run.py +++ b/runners/ssh/run.py @@ -159,17 +159,10 @@ def _validate_remote_prerequisites(self, hostname: str) -> Tuple[bool, str]: # Check if madengine is accessible self.logger.debug(f"Checking madengine installation on {hostname}...") exit_code, stdout, stderr = ssh_manager.execute_command( - f'which {self.config.madengine.path} > /dev/null 2>&1 && echo "found" || echo "missing"' + f'{self.config.madengine.path} --help > /dev/null 2>&1 && echo "found" || echo "missing"' ) - - if stdout != "found": - # Try alternative check - exit_code, stdout, stderr = ssh_manager.execute_command( - f'{self.config.madengine.path} --help > /dev/null 2>&1 && echo "found" || echo "missing"' - ) - - if stdout != "found": - return False, f"madengine not found or not accessible on {hostname}" + if stdout.strip() != "found": + return False, f"madengine not found or not accessible on {hostname}" # Check if we can access the working directory self.logger.debug(f"Checking access to {self.config.madengine.working_directory} directory on {hostname}...") From 2c498aa266a5238cf89c83c6151113fb5e82e109 Mon Sep 17 00:00:00 2001 From: Stephen Shao Date: Wed, 16 Jul 2025 13:03:31 -0400 Subject: [PATCH 13/17] Fixed the error in string --- runners/ssh/run.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/runners/ssh/run.py b/runners/ssh/run.py index 959a7a30..8d4ab02a 100644 --- a/runners/ssh/run.py +++ b/runners/ssh/run.py @@ -268,7 +268,7 @@ def _execute_on_node(self, hostname: str, node_rank: int) -> Tuple[str, bool, st # Compose setup and run commands setup_commands = [ f"cd {self.config.madengine.working_directory}", - "pip install git+https://github.com/ROCm/madengine.git@main + "pip install git+https://github.com/ROCm/madengine.git@main" ] full_command = f"{' && '.join(setup_commands)} && {command}" self.logger.info(f"🚀 Executing on {hostname} (rank {node_rank}): {full_command}") From 18e30ab33dd9c096b32185c838bd5d1677a99a60 Mon Sep 17 00:00:00 2001 From: Stephen Shao Date: Wed, 16 Jul 2025 13:21:53 -0400 Subject: [PATCH 14/17] Debug the command executation --- runners/ssh/run.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/runners/ssh/run.py b/runners/ssh/run.py index 8d4ab02a..0b795278 100644 --- a/runners/ssh/run.py +++ b/runners/ssh/run.py @@ -157,12 +157,12 @@ def _validate_remote_prerequisites(self, hostname: str) -> Tuple[bool, str]: ) # Check if madengine is accessible - self.logger.debug(f"Checking madengine installation on {hostname}...") - exit_code, stdout, stderr = ssh_manager.execute_command( - f'{self.config.madengine.path} --help > /dev/null 2>&1 && echo "found" || echo "missing"' - ) + madengine_check_cmd = f'{self.config.madengine.path} --help > /dev/null 2>&1 && echo "found" || echo "missing"' + self.logger.debug(f"Checking madengine installation on {hostname} with command: {madengine_check_cmd}") + exit_code, stdout, stderr = ssh_manager.execute_command(madengine_check_cmd) + self.logger.error(f"[DEBUG] madengine check on {hostname}: exit_code={exit_code}, stdout='{stdout}', stderr='{stderr}'") if stdout.strip() != "found": - return False, f"madengine not found or not accessible on {hostname}" + return False, f"madengine not found or not accessible on {hostname} (exit_code={exit_code}, stdout='{stdout}', stderr='{stderr}')" # Check if we can access the working directory self.logger.debug(f"Checking access to {self.config.madengine.working_directory} directory on {hostname}...") From 929c7418279fee7084f49977e26b996a179eb409 Mon Sep 17 00:00:00 2001 From: Stephen Shao Date: Wed, 16 Jul 2025 14:11:01 -0400 Subject: [PATCH 15/17] Refactored the flow of setup remote node --- runners/ssh/run.py | 49 ++++++++++++++++++++++++++++++++++------------ 1 file changed, 36 insertions(+), 13 deletions(-) diff --git a/runners/ssh/run.py b/runners/ssh/run.py index 0b795278..771b84a6 100644 --- a/runners/ssh/run.py +++ b/runners/ssh/run.py @@ -144,25 +144,48 @@ def _validate_remote_prerequisites(self, hostname: str) -> Tuple[bool, str]: ssh_manager = self.ssh_managers[hostname] try: - # Check if working directory exists - self.logger.debug(f"Checking {self.config.madengine.working_directory} folder on {hostname}...") + + # Check if MAD directory exists, clone if missing + self.logger.info(f"Checking if {self.config.madengine.working_directory} directory exists on {hostname}...") exit_code, stdout, stderr = ssh_manager.execute_command( f'test -d {self.config.madengine.working_directory} && echo "exists" || echo "missing"' ) - - if stdout != "exists": - return False, ( - f"{self.config.madengine.working_directory} folder not found on {hostname}. " - f"Please run: git clone https://github.com/ROCm/MAD.git {self.config.madengine.working_directory}" + if stdout.strip() == "missing": + self.logger.info(f"{self.config.madengine.working_directory} not found on {hostname}, cloning MAD repo...") + exit_code, stdout, stderr = ssh_manager.execute_command( + f'git clone https://github.com/ROCm/MAD.git {self.config.madengine.working_directory}' ) - - # Check if madengine is accessible - madengine_check_cmd = f'{self.config.madengine.path} --help > /dev/null 2>&1 && echo "found" || echo "missing"' - self.logger.debug(f"Checking madengine installation on {hostname} with command: {madengine_check_cmd}") + if exit_code != 0: + return False, f"Failed to clone MAD repo to {self.config.madengine.working_directory} on {hostname}: {stderr}" + elif exit_code != 0: + return False, f"Failed to check {self.config.madengine.working_directory} on {hostname}: {stderr}" + + # Ensure venv exists, create if missing + venv_path = os.path.join(self.config.madengine.working_directory, 'venv') + venv_python = os.path.join(venv_path, 'bin', 'python3') + venv_madengine = os.path.join(venv_path, 'bin', 'madengine') + self.logger.info(f"Ensuring venv exists at {venv_path} on {hostname}...") + exit_code, stdout, stderr = ssh_manager.execute_command( + f'cd {self.config.madengine.working_directory} && [ -d venv ] || python3 -m venv venv' + ) + if exit_code != 0: + return False, f"Failed to create venv in {self.config.madengine.working_directory} on {hostname}: {stderr}" + + # Install madengine in venv + self.logger.info(f"Installing madengine in venv on {hostname}...") + exit_code, stdout, stderr = ssh_manager.execute_command( + f'cd {self.config.madengine.working_directory} && source venv/bin/activate && pip install --upgrade pip && pip install git+https://github.com/ROCm/madengine.git@main' + ) + if exit_code != 0: + return False, f"Failed to install madengine in venv on {hostname}: {stderr}" + + # Check if madengine is accessible in venv + madengine_check_cmd = f'cd {self.config.madengine.working_directory} && source venv/bin/activate && madengine --help > /dev/null 2>&1 && echo "found" || echo "missing"' + self.logger.debug(f"Checking madengine installation in venv on {hostname} with command: {madengine_check_cmd}") exit_code, stdout, stderr = ssh_manager.execute_command(madengine_check_cmd) - self.logger.error(f"[DEBUG] madengine check on {hostname}: exit_code={exit_code}, stdout='{stdout}', stderr='{stderr}'") + self.logger.error(f"[DEBUG] madengine check in venv on {hostname}: exit_code={exit_code}, stdout='{stdout}', stderr='{stderr}'") if stdout.strip() != "found": - return False, f"madengine not found or not accessible on {hostname} (exit_code={exit_code}, stdout='{stdout}', stderr='{stderr}')" + return False, f"madengine not found or not accessible in venv on {hostname} (exit_code={exit_code}, stdout='{stdout}', stderr='{stderr}')" # Check if we can access the working directory self.logger.debug(f"Checking access to {self.config.madengine.working_directory} directory on {hostname}...") From 133cfd5d8f65d8a554d136012999f94713202515 Mon Sep 17 00:00:00 2001 From: Stephen Shao Date: Wed, 16 Jul 2025 14:24:50 -0400 Subject: [PATCH 16/17] Fix the virtual env for madengine installation --- runners/ssh/run.py | 41 +++++++++-------------------------------- 1 file changed, 9 insertions(+), 32 deletions(-) diff --git a/runners/ssh/run.py b/runners/ssh/run.py index 771b84a6..ec90683a 100644 --- a/runners/ssh/run.py +++ b/runners/ssh/run.py @@ -73,14 +73,7 @@ def __init__(self, config: MultiNodeConfig): ) def _build_madengine_command(self, node_rank: int) -> str: - """Build the madengine command for a specific node. - - Args: - node_rank: The rank of this node (0-based) - - Returns: - Complete madengine command string - """ + """Build the madengine command for a specific node, using venv's madengine binary.""" multi_node_args = { 'RUNNER': 'torchrun', 'MASTER_ADDR': self.config.cluster.master_addr, @@ -90,26 +83,19 @@ def _build_madengine_command(self, node_rank: int) -> str: 'NCCL_SOCKET_IFNAME': self.config.training.nccl_interface, 'GLOO_SOCKET_IFNAME': self.config.training.gloo_interface } - - # Build the additional context string additional_context = f"'{json.dumps({'multi_node_args': multi_node_args})}'" - - # Build the complete command + # Use venv/bin/madengine explicitly + madengine_bin = os.path.join('venv', 'bin', 'madengine') cmd_parts = [ - self.config.madengine.path, + madengine_bin, 'run', '--tags', self.config.training.model, '--additional-context', additional_context ] - - # Add shared data path if specified if self.config.training.shared_data_path: cmd_parts.extend(['--force-mirror-local', self.config.training.shared_data_path]) - - # Add any additional arguments if self.config.training.additional_args: cmd_parts.append(self.config.training.additional_args) - return ' '.join(cmd_parts) def _check_node_connectivity(self) -> List[str]: @@ -274,28 +260,19 @@ def _print_setup_instructions(self) -> None: print(instructions) def _execute_on_node(self, hostname: str, node_rank: int) -> Tuple[str, bool, str]: - """Execute madengine command on a single node. - - Args: - hostname: The hostname/IP of the node - node_rank: The rank of this node - - Returns: - Tuple of (hostname, success, output/error) - """ + """Execute madengine command on a single node, ensuring venv usage.""" ssh_manager = self.ssh_managers[hostname] - try: - # Build madengine command + # Build madengine command (uses venv/bin/madengine) command = self._build_madengine_command(node_rank) - # Compose setup and run commands + # Compose setup and run commands, always use venv's python/pip setup_commands = [ f"cd {self.config.madengine.working_directory}", - "pip install git+https://github.com/ROCm/madengine.git@main" + "venv/bin/python -m pip install --upgrade pip", + "venv/bin/python -m pip install git+https://github.com/ROCm/madengine.git@main" ] full_command = f"{' && '.join(setup_commands)} && {command}" self.logger.info(f"🚀 Executing on {hostname} (rank {node_rank}): {full_command}") - # Execute the command with streaming output with ssh_manager.get_client() as client: stdin, stdout, stderr = client.exec_command( full_command, From 48401fe45f44f10446025884244b63d274d3886f Mon Sep 17 00:00:00 2001 From: Stephen Shao Date: Wed, 16 Jul 2025 17:16:59 -0400 Subject: [PATCH 17/17] Updated pip install statement --- runners/ssh/run.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/runners/ssh/run.py b/runners/ssh/run.py index ec90683a..82b518f7 100644 --- a/runners/ssh/run.py +++ b/runners/ssh/run.py @@ -160,7 +160,7 @@ def _validate_remote_prerequisites(self, hostname: str) -> Tuple[bool, str]: # Install madengine in venv self.logger.info(f"Installing madengine in venv on {hostname}...") exit_code, stdout, stderr = ssh_manager.execute_command( - f'cd {self.config.madengine.working_directory} && source venv/bin/activate && pip install --upgrade pip && pip install git+https://github.com/ROCm/madengine.git@main' + f'cd {self.config.madengine.working_directory} && source venv/bin/activate && pip install --upgrade pip && pip install git+https://github.com/ROCm/madengine.git@main -q' ) if exit_code != 0: return False, f"Failed to install madengine in venv on {hostname}: {stderr}" @@ -269,7 +269,7 @@ def _execute_on_node(self, hostname: str, node_rank: int) -> Tuple[str, bool, st setup_commands = [ f"cd {self.config.madengine.working_directory}", "venv/bin/python -m pip install --upgrade pip", - "venv/bin/python -m pip install git+https://github.com/ROCm/madengine.git@main" + "venv/bin/python -m pip install git+https://github.com/ROCm/madengine.git@main -q" ] full_command = f"{' && '.join(setup_commands)} && {command}" self.logger.info(f"🚀 Executing on {hostname} (rank {node_rank}): {full_command}")