File tree Expand file tree Collapse file tree 1 file changed +14
-3
lines changed
src/diffusers/pipelines/prx Expand file tree Collapse file tree 1 file changed +14
-3
lines changed Original file line number Diff line number Diff line change 2424else :
2525 _import_structure ["pipeline_prx" ] = ["PRXPipeline" ]
2626
27- # Import T5GemmaEncoder for pipeline loading compatibility
27+ # Wrap T5GemmaEncoder to pass config.encoder (T5GemmaModuleConfig) instead of the
28+ # composite T5GemmaConfig, which lacks flat attributes expected by T5GemmaEncoder.__init__.
2829try :
2930 if is_transformers_available ():
3031 import transformers
31- from transformers .models .t5gemma .modeling_t5gemma import T5GemmaEncoder
32+ from transformers .models .t5gemma .modeling_t5gemma import T5GemmaEncoder as _T5GemmaEncoder
33+
34+ class T5GemmaEncoder (_T5GemmaEncoder ):
35+ @classmethod
36+ def from_pretrained (cls , pretrained_model_name_or_path , * args , ** kwargs ):
37+ if "config" not in kwargs :
38+ from transformers .models .t5gemma .configuration_t5gemma import T5GemmaConfig
39+
40+ config = T5GemmaConfig .from_pretrained (pretrained_model_name_or_path )
41+ if hasattr (config , "encoder" ):
42+ kwargs ["config" ] = config .encoder
43+ return super ().from_pretrained (pretrained_model_name_or_path , * args , ** kwargs )
3244
3345 _additional_imports ["T5GemmaEncoder" ] = T5GemmaEncoder
34- # Patch transformers module directly for serialization
3546 if not hasattr (transformers , "T5GemmaEncoder" ):
3647 transformers .T5GemmaEncoder = T5GemmaEncoder
3748except ImportError :
You can’t perform that action at this time.
0 commit comments