idna/
punycode.rs

1// Copyright 2013 The rust-url developers.
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
6// option. This file may not be copied, modified, or distributed
7// except according to those terms.
8
9//! Punycode ([RFC 3492](http://tools.ietf.org/html/rfc3492)) implementation.
10//!
11//! Since Punycode fundamentally works on unicode code points,
12//! `encode` and `decode` take and return slices and vectors of `char`.
13//! `encode_str` and `decode_to_string` provide convenience wrappers
14//! that convert from and to Rust’s UTF-8 based `str` and `String` types.
15
16use alloc::{string::String, vec::Vec};
17use core::char;
18use core::u32;
19
20// Bootstring parameters for Punycode
21static BASE: u32 = 36;
22static T_MIN: u32 = 1;
23static T_MAX: u32 = 26;
24static SKEW: u32 = 38;
25static DAMP: u32 = 700;
26static INITIAL_BIAS: u32 = 72;
27static INITIAL_N: u32 = 0x80;
28static DELIMITER: char = '-';
29
30#[inline]
31fn adapt(mut delta: u32, num_points: u32, first_time: bool) -> u32 {
32    delta /= if first_time { DAMP } else { 2 };
33    delta += delta / num_points;
34    let mut k = 0;
35    while delta > ((BASE - T_MIN) * T_MAX) / 2 {
36        delta /= BASE - T_MIN;
37        k += BASE;
38    }
39    k + (((BASE - T_MIN + 1) * delta) / (delta + SKEW))
40}
41
42/// Convert Punycode to an Unicode `String`.
43///
44/// This is a convenience wrapper around `decode`.
45#[inline]
46pub fn decode_to_string(input: &str) -> Option<String> {
47    decode(input).map(|chars| chars.into_iter().collect())
48}
49
50/// Convert Punycode to Unicode.
51///
52/// Return None on malformed input or overflow.
53/// Overflow can only happen on inputs that take more than
54/// 63 encoded bytes, the DNS limit on domain name labels.
55pub fn decode(input: &str) -> Option<Vec<char>> {
56    Some(Decoder::default().decode(input).ok()?.collect())
57}
58
59#[derive(Default)]
60pub(crate) struct Decoder {
61    insertions: Vec<(usize, char)>,
62}
63
64impl Decoder {
65    /// Split the input iterator and return a Vec with insertions of encoded characters
66    pub(crate) fn decode<'a>(&'a mut self, input: &'a str) -> Result<Decode<'a>, ()> {
67        self.insertions.clear();
68        // Handle "basic" (ASCII) code points.
69        // They are encoded as-is before the last delimiter, if any.
70        let (base, input) = match input.rfind(DELIMITER) {
71            None => ("", input),
72            Some(position) => (
73                &input[..position],
74                if position > 0 {
75                    &input[position + 1..]
76                } else {
77                    input
78                },
79            ),
80        };
81
82        if !base.is_ascii() {
83            return Err(());
84        }
85
86        let base_len = base.len();
87        let mut length = base_len as u32;
88        let mut code_point = INITIAL_N;
89        let mut bias = INITIAL_BIAS;
90        let mut i = 0;
91        let mut iter = input.bytes();
92        loop {
93            let previous_i = i;
94            let mut weight = 1;
95            let mut k = BASE;
96            let mut byte = match iter.next() {
97                None => break,
98                Some(byte) => byte,
99            };
100
101            // Decode a generalized variable-length integer into delta,
102            // which gets added to i.
103            loop {
104                let digit = match byte {
105                    byte @ b'0'..=b'9' => byte - b'0' + 26,
106                    byte @ b'A'..=b'Z' => byte - b'A',
107                    byte @ b'a'..=b'z' => byte - b'a',
108                    _ => return Err(()),
109                } as u32;
110                if digit > (u32::MAX - i) / weight {
111                    return Err(()); // Overflow
112                }
113                i += digit * weight;
114                let t = if k <= bias {
115                    T_MIN
116                } else if k >= bias + T_MAX {
117                    T_MAX
118                } else {
119                    k - bias
120                };
121                if digit < t {
122                    break;
123                }
124                if weight > u32::MAX / (BASE - t) {
125                    return Err(()); // Overflow
126                }
127                weight *= BASE - t;
128                k += BASE;
129                byte = match iter.next() {
130                    None => return Err(()), // End of input before the end of this delta
131                    Some(byte) => byte,
132                };
133            }
134
135            bias = adapt(i - previous_i, length + 1, previous_i == 0);
136            if i / (length + 1) > u32::MAX - code_point {
137                return Err(()); // Overflow
138            }
139
140            // i was supposed to wrap around from length+1 to 0,
141            // incrementing code_point each time.
142            code_point += i / (length + 1);
143            i %= length + 1;
144            let c = match char::from_u32(code_point) {
145                Some(c) => c,
146                None => return Err(()),
147            };
148
149            // Move earlier insertions farther out in the string
150            for (idx, _) in &mut self.insertions {
151                if *idx >= i as usize {
152                    *idx += 1;
153                }
154            }
155            self.insertions.push((i as usize, c));
156            length += 1;
157            i += 1;
158        }
159
160        self.insertions.sort_by_key(|(i, _)| *i);
161        Ok(Decode {
162            base: base.chars(),
163            insertions: &self.insertions,
164            inserted: 0,
165            position: 0,
166            len: base_len + self.insertions.len(),
167        })
168    }
169}
170
171pub(crate) struct Decode<'a> {
172    base: core::str::Chars<'a>,
173    pub(crate) insertions: &'a [(usize, char)],
174    inserted: usize,
175    position: usize,
176    len: usize,
177}
178
179impl<'a> Iterator for Decode<'a> {
180    type Item = char;
181
182    fn next(&mut self) -> Option<Self::Item> {
183        loop {
184            match self.insertions.get(self.inserted) {
185                Some((pos, c)) if *pos == self.position => {
186                    self.inserted += 1;
187                    self.position += 1;
188                    return Some(*c);
189                }
190                _ => {}
191            }
192            if let Some(c) = self.base.next() {
193                self.position += 1;
194                return Some(c);
195            } else if self.inserted >= self.insertions.len() {
196                return None;
197            }
198        }
199    }
200
201    fn size_hint(&self) -> (usize, Option<usize>) {
202        let len = self.len - self.position;
203        (len, Some(len))
204    }
205}
206
207impl<'a> ExactSizeIterator for Decode<'a> {
208    fn len(&self) -> usize {
209        self.len - self.position
210    }
211}
212
213/// Convert an Unicode `str` to Punycode.
214///
215/// This is a convenience wrapper around `encode`.
216#[inline]
217pub fn encode_str(input: &str) -> Option<String> {
218    if input.len() > u32::MAX as usize {
219        return None;
220    }
221    let mut buf = String::with_capacity(input.len());
222    encode_into(input.chars(), &mut buf).ok().map(|()| buf)
223}
224
225/// Convert Unicode to Punycode.
226///
227/// Return None on overflow, which can only happen on inputs that would take more than
228/// 63 encoded bytes, the DNS limit on domain name labels.
229pub fn encode(input: &[char]) -> Option<String> {
230    if input.len() > u32::MAX as usize {
231        return None;
232    }
233    let mut buf = String::with_capacity(input.len());
234    encode_into(input.iter().copied(), &mut buf)
235        .ok()
236        .map(|()| buf)
237}
238
239pub(crate) fn encode_into<I>(input: I, output: &mut String) -> Result<(), ()>
240where
241    I: Iterator<Item = char> + Clone,
242{
243    // Handle "basic" (ASCII) code points. They are encoded as-is.
244    let (mut input_length, mut basic_length) = (0u32, 0);
245    for c in input.clone() {
246        input_length = input_length.checked_add(1).ok_or(())?;
247        if c.is_ascii() {
248            output.push(c);
249            basic_length += 1;
250        }
251    }
252
253    if basic_length > 0 {
254        output.push('-')
255    }
256    let mut code_point = INITIAL_N;
257    let mut delta = 0;
258    let mut bias = INITIAL_BIAS;
259    let mut processed = basic_length;
260    while processed < input_length {
261        // All code points < code_point have been handled already.
262        // Find the next larger one.
263        let min_code_point = input
264            .clone()
265            .map(|c| c as u32)
266            .filter(|&c| c >= code_point)
267            .min()
268            .unwrap();
269        if min_code_point - code_point > (u32::MAX - delta) / (processed + 1) {
270            return Err(()); // Overflow
271        }
272        // Increase delta to advance the decoder’s <code_point,i> state to <min_code_point,0>
273        delta += (min_code_point - code_point) * (processed + 1);
274        code_point = min_code_point;
275        for c in input.clone() {
276            let c = c as u32;
277            if c < code_point {
278                delta = delta.checked_add(1).ok_or(())?;
279            }
280            if c == code_point {
281                // Represent delta as a generalized variable-length integer:
282                let mut q = delta;
283                let mut k = BASE;
284                loop {
285                    let t = if k <= bias {
286                        T_MIN
287                    } else if k >= bias + T_MAX {
288                        T_MAX
289                    } else {
290                        k - bias
291                    };
292                    if q < t {
293                        break;
294                    }
295                    let value = t + ((q - t) % (BASE - t));
296                    output.push(value_to_digit(value));
297                    q = (q - t) / (BASE - t);
298                    k += BASE;
299                }
300                output.push(value_to_digit(q));
301                bias = adapt(delta, processed + 1, processed == basic_length);
302                delta = 0;
303                processed += 1;
304            }
305        }
306        delta += 1;
307        code_point += 1;
308    }
309    Ok(())
310}
311
312#[inline]
313fn value_to_digit(value: u32) -> char {
314    match value {
315        0..=25 => (value as u8 + b'a') as char,       // a..z
316        26..=35 => (value as u8 - 26 + b'0') as char, // 0..9
317        _ => panic!(),
318    }
319}
320
321#[test]
322#[ignore = "slow"]
323#[cfg(target_pointer_width = "64")]
324fn huge_encode() {
325    let mut buf = String::new();
326    assert!(encode_into(std::iter::repeat('ß').take(u32::MAX as usize + 1), &mut buf).is_err());
327    assert_eq!(buf.len(), 0);
328}