Skip to content

refactor: extract load_pretrain_weights + apply_lora into models/weights.py#840

Open
Borda wants to merge 1 commit intodevelopfrom
refactor/weights-module
Open

refactor: extract load_pretrain_weights + apply_lora into models/weights.py#840
Borda wants to merge 1 commit intodevelopfrom
refactor/weights-module

Conversation

@Borda
Copy link
Member

@Borda Borda commented Mar 20, 2026

What does this PR do?

Eliminates duplicated weight-loading and LoRA application logic between detr.py (_load_pretrain_weights_into, _apply_lora_to) and module_model.py (_load_pretrain_weights, _apply_lora). The canonical unified implementation lives in the new src/rfdetr/models/weights.py and both callers delegate to it.

Key changes:

  • New src/rfdetr/models/weights.py: load_pretrain_weights() uses the more complete module_model.py logic (Pydantic model_fields_set for user-override detection, auto-align to fine-tuned checkpoints, expand-back for explicit num_classes) plus class_names extraction from detr.py. apply_lora() is the unified LoRA application function.
  • detr.py: removes ~80 lines of private weight/LoRA methods; delegates to weights.py; fixes mutable COCO_CLASS_NAMES returned directly from class_names property (now returns list copy).
  • module_model.py: removes ~95 lines of private weight/LoRA methods; delegates to weights.py.
  • tests/models/test_weights.py: 8 new characterization tests (4 reinit scenarios, 2 class_names extraction, 2 LoRA application).
  • test_load_pretrain_weights.py, test_module_model.py, test_detr_shim.py: migrated patch targets and call sites to rfdetr.models.weights.

Type of Change

  • Bug fix (non-breaking change that fixes an issue)
  • Refactoring (no functional changes)

Testing

  • I have tested this change locally
  • I have added/updated tests for this change

Additional Context

…hts.py

Eliminates duplicated weight-loading and LoRA application logic between
detr.py (_load_pretrain_weights_into, _apply_lora_to) and module_model.py
(_load_pretrain_weights, _apply_lora).  The canonical unified implementation
lives in the new src/rfdetr/models/weights.py and both callers delegate to it.

Key changes:
- New src/rfdetr/models/weights.py: load_pretrain_weights() uses the more
  complete module_model.py logic (Pydantic model_fields_set for user-override
  detection, auto-align to fine-tuned checkpoints, expand-back for explicit
  num_classes) plus class_names extraction from detr.py.  apply_lora() is
  the unified LoRA application function.
- detr.py: removes ~80 lines of private weight/LoRA methods; delegates to
  weights.py; fixes mutable COCO_CLASS_NAMES returned directly from
  class_names property (now returns list copy).
- module_model.py: removes ~95 lines of private weight/LoRA methods;
  delegates to weights.py.
- tests/models/test_weights.py: 8 new characterization tests (4 reinit
  scenarios, 2 class_names extraction, 2 LoRA application).
- test_load_pretrain_weights.py, test_module_model.py, test_detr_shim.py:
  migrated patch targets and call sites to rfdetr.models.weights.

All 1079 tests pass.

Co-Authored-By: Codex <codex@openai.com>
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Copilot AI review requested due to automatic review settings March 20, 2026 21:02
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR centralizes pretrained checkpoint loading and LoRA application into a single shared module (rfdetr.models.weights) and updates both the inference facade (rfdetr.detr) and the training LightningModule (rfdetr.training.module_model) plus tests to use that unified implementation.

Changes:

  • Added src/rfdetr/models/weights.py providing canonical load_pretrain_weights() and apply_lora().
  • Refactored src/rfdetr/detr.py and src/rfdetr/training/module_model.py to delegate weight loading / LoRA application to rfdetr.models.weights.
  • Updated and added tests to patch and validate the new unified code paths (including new characterization tests).

Reviewed changes

Copilot reviewed 8 out of 8 changed files in this pull request and generated 4 comments.

Show a summary per file
File Description
src/rfdetr/models/weights.py New unified implementation for loading pretrained weights and applying LoRA.
src/rfdetr/training/module_model.py Removes local helpers and calls the shared weights/LoRA utilities.
src/rfdetr/detr.py Removes duplicate helpers, delegates to shared utilities, and returns a copy for COCO class-name fallback.
src/rfdetr/models/init.py Re-exports load_pretrain_weights / apply_lora from rfdetr.models.
tests/models/test_weights.py New unit tests directly exercising rfdetr.models.weights.
tests/training/test_module_model.py Migrates patch targets and call sites to rfdetr.models.weights.
tests/training/test_load_pretrain_weights.py Refactors regression coverage to exercise the unified loader directly.
tests/training/test_detr_shim.py Migrates detr-path checkpoint compatibility tests to the unified loader and updates class_names identity assertions.

Comment on lines +77 to +88
mc = model_config
pretrain_weights = mc.pretrain_weights
class_names: List[str] = []

# Download first (no-op if already present and hash is valid).
download_pretrain_weights(pretrain_weights)
# If the first download attempt didn't produce the file (e.g. stale MD5
# caused an earlier ValueError that was silently swallowed), retry with
# MD5 validation disabled so a stale registry hash can't block training.
if not os.path.isfile(pretrain_weights):
logger.warning("Pretrain weights not found after initial download; retrying without MD5 validation.")
download_pretrain_weights(pretrain_weights, redownload=True, validate_md5=False)
Copy link

Copilot AI Mar 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pretrain_weights = mc.pretrain_weights can be None (ModelConfig allows it), but this function immediately passes it into download_pretrain_weights() / os.path.isfile(), which will raise a type error if a caller invokes load_pretrain_weights() without a configured path. Since this helper is now a public import, add an explicit guard (e.g., early return [] or raise a clear ValueError) when pretrain_weights is None.

Copilot uses AI. Check for mistakes.

# Extract class_names from the checkpoint if available (ported from detr.py).
if "args" in checkpoint:
class_names = _ckpt_args_get(checkpoint["args"], "class_names") or []
Copy link

Copilot AI Mar 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

class_names = _ckpt_args_get(checkpoint["args"], "class_names") or [] doesn't normalize the result to a list[str]. _ckpt_args_get can return any type stored in the checkpoint (e.g., tuple, numpy array, or other sequence), which would violate the annotated return type and can leak a mutable reference if it's already a list. Consider coercing to a new list (and optionally filtering to strings) before returning.

Suggested change
class_names = _ckpt_args_get(checkpoint["args"], "class_names") or []
raw_class_names = _ckpt_args_get(checkpoint["args"], "class_names")
if raw_class_names:
# Normalize to a new List[str] to avoid leaking mutable references and
# to respect the annotated return type.
if isinstance(raw_class_names, str):
class_names = [raw_class_names]
else:
try:
iterator = iter(raw_class_names)
except TypeError:
# Non-iterable, ignore and keep the default empty list.
class_names = []
else:
class_names = [name for name in iterator if isinstance(name, str)]

Copilot uses AI. Check for mistakes.
# non-default override for num_classes (i.e., left it at the
# ModelConfig default): treat the checkpoint as authoritative.
num_classes = checkpoint_num_classes - 1
configured_num_classes_plus_bg = checkpoint_num_classes
Copy link

Copilot AI Mar 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When auto-aligning to a fine-tuned checkpoint (checkpoint_num_classes < configured_num_classes_plus_bg and the user did not override the default), the code updates only the local num_classes variable but does not persist the aligned value back to model_config.num_classes. That leaves model_config (and any pre-built namespaces based on it) inconsistent with the loaded model head for callers that don't separately sync from nn_model.num_classes (unlike RFDETRModelModule). Consider updating model_config.num_classes as part of the auto-align path (or returning the effective class count) so all callers stay consistent.

Suggested change
configured_num_classes_plus_bg = checkpoint_num_classes
configured_num_classes_plus_bg = checkpoint_num_classes
# Keep the ModelConfig in sync with the effective class count.
mc.num_classes = num_classes

Copilot uses AI. Check for mistakes.
are present.

Raises:
RuntimeError: If the checkpoint file cannot be loaded after a re-download.
Copy link

Copilot AI Mar 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The docstring claims this function raises RuntimeError when the checkpoint cannot be loaded after re-download, but the implementation will re-raise whatever exception torch.load() throws on the second attempt. Either wrap the second failure and raise RuntimeError as documented, or adjust the docstring to match actual behavior.

Suggested change
RuntimeError: If the checkpoint file cannot be loaded after a re-download.
Exception: If the checkpoint file cannot be loaded even after a re-download.

Copilot uses AI. Check for mistakes.
@codecov
Copy link

codecov bot commented Mar 20, 2026

Codecov Report

❌ Patch coverage is 94.66667% with 4 lines in your changes missing coverage. Please review.
✅ Project coverage is 75%. Comparing base (3cd6898) to head (dc2357b).

Additional details and impacted files
@@          Coverage Diff           @@
##           develop   #840   +/-   ##
======================================
  Coverage       75%    75%           
======================================
  Files           92     93    +1     
  Lines         7148   7134   -14     
======================================
- Hits          5383   5377    -6     
+ Misses        1765   1757    -8     
🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants