-
Notifications
You must be signed in to change notification settings - Fork 76
Description
https://ontrack-internal.amd.com/browse/SWDEV-573529
Bug Summary:
torch.Tensor.index_copy_ produces NaN values silently on ROCm/HIP devices
under specific conditions involving bfloat16 tensors and certain indexing patterns.
This does NOT raise any exception - the operation appears to succeed but the
resulting tensor contains NaN values.
The bug was discovered through failing CI tests in HuggingFace Transformers:
test_model_4b_cropstest_model_4b_batch_cropstest_model_4b_multiimage
Affected Configuration:
- PyTorch with ROCm/HIP backend
- bfloat16 dtype tensors
- Static cache pattern (pre-allocated tensor with index_copy_ updates)
- Longer sequence lengths (100+ positions)
Impact:
This bug causes NaN logits in transformer model generation when using static cache
with multimodal inputs (e.g., Gemma3 with images). The model generates only 1-2 tokens
before stopping due to NaN values in the attention computation.
Workaround:
Use slice assignment (tensor[:, :, indices] = values) instead of index_copy_()
Run:
python rocm_index_copy_issue.py
Expected Output (CUDA/CPU):
All tests should show "has NaN: False" and "Results match: True"
Actual Output (ROCm - BUG):
Tests with longer sequences show "has NaN: True" or "Results match: False"
Reproducer
#!/usr/bin/env python3
"""
ROCm index_copy_ Silent NaN Bug - Minimal Reproduction
"""
import sys
def print_system_info():
"""Print system and PyTorch configuration."""
import torch
print("=" * 70)
print("SYSTEM INFORMATION")
print("=" * 70)
print(f"Python version: {sys.version.split()[0]}")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"CUDA version: {torch.version.cuda}")
print(f"HIP version: {torch.version.hip}")
if torch.cuda.is_available():
print(f"Device count: {torch.cuda.device_count()}")
for i in range(torch.cuda.device_count()):
print(f"Device {i}: {torch.cuda.get_device_name(i)}")
is_rocm = torch.version.hip is not None
print(f"\nRunning on: {'ROCm/HIP' if is_rocm else 'CUDA' if torch.cuda.is_available() else 'CPU'}")
print("=" * 70)
print()
def test_index_copy_nan(device: str, seq_len: int, dtype=None):
"""
Test if index_copy_ produces NaN values.
"""
import torch
if dtype is None:
dtype = torch.bfloat16
# Simulate KV cache dimensions: [batch_size, num_heads, max_cache_len, head_dim]
batch_size = 1
num_heads = 8
max_cache_len = 256
head_dim = 64
# Create two identical static cache tensors
cache_index_copy = torch.zeros(
(batch_size, num_heads, max_cache_len, head_dim),
dtype=dtype, device=device
)
cache_slice = torch.zeros(
(batch_size, num_heads, max_cache_len, head_dim),
dtype=dtype, device=device
)
# Create cache positions and key states
cache_position = torch.arange(seq_len, device=device)
key_states = torch.randn(
(batch_size, num_heads, seq_len, head_dim),
dtype=dtype, device=device
)
# Method 1: index_copy_ (buggy on ROCm)
cache_index_copy.index_copy_(2, cache_position, key_states)
# Method 2: slice assignment (works correctly)
cache_slice[:, :, cache_position] = key_states
# Check results
has_nan_index_copy = torch.isnan(cache_index_copy).any().item()
has_nan_slice = torch.isnan(cache_slice).any().item()
results_match = torch.allclose(cache_index_copy, cache_slice, equal_nan=True)
return has_nan_index_copy, has_nan_slice, results_match
def test_index_copy_with_compile(device: str, seq_len: int):
"""
Test index_copy_ with torch.compile (static cache typically uses compilation).
"""
import torch
batch_size = 1
num_heads = 8
max_cache_len = 256
head_dim = 64
class CacheModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.cache = torch.zeros(
(batch_size, num_heads, max_cache_len, head_dim),
dtype=torch.bfloat16
)
def forward(self, key_states, cache_position):
self.cache.index_copy_(2, cache_position, key_states)
return self.cache.clone()
cache_position = torch.arange(seq_len, device=device)
key_states = torch.randn(
(batch_size, num_heads, seq_len, head_dim),
dtype=torch.bfloat16, device=device
)
# Test eager mode
model_eager = CacheModule().to(device)
result_eager = model_eager(key_states, cache_position)
has_nan_eager = torch.isnan(result_eager).any().item()
# Test compiled mode
try:
model_compiled = torch.compile(CacheModule().to(device))
result_compiled = model_compiled(key_states, cache_position)
has_nan_compiled = torch.isnan(result_compiled).any().item()
success = True
except Exception as e:
has_nan_compiled = None
success = False
print(f" Compilation failed: {e}")
return has_nan_eager, has_nan_compiled, success
def run_all_tests():
"""Run comprehensive test suite."""
import torch
print_system_info()
device = "cuda" if torch.cuda.is_available() else "cpu"
if device == "cpu":
print("WARNING: No GPU detected. This test is designed for ROCm/CUDA devices.")
print(" Results on CPU may not reproduce the ROCm-specific bug.")
# Test 1: Basic index_copy_ with different sequence lengths
print("TEST 1: index_copy_ vs slice assignment (bfloat16)")
print("-" * 70)
print(f"{'Seq Len':<10} {'index_copy_ NaN':<18} {'slice NaN':<15} {'Match':<10}")
print("-" * 70)
test_seq_lengths = [10, 50, 100, 150, 200]
any_failure = False
for seq_len in test_seq_lengths:
nan_ic, nan_slice, match = test_index_copy_nan(device, seq_len)
status = "FAIL" if nan_ic or not match else "OK"
if nan_ic or not match:
any_failure = True
print(f"{seq_len:<10} {str(nan_ic):<18} {str(nan_slice):<15} {str(match):<10} {status}")
# Test 2: Different dtypes
print("TEST 2: index_copy_ with different dtypes (seq_len=100)")
print("-" * 70)
print(f"{'Dtype':<15} {'index_copy_ NaN':<18} {'slice NaN':<15} {'Match':<10}")
print("-" * 70)
for dtype in [torch.float32, torch.float16, torch.bfloat16]:
nan_ic, nan_slice, match = test_index_copy_nan(device, 100, dtype)
status = "FAIL" if nan_ic or not match else "OK"
if nan_ic or not match:
any_failure = True
print(f"{str(dtype):<15} {str(nan_ic):<18} {str(nan_slice):<15} {str(match):<10} {status}")
# Test 3: With torch.compile
if device == "cuda":
print("TEST 3: index_copy_ with torch.compile")
print("-" * 70)
print(f"{'Seq Len':<10} {'Eager NaN':<15} {'Compiled NaN':<15} {'Status':<10}")
print("-" * 70)
for seq_len in [50, 100, 150]:
nan_eager, nan_compiled, success = test_index_copy_with_compile(device, seq_len)
if success:
status = "FAIL" if nan_eager or nan_compiled else "OK"
if nan_eager or nan_compiled:
any_failure = True
print(f"{seq_len:<10} {str(nan_eager):<15} {str(nan_compiled):<15} {status}")
else:
print(f"{seq_len:<10} {str(nan_eager):<15} {'N/A':<15} SKIP")
print()
# Summary
print("=" * 70)
print("SUMMARY")
print("=" * 70)
if any_failure:
print("RESULT: FAILURES DETECTED")
print("The index_copy_ operation is producing NaN values or incorrect results.")
print("This confirms the bug on this system.")
else:
print("RESULT: ALL TESTS PASSED")
print("No NaN values detected. The bug may not affect this system/configuration.")
print("=" * 70)
return any_failure
def main():
"""Main entry point."""
try:
import torch
except ImportError:
print("ERROR: PyTorch is not installed")
sys.exit(1)
has_failures = run_all_tests()
sys.exit(1 if has_failures else 0)
if name == "main":
main()