Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions Rhapso/detection/image_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,26 @@ def fetch_image_data(self, record, dsxy, dsz):
f"crop_min and crop_max must both be length 3 for 3D cropping; "
f"got crop_min={crop_min}, crop_max={crop_max}"
)

# Validate crop bounds are within array dimensions
array_shape = dask_array.shape
for i in range(3):
if crop_min[i] < 0:
raise ValueError(
f"crop_min[{i}]={crop_min[i]} is negative; "
f"crop bounds must be non-negative"
)
if crop_max[i] >= array_shape[i]:
raise ValueError(
f"crop_max[{i}]={crop_max[i]} exceeds array dimension {i} "
f"(shape={array_shape[i]}); crop_max must be < array shape"
)
if crop_min[i] > crop_max[i]:
raise ValueError(
f"crop_min[{i}]={crop_min[i]} > crop_max[{i}]={crop_max[i]}; "
f"crop_min must be <= crop_max"
)

dask_array = dask_array[
crop_min[0]:crop_max[0] + 1,
crop_min[1]:crop_max[1] + 1,
Expand Down
83 changes: 83 additions & 0 deletions tests/test_detection/test_image_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,89 @@ def test_fetch_image_data_tiff_no_crop_error(self):

self.assertEqual(view_id, 'timepoint: 0, setup: 0')

def test_fetch_image_data_crop_bounds_validation(self):
"""Test that crop bounds exceeding array dimensions raise clear error"""
reader = ImageReader(file_type='zarr')

# Record with crop_max exceeding array dimensions
record = {
'view_id': 'timepoint: 0, setup: 0',
'file_path': 's3://bucket/test.zarr/0',
'interval_key': ((0, 0, 0), (50, 50, 25), (51, 51, 26)),
'offset': 0,
'lb': (0, 0, 0),
'crop_min': [0, 0, 0],
'crop_max': [15, 5, 5] # Exceeds dimension 0 (10x10x10 array)
}

# Mock the zarr opening to return a known dask array (10x10x10)
mock_array = da.ones((1, 1, 10, 10, 10), dtype=np.float32)

with patch('zarr.open') as mock_zarr, \
patch('s3fs.S3FileSystem'), \
patch('s3fs.S3Map'), \
patch('dask.array.from_zarr', return_value=mock_array):

# Should raise ValueError with clear message
with self.assertRaises(ValueError) as context:
reader.fetch_image_data(record, dsxy=1, dsz=1)

error_msg = str(context.exception)
self.assertIn('crop_max[0]=15 exceeds array dimension 0', error_msg)
self.assertIn('(shape=10)', error_msg)

def test_fetch_image_data_negative_crop_min(self):
"""Test that negative crop_min values raise clear error"""
reader = ImageReader(file_type='zarr')

record = {
'view_id': 'timepoint: 0, setup: 0',
'file_path': 's3://bucket/test.zarr/0',
'interval_key': ((0, 0, 0), (50, 50, 25), (51, 51, 26)),
'offset': 0,
'lb': (0, 0, 0),
'crop_min': [-1, 0, 0],
'crop_max': [5, 5, 5]
}

mock_array = da.ones((1, 1, 10, 10, 10), dtype=np.float32)

with patch('zarr.open') as mock_zarr, \
patch('s3fs.S3FileSystem'), \
patch('s3fs.S3Map'), \
patch('dask.array.from_zarr', return_value=mock_array):

with self.assertRaises(ValueError) as context:
reader.fetch_image_data(record, dsxy=1, dsz=1)

self.assertIn('crop_min[0]=-1 is negative', str(context.exception))

def test_fetch_image_data_crop_min_greater_than_crop_max(self):
"""Test that crop_min > crop_max raises clear error"""
reader = ImageReader(file_type='zarr')

record = {
'view_id': 'timepoint: 0, setup: 0',
'file_path': 's3://bucket/test.zarr/0',
'interval_key': ((0, 0, 0), (50, 50, 25), (51, 51, 26)),
'offset': 0,
'lb': (0, 0, 0),
'crop_min': [5, 0, 0],
'crop_max': [3, 5, 5] # crop_min[0] > crop_max[0]
}

mock_array = da.ones((1, 1, 10, 10, 10), dtype=np.float32)

with patch('zarr.open') as mock_zarr, \
patch('s3fs.S3FileSystem'), \
patch('s3fs.S3Map'), \
patch('dask.array.from_zarr', return_value=mock_array):

with self.assertRaises(ValueError) as context:
reader.fetch_image_data(record, dsxy=1, dsz=1)

self.assertIn('crop_min[0]=5 > crop_max[0]=3', str(context.exception))


if __name__ == "__main__":
unittest.main()