2121from monailabel .interfaces .config import TaskConfig
2222from monailabel .interfaces .tasks .infer_v2 import InferTask
2323from monailabel .interfaces .tasks .train import TrainTask
24- from monailabel .utils .others .generic import download_file , strtobool
24+ from monailabel .utils .others .generic import download_file , remove_file , strtobool
2525
2626_ , has_cp = optional_import ("cupy" )
2727_ , has_cucim = optional_import ("cucim" )
@@ -34,33 +34,38 @@ def init(self, name: str, model_dir: str, conf: Dict[str, str], planner: Any, **
3434 super ().init (name , model_dir , conf , planner , ** kwargs )
3535
3636 # Labels
37- self .labels = {
38- "spleen" : 1 ,
39- "kidney_right" : 2 ,
40- "kidney_left" : 3 ,
41- "gallbladder" : 4 ,
42- "liver" : 5 ,
43- "stomach" : 6 ,
44- "aorta" : 7 ,
45- "inferior_vena_cava" : 8 ,
46- "portal_vein_and_splenic_vein" : 9 ,
47- "pancreas" : 10 ,
48- "adrenal_gland_right" : 11 ,
49- "adrenal_gland_left" : 12 ,
50- "lung_upper_lobe_left" : 13 ,
51- "lung_lower_lobe_left" : 14 ,
52- "lung_upper_lobe_right" : 15 ,
53- "lung_middle_lobe_right" : 16 ,
54- "lung_lower_lobe_right" : 17 ,
55- "esophagus" : 42 ,
56- "trachea" : 43 ,
57- "heart_myocardium" : 44 ,
58- "heart_atrium_left" : 45 ,
59- "heart_ventricle_left" : 46 ,
60- "heart_atrium_right" : 47 ,
61- "heart_ventricle_right" : 48 ,
62- "pulmonary_artery" : 49 ,
63- }
37+ conf_labels = self .conf .get ("labels" )
38+ self .labels = (
39+ {label : idx for idx , label in enumerate (conf_labels .split ("," ), start = 1 )}
40+ if conf_labels
41+ else {
42+ "spleen" : 1 ,
43+ "kidney_right" : 2 ,
44+ "kidney_left" : 3 ,
45+ "gallbladder" : 4 ,
46+ "liver" : 5 ,
47+ "stomach" : 6 ,
48+ "aorta" : 7 ,
49+ "inferior_vena_cava" : 8 ,
50+ "portal_vein_and_splenic_vein" : 9 ,
51+ "pancreas" : 10 ,
52+ "adrenal_gland_right" : 11 ,
53+ "adrenal_gland_left" : 12 ,
54+ "lung_upper_lobe_left" : 13 ,
55+ "lung_lower_lobe_left" : 14 ,
56+ "lung_upper_lobe_right" : 15 ,
57+ "lung_middle_lobe_right" : 16 ,
58+ "lung_lower_lobe_right" : 17 ,
59+ "esophagus" : 42 ,
60+ "trachea" : 43 ,
61+ "heart_myocardium" : 44 ,
62+ "heart_atrium_left" : 45 ,
63+ "heart_ventricle_left" : 46 ,
64+ "heart_atrium_right" : 47 ,
65+ "heart_ventricle_right" : 48 ,
66+ "pulmonary_artery" : 49 ,
67+ }
68+ )
6469
6570 # Model Files
6671 self .path = [
@@ -69,11 +74,15 @@ def init(self, name: str, model_dir: str, conf: Dict[str, str], planner: Any, **
6974 ]
7075
7176 # Download PreTrained Model
72- if strtobool (self .conf .get ("use_pretrained_model" , "true" )):
77+ if not conf_labels and strtobool (self .conf .get ("use_pretrained_model" , "true" )):
7378 url = f"{ self .conf .get ('pretrained_path' , self .PRE_TRAINED_PATH )} "
7479 url = f"{ url } /radiology_segmentation_segresnet_multilabel.pt"
7580 download_file (url , self .path [0 ])
7681
82+ # Remove pre-trained pt if user is using his/her custom labels.
83+ if conf_labels :
84+ remove_file (self .path [0 ])
85+
7786 self .target_spacing = (1.5 , 1.5 , 1.5 ) # target space for image
7887 # Setting ROI size - This is for the image padding
7988 self .roi_size = (96 , 96 , 96 )
0 commit comments