-
Notifications
You must be signed in to change notification settings - Fork 396
Description
Hi,
Cross file pytorch/pytorch#153476 here, because maybe the issue should be handled by einops.
einops were supported in compilation in torch 2.6 and can be collected to graph by dynamo. But seems it is broken in torch 2.7
import torch
import einops
print(f"PyTorch version: {torch.__version__}")
@torch.compile(fullgraph=True)
def fn(x):
return einops.rearrange(x, 'b c -> c b')
x = torch.randn(2, 3)
fn(x)With torch 2.6
$ python rearr.py
PyTorch version: 2.6.0+cu124
With torch 2.7
$ python rearr.py
PyTorch version: 2.7.0+cu126
Traceback (most recent call last):
File "/tmp/rearr.py", line 11, in <module>
fn(x)
File "/home/vadim/venv_torchpure2.7/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 659, in _fn
raise e.with_traceback(None) from None
torch._dynamo.exc.Unsupported: Unsupported method call
Explanation: Dynamo does not know how to trace method `symmetric_difference` of class `type`
Hint: Avoid calling `type.symmetric_difference` in your code.
Hint: Please report an issue to PyTorch.
Developer debug context: call_method BuiltinVariable(set) symmetric_difference [SetVariable(), SetVariable()] {}
from user code:
File "/tmp/rearr.py", line 8, in fn
return einops.rearrange(x, 'b c -> c b')
File "/home/vadim/venv_torchpure2.7/lib/python3.10/site-packages/einops/einops.py", line 600, in rearrange
return reduce(tensor, pattern, reduction="rearrange", **axes_lengths)
File "/home/vadim/venv_torchpure2.7/lib/python3.10/site-packages/einops/einops.py", line 531, in reduce
recipe = _prepare_transformation_recipe(pattern, reduction, axes_names=tuple(axes_lengths), ndim=len(shape))
File "/home/vadim/venv_torchpure2.7/lib/python3.10/site-packages/torch/_dynamo/polyfills/__init__.py", line 140, in getattr_and_trace
return fn(*args[2:], **kwargs)
File "/home/vadim/venv_torchpure2.7/lib/python3.10/site-packages/einops/einops.py", line 311, in _prepare_transformation_recipe
difference = set.symmetric_difference(left.identifiers, rght.identifiers)
Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"
Dynamo go inside einops.rearrange and step on unsupported symmetric_difference.
From pytorch side seems the reason is pytorch/pytorch@270ad51
With this commit import einops._torch_specific was limited to einops version < 0.7.0 with comment # version > 0.7.0 does allow_in_graph out of tree. But seems it doesn't work always out of tree.
einops._torch_specific.allow_ops_in_compiled_graph() handles calling allow_in_graph for einops. But if made only import einops then there is no call of allow_ops_in_compiled_graph() or import einops._torch_specific that responsible for calling allow_in_graph.
PS to whom might step on the issue and wish fast workaround. You may add
import einops._torch_specificafter import einops. It fixed the issue.