aho_corasick/packed/
pattern.rs

1use core::{cmp, fmt, mem, u16, usize};
2
3use alloc::{string::String, vec, vec::Vec};
4
5use crate::packed::api::MatchKind;
6
7/// The type used for representing a pattern identifier.
8///
9/// We don't use `usize` here because our packed searchers don't scale to
10/// huge numbers of patterns, so we keep things a bit smaller.
11pub type PatternID = u16;
12
13/// A non-empty collection of non-empty patterns to search for.
14///
15/// This collection of patterns is what is passed around to both execute
16/// searches and to construct the searchers themselves. Namely, this permits
17/// searches to avoid copying all of the patterns, and allows us to keep only
18/// one copy throughout all packed searchers.
19///
20/// Note that this collection is not a set. The same pattern can appear more
21/// than once.
22#[derive(Clone, Debug)]
23pub struct Patterns {
24    /// The match semantics supported by this collection of patterns.
25    ///
26    /// The match semantics determines the order of the iterator over patterns.
27    /// For leftmost-first, patterns are provided in the same order as were
28    /// provided by the caller. For leftmost-longest, patterns are provided in
29    /// descending order of length, with ties broken by the order in which they
30    /// were provided by the caller.
31    kind: MatchKind,
32    /// The collection of patterns, indexed by their identifier.
33    by_id: Vec<Vec<u8>>,
34    /// The order of patterns defined for iteration, given by pattern
35    /// identifiers. The order of `by_id` and `order` is always the same for
36    /// leftmost-first semantics, but may be different for leftmost-longest
37    /// semantics.
38    order: Vec<PatternID>,
39    /// The length of the smallest pattern, in bytes.
40    minimum_len: usize,
41    /// The largest pattern identifier. This should always be equivalent to
42    /// the number of patterns minus one in this collection.
43    max_pattern_id: PatternID,
44    /// The total number of pattern bytes across the entire collection. This
45    /// is used for reporting total heap usage in constant time.
46    total_pattern_bytes: usize,
47}
48
49impl Patterns {
50    /// Create a new collection of patterns for the given match semantics. The
51    /// ID of each pattern is the index of the pattern at which it occurs in
52    /// the `by_id` slice.
53    ///
54    /// If any of the patterns in the slice given are empty, then this panics.
55    /// Similarly, if the number of patterns given is zero, then this also
56    /// panics.
57    pub fn new() -> Patterns {
58        Patterns {
59            kind: MatchKind::default(),
60            by_id: vec![],
61            order: vec![],
62            minimum_len: usize::MAX,
63            max_pattern_id: 0,
64            total_pattern_bytes: 0,
65        }
66    }
67
68    /// Add a pattern to this collection.
69    ///
70    /// This panics if the pattern given is empty.
71    pub fn add(&mut self, bytes: &[u8]) {
72        assert!(!bytes.is_empty());
73        assert!(self.by_id.len() <= u16::MAX as usize);
74
75        let id = self.by_id.len() as u16;
76        self.max_pattern_id = id;
77        self.order.push(id);
78        self.by_id.push(bytes.to_vec());
79        self.minimum_len = cmp::min(self.minimum_len, bytes.len());
80        self.total_pattern_bytes += bytes.len();
81    }
82
83    /// Set the match kind semantics for this collection of patterns.
84    ///
85    /// If the kind is not set, then the default is leftmost-first.
86    pub fn set_match_kind(&mut self, kind: MatchKind) {
87        self.kind = kind;
88        match self.kind {
89            MatchKind::LeftmostFirst => {
90                self.order.sort();
91            }
92            MatchKind::LeftmostLongest => {
93                let (order, by_id) = (&mut self.order, &mut self.by_id);
94                order.sort_by(|&id1, &id2| {
95                    by_id[id1 as usize]
96                        .len()
97                        .cmp(&by_id[id2 as usize].len())
98                        .reverse()
99                });
100            }
101        }
102    }
103
104    /// Return the number of patterns in this collection.
105    ///
106    /// This is guaranteed to be greater than zero.
107    pub fn len(&self) -> usize {
108        self.by_id.len()
109    }
110
111    /// Returns true if and only if this collection of patterns is empty.
112    pub fn is_empty(&self) -> bool {
113        self.len() == 0
114    }
115
116    /// Returns the approximate total amount of heap used by these patterns, in
117    /// units of bytes.
118    pub fn memory_usage(&self) -> usize {
119        self.order.len() * mem::size_of::<PatternID>()
120            + self.by_id.len() * mem::size_of::<Vec<u8>>()
121            + self.total_pattern_bytes
122    }
123
124    /// Clears all heap memory associated with this collection of patterns and
125    /// resets all state such that it is a valid empty collection.
126    pub fn reset(&mut self) {
127        self.kind = MatchKind::default();
128        self.by_id.clear();
129        self.order.clear();
130        self.minimum_len = usize::MAX;
131        self.max_pattern_id = 0;
132    }
133
134    /// Return the maximum pattern identifier in this collection. This can be
135    /// useful in searchers for ensuring that the collection of patterns they
136    /// are provided at search time and at build time have the same size.
137    pub fn max_pattern_id(&self) -> PatternID {
138        assert_eq!((self.max_pattern_id + 1) as usize, self.len());
139        self.max_pattern_id
140    }
141
142    /// Returns the length, in bytes, of the smallest pattern.
143    ///
144    /// This is guaranteed to be at least one.
145    pub fn minimum_len(&self) -> usize {
146        self.minimum_len
147    }
148
149    /// Returns the match semantics used by these patterns.
150    pub fn match_kind(&self) -> &MatchKind {
151        &self.kind
152    }
153
154    /// Return the pattern with the given identifier. If such a pattern does
155    /// not exist, then this panics.
156    pub fn get(&self, id: PatternID) -> Pattern<'_> {
157        Pattern(&self.by_id[id as usize])
158    }
159
160    /// Return the pattern with the given identifier without performing bounds
161    /// checks.
162    ///
163    /// # Safety
164    ///
165    /// Callers must ensure that a pattern with the given identifier exists
166    /// before using this method.
167    #[cfg(all(feature = "std", target_arch = "x86_64"))]
168    pub unsafe fn get_unchecked(&self, id: PatternID) -> Pattern<'_> {
169        Pattern(self.by_id.get_unchecked(id as usize))
170    }
171
172    /// Return an iterator over all the patterns in this collection, in the
173    /// order in which they should be matched.
174    ///
175    /// Specifically, in a naive multi-pattern matcher, the following is
176    /// guaranteed to satisfy the match semantics of this collection of
177    /// patterns:
178    ///
179    /// ```ignore
180    /// for i in 0..haystack.len():
181    ///   for p in patterns.iter():
182    ///     if haystack[i..].starts_with(p.bytes()):
183    ///       return Match(p.id(), i, i + p.bytes().len())
184    /// ```
185    ///
186    /// Namely, among the patterns in a collection, if they are matched in
187    /// the order provided by this iterator, then the result is guaranteed
188    /// to satisfy the correct match semantics. (Either leftmost-first or
189    /// leftmost-longest.)
190    pub fn iter(&self) -> PatternIter<'_> {
191        PatternIter { patterns: self, i: 0 }
192    }
193}
194
195/// An iterator over the patterns in the `Patterns` collection.
196///
197/// The order of the patterns provided by this iterator is consistent with the
198/// match semantics of the originating collection of patterns.
199///
200/// The lifetime `'p` corresponds to the lifetime of the collection of patterns
201/// this is iterating over.
202#[derive(Debug)]
203pub struct PatternIter<'p> {
204    patterns: &'p Patterns,
205    i: usize,
206}
207
208impl<'p> Iterator for PatternIter<'p> {
209    type Item = (PatternID, Pattern<'p>);
210
211    fn next(&mut self) -> Option<(PatternID, Pattern<'p>)> {
212        if self.i >= self.patterns.len() {
213            return None;
214        }
215        let id = self.patterns.order[self.i];
216        let p = self.patterns.get(id);
217        self.i += 1;
218        Some((id, p))
219    }
220}
221
222/// A pattern that is used in packed searching.
223#[derive(Clone)]
224pub struct Pattern<'a>(&'a [u8]);
225
226impl<'a> fmt::Debug for Pattern<'a> {
227    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
228        f.debug_struct("Pattern")
229            .field("lit", &String::from_utf8_lossy(&self.0))
230            .finish()
231    }
232}
233
234impl<'p> Pattern<'p> {
235    /// Returns the length of this pattern, in bytes.
236    pub fn len(&self) -> usize {
237        self.0.len()
238    }
239
240    /// Returns the bytes of this pattern.
241    pub fn bytes(&self) -> &[u8] {
242        &self.0
243    }
244
245    /// Returns the first `len` low nybbles from this pattern. If this pattern
246    /// is shorter than `len`, then this panics.
247    #[cfg(all(feature = "std", target_arch = "x86_64"))]
248    pub fn low_nybbles(&self, len: usize) -> Vec<u8> {
249        let mut nybs = vec![];
250        for &b in self.bytes().iter().take(len) {
251            nybs.push(b & 0xF);
252        }
253        nybs
254    }
255
256    /// Returns true if this pattern is a prefix of the given bytes.
257    #[inline(always)]
258    pub fn is_prefix(&self, bytes: &[u8]) -> bool {
259        self.len() <= bytes.len() && self.equals(&bytes[..self.len()])
260    }
261
262    /// Returns true if and only if this pattern equals the given bytes.
263    #[inline(always)]
264    pub fn equals(&self, bytes: &[u8]) -> bool {
265        // Why not just use memcmp for this? Well, memcmp requires calling out
266        // to libc, and this routine is called in fairly hot code paths. Other
267        // than just calling out to libc, it also seems to result in worse
268        // codegen. By rolling our own memcpy in pure Rust, it seems to appear
269        // more friendly to the optimizer.
270        //
271        // This results in an improvement in just about every benchmark. Some
272        // smaller than others, but in some cases, up to 30% faster.
273
274        let (x, y) = (self.bytes(), bytes);
275        if x.len() != y.len() {
276            return false;
277        }
278        // If we don't have enough bytes to do 4-byte at a time loads, then
279        // fall back to the naive slow version.
280        if x.len() < 4 {
281            for (&b1, &b2) in x.iter().zip(y) {
282                if b1 != b2 {
283                    return false;
284                }
285            }
286            return true;
287        }
288        // When we have 4 or more bytes to compare, then proceed in chunks of 4
289        // at a time using unaligned loads.
290        //
291        // Also, why do 4 byte loads instead of, say, 8 byte loads? The reason
292        // is that this particular version of memcmp is likely to be called
293        // with tiny needles. That means that if we do 8 byte loads, then a
294        // higher proportion of memcmp calls will use the slower variant above.
295        // With that said, this is a hypothesis and is only loosely supported
296        // by benchmarks. There's likely some improvement that could be made
297        // here. The main thing here though is to optimize for latency, not
298        // throughput.
299
300        // SAFETY: Via the conditional above, we know that both `px` and `py`
301        // have the same length, so `px < pxend` implies that `py < pyend`.
302        // Thus, derefencing both `px` and `py` in the loop below is safe.
303        //
304        // Moreover, we set `pxend` and `pyend` to be 4 bytes before the actual
305        // end of of `px` and `py`. Thus, the final dereference outside of the
306        // loop is guaranteed to be valid. (The final comparison will overlap
307        // with the last comparison done in the loop for lengths that aren't
308        // multiples of four.)
309        //
310        // Finally, we needn't worry about alignment here, since we do
311        // unaligned loads.
312        unsafe {
313            let (mut px, mut py) = (x.as_ptr(), y.as_ptr());
314            let (pxend, pyend) = (px.add(x.len() - 4), py.add(y.len() - 4));
315            while px < pxend {
316                let vx = (px as *const u32).read_unaligned();
317                let vy = (py as *const u32).read_unaligned();
318                if vx != vy {
319                    return false;
320                }
321                px = px.add(4);
322                py = py.add(4);
323            }
324            let vx = (pxend as *const u32).read_unaligned();
325            let vy = (pyend as *const u32).read_unaligned();
326            vx == vy
327        }
328    }
329}