Skip to content

Commit 09abb9d

Browse files
fix for loading of Kohya's Flux.2 dev lora
1 parent 227d90a commit 09abb9d

File tree

2 files changed

+69
-20
lines changed

2 files changed

+69
-20
lines changed

src/diffusers/loaders/lora_conversion_utils.py

Lines changed: 59 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -357,7 +357,10 @@ def _get_alpha_name(lora_name_alpha, diffusers_name, alpha):
357357

358358
# The utilities under `_convert_kohya_flux_lora_to_diffusers()`
359359
# are adapted from https://github.com/kohya-ss/sd-scripts/blob/a61cf73a5cb5209c3f4d1a3688dd276a4dfd1ecb/networks/convert_flux_lora.py
360-
def _convert_kohya_flux_lora_to_diffusers(state_dict):
360+
def _convert_kohya_flux_lora_to_diffusers(
361+
state_dict,
362+
version_flux2 = False,
363+
):
361364
def _convert_to_ai_toolkit(sds_sd, ait_sd, sds_key, ait_key):
362365
if sds_key + ".lora_down.weight" not in sds_sd:
363366
return
@@ -448,7 +451,15 @@ def _convert_to_ai_toolkit_cat(sds_sd, ait_sd, sds_key, ait_keys, dims=None):
448451

449452
def _convert_sd_scripts_to_ai_toolkit(sds_sd):
450453
ait_sd = {}
451-
for i in range(19):
454+
455+
max_num_double_blocks, max_num_single_blocks = -1, -1
456+
for key in list(sds_sd.keys()):
457+
if key.startswith("lora_unet_double_blocks_"):
458+
max_num_double_blocks = max(max_num_double_blocks, int(key.split("_")[4]))
459+
if key.startswith("lora_unet_single_blocks_"):
460+
max_num_single_blocks = max(max_num_single_blocks, int(key.split("_")[4]))
461+
462+
for i in range(max_num_double_blocks+1):
452463
_convert_to_ai_toolkit(
453464
sds_sd,
454465
ait_sd,
@@ -469,13 +480,21 @@ def _convert_sd_scripts_to_ai_toolkit(sds_sd):
469480
sds_sd,
470481
ait_sd,
471482
f"lora_unet_double_blocks_{i}_img_mlp_0",
472-
f"transformer.transformer_blocks.{i}.ff.net.0.proj",
483+
(
484+
f"transformer.transformer_blocks.{i}.ff.linear_in"
485+
if version_flux2 else
486+
f"transformer.transformer_blocks.{i}.ff.net.0.proj"
487+
),
473488
)
474489
_convert_to_ai_toolkit(
475490
sds_sd,
476491
ait_sd,
477492
f"lora_unet_double_blocks_{i}_img_mlp_2",
478-
f"transformer.transformer_blocks.{i}.ff.net.2",
493+
(
494+
f"transformer.transformer_blocks.{i}.ff.linear_out"
495+
if version_flux2 else
496+
f"transformer.transformer_blocks.{i}.ff.net.2"
497+
),
479498
)
480499
_convert_to_ai_toolkit(
481500
sds_sd,
@@ -503,13 +522,21 @@ def _convert_sd_scripts_to_ai_toolkit(sds_sd):
503522
sds_sd,
504523
ait_sd,
505524
f"lora_unet_double_blocks_{i}_txt_mlp_0",
506-
f"transformer.transformer_blocks.{i}.ff_context.net.0.proj",
525+
(
526+
f"transformer.transformer_blocks.{i}.ff_context.linear_in"
527+
if version_flux2 else
528+
f"transformer.transformer_blocks.{i}.ff_context.net.0.proj"
529+
),
507530
)
508531
_convert_to_ai_toolkit(
509532
sds_sd,
510533
ait_sd,
511534
f"lora_unet_double_blocks_{i}_txt_mlp_2",
512-
f"transformer.transformer_blocks.{i}.ff_context.net.2",
535+
(
536+
f"transformer.transformer_blocks.{i}.ff_context.linear_out"
537+
if version_flux2 else
538+
f"transformer.transformer_blocks.{i}.ff_context.net.2"
539+
),
513540
)
514541
_convert_to_ai_toolkit(
515542
sds_sd,
@@ -518,24 +545,36 @@ def _convert_sd_scripts_to_ai_toolkit(sds_sd):
518545
f"transformer.transformer_blocks.{i}.norm1_context.linear",
519546
)
520547

521-
for i in range(38):
522-
_convert_to_ai_toolkit_cat(
523-
sds_sd,
524-
ait_sd,
525-
f"lora_unet_single_blocks_{i}_linear1",
526-
[
527-
f"transformer.single_transformer_blocks.{i}.attn.to_q",
528-
f"transformer.single_transformer_blocks.{i}.attn.to_k",
529-
f"transformer.single_transformer_blocks.{i}.attn.to_v",
530-
f"transformer.single_transformer_blocks.{i}.proj_mlp",
531-
],
532-
dims=[3072, 3072, 3072, 12288],
533-
)
548+
for i in range(max_num_single_blocks+1):
549+
if version_flux2:
550+
_convert_to_ai_toolkit(
551+
sds_sd,
552+
ait_sd,
553+
f"lora_unet_single_blocks_{i}_linear1",
554+
f"transformer.single_transformer_blocks.{i}.attn.to_qkv_mlp_proj",
555+
)
556+
else:
557+
_convert_to_ai_toolkit_cat(
558+
sds_sd,
559+
ait_sd,
560+
f"lora_unet_single_blocks_{i}_linear1",
561+
[
562+
f"transformer.single_transformer_blocks.{i}.attn.to_q",
563+
f"transformer.single_transformer_blocks.{i}.attn.to_k",
564+
f"transformer.single_transformer_blocks.{i}.attn.to_v",
565+
f"transformer.single_transformer_blocks.{i}.proj_mlp",
566+
],
567+
dims=[3072, 3072, 3072, 12288],
568+
)
534569
_convert_to_ai_toolkit(
535570
sds_sd,
536571
ait_sd,
537572
f"lora_unet_single_blocks_{i}_linear2",
538-
f"transformer.single_transformer_blocks.{i}.proj_out",
573+
(
574+
f"transformer.single_transformer_blocks.{i}.attn.to_out"
575+
if version_flux2 else
576+
f"transformer.single_transformer_blocks.{i}.proj_out"
577+
),
539578
)
540579
_convert_to_ai_toolkit(
541580
sds_sd,

src/diffusers/loaders/lora_pipeline.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5472,6 +5472,16 @@ def lora_state_dict(
54725472
logger.warning(warn_msg)
54735473
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
54745474

5475+
is_kohya = any(".lora_down.weight" in k for k in state_dict)
5476+
if is_kohya:
5477+
state_dict = _convert_kohya_flux_lora_to_diffusers(
5478+
state_dict,
5479+
version_flux2=True,
5480+
)
5481+
# Kohya already takes care of scaling the LoRA parameters with alpha.
5482+
for k in state_dict:
5483+
assert "alpha" not in k, f"Found key with alpha: {k}"
5484+
54755485
is_ai_toolkit = any(k.startswith("diffusion_model.") for k in state_dict)
54765486
if is_ai_toolkit:
54775487
state_dict = _convert_non_diffusers_flux2_lora_to_diffusers(state_dict)

0 commit comments

Comments
 (0)