From d9a2bc4b63dedb1687b48fb1133f96548d5d3cfe Mon Sep 17 00:00:00 2001 From: Jun Hyeok Lee Date: Thu, 2 Apr 2026 23:53:36 +0900 Subject: [PATCH] fix(auto3dseg): handle precomputed crops and safe no-grad cleanup Signed-off-by: Jun Hyeok Lee --- monai/auto3dseg/analyzer.py | 204 +++++++++++++++++------------------ tests/apps/test_auto3dseg.py | 47 +++++++- 2 files changed, 144 insertions(+), 107 deletions(-) diff --git a/monai/auto3dseg/analyzer.py b/monai/auto3dseg/analyzer.py index d48d5fc878..e6a946725e 100644 --- a/monai/auto3dseg/analyzer.py +++ b/monai/auto3dseg/analyzer.py @@ -252,39 +252,35 @@ def __call__(self, data): """ d = dict(data) start = time.time() - restore_grad_state = torch.is_grad_enabled() - torch.set_grad_enabled(False) - - ndas = [d[self.image_key][i] for i in range(d[self.image_key].shape[0])] - if "nda_croppeds" not in d: - nda_croppeds = [get_foreground_image(nda) for nda in ndas] - - # perform calculation - report = deepcopy(self.get_report_format()) - - report[ImageStatsKeys.SHAPE] = [list(nda.shape) for nda in ndas] - report[ImageStatsKeys.CHANNELS] = len(ndas) - report[ImageStatsKeys.CROPPED_SHAPE] = [list(nda_c.shape) for nda_c in nda_croppeds] - report[ImageStatsKeys.SPACING] = ( - affine_to_spacing(data[self.image_key].affine).tolist() - if isinstance(data[self.image_key], MetaTensor) - else [1.0] * min(3, data[self.image_key].ndim) - ) + with torch.no_grad(): + ndas = [d[self.image_key][i] for i in range(d[self.image_key].shape[0])] + nda_croppeds = d["nda_croppeds"] if "nda_croppeds" in d else [get_foreground_image(nda) for nda in ndas] + + # perform calculation + report = deepcopy(self.get_report_format()) + + report[ImageStatsKeys.SHAPE] = [list(nda.shape) for nda in ndas] + report[ImageStatsKeys.CHANNELS] = len(ndas) + report[ImageStatsKeys.CROPPED_SHAPE] = [list(nda_c.shape) for nda_c in nda_croppeds] + report[ImageStatsKeys.SPACING] = ( + affine_to_spacing(data[self.image_key].affine).tolist() + if isinstance(data[self.image_key], MetaTensor) + else [1.0] * min(3, data[self.image_key].ndim) + ) - report[ImageStatsKeys.SIZEMM] = [ - a * b for a, b in zip(report[ImageStatsKeys.SHAPE][0], report[ImageStatsKeys.SPACING]) - ] + report[ImageStatsKeys.SIZEMM] = [ + a * b for a, b in zip(report[ImageStatsKeys.SHAPE][0], report[ImageStatsKeys.SPACING]) + ] - report[ImageStatsKeys.INTENSITY] = [ - self.ops[ImageStatsKeys.INTENSITY].evaluate(nda_c) for nda_c in nda_croppeds - ] + report[ImageStatsKeys.INTENSITY] = [ + self.ops[ImageStatsKeys.INTENSITY].evaluate(nda_c) for nda_c in nda_croppeds + ] - if not verify_report_format(report, self.get_report_format()): - raise RuntimeError(f"report generated by {self.__class__} differs from the report format.") + if not verify_report_format(report, self.get_report_format()): + raise RuntimeError(f"report generated by {self.__class__} differs from the report format.") - d[self.stats_name] = report + d[self.stats_name] = report - torch.set_grad_enabled(restore_grad_state) logger.debug(f"Get image stats spent {time.time() - start}") return d @@ -341,31 +337,28 @@ def __call__(self, data: Mapping) -> dict: d = dict(data) start = time.time() - restore_grad_state = torch.is_grad_enabled() - torch.set_grad_enabled(False) + with torch.no_grad(): + ndas = [d[self.image_key][i] for i in range(d[self.image_key].shape[0])] + ndas_label = d[self.label_key] # (H,W,D) - ndas = [d[self.image_key][i] for i in range(d[self.image_key].shape[0])] - ndas_label = d[self.label_key] # (H,W,D) + if ndas_label.shape != ndas[0].shape: + raise ValueError(f"Label shape {ndas_label.shape} is different from image shape {ndas[0].shape}") - if ndas_label.shape != ndas[0].shape: - raise ValueError(f"Label shape {ndas_label.shape} is different from image shape {ndas[0].shape}") + nda_foregrounds = [get_foreground_label(nda, ndas_label) for nda in ndas] + nda_foregrounds = [nda if nda.numel() > 0 else MetaTensor([0.0]) for nda in nda_foregrounds] - nda_foregrounds = [get_foreground_label(nda, ndas_label) for nda in ndas] - nda_foregrounds = [nda if nda.numel() > 0 else MetaTensor([0.0]) for nda in nda_foregrounds] - - # perform calculation - report = deepcopy(self.get_report_format()) + # perform calculation + report = deepcopy(self.get_report_format()) - report[ImageStatsKeys.INTENSITY] = [ - self.ops[ImageStatsKeys.INTENSITY].evaluate(nda_f) for nda_f in nda_foregrounds - ] + report[ImageStatsKeys.INTENSITY] = [ + self.ops[ImageStatsKeys.INTENSITY].evaluate(nda_f) for nda_f in nda_foregrounds + ] - if not verify_report_format(report, self.get_report_format()): - raise RuntimeError(f"report generated by {self.__class__} differs from the report format.") + if not verify_report_format(report, self.get_report_format()): + raise RuntimeError(f"report generated by {self.__class__} differs from the report format.") - d[self.stats_name] = report + d[self.stats_name] = report - torch.set_grad_enabled(restore_grad_state) logger.debug(f"Get foreground image stats spent {time.time() - start}") return d @@ -470,78 +463,77 @@ def __call__(self, data: Mapping[Hashable, MetaTensor]) -> dict[Hashable, MetaTe start = time.time() image_tensor = d[self.image_key] label_tensor = d[self.label_key] - using_cuda = any( - isinstance(t, (torch.Tensor, MetaTensor)) and t.device.type == "cuda" for t in (image_tensor, label_tensor) - ) - restore_grad_state = torch.is_grad_enabled() - torch.set_grad_enabled(False) - - if isinstance(image_tensor, (MetaTensor, torch.Tensor)) and isinstance( - label_tensor, (MetaTensor, torch.Tensor) - ): - if label_tensor.device != image_tensor.device: - label_tensor = label_tensor.to(image_tensor.device) # type: ignore + with torch.no_grad(): + using_cuda = any( + isinstance(t, (torch.Tensor, MetaTensor)) and t.device.type == "cuda" + for t in (image_tensor, label_tensor) + ) - ndas: list[MetaTensor] = [image_tensor[i] for i in range(image_tensor.shape[0])] # type: ignore - ndas_label: MetaTensor = label_tensor.astype(torch.int16) # (H,W,D) + if isinstance(image_tensor, (MetaTensor, torch.Tensor)) and isinstance( + label_tensor, (MetaTensor, torch.Tensor) + ): + if label_tensor.device != image_tensor.device: + label_tensor = label_tensor.to(image_tensor.device) # type: ignore - if ndas_label.shape != ndas[0].shape: - raise ValueError(f"Label shape {ndas_label.shape} is different from image shape {ndas[0].shape}") + ndas: list[MetaTensor] = [image_tensor[i] for i in range(image_tensor.shape[0])] # type: ignore + ndas_label: MetaTensor = label_tensor.astype(torch.int16) # (H,W,D) - nda_foregrounds: list[torch.Tensor] = [get_foreground_label(nda, ndas_label) for nda in ndas] - nda_foregrounds = [nda if nda.numel() > 0 else MetaTensor([0.0]) for nda in nda_foregrounds] + if ndas_label.shape != ndas[0].shape: + raise ValueError(f"Label shape {ndas_label.shape} is different from image shape {ndas[0].shape}") - unique_label = unique(ndas_label) - if isinstance(ndas_label, (MetaTensor, torch.Tensor)): - unique_label = unique_label.data.cpu().numpy() # type: ignore[assignment] + nda_foregrounds: list[torch.Tensor] = [get_foreground_label(nda, ndas_label) for nda in ndas] + nda_foregrounds = [nda if nda.numel() > 0 else MetaTensor([0.0]) for nda in nda_foregrounds] - unique_label = unique_label.astype(np.int16).tolist() + unique_label = unique(ndas_label) + if isinstance(ndas_label, (MetaTensor, torch.Tensor)): + unique_label = unique_label.data.cpu().numpy() # type: ignore[assignment] - label_substats = [] # each element is one label - pixel_sum = 0 - pixel_arr = [] - for index in unique_label: - start_label = time.time() - label_dict: dict[str, Any] = {} - mask_index = ndas_label == index + unique_label = unique_label.astype(np.int16).tolist() - nda_masks = [nda[mask_index] for nda in ndas] - label_dict[LabelStatsKeys.IMAGE_INTST] = [ - self.ops[LabelStatsKeys.IMAGE_INTST].evaluate(nda_m) for nda_m in nda_masks - ] + label_substats = [] # each element is one label + pixel_sum = 0 + pixel_arr = [] + for index in unique_label: + start_label = time.time() + label_dict: dict[str, Any] = {} + mask_index = ndas_label == index - pixel_count = sum(mask_index) - pixel_arr.append(pixel_count) - pixel_sum += pixel_count - if self.do_ccp: # apply connected component - if using_cuda: - # The back end of get_label_ccp is CuPy - # which is unable to automatically release CUDA GPU memory held by PyTorch - del nda_masks - torch.cuda.empty_cache() - shape_list, ncomponents = get_label_ccp(mask_index) - label_dict[LabelStatsKeys.LABEL_SHAPE] = shape_list - label_dict[LabelStatsKeys.LABEL_NCOMP] = ncomponents - - label_substats.append(label_dict) - logger.debug(f" label {index} stats takes {time.time() - start_label}") - - for i, _ in enumerate(unique_label): - label_substats[i].update({LabelStatsKeys.PIXEL_PCT: float(pixel_arr[i] / pixel_sum)}) + nda_masks = [nda[mask_index] for nda in ndas] + label_dict[LabelStatsKeys.IMAGE_INTST] = [ + self.ops[LabelStatsKeys.IMAGE_INTST].evaluate(nda_m) for nda_m in nda_masks + ] - report = deepcopy(self.get_report_format()) - report[LabelStatsKeys.LABEL_UID] = unique_label - report[LabelStatsKeys.IMAGE_INTST] = [ - self.ops[LabelStatsKeys.IMAGE_INTST].evaluate(nda_f) for nda_f in nda_foregrounds - ] - report[LabelStatsKeys.LABEL] = label_substats + pixel_count = sum(mask_index) + pixel_arr.append(pixel_count) + pixel_sum += pixel_count + if self.do_ccp: # apply connected component + if using_cuda: + # The back end of get_label_ccp is CuPy + # which is unable to automatically release CUDA GPU memory held by PyTorch + del nda_masks + torch.cuda.empty_cache() + shape_list, ncomponents = get_label_ccp(mask_index) + label_dict[LabelStatsKeys.LABEL_SHAPE] = shape_list + label_dict[LabelStatsKeys.LABEL_NCOMP] = ncomponents + + label_substats.append(label_dict) + logger.debug(f" label {index} stats takes {time.time() - start_label}") + + for i, _ in enumerate(unique_label): + label_substats[i].update({LabelStatsKeys.PIXEL_PCT: float(pixel_arr[i] / pixel_sum)}) + + report = deepcopy(self.get_report_format()) + report[LabelStatsKeys.LABEL_UID] = unique_label + report[LabelStatsKeys.IMAGE_INTST] = [ + self.ops[LabelStatsKeys.IMAGE_INTST].evaluate(nda_f) for nda_f in nda_foregrounds + ] + report[LabelStatsKeys.LABEL] = label_substats - if not verify_report_format(report, self.get_report_format()): - raise RuntimeError(f"report generated by {self.__class__} differs from the report format.") + if not verify_report_format(report, self.get_report_format()): + raise RuntimeError(f"report generated by {self.__class__} differs from the report format.") - d[self.stats_name] = report # type: ignore[assignment] + d[self.stats_name] = report # type: ignore[assignment] - torch.set_grad_enabled(restore_grad_state) logger.debug(f"Get label stats spent {time.time() - start}") return d # type: ignore[return-value] diff --git a/tests/apps/test_auto3dseg.py b/tests/apps/test_auto3dseg.py index 2159265873..60c037d8ef 100644 --- a/tests/apps/test_auto3dseg.py +++ b/tests/apps/test_auto3dseg.py @@ -53,7 +53,7 @@ SqueezeDimd, ToDeviced, ) -from monai.utils.enums import DataStatsKeys, LabelStatsKeys +from monai.utils.enums import DataStatsKeys, ImageStatsKeys, LabelStatsKeys from tests.test_utils import skip_if_no_cuda device = "cpu" @@ -322,6 +322,18 @@ def test_image_stats_case_analyzer(self): report_format = analyzer.get_report_format() assert verify_report_format(d["image_stats"], report_format) + def test_image_stats_uses_precomputed_nda_croppeds(self): + analyzer = ImageStats(image_key="image") + image = torch.arange(64.0, dtype=torch.float32).reshape(1, 4, 4, 4) + nda_croppeds = [torch.ones((2, 2, 2), dtype=torch.float32)] + + result = analyzer({"image": image, "nda_croppeds": nda_croppeds}) + report = result["image_stats"] + + assert verify_report_format(report, analyzer.get_report_format()) + assert report[ImageStatsKeys.CROPPED_SHAPE] == [[2, 2, 2]] + self.assertAlmostEqual(report[ImageStatsKeys.INTENSITY][0]["mean"], 1.0) + def test_foreground_image_stats_cases_analyzer(self): analyzer = FgImageStats(image_key="image", label_key="label") transform_list = [ @@ -411,6 +423,39 @@ def test_label_stats_mixed_device_analyzer(self, input_params): self.assertAlmostEqual(foreground_stats[0]["mean"], 4.75) self.assertAlmostEqual(foreground_stats[1]["mean"], 14.75) + def test_case_analyzers_restore_grad_state_on_exception(self): + cases = [ + ( + "image_stats", + ImageStats(image_key="image"), + {"image": torch.randn(1, 4, 4, 4), "nda_croppeds": [None]}, + AttributeError, + ), + ( + "fg_image_stats", + FgImageStats(image_key="image", label_key="label"), + {"image": torch.randn(1, 4, 4, 4), "label": torch.ones(3, 4, 4)}, + ValueError, + ), + ( + "label_stats", + LabelStats(image_key="image", label_key="label"), + {"image": MetaTensor(torch.randn(1, 4, 4, 4)), "label": MetaTensor(torch.ones(3, 4, 4))}, + ValueError, + ), + ] + + original_grad_state = torch.is_grad_enabled() + try: + for name, analyzer, data, error in cases: + with self.subTest(analyzer=name): + torch.set_grad_enabled(True) + with self.assertRaises(error): + analyzer(data) + self.assertTrue(torch.is_grad_enabled()) + finally: + torch.set_grad_enabled(original_grad_state) + def test_filename_case_analyzer(self): analyzer_image = FilenameStats("image", DataStatsKeys.BY_CASE_IMAGE_PATH) analyzer_label = FilenameStats("label", DataStatsKeys.BY_CASE_IMAGE_PATH)