@@ -172,39 +172,47 @@ def dtype_wrapper(image, *args, dtype_conversion=None, slice_by_slice: bool = Fa
172172 if dtype_conversion is None :
173173 dtype_conversion = DtypeConversion .PRESERVE_INPUT
174174
175- try :
176- # Store original dtype
177- original_dtype = image .dtype
178-
179- # Handle slice_by_slice processing for 3D arrays
180- if slice_by_slice and hasattr (image , "ndim" ) and image .ndim == 3 :
181- result = process_slices (image , func , args , kwargs )
182- else :
183- # Call the original function normally
184- result = func (image , * args , ** kwargs )
185-
186- # Apply dtype conversion based on enum value
187- if hasattr (result , "dtype" ) and dtype_conversion is not None :
188- if dtype_conversion == DtypeConversion .PRESERVE_INPUT :
189- # Preserve input dtype
190- if result .dtype != original_dtype :
191- result = scale_func (result , original_dtype )
192- elif dtype_conversion == DtypeConversion .NATIVE_OUTPUT :
193- # Return framework's native output dtype
194- pass # No conversion needed
195- else :
196- # Force specific dtype
197- target_dtype = dtype_conversion .numpy_dtype
198- if target_dtype is not None :
199- result = scale_func (result , target_dtype )
175+ # Store original dtype
176+ original_dtype = getattr (image , "dtype" , None )
177+
178+ # Handle slice_by_slice processing for 3D arrays
179+ if slice_by_slice and hasattr (image , "ndim" ) and image .ndim == 3 :
180+ result = process_slices (image , func , args , kwargs )
181+ else :
182+ # Call the original function normally
183+ result = func (image , * args , ** kwargs )
184+
185+ def _apply_dtype_conversion (array ):
186+ if dtype_conversion is None or not hasattr (array , "dtype" ):
187+ return array
188+ if dtype_conversion == DtypeConversion .PRESERVE_INPUT :
189+ # Preserve input dtype
190+ if original_dtype is not None and array .dtype != original_dtype :
191+ return scale_func (array , original_dtype )
192+ return array
193+ if dtype_conversion == DtypeConversion .NATIVE_OUTPUT :
194+ # Return framework's native output dtype
195+ return array
196+ # Force specific dtype
197+ target_dtype = dtype_conversion .numpy_dtype
198+ if target_dtype is not None :
199+ return scale_func (array , target_dtype )
200+ return array
200201
201- return result
202+ try :
203+ # Apply dtype conversion to the main output
204+ if isinstance (result , tuple ):
205+ if not result :
206+ return result
207+ converted_main = _apply_dtype_conversion (result [0 ])
208+ return (converted_main , * result [1 :])
209+ return _apply_dtype_conversion (result )
202210 except Exception as e :
203211 logger .error (
204212 f"Error in { mem_type .value } dtype/slice preserving wrapper " f"for { func_name } : { e } "
205213 )
206- # Return original result on error
207- return func ( image , * args , ** kwargs )
214+ # Return unmodified result on conversion errors
215+ return result
208216
209217 # Update function signature to include new parameters
210218 try :
0 commit comments