33# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved.
44
55"""
6- Simple wrapper to run pytest tests within a single distributed process group.
7- This avoids the overhead of creating/destroying process groups for each test case .
6+ Simple wrapper to run pytest tests within a single distributed process group using torchrun .
7+ This avoids port conflicts by leveraging torchrun's automatic port management .
88"""
99
1010import os
1111import sys
12- import torch .multiprocessing as mp
13- import torch .distributed as dist
14- import socket
1512
1613# Set required environment variable for RCCL on ROCm
1714os .environ .setdefault ("HSA_NO_SCRATCH_RECLAIM" , "1" )
1815
1916
20- def _find_free_port ():
21- with socket .socket (socket .AF_INET , socket .SOCK_STREAM ) as s :
22- s .bind (("" , 0 ))
23- return s .getsockname ()[1 ]
24-
25-
26- def _distributed_worker (rank , world_size , test_file , pytest_args , init_method ):
27- """Worker function that runs pytest within a distributed process group."""
28- # Set the correct GPU for this specific process
29- # When ROCR_VISIBLE_DEVICES is set, devices are remapped, so rank 0 should use device 0, etc.
17+ def _distributed_worker_main ():
18+ """Main function for distributed worker that runs pytest."""
3019 import torch
31-
20+ import torch .distributed as dist
21+
22+ # torchrun sets these environment variables automatically
23+ rank = int (os .environ .get ("RANK" , 0 ))
24+ world_size = int (os .environ .get ("WORLD_SIZE" , 1 ))
25+ local_rank = int (os .environ .get ("LOCAL_RANK" , 0 ))
26+
27+ # Set the correct GPU for this specific process
3228 if torch .cuda .is_available ():
33- torch .cuda .set_device (rank )
34-
35- # Initialize distributed once for all tests
29+ torch .cuda .set_device (local_rank )
30+
31+ # Initialize distributed - torchrun already set up the environment
3632 dist .init_process_group (
3733 backend = "nccl" ,
38- init_method = init_method ,
3934 rank = rank ,
4035 world_size = world_size ,
41- device_id = torch .device (f"cuda:{ rank } " ),
36+ device_id = torch .device (f"cuda:{ local_rank } " ) if torch . cuda . is_available () else None ,
4237 )
43-
38+
4439 try :
4540 # Import and run pytest directly
4641 import pytest
47- import sys
48-
49- # Set up sys.argv for pytest
50- original_argv = sys .argv [:]
51- sys .argv = ["pytest" , test_file ] + pytest_args
52-
53- try :
54- # Run pytest directly in this process
55- exit_code = pytest .main ([test_file ] + pytest_args )
56- # If tests failed, exit with the failure code
57- if exit_code != 0 :
58- sys .exit (exit_code )
59- return exit_code
60- finally :
61- # Restore original argv
62- sys .argv = original_argv
63-
42+
43+ # Get pytest args from environment (set by launcher)
44+ pytest_args_str = os .environ .get ("PYTEST_ARGS" , "" )
45+ pytest_args = pytest_args_str .split () if pytest_args_str else []
46+
47+ # Run pytest
48+ exit_code = pytest .main (pytest_args )
49+ sys .exit (exit_code )
6450 finally :
6551 if dist .is_initialized ():
6652 dist .destroy_process_group ()
@@ -71,7 +57,13 @@ def main():
7157 print ("Usage: python run_tests_distributed.py [--num_ranks N] [pytest_args...] <test_file>" )
7258 sys .exit (1 )
7359
74- # Get number of ranks from args or default to 2
60+ # Check if we're being called as a torchrun worker
61+ if "RANK" in os .environ and "WORLD_SIZE" in os .environ :
62+ # We're running inside torchrun - execute as worker
63+ _distributed_worker_main ()
64+ return
65+
66+ # We're the launcher - parse args and start torchrun
7567 num_ranks = 2
7668 args = sys .argv [1 :]
7769
@@ -90,27 +82,34 @@ def main():
9082 test_file = args [0 ]
9183 pytest_args = args [1 :] # Everything after the test file
9284
93- print (f"Running { test_file } with { num_ranks } ranks" )
94- print (f"args={ args } , test_file={ test_file } , pytest_args={ pytest_args } " )
85+ print (f"Running { test_file } with { num_ranks } ranks using torchrun" )
86+
87+ # Build pytest arguments string
88+ pytest_cmd_args = [test_file ] + pytest_args
89+ pytest_args_str = " " .join (pytest_cmd_args )
90+
91+ # Set environment variable for worker to read
92+ os .environ ["PYTEST_ARGS" ] = pytest_args_str
93+
94+ # Build torchrun command - it will re-invoke this script as a worker
95+ import subprocess
96+
97+ torchrun_cmd = [
98+ "torchrun" ,
99+ f"--nproc_per_node={ num_ranks } " ,
100+ "--standalone" , # Single-node training
101+ __file__ , # Re-invoke this script
102+ "--worker-mode" , # Dummy arg to distinguish from launcher
103+ ]
95104
96- # Find a free port for this test run to avoid conflicts with parallel runs
97- free_port = _find_free_port ()
98- init_method = f"tcp://127.0.0.1:{ free_port } "
99- print (f"Using init_method: { init_method } " )
105+ print (f"Executing: { ' ' .join (torchrun_cmd )} " )
100106
101- # Run all tests within a single distributed process group
107+ # Run torchrun and return its exit code
102108 try :
103- mp .spawn (
104- _distributed_worker ,
105- args = (num_ranks , test_file , pytest_args , init_method ),
106- nprocs = num_ranks ,
107- join = True ,
108- )
109- except SystemExit as e :
110- # Catch sys.exit() from worker and return same exit code
111- sys .exit (e .code if isinstance (e .code , int ) else 1 )
112- except Exception :
113- # Any other unhandled exception = failure
109+ result = subprocess .run (torchrun_cmd , check = False , env = os .environ .copy ())
110+ sys .exit (result .returncode )
111+ except Exception as e :
112+ print (f"Error running torchrun: { e } " )
114113 sys .exit (1 )
115114
116115
0 commit comments