Skip to content

Commit 9b97932

Browse files
authored
[tests] consistency tests for modular index (#13192)
* add a test to check modular index consistency * check for compulsory keys.
1 parent 680076f commit 9b97932

File tree

1 file changed

+29
-0
lines changed

1 file changed

+29
-0
lines changed

tests/modular_pipelines/test_modular_pipelines_common.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
import gc
2+
import json
3+
import os
24
import tempfile
35
from typing import Callable
46

@@ -349,6 +351,33 @@ def test_save_from_pretrained(self):
349351

350352
assert torch.abs(image_slices[0] - image_slices[1]).max() < 1e-3
351353

354+
def test_modular_index_consistency(self):
355+
pipe = self.get_pipeline()
356+
components_spec = pipe._component_specs
357+
components = sorted(components_spec.keys())
358+
359+
with tempfile.TemporaryDirectory() as tmpdir:
360+
pipe.save_pretrained(tmpdir)
361+
index_file = os.path.join(tmpdir, "modular_model_index.json")
362+
assert os.path.exists(index_file)
363+
364+
with open(index_file) as f:
365+
index_contents = json.load(f)
366+
367+
compulsory_keys = {"_blocks_class_name", "_class_name", "_diffusers_version"}
368+
for k in compulsory_keys:
369+
assert k in index_contents
370+
371+
to_check_attrs = {"pretrained_model_name_or_path", "revision", "subfolder"}
372+
for component in components:
373+
spec = components_spec[component]
374+
for attr in to_check_attrs:
375+
if getattr(spec, "pretrained_model_name_or_path", None) is not None:
376+
for attr in to_check_attrs:
377+
assert component in index_contents, f"{component} should be present in index but isn't."
378+
attr_value_from_index = index_contents[component][2][attr]
379+
assert getattr(spec, attr) == attr_value_from_index
380+
352381
def test_workflow_map(self):
353382
blocks = self.pipeline_blocks_class()
354383
if blocks._workflow_map is None:

0 commit comments

Comments
 (0)