aho_corasick/packed/teddy/
compile.rs

1// See the README in this directory for an explanation of the Teddy algorithm.
2
3use core::{cmp, fmt};
4
5use alloc::{collections::BTreeMap, format, vec, vec::Vec};
6
7use crate::packed::{
8    pattern::{PatternID, Patterns},
9    teddy::Teddy,
10};
11
12/// A builder for constructing a Teddy matcher.
13///
14/// The builder primarily permits fine grained configuration of the Teddy
15/// matcher. Most options are made only available for testing/benchmarking
16/// purposes. In reality, options are automatically determined by the nature
17/// and number of patterns given to the builder.
18#[derive(Clone, Debug)]
19pub struct Builder {
20    /// When none, this is automatically determined. Otherwise, `false` means
21    /// slim Teddy is used (8 buckets) and `true` means fat Teddy is used
22    /// (16 buckets). Fat Teddy requires AVX2, so if that CPU feature isn't
23    /// available and Fat Teddy was requested, no matcher will be built.
24    fat: Option<bool>,
25    /// When none, this is automatically determined. Otherwise, `false` means
26    /// that 128-bit vectors will be used (up to SSSE3 instructions) where as
27    /// `true` means that 256-bit vectors will be used. As with `fat`, if
28    /// 256-bit vectors are requested and they aren't available, then a
29    /// searcher will not be built.
30    avx: Option<bool>,
31}
32
33impl Default for Builder {
34    fn default() -> Builder {
35        Builder::new()
36    }
37}
38
39impl Builder {
40    /// Create a new builder for configuring a Teddy matcher.
41    pub fn new() -> Builder {
42        Builder { fat: None, avx: None }
43    }
44
45    /// Build a matcher for the set of patterns given. If a matcher could not
46    /// be built, then `None` is returned.
47    ///
48    /// Generally, a matcher isn't built if the necessary CPU features aren't
49    /// available, an unsupported target or if the searcher is believed to be
50    /// slower than standard techniques (i.e., if there are too many literals).
51    pub fn build(&self, patterns: &Patterns) -> Option<Teddy> {
52        self.build_imp(patterns)
53    }
54
55    /// Require the use of Fat (true) or Slim (false) Teddy. Fat Teddy uses
56    /// 16 buckets where as Slim Teddy uses 8 buckets. More buckets are useful
57    /// for a larger set of literals.
58    ///
59    /// `None` is the default, which results in an automatic selection based
60    /// on the number of literals and available CPU features.
61    pub fn fat(&mut self, yes: Option<bool>) -> &mut Builder {
62        self.fat = yes;
63        self
64    }
65
66    /// Request the use of 256-bit vectors (true) or 128-bit vectors (false).
67    /// Generally, a larger vector size is better since it either permits
68    /// matching more patterns or matching more bytes in the haystack at once.
69    ///
70    /// `None` is the default, which results in an automatic selection based on
71    /// the number of literals and available CPU features.
72    pub fn avx(&mut self, yes: Option<bool>) -> &mut Builder {
73        self.avx = yes;
74        self
75    }
76
77    fn build_imp(&self, patterns: &Patterns) -> Option<Teddy> {
78        use crate::packed::teddy::runtime;
79
80        // Most of the logic here is just about selecting the optimal settings,
81        // or perhaps even rejecting construction altogether. The choices
82        // we have are: fat (avx only) or not, ssse3 or avx2, and how many
83        // patterns we allow ourselves to search. Additionally, for testing
84        // and benchmarking, we permit callers to try to "force" a setting,
85        // and if the setting isn't allowed (e.g., forcing AVX when AVX isn't
86        // available), then we bail and return nothing.
87
88        if patterns.len() > 64 {
89            debug!("skipping Teddy because of too many patterns");
90            return None;
91        }
92        let has_ssse3 = std::is_x86_feature_detected!("ssse3");
93        let has_avx = std::is_x86_feature_detected!("avx2");
94        let avx = if self.avx == Some(true) {
95            if !has_avx {
96                debug!(
97                    "skipping Teddy because avx was demanded but unavailable"
98                );
99                return None;
100            }
101            true
102        } else if self.avx == Some(false) {
103            if !has_ssse3 {
104                debug!(
105                    "skipping Teddy because ssse3 was demanded but unavailable"
106                );
107                return None;
108            }
109            false
110        } else if !has_ssse3 && !has_avx {
111            debug!("skipping Teddy because ssse3 and avx are unavailable");
112            return None;
113        } else {
114            has_avx
115        };
116        let fat = match self.fat {
117            None => avx && patterns.len() > 32,
118            Some(false) => false,
119            Some(true) if !avx => {
120                debug!(
121                    "skipping Teddy because it needs to be fat, but fat \
122                     Teddy requires avx which is unavailable"
123                );
124                return None;
125            }
126            Some(true) => true,
127        };
128
129        let mut compiler = Compiler::new(patterns, fat);
130        compiler.compile();
131        let Compiler { buckets, masks, .. } = compiler;
132        // SAFETY: It is required that the builder only produce Teddy matchers
133        // that are allowed to run on the current CPU, since we later assume
134        // that the presence of (for example) TeddySlim1Mask256 means it is
135        // safe to call functions marked with the `avx2` target feature.
136        match (masks.len(), avx, fat) {
137            (1, false, _) => {
138                debug!("Teddy choice: 128-bit slim, 1 byte");
139                Some(Teddy {
140                    buckets,
141                    max_pattern_id: patterns.max_pattern_id(),
142                    exec: runtime::Exec::TeddySlim1Mask128(
143                        runtime::TeddySlim1Mask128 {
144                            mask1: runtime::Mask128::new(masks[0]),
145                        },
146                    ),
147                })
148            }
149            (1, true, false) => {
150                debug!("Teddy choice: 256-bit slim, 1 byte");
151                Some(Teddy {
152                    buckets,
153                    max_pattern_id: patterns.max_pattern_id(),
154                    exec: runtime::Exec::TeddySlim1Mask256(
155                        runtime::TeddySlim1Mask256 {
156                            mask1: runtime::Mask256::new(masks[0]),
157                        },
158                    ),
159                })
160            }
161            (1, true, true) => {
162                debug!("Teddy choice: 256-bit fat, 1 byte");
163                Some(Teddy {
164                    buckets,
165                    max_pattern_id: patterns.max_pattern_id(),
166                    exec: runtime::Exec::TeddyFat1Mask256(
167                        runtime::TeddyFat1Mask256 {
168                            mask1: runtime::Mask256::new(masks[0]),
169                        },
170                    ),
171                })
172            }
173            (2, false, _) => {
174                debug!("Teddy choice: 128-bit slim, 2 bytes");
175                Some(Teddy {
176                    buckets,
177                    max_pattern_id: patterns.max_pattern_id(),
178                    exec: runtime::Exec::TeddySlim2Mask128(
179                        runtime::TeddySlim2Mask128 {
180                            mask1: runtime::Mask128::new(masks[0]),
181                            mask2: runtime::Mask128::new(masks[1]),
182                        },
183                    ),
184                })
185            }
186            (2, true, false) => {
187                debug!("Teddy choice: 256-bit slim, 2 bytes");
188                Some(Teddy {
189                    buckets,
190                    max_pattern_id: patterns.max_pattern_id(),
191                    exec: runtime::Exec::TeddySlim2Mask256(
192                        runtime::TeddySlim2Mask256 {
193                            mask1: runtime::Mask256::new(masks[0]),
194                            mask2: runtime::Mask256::new(masks[1]),
195                        },
196                    ),
197                })
198            }
199            (2, true, true) => {
200                debug!("Teddy choice: 256-bit fat, 2 bytes");
201                Some(Teddy {
202                    buckets,
203                    max_pattern_id: patterns.max_pattern_id(),
204                    exec: runtime::Exec::TeddyFat2Mask256(
205                        runtime::TeddyFat2Mask256 {
206                            mask1: runtime::Mask256::new(masks[0]),
207                            mask2: runtime::Mask256::new(masks[1]),
208                        },
209                    ),
210                })
211            }
212            (3, false, _) => {
213                debug!("Teddy choice: 128-bit slim, 3 bytes");
214                Some(Teddy {
215                    buckets,
216                    max_pattern_id: patterns.max_pattern_id(),
217                    exec: runtime::Exec::TeddySlim3Mask128(
218                        runtime::TeddySlim3Mask128 {
219                            mask1: runtime::Mask128::new(masks[0]),
220                            mask2: runtime::Mask128::new(masks[1]),
221                            mask3: runtime::Mask128::new(masks[2]),
222                        },
223                    ),
224                })
225            }
226            (3, true, false) => {
227                debug!("Teddy choice: 256-bit slim, 3 bytes");
228                Some(Teddy {
229                    buckets,
230                    max_pattern_id: patterns.max_pattern_id(),
231                    exec: runtime::Exec::TeddySlim3Mask256(
232                        runtime::TeddySlim3Mask256 {
233                            mask1: runtime::Mask256::new(masks[0]),
234                            mask2: runtime::Mask256::new(masks[1]),
235                            mask3: runtime::Mask256::new(masks[2]),
236                        },
237                    ),
238                })
239            }
240            (3, true, true) => {
241                debug!("Teddy choice: 256-bit fat, 3 bytes");
242                Some(Teddy {
243                    buckets,
244                    max_pattern_id: patterns.max_pattern_id(),
245                    exec: runtime::Exec::TeddyFat3Mask256(
246                        runtime::TeddyFat3Mask256 {
247                            mask1: runtime::Mask256::new(masks[0]),
248                            mask2: runtime::Mask256::new(masks[1]),
249                            mask3: runtime::Mask256::new(masks[2]),
250                        },
251                    ),
252                })
253            }
254            (4, false, _) => {
255                debug!("Teddy choice: 128-bit slim, 4 bytes");
256                Some(Teddy {
257                    buckets,
258                    max_pattern_id: patterns.max_pattern_id(),
259                    exec: runtime::Exec::TeddySlim4Mask128(
260                        runtime::TeddySlim4Mask128 {
261                            mask1: runtime::Mask128::new(masks[0]),
262                            mask2: runtime::Mask128::new(masks[1]),
263                            mask3: runtime::Mask128::new(masks[2]),
264                            mask4: runtime::Mask128::new(masks[3]),
265                        },
266                    ),
267                })
268            }
269            (4, true, false) => {
270                debug!("Teddy choice: 256-bit slim, 4 bytes");
271                Some(Teddy {
272                    buckets,
273                    max_pattern_id: patterns.max_pattern_id(),
274                    exec: runtime::Exec::TeddySlim4Mask256(
275                        runtime::TeddySlim4Mask256 {
276                            mask1: runtime::Mask256::new(masks[0]),
277                            mask2: runtime::Mask256::new(masks[1]),
278                            mask3: runtime::Mask256::new(masks[2]),
279                            mask4: runtime::Mask256::new(masks[3]),
280                        },
281                    ),
282                })
283            }
284            (4, true, true) => {
285                debug!("Teddy choice: 256-bit fat, 4 bytes");
286                Some(Teddy {
287                    buckets,
288                    max_pattern_id: patterns.max_pattern_id(),
289                    exec: runtime::Exec::TeddyFat4Mask256(
290                        runtime::TeddyFat4Mask256 {
291                            mask1: runtime::Mask256::new(masks[0]),
292                            mask2: runtime::Mask256::new(masks[1]),
293                            mask3: runtime::Mask256::new(masks[2]),
294                            mask4: runtime::Mask256::new(masks[3]),
295                        },
296                    ),
297                })
298            }
299            _ => unreachable!(),
300        }
301    }
302}
303
304/// A compiler is in charge of allocating patterns into buckets and generating
305/// the masks necessary for searching.
306#[derive(Clone)]
307struct Compiler<'p> {
308    patterns: &'p Patterns,
309    buckets: Vec<Vec<PatternID>>,
310    masks: Vec<Mask>,
311}
312
313impl<'p> Compiler<'p> {
314    /// Create a new Teddy compiler for the given patterns. If `fat` is true,
315    /// then 16 buckets will be used instead of 8.
316    ///
317    /// This panics if any of the patterns given are empty.
318    fn new(patterns: &'p Patterns, fat: bool) -> Compiler<'p> {
319        let mask_len = cmp::min(4, patterns.minimum_len());
320        assert!(1 <= mask_len && mask_len <= 4);
321
322        Compiler {
323            patterns,
324            buckets: vec![vec![]; if fat { 16 } else { 8 }],
325            masks: vec![Mask::default(); mask_len],
326        }
327    }
328
329    /// Compile the patterns in this compiler into buckets and masks.
330    fn compile(&mut self) {
331        let mut lonibble_to_bucket: BTreeMap<Vec<u8>, usize> = BTreeMap::new();
332        for (id, pattern) in self.patterns.iter() {
333            // We try to be slightly clever in how we assign patterns into
334            // buckets. Generally speaking, we want patterns with the same
335            // prefix to be in the same bucket, since it minimizes the amount
336            // of time we spend churning through buckets in the verification
337            // step.
338            //
339            // So we could assign patterns with the same N-prefix (where N
340            // is the size of the mask, which is one of {1, 2, 3}) to the
341            // same bucket. However, case insensitive searches are fairly
342            // common, so we'd for example, ideally want to treat `abc` and
343            // `ABC` as if they shared the same prefix. ASCII has the nice
344            // property that the lower 4 bits of A and a are the same, so we
345            // therefore group patterns with the same low-nybbe-N-prefix into
346            // the same bucket.
347            //
348            // MOREOVER, this is actually necessary for correctness! In
349            // particular, by grouping patterns with the same prefix into the
350            // same bucket, we ensure that we preserve correct leftmost-first
351            // and leftmost-longest match semantics. In addition to the fact
352            // that `patterns.iter()` iterates in the correct order, this
353            // guarantees that all possible ambiguous matches will occur in
354            // the same bucket. The verification routine could be adjusted to
355            // support correct leftmost match semantics regardless of bucket
356            // allocation, but that results in a performance hit. It's much
357            // nicer to be able to just stop as soon as a match is found.
358            let lonybs = pattern.low_nybbles(self.masks.len());
359            if let Some(&bucket) = lonibble_to_bucket.get(&lonybs) {
360                self.buckets[bucket].push(id);
361            } else {
362                // N.B. We assign buckets in reverse because it shouldn't have
363                // any influence on performance, but it does make it harder to
364                // get leftmost match semantics accidentally correct.
365                let bucket = (self.buckets.len() - 1)
366                    - (id as usize % self.buckets.len());
367                self.buckets[bucket].push(id);
368                lonibble_to_bucket.insert(lonybs, bucket);
369            }
370        }
371        for (bucket_index, bucket) in self.buckets.iter().enumerate() {
372            for &pat_id in bucket {
373                let pat = self.patterns.get(pat_id);
374                for (i, mask) in self.masks.iter_mut().enumerate() {
375                    if self.buckets.len() == 8 {
376                        mask.add_slim(bucket_index as u8, pat.bytes()[i]);
377                    } else {
378                        mask.add_fat(bucket_index as u8, pat.bytes()[i]);
379                    }
380                }
381            }
382        }
383    }
384}
385
386impl<'p> fmt::Debug for Compiler<'p> {
387    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
388        let mut buckets = vec![vec![]; self.buckets.len()];
389        for (i, bucket) in self.buckets.iter().enumerate() {
390            for &patid in bucket {
391                buckets[i].push(self.patterns.get(patid));
392            }
393        }
394        f.debug_struct("Compiler")
395            .field("buckets", &buckets)
396            .field("masks", &self.masks)
397            .finish()
398    }
399}
400
401/// Mask represents the low and high nybble masks that will be used during
402/// search. Each mask is 32 bytes wide, although only the first 16 bytes are
403/// used for the SSSE3 runtime.
404///
405/// Each byte in the mask corresponds to a 8-bit bitset, where bit `i` is set
406/// if and only if the corresponding nybble is in the ith bucket. The index of
407/// the byte (0-15, inclusive) corresponds to the nybble.
408///
409/// Each mask is used as the target of a shuffle, where the indices for the
410/// shuffle are taken from the haystack. AND'ing the shuffles for both the
411/// low and high masks together also results in 8-bit bitsets, but where bit
412/// `i` is set if and only if the correspond *byte* is in the ith bucket.
413///
414/// During compilation, masks are just arrays. But during search, these masks
415/// are represented as 128-bit or 256-bit vectors.
416///
417/// (See the README is this directory for more details.)
418#[derive(Clone, Copy, Default)]
419pub struct Mask {
420    lo: [u8; 32],
421    hi: [u8; 32],
422}
423
424impl Mask {
425    /// Update this mask by adding the given byte to the given bucket. The
426    /// given bucket must be in the range 0-7.
427    ///
428    /// This is for "slim" Teddy, where there are only 8 buckets.
429    fn add_slim(&mut self, bucket: u8, byte: u8) {
430        assert!(bucket < 8);
431
432        let byte_lo = (byte & 0xF) as usize;
433        let byte_hi = ((byte >> 4) & 0xF) as usize;
434        // When using 256-bit vectors, we need to set this bucket assignment in
435        // the low and high 128-bit portions of the mask. This allows us to
436        // process 32 bytes at a time. Namely, AVX2 shuffles operate on each
437        // of the 128-bit lanes, rather than the full 256-bit vector at once.
438        self.lo[byte_lo] |= 1 << bucket;
439        self.lo[byte_lo + 16] |= 1 << bucket;
440        self.hi[byte_hi] |= 1 << bucket;
441        self.hi[byte_hi + 16] |= 1 << bucket;
442    }
443
444    /// Update this mask by adding the given byte to the given bucket. The
445    /// given bucket must be in the range 0-15.
446    ///
447    /// This is for "fat" Teddy, where there are 16 buckets.
448    fn add_fat(&mut self, bucket: u8, byte: u8) {
449        assert!(bucket < 16);
450
451        let byte_lo = (byte & 0xF) as usize;
452        let byte_hi = ((byte >> 4) & 0xF) as usize;
453        // Unlike slim teddy, fat teddy only works with AVX2. For fat teddy,
454        // the high 128 bits of our mask correspond to buckets 8-15, while the
455        // low 128 bits correspond to buckets 0-7.
456        if bucket < 8 {
457            self.lo[byte_lo] |= 1 << bucket;
458            self.hi[byte_hi] |= 1 << bucket;
459        } else {
460            self.lo[byte_lo + 16] |= 1 << (bucket % 8);
461            self.hi[byte_hi + 16] |= 1 << (bucket % 8);
462        }
463    }
464
465    /// Return the low 128 bits of the low-nybble mask.
466    pub fn lo128(&self) -> [u8; 16] {
467        let mut tmp = [0; 16];
468        tmp.copy_from_slice(&self.lo[..16]);
469        tmp
470    }
471
472    /// Return the full low-nybble mask.
473    pub fn lo256(&self) -> [u8; 32] {
474        self.lo
475    }
476
477    /// Return the low 128 bits of the high-nybble mask.
478    pub fn hi128(&self) -> [u8; 16] {
479        let mut tmp = [0; 16];
480        tmp.copy_from_slice(&self.hi[..16]);
481        tmp
482    }
483
484    /// Return the full high-nybble mask.
485    pub fn hi256(&self) -> [u8; 32] {
486        self.hi
487    }
488}
489
490impl fmt::Debug for Mask {
491    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
492        let (mut parts_lo, mut parts_hi) = (vec![], vec![]);
493        for i in 0..32 {
494            parts_lo.push(format!("{:02}: {:08b}", i, self.lo[i]));
495            parts_hi.push(format!("{:02}: {:08b}", i, self.hi[i]));
496        }
497        f.debug_struct("Mask")
498            .field("lo", &parts_lo)
499            .field("hi", &parts_hi)
500            .finish()
501    }
502}