diff options
-rw-r--r-- | .gitignore | 2 | ||||
-rw-r--r-- | Cargo.toml | 10 | ||||
-rw-r--r-- | src/lib.rs | 83 |
3 files changed, 95 insertions, 0 deletions
diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..4fffb2f --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +/target +/Cargo.lock diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..c4179bd --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,10 @@ +[package] +name = "base64" +version = "1.0.0" +authors = ["Ben Bridle"] +edition = "2021" +description = "Encode and decode base64 strings" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..2992069 --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,83 @@ +#![feature(iter_array_chunks)] + +pub fn encode(bytes: &[u8]) -> String { + let remainder = bytes.len() % 3; + let bytes = match remainder { + 0 => bytes.iter().chain(&[0u8;0]), + 1 => bytes.iter().chain(&[0u8;2]), + 2 => bytes.iter().chain(&[0u8;1]), + _ => unreachable!(), + }; + let mut string = String::new(); + let mut triples = bytes.array_chunks::<3>(); + let mut n = 0u32; + let mut hextets = [0u32;4]; + while let Some(triple) = triples.next() { + for i in 0..3 { + n <<= 8; + n += *triple[i] as u32; + } + for i in (0..4).rev() { + hextets[i] = n & 0b111111; + n >>= 6; + } + for i in 0..4 { + let symbol = hextets[i]; + let code = match symbol { + 0..=25 => symbol + 65, // A..Z + 26..=51 => symbol + 71, // a..z + 52..=61 => symbol - 4, // 0..9 + 62 => 43, // '+' + 63 => 47, // '/' + _ => unreachable!(), + + }; + string.push(char::from_u32(code).unwrap()); + } + } + match remainder { + 0 => { string}, + 1 => {string.pop(); string.pop(); string.push_str("=="); string}, + 2 => { string.pop(); string.push('='); string}, + _ => unreachable!(), + } +} + +pub fn decode(string: &str) -> Result<Vec<u8>, String> { + if string.len() % 4 != 0 { + return Err("Input string length is not divisible by 4".into()) + } + + let mut bytes: Vec<u8> = Vec::new(); + let mut n = 0u32; + let mut padding = 0; + + for (i, c) in string.chars().enumerate() { + let code: u32 = c.into(); + if padding > 0 && code != 61 { + return Err(format!("Character at position {} is a non-terminating padding character", i-1)); + } + let hextet = match code { + 65..=90 => code - 65, // A..Z + 97..=122 => code - 71, // a..z + 48..=57 => code + 4, // 0..9 + 43 => 62, // '+' + 47 => 63, // '/' + 61 => {padding += 1; 0}, // '=' + _ => return Err(format!("Character '{c}' at position {i} is not a member of the base64 alphabet")), + }; + n <<= 6; + n += hextet; + if i & 0b11 == 3 { + bytes.extend(&n.to_be_bytes()[1..4]); + n = 0; + } + } + if padding > 2 { + return Err(format!("Input string contains more than 2 terminating padding characters")) + } + for _ in 0..padding { + bytes.pop(); + } + Ok(bytes) +} |