iri_string/normalize/
pct_case.rs

1//! Percent-encoding normalization and case normalization.
2
3use core::cmp::Ordering;
4use core::fmt::{self, Write as _};
5use core::marker::PhantomData;
6
7use crate::format::eq_str_display;
8use crate::parser::char::{is_ascii_unreserved, is_unreserved, is_utf8_byte_continue};
9use crate::parser::str::{find_split_hole, take_first_char};
10use crate::parser::trusted::take_xdigits2;
11use crate::spec::Spec;
12
13/// Returns true if the given string is percent-encoding normalized and case
14/// normalized.
15///
16/// Note that normalization of ASCII-only host requires additional case
17/// normalization, so checking by this function is not sufficient for that case.
18pub(crate) fn is_pct_case_normalized<S: Spec>(s: &str) -> bool {
19    eq_str_display(s, &PctCaseNormalized::<S>::new(s))
20}
21
22/// Returns a character for the slice.
23///
24/// Essentially equivalent to `core::str::from_utf8(bytes).unwrap().and_then(|s| s.get(0))`,
25/// but this function fully trusts that the input is a valid UTF-8 string with
26/// only one character.
27fn into_char_trusted(bytes: &[u8]) -> Result<char, ()> {
28    /// The bit mask to get the content part in a continue byte.
29    const CONTINUE_BYTE_MASK: u8 = 0b_0011_1111;
30    /// Minimum valid values for a code point in a UTF-8 sequence of 2, 3, and 4 bytes.
31    const MIN: [u32; 3] = [0x80, 0x800, 0x1_0000];
32
33    let len = bytes.len();
34    let c: u32 = match len {
35        2 => (u32::from(bytes[0] & 0b_0001_1111) << 6) | u32::from(bytes[1] & CONTINUE_BYTE_MASK),
36        3 => {
37            (u32::from(bytes[0] & 0b_0000_1111) << 12)
38                | (u32::from(bytes[1] & CONTINUE_BYTE_MASK) << 6)
39                | u32::from(bytes[2] & CONTINUE_BYTE_MASK)
40        }
41        4 => {
42            (u32::from(bytes[0] & 0b_0000_0111) << 18)
43                | (u32::from(bytes[1] & CONTINUE_BYTE_MASK) << 12)
44                | (u32::from(bytes[2] & CONTINUE_BYTE_MASK) << 6)
45                | u32::from(bytes[3] & CONTINUE_BYTE_MASK)
46        }
47        len => unreachable!(
48            "[consistency] expected 2, 3, or 4 bytes for a character, but got {len} as the length"
49        ),
50    };
51    if c < MIN[len - 2] {
52        // Redundant UTF-8 encoding.
53        return Err(());
54    }
55    // Can be an invalid Unicode code point.
56    char::from_u32(c).ok_or(())
57}
58
59/// Writable as a normalized path segment percent-encoding IRI.
60///
61/// This wrapper does the things below when being formatted:
62///
63/// * Decode unnecessarily percent-encoded characters.
64/// * Convert alphabetic characters uppercase in percent-encoded triplets.
65///
66/// Note that this does not newly encode raw characters.
67///
68/// # Safety
69///
70/// The given string should be the valid path segment.
71#[derive(Debug, Clone, Copy)]
72pub(crate) struct PctCaseNormalized<'a, S> {
73    /// Valid segment name to normalize.
74    segname: &'a str,
75    /// Spec.
76    _spec: PhantomData<fn() -> S>,
77}
78
79impl<'a, S: Spec> PctCaseNormalized<'a, S> {
80    /// Creates a new `PctCaseNormalized` value.
81    #[inline]
82    #[must_use]
83    pub(crate) fn new(source: &'a str) -> Self {
84        Self {
85            segname: source,
86            _spec: PhantomData,
87        }
88    }
89}
90
91impl<S: Spec> fmt::Display for PctCaseNormalized<'_, S> {
92    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
93        let mut rest = self.segname;
94
95        'outer_loop: while !rest.is_empty() {
96            // Scan the next percent-encoded triplet.
97            let (prefix, after_percent) = match find_split_hole(rest, b'%') {
98                Some(v) => v,
99                None => return f.write_str(rest),
100            };
101            // Write the string before the percent-encoded triplet.
102            f.write_str(prefix)?;
103            // Decode the percent-encoded triplet.
104            let (first_decoded, after_first_triplet) = take_xdigits2(after_percent);
105            rest = after_first_triplet;
106
107            if first_decoded.is_ascii() {
108                if is_ascii_unreserved(first_decoded) {
109                    // Unreserved. Print the decoded.
110                    f.write_char(char::from(first_decoded))?;
111                } else {
112                    write!(f, "%{:02X}", first_decoded)?;
113                }
114                continue 'outer_loop;
115            }
116
117            // Continue byte cannot be the first byte of a character.
118            if is_utf8_byte_continue(first_decoded) {
119                write!(f, "%{:02X}", first_decoded)?;
120                continue 'outer_loop;
121            }
122
123            // Get the expected length of decoded char.
124            let expected_char_len = match (first_decoded & 0xf0).cmp(&0b1110_0000) {
125                Ordering::Less => 2,
126                Ordering::Equal => 3,
127                Ordering::Greater => 4,
128            };
129
130            // Get continue bytes.
131            let c_buf = &mut [first_decoded, 0, 0, 0][..expected_char_len];
132            for (i, buf_dest) in c_buf[1..].iter_mut().enumerate() {
133                match take_first_char(rest) {
134                    Some(('%', after_percent)) => {
135                        let (byte, after_triplet) = take_xdigits2(after_percent);
136                        if !is_utf8_byte_continue(byte) {
137                            // Note that `byte` can start the new string.
138                            // Leave the byte in the `rest` for next try (i.e.
139                            // don't update `rest` in this case).
140                            c_buf[..=i]
141                                .iter()
142                                .try_for_each(|b| write!(f, "%{:02X}", b))?;
143                            continue 'outer_loop;
144                        }
145                        *buf_dest = byte;
146                        rest = after_triplet;
147                    }
148                    // If the next character is not `%`, decoded bytes so far
149                    // won't be valid UTF-8 byte sequence.
150                    // Write the read percent-encoded triplets without decoding.
151                    // Note that all characters in `&c_buf[1..]` (if available)
152                    // will be decoded to "continue byte" of UTF-8, so they
153                    // cannot be the start of a valid UTF-8 byte sequence if
154                    // decoded.
155                    Some((c, after_percent)) => {
156                        c_buf[..=i]
157                            .iter()
158                            .try_for_each(|b| write!(f, "%{:02X}", b))?;
159                        f.write_char(c)?;
160                        rest = after_percent;
161                        continue 'outer_loop;
162                    }
163                    None => {
164                        c_buf[..=i]
165                            .iter()
166                            .try_for_each(|b| write!(f, "%{:02X}", b))?;
167                        // Reached the end of the string.
168                        break 'outer_loop;
169                    }
170                }
171            }
172
173            // Decode the bytes into a character.
174            match into_char_trusted(&c_buf[..expected_char_len]) {
175                Ok(decoded_c) => {
176                    if is_unreserved::<S>(decoded_c) {
177                        // Unreserved. Print the decoded.
178                        f.write_char(decoded_c)?;
179                    } else {
180                        c_buf[0..expected_char_len]
181                            .iter()
182                            .try_for_each(|b| write!(f, "%{:02X}", b))?;
183                    }
184                }
185                Err(_) => {
186                    // Skip decoding of the entire sequence of pct-encoded triplets loaded
187                    // in `c_buf`. This is valid from the reasons below.
188                    //
189                    // * The first byte in `c_buf` is valid as the first byte, and it tells the
190                    //   expected number of bytes for a code unit. The cases the bytes being too
191                    //   short and the sequence being incomplete have already been handled, and
192                    //   the execution does not reach here then.
193                    // * All of the non-first bytes are checked if they are valid as UTF8 continue
194                    //   bytes by `is_utf8_byte_continue()`. If they're not, the decoding of
195                    //   that codepoint is aborted and the bytes in the buffer are immediately
196                    //   emitted as pct-encoded, and the execution does not reach here. This
197                    //   means that the bytes in the current `c_buf` have passed these tests.
198                    // * Since all of the the non-first bytes are UTF8 continue bytes, any of
199                    //   them cannot start the new valid UTF-8 byte sequence. This means that
200                    //   if the bytes in the buffer does not consitute a valid UTF-8 bytes
201                    //   sequence, the whole buffer can immediately be emmitted as pct-encoded.
202
203                    debug_assert!(
204                        c_buf[1..expected_char_len]
205                            .iter()
206                            .copied()
207                            .all(is_utf8_byte_continue),
208                        "[consistency] all non-first bytes have been \
209                         confirmed that they are UTF-8 continue bytes"
210                    );
211                    // Note that the first pct-encoded triplet is stripped from
212                    // `after_first_triplet`.
213                    rest = &after_first_triplet[((expected_char_len - 1) * 3)..];
214                    c_buf[0..expected_char_len]
215                        .iter()
216                        .try_for_each(|b| write!(f, "%{:02X}", b))?;
217                }
218            }
219        }
220
221        Ok(())
222    }
223}
224
225/// Writable as a normalized ASCII-only `host` (and optionally `port` followed).
226#[derive(Debug, Clone, Copy)]
227pub(crate) struct NormalizedAsciiOnlyHost<'a> {
228    /// Valid host (and additionaly port) to normalize.
229    host_port: &'a str,
230}
231
232impl<'a> NormalizedAsciiOnlyHost<'a> {
233    /// Creates a new `NormalizedAsciiOnlyHost` value.
234    ///
235    /// # Preconditions
236    ///
237    /// The given string should be the valid ASCII-only `host` or
238    /// `host ":" port` after percent-encoding normalization.
239    /// In other words, [`parser::trusted::is_ascii_only_host`] should return
240    /// true for the given value.
241    ///
242    /// [`parser::trusted::is_ascii_only_host`]: `crate::parser::trusted::is_ascii_only_host`
243    #[inline]
244    #[must_use]
245    pub(crate) fn new(host_port: &'a str) -> Self {
246        Self { host_port }
247    }
248}
249
250impl fmt::Display for NormalizedAsciiOnlyHost<'_> {
251    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
252        let mut rest = self.host_port;
253
254        while !rest.is_empty() {
255            // Scan the next percent-encoded triplet.
256            let (prefix, after_percent) = match find_split_hole(rest, b'%') {
257                Some(v) => v,
258                None => {
259                    return rest
260                        .chars()
261                        .try_for_each(|c| f.write_char(c.to_ascii_lowercase()));
262                }
263            };
264            // Write the string before the percent-encoded triplet.
265            prefix
266                .chars()
267                .try_for_each(|c| f.write_char(c.to_ascii_lowercase()))?;
268            // Decode the percent-encoded triplet.
269            let (first_decoded, after_triplet) = take_xdigits2(after_percent);
270            rest = after_triplet;
271
272            assert!(
273                first_decoded.is_ascii(),
274                "[consistency] this function requires ASCII-only host as an argument"
275            );
276
277            if is_ascii_unreserved(first_decoded) {
278                // Unreserved. Convert to lowercase and print.
279                f.write_char(char::from(first_decoded.to_ascii_lowercase()))?;
280            } else {
281                write!(f, "%{:02X}", first_decoded)?;
282            }
283        }
284
285        Ok(())
286    }
287}
288
289#[cfg(test)]
290#[cfg(feature = "alloc")]
291mod tests {
292    use super::*;
293
294    #[cfg(all(feature = "alloc", not(feature = "std")))]
295    use alloc::string::ToString;
296
297    use crate::spec::{IriSpec, UriSpec};
298
299    #[test]
300    fn invalid_utf8() {
301        assert_eq!(
302            PctCaseNormalized::<UriSpec>::new("%80%cc%cc%cc").to_string(),
303            "%80%CC%CC%CC"
304        );
305        assert_eq!(
306            PctCaseNormalized::<IriSpec>::new("%80%cc%cc%cc").to_string(),
307            "%80%CC%CC%CC"
308        );
309    }
310
311    #[test]
312    fn iri_unreserved() {
313        assert_eq!(
314            PctCaseNormalized::<UriSpec>::new("%ce%b1").to_string(),
315            "%CE%B1"
316        );
317        assert_eq!(
318            PctCaseNormalized::<IriSpec>::new("%ce%b1").to_string(),
319            "\u{03B1}"
320        );
321    }
322
323    #[test]
324    fn iri_middle_decode() {
325        assert_eq!(
326            PctCaseNormalized::<UriSpec>::new("%ce%ce%b1%b1").to_string(),
327            "%CE%CE%B1%B1"
328        );
329        assert_eq!(
330            PctCaseNormalized::<IriSpec>::new("%ce%ce%b1%b1").to_string(),
331            "%CE\u{03B1}%B1"
332        );
333    }
334
335    #[test]
336    fn ascii_reserved() {
337        assert_eq!(PctCaseNormalized::<UriSpec>::new("%3f").to_string(), "%3F");
338        assert_eq!(PctCaseNormalized::<IriSpec>::new("%3f").to_string(), "%3F");
339    }
340
341    #[test]
342    fn ascii_forbidden() {
343        assert_eq!(
344            PctCaseNormalized::<UriSpec>::new("%3c%3e").to_string(),
345            "%3C%3E"
346        );
347        assert_eq!(
348            PctCaseNormalized::<IriSpec>::new("%3c%3e").to_string(),
349            "%3C%3E"
350        );
351    }
352
353    #[test]
354    fn ascii_unreserved() {
355        assert_eq!(PctCaseNormalized::<UriSpec>::new("%7ea").to_string(), "~a");
356        assert_eq!(PctCaseNormalized::<IriSpec>::new("%7ea").to_string(), "~a");
357    }
358}