diff --git a/book/src/tree_sequence_edge_diffs.md b/book/src/tree_sequence_edge_diffs.md index a694f9800..688b92ef6 100644 --- a/book/src/tree_sequence_edge_diffs.md +++ b/book/src/tree_sequence_edge_diffs.md @@ -1,6 +1,6 @@ ## Iterating over edge differences -As with [trees](tree_sequence_iterate_trees.md), the API provides a *lending* iterator over edge differences. +The API provides an iterator over edge differences. Each step of the iterator advances to the next tree in the tree sequence. For each tree, a standard `Iterator` over removals and insertions is available: diff --git a/src/edge_differences.rs b/src/edge_differences.rs index 17e661f07..38330661a 100644 --- a/src/edge_differences.rs +++ b/src/edge_differences.rs @@ -1,41 +1,8 @@ +use crate::EdgeId; use crate::NodeId; use crate::Position; use crate::TreeSequence; -use crate::sys::bindings; - -#[repr(transparent)] -struct LLEdgeDifferenceIterator(bindings::tsk_diff_iter_t); - -impl Drop for LLEdgeDifferenceIterator { - fn drop(&mut self) { - unsafe { bindings::tsk_diff_iter_free(&mut self.0) }; - } -} - -impl LLEdgeDifferenceIterator { - pub fn new_from_treeseq( - treeseq: &TreeSequence, - flags: bindings::tsk_flags_t, - ) -> Result { - let mut inner = std::mem::MaybeUninit::::uninit(); - let treeseq_ptr = treeseq.as_ptr(); - assert!(!treeseq_ptr.is_null()); - // SAFETY: treeseq_ptr is not null - let tables_ptr = - unsafe { (*treeseq_ptr).tables } as *const bindings::tsk_table_collection_t; - assert!(!tables_ptr.is_null()); - // SAFETY: tables_ptr is not null, - // init of inner will be handled by tsk_diff_iter_init - let num_trees: i32 = treeseq.num_trees().try_into()?; - let code = unsafe { - bindings::tsk_diff_iter_init(inner.as_mut_ptr(), tables_ptr, num_trees, flags) - }; - // SAFETY: tsk_diff_iter_init has initialized our object - handle_tsk_return_value!(code, Self(unsafe { inner.assume_init() })) - } -} - /// Marker type for edge insertion. pub struct Insertion {} @@ -49,49 +16,6 @@ mod private { impl EdgeDifferenceIteration for super::Removal {} } -struct LLEdgeList { - inner: bindings::tsk_edge_list_t, - marker: std::marker::PhantomData, -} - -macro_rules! build_lledgelist { - ($name: ident, $generic: ty) => { - type $name = LLEdgeList<$generic>; - - impl Default for $name { - fn default() -> Self { - Self { - inner: bindings::tsk_edge_list_t { - head: std::ptr::null_mut(), - tail: std::ptr::null_mut(), - }, - marker: std::marker::PhantomData::<$generic> {}, - } - } - } - }; -} - -build_lledgelist!(LLEdgeInsertionList, Insertion); -build_lledgelist!(LLEdgeRemovalList, Removal); - -/// Concrete type implementing [`Iterator`] over [`EdgeInsertion`] or [`EdgeRemoval`]. -/// Created by [`EdgeDifferencesIterator::edge_insertions`] or -/// [`EdgeDifferencesIterator::edge_removals`], respectively. -pub struct EdgeDifferences<'a, T: private::EdgeDifferenceIteration> { - inner: &'a LLEdgeList, - current: *mut bindings::tsk_edge_list_node_t, -} - -impl<'a, T: private::EdgeDifferenceIteration> EdgeDifferences<'a, T> { - fn new(inner: &'a LLEdgeList) -> Self { - Self { - inner, - current: std::ptr::null_mut(), - } - } -} - /// An edge difference. Edge insertions and removals are differentiated by /// marker types [`Insertion`] and [`Removal`], respectively. #[derive(Debug, Copy, Clone)] @@ -149,103 +73,170 @@ pub type EdgeInsertion = EdgeDifference; /// Type alias for [`EdgeDifference`] pub type EdgeRemoval = EdgeDifference; -impl Iterator for EdgeDifferences<'_, T> -where - T: private::EdgeDifferenceIteration, -{ - type Item = EdgeDifference; +/// Manages iteration over trees to obtain +/// edge differences. +pub struct EdgeDifferencesIterator<'ts> { + edges_left: &'ts [Position], + edges_right: &'ts [Position], + edges_parent: &'ts [NodeId], + edges_child: &'ts [NodeId], + insertion_order: &'ts [EdgeId], + removal_order: &'ts [EdgeId], + left: f64, + sequence_length: f64, + insertion_index: usize, + removal_index: usize, +} - fn next(&mut self) -> Option { - if self.current.is_null() { - self.current = self.inner.inner.head; - } else { - self.current = unsafe { *self.current }.next; - } - if self.current.is_null() { - None - } else { - let left = unsafe { (*self.current).edge.left }; - let right = unsafe { (*self.current).edge.right }; - let parent = unsafe { (*self.current).edge.parent }; - let child = unsafe { (*self.current).edge.child }; - Some(Self::Item::new(left, right, parent, child)) +impl<'ts> EdgeDifferencesIterator<'ts> { + pub(crate) fn new(treeseq: &'ts TreeSequence) -> Self { + Self { + edges_left: treeseq.tables().edges().left_slice(), + edges_right: treeseq.tables().edges().right_slice(), + edges_parent: treeseq.tables().edges().parent_slice(), + edges_child: treeseq.tables().edges().child_slice(), + insertion_order: treeseq.edge_insertion_order(), + removal_order: treeseq.edge_removal_order(), + left: 0., + sequence_length: treeseq.tables().sequence_length().into(), + insertion_index: 0, + removal_index: 0, } } } -/// Manages iteration over trees to obtain -/// edge differences. -pub struct EdgeDifferencesIterator { - inner: LLEdgeDifferenceIterator, - insertion: LLEdgeInsertionList, - removal: LLEdgeRemovalList, +#[derive(Clone)] +pub struct CurrentTreeEdgeDifferences<'ts> { + edges_left: &'ts [Position], + edges_right: &'ts [Position], + edges_parent: &'ts [NodeId], + edges_child: &'ts [NodeId], + insertion_order: &'ts [EdgeId], + removal_order: &'ts [EdgeId], + removals: (usize, usize), + insertions: (usize, usize), left: f64, right: f64, - advanced: i32, } -impl EdgeDifferencesIterator { - // NOTE: will return None if tskit-c cannot - // allocate memory for internal structures. - pub(crate) fn new_from_treeseq( - treeseq: &TreeSequence, - flags: bindings::tsk_flags_t, - ) -> Result { - LLEdgeDifferenceIterator::new_from_treeseq(treeseq, flags).map(|inner| Self { - inner, - insertion: LLEdgeInsertionList::default(), - removal: LLEdgeRemovalList::default(), - left: f64::default(), - right: f64::default(), - advanced: 0, - }) - } +#[repr(transparent)] +pub struct EdgeRemovalsIterator<'ts>(CurrentTreeEdgeDifferences<'ts>); - fn advance_tree(&mut self) { - // SAFETY: our tree sequence is guaranteed - // to be valid and own its tables. - self.advanced = unsafe { - bindings::tsk_diff_iter_next( - &mut self.inner.0, - &mut self.left, - &mut self.right, - &mut self.removal.inner, - &mut self.insertion.inner, - ) - }; - } +#[repr(transparent)] +pub struct EdgeInsertionsIterator<'ts>(CurrentTreeEdgeDifferences<'ts>); - pub fn left(&self) -> Position { - self.left.into() +impl<'ts> Iterator for EdgeRemovalsIterator<'ts> { + type Item = EdgeDifference; + fn next(&mut self) -> Option { + if self.0.removals.0 < self.0.removals.1 { + let index = self.0.removals.0; + self.0.removals.0 += 1; + Some(Self::Item::new( + self.0.edges_left[self.0.removal_order[index].as_usize()], + self.0.edges_right[self.0.removal_order[index].as_usize()], + self.0.edges_parent[self.0.removal_order[index].as_usize()], + self.0.edges_child[self.0.removal_order[index].as_usize()], + )) + } else { + None + } } +} - pub fn right(&self) -> Position { - self.right.into() +impl<'ts> Iterator for EdgeInsertionsIterator<'ts> { + type Item = EdgeDifference; + fn next(&mut self) -> Option { + if self.0.insertions.0 < self.0.insertions.1 { + let index = self.0.insertions.0; + self.0.insertions.0 += 1; + Some(Self::Item::new( + self.0.edges_left[self.0.insertion_order[index].as_usize()], + self.0.edges_right[self.0.insertion_order[index].as_usize()], + self.0.edges_parent[self.0.insertion_order[index].as_usize()], + self.0.edges_child[self.0.insertion_order[index].as_usize()], + )) + } else { + None + } } +} - pub fn interval(&self) -> (Position, Position) { - (self.left(), self.right()) +impl<'ts> CurrentTreeEdgeDifferences<'ts> { + pub fn removals(&self) -> impl Iterator + '_ { + EdgeRemovalsIterator(self.clone()) } - pub fn edge_removals(&self) -> impl Iterator + '_ { - EdgeDifferences::::new(&self.removal) + pub fn insertions(&self) -> impl Iterator + '_ { + EdgeInsertionsIterator(self.clone()) } - pub fn edge_insertions(&self) -> impl Iterator + '_ { - EdgeDifferences::::new(&self.insertion) + pub fn interval(&self) -> (Position, Position) { + (self.left.into(), self.right.into()) } } -impl crate::StreamingIterator for EdgeDifferencesIterator { - type Item = EdgeDifferencesIterator; - - fn advance(&mut self) { - self.advance_tree() +fn update_right( + right: f64, + index: usize, + position_slice: &[Position], + diff_slice: &[EdgeId], +) -> f64 { + if index < diff_slice.len() { + let temp = position_slice[diff_slice[index].as_usize()]; + if temp < right { + temp.into() + } else { + right + } + } else { + right } +} - fn get(&self) -> Option<&Self::Item> { - if self.advanced > 0 { - Some(self) +impl<'ts> Iterator for EdgeDifferencesIterator<'ts> { + type Item = CurrentTreeEdgeDifferences<'ts>; + + fn next(&mut self) -> Option { + if self.insertion_index < self.insertion_order.len() && self.left < self.sequence_length { + let removals_start = self.removal_index; + while self.removal_index < self.removal_order.len() + && self.edges_right[self.removal_order[self.removal_index].as_usize()] == self.left + { + self.removal_index += 1; + } + let insertions_start = self.insertion_index; + while self.insertion_index < self.insertion_order.len() + && self.edges_left[self.insertion_order[self.insertion_index].as_usize()] + == self.left + { + self.insertion_index += 1; + } + let right = update_right( + self.sequence_length, + self.insertion_index, + self.edges_left, + self.insertion_order, + ); + let right = update_right( + right, + self.removal_index, + self.edges_right, + self.removal_order, + ); + let diffs = CurrentTreeEdgeDifferences { + edges_left: self.edges_left, + edges_right: self.edges_right, + edges_parent: self.edges_parent, + edges_child: self.edges_child, + insertion_order: self.insertion_order, + removal_order: self.removal_order, + removals: (removals_start, self.removal_index), + insertions: (insertions_start, self.insertion_index), + left: self.left, + right, + }; + self.left = right; + Some(diffs) } else { None } diff --git a/src/trees/treeseq.rs b/src/trees/treeseq.rs index 6b1e1a094..3f8002459 100644 --- a/src/trees/treeseq.rs +++ b/src/trees/treeseq.rs @@ -477,16 +477,9 @@ impl TreeSequence { handle_tsk_return_value!(rv, crate::ProvenanceId::from(rv)) } - /// Build a lending iterator over edge differences. - /// - /// # Errors - /// - /// * [`TskitError`] if the `C` back end is unable to allocate - /// needed memory - pub fn edge_differences_iter( - &self, - ) -> Result { - crate::edge_differences::EdgeDifferencesIterator::new_from_treeseq(self, 0) + /// Build an iterator over edge differences. + pub fn edge_differences_iter(&self) -> crate::edge_differences::EdgeDifferencesIterator { + crate::edge_differences::EdgeDifferencesIterator::new(self) } /// Reference to the underlying table collection. diff --git a/tests/book_trees.rs b/tests/book_trees.rs index 2e7ac2a02..bb3efa28f 100644 --- a/tests/book_trees.rs +++ b/tests/book_trees.rs @@ -168,17 +168,13 @@ fn initialize_from_table_collection() { // } // ANCHOR: iterate_edge_differences - if let Ok(mut edge_diff_iterator) = treeseq.edge_differences_iter() { - while let Some(diffs) = edge_diff_iterator.next() { - for edge_removal in diffs.edge_removals() { - println!("{}", edge_removal); - } - for edge_insertion in diffs.edge_insertions() { - println!("{}", edge_insertion); - } + for diffs in treeseq.edge_differences_iter() { + for edge_removal in diffs.removals() { + println!("edge removal: {}", edge_removal); + } + for edge_insertion in diffs.insertions() { + println!("edge insertion: {}", edge_insertion); } - } else { - panic!("creating edge diffs iterator failed"); } // ANCHOR_END: iterate_edge_differences @@ -187,25 +183,18 @@ fn initialize_from_table_collection() { // num_nodes + 1 to reflect a "virtual root" present in // the tree arrays let mut parents = vec![NodeId::NULL; num_nodes + 1]; - match treeseq.edge_differences_iter() { - Ok(mut ediff_iter) => match treeseq.tree_iterator(0) { - Ok(mut tree_iter) => { - while let Some(diffs) = ediff_iter.next() { - let tree = tree_iter.next().unwrap(); - for edge_out in diffs.edge_removals() { - let c = edge_out.child(); - parents[c.as_usize()] = NodeId::NULL; - } - for edge_in in diffs.edge_insertions() { - let c = edge_in.child(); - parents[c.as_usize()] = edge_in.parent(); - } - assert_eq!(tree.parent_array(), &parents); - } - } - Err(e) => panic!("error creating tree iter: {:?}", e), - }, - Err(e) => panic!("error creating edge diff iter: {:?}", e), + let mut tree_iter = treeseq.tree_iterator(0).unwrap(); + for diffs in treeseq.edge_differences_iter() { + let tree = tree_iter.next().unwrap(); + for edge_out in diffs.removals() { + let c = edge_out.child(); + parents[c.as_usize()] = NodeId::NULL; + } + for edge_in in diffs.insertions() { + let c = edge_in.child(); + parents[c.as_usize()] = edge_in.parent(); + } + assert_eq!(tree.parent_array(), &parents); } // ANCHOR_END: iterate_edge_differences_update_parents } diff --git a/tests/test_edge_difference_iteration.rs b/tests/test_edge_difference_iteration.rs new file mode 100644 index 000000000..48f9763da --- /dev/null +++ b/tests/test_edge_difference_iteration.rs @@ -0,0 +1,151 @@ +fn make_treeseq() -> tskit::TreeSequence { + let mut tables = tskit::TableCollection::new(1000.).unwrap(); + tables + .add_node(0, 2.0, tskit::PopulationId::NULL, tskit::IndividualId::NULL) + .unwrap(); + tables + .add_node(0, 1.0, tskit::PopulationId::NULL, tskit::IndividualId::NULL) + .unwrap(); + tables + .add_node( + tskit::NodeFlags::new_sample(), + 0.0, + tskit::PopulationId::NULL, + tskit::IndividualId::NULL, + ) + .unwrap(); + tables + .add_node( + tskit::NodeFlags::new_sample(), + 0.0, + tskit::PopulationId::NULL, + tskit::IndividualId::NULL, + ) + .unwrap(); + tables + .add_node( + tskit::NodeFlags::new_sample(), + 0.0, + tskit::PopulationId::NULL, + tskit::IndividualId::NULL, + ) + .unwrap(); + tables + .add_node( + tskit::NodeFlags::new_sample(), + 0.0, + tskit::PopulationId::NULL, + tskit::IndividualId::NULL, + ) + .unwrap(); + tables.add_edge(500., 1000., 0, 1).unwrap(); + tables.add_edge(0., 500., 0, 2).unwrap(); + tables.add_edge(0., 1000., 0, 3).unwrap(); + tables.add_edge(500., 1000., 1, 2).unwrap(); + tables.add_edge(0., 1000., 1, 4).unwrap(); + tables.add_edge(0., 1000., 1, 5).unwrap(); + + tables + .full_sort(tskit::TableSortOptions::default()) + .unwrap(); + + tables.build_index().unwrap(); + + tables + .tree_sequence(tskit::TreeSequenceFlags::default()) + .unwrap() +} + +// A fundamental property of iterators is that their Items +// are collectible into objects that are valid to use later. + +#[test] +fn test_collected_edge_insertions() { + let ts = make_treeseq(); + // The ergonomics here seem a bit ugly but it is a corner case? + let insertions = ts + .edge_differences_iter() + .flat_map(|d| d.insertions().collect::>()) + .collect::>(); + assert_eq!(insertions.len(), ts.edge_insertion_order().len()); + for (i, j) in insertions.iter().zip(ts.edge_insertion_order().iter()) { + assert_eq!( + i.parent(), + ts.tables().edges().parent_column()[j.as_usize()] + ); + assert_eq!(i.child(), ts.tables().edges().child_column()[j.as_usize()]); + assert_eq!(i.left(), ts.tables().edges().left_column()[j.as_usize()]); + assert_eq!(i.right(), ts.tables().edges().right_column()[j.as_usize()]); + } + + // Better ergonomics + let mut insertions = vec![]; + for diffs in ts.edge_differences_iter() { + insertions.extend(diffs.insertions()); + } + assert_eq!(insertions.len(), ts.edge_insertion_order().len()); + for (i, j) in insertions.iter().zip(ts.edge_insertion_order().iter()) { + assert_eq!( + i.parent(), + ts.tables().edges().parent_column()[j.as_usize()] + ); + assert_eq!(i.child(), ts.tables().edges().child_column()[j.as_usize()]); + assert_eq!(i.left(), ts.tables().edges().left_column()[j.as_usize()]); + assert_eq!(i.right(), ts.tables().edges().right_column()[j.as_usize()]); + } +} + +#[test] +fn test_collect_edge_diff_iterators() { + let ts = make_treeseq(); + + let diffs = ts.edge_differences_iter().collect::>(); + + for (di, dj) in diffs.iter().zip(ts.edge_differences_iter()) { + for (ri, rj) in di.removals().zip(dj.removals()) { + assert_eq!(ri.parent(), rj.parent()); + assert_eq!(ri.child(), rj.child()); + assert_eq!(ri.left(), rj.left()); + assert_eq!(ri.right(), rj.right()); + } + } + + let insertions = diffs + .iter() + .flat_map(|d| d.insertions()) + .collect::>(); + assert_eq!(insertions.len(), ts.edge_insertion_order().len()); + for (i, j) in insertions.iter().zip(ts.edge_insertion_order().iter()) { + assert_eq!( + i.parent(), + ts.tables().edges().parent_column()[j.as_usize()] + ); + assert_eq!(i.child(), ts.tables().edges().child_column()[j.as_usize()]); + assert_eq!(i.left(), ts.tables().edges().left_column()[j.as_usize()]); + assert_eq!(i.right(), ts.tables().edges().right_column()[j.as_usize()]); + } + + let removals = diffs.iter().flat_map(|d| d.removals()).collect::>(); + let removal_order = ts.edge_removal_order(); + // Removals have some nuance: + // The "standard" loop ends when all IMSERTIONS havee + // been processed, which means that all edges + // leaving the tree at the "sequence length" are never visited. + let num_removals_not_at_end = removal_order + .iter() + .filter(|r| { + ts.tables().edges().right_column()[r.as_usize()] != ts.tables().sequence_length() + }) + .count(); + assert_eq!(removals.len(), num_removals_not_at_end); + + for (i, j) in removals.iter().zip(removal_order.iter()) { + assert_eq!( + i.parent(), + ts.tables().edges().parent_column()[j.as_usize()] + ); + assert_eq!(i.child(), ts.tables().edges().child_column()[j.as_usize()]); + assert_eq!(i.left(), ts.tables().edges().left_column()[j.as_usize()]); + assert_eq!(i.right(), ts.tables().edges().right_column()[j.as_usize()]); + } +}