diff --git a/grounded_sam2_local_demo.py b/grounded_sam2_local_demo.py index 49992541..18b6bf25 100644 --- a/grounded_sam2_local_demo.py +++ b/grounded_sam2_local_demo.py @@ -17,7 +17,7 @@ TEXT_PROMPT = "car. tire." IMG_PATH = "notebooks/images/truck.jpg" SAM2_CHECKPOINT = "./checkpoints/sam2.1_hiera_large.pt" -SAM2_MODEL_CONFIG = "configs/sam2.1/sam2.1_hiera_l.yaml" +SAM2_MODEL_CONFIG = "./sam2/configs/sam2.1/sam2.1_hiera_l.yaml" GROUNDING_DINO_CONFIG = "grounding_dino/groundingdino/config/GroundingDINO_SwinT_OGC.py" GROUNDING_DINO_CHECKPOINT = "gdino_checkpoints/groundingdino_swint_ogc.pth" BOX_THRESHOLD = 0.35 diff --git a/sam2/build_sam.py b/sam2/build_sam.py index 7cfc4513..1ff287b0 100644 --- a/sam2/build_sam.py +++ b/sam2/build_sam.py @@ -8,10 +8,13 @@ import os import torch -from hydra import compose +from hydra import compose, initialize_config_dir from hydra.utils import instantiate +from hydra.core.global_hydra import GlobalHydra from omegaconf import OmegaConf +from pathlib import Path + import sam2 # Check if the user is running Python from the parent directory of the sam2 repo @@ -73,27 +76,59 @@ def build_sam2( ckpt_path=None, device="cuda", mode="eval", - hydra_overrides_extra=[], + hydra_overrides_extra=None, apply_postprocessing=True, + config_root=".", **kwargs, ): + """ + Builds SAM 2 model with support for loading config files from any local directory. + Compatible with Hydra >= 1.3.2. + + Args: + config_file (str): Name of the Hydra config YAML file (e.g., "config.yaml"). + ckpt_path (str): Path to the model checkpoint to load. + device (str): Device to place the model on ("cuda" or "cpu"). + mode (str): "eval" or "train" mode for the model. + hydra_overrides_extra (list): Additional Hydra override strings. + apply_postprocessing (bool): Whether to apply extra SAM-specific config tweaks. + config_root (str): Path to the directory containing Hydra config files. + **kwargs: Additional unused arguments (for compatibility). + + Returns: + torch.nn.Module: The instantiated and loaded model. + """ + + hydra_overrides_extra = hydra_overrides_extra.copy() if hydra_overrides_extra else [] if apply_postprocessing: - hydra_overrides_extra = hydra_overrides_extra.copy() hydra_overrides_extra += [ - # dynamically fall back to multi-mask if the single mask is not stable "++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true", "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05", "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98", ] - # Read config and init model - cfg = compose(config_name=config_file, overrides=hydra_overrides_extra) + + config_root = str(Path(config_root).resolve()) + + try: + if GlobalHydra.instance().is_initialized(): + GlobalHydra.instance().clear() + + with initialize_config_dir(config_dir=config_root, job_name="build_sam2"): + cfg = compose(config_name=config_file, overrides=hydra_overrides_extra) + logging.info(f"Successfully loaded config: {config_file}") + except Exception as e: + logging.error(f"Failed to load config '{config_file}' from '{config_root}': {e}") + raise + OmegaConf.resolve(cfg) model = instantiate(cfg.model, _recursive_=True) _load_checkpoint(model, ckpt_path) model = model.to(device) + if mode == "eval": model.eval() + return model