Skip to content

Commit a5733eb

Browse files
Error in pulling sam2.1 vs sam2 when SAM-2 is installed (#1782)
Signed-off-by: Sachidanand Alle <salle@nvidia.com>
1 parent b885a59 commit a5733eb

2 files changed

Lines changed: 40 additions & 31 deletions

File tree

monailabel/config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,15 @@
99
# See the License for the specific language governing permissions and
1010
# limitations under the License.
1111
import os
12-
from importlib.util import find_spec
12+
from importlib.metadata import distributions
1313
from typing import Any, Dict, List, Optional
1414

1515
from pydantic import AnyHttpUrl
1616
from pydantic_settings import BaseSettings, SettingsConfigDict
1717

1818

1919
def is_package_installed(name):
20-
return False if find_spec(name) is None else True
20+
return name in sorted(x.name for x in distributions())
2121

2222

2323
class Settings(BaseSettings):

sample-apps/radiology/lib/configs/segmentation.py

Lines changed: 38 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from monailabel.interfaces.config import TaskConfig
2222
from monailabel.interfaces.tasks.infer_v2 import InferTask
2323
from monailabel.interfaces.tasks.train import TrainTask
24-
from monailabel.utils.others.generic import download_file, strtobool
24+
from monailabel.utils.others.generic import download_file, remove_file, strtobool
2525

2626
_, has_cp = optional_import("cupy")
2727
_, has_cucim = optional_import("cucim")
@@ -34,33 +34,38 @@ def init(self, name: str, model_dir: str, conf: Dict[str, str], planner: Any, **
3434
super().init(name, model_dir, conf, planner, **kwargs)
3535

3636
# Labels
37-
self.labels = {
38-
"spleen": 1,
39-
"kidney_right": 2,
40-
"kidney_left": 3,
41-
"gallbladder": 4,
42-
"liver": 5,
43-
"stomach": 6,
44-
"aorta": 7,
45-
"inferior_vena_cava": 8,
46-
"portal_vein_and_splenic_vein": 9,
47-
"pancreas": 10,
48-
"adrenal_gland_right": 11,
49-
"adrenal_gland_left": 12,
50-
"lung_upper_lobe_left": 13,
51-
"lung_lower_lobe_left": 14,
52-
"lung_upper_lobe_right": 15,
53-
"lung_middle_lobe_right": 16,
54-
"lung_lower_lobe_right": 17,
55-
"esophagus": 42,
56-
"trachea": 43,
57-
"heart_myocardium": 44,
58-
"heart_atrium_left": 45,
59-
"heart_ventricle_left": 46,
60-
"heart_atrium_right": 47,
61-
"heart_ventricle_right": 48,
62-
"pulmonary_artery": 49,
63-
}
37+
conf_labels = self.conf.get("labels")
38+
self.labels = (
39+
{label: idx for idx, label in enumerate(conf_labels.split(","), start=1)}
40+
if conf_labels
41+
else {
42+
"spleen": 1,
43+
"kidney_right": 2,
44+
"kidney_left": 3,
45+
"gallbladder": 4,
46+
"liver": 5,
47+
"stomach": 6,
48+
"aorta": 7,
49+
"inferior_vena_cava": 8,
50+
"portal_vein_and_splenic_vein": 9,
51+
"pancreas": 10,
52+
"adrenal_gland_right": 11,
53+
"adrenal_gland_left": 12,
54+
"lung_upper_lobe_left": 13,
55+
"lung_lower_lobe_left": 14,
56+
"lung_upper_lobe_right": 15,
57+
"lung_middle_lobe_right": 16,
58+
"lung_lower_lobe_right": 17,
59+
"esophagus": 42,
60+
"trachea": 43,
61+
"heart_myocardium": 44,
62+
"heart_atrium_left": 45,
63+
"heart_ventricle_left": 46,
64+
"heart_atrium_right": 47,
65+
"heart_ventricle_right": 48,
66+
"pulmonary_artery": 49,
67+
}
68+
)
6469

6570
# Model Files
6671
self.path = [
@@ -69,11 +74,15 @@ def init(self, name: str, model_dir: str, conf: Dict[str, str], planner: Any, **
6974
]
7075

7176
# Download PreTrained Model
72-
if strtobool(self.conf.get("use_pretrained_model", "true")):
77+
if not conf_labels and strtobool(self.conf.get("use_pretrained_model", "true")):
7378
url = f"{self.conf.get('pretrained_path', self.PRE_TRAINED_PATH)}"
7479
url = f"{url}/radiology_segmentation_segresnet_multilabel.pt"
7580
download_file(url, self.path[0])
7681

82+
# Remove pre-trained pt if user is using his/her custom labels.
83+
if conf_labels:
84+
remove_file(self.path[0])
85+
7786
self.target_spacing = (1.5, 1.5, 1.5) # target space for image
7887
# Setting ROI size - This is for the image padding
7988
self.roi_size = (96, 96, 96)

0 commit comments

Comments
 (0)