Skip to content

Commit 76c45cb

Browse files
authored
Merge branch 'main' into transformers-v5-pr
2 parents b1034ae + f1e5914 commit 76c45cb

File tree

1 file changed

+14
-3
lines changed

1 file changed

+14
-3
lines changed

src/diffusers/pipelines/prx/__init__.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,14 +24,25 @@
2424
else:
2525
_import_structure["pipeline_prx"] = ["PRXPipeline"]
2626

27-
# Import T5GemmaEncoder for pipeline loading compatibility
27+
# Wrap T5GemmaEncoder to pass config.encoder (T5GemmaModuleConfig) instead of the
28+
# composite T5GemmaConfig, which lacks flat attributes expected by T5GemmaEncoder.__init__.
2829
try:
2930
if is_transformers_available():
3031
import transformers
31-
from transformers.models.t5gemma.modeling_t5gemma import T5GemmaEncoder
32+
from transformers.models.t5gemma.modeling_t5gemma import T5GemmaEncoder as _T5GemmaEncoder
33+
34+
class T5GemmaEncoder(_T5GemmaEncoder):
35+
@classmethod
36+
def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
37+
if "config" not in kwargs:
38+
from transformers.models.t5gemma.configuration_t5gemma import T5GemmaConfig
39+
40+
config = T5GemmaConfig.from_pretrained(pretrained_model_name_or_path)
41+
if hasattr(config, "encoder"):
42+
kwargs["config"] = config.encoder
43+
return super().from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
3244

3345
_additional_imports["T5GemmaEncoder"] = T5GemmaEncoder
34-
# Patch transformers module directly for serialization
3546
if not hasattr(transformers, "T5GemmaEncoder"):
3647
transformers.T5GemmaEncoder = T5GemmaEncoder
3748
except ImportError:

0 commit comments

Comments
 (0)