Skip to content

Commit c23d036

Browse files
committed
Fix dtype wrapper conversion handling
1 parent d8982c7 commit c23d036

1 file changed

Lines changed: 36 additions & 28 deletions

File tree

src/arraybridge/decorators.py

Lines changed: 36 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -172,39 +172,47 @@ def dtype_wrapper(image, *args, dtype_conversion=None, slice_by_slice: bool = Fa
172172
if dtype_conversion is None:
173173
dtype_conversion = DtypeConversion.PRESERVE_INPUT
174174

175-
try:
176-
# Store original dtype
177-
original_dtype = image.dtype
178-
179-
# Handle slice_by_slice processing for 3D arrays
180-
if slice_by_slice and hasattr(image, "ndim") and image.ndim == 3:
181-
result = process_slices(image, func, args, kwargs)
182-
else:
183-
# Call the original function normally
184-
result = func(image, *args, **kwargs)
185-
186-
# Apply dtype conversion based on enum value
187-
if hasattr(result, "dtype") and dtype_conversion is not None:
188-
if dtype_conversion == DtypeConversion.PRESERVE_INPUT:
189-
# Preserve input dtype
190-
if result.dtype != original_dtype:
191-
result = scale_func(result, original_dtype)
192-
elif dtype_conversion == DtypeConversion.NATIVE_OUTPUT:
193-
# Return framework's native output dtype
194-
pass # No conversion needed
195-
else:
196-
# Force specific dtype
197-
target_dtype = dtype_conversion.numpy_dtype
198-
if target_dtype is not None:
199-
result = scale_func(result, target_dtype)
175+
# Store original dtype
176+
original_dtype = getattr(image, "dtype", None)
177+
178+
# Handle slice_by_slice processing for 3D arrays
179+
if slice_by_slice and hasattr(image, "ndim") and image.ndim == 3:
180+
result = process_slices(image, func, args, kwargs)
181+
else:
182+
# Call the original function normally
183+
result = func(image, *args, **kwargs)
184+
185+
def _apply_dtype_conversion(array):
186+
if dtype_conversion is None or not hasattr(array, "dtype"):
187+
return array
188+
if dtype_conversion == DtypeConversion.PRESERVE_INPUT:
189+
# Preserve input dtype
190+
if original_dtype is not None and array.dtype != original_dtype:
191+
return scale_func(array, original_dtype)
192+
return array
193+
if dtype_conversion == DtypeConversion.NATIVE_OUTPUT:
194+
# Return framework's native output dtype
195+
return array
196+
# Force specific dtype
197+
target_dtype = dtype_conversion.numpy_dtype
198+
if target_dtype is not None:
199+
return scale_func(array, target_dtype)
200+
return array
200201

201-
return result
202+
try:
203+
# Apply dtype conversion to the main output
204+
if isinstance(result, tuple):
205+
if not result:
206+
return result
207+
converted_main = _apply_dtype_conversion(result[0])
208+
return (converted_main, *result[1:])
209+
return _apply_dtype_conversion(result)
202210
except Exception as e:
203211
logger.error(
204212
f"Error in {mem_type.value} dtype/slice preserving wrapper " f"for {func_name}: {e}"
205213
)
206-
# Return original result on error
207-
return func(image, *args, **kwargs)
214+
# Return unmodified result on conversion errors
215+
return result
208216

209217
# Update function signature to include new parameters
210218
try:

0 commit comments

Comments
 (0)