@@ -357,7 +357,10 @@ def _get_alpha_name(lora_name_alpha, diffusers_name, alpha):
357357
358358# The utilities under `_convert_kohya_flux_lora_to_diffusers()`
359359# are adapted from https://github.com/kohya-ss/sd-scripts/blob/a61cf73a5cb5209c3f4d1a3688dd276a4dfd1ecb/networks/convert_flux_lora.py
360- def _convert_kohya_flux_lora_to_diffusers (state_dict ):
360+ def _convert_kohya_flux_lora_to_diffusers (
361+ state_dict ,
362+ version_flux2 = False ,
363+ ):
361364 def _convert_to_ai_toolkit (sds_sd , ait_sd , sds_key , ait_key ):
362365 if sds_key + ".lora_down.weight" not in sds_sd :
363366 return
@@ -448,7 +451,15 @@ def _convert_to_ai_toolkit_cat(sds_sd, ait_sd, sds_key, ait_keys, dims=None):
448451
449452 def _convert_sd_scripts_to_ai_toolkit (sds_sd ):
450453 ait_sd = {}
451- for i in range (19 ):
454+
455+ max_num_double_blocks , max_num_single_blocks = - 1 , - 1
456+ for key in list (sds_sd .keys ()):
457+ if key .startswith ("lora_unet_double_blocks_" ):
458+ max_num_double_blocks = max (max_num_double_blocks , int (key .split ("_" )[4 ]))
459+ if key .startswith ("lora_unet_single_blocks_" ):
460+ max_num_single_blocks = max (max_num_single_blocks , int (key .split ("_" )[4 ]))
461+
462+ for i in range (max_num_double_blocks + 1 ):
452463 _convert_to_ai_toolkit (
453464 sds_sd ,
454465 ait_sd ,
@@ -469,13 +480,21 @@ def _convert_sd_scripts_to_ai_toolkit(sds_sd):
469480 sds_sd ,
470481 ait_sd ,
471482 f"lora_unet_double_blocks_{ i } _img_mlp_0" ,
472- f"transformer.transformer_blocks.{ i } .ff.net.0.proj" ,
483+ (
484+ f"transformer.transformer_blocks.{ i } .ff.linear_in"
485+ if version_flux2 else
486+ f"transformer.transformer_blocks.{ i } .ff.net.0.proj"
487+ ),
473488 )
474489 _convert_to_ai_toolkit (
475490 sds_sd ,
476491 ait_sd ,
477492 f"lora_unet_double_blocks_{ i } _img_mlp_2" ,
478- f"transformer.transformer_blocks.{ i } .ff.net.2" ,
493+ (
494+ f"transformer.transformer_blocks.{ i } .ff.linear_out"
495+ if version_flux2 else
496+ f"transformer.transformer_blocks.{ i } .ff.net.2"
497+ ),
479498 )
480499 _convert_to_ai_toolkit (
481500 sds_sd ,
@@ -503,13 +522,21 @@ def _convert_sd_scripts_to_ai_toolkit(sds_sd):
503522 sds_sd ,
504523 ait_sd ,
505524 f"lora_unet_double_blocks_{ i } _txt_mlp_0" ,
506- f"transformer.transformer_blocks.{ i } .ff_context.net.0.proj" ,
525+ (
526+ f"transformer.transformer_blocks.{ i } .ff_context.linear_in"
527+ if version_flux2 else
528+ f"transformer.transformer_blocks.{ i } .ff_context.net.0.proj"
529+ ),
507530 )
508531 _convert_to_ai_toolkit (
509532 sds_sd ,
510533 ait_sd ,
511534 f"lora_unet_double_blocks_{ i } _txt_mlp_2" ,
512- f"transformer.transformer_blocks.{ i } .ff_context.net.2" ,
535+ (
536+ f"transformer.transformer_blocks.{ i } .ff_context.linear_out"
537+ if version_flux2 else
538+ f"transformer.transformer_blocks.{ i } .ff_context.net.2"
539+ ),
513540 )
514541 _convert_to_ai_toolkit (
515542 sds_sd ,
@@ -518,24 +545,36 @@ def _convert_sd_scripts_to_ai_toolkit(sds_sd):
518545 f"transformer.transformer_blocks.{ i } .norm1_context.linear" ,
519546 )
520547
521- for i in range (38 ):
522- _convert_to_ai_toolkit_cat (
523- sds_sd ,
524- ait_sd ,
525- f"lora_unet_single_blocks_{ i } _linear1" ,
526- [
527- f"transformer.single_transformer_blocks.{ i } .attn.to_q" ,
528- f"transformer.single_transformer_blocks.{ i } .attn.to_k" ,
529- f"transformer.single_transformer_blocks.{ i } .attn.to_v" ,
530- f"transformer.single_transformer_blocks.{ i } .proj_mlp" ,
531- ],
532- dims = [3072 , 3072 , 3072 , 12288 ],
533- )
548+ for i in range (max_num_single_blocks + 1 ):
549+ if version_flux2 :
550+ _convert_to_ai_toolkit (
551+ sds_sd ,
552+ ait_sd ,
553+ f"lora_unet_single_blocks_{ i } _linear1" ,
554+ f"transformer.single_transformer_blocks.{ i } .attn.to_qkv_mlp_proj" ,
555+ )
556+ else :
557+ _convert_to_ai_toolkit_cat (
558+ sds_sd ,
559+ ait_sd ,
560+ f"lora_unet_single_blocks_{ i } _linear1" ,
561+ [
562+ f"transformer.single_transformer_blocks.{ i } .attn.to_q" ,
563+ f"transformer.single_transformer_blocks.{ i } .attn.to_k" ,
564+ f"transformer.single_transformer_blocks.{ i } .attn.to_v" ,
565+ f"transformer.single_transformer_blocks.{ i } .proj_mlp" ,
566+ ],
567+ dims = [3072 , 3072 , 3072 , 12288 ],
568+ )
534569 _convert_to_ai_toolkit (
535570 sds_sd ,
536571 ait_sd ,
537572 f"lora_unet_single_blocks_{ i } _linear2" ,
538- f"transformer.single_transformer_blocks.{ i } .proj_out" ,
573+ (
574+ f"transformer.single_transformer_blocks.{ i } .attn.to_out"
575+ if version_flux2 else
576+ f"transformer.single_transformer_blocks.{ i } .proj_out"
577+ ),
539578 )
540579 _convert_to_ai_toolkit (
541580 sds_sd ,
0 commit comments