@@ -309,16 +309,16 @@ def load(self, **kwargs) -> Any:
309309 f"`type_hint` is required when loading a single file model but is missing for component: { self .name } "
310310 )
311311
312+ from diffusers import AutoModel
313+
312314 # `torch_dtype` is not an accepted parameter for tokenizers and processors.
313315 # As a result, it gets stored in `init_kwargs`, which are written to the config
314316 # during save. This causes JSON serialization to fail when saving the component.
315- if self .type_hint is not None and not issubclass (self .type_hint , torch .nn .Module ):
317+ if self .type_hint is not None and not issubclass (self .type_hint , ( torch .nn .Module , AutoModel ) ):
316318 kwargs .pop ("torch_dtype" , None )
317319
318320 if self .type_hint is None :
319321 try :
320- from diffusers import AutoModel
321-
322322 component = AutoModel .from_pretrained (pretrained_model_name_or_path , ** load_kwargs , ** kwargs )
323323 except Exception as e :
324324 raise ValueError (f"Unable to load { self .name } without `type_hint`: { e } " )
@@ -332,12 +332,6 @@ def load(self, **kwargs) -> Any:
332332 else getattr (self .type_hint , "from_pretrained" )
333333 )
334334
335- # `torch_dtype` is not an accepted parameter for tokenizers and processors.
336- # As a result, it gets stored in `init_kwargs`, which are written to the config
337- # during save. This causes JSON serialization to fail when saving the component.
338- if not issubclass (self .type_hint , torch .nn .Module ):
339- kwargs .pop ("torch_dtype" , None )
340-
341335 try :
342336 component = load_method (pretrained_model_name_or_path , ** load_kwargs , ** kwargs )
343337 except Exception as e :
0 commit comments