1919import sys
2020from functools import partial
2121from pathlib import Path
22- from typing import TypeVar , cast
2322
2423import nibabel as nib
2524import numpy as np
26- from numpy import typing as npt
2725from scipy import ndimage
2826
29- import FastSurferCNN .utils .logging as logging
3027from CorpusCallosum .data .constants import FORNIX_LABEL , SUBSEGMENT_LABELS
3128from FastSurferCNN .data_loader .conform import is_conform
29+ from FastSurferCNN .data_loader .data_utils import load_image
3230from FastSurferCNN .reduce_to_aseg import reduce_to_aseg_and_save
31+ from FastSurferCNN .utils import Mask2d , Mask3d , Shape3d , logging
3332from FastSurferCNN .utils .arg_types import path_or_none
3433from FastSurferCNN .utils .brainvolstats import mask_in_array
3534from FastSurferCNN .utils .parallel import thread_executor
3635
37- _T = TypeVar ("_T" , bound = np .number )
38-
3936logger = logging .get_logger (__name__ )
4037
4138HELPTEXT = """
5552
5653Original Author: Leonie Henschel
5754Date: Jul-10-2020
58-
55+ Modified by: Clemens Pollak, David Kügler
56+ Date: Dec-2025
5957"""
6058
6159
@@ -110,26 +108,23 @@ def make_parser() -> argparse.ArgumentParser:
110108 return parser
111109
112110
113- def paint_in_cc (pred : npt .NDArray [np .int_ ],
114- aseg_cc : npt .NDArray [np .int_ ]) -> npt .NDArray [np .int_ ]:
111+ def paint_in_cc (
112+ pred : np .ndarray [Shape3d , np .dtype [int ]],
113+ aseg_cc : np .ndarray [Shape3d , np .dtype [int ]],
114+ ) -> np .ndarray [Shape3d , np .dtype [int ]]:
115115 """Paint corpus callosum segmentation into aseg+dkt segmentation map.
116116
117117 Parameters
118118 ----------
119- pred : npt.NDArray[ np.int_]
119+ pred : np.ndarray
120120 Deep-learning segmentation map.
121- aseg_cc : npt.NDArray[ np.int_]
121+ aseg_cc : np.ndarray
122122 Aseg segmentation with CC.
123123
124124 Returns
125125 -------
126- npt.NDArray[ np.int_]
126+ np.ndarray
127127 Segmentation map with added CC.
128-
129- Notes
130- -----
131- This function modifies the original array and does not create a copy.
132- The CC labels (251-255) from aseg_cc are copied into pred.
133128 """
134129 cc_mask = mask_in_array (aseg_cc , SUBSEGMENT_LABELS )
135130
@@ -142,14 +137,14 @@ def paint_in_cc(pred: npt.NDArray[np.int_],
142137 logger .info (f"Painting CC: { np .sum (cc_mask )} voxels (replacing { num_wm_replaced } WM, "
143138 f"{ num_background_replaced } background, { num_other_replaced } other)" )
144139
145- pred [ cc_mask ] = aseg_cc [ cc_mask ]
146- return pred
147-
148- def _fill_gaps_in_direction (
149- corrected_pred : npt . NDArray [ np .int_ ],
150- potential_fill : npt . NDArray [ np . bool_ ] ,
151- source_binary : npt . NDArray [ np . bool_ ] ,
152- target_binary : npt . NDArray [ np . bool_ ] ,
140+ out = np . where ( cc_mask , aseg_cc , pred )
141+ return out
142+
143+ def _fill_gaps_in_direction_ (
144+ corrected_pred : np . ndarray [ Shape3d , np .dtype [ int ] ],
145+ potential_fill : Mask2d ,
146+ source_binary : Mask2d ,
147+ target_binary : Mask2d ,
153148 x_slice : int ,
154149 direction : str ,
155150 max_gap_voxels : int ,
@@ -159,13 +154,13 @@ def _fill_gaps_in_direction(
159154
160155 Parameters
161156 ----------
162- corrected_pred : npt.NDArray[ np.int_]
157+ corrected_pred : np.ndarray
163158 The segmentation array to modify in place.
164- potential_fill : npt.NDArray[ np.bool_]
159+ potential_fill : np.ndarray
165160 2D mask of potential fill regions for this slice.
166- source_binary : npt.NDArray[ np.bool_]
161+ source_binary : np.ndarray
167162 2D binary mask of source structure (e.g., CC).
168- target_binary : npt.NDArray[ np.bool_]
163+ target_binary : np.ndarray
169164 2D binary mask of target structure (e.g., ventricle).
170165 x_slice : int
171166 The x-coordinate of the current slice.
@@ -254,10 +249,10 @@ def _fill_gaps_in_direction(
254249 return voxels_filled
255250
256251
257- def _fill_gaps_between_structures (
258- corrected_pred : npt . NDArray [ np .int_ ],
259- source_mask : npt . NDArray [ np . bool_ ] ,
260- target_mask : npt . NDArray [ np . bool_ ] ,
252+ def _fill_gaps_between_structures_ (
253+ corrected_pred : np . ndarray [ Shape3d , np .dtype [ int ] ],
254+ source_mask : Mask3d ,
255+ target_mask : Mask3d ,
261256 voxel_size : tuple [float , float , float ],
262257 close_gap_size_mm : float ,
263258 fillable_labels : set [int ],
@@ -267,11 +262,11 @@ def _fill_gaps_between_structures(
267262
268263 Parameters
269264 ----------
270- corrected_pred : npt.NDArray[ np.int_]
265+ corrected_pred : np.ndarray
271266 The segmentation array to modify in place.
272- source_mask : npt.NDArray[ np.bool_]
267+ source_mask : np.ndarray
273268 3D binary mask of source structure (e.g., CC).
274- target_mask : npt.NDArray[ np.bool_]
269+ target_mask : np.ndarray
275270 3D binary mask of target structure (e.g., ventricle or background).
276271 voxel_size : tuple[float, float, float]
277272 Voxel size in mm.
@@ -315,13 +310,13 @@ def _fill_gaps_between_structures(
315310 potential_fill = (source_dilated & target_dilated ) & ~ (source_binary | target_binary )
316311
317312 # Fill gaps in inferior-superior direction
318- voxels_filled += _fill_gaps_in_direction (
313+ voxels_filled += _fill_gaps_in_direction_ (
319314 corrected_pred , potential_fill , source_binary , target_binary ,
320315 x , 'inferior-superior' , max_gap_vox_inferior_superior , fillable_labels
321316 )
322317
323318 # Fill gaps in anterior-posterior direction
324- voxels_filled += _fill_gaps_in_direction (
319+ voxels_filled += _fill_gaps_in_direction_ (
325320 corrected_pred , potential_fill , source_binary , target_binary ,
326321 x , 'anterior-posterior' , max_gap_vox_anterior_posterior , fillable_labels
327322 )
@@ -333,11 +328,11 @@ def _fill_gaps_between_structures(
333328
334329
335330def correct_wm_ventricles (
336- aseg_cc : npt . NDArray [ np .int_ ],
337- fornix_mask : npt . NDArray [ np . bool_ ] ,
331+ aseg_cc : np . ndarray [ Shape3d , np .dtype [ int ] ],
332+ fornix_mask : Mask3d ,
338333 voxel_size : tuple [float , float , float ],
339334 close_gap_size_mm : float = 3.0
340- ) -> npt . NDArray [ np .int_ ]:
335+ ) -> np . ndarray [ Shape3d , np .dtype [ int ] ]:
341336 """Fill small gaps between corpus callosum, ventricles, and background.
342337
343338 This function performs two gap-filling operations:
@@ -349,9 +344,9 @@ def correct_wm_ventricles(
349344
350345 Parameters
351346 ----------
352- aseg_cc : npt.NDArray[ np.int_]
347+ aseg_cc : np.ndarray
353348 Aseg segmentation with CC already painted in.
354- fornix_mask : npt.NDArray[ np.bool_]
349+ fornix_mask : np.ndarray
355350 Mask of the fornix. Not currently used (kept for interface compatibility).
356351 voxel_size : tuple[float, float, float]
357352 Voxel size of the aseg image in mm.
@@ -360,7 +355,7 @@ def correct_wm_ventricles(
360355
361356 Returns
362357 -------
363- npt.NDArray[ np.int_]
358+ np.ndarray
364359 Corrected segmentation map with filled gaps.
365360 """
366361 # Create a copy to avoid modifying the original
@@ -374,37 +369,38 @@ def correct_wm_ventricles(
374369
375370 # Get background mask
376371 background_mask = aseg_cc == 0
377-
372+ print (np .unique (corrected_pred ))
373+
378374 # 1. Fill gaps between CC and ventricles (replace WM and background with ventricle labels)
379- _fill_gaps_between_structures (
375+ _fill_gaps_between_structures_ (
380376 corrected_pred , cc_mask , ventricle_mask , voxel_size , close_gap_size_mm ,
381377 fillable_labels = {0 , 2 , 41 }, # background and WM
382378 description = "between CC and ventricles (WM/background → ventricle)"
383379 )
384-
380+ print (np .unique (corrected_pred ))
381+
385382 # 2. Fill WM gaps between CC and background (replace WM with background)
386- _fill_gaps_between_structures (
383+ _fill_gaps_between_structures_ (
387384 corrected_pred , cc_mask , background_mask , voxel_size , close_gap_size_mm ,
388385 fillable_labels = {2 , 41 }, # only WM
389386 description = "between CC and background (WM → background)"
390387 )
388+ print (np .unique (corrected_pred ))
391389
392390 return corrected_pred
393391
394392
395393if __name__ == "__main__" :
396- from FastSurferCNN .utils import nibabelImage
397394
398395 # Command Line options are error checking done here
399396 options = argument_parse ()
400397
401398 logging .setup_logging ()
402399
403400 logger .info (f"Reading inputs: { options .input_cc } { options .input_pred } ..." )
404- cc_seg_image = cast (nibabelImage , nib .load (options .input_cc ))
405- cc_seg_data = np .asanyarray (cc_seg_image .dataobj )
406- aseg_image = cast (nibabelImage , nib .load (options .input_pred ))
407- aseg_data = np .asanyarray (aseg_image .dataobj )
401+
402+ tmap = thread_executor ().map
403+ (cc_seg_image , cc_seg_data ), (aseg_image , aseg_data ) = tmap (load_image , (options .input_cc , options .input_pred ))
408404
409405 def _is_conform (img , dtype , verbose ):
410406 return is_conform (img , vox_size = None , img_size = None , verbose = verbose , dtype = dtype )
@@ -433,8 +429,8 @@ def _is_conform(img, dtype, verbose):
433429 initial_wm = np .sum ((aseg_data == 2 ) | (aseg_data == 41 ))
434430 initial_ventricles = np .sum ((aseg_data == 4 ) | (aseg_data == 43 ))
435431
436- # Paint CC into prediction (modifies aseg_data in place)
437- paint_in_cc (aseg_data , cc_seg_data )
432+ # Paint CC into prediction
433+ aseg_data = paint_in_cc (aseg_data , cc_seg_data )
438434
439435 # Apply ventricle gap filling corrections
440436 fornix_mask = cc_seg_data == FORNIX_LABEL
0 commit comments