Skip to content

Commit b1df740

Browse files
committed
update
1 parent 8d20369 commit b1df740

File tree

1 file changed

+43
-0
lines changed

1 file changed

+43
-0
lines changed

tests/models/test_models_auto.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
1+
import json
2+
import os
3+
import tempfile
14
import unittest
25
from unittest.mock import MagicMock, patch
36

7+
import torch
48
from transformers import CLIPTextModel, LongformerModel
59

610
from diffusers.models import AutoModel, UNet2DConditionModel
@@ -35,6 +39,45 @@ def test_load_from_model_index(self):
3539
)
3640
assert isinstance(model, CLIPTextModel)
3741

42+
def test_load_dynamic_module_from_local_path_with_subfolder(self):
43+
CUSTOM_MODEL_CODE = (
44+
"import torch\n"
45+
"from diffusers import ModelMixin, ConfigMixin\n"
46+
"from diffusers.configuration_utils import register_to_config\n"
47+
"\n"
48+
"class CustomModel(ModelMixin, ConfigMixin):\n"
49+
" @register_to_config\n"
50+
" def __init__(self, hidden_size=8):\n"
51+
" super().__init__()\n"
52+
" self.linear = torch.nn.Linear(hidden_size, hidden_size)\n"
53+
"\n"
54+
" def forward(self, x):\n"
55+
" return self.linear(x)\n"
56+
)
57+
58+
with tempfile.TemporaryDirectory() as tmpdir:
59+
subfolder = "custom_model"
60+
model_dir = os.path.join(tmpdir, subfolder)
61+
os.makedirs(model_dir)
62+
63+
with open(os.path.join(model_dir, "modeling.py"), "w") as f:
64+
f.write(CUSTOM_MODEL_CODE)
65+
66+
config = {
67+
"_class_name": "CustomModel",
68+
"_diffusers_version": "0.0.0",
69+
"auto_map": {"AutoModel": "modeling.CustomModel"},
70+
"hidden_size": 8,
71+
}
72+
with open(os.path.join(model_dir, "config.json"), "w") as f:
73+
json.dump(config, f)
74+
75+
torch.save({}, os.path.join(model_dir, "diffusion_pytorch_model.bin"))
76+
77+
model = AutoModel.from_pretrained(tmpdir, subfolder=subfolder, trust_remote_code=True)
78+
assert model.__class__.__name__ == "CustomModel"
79+
assert model.config["hidden_size"] == 8
80+
3881

3982
class TestAutoModelFromConfig(unittest.TestCase):
4083
@patch(

0 commit comments

Comments
 (0)