use std::mem::take;
use std::collections::hash_map::Entry;

use SyntacticTokenType as Syn;
use SemanticTokenType as Sem;
use crate::*;

use std::collections::HashMap;

/// The inner value is the index of the token that defines this symbol.
pub enum SymbolDefinition {
    Macro(usize),
    Label(usize),
}

pub struct Assembler {
    /// The contents of the program as a list of syntactic tokens.
    syntactic_tokens: Vec<SyntacticToken>,
    /// The contents of the program as a list of semantic tokens.
    semantic_tokens: Vec<SemanticToken>,
    /// Map the name of each defined symbol to the index of the defining token.
    symbol_definitions: HashMap<String, SymbolDefinition>,
    /// Map each macro definition token index to a list of syntactic body tokens.
    syntactic_macro_bodies: HashMap<usize, Vec<SyntacticToken>>,
    /// Map each macro definition token index to a list of semantic body tokens.
    semantic_macro_bodies: HashMap<usize, Vec<SemanticToken>>,
}

impl Assembler {
    pub fn new() -> Self {
        Self {
            syntactic_tokens: Vec::new(),
            semantic_tokens: Vec::new(),
            symbol_definitions: HashMap::new(),
            syntactic_macro_bodies: HashMap::new(),
            semantic_macro_bodies: HashMap::new(),
        }
    }

    pub fn tokenise_source(&mut self, source_code: &str) {
        // The index of the current macro definition token
        let mut macro_definition: Option<usize> = None;
        let mut macro_definition_body_tokens: Vec<SyntacticToken> = Vec::new();

        for mut token in TokenIterator::from_str(source_code) {
            let next_index = self.syntactic_tokens.len();
            if let Some(index) = macro_definition {
                token.use_in_macro_body();
                if token.is_macro_terminator() {
                    // Commit the current macro definition
                    macro_definition_body_tokens.push(token);
                    self.syntactic_macro_bodies.insert(
                        index, take(&mut macro_definition_body_tokens));
                    macro_definition = None;
                } else {
                    macro_definition_body_tokens.push(token);
                }
            } else {
                if let Syn::MacroDefinition(ref name) = token.r#type {
                    macro_definition = Some(next_index);
                    match self.symbol_definitions.entry(name.to_string()) {
                        Entry::Occupied(_) => {token.set_error(Error::DuplicateDefinition);}
                        Entry::Vacant(v) => {v.insert(SymbolDefinition::Macro(next_index));}
                    }
                } else if let Syn::LabelDefinition(ref name) = token.r#type {
                    match self.symbol_definitions.entry(name.to_string()) {
                        Entry::Occupied(_) => {token.set_error(Error::DuplicateDefinition);}
                        Entry::Vacant(v) => {v.insert(SymbolDefinition::Label(next_index));}
                    }
                } else if token.is_macro_terminator() {
                    token.set_error(Error::OrphanedMacroDefinitionTerminator);
                }
                self.syntactic_tokens.push(token);
            }
        }
    }

    pub fn resolve_references(&mut self) {
        let syntactic_tokens = take(&mut self.syntactic_tokens);
        let syntactic_token_count = syntactic_tokens.len();
        let mut parent_label = None;

        for (index, syntactic_token) in syntactic_tokens.into_iter().enumerate() {
            if let SyntacticTokenType::LabelDefinition(name) = &syntactic_token.r#type {
                parent_label = Some(name.to_owned());
            }
            let semantic_token = self.convert_syn_token_to_sem_token(syntactic_token, index, parent_label.clone());
            self.semantic_tokens.push(semantic_token);
        }
        assert_eq!(syntactic_token_count, self.semantic_tokens.len());

        // Find all cyclic macros
        let cyclic_macros: Vec<usize> = self.semantic_macro_bodies.keys().map(|i|*i).filter(
            |i| !self.traverse_macro_definition(*i, 0)).collect();
        // Replace each cyclic macro reference in a macro definition with an error
        for body_tokens in &mut self.semantic_macro_bodies.values_mut() {
            for body_token in body_tokens {
                if let Sem::MacroReference(i) = body_token.r#type {
                    if cyclic_macros.contains(&i) {
                        let name = body_token.source_location.source.clone();
                        body_token.r#type = Sem::Error(Syn::Reference(name), Error::CyclicMacroReference);
                    }
                }
            }
        }

    }

    /// Attempt to recursively traverse the body tokens of a macro definition, returning
    /// false if the depth exceeds a preset maximum, and returning true otherwise.
    fn traverse_macro_definition(&self, index: usize, level: usize) -> bool {
        if level == 16 {
            false
        } else {
            self.semantic_macro_bodies[&index].iter().all(
                |token| if let Sem::MacroReference(i) = token.r#type {
                    self.traverse_macro_definition(i, level+1)
                } else {
                    true
                }
            )
        }
    }

    pub fn generate_bytecode(&mut self) -> (Vec<u8>, Vec<SemanticToken>) {
        let mut bytecode: Vec<u8> = Vec::new();
        // Map each label definition token index to the bytecode addresses of the references
        let mut reference_addresses: HashMap<usize, Vec<u16>> = HashMap::new();
        // Map each label and macro definition token to a list of reference token indices
        let mut reference_tokens: HashMap<usize, Vec<usize>> = HashMap::new();

        macro_rules! push_u8 {($v:expr) => {bytecode.push($v)};}
        macro_rules! push_u16 {($v:expr) => {bytecode.extend_from_slice(&u16::to_be_bytes($v))};}
        macro_rules! pad {($p:expr) => {bytecode.resize((bytecode.len() + $p as usize), 0)};}

        let mut semantic_tokens = take(&mut self.semantic_tokens);

        // Translate semantic tokens into bytecode
        for (index, semantic_token) in semantic_tokens.iter_mut().enumerate() {
            let start_addr = bytecode.len() as u16;
            match &mut semantic_token.r#type {
                Sem::LabelReference(i) => {
                    reference_tokens.entry(*i).or_default().push(index);
                    reference_addresses.entry(*i).or_default().push(start_addr);
                    push_u16!(0);
                }
                Sem::MacroReference(i) => {
                    reference_tokens.entry(*i).or_default().push(index);
                    self.expand_macro_reference(*i, &mut bytecode, &mut reference_addresses);
                }
                Sem::LabelDefinition(def) => def.address=start_addr,
                Sem::MacroDefinition(_) => (),

                Sem::Padding(p) => pad!(*p),
                Sem::ByteLiteral(b) => push_u8!(*b),
                Sem::ShortLiteral(s) => push_u16!(*s),
                Sem::Instruction(b) => push_u8!(*b),

                Sem::MacroDefinitionTerminator => unreachable!(),
                Sem::Comment => (),
                Sem::Error(..) => (),
            };
            let end_addr = bytecode.len() as u16;
            semantic_token.bytecode_location.start = start_addr;
            semantic_token.bytecode_location.length = end_addr - start_addr;
        }

        // Fill each label reference with the address of the matching label definition
        for (index, slots) in reference_addresses {
            if let Sem::LabelDefinition(definition) = &semantic_tokens[index].r#type {
                let [h,l] = definition.address.to_be_bytes();
                for slot in slots {
                    bytecode[slot as usize] = h;
                    bytecode[slot.wrapping_add(1) as usize] = l;
                }
            } else { unreachable!() }
        }

        // Move references and macro body tokens into label and macro definition tokens
        for (index, semantic_token) in semantic_tokens.iter_mut().enumerate() {
            if let Sem::MacroDefinition(definition) = &mut semantic_token.r#type {
                definition.body_tokens = self.semantic_macro_bodies.remove(&index).unwrap();
                if let Some(references) = reference_tokens.remove(&index) {
                    definition.references = references;
                }
            } else if let Sem::LabelDefinition(definition) = &mut semantic_token.r#type {
                if let Some(references) = reference_tokens.remove(&index) {
                    definition.references = references;
                }
            }
        }
        assert_eq!(reference_tokens.len(), 0);

        // Remove trailing null bytes from the bytecode
        if let Some(final_nonnull_byte) = bytecode.iter().rposition(|b| *b != 0) {
            let truncated_length = final_nonnull_byte + 1;
            let removed_byte_count = bytecode.len() - truncated_length;
            if removed_byte_count > 0 {
                bytecode.truncate(truncated_length);
            }
        }

        (bytecode, semantic_tokens)
    }

    fn convert_syn_token_to_sem_token(&mut self, mut syn_token: SyntacticToken, index: usize, parent_label: Option<String>) -> SemanticToken {
        SemanticToken {
            r#type: {
                if let Some(err) = syn_token.error {
                    Sem::Error(syn_token.r#type, err)
                } else {
                    match syn_token.r#type {
                        Syn::Reference(ref name) => {
                            match self.symbol_definitions.get(name) {
                                Some(SymbolDefinition::Macro(i)) => Sem::MacroReference(*i),
                                Some(SymbolDefinition::Label(i)) => Sem::LabelReference(*i),
                                None => Sem::Error(syn_token.r#type, Error::UnresolvedReference),
                            }
                        }
                        Syn::LabelDefinition(name) => {Sem::LabelDefinition(LabelDefinition::new(name))},
                        Syn::MacroDefinition(name) => {
                            let mut sem_body_tokens = Vec::new();
                            for syn_body_token in self.syntactic_macro_bodies.remove(&index).unwrap() {
                                // Make the source location of the macro definition token span the entire definition
                                if syn_body_token.is_macro_terminator() {
                                    syn_token.source_location.end = syn_body_token.source_location.start;
                                }
                                let sem_body_token = self.convert_syn_token_to_sem_token(syn_body_token, 0, parent_label.clone());
                                sem_body_tokens.push(sem_body_token);
                            }
                            self.semantic_macro_bodies.insert(index, sem_body_tokens);
                            Sem::MacroDefinition(MacroDefinition::new(name))
                        },
                        Syn::MacroDefinitionTerminator => Sem::MacroDefinitionTerminator,
                        Syn::Padding(v) => Sem::Padding(v),
                        Syn::ByteLiteral(v) => Sem::ByteLiteral(v),
                        Syn::ShortLiteral(v) => Sem::ShortLiteral(v),
                        Syn::Instruction(v) => Sem::Instruction(v),
                        Syn::Comment => Sem::Comment,
                    }
                }
            },
            source_location: syn_token.source_location,
            bytecode_location: BytecodeLocation::zero(),
            parent_label,
        }
    }

    fn expand_macro_reference(&self, index: usize, bytecode: &mut Vec<u8>, reference_addresses: &mut HashMap<usize, Vec<u16>>) {
        macro_rules! push_u8 {($v:expr) => {bytecode.push($v)};}
        macro_rules! push_u16 {($v:expr) => {bytecode.extend_from_slice(&u16::to_be_bytes($v))};}
        macro_rules! pad {($p:expr) => {bytecode.resize((bytecode.len() + $p as usize), 0)};}

        for body_token in self.semantic_macro_bodies.get(&index).unwrap() {
            let start_addr = bytecode.len() as u16;
            match &body_token.r#type {
                Sem::LabelReference(i) => {
                    reference_addresses.entry(*i).or_default().push(start_addr);
                    push_u16!(0u16);
                },
                Sem::MacroReference(i) => {
                    self.expand_macro_reference(*i, bytecode, reference_addresses);
                },
                Sem::LabelDefinition(_) => unreachable!(),
                Sem::MacroDefinition(_) => unreachable!(),

                Sem::Padding(p) => pad!(*p),
                Sem::ByteLiteral(b) => push_u8!(*b),
                Sem::ShortLiteral(s) => push_u16!(*s),
                Sem::Instruction(b) => push_u8!(*b),

                Sem::MacroDefinitionTerminator => (),
                Sem::Comment => (),
                Sem::Error(..) => (),
            };
        }
    }
}