Skip to content

Let PyTorch set -gencode flags #7871

@Flamefire

Description

@Flamefire

The following code parses TORCH_CUDA_ARCH_LIST:

def compute_capability_args(self, cross_compile_archs=None):
"""
Returns nvcc compute capability compile flags.
1. `TORCH_CUDA_ARCH_LIST` takes priority over `cross_compile_archs`.
2. If neither is set default compute capabilities will be used
3. Under `jit_mode` compute capabilities of all visible cards will be used plus PTX
Format:
- `TORCH_CUDA_ARCH_LIST` may use ; or whitespace separators. Examples:
TORCH_CUDA_ARCH_LIST="6.1;7.5;8.6;9.0;10.0" pip install ...
TORCH_CUDA_ARCH_LIST="6.0 6.1 7.0 7.5 8.0 8.6 9.0 10.0+PTX" pip install ...
- `cross_compile_archs` uses ; separator.
"""
ccs = []
if self.jit_mode:
# Compile for underlying architectures since we know those at runtime
for i in range(torch.cuda.device_count()):
CC_MAJOR, CC_MINOR = torch.cuda.get_device_capability(i)
cc = f"{CC_MAJOR}.{CC_MINOR}"
if cc not in ccs:
ccs.append(cc)
ccs = sorted(ccs)
ccs[-1] += '+PTX'
else:
# Cross-compile mode, compile for various architectures
# env override takes priority
cross_compile_archs_env = os.environ.get('TORCH_CUDA_ARCH_LIST', None)
if cross_compile_archs_env is not None:
if cross_compile_archs is not None:
print(
f"{WARNING} env var TORCH_CUDA_ARCH_LIST={cross_compile_archs_env} overrides cross_compile_archs={cross_compile_archs}"
)
cross_compile_archs = cross_compile_archs_env.replace(' ', ';')
else:
if cross_compile_archs is None:
cross_compile_archs = get_default_compute_capabilities()
ccs = cross_compile_archs.split(';')
ccs = self.filter_ccs(ccs)
if len(ccs) == 0:
raise RuntimeError(
f"Unable to load {self.name} op due to no compute capabilities remaining after filtering")
args = []
self.enable_bf16 = True
for cc in ccs:
num = cc[0] + cc[1].split('+')[0]
args.append(f'-gencode=arch=compute_{num},code=sm_{num}')
if cc[1].endswith('+PTX'):
args.append(f'-gencode=arch=compute_{num},code=compute_{num}')
if int(cc[0]) <= 7:
self.enable_bf16 = False
return args

This mostly duplicates the logic already present in PyTorch including falling back to device architectures of visible devices.

Is there any specific reason for that?

I just noticed this while investigating #7863 and found that in PyTorch the -gencode flags are sorted while the deepspeed flags are not.
Although it would have resolved the issue in my environment, it just changes the issue but makes it at least independent of the order used in TORCH_CUDA_ARCH_LIST

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions