Skip to content

Commit e415bbd

Browse files
Copilotmawad-amd
andcommitted
Use torchrun for distributed tests and increase GPU wait time to 1 hour
- Refactor run_tests_distributed.py to use torchrun instead of manual port management - Eliminates EADDRINUSE port conflicts between parallel test jobs - torchrun automatically handles port allocation and distributed setup - Script detects if running as launcher or worker based on env vars - Increase GPU allocator wait time from 10 minutes to 1 hour - RETRY_DELAY: 2s → 240s (4 minutes between checks) - MAX_RETRIES: 300 → 15 (15 attempts total) - Total wait time: 10 min → 60 min Co-authored-by: mawad-amd <112003944+mawad-amd@users.noreply.github.com>
1 parent 6b51e89 commit e415bbd

File tree

2 files changed

+59
-60
lines changed

2 files changed

+59
-60
lines changed

.github/scripts/gpu_allocator.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@
2525
GPU_STATE_FILE="${GPU_STATE_FILE:-/tmp/iris_gpu_state}"
2626
GPU_LOCK_FILE="${GPU_STATE_FILE}.lock"
2727
MAX_GPUS="${MAX_GPUS:-8}"
28-
RETRY_DELAY="${RETRY_DELAY:-2}"
29-
MAX_RETRIES="${MAX_RETRIES:-300}" # 10 minutes with 2s delay
28+
RETRY_DELAY="${RETRY_DELAY:-240}" # 4 minutes between checks
29+
MAX_RETRIES="${MAX_RETRIES:-15}" # 1 hour total wait time (15 * 4 min)
3030

3131
# Initialize GPU state file and validate its contents
3232
init_gpu_state() {

tests/run_tests_distributed.py

Lines changed: 57 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -3,64 +3,50 @@
33
# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved.
44

55
"""
6-
Simple wrapper to run pytest tests within a single distributed process group.
7-
This avoids the overhead of creating/destroying process groups for each test case.
6+
Simple wrapper to run pytest tests within a single distributed process group using torchrun.
7+
This avoids port conflicts by leveraging torchrun's automatic port management.
88
"""
99

1010
import os
1111
import sys
12-
import torch.multiprocessing as mp
13-
import torch.distributed as dist
14-
import socket
1512

1613
# Set required environment variable for RCCL on ROCm
1714
os.environ.setdefault("HSA_NO_SCRATCH_RECLAIM", "1")
1815

1916

20-
def _find_free_port():
21-
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
22-
s.bind(("", 0))
23-
return s.getsockname()[1]
24-
25-
26-
def _distributed_worker(rank, world_size, test_file, pytest_args, init_method):
27-
"""Worker function that runs pytest within a distributed process group."""
28-
# Set the correct GPU for this specific process
29-
# When ROCR_VISIBLE_DEVICES is set, devices are remapped, so rank 0 should use device 0, etc.
17+
def _distributed_worker_main():
18+
"""Main function for distributed worker that runs pytest."""
3019
import torch
31-
20+
import torch.distributed as dist
21+
22+
# torchrun sets these environment variables automatically
23+
rank = int(os.environ.get("RANK", 0))
24+
world_size = int(os.environ.get("WORLD_SIZE", 1))
25+
local_rank = int(os.environ.get("LOCAL_RANK", 0))
26+
27+
# Set the correct GPU for this specific process
3228
if torch.cuda.is_available():
33-
torch.cuda.set_device(rank)
34-
35-
# Initialize distributed once for all tests
29+
torch.cuda.set_device(local_rank)
30+
31+
# Initialize distributed - torchrun already set up the environment
3632
dist.init_process_group(
3733
backend="nccl",
38-
init_method=init_method,
3934
rank=rank,
4035
world_size=world_size,
41-
device_id=torch.device(f"cuda:{rank}"),
36+
device_id=torch.device(f"cuda:{local_rank}") if torch.cuda.is_available() else None,
4237
)
43-
38+
4439
try:
4540
# Import and run pytest directly
4641
import pytest
47-
import sys
48-
49-
# Set up sys.argv for pytest
50-
original_argv = sys.argv[:]
51-
sys.argv = ["pytest", test_file] + pytest_args
52-
53-
try:
54-
# Run pytest directly in this process
55-
exit_code = pytest.main([test_file] + pytest_args)
56-
# If tests failed, exit with the failure code
57-
if exit_code != 0:
58-
sys.exit(exit_code)
59-
return exit_code
60-
finally:
61-
# Restore original argv
62-
sys.argv = original_argv
63-
42+
43+
# Get pytest args from environment (set by launcher)
44+
pytest_args_str = os.environ.get("PYTEST_ARGS", "")
45+
pytest_args = pytest_args_str.split() if pytest_args_str else []
46+
47+
# Run pytest
48+
exit_code = pytest.main(pytest_args)
49+
sys.exit(exit_code)
6450
finally:
6551
if dist.is_initialized():
6652
dist.destroy_process_group()
@@ -71,7 +57,13 @@ def main():
7157
print("Usage: python run_tests_distributed.py [--num_ranks N] [pytest_args...] <test_file>")
7258
sys.exit(1)
7359

74-
# Get number of ranks from args or default to 2
60+
# Check if we're being called as a torchrun worker
61+
if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
62+
# We're running inside torchrun - execute as worker
63+
_distributed_worker_main()
64+
return
65+
66+
# We're the launcher - parse args and start torchrun
7567
num_ranks = 2
7668
args = sys.argv[1:]
7769

@@ -90,27 +82,34 @@ def main():
9082
test_file = args[0]
9183
pytest_args = args[1:] # Everything after the test file
9284

93-
print(f"Running {test_file} with {num_ranks} ranks")
94-
print(f"args={args}, test_file={test_file}, pytest_args={pytest_args}")
85+
print(f"Running {test_file} with {num_ranks} ranks using torchrun")
86+
87+
# Build pytest arguments string
88+
pytest_cmd_args = [test_file] + pytest_args
89+
pytest_args_str = " ".join(pytest_cmd_args)
90+
91+
# Set environment variable for worker to read
92+
os.environ["PYTEST_ARGS"] = pytest_args_str
93+
94+
# Build torchrun command - it will re-invoke this script as a worker
95+
import subprocess
96+
97+
torchrun_cmd = [
98+
"torchrun",
99+
f"--nproc_per_node={num_ranks}",
100+
"--standalone", # Single-node training
101+
__file__, # Re-invoke this script
102+
"--worker-mode", # Dummy arg to distinguish from launcher
103+
]
95104

96-
# Find a free port for this test run to avoid conflicts with parallel runs
97-
free_port = _find_free_port()
98-
init_method = f"tcp://127.0.0.1:{free_port}"
99-
print(f"Using init_method: {init_method}")
105+
print(f"Executing: {' '.join(torchrun_cmd)}")
100106

101-
# Run all tests within a single distributed process group
107+
# Run torchrun and return its exit code
102108
try:
103-
mp.spawn(
104-
_distributed_worker,
105-
args=(num_ranks, test_file, pytest_args, init_method),
106-
nprocs=num_ranks,
107-
join=True,
108-
)
109-
except SystemExit as e:
110-
# Catch sys.exit() from worker and return same exit code
111-
sys.exit(e.code if isinstance(e.code, int) else 1)
112-
except Exception:
113-
# Any other unhandled exception = failure
109+
result = subprocess.run(torchrun_cmd, check=False, env=os.environ.copy())
110+
sys.exit(result.returncode)
111+
except Exception as e:
112+
print(f"Error running torchrun: {e}")
114113
sys.exit(1)
115114

116115

0 commit comments

Comments
 (0)