Skip to content

Commit c7d5689

Browse files
authored
Merge branch 'main' into fix-llava-kwargs-tests
2 parents 7c10568 + d1b3555 commit c7d5689

File tree

1 file changed

+3
-9
lines changed

1 file changed

+3
-9
lines changed

src/diffusers/modular_pipelines/modular_pipeline_utils.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -309,16 +309,16 @@ def load(self, **kwargs) -> Any:
309309
f"`type_hint` is required when loading a single file model but is missing for component: {self.name}"
310310
)
311311

312+
from diffusers import AutoModel
313+
312314
# `torch_dtype` is not an accepted parameter for tokenizers and processors.
313315
# As a result, it gets stored in `init_kwargs`, which are written to the config
314316
# during save. This causes JSON serialization to fail when saving the component.
315-
if self.type_hint is not None and not issubclass(self.type_hint, torch.nn.Module):
317+
if self.type_hint is not None and not issubclass(self.type_hint, (torch.nn.Module, AutoModel)):
316318
kwargs.pop("torch_dtype", None)
317319

318320
if self.type_hint is None:
319321
try:
320-
from diffusers import AutoModel
321-
322322
component = AutoModel.from_pretrained(pretrained_model_name_or_path, **load_kwargs, **kwargs)
323323
except Exception as e:
324324
raise ValueError(f"Unable to load {self.name} without `type_hint`: {e}")
@@ -332,12 +332,6 @@ def load(self, **kwargs) -> Any:
332332
else getattr(self.type_hint, "from_pretrained")
333333
)
334334

335-
# `torch_dtype` is not an accepted parameter for tokenizers and processors.
336-
# As a result, it gets stored in `init_kwargs`, which are written to the config
337-
# during save. This causes JSON serialization to fail when saving the component.
338-
if not issubclass(self.type_hint, torch.nn.Module):
339-
kwargs.pop("torch_dtype", None)
340-
341335
try:
342336
component = load_method(pretrained_model_name_or_path, **load_kwargs, **kwargs)
343337
except Exception as e:

0 commit comments

Comments
 (0)