#![feature(iter_array_chunks)] pub fn encode(bytes: &[u8]) -> String { let remainder = bytes.len() % 3; let bytes = match remainder { 0 => bytes.iter().chain(&[0u8;0]), 1 => bytes.iter().chain(&[0u8;2]), 2 => bytes.iter().chain(&[0u8;1]), _ => unreachable!(), }; let mut string = String::new(); let mut triples = bytes.array_chunks::<3>(); let mut n = 0u32; let mut hextets = [0u32;4]; while let Some(triple) = triples.next() { for i in 0..3 { n <<= 8; n += *triple[i] as u32; } for i in (0..4).rev() { hextets[i] = n & 0b111111; n >>= 6; } for i in 0..4 { let symbol = hextets[i]; let code = match symbol { 0..=25 => symbol + 65, // A..Z 26..=51 => symbol + 71, // a..z 52..=61 => symbol - 4, // 0..9 62 => 43, // '+' 63 => 47, // '/' _ => unreachable!(), }; string.push(char::from_u32(code).unwrap()); } } match remainder { 0 => { string}, 1 => {string.pop(); string.pop(); string.push_str("=="); string}, 2 => { string.pop(); string.push('='); string}, _ => unreachable!(), } } pub fn decode(string: &str) -> Result, String> { if string.len() % 4 != 0 { return Err("Input string length is not divisible by 4".into()) } let mut bytes: Vec = Vec::new(); let mut n = 0u32; let mut padding = 0; for (i, c) in string.chars().enumerate() { let code: u32 = c.into(); if padding > 0 && code != 61 { return Err(format!("Character at position {} is a non-terminating padding character", i-1)); } let hextet = match code { 65..=90 => code - 65, // A..Z 97..=122 => code - 71, // a..z 48..=57 => code + 4, // 0..9 43 => 62, // '+' 47 => 63, // '/' 61 => {padding += 1; 0}, // '=' _ => return Err(format!("Character '{c}' at position {i} is not a member of the base64 alphabet")), }; n <<= 6; n += hextet; if i & 0b11 == 3 { bytes.extend(&n.to_be_bytes()[1..4]); n = 0; } } if padding > 2 { return Err(format!("Input string contains more than 2 terminating padding characters")) } for _ in 0..padding { bytes.pop(); } Ok(bytes) }