summaryrefslogtreecommitdiff
path: root/src/lib.rs
blob: 2992069612798465d7d1aea4531be5fcaff440df (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
#![feature(iter_array_chunks)]

pub fn encode(bytes: &[u8]) -> String {
    let remainder = bytes.len() % 3;
    let bytes = match remainder {
        0 => bytes.iter().chain(&[0u8;0]),
        1 => bytes.iter().chain(&[0u8;2]),
        2 => bytes.iter().chain(&[0u8;1]),
        _ => unreachable!(),
    };
    let mut string = String::new();
    let mut triples = bytes.array_chunks::<3>();
    let mut n = 0u32;
    let mut hextets = [0u32;4];
    while let Some(triple) = triples.next() {
        for i in 0..3 {
            n <<= 8;
            n += *triple[i] as u32;
        }
        for i in (0..4).rev() {
            hextets[i] = n & 0b111111;
            n >>= 6;
        }
        for i in 0..4 {
            let symbol = hextets[i];
            let code = match symbol {
                 0..=25  =>     symbol + 65,  // A..Z
                26..=51  =>     symbol + 71,  // a..z
                52..=61  =>     symbol -  4,  // 0..9
                62       =>              43,  // '+'
                63       =>              47,  // '/'
                _        =>  unreachable!(),

            };
            string.push(char::from_u32(code).unwrap());
        }
    }
    match remainder {
        0 => {                                                   string},
        1 => {string.pop(); string.pop(); string.push_str("=="); string},
        2 => {              string.pop();      string.push('='); string},
        _ => unreachable!(),
    }
}

pub fn decode(string: &str) -> Result<Vec<u8>, String> {
    if string.len() % 4 != 0 {
        return Err("Input string length is not divisible by 4".into())
    }

    let mut bytes: Vec<u8> = Vec::new();
    let mut n = 0u32;
    let mut padding = 0;

    for (i, c) in string.chars().enumerate() {
        let code: u32 = c.into();
        if padding > 0 && code != 61 {
            return Err(format!("Character at position {} is a non-terminating padding character", i-1));
        }
        let hextet = match code {
            65..=90  =>        code - 65,  // A..Z
            97..=122 =>        code - 71,  // a..z
            48..=57  =>        code +  4,  // 0..9
            43       =>               62,  // '+'
            47       =>               63,  // '/'
            61       => {padding += 1; 0}, // '='
            _ => return Err(format!("Character '{c}' at position {i} is not a member of the base64 alphabet")),
        };
        n <<= 6;
        n += hextet;
        if i & 0b11 == 3 {
            bytes.extend(&n.to_be_bytes()[1..4]);
            n = 0;
        }
    }
    if padding > 2 {
        return Err(format!("Input string contains more than 2 terminating padding characters"))
    }
    for _ in 0..padding {
        bytes.pop();
    }
    Ok(bytes)
}