Skip to content

[Pytorch] index_copy_ silently produces NaN #2896

@anatoliylitv

Description

@anatoliylitv

https://ontrack-internal.amd.com/browse/SWDEV-573529

Bug Summary:
torch.Tensor.index_copy_ produces NaN values silently on ROCm/HIP devices
under specific conditions involving bfloat16 tensors and certain indexing patterns.
This does NOT raise any exception - the operation appears to succeed but the
resulting tensor contains NaN values.

The bug was discovered through failing CI tests in HuggingFace Transformers:

  • test_model_4b_crops
  • test_model_4b_batch_crops
  • test_model_4b_multiimage

Affected Configuration:
- PyTorch with ROCm/HIP backend
- bfloat16 dtype tensors
- Static cache pattern (pre-allocated tensor with index_copy_ updates)
- Longer sequence lengths (100+ positions)

Impact:
This bug causes NaN logits in transformer model generation when using static cache
with multimodal inputs (e.g., Gemma3 with images). The model generates only 1-2 tokens
before stopping due to NaN values in the attention computation.

Workaround:
Use slice assignment (tensor[:, :, indices] = values) instead of index_copy_()

Run:
python rocm_index_copy_issue.py

Expected Output (CUDA/CPU):
All tests should show "has NaN: False" and "Results match: True"

Actual Output (ROCm - BUG):
Tests with longer sequences show "has NaN: True" or "Results match: False"

Reproducer

#!/usr/bin/env python3
"""
ROCm index_copy_ Silent NaN Bug - Minimal Reproduction
"""

import sys

def print_system_info():
"""Print system and PyTorch configuration."""
import torch

print("=" * 70)
print("SYSTEM INFORMATION")
print("=" * 70)
print(f"Python version:  {sys.version.split()[0]}")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available:  {torch.cuda.is_available()}")
print(f"CUDA version:    {torch.version.cuda}")
print(f"HIP version:     {torch.version.hip}")

if torch.cuda.is_available():
    print(f"Device count:    {torch.cuda.device_count()}")
    for i in range(torch.cuda.device_count()):
        print(f"Device {i}:        {torch.cuda.get_device_name(i)}")

is_rocm = torch.version.hip is not None
print(f"\nRunning on:      {'ROCm/HIP' if is_rocm else 'CUDA' if torch.cuda.is_available() else 'CPU'}")
print("=" * 70)
print()

def test_index_copy_nan(device: str, seq_len: int, dtype=None):
"""
Test if index_copy_ produces NaN values.
"""
import torch

if dtype is None:
    dtype = torch.bfloat16

# Simulate KV cache dimensions: [batch_size, num_heads, max_cache_len, head_dim]
batch_size = 1
num_heads = 8
max_cache_len = 256
head_dim = 64

# Create two identical static cache tensors
cache_index_copy = torch.zeros(
    (batch_size, num_heads, max_cache_len, head_dim),
    dtype=dtype, device=device
)
cache_slice = torch.zeros(
    (batch_size, num_heads, max_cache_len, head_dim),
    dtype=dtype, device=device
)

# Create cache positions and key states
cache_position = torch.arange(seq_len, device=device)
key_states = torch.randn(
    (batch_size, num_heads, seq_len, head_dim),
    dtype=dtype, device=device
)

# Method 1: index_copy_ (buggy on ROCm)
cache_index_copy.index_copy_(2, cache_position, key_states)

# Method 2: slice assignment (works correctly)
cache_slice[:, :, cache_position] = key_states

# Check results
has_nan_index_copy = torch.isnan(cache_index_copy).any().item()
has_nan_slice = torch.isnan(cache_slice).any().item()
results_match = torch.allclose(cache_index_copy, cache_slice, equal_nan=True)

return has_nan_index_copy, has_nan_slice, results_match

def test_index_copy_with_compile(device: str, seq_len: int):
"""
Test index_copy_ with torch.compile (static cache typically uses compilation).
"""
import torch

batch_size = 1
num_heads = 8
max_cache_len = 256
head_dim = 64

class CacheModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.cache = torch.zeros(
            (batch_size, num_heads, max_cache_len, head_dim),
            dtype=torch.bfloat16
        )

    def forward(self, key_states, cache_position):
        self.cache.index_copy_(2, cache_position, key_states)
        return self.cache.clone()

cache_position = torch.arange(seq_len, device=device)
key_states = torch.randn(
    (batch_size, num_heads, seq_len, head_dim),
    dtype=torch.bfloat16, device=device
)

# Test eager mode
model_eager = CacheModule().to(device)
result_eager = model_eager(key_states, cache_position)
has_nan_eager = torch.isnan(result_eager).any().item()

# Test compiled mode
try:
    model_compiled = torch.compile(CacheModule().to(device))
    result_compiled = model_compiled(key_states, cache_position)
    has_nan_compiled = torch.isnan(result_compiled).any().item()
    success = True
except Exception as e:
    has_nan_compiled = None
    success = False
    print(f"    Compilation failed: {e}")

return has_nan_eager, has_nan_compiled, success

def run_all_tests():
"""Run comprehensive test suite."""
import torch

print_system_info()

device = "cuda" if torch.cuda.is_available() else "cpu"
if device == "cpu":
    print("WARNING: No GPU detected. This test is designed for ROCm/CUDA devices.")
    print("         Results on CPU may not reproduce the ROCm-specific bug.")

# Test 1: Basic index_copy_ with different sequence lengths
print("TEST 1: index_copy_ vs slice assignment (bfloat16)")
print("-" * 70)
print(f"{'Seq Len':<10} {'index_copy_ NaN':<18} {'slice NaN':<15} {'Match':<10}")
print("-" * 70)

test_seq_lengths = [10, 50, 100, 150, 200]
any_failure = False

for seq_len in test_seq_lengths:
    nan_ic, nan_slice, match = test_index_copy_nan(device, seq_len)
    status = "FAIL" if nan_ic or not match else "OK"
    if nan_ic or not match:
        any_failure = True
    print(f"{seq_len:<10} {str(nan_ic):<18} {str(nan_slice):<15} {str(match):<10} {status}")

# Test 2: Different dtypes
print("TEST 2: index_copy_ with different dtypes (seq_len=100)")
print("-" * 70)
print(f"{'Dtype':<15} {'index_copy_ NaN':<18} {'slice NaN':<15} {'Match':<10}")
print("-" * 70)

for dtype in [torch.float32, torch.float16, torch.bfloat16]:
    nan_ic, nan_slice, match = test_index_copy_nan(device, 100, dtype)
    status = "FAIL" if nan_ic or not match else "OK"
    if nan_ic or not match:
        any_failure = True
    print(f"{str(dtype):<15} {str(nan_ic):<18} {str(nan_slice):<15} {str(match):<10} {status}")

# Test 3: With torch.compile
if device == "cuda":
    print("TEST 3: index_copy_ with torch.compile")
    print("-" * 70)
    print(f"{'Seq Len':<10} {'Eager NaN':<15} {'Compiled NaN':<15} {'Status':<10}")
    print("-" * 70)

    for seq_len in [50, 100, 150]:
        nan_eager, nan_compiled, success = test_index_copy_with_compile(device, seq_len)
        if success:
            status = "FAIL" if nan_eager or nan_compiled else "OK"
            if nan_eager or nan_compiled:
                any_failure = True
            print(f"{seq_len:<10} {str(nan_eager):<15} {str(nan_compiled):<15} {status}")
        else:
            print(f"{seq_len:<10} {str(nan_eager):<15} {'N/A':<15} SKIP")

    print()

# Summary
print("=" * 70)
print("SUMMARY")
print("=" * 70)

if any_failure:
    print("RESULT: FAILURES DETECTED")
    print("The index_copy_ operation is producing NaN values or incorrect results.")
    print("This confirms the bug on this system.")
else:
    print("RESULT: ALL TESTS PASSED")
    print("No NaN values detected. The bug may not affect this system/configuration.")

print("=" * 70)

return any_failure

def main():
"""Main entry point."""
try:
import torch
except ImportError:
print("ERROR: PyTorch is not installed")
sys.exit(1)

has_failures = run_all_tests()
sys.exit(1 if has_failures else 0)

if name == "main":
main()

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions