Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 3 additions & 9 deletions src/diffusers/modular_pipelines/modular_pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,16 +309,16 @@ def load(self, **kwargs) -> Any:
f"`type_hint` is required when loading a single file model but is missing for component: {self.name}"
)

from diffusers import AutoModel

# `torch_dtype` is not an accepted parameter for tokenizers and processors.
# As a result, it gets stored in `init_kwargs`, which are written to the config
# during save. This causes JSON serialization to fail when saving the component.
if self.type_hint is not None and not issubclass(self.type_hint, torch.nn.Module):
if self.type_hint is not None and not issubclass(self.type_hint, (torch.nn.Module, AutoModel)):
kwargs.pop("torch_dtype", None)

if self.type_hint is None:
try:
from diffusers import AutoModel
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved this up because now we need to import AutoModel regardless of whether type_hint is provided or not.


component = AutoModel.from_pretrained(pretrained_model_name_or_path, **load_kwargs, **kwargs)
except Exception as e:
raise ValueError(f"Unable to load {self.name} without `type_hint`: {e}")
Expand All @@ -332,12 +332,6 @@ def load(self, **kwargs) -> Any:
else getattr(self.type_hint, "from_pretrained")
)

# `torch_dtype` is not an accepted parameter for tokenizers and processors.
# As a result, it gets stored in `init_kwargs`, which are written to the config
# during save. This causes JSON serialization to fail when saving the component.
if not issubclass(self.type_hint, torch.nn.Module):
kwargs.pop("torch_dtype", None)

try:
component = load_method(pretrained_model_name_or_path, **load_kwargs, **kwargs)
except Exception as e:
Expand Down
Loading