diff --git a/modules/dynunet_pipeline/transforms.py b/modules/dynunet_pipeline/transforms.py index a02d82b5c..0c18eb011 100644 --- a/modules/dynunet_pipeline/transforms.py +++ b/modules/dynunet_pipeline/transforms.py @@ -40,9 +40,9 @@ def get_task_transforms(mode, task_id, pos_sample_num, neg_sample_num, num_sampl keys = ["image", "label"] else: keys = ["image"] - + # 1. loading load_transforms = [ - LoadImaged(keys=keys), + LoadImaged(keys=keys, image_only=False), EnsureChannelFirstd(keys=keys), ] # 2. sampling @@ -284,6 +284,8 @@ def __init__( def calculate_new_shape(self, spacing, shape): spacing_ratio = np.array(spacing) / np.array(self.target_spacing) + if len(shape) == 4: # If shape includes channel dimension + shape = shape[1:] new_shape = (spacing_ratio * np.array(shape)).astype(int).tolist() return new_shape