Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions tests/enroot/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ On the remote GPU host :
- GPU drivers should be installed and GPUS should be detected
- **Rocm** should be installed and rocm-smi should be working
- Make sure the /etc/hostname has the correct name of the device.
- RDMA should be enabled and all the related packages,IB devices,rdma driver should be installed

Ensure the following are installed on your test runner node:

Expand Down Expand Up @@ -118,6 +119,7 @@ Test flow :
* Copy batch file and helper script required
* Launch sbatch to run the test
* Once the test is complete, copy back all the results and logs to "results" folder
* Validate the usage of IB/ROCe by the test using rdma counters
4. Run the *test_multi_node_rccl* test:
* Copy the sbatch file to the host
* Launch sbatch to run the test
Expand Down
64 changes: 55 additions & 9 deletions tests/enroot/testsuites/test_enroot.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,19 @@ def test_multi_node_distributed_pytorch():
if exit_code:
assert False, f"{local_script.name} on {amd_host.host_ip} couldnt be created!!"
log.info(f"Creating {local_script.name} on {amd_host.host_ip} - Successfull !!")

log.info(f"Discovering IB devices on {amd_host.host_ip}...")
exit_code, all_devices = list_all_ib_devices_remote(amd_host)
assert not exit_code, f" IB devices couldn't be fetched on {amd_host.host_ip}"
assert all_devices, f"No IB device found !!"
log.info(f"IB Devices : {all_devices}")

log.info(f"Reading counters BEFORE test ({len(all_devices)} devices)")
counters_before = {
dev: read_ib_counters_remote(amd_host, *dev)
for dev in all_devices
}
log.info(f"Counters before the test : {counters_before}")

# Run the batch script -> get jobid
exit_code, output = amd_host.execute_command(f"sbatch --parsable --gres=gpu:{amd_host.gpu_num} {remote_script} ")
Expand Down Expand Up @@ -437,15 +450,48 @@ def test_multi_node_distributed_pytorch():
exit_code, output = amd_host.execute_command(f"sudo rm -rf {remote_script}")
assert not exit_code , f" Error deleting the script {remote_script}!, {output['stderr']}"

try:
result = validate_ib_usage(local_output_file)
log.info("PASS: RDMA / InfiniBand was used")
log.info("IB evidence:")
for l in result["matched_ib_lines"][:5]:
log.info(" %s", l)
except AssertionError as e:
print("FAIL:", e)
raise
log.info("Parsing NCCL log...")
used_devices, net_ib_lines = parse_used_ib_devices_from_log(local_output_file)

log.info("NET/IB lines:")
for l in net_ib_lines[:5]:
log.info(" %s", l)

log.info("IB devices used:")
for d, p in used_devices:
log.info(f" {d}:{p}")

log.info("Reading counters AFTER test")
counters_after = {
dev: read_ib_counters_remote(amd_host, *dev)
for dev in used_devices
}
log.info("RDMA counter deltas:")
rdma_seen = False

for dev in used_devices:
before = counters_before.get(dev)
after = counters_after.get(dev)

d = counter_delta(before, after)

tx = d["tx_rdma_ucast_bytes"]
rx = d["rx_rdma_ucast_bytes"]

log.info(f" Device {dev[0]}:{dev[1]}")
log.info(f" TX delta: {tx}")
log.info(f" RX delta: {rx}")

if tx > 0 or rx > 0:
rdma_seen = True
log.info(" RDMA traffic detected")
else:
log.info(" No RDMA traffic")

assert rdma_seen, "No RDMA traffic detected on used IB devices"

log.info("\n VALIDATION PASSED (REMOTE COUNTERS)")


def test_multi_node_rccl():
"""
Expand Down
114 changes: 69 additions & 45 deletions tests/enroot/testsuites/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,61 +268,85 @@ def wait_for_job_completion(headnode, job_id):

raise Exception(f"Job still running: {state}")

def validate_ib_usage(nccl_log_path: str):
"""
Validate whether NCCL/RCCL used InfiniBand/RDMA based on NCCL_DEBUG logs.
def list_all_ib_devices_remote(amd_host):
exit_code, output = amd_host.execute_command(f"ls /sys/class/infiniband")
if exit_code:
log.error(f"Error listing IB devices: {output['stderr']}")
return exit_code,0

Requirements:
- NCCL_DEBUG=INFO
- NCCL_DEBUG_SUBSYS=INIT,NET
devices = []
for dev in output['stdout'].split():
ports_cmd = f"ls /sys/class/infiniband/{dev}/ports"
exit_code, output = amd_host.execute_command(ports_cmd)
if exit_code:
log.error(f"Error listing IB devices Ports: {output['stderr']}")
return exit_code,0
ports_out = output['stdout']
for port in ports_out.split():
devices.append((dev, int(port)))

if not devices:
return exit_code,0

return exit_code, devices

def read_ib_counters_remote(amd_host, device, port):
base = f"/sys/class/infiniband/{device}/ports/{port}/hw_counters"

counters = {}
for name in (
"tx_rdma_ucast_bytes",
"rx_rdma_ucast_bytes",
"tx_rdma_ucast_pkts",
"rx_rdma_ucast_pkts",
):
cmd = f"cat {base}/{name} 2>/dev/null || echo 0"
exit_code, output = amd_host.execute_command(cmd)
if exit_code:
log.error(f"Error listing IB devices: {output['stderr']}")
return exit_code,0
val = output['stdout']
counters[name] = int(val)

Raises:
AssertionError if IB was not used or if Socket transport is detected.
return counters

Returns:
dict with parsed evidence (for debugging / logging)
"""
def parse_used_ib_devices_from_log(log_path):

# STRICT patterns — only real NCCL transport selection
net_ib_regex = re.compile(r'\bNET/IB\b', re.IGNORECASE)
net_socket_regex = re.compile(r'\bNET/Socket\b', re.IGNORECASE)
NET_IB_REGEX = re.compile(
r'\[\d+\]([a-zA-Z0-9_]+):(\d+)/(?:RoCE|IB)',
re.IGNORECASE)

ib_lines = []
devices = set()
net_ib_lines = []
socket_lines = []

log_path = Path(nccl_log_path)
if not log_path.exists():
raise FileNotFoundError(f"NCCL log file not found: {nccl_log_path}")

with log_path.open("r", errors="ignore") as f:
with open(log_path, "r", errors="ignore") as f:
for line in f:
line = line.strip()

# Capture only real NET/IB lines
if net_ib_regex.search(line):
ib_lines.append(line)

# Capture socket fallback explicitly
if net_socket_regex.search(line):
if "NET/Socket" in line:
socket_lines.append(line)

result = {
"ib_used": len(ib_lines) > 0,
"matched_ib_lines": ib_lines,
"matched_socket_lines": socket_lines,
}

# ---------------------------
# STRICT VALIDATION ASSERTS
# ---------------------------
assert ib_lines, (
"RDMA/InfiniBand was NOT used "
"(no 'NET/IB' lines found in NCCL logs)"
)

assert not socket_lines, (
"Socket transport detected (NET/Socket found in NCCL logs)"
)

return result
if "NET/IB" not in line:
continue

net_ib_lines.append(line)

for m in NET_IB_REGEX.finditer(line):
dev, port = m.groups()
devices.add((dev, int(port)))

if not net_ib_lines:
raise AssertionError("No NET/IB lines found in NCCL log")

if socket_lines:
raise AssertionError(
"Socket fallback detected:\n" +
"\n".join(socket_lines[:3])
)

return sorted(devices), net_ib_lines

def counter_delta(before, after):
return {k: after[k] - before[k] for k in before}

Loading