diff --git a/runners/ssh/README.md b/runners/ssh/README.md new file mode 100644 index 00000000..78b3b42d --- /dev/null +++ b/runners/ssh/README.md @@ -0,0 +1,200 @@ +# SSH Multi-Node Runner for MAD Engine + +This SSH runner automates the execution of PyTorch Megatron-LM training across multiple nodes using SSH connections. + +## Features + +- ✅ Automated SSH connection management +- ✅ Parallel execution across multiple nodes +- ✅ Real-time output streaming from all nodes +- ✅ Robust error handling and connectivity checking +- ✅ Support for both SSH key and password authentication +- ✅ Configurable network interfaces (NCCL/GLOO) +- ✅ Shared filesystem support + +## Prerequisites + +1. **Python Dependencies**: + ```bash + pip install -r requirements.txt + ``` + + Or use the quick-start script: + ```bash + bash quick-start.sh + ``` + +2. **SSH Access**: Ensure you have SSH access to all target nodes with either: + - SSH key-based authentication (recommended) + - Password-based authentication + +3. **Shared Filesystem**: All nodes should have access to a shared filesystem for data (e.g., NFS mount) + +4. **MAD Engine**: Ensure `madengine` is installed and accessible on all target nodes + +## Usage + +### Basic Usage with SSH Key + +```bash +python run.py --model pyt_megatron_lm_train_llama2_7b \ + --nodes 192.168.1.1,192.168.1.2 \ + --master-addr 192.168.0.1 \ + --ssh-user ubuntu \ + --ssh-key ~/.ssh/id_rsa \ + --shared-data-path /nfs/data +``` + +### Usage with Password Authentication + +```bash +python run.py --model pyt_megatron_lm_train_llama2_7b \ + --nodes node1.cluster.com,node2.cluster.com \ + --ssh-user root \ + --ssh-password mypassword \ + --shared-data-path /shared/data +``` + +### Advanced Configuration + +```bash +python run.py --model pyt_megatron_lm_train_llama2_7b \ + --nodes 192.168.1.10,192.168.1.11,192.168.1.12 \ + --master-addr 192.168.1.10 \ + --master-port 5000 \ + --ssh-user mluser \ + --ssh-key /home/user/.ssh/cluster_key \ + --shared-data-path /mnt/nfs/datasets \ + --nccl-interface eth0 \ + --gloo-interface eth0 \ + --timeout 7200 \ + --additional-args "--some-extra-flag" +``` + +## Command Line Arguments + +### Required Arguments + +- `--model`: Model tag to run (e.g., `pyt_megatron_lm_train_llama2_7b`) +- `--nodes`: Comma-separated list of node hostnames/IPs +- `--ssh-user`: SSH username for all nodes + +### Authentication (one required) + +- `--ssh-password`: SSH password for all nodes +- `--ssh-key`: Path to SSH private key file + +### Optional Arguments + +- `--master-addr`: Master node address (defaults to first node) +- `--master-port`: Master node port (default: 4000) +- `--shared-data-path`: Path to shared data filesystem (default: /nfs/data) +- `--nccl-interface`: NCCL socket interface (default: ens14np0) +- `--gloo-interface`: GLOO socket interface (default: ens14np0) +- `--timeout`: Execution timeout in seconds (default: 3600) +- `--madengine-path`: Path to madengine executable (default: madengine) +- `--additional-args`: Additional arguments to pass to madengine + +## How It Works + +1. **Connectivity Check**: Verifies SSH connectivity to all nodes +2. **Command Generation**: Builds appropriate `madengine` commands for each node with correct `NODE_RANK` +3. **Parallel Execution**: Executes commands on all nodes simultaneously using threading +4. **Output Streaming**: Streams real-time output from all nodes with node identification +5. **Result Aggregation**: Collects and reports results from all nodes + +## Example Output + +``` +🌐 Starting multi-node training on 2 nodes +📋 Model: pyt_megatron_lm_train_llama2_7b +🏠 Master: 192.168.0.1:4000 +📁 Shared data: /nfs/data +🔗 Nodes: 192.168.1.1, 192.168.1.2 + +🔍 Checking connectivity to all nodes... +✓ 192.168.1.1 is reachable +✓ 192.168.1.2 is reachable +✅ All nodes are reachable + +🚀 Executing on 192.168.1.1 (rank 0): madengine run --tags pyt_megatron_lm_train_llama2_7b ... +🚀 Executing on 192.168.1.2 (rank 1): madengine run --tags pyt_megatron_lm_train_llama2_7b ... + +[192.168.1.1:0] Starting training... +[192.168.1.2:1] Starting training... +... +✅ 192.168.1.1 completed successfully +✅ 192.168.1.2 completed successfully + +📊 Training Results: +✅ Successful nodes: 2/2 +🎉 Multi-node training completed successfully! +``` + +## Network Configuration + +For optimal performance, ensure: + +1. **Network Interface**: Use the correct network interface names for `--nccl-interface` and `--gloo-interface` + ```bash + # Check available interfaces on your nodes + ssh user@node "ip addr show" + ``` + +2. **Firewall**: Ensure the master port is open between nodes + ```bash + # Example: Open port 4000 on Ubuntu/Debian + sudo ufw allow 4000 + ``` + +3. **Shared Storage**: Verify shared filesystem is mounted on all nodes + ```bash + # Check if NFS mount is available + ssh user@node "ls -la /nfs/data" + ``` + +## Troubleshooting + +### SSH Connection Issues + +- Verify SSH key permissions: `chmod 600 ~/.ssh/id_rsa` +- Test manual SSH connection: `ssh -i ~/.ssh/id_rsa user@node` +- Check SSH agent: `ssh-add ~/.ssh/id_rsa` + +### Network Communication Issues + +- Verify nodes can reach each other on the master port +- Check firewall settings +- Ensure correct network interface names + +### MAD Engine Issues + +- Verify madengine is installed on all nodes: `ssh user@node "which madengine"` +- Check shared data path exists: `ssh user@node "ls -la /nfs/data"` +- Review madengine logs for specific errors + +## Integration with MAD Engine + +This SSH runner integrates seamlessly with the MAD Engine multi-node framework: + +- Automatically configures `multi_node_args` for each node +- Sets appropriate `NODE_RANK` for each node (0, 1, 2, ...) +- Configures `NNODES` based on the number of nodes provided +- Uses `torchrun` as the distributed runner +- Handles network interface configuration for NCCL and GLOO + +The generated command for each node follows this pattern: + +```bash +madengine run --tags pyt_megatron_lm_train_llama2_7b \ + --additional-context "{'multi_node_args': { + 'RUNNER': 'torchrun', + 'MASTER_ADDR': '192.168.0.1', + 'MASTER_PORT': '4000', + 'NNODES': '2', + 'NODE_RANK': '0', # Different for each node + 'NCCL_SOCKET_IFNAME': 'ens14np0', + 'GLOO_SOCKET_IFNAME': 'ens14np0' + }}" \ + --force-mirror-local /nfs/data +``` diff --git a/runners/ssh/__init__.py b/runners/ssh/__init__.py new file mode 100644 index 00000000..4de2b246 --- /dev/null +++ b/runners/ssh/__init__.py @@ -0,0 +1,46 @@ +"""SSH Multi-Node Runner for MAD Engine + +This package provides SSH-based multi-node distributed training capabilities +for the MAD Engine framework. + +Main Components: +- SSHMultiNodeRunner: Main orchestration class +- SSHClientManager: Robust SSH connection management +- MultiNodeConfig: Configuration management +- Configuration validation and setup instructions +- Utilities: Common helper functions + +Example Usage: + from runners.ssh import SSHMultiNodeRunner, MultiNodeConfig + + config = MultiNodeConfig.from_config_file('config.ini') + runner = SSHMultiNodeRunner(config) + success = runner.run() + +Copyright (c) Advanced Micro Devices, Inc. All rights reserved. +""" + +from .config_manager import ( + MultiNodeConfig, + SSHConfig, + ClusterConfig, + TrainingConfig, + MadEngineConfig +) +from .ssh_client_manager import SSHClientManager +from .run import SSHMultiNodeRunner +from . import utils + +__version__ = "1.0.0" +__author__ = "Advanced Micro Devices, Inc." + +__all__ = [ + 'SSHMultiNodeRunner', + 'SSHClientManager', + 'MultiNodeConfig', + 'SSHConfig', + 'ClusterConfig', + 'TrainingConfig', + 'MadEngineConfig', + 'utils' +] \ No newline at end of file diff --git a/runners/ssh/config.ini b/runners/ssh/config.ini new file mode 100644 index 00000000..07f9b5ec --- /dev/null +++ b/runners/ssh/config.ini @@ -0,0 +1,48 @@ +# Configuration for SSH multi-node runner +# Copy this file to config.ini and customize for your environment + +[cluster] +# Comma-separated list of node hostnames or IPs +nodes = 192.168.0.1,192.168.0.2 + +# Master node configuration (defaults to first node if not specified) +master_addr = 10.227.23.63 +master_port = 4000 + +[ssh] +# SSH authentication - use either key_file OR password (key_file is recommended) +user = username + +# SSH key-based authentication (recommended) +key_file = ~/.ssh/id_ed25519 + +# Password-based authentication (less secure, comment out if using key_file) +# password = your_password_here + +# SSH connection settings +timeout = 30 +max_retries = 3 + +[training] +# Model to train +model = pyt_megatron_lm_train_llama2_7b + +# Shared filesystem path where data is located +shared_data_path = /nfs/data + +# Network interfaces for distributed communication +nccl_interface = ens14np0 +gloo_interface = ens14np0 + +# Execution timeout in seconds (2 hours) +timeout = 7200 + +# Additional arguments to pass to madengine (optional) +# additional_args = --live-output --some-other-flag + +[madengine] +# Path to madengine executable (if not in PATH) +path = madengine + +# Working directory on remote nodes +working_directory = MAD diff --git a/runners/ssh/config_manager.py b/runners/ssh/config_manager.py new file mode 100644 index 00000000..f3135fd3 --- /dev/null +++ b/runners/ssh/config_manager.py @@ -0,0 +1,294 @@ +#!/usr/bin/env python3 +"""Configuration Management for SSH Multi-Node Runner + +This module provides configuration validation and management for the SSH runner. + +Copyright (c) Advanced Micro Devices, Inc. All rights reserved. +""" + +import configparser +import os +from dataclasses import dataclass, field +from typing import Dict, List, Optional, Any +from pathlib import Path + + +@dataclass +class SSHConfig: + """SSH configuration settings.""" + user: str + password: Optional[str] = None + key_file: Optional[str] = None + timeout: int = 30 + max_retries: int = 3 + + def __post_init__(self): + """Validate SSH configuration after initialization.""" + if not self.password and not self.key_file: + raise ValueError("Either SSH password or key file must be specified") + + if self.key_file and not os.path.exists(self.key_file): + raise FileNotFoundError(f"SSH key file not found: {self.key_file}") + + +@dataclass +class ClusterConfig: + """Cluster configuration settings.""" + nodes: List[str] + master_addr: Optional[str] = None + master_port: int = 4000 + + def __post_init__(self): + """Validate cluster configuration after initialization.""" + if not self.nodes: + raise ValueError("At least one node must be specified") + + # Clean up node list + self.nodes = [node.strip() for node in self.nodes if node.strip()] + + if not self.nodes: + raise ValueError("At least one valid node must be specified") + + # Set master address to first node if not specified + if not self.master_addr: + self.master_addr = self.nodes[0] + + +@dataclass +class TrainingConfig: + """Training configuration settings.""" + model: str + shared_data_path: str = '/nfs/data' + nccl_interface: str = 'ens14np0' + gloo_interface: str = 'ens14np0' + timeout: int = 3600 + additional_args: str = '' + + def __post_init__(self): + """Validate training configuration after initialization.""" + if not self.model: + raise ValueError("Model must be specified") + + +@dataclass +class MadEngineConfig: + """MAD Engine specific configuration.""" + path: str = 'madengine' + working_directory: str = 'MAD' + + +@dataclass +class MultiNodeConfig: + """Complete multi-node runner configuration.""" + ssh: SSHConfig + cluster: ClusterConfig + training: TrainingConfig + madengine: MadEngineConfig = field(default_factory=MadEngineConfig) + + @classmethod + def from_args(cls, args) -> 'MultiNodeConfig': + """Create configuration from command line arguments. + + Args: + args: Parsed command line arguments + + Returns: + Complete configuration object + """ + # Parse nodes list + nodes = [node.strip() for node in args.nodes.split(',') if node.strip()] + + ssh_config = SSHConfig( + user=args.ssh_user, + password=getattr(args, 'ssh_password', None), + key_file=getattr(args, 'ssh_key', None), + timeout=getattr(args, 'ssh_timeout', 30), + max_retries=getattr(args, 'ssh_max_retries', 3) + ) + + cluster_config = ClusterConfig( + nodes=nodes, + master_addr=getattr(args, 'master_addr', None), + master_port=getattr(args, 'master_port', 4000) + ) + + training_config = TrainingConfig( + model=args.model, + shared_data_path=getattr(args, 'shared_data_path', '/nfs/data'), + nccl_interface=getattr(args, 'nccl_interface', 'ens14np0'), + gloo_interface=getattr(args, 'gloo_interface', 'ens14np0'), + timeout=getattr(args, 'timeout', 3600), + additional_args=getattr(args, 'additional_args', '') + ) + + madengine_config = MadEngineConfig( + path=getattr(args, 'madengine_path', 'madengine'), + working_directory=getattr(args, 'working_directory', 'MAD') + ) + + return cls( + ssh=ssh_config, + cluster=cluster_config, + training=training_config, + madengine=madengine_config + ) + + @classmethod + def from_config_file(cls, config_path: str) -> 'MultiNodeConfig': + """Create configuration from INI file. + + Args: + config_path: Path to configuration file + + Returns: + Complete configuration object + + Raises: + FileNotFoundError: If config file doesn't exist + configparser.Error: If config file is malformed + """ + if not os.path.exists(config_path): + raise FileNotFoundError(f"Configuration file not found: {config_path}") + + config = configparser.ConfigParser() + config.read(config_path) + + # Parse SSH configuration + ssh_section = config['ssh'] if 'ssh' in config else {} + ssh_config = SSHConfig( + user=ssh_section.get('user'), + password=ssh_section.get('password'), + key_file=ssh_section.get('key_file'), + timeout=int(ssh_section.get('timeout', 30)), + max_retries=int(ssh_section.get('max_retries', 3)) + ) + + # Parse cluster configuration + cluster_section = config['cluster'] if 'cluster' in config else {} + nodes_str = cluster_section.get('nodes', '') + nodes = [node.strip() for node in nodes_str.split(',') if node.strip()] + + cluster_config = ClusterConfig( + nodes=nodes, + master_addr=cluster_section.get('master_addr'), + master_port=int(cluster_section.get('master_port', 4000)) + ) + + # Parse training configuration + training_section = config['training'] if 'training' in config else {} + training_config = TrainingConfig( + model=training_section.get('model'), + shared_data_path=training_section.get('shared_data_path', '/nfs/data'), + nccl_interface=training_section.get('nccl_interface', 'ens14np0'), + gloo_interface=training_section.get('gloo_interface', 'ens14np0'), + timeout=int(training_section.get('timeout', 3600)), + additional_args=training_section.get('additional_args', '') + ) + + # Parse madengine configuration + madengine_section = config['madengine'] if 'madengine' in config else {} + madengine_config = MadEngineConfig( + path=madengine_section.get('path', 'madengine'), + working_directory=madengine_section.get('working_directory', 'MAD') + ) + + return cls( + ssh=ssh_config, + cluster=cluster_config, + training=training_config, + madengine=madengine_config + ) + + def validate(self) -> None: + """Validate the complete configuration. + + Raises: + ValueError: If configuration is invalid + """ + # Configurations are validated in their __post_init__ methods + pass + + def to_dict(self) -> Dict[str, Any]: + """Convert configuration to dictionary. + + Returns: + Configuration as dictionary + """ + return { + 'ssh': { + 'user': self.ssh.user, + 'has_password': bool(self.ssh.password), + 'has_key_file': bool(self.ssh.key_file), + 'timeout': self.ssh.timeout, + 'max_retries': self.ssh.max_retries + }, + 'cluster': { + 'nodes': self.cluster.nodes, + 'master_addr': self.cluster.master_addr, + 'master_port': self.cluster.master_port + }, + 'training': { + 'model': self.training.model, + 'shared_data_path': self.training.shared_data_path, + 'nccl_interface': self.training.nccl_interface, + 'gloo_interface': self.training.gloo_interface, + 'timeout': self.training.timeout, + 'additional_args': self.training.additional_args + }, + 'madengine': { + 'path': self.madengine.path, + 'working_directory': self.madengine.working_directory + } + } + + +def merge_config_file_with_args(config_path: str, args) -> 'MultiNodeConfig': + """Merge configuration file with command line arguments. + + Command line arguments take precedence over config file values. + + Args: + config_path: Path to configuration file + args: Command line arguments + + Returns: + Merged configuration + """ + # Start with config file + config = MultiNodeConfig.from_config_file(config_path) + + # Override with command line arguments if provided + if hasattr(args, 'nodes') and args.nodes: + nodes = [node.strip() for node in args.nodes.split(',') if node.strip()] + config.cluster.nodes = nodes + + if hasattr(args, 'master_addr') and args.master_addr: + config.cluster.master_addr = args.master_addr + + if hasattr(args, 'master_port') and args.master_port: + config.cluster.master_port = args.master_port + + if hasattr(args, 'ssh_user') and args.ssh_user: + config.ssh.user = args.ssh_user + + if hasattr(args, 'ssh_password') and args.ssh_password: + config.ssh.password = args.ssh_password + + if hasattr(args, 'ssh_key') and args.ssh_key: + config.ssh.key_file = args.ssh_key + + if hasattr(args, 'model') and args.model: + config.training.model = args.model + + if hasattr(args, 'shared_data_path') and args.shared_data_path: + config.training.shared_data_path = args.shared_data_path + + if hasattr(args, 'timeout') and args.timeout: + config.training.timeout = args.timeout + + if hasattr(args, 'madengine_path') and args.madengine_path: + config.madengine.path = args.madengine_path + + # Re-validate after merging + config.validate() + return config diff --git a/runners/ssh/example.sh b/runners/ssh/example.sh new file mode 100644 index 00000000..bd2f30e3 --- /dev/null +++ b/runners/ssh/example.sh @@ -0,0 +1,67 @@ +#!/bin/bash +# Example script showing how to use the SSH multi-node runner + +# Configuration +MODEL="pyt_megatron_lm_train_llama2_7b" +NODES="192.168.0.1,192.168.0.2" +MASTER_ADDR="10.227.23.63" +MASTER_PORT="4000" +SSH_USER="username" # Replace with your SSH username +SSH_KEY="$HOME/.ssh/id_ed25519" +SHARED_DATA="/nfs/data" +NCCL_INTERFACE="ens14np0" +GLOO_INTERFACE="ens14np0" + +echo "🚀 Starting multi-node training with SSH runner" +echo "📋 Model: $MODEL" +echo "🔗 Nodes: $NODES" +echo "🏠 Master: $MASTER_ADDR:$MASTER_PORT" + +# Validate SSH key exists +if [ ! -f "$SSH_KEY" ]; then + echo "❌ SSH key file not found: $SSH_KEY" + echo "💡 Available SSH keys in ~/.ssh/:" + ls -la ~/.ssh/*.pub 2>/dev/null || echo " No SSH keys found" + echo "" + echo "🔧 To generate a new SSH key, run:" + echo " ssh-keygen -t ed25519 -f ~/.ssh/id_ed25519" + echo " ssh-copy-id -i ~/.ssh/id_ed25519.pub $SSH_USER@" + exit 1 +fi + +echo "✅ SSH key validated: $SSH_KEY" + +# Detect Python command +PYTHON_CMD="" +if command -v python3 &> /dev/null; then + PYTHON_CMD="python3" +elif command -v python &> /dev/null; then + PYTHON_CMD="python" +elif [ -n "$VIRTUAL_ENV" ] && [ -x "$VIRTUAL_ENV/bin/python" ]; then + PYTHON_CMD="$VIRTUAL_ENV/bin/python" +else + echo "❌ Python is not installed or not in PATH" + echo "💡 If you're using a virtual environment, make sure it's activated" + exit 1 +fi + +# Install requirements if not already installed +if ! $PYTHON_CMD -c "import paramiko" 2>/dev/null; then + echo "📦 Installing required packages..." + $PYTHON_CMD -m pip install -r requirements.txt +fi + +# Run the SSH multi-node runner +$PYTHON_CMD run.py \ + --model "$MODEL" \ + --nodes "$NODES" \ + --master-addr "$MASTER_ADDR" \ + --master-port "$MASTER_PORT" \ + --ssh-user "$SSH_USER" \ + --ssh-key "$SSH_KEY" \ + --shared-data-path "$SHARED_DATA" \ + --nccl-interface "$NCCL_INTERFACE" \ + --gloo-interface "$GLOO_INTERFACE" \ + --timeout 7200 + +echo "✅ Multi-node training completed!" diff --git a/runners/ssh/monitoring.py b/runners/ssh/monitoring.py new file mode 100644 index 00000000..d6323798 --- /dev/null +++ b/runners/ssh/monitoring.py @@ -0,0 +1,340 @@ +#!/usr/bin/env python3 +"""Monitoring and logging utilities for SSH Multi-Node Runner + +This module provides enhanced monitoring and logging capabilities. + +Copyright (c) Advanced Micro Devices, Inc. All rights reserved. +""" + +import json +import logging +import time +from datetime import datetime +from pathlib import Path +from typing import Dict, List, Any, Optional +from dataclasses import dataclass, asdict + + +@dataclass +class NodeExecutionResult: + """Result of execution on a single node.""" + hostname: str + node_rank: int + success: bool + start_time: float + end_time: float + output: str + error_message: Optional[str] = None + + @property + def duration(self) -> float: + """Get execution duration in seconds.""" + return self.end_time - self.start_time + + @property + def duration_formatted(self) -> str: + """Get formatted duration string.""" + from .utils import format_duration + return format_duration(self.duration) + + +@dataclass +class TrainingSession: + """Complete training session information.""" + session_id: str + model: str + nodes: List[str] + master_addr: str + master_port: int + start_time: float + end_time: Optional[float] = None + results: List[NodeExecutionResult] = None + + def __post_init__(self): + if self.results is None: + self.results = [] + + @property + def duration(self) -> Optional[float]: + """Get total session duration in seconds.""" + if self.end_time is None: + return None + return self.end_time - self.start_time + + @property + def success_rate(self) -> float: + """Get success rate as percentage.""" + if not self.results: + return 0.0 + successful = sum(1 for r in self.results if r.success) + return (successful / len(self.results)) * 100.0 + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary for serialization.""" + data = asdict(self) + data['duration'] = self.duration + data['success_rate'] = self.success_rate + return data + + +class SessionLogger: + """Logger for training sessions with structured output.""" + + def __init__(self, log_dir: str = "logs", session_id: Optional[str] = None): + """Initialize session logger. + + Args: + log_dir: Directory to store log files + session_id: Unique session identifier (auto-generated if None) + """ + self.log_dir = Path(log_dir) + self.log_dir.mkdir(exist_ok=True) + + if session_id is None: + session_id = f"session_{datetime.now().strftime('%Y%m%d_%H%M%S')}" + + self.session_id = session_id + self.session_file = self.log_dir / f"{session_id}.json" + self.log_file = self.log_dir / f"{session_id}.log" + + # Setup file logger + self.logger = logging.getLogger(f"session.{session_id}") + self.logger.setLevel(logging.DEBUG) + + # Create file handler + file_handler = logging.FileHandler(str(self.log_file)) + file_handler.setLevel(logging.DEBUG) + + # Create formatter + formatter = logging.Formatter( + '%(asctime)s - %(name)s - %(levelname)s - %(message)s' + ) + file_handler.setFormatter(formatter) + + self.logger.addHandler(file_handler) + + self.session: Optional[TrainingSession] = None + + def start_session(self, model: str, nodes: List[str], master_addr: str, master_port: int) -> None: + """Start a new training session. + + Args: + model: Model being trained + nodes: List of node hostnames + master_addr: Master node address + master_port: Master node port + """ + self.session = TrainingSession( + session_id=self.session_id, + model=model, + nodes=nodes.copy(), + master_addr=master_addr, + master_port=master_port, + start_time=time.time() + ) + + self.logger.info(f"Starting training session {self.session_id}") + self.logger.info(f"Model: {model}") + self.logger.info(f"Nodes: {', '.join(nodes)}") + self.logger.info(f"Master: {master_addr}:{master_port}") + + self._save_session() + + def log_node_start(self, hostname: str, node_rank: int, command: str) -> None: + """Log the start of execution on a node. + + Args: + hostname: Node hostname + node_rank: Node rank + command: Command being executed + """ + self.logger.info(f"Node {hostname} (rank {node_rank}) starting: {command}") + + def log_node_output(self, hostname: str, node_rank: int, line: str, is_error: bool = False) -> None: + """Log output from a node. + + Args: + hostname: Node hostname + node_rank: Node rank + line: Output line + is_error: Whether this is error output + """ + level = logging.ERROR if is_error else logging.INFO + prefix = "ERROR" if is_error else "OUTPUT" + self.logger.log(level, f"[{hostname}:{node_rank}] {prefix}: {line}") + + def log_node_result(self, result: NodeExecutionResult) -> None: + """Log the result of execution on a node. + + Args: + result: Node execution result + """ + if self.session is None: + raise RuntimeError("No active session") + + self.session.results.append(result) + + status = "SUCCESS" if result.success else "FAILED" + self.logger.info( + f"Node {result.hostname} (rank {result.node_rank}) {status} " + f"in {result.duration_formatted}" + ) + + if not result.success and result.error_message: + self.logger.error(f"Node {result.hostname} error: {result.error_message}") + + self._save_session() + + def end_session(self, success: bool) -> None: + """End the training session. + + Args: + success: Whether the overall session was successful + """ + if self.session is None: + raise RuntimeError("No active session") + + self.session.end_time = time.time() + + status = "SUCCESS" if success else "FAILED" + duration = self.session.duration + from .utils import format_duration + duration_str = format_duration(duration) if duration else "unknown" + + self.logger.info(f"Training session {status} in {duration_str}") + self.logger.info(f"Success rate: {self.session.success_rate:.1f}%") + + self._save_session() + + def _save_session(self) -> None: + """Save session data to JSON file.""" + if self.session is None: + return + + try: + with open(self.session_file, 'w') as f: + json.dump(self.session.to_dict(), f, indent=2, default=str) + except Exception as e: + self.logger.error(f"Failed to save session data: {e}") + + def get_session_summary(self) -> Dict[str, Any]: + """Get session summary. + + Returns: + Dictionary containing session summary + """ + if self.session is None: + return {} + + return { + 'session_id': self.session.session_id, + 'model': self.session.model, + 'total_nodes': len(self.session.nodes), + 'completed_nodes': len(self.session.results), + 'successful_nodes': sum(1 for r in self.session.results if r.success), + 'failed_nodes': sum(1 for r in self.session.results if not r.success), + 'success_rate': self.session.success_rate, + 'duration': self.session.duration, + 'status': 'completed' if self.session.end_time else 'running' + } + + +class ProgressMonitor: + """Monitor training progress across nodes.""" + + def __init__(self, total_nodes: int): + """Initialize progress monitor. + + Args: + total_nodes: Total number of nodes + """ + self.total_nodes = total_nodes + self.completed_nodes = 0 + self.successful_nodes = 0 + self.failed_nodes = 0 + self.start_time = time.time() + + def update(self, success: bool) -> None: + """Update progress with a completed node. + + Args: + success: Whether the node completed successfully + """ + self.completed_nodes += 1 + if success: + self.successful_nodes += 1 + else: + self.failed_nodes += 1 + + def get_progress(self) -> Dict[str, Any]: + """Get current progress information. + + Returns: + Dictionary containing progress information + """ + elapsed_time = time.time() - self.start_time + completion_rate = self.completed_nodes / self.total_nodes if self.total_nodes > 0 else 0 + + # Estimate remaining time + if completion_rate > 0: + estimated_total_time = elapsed_time / completion_rate + estimated_remaining_time = estimated_total_time - elapsed_time + else: + estimated_remaining_time = None + + return { + 'total_nodes': self.total_nodes, + 'completed_nodes': self.completed_nodes, + 'successful_nodes': self.successful_nodes, + 'failed_nodes': self.failed_nodes, + 'completion_rate': completion_rate, + 'elapsed_time': elapsed_time, + 'estimated_remaining_time': estimated_remaining_time, + 'success_rate': self.successful_nodes / self.completed_nodes if self.completed_nodes > 0 else 0 + } + + def print_progress(self) -> None: + """Print current progress to console.""" + progress = self.get_progress() + + from .utils import format_duration + elapsed = format_duration(progress['elapsed_time']) + + if progress['estimated_remaining_time']: + remaining = format_duration(progress['estimated_remaining_time']) + time_info = f"Elapsed: {elapsed}, Remaining: ~{remaining}" + else: + time_info = f"Elapsed: {elapsed}" + + print(f"Progress: {progress['completed_nodes']}/{progress['total_nodes']} " + f"({progress['completion_rate']*100:.1f}%) - " + f"Success: {progress['successful_nodes']}, Failed: {progress['failed_nodes']} - " + f"{time_info}") + + +def load_session_history(log_dir: str = "logs") -> List[Dict[str, Any]]: + """Load session history from log directory. + + Args: + log_dir: Directory containing log files + + Returns: + List of session summaries + """ + log_path = Path(log_dir) + if not log_path.exists(): + return [] + + sessions = [] + for json_file in log_path.glob("session_*.json"): + try: + with open(json_file, 'r') as f: + session_data = json.load(f) + sessions.append(session_data) + except Exception: + # Skip corrupted files + continue + + # Sort by start time (most recent first) + sessions.sort(key=lambda x: x.get('start_time', 0), reverse=True) + return sessions diff --git a/runners/ssh/quick-start.sh b/runners/ssh/quick-start.sh new file mode 100644 index 00000000..7f8ab95f --- /dev/null +++ b/runners/ssh/quick-start.sh @@ -0,0 +1,51 @@ +#!/bin/bash +# Quick start script for SSH Multi-Node Runner + +set -e + +echo "🚀 SSH Multi-Node Runner for MAD Engine" +echo "========================================" +echo "" + +# Check if Python is available (try python3, python, or VIRTUAL_ENV python) +PYTHON_CMD="" +if command -v python3 &> /dev/null; then + PYTHON_CMD="python3" +elif command -v python &> /dev/null; then + PYTHON_CMD="python" +elif [ -n "$VIRTUAL_ENV" ] && [ -x "$VIRTUAL_ENV/bin/python" ]; then + PYTHON_CMD="$VIRTUAL_ENV/bin/python" +else + echo "❌ Python is not installed or not in PATH" + echo "💡 If you're using a virtual environment, make sure it's activated" + exit 1 +fi + +echo "✅ Python is available ($PYTHON_CMD)" + +# Check if paramiko is installed +if ! $PYTHON_CMD -c "import paramiko" 2>/dev/null; then + echo "📦 Installing paramiko..." + $PYTHON_CMD -m pip install paramiko +else + echo "✅ paramiko is already installed" +fi + +echo "" +echo "🎯 Quick Start Examples:" +echo "" +echo "1. SSH Key Authentication:" +echo " $PYTHON_CMD run.py --model pyt_megatron_lm_train_llama2_7b \\" +echo " --nodes 192.168.1.1,192.168.1.2 \\" +echo " --master-addr 192.168.0.1 \\" +echo " --ssh-user ubuntu \\" +echo " --ssh-key ~/.ssh/id_rsa" +echo "" +echo "2. Configuration File:" +echo " $PYTHON_CMD run.py --config config.ini" +echo "" +echo "3. Run Tests:" +echo " $PYTHON_CMD test_runner.py" +echo "" +echo "📖 For detailed documentation, see README.md" +echo "✨ Ready to run multi-node training!" diff --git a/runners/ssh/requirements.txt b/runners/ssh/requirements.txt new file mode 100644 index 00000000..0e7b042e --- /dev/null +++ b/runners/ssh/requirements.txt @@ -0,0 +1,4 @@ +# SSH Multi-Node Runner Requirements + +# Core SSH functionality +paramiko>=2.9.0,<4.0.0 diff --git a/runners/ssh/run.py b/runners/ssh/run.py new file mode 100644 index 00000000..82b518f7 --- /dev/null +++ b/runners/ssh/run.py @@ -0,0 +1,541 @@ +#!/usr/bin/env python3 +"""SSH Multi-Node Runner for MAD Engine + +This script orchestrates distributed training across multiple nodes using SSH. +It automatically configures and executes madengine commands on remote nodes +for PyTorch Megatron-LM training workloads. + +Example Usage: + python run.py --model pyt_megatron_lm_train_llama2_7b \ + --nodes 192.168.1.1,192.168.1.2 \ + --master-addr 192.168.0.1 \ + --master-port 4000 \ + --ssh-user username \ + --ssh-key /path/to/ssh/key \ + --shared-data-path /nfs/data + +Copyright (c) Advanced Micro Devices, Inc. All rights reserved. +""" + +import argparse +import json +import logging +import os +import sys +import time +from concurrent.futures import ThreadPoolExecutor, as_completed +from typing import List, Tuple, Optional + +# Local imports +try: + from config_manager import MultiNodeConfig, merge_config_file_with_args + from ssh_client_manager import SSHClientManager +except ImportError: + # Fallback for direct execution + sys.path.insert(0, os.path.dirname(__file__)) + from config_manager import MultiNodeConfig, merge_config_file_with_args + from ssh_client_manager import SSHClientManager + +# Setup logging +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' +) + + +class ValidationError(Exception): + """Exception raised when validation fails.""" + pass + + +class SSHMultiNodeRunner: + """SSH-based multi-node runner for distributed training.""" + + def __init__(self, config: MultiNodeConfig): + """Initialize the SSH multi-node runner. + + Args: + config: Complete configuration object + """ + self.config = config + self.logger = logging.getLogger(self.__class__.__name__) + + # Create SSH client managers for each node + self.ssh_managers = {} + for node in config.cluster.nodes: + self.ssh_managers[node] = SSHClientManager( + hostname=node, + username=config.ssh.user, + password=config.ssh.password, + key_filename=config.ssh.key_file, + timeout=config.ssh.timeout, + max_retries=config.ssh.max_retries + ) + + def _build_madengine_command(self, node_rank: int) -> str: + """Build the madengine command for a specific node, using venv's madengine binary.""" + multi_node_args = { + 'RUNNER': 'torchrun', + 'MASTER_ADDR': self.config.cluster.master_addr, + 'MASTER_PORT': str(self.config.cluster.master_port), + 'NNODES': str(len(self.config.cluster.nodes)), + 'NODE_RANK': str(node_rank), + 'NCCL_SOCKET_IFNAME': self.config.training.nccl_interface, + 'GLOO_SOCKET_IFNAME': self.config.training.gloo_interface + } + additional_context = f"'{json.dumps({'multi_node_args': multi_node_args})}'" + # Use venv/bin/madengine explicitly + madengine_bin = os.path.join('venv', 'bin', 'madengine') + cmd_parts = [ + madengine_bin, + 'run', + '--tags', self.config.training.model, + '--additional-context', additional_context + ] + if self.config.training.shared_data_path: + cmd_parts.extend(['--force-mirror-local', self.config.training.shared_data_path]) + if self.config.training.additional_args: + cmd_parts.append(self.config.training.additional_args) + return ' '.join(cmd_parts) + + def _check_node_connectivity(self) -> List[str]: + """Check connectivity to all nodes. + + Returns: + List of nodes that are reachable + """ + reachable_nodes = [] + + self.logger.info("Checking connectivity to all nodes...") + + for hostname in self.config.cluster.nodes: + ssh_manager = self.ssh_managers[hostname] + if ssh_manager.test_connectivity(): + reachable_nodes.append(hostname) + self.logger.info(f"✓ {hostname} is reachable") + else: + self.logger.error(f"✗ {hostname} is not reachable") + + return reachable_nodes + + def _validate_remote_prerequisites(self, hostname: str) -> Tuple[bool, str]: + """Validate that remote node has required prerequisites. + + Args: + hostname: The hostname/IP of the node to validate + + Returns: + Tuple of (success, error_message) + """ + ssh_manager = self.ssh_managers[hostname] + + try: + + # Check if MAD directory exists, clone if missing + self.logger.info(f"Checking if {self.config.madengine.working_directory} directory exists on {hostname}...") + exit_code, stdout, stderr = ssh_manager.execute_command( + f'test -d {self.config.madengine.working_directory} && echo "exists" || echo "missing"' + ) + if stdout.strip() == "missing": + self.logger.info(f"{self.config.madengine.working_directory} not found on {hostname}, cloning MAD repo...") + exit_code, stdout, stderr = ssh_manager.execute_command( + f'git clone https://github.com/ROCm/MAD.git {self.config.madengine.working_directory}' + ) + if exit_code != 0: + return False, f"Failed to clone MAD repo to {self.config.madengine.working_directory} on {hostname}: {stderr}" + elif exit_code != 0: + return False, f"Failed to check {self.config.madengine.working_directory} on {hostname}: {stderr}" + + # Ensure venv exists, create if missing + venv_path = os.path.join(self.config.madengine.working_directory, 'venv') + venv_python = os.path.join(venv_path, 'bin', 'python3') + venv_madengine = os.path.join(venv_path, 'bin', 'madengine') + self.logger.info(f"Ensuring venv exists at {venv_path} on {hostname}...") + exit_code, stdout, stderr = ssh_manager.execute_command( + f'cd {self.config.madengine.working_directory} && [ -d venv ] || python3 -m venv venv' + ) + if exit_code != 0: + return False, f"Failed to create venv in {self.config.madengine.working_directory} on {hostname}: {stderr}" + + # Install madengine in venv + self.logger.info(f"Installing madengine in venv on {hostname}...") + exit_code, stdout, stderr = ssh_manager.execute_command( + f'cd {self.config.madengine.working_directory} && source venv/bin/activate && pip install --upgrade pip && pip install git+https://github.com/ROCm/madengine.git@main -q' + ) + if exit_code != 0: + return False, f"Failed to install madengine in venv on {hostname}: {stderr}" + + # Check if madengine is accessible in venv + madengine_check_cmd = f'cd {self.config.madengine.working_directory} && source venv/bin/activate && madengine --help > /dev/null 2>&1 && echo "found" || echo "missing"' + self.logger.debug(f"Checking madengine installation in venv on {hostname} with command: {madengine_check_cmd}") + exit_code, stdout, stderr = ssh_manager.execute_command(madengine_check_cmd) + self.logger.error(f"[DEBUG] madengine check in venv on {hostname}: exit_code={exit_code}, stdout='{stdout}', stderr='{stderr}'") + if stdout.strip() != "found": + return False, f"madengine not found or not accessible in venv on {hostname} (exit_code={exit_code}, stdout='{stdout}', stderr='{stderr}')" + + # Check if we can access the working directory + self.logger.debug(f"Checking access to {self.config.madengine.working_directory} directory on {hostname}...") + exit_code, stdout, stderr = ssh_manager.execute_command( + f'cd {self.config.madengine.working_directory} && pwd' + ) + + if stderr or not stdout.endswith(self.config.madengine.working_directory): + return False, f"Cannot access {self.config.madengine.working_directory} directory on {hostname}" + + # Check shared data path if specified and not default + if self.config.training.shared_data_path != '/nfs/data': + self.logger.debug(f"Checking shared data path on {hostname}...") + exit_code, stdout, stderr = ssh_manager.execute_command( + f'test -d "{self.config.training.shared_data_path}" && echo "exists" || echo "missing"' + ) + + if stdout != "exists": + return False, f"Shared data path '{self.config.training.shared_data_path}' not found on {hostname}" + + return True, "" + + except Exception as e: + return False, f"Error validating prerequisites on {hostname}: {str(e)}" + + def _check_all_prerequisites(self) -> bool: + """Check prerequisites on all nodes. + + Returns: + True if all nodes meet prerequisites, False otherwise + """ + self.logger.info("Validating prerequisites on all nodes...") + + failed_nodes = [] + + for hostname in self.config.cluster.nodes: + success, error_msg = self._validate_remote_prerequisites(hostname) + if not success: + self.logger.error(f"❌ {hostname}: {error_msg}") + failed_nodes.append((hostname, error_msg)) + else: + self.logger.info(f"✅ {hostname}: All prerequisites met") + + if failed_nodes: + self.logger.error(f"Prerequisites check failed for {len(failed_nodes)} node(s)") + self._print_setup_instructions() + return False + + self.logger.info("All nodes meet the prerequisites") + return True + + def _print_setup_instructions(self) -> None: + """Print setup instructions for remote nodes.""" + instructions = f""" +{"="*60} +🔧 REMOTE NODE SETUP INSTRUCTIONS +{"="*60} + +To prepare your remote nodes for multi-node training: + +1. {self.config.madengine.working_directory} Directory: + • Create or ensure the {self.config.madengine.working_directory} directory exists + • Command: mkdir -p ~/{self.config.madengine.working_directory} + +2. MAD Engine Installation: + • Install madengine on each remote node + • Command: pip install git+https://github.com/ROCm/madengine.git@main + • Verify with: {self.config.madengine.path} --help + +3. Shared Data Path: + • Ensure the shared data path '{self.config.training.shared_data_path}' exists + • This should be a shared filesystem mounted on all nodes + +4. SSH Access: + • Ensure SSH key-based or password authentication is configured + • Test SSH access manually before running this script + +5. Network Configuration: + • Ensure nodes can communicate on the specified interfaces + • NCCL interface: {self.config.training.nccl_interface} + • GLOO interface: {self.config.training.gloo_interface} + • Master node {self.config.cluster.master_addr} should be accessible on port {self.config.cluster.master_port} + +{"="*60} + """ + print(instructions) + + def _execute_on_node(self, hostname: str, node_rank: int) -> Tuple[str, bool, str]: + """Execute madengine command on a single node, ensuring venv usage.""" + ssh_manager = self.ssh_managers[hostname] + try: + # Build madengine command (uses venv/bin/madengine) + command = self._build_madengine_command(node_rank) + # Compose setup and run commands, always use venv's python/pip + setup_commands = [ + f"cd {self.config.madengine.working_directory}", + "venv/bin/python -m pip install --upgrade pip", + "venv/bin/python -m pip install git+https://github.com/ROCm/madengine.git@main -q" + ] + full_command = f"{' && '.join(setup_commands)} && {command}" + self.logger.info(f"🚀 Executing on {hostname} (rank {node_rank}): {full_command}") + with ssh_manager.get_client() as client: + stdin, stdout, stderr = client.exec_command( + full_command, + timeout=self.config.training.timeout + ) + output_lines = [] + error_lines = [] + for line in stdout: + line = line.strip() + if line: + self.logger.info(f"[{hostname}:{node_rank}] {line}") + output_lines.append(line) + for line in stderr: + line = line.strip() + if line: + self.logger.warning(f"[{hostname}:{node_rank}] ERROR: {line}") + error_lines.append(line) + exit_code = stdout.channel.recv_exit_status() + if exit_code == 0: + return hostname, True, '\n'.join(output_lines) + else: + error_output = '\n'.join(error_lines) if error_lines else "Command failed with no error output" + return hostname, False, f"Exit code {exit_code}: {error_output}" + except Exception as e: + return hostname, False, str(e) + + def run(self) -> bool: + """Run distributed training across all nodes. + + Returns: + True if all nodes completed successfully, False otherwise + """ + self.logger.info(f"Starting multi-node training on {len(self.config.cluster.nodes)} nodes") + self.logger.info(f"Model: {self.config.training.model}") + self.logger.info(f"Master: {self.config.cluster.master_addr}:{self.config.cluster.master_port}") + self.logger.info(f"Shared data: {self.config.training.shared_data_path}") + self.logger.info(f"Nodes: {', '.join(self.config.cluster.nodes)}") + + # Check connectivity to all nodes + reachable_nodes = self._check_node_connectivity() + + if len(reachable_nodes) != len(self.config.cluster.nodes): + unreachable = set(self.config.cluster.nodes) - set(reachable_nodes) + self.logger.error(f"Some nodes are unreachable: {', '.join(unreachable)}") + return False + + self.logger.info("All nodes are reachable") + + # Validate prerequisites on all nodes + if not self._check_all_prerequisites(): + return False + + # Execute on all nodes concurrently + results = [] + + with ThreadPoolExecutor(max_workers=len(self.config.cluster.nodes)) as executor: + # Submit jobs for all nodes + futures = [] + for i, hostname in enumerate(self.config.cluster.nodes): + future = executor.submit(self._execute_on_node, hostname, i) + futures.append(future) + + # Collect results as they complete + for future in as_completed(futures): + hostname, success, output = future.result() + results.append((hostname, success, output)) + + if success: + self.logger.info(f"✅ {hostname} completed successfully") + else: + self.logger.error(f"❌ {hostname} failed: {output}") + + # Check overall success + successful_nodes = [r[0] for r in results if r[1]] + failed_nodes = [r[0] for r in results if not r[1]] + + self.logger.info(f"Training Results:") + self.logger.info(f"✅ Successful nodes: {len(successful_nodes)}/{len(self.config.cluster.nodes)}") + + if failed_nodes: + self.logger.error(f"❌ Failed nodes: {', '.join(failed_nodes)}") + return False + + self.logger.info("🎉 Multi-node training completed successfully!") + return True + + +def parse_args() -> argparse.Namespace: + """Parse command line arguments. + + Returns: + Parsed arguments + """ + parser = argparse.ArgumentParser( + description="SSH Multi-Node Runner for MAD Engine", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + # Run with SSH key authentication + python run.py --model pyt_megatron_lm_train_llama2_7b \\ + --nodes 192.168.1.1,192.168.1.2 \\ + --master-addr 192.168.0.1 \\ + --ssh-user ubuntu \\ + --ssh-key ~/.ssh/id_rsa + + # Run with password authentication + python run.py --model pyt_megatron_lm_train_llama2_7b \\ + --nodes node1,node2,node3 \\ + --ssh-user root \\ + --ssh-password mypassword \\ + --shared-data-path /shared/data + + # Run with configuration file + python run.py --config config.ini + """ + ) + + # Configuration file option + parser.add_argument( + '--config', + help='Path to configuration file (INI format)' + ) + + # Required arguments (unless provided in config) + parser.add_argument( + '--model', + help='Model tag to run (e.g., pyt_megatron_lm_train_llama2_7b)' + ) + + parser.add_argument( + '--nodes', + help='Comma-separated list of node hostnames/IPs' + ) + + parser.add_argument( + '--ssh-user', + help='SSH username for all nodes' + ) + + # SSH authentication (one required unless in config) + parser.add_argument( + '--ssh-password', + help='SSH password for all nodes' + ) + parser.add_argument( + '--ssh-key', + help='Path to SSH private key file' + ) + + # Optional arguments + parser.add_argument( + '--master-addr', + help='Master node address (defaults to first node)' + ) + + parser.add_argument( + '--master-port', + type=int, + default=4000, + help='Master node port (default: 4000)' + ) + + parser.add_argument( + '--shared-data-path', + default='/nfs/data', + help='Path to shared data filesystem (default: /nfs/data)' + ) + + parser.add_argument( + '--nccl-interface', + default='ens14np0', + help='NCCL socket interface (default: ens14np0)' + ) + + parser.add_argument( + '--gloo-interface', + default='ens14np0', + help='GLOO socket interface (default: ens14np0)' + ) + + parser.add_argument( + '--timeout', + type=int, + default=3600, + help='Execution timeout in seconds (default: 3600)' + ) + + parser.add_argument( + '--madengine-path', + default='madengine', + help='Path to madengine executable (default: madengine)' + ) + + parser.add_argument( + '--additional-args', + help='Additional arguments to pass to madengine' + ) + + parser.add_argument( + '--working-directory', + default='MAD', + help='Working directory on remote nodes (default: MAD)' + ) + + parser.add_argument( + '--ssh-timeout', + type=int, + default=30, + help='SSH connection timeout in seconds (default: 30)' + ) + + parser.add_argument( + '--ssh-max-retries', + type=int, + default=3, + help='Maximum SSH connection retry attempts (default: 3)' + ) + + parser.add_argument( + '--log-level', + default='INFO', + choices=['DEBUG', 'INFO', 'WARNING', 'ERROR'], + help='Logging level (default: INFO)' + ) + + return parser.parse_args() + + +def main(): + """Main entry point.""" + try: + args = parse_args() + + # Set logging level + logging.getLogger().setLevel(getattr(logging, args.log_level)) + + # Create configuration + if args.config: + config = merge_config_file_with_args(args.config, args) + else: + config = MultiNodeConfig.from_args(args) + + # Validate configuration + config.validate() + + # Create and run the runner + runner = SSHMultiNodeRunner(config) + success = runner.run() + + if success: + print("🎯 All nodes completed successfully!") + sys.exit(0) + else: + print("💥 Some nodes failed!") + sys.exit(1) + + except KeyboardInterrupt: + print("\n🛑 Interrupted by user") + sys.exit(1) + except Exception as e: + logging.error(f"Error: {e}") + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/runners/ssh/ssh_client_manager.py b/runners/ssh/ssh_client_manager.py new file mode 100644 index 00000000..38929b29 --- /dev/null +++ b/runners/ssh/ssh_client_manager.py @@ -0,0 +1,161 @@ +#!/usr/bin/env python3 +"""SSH Client Manager for MAD Engine Multi-Node Runner + +This module provides a robust SSH client management class with connection pooling, +error handling, and retry mechanisms. + +Copyright (c) Advanced Micro Devices, Inc. All rights reserved. +""" + +import logging +import socket +import time +from contextlib import contextmanager +from typing import Optional, Tuple, Iterator + +try: + import paramiko +except ImportError: + raise ImportError("paramiko is required but not installed. Please install it with: pip install paramiko") + + +class SSHClientManager: + """Manages SSH connections with robust error handling and retry mechanisms.""" + + def __init__(self, hostname: str, username: str, password: Optional[str] = None, + key_filename: Optional[str] = None, timeout: int = 30, max_retries: int = 3): + """Initialize SSH client manager. + + Args: + hostname: Target hostname or IP address + username: SSH username + password: SSH password (if using password auth) + key_filename: Path to SSH private key (if using key auth) + timeout: Connection timeout in seconds + max_retries: Maximum number of connection retries + """ + self.hostname = hostname + self.username = username + self.password = password + self.key_filename = key_filename + self.timeout = timeout + self.max_retries = max_retries + self.logger = logging.getLogger(f"{__name__}.{hostname}") + + if not password and not key_filename: + raise ValueError("Either password or key_filename must be provided") + + def _create_client(self) -> paramiko.SSHClient: + """Create and configure a new SSH client. + + Returns: + Configured SSH client + """ + client = paramiko.SSHClient() + client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + client.load_system_host_keys() + return client + + def _connect_with_retry(self, client: paramiko.SSHClient) -> None: + """Connect SSH client with retry mechanism. + + Args: + client: SSH client to connect + + Raises: + ConnectionError: If all connection attempts fail + """ + last_exception = None + + for attempt in range(self.max_retries): + try: + if self.key_filename: + client.connect( + hostname=self.hostname, + username=self.username, + key_filename=self.key_filename, + timeout=self.timeout + ) + else: + client.connect( + hostname=self.hostname, + username=self.username, + password=self.password, + timeout=self.timeout + ) + + self.logger.debug(f"Successfully connected to {self.hostname} on attempt {attempt + 1}") + return + + except (paramiko.ssh_exception.AuthenticationException, + paramiko.ssh_exception.SSHException, + socket.error) as e: + last_exception = e + self.logger.warning(f"Connection attempt {attempt + 1} failed: {e}") + + if attempt < self.max_retries - 1: + wait_time = 2 ** attempt # Exponential backoff + self.logger.info(f"Retrying in {wait_time} seconds...") + time.sleep(wait_time) + + raise ConnectionError(f"Failed to connect to {self.hostname} after {self.max_retries} attempts. Last error: {last_exception}") + + @contextmanager + def get_client(self) -> Iterator[paramiko.SSHClient]: + """Get an SSH client with automatic cleanup. + + Yields: + Connected SSH client + + Raises: + ConnectionError: If connection fails + """ + client = self._create_client() + try: + self._connect_with_retry(client) + yield client + finally: + try: + client.close() + except Exception as e: + self.logger.warning(f"Error closing SSH connection: {e}") + + def execute_command(self, command: str, timeout: Optional[int] = None) -> Tuple[int, str, str]: + """Execute a command on the remote host. + + Args: + command: Command to execute + timeout: Command timeout (uses default if None) + + Returns: + Tuple of (exit_code, stdout, stderr) + + Raises: + ConnectionError: If SSH connection fails + TimeoutError: If command times out + """ + with self.get_client() as client: + try: + stdin, stdout, stderr = client.exec_command(command, timeout=timeout or self.timeout) + + exit_code = stdout.channel.recv_exit_status() + stdout_text = stdout.read().decode('utf-8', errors='replace').strip() + stderr_text = stderr.read().decode('utf-8', errors='replace').strip() + + return exit_code, stdout_text, stderr_text + + except socket.timeout: + raise TimeoutError(f"Command timed out after {timeout or self.timeout} seconds: {command}") + + def test_connectivity(self) -> bool: + """Test connectivity to the remote host. + + Returns: + True if connection successful, False otherwise + """ + try: + exit_code, stdout, stderr = self.execute_command('echo "connectivity_test"') + return exit_code == 0 and stdout.strip() == "connectivity_test" + except Exception as e: + self.logger.error(f"Connectivity test failed: {e}") + return False diff --git a/runners/ssh/test_runner.py b/runners/ssh/test_runner.py new file mode 100644 index 00000000..48a3361a --- /dev/null +++ b/runners/ssh/test_runner.py @@ -0,0 +1,420 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +"""Test script for SSH Multi-Node Runner + +This script provides comprehensive unit tests and validation for the SSH runner. +""" + +import os +import sys +import tempfile +import unittest +from unittest.mock import patch, MagicMock, call +from pathlib import Path + +# Add the current directory to the Python path +sys.path.insert(0, os.path.dirname(__file__)) + +try: + from config_manager import MultiNodeConfig, SSHConfig, ClusterConfig, TrainingConfig + from ssh_client_manager import SSHClientManager + from run import SSHMultiNodeRunner +except ImportError as e: + print(f"Error: Could not import modules: {e}") + print("Make sure all modules are in the same directory.") + sys.exit(1) + + +class TestConfigManager(unittest.TestCase): + """Test cases for configuration management.""" + + def test_ssh_config_validation(self): + """Test SSH configuration validation.""" + # Valid configuration with password + config = SSHConfig(user="testuser", password="testpass") + self.assertEqual(config.user, "testuser") + self.assertEqual(config.password, "testpass") + + # Valid configuration with key file (mock file existence) + with patch('os.path.exists', return_value=True): + config = SSHConfig(user="testuser", key_file="/path/to/key") + self.assertEqual(config.key_file, "/path/to/key") + + # Invalid configuration - no auth method + with self.assertRaises(ValueError): + SSHConfig(user="testuser") + + # Invalid configuration - non-existent key file + with patch('os.path.exists', return_value=False): + with self.assertRaises(FileNotFoundError): + SSHConfig(user="testuser", key_file="/nonexistent/key") + + def test_cluster_config_validation(self): + """Test cluster configuration validation.""" + # Valid configuration + config = ClusterConfig(nodes=["node1", "node2"]) + self.assertEqual(config.nodes, ["node1", "node2"]) + self.assertEqual(config.master_addr, "node1") # Should default to first node + + # Valid configuration with master_addr specified + config = ClusterConfig(nodes=["node1", "node2"], master_addr="node1") + self.assertEqual(config.master_addr, "node1") + + # Invalid configuration - empty nodes + with self.assertRaises(ValueError): + ClusterConfig(nodes=[]) + + # Invalid configuration - whitespace-only nodes + with self.assertRaises(ValueError): + ClusterConfig(nodes=["", " ", "\t"]) + + def test_training_config_validation(self): + """Test training configuration validation.""" + # Valid configuration + config = TrainingConfig(model="test_model") + self.assertEqual(config.model, "test_model") + self.assertEqual(config.shared_data_path, "/nfs/data") # Default + + # Invalid configuration - empty model + with self.assertRaises(ValueError): + TrainingConfig(model="") + + def test_config_from_args(self): + """Test configuration creation from arguments.""" + mock_args = MagicMock() + mock_args.model = 'test_model' + mock_args.nodes = 'node1,node2,node3' + mock_args.ssh_user = 'testuser' + mock_args.ssh_password = 'testpass' + mock_args.ssh_key = None + + # Set default values for optional attributes + for attr, default in [ + ('master_addr', None), + ('master_port', 4000), + ('shared_data_path', '/nfs/data'), + ('nccl_interface', 'ens14np0'), + ('gloo_interface', 'ens14np0'), + ('timeout', 3600), + ('additional_args', ''), + ('madengine_path', 'madengine'), + ('working_directory', 'MAD'), + ('ssh_timeout', 30), + ('ssh_max_retries', 3) + ]: + setattr(mock_args, attr, getattr(mock_args, attr, default)) + + config = MultiNodeConfig.from_args(mock_args) + + self.assertEqual(config.training.model, 'test_model') + self.assertEqual(config.cluster.nodes, ['node1', 'node2', 'node3']) + self.assertEqual(config.ssh.user, 'testuser') + self.assertEqual(config.ssh.password, 'testpass') + + def test_config_from_file(self): + """Test configuration loading from file.""" + config_content = """ +[cluster] +nodes = node1,node2,node3 +master_addr = node1 +master_port = 5000 + +[ssh] +user = testuser +password = testpass + +[training] +model = test_model +shared_data_path = /shared/data +""" + + with tempfile.NamedTemporaryFile(mode='w', suffix='.ini', delete=False) as f: + f.write(config_content) + temp_config_path = f.name + + try: + config = MultiNodeConfig.from_config_file(temp_config_path) + + self.assertEqual(config.cluster.nodes, ['node1', 'node2', 'node3']) + self.assertEqual(config.cluster.master_addr, 'node1') + self.assertEqual(config.cluster.master_port, 5000) + self.assertEqual(config.ssh.user, 'testuser') + self.assertEqual(config.ssh.password, 'testpass') + self.assertEqual(config.training.model, 'test_model') + self.assertEqual(config.training.shared_data_path, '/shared/data') + + finally: + os.unlink(temp_config_path) + + +class TestSSHClientManager(unittest.TestCase): + """Test cases for SSH client management.""" + + @patch('paramiko.SSHClient') + def test_connectivity_test_success(self, mock_ssh_client_class): + """Test successful connectivity test.""" + # Mock SSH client + mock_client = MagicMock() + mock_ssh_client_class.return_value = mock_client + + # Mock successful execution + mock_client.exec_command.return_value = (None, MagicMock(), MagicMock()) + mock_client.exec_command.return_value[1].channel.recv_exit_status.return_value = 0 + mock_client.exec_command.return_value[1].read.return_value = b'connectivity_test' + mock_client.exec_command.return_value[2].read.return_value = b'' + + ssh_manager = SSHClientManager( + hostname="testhost", + username="testuser", + password="testpass" + ) + + result = ssh_manager.test_connectivity() + self.assertTrue(result) + + @patch('paramiko.SSHClient') + def test_connectivity_test_failure(self, mock_ssh_client_class): + """Test failed connectivity test.""" + # Mock SSH client that raises exception + mock_client = MagicMock() + mock_ssh_client_class.return_value = mock_client + mock_client.connect.side_effect = Exception("Connection failed") + + ssh_manager = SSHClientManager( + hostname="testhost", + username="testuser", + password="testpass" + ) + + result = ssh_manager.test_connectivity() + self.assertFalse(result) + + def test_invalid_authentication(self): + """Test invalid authentication configuration.""" + with self.assertRaises(ValueError): + SSHClientManager( + hostname="testhost", + username="testuser" + # No password or key_filename + ) + + +class TestSSHMultiNodeRunner(unittest.TestCase): + """Test cases for SSH Multi-Node Runner.""" + + def setUp(self): + """Set up test fixtures.""" + self.config = MultiNodeConfig( + ssh=SSHConfig(user="testuser", password="testpass"), + cluster=ClusterConfig(nodes=["node1", "node2"]), + training=TrainingConfig(model="test_model") + ) + + def test_initialization(self): + """Test runner initialization.""" + runner = SSHMultiNodeRunner(self.config) + + self.assertEqual(len(runner.ssh_managers), 2) + self.assertIn("node1", runner.ssh_managers) + self.assertIn("node2", runner.ssh_managers) + + def test_command_generation(self): + """Test madengine command generation.""" + runner = SSHMultiNodeRunner(self.config) + + # Test command for node rank 0 + cmd_0 = runner._build_madengine_command(0) + self.assertIn('madengine run', cmd_0) + self.assertIn('test_model', cmd_0) + self.assertIn('"NODE_RANK": "0"', cmd_0) + self.assertIn('"NNODES": "2"', cmd_0) + + # Test command for node rank 1 + cmd_1 = runner._build_madengine_command(1) + self.assertIn('"NODE_RANK": "1"', cmd_1) + + @patch.object(SSHClientManager, 'test_connectivity') + def test_connectivity_check_success(self, mock_connectivity): + """Test successful connectivity checking.""" + mock_connectivity.return_value = True + + runner = SSHMultiNodeRunner(self.config) + reachable_nodes = runner._check_node_connectivity() + + self.assertEqual(len(reachable_nodes), 2) + self.assertIn("node1", reachable_nodes) + self.assertIn("node2", reachable_nodes) + + @patch.object(SSHClientManager, 'test_connectivity') + def test_connectivity_check_partial_failure(self, mock_connectivity): + """Test connectivity checking with partial failures.""" + # Mock node1 success, node2 failure + mock_connectivity.side_effect = [True, False] + + runner = SSHMultiNodeRunner(self.config) + reachable_nodes = runner._check_node_connectivity() + + self.assertEqual(len(reachable_nodes), 1) + self.assertIn("node1", reachable_nodes) + self.assertNotIn("node2", reachable_nodes) + + @patch.object(SSHClientManager, 'execute_command') + def test_prerequisites_validation_success(self, mock_execute): + """Test successful prerequisites validation.""" + # Mock successful responses for all checks + mock_execute.side_effect = [ + (0, "exists", ""), # Working directory exists + (0, "found", ""), # madengine found + (0, "/home/user/MAD", "") # Directory accessible + ] + + runner = SSHMultiNodeRunner(self.config) + success, error_msg = runner._validate_remote_prerequisites("node1") + + self.assertTrue(success) + self.assertEqual(error_msg, "") + + @patch.object(SSHClientManager, 'execute_command') + def test_prerequisites_validation_missing_directory(self, mock_execute): + """Test prerequisites validation with missing directory.""" + mock_execute.return_value = (0, "missing", "") + + runner = SSHMultiNodeRunner(self.config) + success, error_msg = runner._validate_remote_prerequisites("node1") + + self.assertFalse(success) + self.assertIn("MAD folder not found", error_msg) + + +class TestIntegration(unittest.TestCase): + """Integration tests with mocked components.""" + + def setUp(self): + """Set up test fixtures.""" + self.config = MultiNodeConfig( + ssh=SSHConfig(user="testuser", password="testpass"), + cluster=ClusterConfig(nodes=["node1", "node2"]), + training=TrainingConfig(model="test_model") + ) + + @patch.object(SSHMultiNodeRunner, '_execute_on_node') + @patch.object(SSHMultiNodeRunner, '_check_all_prerequisites') + @patch.object(SSHMultiNodeRunner, '_check_node_connectivity') + def test_successful_run(self, mock_connectivity, mock_prerequisites, mock_execute): + """Test successful end-to-end run.""" + # Mock all checks and execution as successful + mock_connectivity.return_value = ["node1", "node2"] + mock_prerequisites.return_value = True + mock_execute.side_effect = [ + ("node1", True, "Training completed"), + ("node2", True, "Training completed") + ] + + runner = SSHMultiNodeRunner(self.config) + result = runner.run() + + self.assertTrue(result) + self.assertEqual(mock_execute.call_count, 2) + + @patch.object(SSHMultiNodeRunner, '_execute_on_node') + @patch.object(SSHMultiNodeRunner, '_check_all_prerequisites') + @patch.object(SSHMultiNodeRunner, '_check_node_connectivity') + def test_failed_connectivity(self, mock_connectivity, mock_prerequisites, mock_execute): + """Test run with connectivity failure.""" + # Mock partial connectivity failure + mock_connectivity.return_value = ["node1"] # node2 unreachable + + runner = SSHMultiNodeRunner(self.config) + result = runner.run() + + self.assertFalse(result) + mock_prerequisites.assert_not_called() + mock_execute.assert_not_called() + + @patch.object(SSHMultiNodeRunner, '_execute_on_node') + @patch.object(SSHMultiNodeRunner, '_check_all_prerequisites') + @patch.object(SSHMultiNodeRunner, '_check_node_connectivity') + def test_failed_prerequisites(self, mock_connectivity, mock_prerequisites, mock_execute): + """Test run with prerequisites failure.""" + mock_connectivity.return_value = ["node1", "node2"] + mock_prerequisites.return_value = False + + runner = SSHMultiNodeRunner(self.config) + result = runner.run() + + self.assertFalse(result) + mock_execute.assert_not_called() + + +def run_tests(): + """Run all tests.""" + print("🧪 Running SSH Multi-Node Runner Tests...") + + # Create test suite + loader = unittest.TestLoader() + suite = unittest.TestSuite() + + # Add test cases + suite.addTests(loader.loadTestsFromTestCase(TestConfigManager)) + suite.addTests(loader.loadTestsFromTestCase(TestSSHClientManager)) + suite.addTests(loader.loadTestsFromTestCase(TestSSHMultiNodeRunner)) + suite.addTests(loader.loadTestsFromTestCase(TestIntegration)) + + # Run tests + runner = unittest.TextTestRunner(verbosity=2) + result = runner.run(suite) + + # Return success status + return result.wasSuccessful() + + +def validate_environment(): + """Validate that the environment is set up correctly.""" + print("🔍 Validating environment...") + + issues = [] + + # Check if paramiko is available + try: + import paramiko + print("✅ paramiko is available") + except ImportError: + issues.append("❌ paramiko is not installed. Run: pip install paramiko") + + # Check if required files exist + required_files = ['run.py', 'config_manager.py', 'ssh_client_manager.py', 'requirements.txt'] + current_dir = os.path.dirname(__file__) + + for filename in required_files: + file_path = os.path.join(current_dir, filename) + if os.path.exists(file_path): + print(f"✅ {filename} found") + else: + issues.append(f"❌ {filename} not found in current directory") + + if issues: + print("\n🚨 Issues found:") + for issue in issues: + print(f" {issue}") + return False + else: + print("\n✅ Environment validation passed!") + return True + + +if __name__ == "__main__": + print("SSH Multi-Node Runner Test Suite") + print("=" * 40) + + # Validate environment first + if not validate_environment(): + print("\n💥 Environment validation failed. Please fix the issues above.") + sys.exit(1) + + # Run tests + if run_tests(): + print("\n🎉 All tests passed!") + sys.exit(0) + else: + print("\n💥 Some tests failed!") + sys.exit(1) diff --git a/runners/ssh/utils.py b/runners/ssh/utils.py new file mode 100644 index 00000000..1c39755a --- /dev/null +++ b/runners/ssh/utils.py @@ -0,0 +1,219 @@ +#!/usr/bin/env python3 +"""Utility functions for SSH Multi-Node Runner + +This module provides common utility functions used across the SSH runner. + +Copyright (c) Advanced Micro Devices, Inc. All rights reserved. +""" + +import logging +import socket +import time +from typing import Dict, Any, Optional + + +def setup_logging(level: str = 'INFO', format_string: Optional[str] = None) -> logging.Logger: + """Setup logging configuration. + + Args: + level: Logging level (DEBUG, INFO, WARNING, ERROR) + format_string: Custom format string for log messages + + Returns: + Configured logger instance + """ + if format_string is None: + format_string = '%(asctime)s - %(name)s - %(levelname)s - %(message)s' + + logging.basicConfig( + level=getattr(logging, level.upper()), + format=format_string, + datefmt='%Y-%m-%d %H:%M:%S' + ) + + return logging.getLogger(__name__) + + +def validate_network_connectivity(hostname: str, port: int, timeout: int = 5) -> bool: + """Test network connectivity to a host and port. + + Args: + hostname: Target hostname or IP address + port: Target port number + timeout: Connection timeout in seconds + + Returns: + True if connection successful, False otherwise + """ + try: + with socket.create_connection((hostname, port), timeout=timeout): + return True + except (socket.error, socket.timeout): + return False + + +def wait_for_port_ready(hostname: str, port: int, max_wait_time: int = 60, check_interval: int = 2) -> bool: + """Wait for a port to become available on a host. + + Args: + hostname: Target hostname or IP address + port: Target port number + max_wait_time: Maximum time to wait in seconds + check_interval: Time between checks in seconds + + Returns: + True if port becomes available, False if timeout + """ + start_time = time.time() + + while time.time() - start_time < max_wait_time: + if validate_network_connectivity(hostname, port): + return True + time.sleep(check_interval) + + return False + + +def format_duration(seconds: float) -> str: + """Format duration in seconds to human-readable string. + + Args: + seconds: Duration in seconds + + Returns: + Formatted duration string (e.g., "2h 30m 15s") + """ + if seconds < 60: + return f"{seconds:.1f}s" + + minutes = int(seconds // 60) + remaining_seconds = int(seconds % 60) + + if minutes < 60: + return f"{minutes}m {remaining_seconds}s" + + hours = int(minutes // 60) + remaining_minutes = int(minutes % 60) + + if hours < 24: + return f"{hours}h {remaining_minutes}m {remaining_seconds}s" + + days = int(hours // 24) + remaining_hours = int(hours % 24) + + return f"{days}d {remaining_hours}h {remaining_minutes}m {remaining_seconds}s" + + +def sanitize_hostname(hostname: str) -> str: + """Sanitize hostname for safe use in file names and logs. + + Args: + hostname: Raw hostname or IP address + + Returns: + Sanitized hostname safe for file system use + """ + # Replace common problematic characters + safe_hostname = hostname.replace(':', '_').replace('/', '_').replace('\\', '_') + # Remove any remaining problematic characters + safe_hostname = ''.join(c for c in safe_hostname if c.isalnum() or c in '-_.') + return safe_hostname + + +def create_node_summary(nodes: list, successful: list, failed: list) -> Dict[str, Any]: + """Create a summary of node execution results. + + Args: + nodes: List of all nodes + successful: List of successful nodes + failed: List of failed nodes + + Returns: + Dictionary containing execution summary + """ + return { + 'total_nodes': len(nodes), + 'successful_nodes': len(successful), + 'failed_nodes': len(failed), + 'success_rate': len(successful) / len(nodes) if nodes else 0.0, + 'successful_node_list': successful, + 'failed_node_list': failed, + 'all_successful': len(failed) == 0 + } + + +def validate_madengine_command(command: str) -> bool: + """Validate that a madengine command looks reasonable. + + Args: + command: Command string to validate + + Returns: + True if command appears valid, False otherwise + """ + required_parts = ['madengine', 'run', '--tags'] + return all(part in command for part in required_parts) + + +def escape_shell_argument(argument: str) -> str: + """Escape shell argument for safe execution. + + Args: + argument: Argument to escape + + Returns: + Escaped argument safe for shell execution + """ + # Simple escaping - wrap in single quotes and escape any single quotes + return f"'{argument.replace(chr(39), chr(39) + chr(92) + chr(39) + chr(39))}'" + + +def parse_node_list(nodes_string: str) -> list: + """Parse a comma-separated list of nodes. + + Args: + nodes_string: Comma-separated string of node names/IPs + + Returns: + List of cleaned node names + """ + if not nodes_string: + return [] + + nodes = [node.strip() for node in nodes_string.split(',') if node.strip()] + return nodes + + +def get_network_interfaces() -> Dict[str, str]: + """Get available network interfaces on the local machine. + + Returns: + Dictionary mapping interface names to IP addresses + """ + interfaces = {} + + try: + import socket + # Get hostname + hostname = socket.gethostname() + # Get local IP + local_ip = socket.gethostbyname(hostname) + interfaces['local'] = local_ip + + # Try to get more detailed interface information + try: + import psutil + for interface, addresses in psutil.net_if_addrs().items(): + for address in addresses: + if address.family == socket.AF_INET: + interfaces[interface] = address.address + break + except ImportError: + # psutil not available, use basic detection + pass + + except Exception: + # Fallback to basic detection + pass + + return interfaces