Skip to content
Draft
40 changes: 39 additions & 1 deletion src/tracksdata/graph/_base_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@
import polars as pl
from numpy.typing import ArrayLike

from tracksdata.attrs import AttrComparison
from tracksdata.attrs import AttrComparison, NodeAttr
from tracksdata.constants import DEFAULT_ATTR_KEYS
from tracksdata.utils._logging import LOG
from tracksdata.utils._multiprocessing import multiprocessing_apply

if TYPE_CHECKING:
from tracksdata.graph._graph_view import GraphView
Expand Down Expand Up @@ -792,3 +793,40 @@ def from_other(cls: type[T], other: "BaseGraph", **kwargs) -> T:
graph.bulk_add_overlaps(overlaps.tolist())

return graph

def set_overlaps(self, iou_threshold: float = 0.0) -> None:
"""
Find overlapping nodes within each frame and add them their overlap relation into the graph.

Parameters
----------
iou_threshold : float
Nodes with an IoU greater than this threshold are considered overlapping.
If 0, all nodes are considered overlapping.

Examples
--------
```python
graph.set_overlaps(iou_threshold=0.5)
```
"""
if iou_threshold < 0.0 or iou_threshold > 1.0:
raise ValueError("iou_threshold must be between 0.0 and 1.0")

def _estimate_overlaps(t: int) -> list[list[int, 2]]:
node_ids = self.filter_nodes_by_attrs(NodeAttr(DEFAULT_ATTR_KEYS.T) == t)
masks = self.node_attrs(node_ids=node_ids, attr_keys=[DEFAULT_ATTR_KEYS.MASK])[DEFAULT_ATTR_KEYS.MASK]
overlaps = []
for i in range(len(masks)):
mask_i = masks[i]
for j in range(i + 1, len(masks)):
if mask_i.iou(masks[j]) > iou_threshold:
overlaps.append([node_ids[i], node_ids[j]])
return overlaps

for overlaps in multiprocessing_apply(
func=_estimate_overlaps,
sequence=self.time_points(),
desc="Setting overlaps",
):
self.bulk_add_overlaps(overlaps)
3 changes: 2 additions & 1 deletion src/tracksdata/nodes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

from tracksdata.nodes._generic_nodes import GenericFuncNodeAttrs
from tracksdata.nodes._mask import Mask
from tracksdata.nodes._node_interpolator import NodeInterpolator
from tracksdata.nodes._random import RandomNodes
from tracksdata.nodes._regionprops import RegionPropsNodes

__all__ = ["GenericFuncNodeAttrs", "Mask", "RandomNodes", "RegionPropsNodes"]
__all__ = ["GenericFuncNodeAttrs", "Mask", "NodeInterpolator", "RandomNodes", "RegionPropsNodes"]
58 changes: 57 additions & 1 deletion src/tracksdata/nodes/_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,61 @@
import blosc2
import numpy as np
import skimage.morphology as morph
from numba import njit
from numpy.typing import ArrayLike, NDArray

from tracksdata.functional._iou import fast_intersection_with_bbox, fast_iou_with_bbox


@njit
def bbox_interpolation_offset(
tgt_bbox: np.ndarray,
src_bbox: np.ndarray,
w: float,
) -> np.ndarray:
"""
Interpolate the bounding box between two masks.
The reference is the target mask and w is relative distance to it:
```python
(target_value - new_value) / (target_value - source_value) = w
```

Parameters
----------
tgt_bbox : np.ndarray
The target bounding box.
src_bbox : np.ndarray
The source bounding box.
w : float
The weight of the interpolation.

Returns
-------
np.ndarray
The offset to add to the bounding box.
"""
if w < 0 or w > 1:
raise ValueError(f"w = {w} is not between 0 and 1")

ndim = tgt_bbox.shape[0] // 2
tgt_center = tgt_bbox[ndim:] - tgt_bbox[:ndim] // 2
src_center = src_bbox[ndim:] - src_bbox[:ndim] // 2
signed_dist = tgt_center - src_center
offset = -np.round((1 - w) * signed_dist).astype(np.int32)

for i in range(ndim):
if offset[i] > 0:
new_value = tgt_bbox[ndim + i] - offset[i]
dist_to_border = min(new_value - tgt_bbox[ndim + i], 0)
offset[i] += dist_to_border
else:
new_value = tgt_bbox[i] + offset[i]
dist_to_border = max(tgt_bbox[i] - new_value, 0)
offset[i] += dist_to_border

return offset


@lru_cache(maxsize=5)
def _spherical_mask(
radius: int,
Expand Down Expand Up @@ -109,14 +159,20 @@ def crop(
slicing = tuple(slice(self._bbox[i], self._bbox[i + ndim]) for i in range(ndim))

else:
center = (self._bbox[: self._mask.ndim] + self._bbox[self._mask.ndim :]) // 2
center = self.bbox_center()
half_shape = np.asarray(shape) // 2
start = np.maximum(center - half_shape, 0)
end = np.minimum(center + half_shape, image.shape)
slicing = tuple(slice(s, e) for s, e in zip(start, end, strict=True))

return image[slicing]

def bbox_center(self) -> NDArray[np.integer]:
"""
Get the center of the bounding box.
"""
return (self._bbox[: self._mask.ndim] + self._bbox[self._mask.ndim :]) // 2

def mask_indices(
self,
offset: NDArray[np.integer] | int = 0,
Expand Down
Loading