refactor: extract load_pretrain_weights + apply_lora into models/weights.py#840
refactor: extract load_pretrain_weights + apply_lora into models/weights.py#840
Conversation
…hts.py Eliminates duplicated weight-loading and LoRA application logic between detr.py (_load_pretrain_weights_into, _apply_lora_to) and module_model.py (_load_pretrain_weights, _apply_lora). The canonical unified implementation lives in the new src/rfdetr/models/weights.py and both callers delegate to it. Key changes: - New src/rfdetr/models/weights.py: load_pretrain_weights() uses the more complete module_model.py logic (Pydantic model_fields_set for user-override detection, auto-align to fine-tuned checkpoints, expand-back for explicit num_classes) plus class_names extraction from detr.py. apply_lora() is the unified LoRA application function. - detr.py: removes ~80 lines of private weight/LoRA methods; delegates to weights.py; fixes mutable COCO_CLASS_NAMES returned directly from class_names property (now returns list copy). - module_model.py: removes ~95 lines of private weight/LoRA methods; delegates to weights.py. - tests/models/test_weights.py: 8 new characterization tests (4 reinit scenarios, 2 class_names extraction, 2 LoRA application). - test_load_pretrain_weights.py, test_module_model.py, test_detr_shim.py: migrated patch targets and call sites to rfdetr.models.weights. All 1079 tests pass. Co-Authored-By: Codex <codex@openai.com> Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
There was a problem hiding this comment.
Pull request overview
This PR centralizes pretrained checkpoint loading and LoRA application into a single shared module (rfdetr.models.weights) and updates both the inference facade (rfdetr.detr) and the training LightningModule (rfdetr.training.module_model) plus tests to use that unified implementation.
Changes:
- Added
src/rfdetr/models/weights.pyproviding canonicalload_pretrain_weights()andapply_lora(). - Refactored
src/rfdetr/detr.pyandsrc/rfdetr/training/module_model.pyto delegate weight loading / LoRA application torfdetr.models.weights. - Updated and added tests to patch and validate the new unified code paths (including new characterization tests).
Reviewed changes
Copilot reviewed 8 out of 8 changed files in this pull request and generated 4 comments.
Show a summary per file
| File | Description |
|---|---|
| src/rfdetr/models/weights.py | New unified implementation for loading pretrained weights and applying LoRA. |
| src/rfdetr/training/module_model.py | Removes local helpers and calls the shared weights/LoRA utilities. |
| src/rfdetr/detr.py | Removes duplicate helpers, delegates to shared utilities, and returns a copy for COCO class-name fallback. |
| src/rfdetr/models/init.py | Re-exports load_pretrain_weights / apply_lora from rfdetr.models. |
| tests/models/test_weights.py | New unit tests directly exercising rfdetr.models.weights. |
| tests/training/test_module_model.py | Migrates patch targets and call sites to rfdetr.models.weights. |
| tests/training/test_load_pretrain_weights.py | Refactors regression coverage to exercise the unified loader directly. |
| tests/training/test_detr_shim.py | Migrates detr-path checkpoint compatibility tests to the unified loader and updates class_names identity assertions. |
| mc = model_config | ||
| pretrain_weights = mc.pretrain_weights | ||
| class_names: List[str] = [] | ||
|
|
||
| # Download first (no-op if already present and hash is valid). | ||
| download_pretrain_weights(pretrain_weights) | ||
| # If the first download attempt didn't produce the file (e.g. stale MD5 | ||
| # caused an earlier ValueError that was silently swallowed), retry with | ||
| # MD5 validation disabled so a stale registry hash can't block training. | ||
| if not os.path.isfile(pretrain_weights): | ||
| logger.warning("Pretrain weights not found after initial download; retrying without MD5 validation.") | ||
| download_pretrain_weights(pretrain_weights, redownload=True, validate_md5=False) |
There was a problem hiding this comment.
pretrain_weights = mc.pretrain_weights can be None (ModelConfig allows it), but this function immediately passes it into download_pretrain_weights() / os.path.isfile(), which will raise a type error if a caller invokes load_pretrain_weights() without a configured path. Since this helper is now a public import, add an explicit guard (e.g., early return [] or raise a clear ValueError) when pretrain_weights is None.
|
|
||
| # Extract class_names from the checkpoint if available (ported from detr.py). | ||
| if "args" in checkpoint: | ||
| class_names = _ckpt_args_get(checkpoint["args"], "class_names") or [] |
There was a problem hiding this comment.
class_names = _ckpt_args_get(checkpoint["args"], "class_names") or [] doesn't normalize the result to a list[str]. _ckpt_args_get can return any type stored in the checkpoint (e.g., tuple, numpy array, or other sequence), which would violate the annotated return type and can leak a mutable reference if it's already a list. Consider coercing to a new list (and optionally filtering to strings) before returning.
| class_names = _ckpt_args_get(checkpoint["args"], "class_names") or [] | |
| raw_class_names = _ckpt_args_get(checkpoint["args"], "class_names") | |
| if raw_class_names: | |
| # Normalize to a new List[str] to avoid leaking mutable references and | |
| # to respect the annotated return type. | |
| if isinstance(raw_class_names, str): | |
| class_names = [raw_class_names] | |
| else: | |
| try: | |
| iterator = iter(raw_class_names) | |
| except TypeError: | |
| # Non-iterable, ignore and keep the default empty list. | |
| class_names = [] | |
| else: | |
| class_names = [name for name in iterator if isinstance(name, str)] |
| # non-default override for num_classes (i.e., left it at the | ||
| # ModelConfig default): treat the checkpoint as authoritative. | ||
| num_classes = checkpoint_num_classes - 1 | ||
| configured_num_classes_plus_bg = checkpoint_num_classes |
There was a problem hiding this comment.
When auto-aligning to a fine-tuned checkpoint (checkpoint_num_classes < configured_num_classes_plus_bg and the user did not override the default), the code updates only the local num_classes variable but does not persist the aligned value back to model_config.num_classes. That leaves model_config (and any pre-built namespaces based on it) inconsistent with the loaded model head for callers that don't separately sync from nn_model.num_classes (unlike RFDETRModelModule). Consider updating model_config.num_classes as part of the auto-align path (or returning the effective class count) so all callers stay consistent.
| configured_num_classes_plus_bg = checkpoint_num_classes | |
| configured_num_classes_plus_bg = checkpoint_num_classes | |
| # Keep the ModelConfig in sync with the effective class count. | |
| mc.num_classes = num_classes |
| are present. | ||
|
|
||
| Raises: | ||
| RuntimeError: If the checkpoint file cannot be loaded after a re-download. |
There was a problem hiding this comment.
The docstring claims this function raises RuntimeError when the checkpoint cannot be loaded after re-download, but the implementation will re-raise whatever exception torch.load() throws on the second attempt. Either wrap the second failure and raise RuntimeError as documented, or adjust the docstring to match actual behavior.
| RuntimeError: If the checkpoint file cannot be loaded after a re-download. | |
| Exception: If the checkpoint file cannot be loaded even after a re-download. |
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## develop #840 +/- ##
======================================
Coverage 75% 75%
======================================
Files 92 93 +1
Lines 7148 7134 -14
======================================
- Hits 5383 5377 -6
+ Misses 1765 1757 -8 🚀 New features to boost your workflow:
|
What does this PR do?
Eliminates duplicated weight-loading and LoRA application logic between detr.py (_load_pretrain_weights_into, _apply_lora_to) and module_model.py (_load_pretrain_weights, _apply_lora). The canonical unified implementation lives in the new src/rfdetr/models/weights.py and both callers delegate to it.
Key changes:
Type of Change
Testing
Additional Context