Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
287 changes: 149 additions & 138 deletions Rhapso/split_dataset/split_images.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,16 @@
import math

class SplitImages:
def __init__(self, target_image_size, target_overlap, min_step_size, data_gloabl, n5_path, point_density, min_points, max_points,
def __init__(self, target_image_size, target_overlap, min_step_size, data_global, n5_path, point_density, min_points, max_points,
error, excludeRadius):
self.target_image_size = target_image_size
self.target_overlap = target_overlap
self.min_step_size = min_step_size
self.data_global = data_gloabl
self.image_loader_df = data_gloabl['image_loader']
self.view_setups_df = data_gloabl['view_setups']
self.view_registrations_df = data_gloabl['view_registrations']
self.view_interest_points_df = data_gloabl['view_interest_points']
self.data_global = data_global
self.image_loader_df = data_global['image_loader']
self.view_setups_df = data_global['view_setups']
self.view_registrations_df = data_global['view_registrations']
self.view_interest_points_df = data_global['view_interest_points']
self.n5_path = n5_path
self.point_density = point_density
self.min_points = min_points
Expand Down Expand Up @@ -96,21 +96,24 @@ def split_dims(self, input, i, final_size, overlap):
to_val = 0
from_val = input_min[i]

while to_val < input[i]:
while to_val < input[i]-1:
to_val = min(input[i], from_val + final_size - 1)
dim_intervals.append((from_val, to_val))
from_val = to_val - overlap + 1

return dim_intervals

def last_image_size(self, l, s, o):
num = l - 2 * (s - o) - o
den = s - o
rem = num % den if num >= 0 else -((-num) % den)
size = o + rem
if size < 0:
size = l + size
return size

def last_image_size(self, L, S, O):
stride = S - O
if not (0 <= O < S):
raise ValueError("Require 0 <= O < S")
if L <= 0:
raise ValueError("Require L > 0")

start_last = ((max(L - S, 0)) // stride) * stride
return L - start_last # will be S when it fits perfectly


def distribute_intervals_fixed_overlap(self, input):
input = list(map(int, input.split()))
Expand All @@ -127,8 +130,8 @@ def distribute_intervals_fixed_overlap(self, input):
length = input[i]

if length <= self.target_image_size[i]:
pass
dim_intervals.append((0, length - 1))

else:
l = length
s = self.target_image_size[i]
Expand Down Expand Up @@ -340,137 +343,141 @@ def split_images(self, timepoints, interest_points, fake_label):
new_registrations[(new_view_id_key)] = new_view_registration

new_v_ip_l = []

old_v_ip_l = {
'points': interest_points[old_view_id],
'setup': old_id,
'timepoint': t,
}

id = 0
new_ip1 = []
old_ip_l1 = old_v_ip_l['points']
old_ip_1 = deepcopy(old_ip_l1['points'])

for ip in old_ip_1:
if self.contains(ip, interval):
l = deepcopy(ip)
for j in range(len(interval[0])):
l[j] -= interval[0][j]

new_ip1.append((id, l))
id += 1

new_ip_l1 = {
'base_directory': old_ip_l1['base_path'],
'corresponding_interest_points': None,
'interest_points': new_ip1,
'modified_corresponding_interest_points': None,
'modified_interest_points': None,
'n5_path': f"interestpoints.n5/tpId_{t}_viewSetupId_{new_view_id['setup']}/beads_split",
'xml_n5_path': f"tpId_{t}_viewSetupId_{new_view_id['setup']}/{fake_label}",
"parameters": old_ip_l1['parameters_split']
}

new_v_ip_l.append({
'label': "beads_split",
'ip_list': new_ip_l1
})

new_ip = []
id = 0

for j in range(i):
other_interval = intervals[j]
intersection = self.intersect(interval, other_interval)
if old_view_id in interest_points:
old_v_ip_l = {
'points': interest_points[old_view_id],
'setup': old_id,
'timepoint': t,
}

id = 0
new_ip1 = []
old_ip_l1 = old_v_ip_l['points']
old_ip_1 = deepcopy(old_ip_l1['points'])

if not self.is_empty(intersection):
other_setup = interval_to_view_setup[(tuple(other_interval[0]), tuple(other_interval[1]))]
other_view_id = f"timepoint: {t}, setup: {other_setup['id']}"
other_ip_list = new_interest_points[other_view_id]

n = len(interval[0])
num_pixels = 1

for k in range(n):
num_pixels *= (intersection[1][k] - intersection[0][k] + 1)

num_points = min(self.max_points, max(self.min_points, math.ceil(self.point_density * num_pixels / (100.0*100.0*100.0))))
other_points = (next((x for x in other_ip_list if x.get("label") == fake_label), {"ip_list": {}})["ip_list"].get("interest_points") or [])
other_id = len(other_points)

tree2 = None
search2 = None
for ip in old_ip_1:
if self.contains(ip, interval):
l = deepcopy(ip)
for j in range(len(interval[0])):
l[j] -= interval[0][j]

new_ip1.append((id, l))
id += 1

new_ip_l1 = {
'base_directory': old_ip_l1['base_path'],
'corresponding_interest_points': None,
'interest_points': new_ip1,
'modified_corresponding_interest_points': None,
'modified_interest_points': None,
'n5_path': f"interestpoints.n5/tpId_{t}_viewSetupId_{new_view_id['setup']}/beads_split",
'xml_n5_path': f"tpId_{t}_viewSetupId_{new_view_id['setup']}/{fake_label}",
"parameters": old_ip_l1['parameters_split']
}

new_v_ip_l.append({
'label': "beads_split",
'ip_list': new_ip_l1
})

if self.max_points > 0:
new_ip = []
id = 0

if self.exclude_radius > 0:
other_ip_global = []
for j in range(i):
other_interval = intervals[j]
intersection = self.intersect(interval, other_interval)

for k, ip in enumerate(other_points):
l = deepcopy(ip[1])
if not self.is_empty(intersection):
other_setup = interval_to_view_setup[(tuple(other_interval[0]), tuple(other_interval[1]))]
other_view_id = f"timepoint: {t}, setup: {other_setup['id']}"
other_ip_list = new_interest_points[other_view_id]

for m in range(n):
l[m] = l[m] + other_interval[0][m]

other_ip_global.append((k, l))
n = len(interval[0])
num_pixels = 1

if len(other_ip_global) > 0:
coords = np.vstack([l for _, l in other_ip_global])
tree2 = cKDTree(coords)
for k in range(n):
num_pixels *= (intersection[1][k] - intersection[0][k] + 1)

num_points = min(self.max_points, max(self.min_points, math.ceil(self.point_density * num_pixels / (100.0*100.0*100.0))))
other_points = (next((x for x in other_ip_list if x.get("label") == fake_label), {"ip_list": {}})["ip_list"].get("interest_points") or [])
other_id = len(other_points)

def search2(q_point_global, radius=self.exclude_radius):
idxs = tree2.query_ball_point(np.asarray(q_point_global, float), radius)
return [other_ip_global[k] for k in idxs]
else:
tree2 = None
search2 = None

else:
tree2 = None
search2 = None

tmp = [0.0] * n

for k in range(num_points):
p = [0.0] * n
op = [0.0] * n

for d in range(n):
l = rnd.random() * (intersection[1][d] - intersection[0][d] + 1) + intersection[0][d]
p[d] = (l + (rnd.random() - 0.5) * self.error) - interval[0][d]
op[d] = (l + (rnd.random() - 0.5) * self.error) - other_interval[0][d]
tmp[d] = l

num_neighbors = 0
if self.exclude_radius > 0:
tmp_ip = (0, np.asarray(tmp, dtype=float))
if self.exclude_radius > 0:
other_ip_global = []

for k, ip in enumerate(other_points):
l = deepcopy(ip[1])

for m in range(n):
l[m] = l[m] + other_interval[0][m]

other_ip_global.append((k, l))

if len(other_ip_global) > 0:
coords = np.vstack([l for _, l in other_ip_global])
tree2 = cKDTree(coords)

def search2(q_point_global, radius=self.exclude_radius):
idxs = tree2.query_ball_point(np.asarray(q_point_global, float), radius)
return [other_ip_global[k] for k in idxs]
else:
tree2 = None
search2 = None

if search2 is not None:
neighbors = search2(tmp_ip[1], self.exclude_radius)
num_neighbors += len(neighbors)

if num_neighbors == 0:
new_ip.append((id, p))
other_points.append((other_id, op))
id += 1
other_id += 1
else:
tree2 = None
search2 = None

tmp = [0.0] * n

for k in range(num_points):
p = [0.0] * n
op = [0.0] * n

for d in range(n):
l = rnd.random() * (intersection[1][d] - intersection[0][d] + 1) + intersection[0][d]
p[d] = (l + (rnd.random() - 0.5) * self.error) - interval[0][d]
op[d] = (l + (rnd.random() - 0.5) * self.error) - other_interval[0][d]
tmp[d] = l

num_neighbors = 0
if self.exclude_radius > 0:
tmp_ip = (0, np.asarray(tmp, dtype=float))

if search2 is not None:
neighbors = search2(tmp_ip[1], self.exclude_radius)
num_neighbors += len(neighbors)

if num_neighbors == 0:
new_ip.append((id, p))
other_points.append((other_id, op))
id += 1
other_id += 1

next(x for x in other_ip_list if x.get("label") == fake_label)["ip_list"]["interest_points"] = other_points

next(x for x in other_ip_list if x.get("label") == fake_label)["ip_list"]["interest_points"] = other_points

new_ip_l = {
'base_directory': old_ip_l1['base_path'],
'corresponding_interest_points': None,
'interest_points': new_ip,
'modified_corresponding_interest_points': None,
'modified_interest_points': None,
'n5_path': f"interestpoints.n5/tpId_{t}_viewSetupId_{new_view_id['setup']}/{fake_label}",
'xml_n5_path': f"tpId_{t}_viewSetupId_{new_view_id['setup']}/{fake_label}",
"parameters": old_ip_l1['parameters_fake']
}

new_v_ip_l.append({
'label': fake_label,
'ip_list': new_ip_l
})
new_ip_l = {
'base_directory': old_ip_l1['base_path'],
'corresponding_interest_points': None,
'interest_points': new_ip,
'modified_corresponding_interest_points': None,
'modified_interest_points': None,
'n5_path': f"interestpoints.n5/tpId_{t}_viewSetupId_{new_view_id['setup']}/{fake_label}",
'xml_n5_path': f"tpId_{t}_viewSetupId_{new_view_id['setup']}/{fake_label}",
"parameters": old_ip_l1['parameters_fake']
}

new_v_ip_l.append({
'label': fake_label,
'ip_list': new_ip_l
})

if len(new_v_ip_l) > 0:
new_interest_points[new_view_id_key] = new_v_ip_l

self.setup_definition.append({
'interval': interval,
Expand All @@ -484,7 +491,6 @@ def search2(q_point_global, radius=self.exclude_radius):
'old_models': transform_list
})

new_interest_points[new_view_id_key] = new_v_ip_l
new_id += 1

return new_interest_points
Expand All @@ -493,6 +499,10 @@ def load_interest_points(self, fake_label):
full_path = self.n5_path + "interestpoints.n5"
interest_points = {}

# Skip loading interest points if dataframe is empty
if self.view_interest_points_df.empty:
return {}

if full_path.startswith("s3://"):
path = full_path.rstrip("/")
s3 = s3fs.S3FileSystem(anon=False)
Expand All @@ -503,6 +513,7 @@ def load_interest_points(self, fake_label):
store = zarr.N5Store(full_path)
root = zarr.open(store, mode="r")


for _, row in self.view_interest_points_df.iterrows():
view_id = f"timepoint: {row['timepoint']}, setup: {row['setup']}"
interestpoints_prefix = f"{row['path']}/interestpoints/loc/"
Expand Down