diff --git a/scripts/import_mlx.py b/scripts/import_mlx.py index 8c2fe346..92344446 100644 --- a/scripts/import_mlx.py +++ b/scripts/import_mlx.py @@ -40,9 +40,6 @@ def import_model( model = {} for name in ["text_emb.weight", "text_linear.weight"]: model[name] = tch_model[name] - for name in tch_model.keys(): - if name.startswith("condition_provider.conditioners"): - model[name] = tch_model[name] model["out_norm.weight"] = tch_model["out_norm.alpha"][0, 0] for idx in range(in_n_q): src_name = f"emb.{idx}.weight" @@ -99,6 +96,13 @@ def import_model( model[layer + "gating.linear_out.weight"] = tch_model[ f"depformer.layers.{layer_idx}.gating.{idx}.linear_out.weight" ] + if "condition_provider.conditioners.description.embed.weight" in tch_model: + e = tch_model["condition_provider.conditioners.description.embed.weight"] + w = tch_model["condition_provider.conditioners.description.output_proj.weight"] + # 4 is very_good + e = e[4:5] @ w.T + print(f"adding the very_good conditioning {e.shape} to {model['text_emb.weight'].shape}") + model["text_emb.weight"] += e save_file(model, out_path)