comrak/
arena_tree.rs

1/*!
2  Included from <https://github.com/SimonSapin/rust-forest/blob/5783c8be8680b84c0438638bdee07d4e4aca40ac/arena-tree/lib.rs>.
3  MIT license (per Cargo.toml).
4
5A DOM-like tree data structure based on `&Node` references.
6
7Any non-trivial tree involves reference cycles
8(e.g. if a node has a first child, the parent of the child is that node).
9To enable this, nodes need to live in an arena allocator
10such as `arena::TypedArena` distributed with rustc (which is `#[unstable]` as of this writing)
11or [`typed_arena::Arena`](https://crates.io/crates/typed-arena).
12
13If you need mutability in the node’s `data`,
14make it a cell (`Cell` or `RefCell`) or use cells inside of it.
15
16*/
17
18use std::cell::Cell;
19use std::fmt;
20
21/// A node inside a DOM-like tree.
22pub struct Node<'a, T: 'a> {
23    parent: Cell<Option<&'a Node<'a, T>>>,
24    previous_sibling: Cell<Option<&'a Node<'a, T>>>,
25    next_sibling: Cell<Option<&'a Node<'a, T>>>,
26    first_child: Cell<Option<&'a Node<'a, T>>>,
27    last_child: Cell<Option<&'a Node<'a, T>>>,
28
29    /// The data held by the node.
30    pub data: T,
31}
32
33/// A simple Debug implementation that prints the children as a tree, without
34/// looping through the various interior pointer cycles.
35impl<'a, T: 'a> fmt::Debug for Node<'a, T>
36where
37    T: fmt::Debug,
38{
39    fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
40        struct Children<'a, T>(Option<&'a Node<'a, T>>);
41        impl<T: fmt::Debug> fmt::Debug for Children<'_, T> {
42            fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
43                f.debug_list()
44                    .entries(std::iter::successors(self.0, |child| {
45                        child.next_sibling.get()
46                    }))
47                    .finish()
48            }
49        }
50
51        let mut struct_fmt = f.debug_struct("Node");
52        struct_fmt.field("data", &self.data);
53        struct_fmt.field("children", &Children(self.first_child.get()));
54        struct_fmt.finish()?;
55
56        Ok(())
57    }
58}
59
60impl<'a, T> Node<'a, T> {
61    /// Create a new node from its associated data.
62    ///
63    /// Typically, this node needs to be moved into an arena allocator
64    /// before it can be used in a tree.
65    pub fn new(data: T) -> Node<'a, T> {
66        Node {
67            parent: Cell::new(None),
68            first_child: Cell::new(None),
69            last_child: Cell::new(None),
70            previous_sibling: Cell::new(None),
71            next_sibling: Cell::new(None),
72            data,
73        }
74    }
75
76    /// Return a reference to the parent node, unless this node is the root of the tree.
77    pub fn parent(&self) -> Option<&'a Node<'a, T>> {
78        self.parent.get()
79    }
80
81    /// Return a reference to the first child of this node, unless it has no child.
82    pub fn first_child(&self) -> Option<&'a Node<'a, T>> {
83        self.first_child.get()
84    }
85
86    /// Return a reference to the last child of this node, unless it has no child.
87    pub fn last_child(&self) -> Option<&'a Node<'a, T>> {
88        self.last_child.get()
89    }
90
91    /// Return a reference to the previous sibling of this node, unless it is a first child.
92    pub fn previous_sibling(&self) -> Option<&'a Node<'a, T>> {
93        self.previous_sibling.get()
94    }
95
96    /// Return a reference to the next sibling of this node, unless it is a last child.
97    pub fn next_sibling(&self) -> Option<&'a Node<'a, T>> {
98        self.next_sibling.get()
99    }
100
101    /// Returns whether two references point to the same node.
102    pub fn same_node(&self, other: &Node<'a, T>) -> bool {
103        std::ptr::eq(self, other)
104    }
105
106    /// Return an iterator of references to this node and its ancestors.
107    ///
108    /// Call `.next().unwrap()` once on the iterator to skip the node itself.
109    pub fn ancestors(&'a self) -> Ancestors<'a, T> {
110        Ancestors(Some(self))
111    }
112
113    /// Return an iterator of references to this node and the siblings before it.
114    ///
115    /// Call `.next().unwrap()` once on the iterator to skip the node itself.
116    pub fn preceding_siblings(&'a self) -> PrecedingSiblings<'a, T> {
117        PrecedingSiblings(Some(self))
118    }
119
120    /// Return an iterator of references to this node and the siblings after it.
121    ///
122    /// Call `.next().unwrap()` once on the iterator to skip the node itself.
123    pub fn following_siblings(&'a self) -> FollowingSiblings<'a, T> {
124        FollowingSiblings(Some(self))
125    }
126
127    /// Return an iterator of references to this node’s children.
128    pub fn children(&'a self) -> Children<'a, T> {
129        Children(self.first_child.get())
130    }
131
132    /// Return an iterator of references to this node’s children, in reverse order.
133    pub fn reverse_children(&'a self) -> ReverseChildren<'a, T> {
134        ReverseChildren(self.last_child.get())
135    }
136
137    /// Return an iterator of references to this `Node` and its descendants, in tree order.
138    ///
139    /// Parent nodes appear before the descendants.
140    /// Call `.next().unwrap()` once on the iterator to skip the node itself.
141    ///
142    /// *Similar Functions:* Use `traverse()` or `reverse_traverse` if you need
143    /// references to the `NodeEdge` structs associated with each `Node`
144    pub fn descendants(&'a self) -> Descendants<'a, T> {
145        Descendants(self.traverse())
146    }
147
148    /// Return an iterator of references to `NodeEdge` enums for each `Node` and its descendants,
149    /// in tree order.
150    ///
151    /// `NodeEdge` enums represent the `Start` or `End` of each node.
152    ///
153    /// *Similar Functions:* Use `descendants()` if you don't need `Start` and `End`.
154    pub fn traverse(&'a self) -> Traverse<'a, T> {
155        Traverse {
156            root: self,
157            next: Some(NodeEdge::Start(self)),
158        }
159    }
160
161    /// Return an iterator of references to `NodeEdge` enums for each `Node` and its descendants,
162    /// in *reverse* order.
163    ///
164    /// `NodeEdge` enums represent the `Start` or `End` of each node.
165    ///
166    /// *Similar Functions:* Use `descendants()` if you don't need `Start` and `End`.
167    pub fn reverse_traverse(&'a self) -> ReverseTraverse<'a, T> {
168        ReverseTraverse {
169            root: self,
170            next: Some(NodeEdge::End(self)),
171        }
172    }
173
174    /// Detach a node from its parent and siblings. Children are not affected.
175    pub fn detach(&self) {
176        let parent = self.parent.take();
177        let previous_sibling = self.previous_sibling.take();
178        let next_sibling = self.next_sibling.take();
179
180        if let Some(next_sibling) = next_sibling {
181            next_sibling.previous_sibling.set(previous_sibling);
182        } else if let Some(parent) = parent {
183            parent.last_child.set(previous_sibling);
184        }
185
186        if let Some(previous_sibling) = previous_sibling {
187            previous_sibling.next_sibling.set(next_sibling);
188        } else if let Some(parent) = parent {
189            parent.first_child.set(next_sibling);
190        }
191    }
192
193    /// Append a new child to this node, after existing children.
194    pub fn append(&'a self, new_child: &'a Node<'a, T>) {
195        new_child.detach();
196        new_child.parent.set(Some(self));
197        if let Some(last_child) = self.last_child.take() {
198            new_child.previous_sibling.set(Some(last_child));
199            debug_assert!(last_child.next_sibling.get().is_none());
200            last_child.next_sibling.set(Some(new_child));
201        } else {
202            debug_assert!(self.first_child.get().is_none());
203            self.first_child.set(Some(new_child));
204        }
205        self.last_child.set(Some(new_child));
206    }
207
208    /// Prepend a new child to this node, before existing children.
209    pub fn prepend(&'a self, new_child: &'a Node<'a, T>) {
210        new_child.detach();
211        new_child.parent.set(Some(self));
212        if let Some(first_child) = self.first_child.take() {
213            debug_assert!(first_child.previous_sibling.get().is_none());
214            first_child.previous_sibling.set(Some(new_child));
215            new_child.next_sibling.set(Some(first_child));
216        } else {
217            debug_assert!(self.first_child.get().is_none());
218            self.last_child.set(Some(new_child));
219        }
220        self.first_child.set(Some(new_child));
221    }
222
223    /// Insert a new sibling after this node.
224    pub fn insert_after(&'a self, new_sibling: &'a Node<'a, T>) {
225        new_sibling.detach();
226        new_sibling.parent.set(self.parent.get());
227        new_sibling.previous_sibling.set(Some(self));
228        if let Some(next_sibling) = self.next_sibling.take() {
229            debug_assert!(std::ptr::eq(
230                next_sibling.previous_sibling.get().unwrap(),
231                self
232            ));
233            next_sibling.previous_sibling.set(Some(new_sibling));
234            new_sibling.next_sibling.set(Some(next_sibling));
235        } else if let Some(parent) = self.parent.get() {
236            debug_assert!(std::ptr::eq(parent.last_child.get().unwrap(), self));
237            parent.last_child.set(Some(new_sibling));
238        }
239        self.next_sibling.set(Some(new_sibling));
240    }
241
242    /// Insert a new sibling before this node.
243    pub fn insert_before(&'a self, new_sibling: &'a Node<'a, T>) {
244        new_sibling.detach();
245        new_sibling.parent.set(self.parent.get());
246        new_sibling.next_sibling.set(Some(self));
247        if let Some(previous_sibling) = self.previous_sibling.take() {
248            new_sibling.previous_sibling.set(Some(previous_sibling));
249            debug_assert!(std::ptr::eq(
250                previous_sibling.next_sibling.get().unwrap(),
251                self
252            ));
253            previous_sibling.next_sibling.set(Some(new_sibling));
254        } else if let Some(parent) = self.parent.get() {
255            debug_assert!(std::ptr::eq(parent.first_child.get().unwrap(), self));
256            parent.first_child.set(Some(new_sibling));
257        }
258        self.previous_sibling.set(Some(new_sibling));
259    }
260}
261
262macro_rules! axis_iterator {
263    (#[$attr:meta] $name:ident : $next:ident) => {
264        #[$attr]
265        #[derive(Debug)]
266        pub struct $name<'a, T: 'a>(Option<&'a Node<'a, T>>);
267
268        impl<'a, T> Iterator for $name<'a, T> {
269            type Item = &'a Node<'a, T>;
270
271            fn next(&mut self) -> Option<&'a Node<'a, T>> {
272                match self.0.take() {
273                    Some(node) => {
274                        self.0 = node.$next.get();
275                        Some(node)
276                    }
277                    None => None,
278                }
279            }
280        }
281    };
282}
283
284axis_iterator! {
285    #[doc = "An iterator of references to the ancestors a given node."]
286    Ancestors: parent
287}
288
289axis_iterator! {
290    #[doc = "An iterator of references to the siblings before a given node."]
291    PrecedingSiblings: previous_sibling
292}
293
294axis_iterator! {
295    #[doc = "An iterator of references to the siblings after a given node."]
296    FollowingSiblings: next_sibling
297}
298
299axis_iterator! {
300    #[doc = "An iterator of references to the children of a given node."]
301    Children: next_sibling
302}
303
304axis_iterator! {
305    #[doc = "An iterator of references to the children of a given node, in reverse order."]
306    ReverseChildren: previous_sibling
307}
308
309/// An iterator of references to a given node and its descendants, in tree order.
310#[derive(Debug)]
311pub struct Descendants<'a, T: 'a>(Traverse<'a, T>);
312
313impl<'a, T> Iterator for Descendants<'a, T> {
314    type Item = &'a Node<'a, T>;
315
316    fn next(&mut self) -> Option<&'a Node<'a, T>> {
317        loop {
318            match self.0.next() {
319                Some(NodeEdge::Start(node)) => return Some(node),
320                Some(NodeEdge::End(_)) => {}
321                None => return None,
322            }
323        }
324    }
325}
326
327/// An edge of the node graph returned by a traversal iterator.
328#[derive(Debug, Clone)]
329pub enum NodeEdge<T> {
330    /// Indicates that start of a node that has children.
331    /// Yielded by `Traverse::next` before the node’s descendants.
332    /// In HTML or XML, this corresponds to an opening tag like `<div>`
333    Start(T),
334
335    /// Indicates that end of a node that has children.
336    /// Yielded by `Traverse::next` after the node’s descendants.
337    /// In HTML or XML, this corresponds to a closing tag like `</div>`
338    End(T),
339}
340
341macro_rules! traverse_iterator {
342    (#[$attr:meta] $name:ident : $first_child:ident, $next_sibling:ident) => {
343        #[$attr]
344        #[derive(Debug)]
345        pub struct $name<'a, T: 'a> {
346            root: &'a Node<'a, T>,
347            next: Option<NodeEdge<&'a Node<'a, T>>>,
348        }
349
350        impl<'a, T> Iterator for $name<'a, T> {
351            type Item = NodeEdge<&'a Node<'a, T>>;
352
353            fn next(&mut self) -> Option<NodeEdge<&'a Node<'a, T>>> {
354                match self.next.take() {
355                    Some(item) => {
356                        self.next = match item {
357                            NodeEdge::Start(node) => match node.$first_child.get() {
358                                Some(child) => Some(NodeEdge::Start(child)),
359                                None => Some(NodeEdge::End(node)),
360                            },
361                            NodeEdge::End(node) => {
362                                if node.same_node(self.root) {
363                                    None
364                                } else {
365                                    match node.$next_sibling.get() {
366                                        Some(sibling) => Some(NodeEdge::Start(sibling)),
367                                        None => match node.parent.get() {
368                                            Some(parent) => Some(NodeEdge::End(parent)),
369                                            None => panic!("tree modified during iteration"),
370                                        },
371                                    }
372                                }
373                            }
374                        };
375                        Some(item)
376                    }
377                    None => None,
378                }
379            }
380        }
381    };
382}
383
384traverse_iterator! {
385    #[doc = "An iterator of the start and end edges of a given
386    node and its descendants, in tree order."]
387    Traverse: first_child, next_sibling
388}
389
390traverse_iterator! {
391    #[doc = "An iterator of the start and end edges of a given
392    node and its descendants, in reverse tree order."]
393    ReverseTraverse: last_child, previous_sibling
394}
395
396#[test]
397fn it_works() {
398    struct DropTracker<'a>(&'a Cell<u32>);
399    impl<'a> Drop for DropTracker<'a> {
400        fn drop(&mut self) {
401            self.0.set(self.0.get() + 1);
402        }
403    }
404
405    let drop_counter = Cell::new(0);
406    {
407        let mut new_counter = 0;
408        let arena = typed_arena::Arena::new();
409        let mut new = || {
410            new_counter += 1;
411            arena.alloc(Node::new((new_counter, DropTracker(&drop_counter))))
412        };
413
414        let a = new(); // 1
415        a.append(new()); // 2
416        a.append(new()); // 3
417        a.prepend(new()); // 4
418        let b = new(); // 5
419        b.append(a);
420        a.insert_before(new()); // 6
421        a.insert_before(new()); // 7
422        a.insert_after(new()); // 8
423        a.insert_after(new()); // 9
424        let c = new(); // 10
425        b.append(c);
426
427        assert_eq!(drop_counter.get(), 0);
428        c.previous_sibling.get().unwrap().detach();
429        assert_eq!(drop_counter.get(), 0);
430
431        assert_eq!(
432            b.descendants().map(|node| node.data.0).collect::<Vec<_>>(),
433            [5, 6, 7, 1, 4, 2, 3, 9, 10]
434        );
435    }
436
437    assert_eq!(drop_counter.get(), 10);
438}