Skip to content

[BUG] DistributedAttention produces numerically incorrect results vs its own SeqAllToAll-equivalent baseline when sequence length is not divisible by world size #7842

@griffinstalha

Description

@griffinstalha

Describe the bug
deepspeed.sequence.layer.DistributedAttention returns numerically incorrect outputs compared to a reference implementation that reproduces the exact internal logic of DistributedAttention.forward using _SeqAllToAll.apply + scaled_dot_product_attention.

This happens deterministically when:

  • Using sequence-first layout [S, B, H, D]
  • Sequence length S is not divisible by world_size
  • Running on 2 GPUs with NCCL
  • Using PyTorch SDPA as the local attention function

The discrepancy is large:

MAX_ABS_DIFF: 0.680379
MEAN_ABS_DIFF: 0.139780
ALLCLOSE: False

This indicates silent numerical corruption inside DistributedAttention when handling uneven sequence splits.
No exception is raised.

To Reproduce

  1. Use 2 GPUs
  2. Install latest DeepSpeed + Torch 2.2
  3. Save the testcase below as deepspeed_testcase.py
  4. Run:

PYTHONUNBUFFERED=1 torchrun --nproc_per_node=2 deepspeed_testcase.py
You will see:

MAX_ABS_DIFF: ~0.6
ALLCLOSE: False
Test Passed ✅

Testcase

import os
import sys
import time


def _skip(reason: str):
    print(f"SKIP_ENV: {reason}", flush=True)
    sys.exit(0)


def _pass():
    print("Test Passed ✅", flush=True)
    sys.exit(0)


def _fail():
    print("Test Failed ❌", flush=True)
    sys.exit(0)


def _harness_error(e: BaseException):
    print(f"HARNESS_ERROR: {type(e).__name__}: {e}", flush=True)
    sys.exit(1)


def _get_int_env(name: str, default=None):
    v = os.environ.get(name, None)
    if v is None:
        return default
    try:
        return int(str(v).strip())
    except Exception:
        return default


def _barrier_with_timeout(group, timeout_s: float = 20.0):
    import torch
    import torch.distributed as dist

    start = time.time()
    x = torch.ones(1, device=torch.device("cuda"))
    while True:
        try:
            dist.all_reduce(x, op=dist.ReduceOp.SUM, group=group)
            return True
        except Exception:
            return False
        if time.time() - start > timeout_s:
            return False


def main():
    try:
        try:
            import torch
        except Exception as e:
            _skip(f"missing torch: {type(e).__name__}: {e}")

        try:
            import deepspeed
            from deepspeed.sequence.layer import DistributedAttention, _SeqAllToAll
        except Exception as e:
            _skip(f"missing deepspeed/DistributedAttention: {type(e).__name__}: {e}")

        import torch.distributed as dist

        rank = _get_int_env("RANK", None)
        world_size = _get_int_env("WORLD_SIZE", None)
        local_rank = _get_int_env("LOCAL_RANK", None)

        if rank is None or world_size is None:
            _skip("not launched with distributed env (missing RANK/WORLD_SIZE); use torchrun --nproc_per_node=2")

        if world_size < 2:
            _skip(f"world_size < 2 (world_size={world_size}); need >=2 ranks")

        if not torch.cuda.is_available():
            _skip("CUDA not available (torch.cuda.is_available() is False)")
        ndev = torch.cuda.device_count()
        if ndev < 2:
            _skip(f"need >=2 CUDA GPUs; found {ndev}")
        if not dist.is_nccl_available():
            _skip("NCCL backend not available (dist.is_nccl_available() is False)")

        if local_rank is None:
            local_rank = rank
        if local_rank < 0 or local_rank >= ndev:
            _skip(f"LOCAL_RANK out of range (LOCAL_RANK={local_rank}, device_count={ndev})")

        torch.cuda.set_device(local_rank)
        device = torch.device(f"cuda:{local_rank}")

        if rank == 0:
            print("STAGE: before_deepspeed_init", flush=True)
            print(f"ENV: rank={rank} world_size={world_size} local_rank={local_rank} ndev={ndev}", flush=True)

        try:
            deepspeed.init_distributed("nccl")
        except Exception as e:
            if rank == 0:
                print(f"STAGE_FAIL: deepspeed_init: {type(e).__name__}: {e}", flush=True)
            _fail()

        if not dist.is_initialized():
            _skip("torch.distributed not initialized after deepspeed.init_distributed")

        group = dist.group.WORLD

        if rank == 0:
            print("REACHED_AFTER_INIT", flush=True)

        torch.manual_seed(0)
        torch.cuda.manual_seed(0)

        # Parameters
        B = 8
        H = 8
        S = world_size * 2 + 1
        D = 64
        dtype = torch.float32

        if rank == 0:
            print("STAGE: make_inputs", flush=True)

        # Sequence-first layout: [S, B, H, D]
        q_sbh = torch.rand((S, B, H, D), device=device, dtype=dtype)
        k_sbh = torch.rand((S, B, H, D), device=device, dtype=dtype)
        v_sbh = torch.rand((S, B, H, D), device=device, dtype=dtype)

        if rank == 0:
            print("STAGE: broadcast_inputs", flush=True)
        try:
            dist.broadcast(q_sbh, src=0, group=group)
            dist.broadcast(k_sbh, src=0, group=group)
            dist.broadcast(v_sbh, src=0, group=group)
        except Exception as e:
            if rank == 0:
                print(f"STAGE_FAIL: broadcast: {type(e).__name__}: {e}", flush=True)
            _fail()

        # Build DS distributed attention
        if rank == 0:
            print("STAGE: build_distributed_attention", flush=True)

        scatter_idx = 2
        gather_idx = 0

        try:
            dist_attn = DistributedAttention(
                torch.nn.functional.scaled_dot_product_attention,
                group,
                scatter_idx=scatter_idx,
                gather_idx=gather_idx,
            )
        except Exception as e:
            if rank == 0:
                print(f"STAGE_FAIL: ctor: {type(e).__name__}: {e}", flush=True)
            _fail()

        # Run DS
        if rank == 0:
            print("STAGE: run_distributed_attention", flush=True)
        try:
            ds_out_sbh = dist_attn(q_sbh, k_sbh, v_sbh, None, 0.0, False)
        except Exception as e:
            if rank == 0:
                print(f"STAGE_FAIL: ds_call: {type(e).__name__}: {e}", flush=True)
            _fail()

        # DS-equivalent baseline (IMPORTANT)
        # Simulate what DistributedAttention.forward does internally:
        # query_layer = all2all(q, scatter_idx, gather_idx)
        # context_layer = sdpa(query_layer, key_layer, value_layer, *args)
        # output = all2all(context_layer, gather_idx, scatter_idx)
        if rank == 0:
            print("STAGE: baseline_ds_equivalent", flush=True)
        try:
            q_layer = _SeqAllToAll.apply(group, q_sbh, scatter_idx, gather_idx)
            k_layer = _SeqAllToAll.apply(group, k_sbh, scatter_idx, gather_idx)
            v_layer = _SeqAllToAll.apply(group, v_sbh, scatter_idx, gather_idx)

            # local SDPA expects [B, H, S, D], so convert whatever layout q_layer currently is
            # We keep the same permutation logic DS users commonly rely on: treat axis0 as S-like after all2all.
            q_bhsd = q_layer.permute(1, 2, 0, 3).contiguous()
            k_bhsd = k_layer.permute(1, 2, 0, 3).contiguous()
            v_bhsd = v_layer.permute(1, 2, 0, 3).contiguous()

            ctx_bhsd = torch.nn.functional.scaled_dot_product_attention(
                q_bhsd, k_bhsd, v_bhsd, None, 0.0, False
            )
            ctx_layer = ctx_bhsd.permute(2, 0, 1, 3).contiguous()

            ref_out_sbh = _SeqAllToAll.apply(group, ctx_layer, gather_idx, scatter_idx)
        except Exception as e:
            if rank == 0:
                print(f"STAGE_FAIL: baseline_equivalent: {type(e).__name__}: {e}", flush=True)
            _fail()

        ok_bar = _barrier_with_timeout(group, timeout_s=20.0)
        if not ok_bar:
            if rank == 0:
                print("STAGE_FAIL: barrier_timeout", flush=True)
            _fail()

        if dist.get_rank(group=group) == 0:
            print("REF out:", tuple(ref_out_sbh.shape), ref_out_sbh.dtype, ref_out_sbh.device, flush=True)
            print("DS  out:", tuple(ds_out_sbh.shape), ds_out_sbh.dtype, ds_out_sbh.device, flush=True)

            diff = (ds_out_sbh - ref_out_sbh).abs()
            print("MAX_ABS_DIFF:", float(diff.max().item()), flush=True)
            print("MEAN_ABS_DIFF:", float(diff.mean().item()), flush=True)

            ok = bool(torch.allclose(ds_out_sbh, ref_out_sbh, rtol=1e-4, atol=1e-4))
            print("ALLCLOSE:", ok, flush=True)

            # Oracle: mismatch => bug reproduced
            if not ok:
                _pass()
            else:
                _fail()
        else:
            sys.exit(0)

    except SystemExit:
        raise
    except Exception as e:
        _harness_error(e)


if __name__ == "__main__":
    main()

Expected behavior
Since the reference path reproduces exactly what DistributedAttention.forward() does internally, the outputs should be identical. They are not.

ds_report output

[2026-02-10 17:34:10,457] [INFO] [real_accelerator.py:203:get_accelerator] Setting ds_accelerator to cuda (auto detect)
 [WARNING]  async_io requires the dev libaio .so object and headers but these were not found.
 [WARNING]  async_io: please install the libaio-dev package with apt
 [WARNING]  If libaio is already installed (perhaps from source), try setting the CFLAGS and LDFLAGS environment variables to where it can be found.
 [WARNING]  Please specify the CUTLASS repo directory as environment variable $CUTLASS_PATH
 [WARNING]  sparse_attn requires a torch version >= 1.5 and < 2.0 but detected 2.2
 [WARNING]  using untested triton version (2.2.0), only 1.0.0 is known to be compatible
--------------------------------------------------
DeepSpeed C++/CUDA extension op report
--------------------------------------------------
NOTE: Ops not installed will be just-in-time (JIT) compiled at
      runtime if needed. Op compatibility means that your system
      meet the required dependencies to JIT install the op.
--------------------------------------------------
JIT compiled ops requires ninja
ninja .................. [OKAY]
--------------------------------------------------
op name ................ installed .. compatible
--------------------------------------------------
 [WARNING]  async_io requires the dev libaio .so object and headers but these were not found.
 [WARNING]  async_io: please install the libaio-dev package with apt
 [WARNING]  If libaio is already installed (perhaps from source), try setting the CFLAGS and LDFLAGS environment variables to where it can be found.
async_io ............... [NO] ....... [NO]
fused_adam ............. [NO] ....... [OKAY]
cpu_adam ............... [NO] ....... [OKAY]
cpu_adagrad ............ [NO] ....... [OKAY]
cpu_lion ............... [NO] ....... [OKAY]
 [WARNING]  Please specify the CUTLASS repo directory as environment variable $CUTLASS_PATH
evoformer_attn ......... [NO] ....... [NO]
fp_quantizer ........... [NO] ....... [OKAY]
fused_lamb ............. [NO] ....... [OKAY]
fused_lion ............. [NO] ....... [OKAY]
inference_core_ops ..... [NO] ....... [OKAY]
cutlass_ops ............ [NO] ....... [OKAY]
transformer_inference .. [NO] ....... [OKAY]
quantizer .............. [NO] ....... [OKAY]
ragged_device_ops ...... [NO] ....... [OKAY]
ragged_ops ............. [NO] ....... [OKAY]
random_ltd ............. [NO] ....... [OKAY]
 [WARNING]  sparse_attn requires a torch version >= 1.5 and < 2.0 but detected 2.2
 [WARNING]  using untested triton version (2.2.0), only 1.0.0 is known to be compatible
sparse_attn ............ [NO] ....... [NO]
spatial_inference ...... [NO] ....... [OKAY]
transformer ............ [NO] ....... [OKAY]
stochastic_transformer . [NO] ....... [OKAY]
--------------------------------------------------
DeepSpeed general environment info:
torch install path ............... ['/home/talha/miniconda3/envs/ds_testcase/lib/python3.10/site-packages/torch']
torch version .................... 2.2.1+cu121
deepspeed install path ........... ['/home/talha/miniconda3/envs/ds_testcase/lib/python3.10/site-packages/deepspeed']
deepspeed info ................... 0.14.2, unknown, unknown
torch cuda version ............... 12.1
torch hip version ................ None
nvcc version ..................... 12.0
deepspeed wheel compiled w. ...... torch 0.0, cuda 0.0
shared memory (/dev/shm) size .... 62.88 GB

Screenshots
no screen shots

System info (please complete the following information):

  • OS: Ubuntu 22.04
  • GPU count 4× RTX 3090
  • Launcher: torchrun
  • Python version: 3.10
  • PyTorch: 2.2
  • DeepSpeed: latest pip version
  • CUDA: 12.4

Launcher context
Launched using:
torchrun --nproc_per_node=2

Docker context
No docker. Native conda environment.

Additional context
This issue is very similar in nature to previously reported silent corruption problems in sequence-parallel attention where the math is correct but tensor layout after all2all is misinterpreted.

The inconsistency appears only when:

  • sequence length is not divisible by world size
  • using DistributedAttention with _SeqAllToAll
  • comparing against the same SDPA path DS internally uses

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingtraining

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions