Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
204 changes: 98 additions & 106 deletions monai/auto3dseg/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,39 +252,35 @@ def __call__(self, data):
"""
d = dict(data)
start = time.time()
restore_grad_state = torch.is_grad_enabled()
torch.set_grad_enabled(False)

ndas = [d[self.image_key][i] for i in range(d[self.image_key].shape[0])]
if "nda_croppeds" not in d:
nda_croppeds = [get_foreground_image(nda) for nda in ndas]

# perform calculation
report = deepcopy(self.get_report_format())

report[ImageStatsKeys.SHAPE] = [list(nda.shape) for nda in ndas]
report[ImageStatsKeys.CHANNELS] = len(ndas)
report[ImageStatsKeys.CROPPED_SHAPE] = [list(nda_c.shape) for nda_c in nda_croppeds]
report[ImageStatsKeys.SPACING] = (
affine_to_spacing(data[self.image_key].affine).tolist()
if isinstance(data[self.image_key], MetaTensor)
else [1.0] * min(3, data[self.image_key].ndim)
)
with torch.no_grad():
ndas = [d[self.image_key][i] for i in range(d[self.image_key].shape[0])]
nda_croppeds = d["nda_croppeds"] if "nda_croppeds" in d else [get_foreground_image(nda) for nda in ndas]

# perform calculation
report = deepcopy(self.get_report_format())

report[ImageStatsKeys.SHAPE] = [list(nda.shape) for nda in ndas]
report[ImageStatsKeys.CHANNELS] = len(ndas)
report[ImageStatsKeys.CROPPED_SHAPE] = [list(nda_c.shape) for nda_c in nda_croppeds]
report[ImageStatsKeys.SPACING] = (
affine_to_spacing(data[self.image_key].affine).tolist()
if isinstance(data[self.image_key], MetaTensor)
else [1.0] * min(3, data[self.image_key].ndim)
)

report[ImageStatsKeys.SIZEMM] = [
a * b for a, b in zip(report[ImageStatsKeys.SHAPE][0], report[ImageStatsKeys.SPACING])
]
report[ImageStatsKeys.SIZEMM] = [
a * b for a, b in zip(report[ImageStatsKeys.SHAPE][0], report[ImageStatsKeys.SPACING])
]

report[ImageStatsKeys.INTENSITY] = [
self.ops[ImageStatsKeys.INTENSITY].evaluate(nda_c) for nda_c in nda_croppeds
]
report[ImageStatsKeys.INTENSITY] = [
self.ops[ImageStatsKeys.INTENSITY].evaluate(nda_c) for nda_c in nda_croppeds
]

if not verify_report_format(report, self.get_report_format()):
raise RuntimeError(f"report generated by {self.__class__} differs from the report format.")
if not verify_report_format(report, self.get_report_format()):
raise RuntimeError(f"report generated by {self.__class__} differs from the report format.")

d[self.stats_name] = report
d[self.stats_name] = report

torch.set_grad_enabled(restore_grad_state)
logger.debug(f"Get image stats spent {time.time() - start}")
return d

Expand Down Expand Up @@ -341,31 +337,28 @@ def __call__(self, data: Mapping) -> dict:

d = dict(data)
start = time.time()
restore_grad_state = torch.is_grad_enabled()
torch.set_grad_enabled(False)
with torch.no_grad():
ndas = [d[self.image_key][i] for i in range(d[self.image_key].shape[0])]
ndas_label = d[self.label_key] # (H,W,D)

ndas = [d[self.image_key][i] for i in range(d[self.image_key].shape[0])]
ndas_label = d[self.label_key] # (H,W,D)
if ndas_label.shape != ndas[0].shape:
raise ValueError(f"Label shape {ndas_label.shape} is different from image shape {ndas[0].shape}")

if ndas_label.shape != ndas[0].shape:
raise ValueError(f"Label shape {ndas_label.shape} is different from image shape {ndas[0].shape}")
nda_foregrounds = [get_foreground_label(nda, ndas_label) for nda in ndas]
nda_foregrounds = [nda if nda.numel() > 0 else MetaTensor([0.0]) for nda in nda_foregrounds]

nda_foregrounds = [get_foreground_label(nda, ndas_label) for nda in ndas]
nda_foregrounds = [nda if nda.numel() > 0 else MetaTensor([0.0]) for nda in nda_foregrounds]

# perform calculation
report = deepcopy(self.get_report_format())
# perform calculation
report = deepcopy(self.get_report_format())

report[ImageStatsKeys.INTENSITY] = [
self.ops[ImageStatsKeys.INTENSITY].evaluate(nda_f) for nda_f in nda_foregrounds
]
report[ImageStatsKeys.INTENSITY] = [
self.ops[ImageStatsKeys.INTENSITY].evaluate(nda_f) for nda_f in nda_foregrounds
]

if not verify_report_format(report, self.get_report_format()):
raise RuntimeError(f"report generated by {self.__class__} differs from the report format.")
if not verify_report_format(report, self.get_report_format()):
raise RuntimeError(f"report generated by {self.__class__} differs from the report format.")

d[self.stats_name] = report
d[self.stats_name] = report

torch.set_grad_enabled(restore_grad_state)
logger.debug(f"Get foreground image stats spent {time.time() - start}")
return d

Expand Down Expand Up @@ -470,78 +463,77 @@ def __call__(self, data: Mapping[Hashable, MetaTensor]) -> dict[Hashable, MetaTe
start = time.time()
image_tensor = d[self.image_key]
label_tensor = d[self.label_key]
using_cuda = any(
isinstance(t, (torch.Tensor, MetaTensor)) and t.device.type == "cuda" for t in (image_tensor, label_tensor)
)
restore_grad_state = torch.is_grad_enabled()
torch.set_grad_enabled(False)

if isinstance(image_tensor, (MetaTensor, torch.Tensor)) and isinstance(
label_tensor, (MetaTensor, torch.Tensor)
):
if label_tensor.device != image_tensor.device:
label_tensor = label_tensor.to(image_tensor.device) # type: ignore
with torch.no_grad():
using_cuda = any(
isinstance(t, (torch.Tensor, MetaTensor)) and t.device.type == "cuda"
for t in (image_tensor, label_tensor)
)

ndas: list[MetaTensor] = [image_tensor[i] for i in range(image_tensor.shape[0])] # type: ignore
ndas_label: MetaTensor = label_tensor.astype(torch.int16) # (H,W,D)
if isinstance(image_tensor, (MetaTensor, torch.Tensor)) and isinstance(
label_tensor, (MetaTensor, torch.Tensor)
):
if label_tensor.device != image_tensor.device:
label_tensor = label_tensor.to(image_tensor.device) # type: ignore

if ndas_label.shape != ndas[0].shape:
raise ValueError(f"Label shape {ndas_label.shape} is different from image shape {ndas[0].shape}")
ndas: list[MetaTensor] = [image_tensor[i] for i in range(image_tensor.shape[0])] # type: ignore
ndas_label: MetaTensor = label_tensor.astype(torch.int16) # (H,W,D)

nda_foregrounds: list[torch.Tensor] = [get_foreground_label(nda, ndas_label) for nda in ndas]
nda_foregrounds = [nda if nda.numel() > 0 else MetaTensor([0.0]) for nda in nda_foregrounds]
if ndas_label.shape != ndas[0].shape:
raise ValueError(f"Label shape {ndas_label.shape} is different from image shape {ndas[0].shape}")

unique_label = unique(ndas_label)
if isinstance(ndas_label, (MetaTensor, torch.Tensor)):
unique_label = unique_label.data.cpu().numpy() # type: ignore[assignment]
nda_foregrounds: list[torch.Tensor] = [get_foreground_label(nda, ndas_label) for nda in ndas]
nda_foregrounds = [nda if nda.numel() > 0 else MetaTensor([0.0]) for nda in nda_foregrounds]

unique_label = unique_label.astype(np.int16).tolist()
unique_label = unique(ndas_label)
if isinstance(ndas_label, (MetaTensor, torch.Tensor)):
unique_label = unique_label.data.cpu().numpy() # type: ignore[assignment]

label_substats = [] # each element is one label
pixel_sum = 0
pixel_arr = []
for index in unique_label:
start_label = time.time()
label_dict: dict[str, Any] = {}
mask_index = ndas_label == index
unique_label = unique_label.astype(np.int16).tolist()

nda_masks = [nda[mask_index] for nda in ndas]
label_dict[LabelStatsKeys.IMAGE_INTST] = [
self.ops[LabelStatsKeys.IMAGE_INTST].evaluate(nda_m) for nda_m in nda_masks
]
label_substats = [] # each element is one label
pixel_sum = 0
pixel_arr = []
for index in unique_label:
start_label = time.time()
label_dict: dict[str, Any] = {}
mask_index = ndas_label == index

pixel_count = sum(mask_index)
pixel_arr.append(pixel_count)
pixel_sum += pixel_count
if self.do_ccp: # apply connected component
if using_cuda:
# The back end of get_label_ccp is CuPy
# which is unable to automatically release CUDA GPU memory held by PyTorch
del nda_masks
torch.cuda.empty_cache()
shape_list, ncomponents = get_label_ccp(mask_index)
label_dict[LabelStatsKeys.LABEL_SHAPE] = shape_list
label_dict[LabelStatsKeys.LABEL_NCOMP] = ncomponents

label_substats.append(label_dict)
logger.debug(f" label {index} stats takes {time.time() - start_label}")

for i, _ in enumerate(unique_label):
label_substats[i].update({LabelStatsKeys.PIXEL_PCT: float(pixel_arr[i] / pixel_sum)})
nda_masks = [nda[mask_index] for nda in ndas]
label_dict[LabelStatsKeys.IMAGE_INTST] = [
self.ops[LabelStatsKeys.IMAGE_INTST].evaluate(nda_m) for nda_m in nda_masks
]

report = deepcopy(self.get_report_format())
report[LabelStatsKeys.LABEL_UID] = unique_label
report[LabelStatsKeys.IMAGE_INTST] = [
self.ops[LabelStatsKeys.IMAGE_INTST].evaluate(nda_f) for nda_f in nda_foregrounds
]
report[LabelStatsKeys.LABEL] = label_substats
pixel_count = sum(mask_index)
pixel_arr.append(pixel_count)
pixel_sum += pixel_count
if self.do_ccp: # apply connected component
if using_cuda:
# The back end of get_label_ccp is CuPy
# which is unable to automatically release CUDA GPU memory held by PyTorch
del nda_masks
torch.cuda.empty_cache()
shape_list, ncomponents = get_label_ccp(mask_index)
label_dict[LabelStatsKeys.LABEL_SHAPE] = shape_list
label_dict[LabelStatsKeys.LABEL_NCOMP] = ncomponents

label_substats.append(label_dict)
logger.debug(f" label {index} stats takes {time.time() - start_label}")

for i, _ in enumerate(unique_label):
label_substats[i].update({LabelStatsKeys.PIXEL_PCT: float(pixel_arr[i] / pixel_sum)})

report = deepcopy(self.get_report_format())
report[LabelStatsKeys.LABEL_UID] = unique_label
report[LabelStatsKeys.IMAGE_INTST] = [
self.ops[LabelStatsKeys.IMAGE_INTST].evaluate(nda_f) for nda_f in nda_foregrounds
]
report[LabelStatsKeys.LABEL] = label_substats

if not verify_report_format(report, self.get_report_format()):
raise RuntimeError(f"report generated by {self.__class__} differs from the report format.")
if not verify_report_format(report, self.get_report_format()):
raise RuntimeError(f"report generated by {self.__class__} differs from the report format.")

d[self.stats_name] = report # type: ignore[assignment]
d[self.stats_name] = report # type: ignore[assignment]

torch.set_grad_enabled(restore_grad_state)
logger.debug(f"Get label stats spent {time.time() - start}")
return d # type: ignore[return-value]

Expand Down
47 changes: 46 additions & 1 deletion tests/apps/test_auto3dseg.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
SqueezeDimd,
ToDeviced,
)
from monai.utils.enums import DataStatsKeys, LabelStatsKeys
from monai.utils.enums import DataStatsKeys, ImageStatsKeys, LabelStatsKeys
from tests.test_utils import skip_if_no_cuda

device = "cpu"
Expand Down Expand Up @@ -322,6 +322,18 @@ def test_image_stats_case_analyzer(self):
report_format = analyzer.get_report_format()
assert verify_report_format(d["image_stats"], report_format)

def test_image_stats_uses_precomputed_nda_croppeds(self):
analyzer = ImageStats(image_key="image")
image = torch.arange(64.0, dtype=torch.float32).reshape(1, 4, 4, 4)
nda_croppeds = [torch.ones((2, 2, 2), dtype=torch.float32)]

result = analyzer({"image": image, "nda_croppeds": nda_croppeds})
report = result["image_stats"]

assert verify_report_format(report, analyzer.get_report_format())
assert report[ImageStatsKeys.CROPPED_SHAPE] == [[2, 2, 2]]
self.assertAlmostEqual(report[ImageStatsKeys.INTENSITY][0]["mean"], 1.0)

def test_foreground_image_stats_cases_analyzer(self):
analyzer = FgImageStats(image_key="image", label_key="label")
transform_list = [
Expand Down Expand Up @@ -411,6 +423,39 @@ def test_label_stats_mixed_device_analyzer(self, input_params):
self.assertAlmostEqual(foreground_stats[0]["mean"], 4.75)
self.assertAlmostEqual(foreground_stats[1]["mean"], 14.75)

def test_case_analyzers_restore_grad_state_on_exception(self):
cases = [
(
"image_stats",
ImageStats(image_key="image"),
{"image": torch.randn(1, 4, 4, 4), "nda_croppeds": [None]},
AttributeError,
),
(
"fg_image_stats",
FgImageStats(image_key="image", label_key="label"),
{"image": torch.randn(1, 4, 4, 4), "label": torch.ones(3, 4, 4)},
ValueError,
),
(
"label_stats",
LabelStats(image_key="image", label_key="label"),
{"image": MetaTensor(torch.randn(1, 4, 4, 4)), "label": MetaTensor(torch.ones(3, 4, 4))},
ValueError,
),
]

original_grad_state = torch.is_grad_enabled()
try:
for name, analyzer, data, error in cases:
with self.subTest(analyzer=name):
torch.set_grad_enabled(True)
with self.assertRaises(error):
analyzer(data)
self.assertTrue(torch.is_grad_enabled())
finally:
torch.set_grad_enabled(original_grad_state)

def test_filename_case_analyzer(self):
analyzer_image = FilenameStats("image", DataStatsKeys.BY_CASE_IMAGE_PATH)
analyzer_label = FilenameStats("label", DataStatsKeys.BY_CASE_IMAGE_PATH)
Expand Down
Loading