|
22 | 22 | import pylab |
23 | 23 | import schedule |
24 | 24 | import torch |
| 25 | +from hydra import initialize_config_dir |
| 26 | +from hydra.core.global_hydra import GlobalHydra |
25 | 27 | from monai.transforms import KeepLargestConnectedComponent, LoadImaged |
26 | 28 | from PIL import Image |
27 | 29 | from sam2.build_sam import build_sam2, build_sam2_video_predictor |
@@ -114,13 +116,21 @@ def __init__( |
114 | 116 | self._config.update(config) |
115 | 117 |
|
116 | 118 | # Download PreTrained Model |
117 | | - # https://github.com/facebookresearch/sam2?tab=readme-ov-file#model-description |
118 | | - pt = "sam2.1_hiera_large.pt" |
119 | | - url = f"https://dl.fbaipublicfiles.com/segment_anything_2/092824/{pt}" |
120 | | - self.path = os.path.join(model_dir, f"pretrained_{pt}") |
121 | | - download_file(url, self.path) |
| 119 | + pt_url = settings.MONAI_SAM_MODEL_PT |
| 120 | + conf_url = settings.MONAI_SAM_MODEL_CFG |
| 121 | + sam_pt = pt_url.split("/")[-1] |
| 122 | + sam_conf = conf_url.split("/")[-1] |
| 123 | + |
| 124 | + self.path = os.path.join(model_dir, sam_pt) |
| 125 | + self.config_path = os.path.join(model_dir, sam_conf) |
| 126 | + |
| 127 | + GlobalHydra.instance().clear() |
| 128 | + initialize_config_dir(config_dir=model_dir) |
| 129 | + |
| 130 | + download_file(pt_url, self.path) |
| 131 | + download_file(conf_url, self.config_path) |
| 132 | + self.config_path = sam_conf |
122 | 133 |
|
123 | | - self.config_path = "configs/sam2.1/sam2.1_hiera_l.yaml" |
124 | 134 | self.predictors = {} |
125 | 135 | self.image_cache = {} |
126 | 136 | self.inference_state = None |
@@ -393,8 +403,8 @@ def main(): |
393 | 403 | force=True, |
394 | 404 | ) |
395 | 405 |
|
396 | | - app_name = "pathology" |
397 | | - app_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..", "sample-apps", app_name)) |
| 406 | + app_name = "radiology" |
| 407 | + app_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "sample-apps", app_name)) |
398 | 408 | model_dir = os.path.join(app_dir, "model") |
399 | 409 | logger.info(f"Model Dir: {model_dir}") |
400 | 410 | if app_name == "pathology": |
|
0 commit comments