Skip to content

Commit 4fa1bbc

Browse files
committed
- Fix CC Anterior to Posterior Order
- Pass the correct segmentation file into fastsurfer_cc - Make reduce_to_aseg in reduce_to_aseg.py NOT be silently an inplace operation - Clean up typing in CorpusCallosum/paint_into_pred.py
1 parent 55ac2b4 commit 4fa1bbc

4 files changed

Lines changed: 59 additions & 66 deletions

File tree

CorpusCallosum/paint_cc_into_pred.py

Lines changed: 50 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -19,23 +19,20 @@
1919
import sys
2020
from functools import partial
2121
from pathlib import Path
22-
from typing import TypeVar, cast
2322

2423
import nibabel as nib
2524
import numpy as np
26-
from numpy import typing as npt
2725
from scipy import ndimage
2826

29-
import FastSurferCNN.utils.logging as logging
3027
from CorpusCallosum.data.constants import FORNIX_LABEL, SUBSEGMENT_LABELS
3128
from FastSurferCNN.data_loader.conform import is_conform
29+
from FastSurferCNN.data_loader.data_utils import load_image
3230
from FastSurferCNN.reduce_to_aseg import reduce_to_aseg_and_save
31+
from FastSurferCNN.utils import Mask2d, Mask3d, Shape3d, logging
3332
from FastSurferCNN.utils.arg_types import path_or_none
3433
from FastSurferCNN.utils.brainvolstats import mask_in_array
3534
from FastSurferCNN.utils.parallel import thread_executor
3635

37-
_T = TypeVar("_T", bound=np.number)
38-
3936
logger = logging.get_logger(__name__)
4037

4138
HELPTEXT = """
@@ -55,7 +52,8 @@
5552
5653
Original Author: Leonie Henschel
5754
Date: Jul-10-2020
58-
55+
Modified by: Clemens Pollak, David Kügler
56+
Date: Dec-2025
5957
"""
6058

6159

@@ -110,26 +108,23 @@ def make_parser() -> argparse.ArgumentParser:
110108
return parser
111109

112110

113-
def paint_in_cc(pred: npt.NDArray[np.int_],
114-
aseg_cc: npt.NDArray[np.int_]) -> npt.NDArray[np.int_]:
111+
def paint_in_cc(
112+
pred: np.ndarray[Shape3d, np.dtype[int]],
113+
aseg_cc: np.ndarray[Shape3d, np.dtype[int]],
114+
) -> np.ndarray[Shape3d, np.dtype[int]]:
115115
"""Paint corpus callosum segmentation into aseg+dkt segmentation map.
116116
117117
Parameters
118118
----------
119-
pred : npt.NDArray[np.int_]
119+
pred : np.ndarray
120120
Deep-learning segmentation map.
121-
aseg_cc : npt.NDArray[np.int_]
121+
aseg_cc : np.ndarray
122122
Aseg segmentation with CC.
123123
124124
Returns
125125
-------
126-
npt.NDArray[np.int_]
126+
np.ndarray
127127
Segmentation map with added CC.
128-
129-
Notes
130-
-----
131-
This function modifies the original array and does not create a copy.
132-
The CC labels (251-255) from aseg_cc are copied into pred.
133128
"""
134129
cc_mask = mask_in_array(aseg_cc, SUBSEGMENT_LABELS)
135130

@@ -142,14 +137,14 @@ def paint_in_cc(pred: npt.NDArray[np.int_],
142137
logger.info(f"Painting CC: {np.sum(cc_mask)} voxels (replacing {num_wm_replaced} WM, "
143138
f"{num_background_replaced} background, {num_other_replaced} other)")
144139

145-
pred[cc_mask] = aseg_cc[cc_mask]
146-
return pred
147-
148-
def _fill_gaps_in_direction(
149-
corrected_pred: npt.NDArray[np.int_],
150-
potential_fill: npt.NDArray[np.bool_],
151-
source_binary: npt.NDArray[np.bool_],
152-
target_binary: npt.NDArray[np.bool_],
140+
out = np.where(cc_mask, aseg_cc, pred)
141+
return out
142+
143+
def _fill_gaps_in_direction_(
144+
corrected_pred: np.ndarray[Shape3d, np.dtype[int]],
145+
potential_fill: Mask2d,
146+
source_binary: Mask2d,
147+
target_binary: Mask2d,
153148
x_slice: int,
154149
direction: str,
155150
max_gap_voxels: int,
@@ -159,13 +154,13 @@ def _fill_gaps_in_direction(
159154
160155
Parameters
161156
----------
162-
corrected_pred : npt.NDArray[np.int_]
157+
corrected_pred : np.ndarray
163158
The segmentation array to modify in place.
164-
potential_fill : npt.NDArray[np.bool_]
159+
potential_fill : np.ndarray
165160
2D mask of potential fill regions for this slice.
166-
source_binary : npt.NDArray[np.bool_]
161+
source_binary : np.ndarray
167162
2D binary mask of source structure (e.g., CC).
168-
target_binary : npt.NDArray[np.bool_]
163+
target_binary : np.ndarray
169164
2D binary mask of target structure (e.g., ventricle).
170165
x_slice : int
171166
The x-coordinate of the current slice.
@@ -254,10 +249,10 @@ def _fill_gaps_in_direction(
254249
return voxels_filled
255250

256251

257-
def _fill_gaps_between_structures(
258-
corrected_pred: npt.NDArray[np.int_],
259-
source_mask: npt.NDArray[np.bool_],
260-
target_mask: npt.NDArray[np.bool_],
252+
def _fill_gaps_between_structures_(
253+
corrected_pred: np.ndarray[Shape3d, np.dtype[int]],
254+
source_mask: Mask3d,
255+
target_mask: Mask3d,
261256
voxel_size: tuple[float, float, float],
262257
close_gap_size_mm: float,
263258
fillable_labels: set[int],
@@ -267,11 +262,11 @@ def _fill_gaps_between_structures(
267262
268263
Parameters
269264
----------
270-
corrected_pred : npt.NDArray[np.int_]
265+
corrected_pred : np.ndarray
271266
The segmentation array to modify in place.
272-
source_mask : npt.NDArray[np.bool_]
267+
source_mask : np.ndarray
273268
3D binary mask of source structure (e.g., CC).
274-
target_mask : npt.NDArray[np.bool_]
269+
target_mask : np.ndarray
275270
3D binary mask of target structure (e.g., ventricle or background).
276271
voxel_size : tuple[float, float, float]
277272
Voxel size in mm.
@@ -315,13 +310,13 @@ def _fill_gaps_between_structures(
315310
potential_fill = (source_dilated & target_dilated) & ~(source_binary | target_binary)
316311

317312
# Fill gaps in inferior-superior direction
318-
voxels_filled += _fill_gaps_in_direction(
313+
voxels_filled += _fill_gaps_in_direction_(
319314
corrected_pred, potential_fill, source_binary, target_binary,
320315
x, 'inferior-superior', max_gap_vox_inferior_superior, fillable_labels
321316
)
322317

323318
# Fill gaps in anterior-posterior direction
324-
voxels_filled += _fill_gaps_in_direction(
319+
voxels_filled += _fill_gaps_in_direction_(
325320
corrected_pred, potential_fill, source_binary, target_binary,
326321
x, 'anterior-posterior', max_gap_vox_anterior_posterior, fillable_labels
327322
)
@@ -333,11 +328,11 @@ def _fill_gaps_between_structures(
333328

334329

335330
def correct_wm_ventricles(
336-
aseg_cc: npt.NDArray[np.int_],
337-
fornix_mask: npt.NDArray[np.bool_],
331+
aseg_cc: np.ndarray[Shape3d, np.dtype[int]],
332+
fornix_mask: Mask3d,
338333
voxel_size: tuple[float, float, float],
339334
close_gap_size_mm: float = 3.0
340-
) -> npt.NDArray[np.int_]:
335+
) -> np.ndarray[Shape3d, np.dtype[int]]:
341336
"""Fill small gaps between corpus callosum, ventricles, and background.
342337
343338
This function performs two gap-filling operations:
@@ -349,9 +344,9 @@ def correct_wm_ventricles(
349344
350345
Parameters
351346
----------
352-
aseg_cc : npt.NDArray[np.int_]
347+
aseg_cc : np.ndarray
353348
Aseg segmentation with CC already painted in.
354-
fornix_mask : npt.NDArray[np.bool_]
349+
fornix_mask : np.ndarray
355350
Mask of the fornix. Not currently used (kept for interface compatibility).
356351
voxel_size : tuple[float, float, float]
357352
Voxel size of the aseg image in mm.
@@ -360,7 +355,7 @@ def correct_wm_ventricles(
360355
361356
Returns
362357
-------
363-
npt.NDArray[np.int_]
358+
np.ndarray
364359
Corrected segmentation map with filled gaps.
365360
"""
366361
# Create a copy to avoid modifying the original
@@ -374,37 +369,38 @@ def correct_wm_ventricles(
374369

375370
# Get background mask
376371
background_mask = aseg_cc == 0
377-
372+
print(np.unique(corrected_pred))
373+
378374
# 1. Fill gaps between CC and ventricles (replace WM and background with ventricle labels)
379-
_fill_gaps_between_structures(
375+
_fill_gaps_between_structures_(
380376
corrected_pred, cc_mask, ventricle_mask, voxel_size, close_gap_size_mm,
381377
fillable_labels={0, 2, 41}, # background and WM
382378
description="between CC and ventricles (WM/background → ventricle)"
383379
)
384-
380+
print(np.unique(corrected_pred))
381+
385382
# 2. Fill WM gaps between CC and background (replace WM with background)
386-
_fill_gaps_between_structures(
383+
_fill_gaps_between_structures_(
387384
corrected_pred, cc_mask, background_mask, voxel_size, close_gap_size_mm,
388385
fillable_labels={2, 41}, # only WM
389386
description="between CC and background (WM → background)"
390387
)
388+
print(np.unique(corrected_pred))
391389

392390
return corrected_pred
393391

394392

395393
if __name__ == "__main__":
396-
from FastSurferCNN.utils import nibabelImage
397394

398395
# Command Line options are error checking done here
399396
options = argument_parse()
400397

401398
logging.setup_logging()
402399

403400
logger.info(f"Reading inputs: {options.input_cc} {options.input_pred}...")
404-
cc_seg_image = cast(nibabelImage, nib.load(options.input_cc))
405-
cc_seg_data = np.asanyarray(cc_seg_image.dataobj)
406-
aseg_image = cast(nibabelImage, nib.load(options.input_pred))
407-
aseg_data = np.asanyarray(aseg_image.dataobj)
401+
402+
tmap = thread_executor().map
403+
(cc_seg_image, cc_seg_data), (aseg_image, aseg_data) = tmap(load_image, (options.input_cc, options.input_pred))
408404

409405
def _is_conform(img, dtype, verbose):
410406
return is_conform(img, vox_size=None, img_size=None, verbose=verbose, dtype=dtype)
@@ -433,8 +429,8 @@ def _is_conform(img, dtype, verbose):
433429
initial_wm = np.sum((aseg_data == 2) | (aseg_data == 41))
434430
initial_ventricles = np.sum((aseg_data == 4) | (aseg_data == 43))
435431

436-
# Paint CC into prediction (modifies aseg_data in place)
437-
paint_in_cc(aseg_data, cc_seg_data)
432+
# Paint CC into prediction
433+
aseg_data = paint_in_cc(aseg_data, cc_seg_data)
438434

439435
# Apply ventricle gap filling corrections
440436
fornix_mask = cc_seg_data == FORNIX_LABEL

CorpusCallosum/shape/postprocessing.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -592,14 +592,13 @@ def make_subdivision_mask(
592592
# Use only as many labels as needed based on the number of subdivisions
593593
# Number of regions = number of division lines + 1
594594
num_labels_needed = len(subdivision_lines) + 1
595-
cc_labels_posterior_to_anterior = SUBSEGMENT_LABELS[:num_labels_needed]
595+
cc_labels_anterior_to_posterior = SUBSEGMENT_LABELS[:num_labels_needed][::-1]
596596

597597
# Initialize with first segment label
598-
subdivision_mask = np.full(slice_shape, cc_labels_posterior_to_anterior[0], dtype=np.int32)
599-
598+
subdivision_mask = np.full(slice_shape, cc_labels_anterior_to_posterior[0], dtype=np.int32)
600599
# Process each subdivision line, subdivision_lines has for each division line the two points that are on the
601600
# contour and divide the subsegments
602-
for label, segment_points in zip(cc_labels_posterior_to_anterior[1:], subdivision_lines, strict=True):
601+
for label, segment_points in zip(cc_labels_anterior_to_posterior[1:], subdivision_lines, strict=True):
603602
# line_start and line_end are the intersection points of the CC subsegmentation boundary and the contour line
604603
line_start, line_end = segment_points
605604

@@ -617,14 +616,14 @@ def make_subdivision_mask(
617616
from FastSurferCNN.utils.plotting import backend
618617
with backend("qtagg"):
619618
plt.figure(figsize=(10, 8))
620-
plt.imshow(subdivision_mask, cmap='tab10')
619+
plkwargs = {f"v{op}": getattr(np, op)(cc_labels_anterior_to_posterior) for op in ("min", "max")}
620+
plt.imshow(subdivision_mask, cmap='tab10', **plkwargs)
621621
plt.colorbar(label='Subdivision')
622622
plt.title('CC Subdivision Mask')
623623
plt.xlabel('X')
624624
plt.ylabel('Y')
625625
plt.tight_layout()
626626
plt.show()
627-
628627
return subdivision_mask
629628

630629

FastSurferCNN/reduce_to_aseg.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -119,11 +119,9 @@ def reduce_to_aseg(data_inseg: np.ndarray[ShapeType, _TDType]) -> np.ndarray[Sha
119119
The reduced segmentation.
120120
"""
121121
LOGGER.info("Reducing to aseg ...")
122-
# replace 2000... with 42
123-
data_inseg[data_inseg >= 2000] = 42
124-
# replace 1000... with 3
125-
data_inseg[data_inseg >= 1000] = 3
126-
return data_inseg
122+
cortical_fill = np.full_like(data_inseg, 3)
123+
cortical_fill[data_inseg >= 2000] = 42
124+
return np.where(data_inseg >= 1000, cortical_fill, data_inseg)
127125

128126

129127
def create_mask(aseg_data: np.ndarray[ShapeType, _TDType], dnum: int, enum: int) \

run_fastsurfer.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1189,7 +1189,7 @@ then
11891189
callosum_seg_manedit="$(add_file_suffix "$callosum_seg" "manedit")"
11901190
# generate callosum segmentation, mesh, shape and downstream measure files
11911191
cmd=($python "$CorpusCallosumDir/fastsurfer_cc.py" --sd "$sd" --sid "$subject"
1192-
"--threads" "$threads_seg" "--conformed_name" "$conformed_name" "--aseg_name" "$asegdkt_segfile"
1192+
"--threads" "$threads_seg" "--conformed_name" "$conformed_name" "--aseg_name" "$aseg_segfile"
11931193
"--segmentation_in_orig" "$callosum_seg" "${cc_flags[@]}")
11941194
echo_quoted "${cmd[@]}" | tee -a "$seg_log"
11951195
"${wrap[@]}" "${cmd[@]}" 2>&1 | tee -a "$seg_log"

0 commit comments

Comments
 (0)