-
Notifications
You must be signed in to change notification settings - Fork 4.8k
Description
cc @zhangfei829 @wuyl1 @TianDi101
Summary
This RFC proposes integrating an SDMA (System DMA) engine-based collective communication backend into DeepSpeed's ZeRO Stage 3 for AMD GPUs. Unlike RCCL, which relies on GPU Compute Units (CUs) to drive collective operations, the SDMA-based approach offloads communication to dedicated hardware DMA engines on AMD GPUs. This frees CUs entirely for computation, yielding significantly better computation–communication overlap efficiency—a critical advantage for ZeRO-3 training where AllGather and compute are heavily overlapped.
Motivation
The Problem: CU Contention in Communication–Computation Overlap
ZeRO Stage 3 partitions model parameters across data-parallel ranks and uses AllGather to reconstruct full parameters on-the-fly before each forward/backward computation. DeepSpeed's PartitionedParameterCoordinator aggressively prefetches parameters via AllGather on a separate CUDA/HIP stream, overlapping communication with GEMM and other compute operations.
However, current communication libraries (NCCL on NVIDIA, RCCL on AMD) execute collective operations using GPU compute resources:
- RCCL AllGather kernels occupy CUs, consuming shared resources (wavefront slots, LDS, register file, memory bandwidth) that could otherwise be used for computation.
- When AllGather and GEMM run concurrently, they compete for the same CU resources, causing compute throughput degradation.
- This contention increases with message size. For large models with tens or hundreds of millions of parameters per layer, the AllGather payloads easily reach 128MB–512MB, making CU contention a significant bottleneck.
The net effect: even though DeepSpeed's overlap mechanism is well-designed, the actual compute efficiency during overlap is lower than the theoretical maximum because the communication library itself consumes compute resources.
The Opportunity: SDMA Engines on AMD GPUs
AMD GPUs (MI300X, MI325X, MI355X, and future generations) contain dedicated SDMA (System DMA) engines—hardware copy engines that can perform memory transfers independently of the CU array. Key properties:
- SDMA engines operate autonomously from CUs: they have their own command queues and execution pipelines.
- They can perform intra-node GPU-to-GPU transfers over XGMI/Infinity Fabric without involving any CU.
- Multiple SDMA engines are available per GPU (e.g., MI300X has multiple SDMA engines across its chiplets), enabling parallelism.
By routing collective communication through SDMA engines instead of CU-based kernels, we achieve true zero-CU-overhead communication, enabling compute kernels to run at full throughput even during concurrent communication.
Proposed Solution
MORI-CCL: Lightweight SDMA-Based Collective Communication
We propose integrating MORI-CCL (part of the MORI framework) as an alternative communication backend for DeepSpeed's ZeRO-3 AllGather operations on AMD GPUs.
MORI-CCL is built on top of MORI's modular transport framework, which provides:
- SDMA transport: Direct submission of copy and atomic packets to SDMA hardware queues via HSAKMT, bypassing CU-based kernel launches entirely.
- OpenSHMEM-style APIs:
PutMemNbi(non-blocking put),Quiet(completion),Barrierfor building collectives from one-sided primitives. - PyTorch integration: Bootstrap via
torch.distributedprocess groups; compatible with existing DeepSpeed initialization.
Architecture
┌──────────────────────────────────────────────────────────────┐
│ DeepSpeed ZeRO-3 │
│ ┌──────────────────────────────────────────────────────┐ │
│ │ PartitionedParameterCoordinator │ │
│ │ (prefetch, fetch submit/wait, release) │ │
│ └────────────────────┬─────────────────────────────────┘ │
│ │ AllGather │
│ ┌────────┴────────┐ │
│ │ Backend Switch │ │
│ └───┬─────────┬──┘ │
│ ┌───────┘ └───────┐ │
│ ▼ ▼ │
│ ┌───────────────┐ ┌─────────────────┐ │
│ │ RCCL Path │ │ MORI-CCL Path │ │
│ │ (via torch. │ │ (SDMA engines) │ │
│ │ distributed) │ │ │ │
│ │ │ │ ┌───────────┐ │ │
│ │ Uses CUs ⚠ │ │ │SDMA Engine│ │ │
│ │ │ │ │ Queue 0 │ │ │
│ │ │ │ ├───────────┤ │ │
│ │ │ │ │SDMA Engine│ │ │
│ │ │ │ │ Queue 1 │ │ │
│ │ │ │ └───────────┘ │ │
│ │ │ │ Zero CU usage │ │
│ └───────────────┘ └─────────────────┘ │
│ │
│ ┌──────────────────────────────────────────────────────┐ │
│ │ GPU CU Array (Compute) │ │
│ │ GEMM / Attention / Normalization / Activation │ │
│ │ 100% CU availability with MORI-CCL path │ │
│ └──────────────────────────────────────────────────────┘ │
└──────────────────────────────────────────────────────────────┘
Integration Points in DeepSpeed
The integration touches the following components:
1. Communication Backend Registration
Add MORI-CCL as a recognized backend in deepspeed/comm/constants.py:
MORI_BACKEND = 'mori'Create a MORI backend wrapper in deepspeed/comm/mori.py that implements the Backend interface (extending deepspeed/comm/backend.py), routing all_gather_into_tensor and other collectives to MORI-CCL's SDMA-based implementation.
2. ZeRO-3 AllGather Path (deepspeed/runtime/zero/partition_parameters.py)
Modify AllGatherCoalescedHandle and the AllGather dispatch logic to optionally use MORI-CCL when configured. The key method is allgather_fn used in _allgather_params:
if zero_comm_backend == 'mori':
handle = mori_ccl.all_gather_into_tensor(output_buffer, input_partition, group)
else:
handle = dist.all_gather_into_tensor(output_buffer, input_partition, group, async_op=True)3. Configuration (deepspeed/runtime/zero/config.py)
Add a new configuration option under zero_optimization:
{
"zero_optimization": {
"stage": 3,
"overlap_comm": true,
"comm_backend": "mori",
"mori_config": {
"enable_sdma": true,
"num_sdma_queues": 4,
"heap_size": "2G"
}
}
}4. Accelerator Integration
Extend accelerator/cuda_accelerator.py (ROCm path) to detect SDMA engine availability and expose MORI-CCL as a communication option when running on supported AMD hardware.
Performance Data
Benchmark: GEMM + AllGather Overlap Total Time
We measured the total wall-clock time for overlapped GEMM computation with AllGather communication using RCCL vs. SDMA-based AllGather on AMD MI300X GPUs. The GEMM kernel and AllGather run concurrently on the same GPU—RCCL AllGather competes for CUs while SDMA AllGather uses dedicated DMA engines.
Key observations:
- SDMA-based AllGather achieves 1.2x–1.55x speedup in overlapped total time across all message sizes.
- The improvement is consistent across message sizes from 4MB to 512MB, covering the typical range of ZeRO-3 AllGather payloads.
- At larger message sizes (128MB–512MB), which are common for large model training, the absolute time savings are substantial (1–5ms per AllGather), and these add up across hundreds of layers per training step.
Why This Matters for ZeRO-3
In ZeRO Stage 3 training:
- AllGather dominates communication: Every forward and backward pass requires AllGathering the full parameters for each layer, resulting in hundreds of AllGather operations per training step.
- Overlap is nearly complete: DeepSpeed's
PartitionedParameterCoordinatorwith prefetching ensures that most AllGather operations overlap with computation. The overlap ratio is close to 100% for sufficiently large models. - CU contention is the remaining bottleneck: With near-complete overlap, the primary performance limiter is that RCCL's CU usage degrades compute throughput during overlap. Eliminating CU contention via SDMA directly translates to faster training step time.
Estimated end-to-end impact for large model training (e.g., 70B+ parameters):
- Per-layer AllGather payloads: 64MB–256MB (depending on hidden size and world size)
- Per-step AllGather count: 200–400 (forward + backward, all layers)
- Per-AllGather savings: 0.5–2ms
- Estimated per-step savings: 100–800ms (5–15% of total step time)
Scope and Limitations
In Scope
- ZeRO Stage 3 AllGather: Primary target, as this is the dominant overlapped collective in ZeRO-3.
- Intra-node communication: SDMA engines handle GPU-to-GPU transfers within a node over XGMI/Infinity Fabric.
- AMD MI300X, MI325X, MI355X: Current hardware with SDMA support.
Future Extensions
- ReduceScatter for ZeRO-3 gradients: Apply SDMA to gradient reduce-scatter as well.
- Inter-node communication: Combine SDMA (intra-node) with RDMA/IBGDA (inter-node) for multi-node training.
- ZeRO Stage 1/2: AllReduce operations could also benefit, though the overlap pattern differs.
- ZeRO++: Quantized AllGather (
qwZ) could leverage SDMA for the communication portion.
Limitations
- AMD-only: SDMA engines are specific to AMD GPU architecture. NVIDIA GPUs have CE (Copy Engines) but with different programming interfaces—this RFC targets AMD specifically.
- Intra-node first: The initial implementation focuses on intra-node AllGather. Multi-node scenarios require integration with RDMA transports.
- ROCm dependency: Requires ROCm >= 6.4.0 and MORI library installation.
Design Details
SDMA Queue Management
MORI uses HSAKMT to create and manage SDMA queues:
hsaKmtCreateQueueExt(..., HSA_QUEUE_SDMA_BY_ENG_ID, sdma_engine_id, ...)Each AllGather operation is decomposed into a series of PutMemNbi operations (non-blocking one-sided puts) that are submitted to SDMA queues. Completion is tracked via atomic signals—the SDMA engine writes a signal value upon completion, which can be polled without CU involvement.
Ring AllGather via SDMA
The AllGather is implemented as a ring algorithm where each rank sends its partition to the next rank in a ring, repeated world_size - 1 times:
Step 0: Rank 0 → Rank 1 (chunk 0), Rank 1 → Rank 2 (chunk 1), ...
Step 1: Rank 0 → Rank 1 (chunk N-1), Rank 1 → Rank 2 (chunk 0), ...
...
Step N-1: All ranks have all chunks
Each step uses SDMA PutMemNbi + Quiet:
PutMemNbi: Enqueues a copy packet on the SDMA engine (source → destination GPU).Quiet: Waits for the SDMA engine to complete all pending operations via signal polling.
Stream Integration
MORI-CCL operations are submitted on a dedicated HIP stream, compatible with DeepSpeed's existing allgather_stream used by PartitionedParameterCoordinator. Synchronization between the AllGather stream and the compute stream uses HIP events, preserving DeepSpeed's existing overlap mechanism.
Memory Management
MORI uses symmetric memory (shmem_malloc) for communication buffers. For ZeRO-3 integration:
- AllGather output buffers (flat parameter buffers) need to be registered with MORI for SDMA access.
shmem_buffer_register(ptr, size)enables SDMA transfers on existing PyTorch tensor memory.- This avoids extra copies—DeepSpeed's existing parameter buffers are used directly.
Alternatives Considered
1. Tuning RCCL Thread/CU Allocation
RCCL provides some knobs to limit CU usage (e.g., NCCL_MAX_NCHANNELS), but this trades communication bandwidth for compute availability. SDMA eliminates this trade-off entirely.
2. CPU-Based Communication (GDR + Host)
Using CPU-side RDMA with GPUDirect could offload from CUs, but adds PCIe latency and bandwidth bottlenecks. SDMA operates directly on the GPU's memory fabric (XGMI), avoiding PCIe.
3. Priority Streams / Preemption
Using different stream priorities to favor compute over communication doesn't solve the fundamental resource contention—CUs are still shared. SDMA provides hardware-level isolation.
Compatibility
- Backward compatible: The MORI backend is opt-in via configuration. Default behavior (RCCL via
torch.distributed) is unchanged. - No API changes: DeepSpeed's public API remains the same. The change is purely in the communication backend selection.
- Fallback: If MORI is not installed or hardware is unsupported, DeepSpeed falls back to the default backend with a warning.