use crate::*;

use std::mem::take;


/// Resolve symbol references across source units.
pub struct SymbolResolver {
    pub definitions: Vec<TrackedSymbol>,
    /// All resolved references.
    pub resolved: Vec<TrackedSymbol>,
    /// All unresolved references.
    pub unresolved: Vec<TrackedSymbol>,
    /// Contains the ID of the owner of the original definition.
    pub redefinitions: Vec<(TrackedSymbol, usize)>,
    pub source_units: Vec<HeirarchicalSourceUnit>,
    pub root_unit_ids: Vec<usize>,
    pub unused_library_units: Vec<SourceUnit>,
}


impl SymbolResolver {
    /// Construct a resolver from a root source unit.
    pub fn from_source_unit(source_unit: SourceUnit) -> Self {
        let mut new = Self {
            definitions: Vec::new(),
            resolved: Vec::new(),
            unresolved: Vec::new(),
            redefinitions: Vec::new(),
            source_units: Vec::new(),
            root_unit_ids: Vec::new(),
            unused_library_units: Vec::new(),
        };
        new.add_source_unit(source_unit, None);
        return new;
    }

    pub fn add_library_units(&mut self, mut source_units: Vec<SourceUnit>) {
        self.unused_library_units.append(&mut source_units);
    }

    pub fn resolve(&mut self) {
        // Repeatedly test if any unused source unit resolves an unresolved symbol,
        // breaking the loop when no new resolutions are found.
        'outer: loop {
            for (i, source_unit) in self.unused_library_units.iter().enumerate() {
                if let Some(id) = self.resolves_reference(&source_unit) {
                    let source_unit = self.unused_library_units.remove(i);
                    self.add_source_unit(source_unit, Some(id));
                    continue 'outer;
                }
            }
            break;
        }

        // For every macro reference in every unit, find the ID of the unit which
        // resolves that reference and add it to the .parent_ids field of the
        // referencing unit.
        for reference in &self.resolved {
            let predicate = |d: &&TrackedSymbol| d.symbol.name == reference.symbol.name;
            if let Some(definition) = self.definitions.iter().find(predicate) {
                let is_self = reference.source_id == definition.source_id;
                let is_label = definition.symbol.variant == SymbolVariant::LabelDefinition;
                if  is_self || is_label { continue; }
                let referencing_unit = &mut self.source_units[reference.source_id];
                referencing_unit.parent_ids.push(definition.source_id);
            };
        }
    }

    /// Add a source unit to the resolver and link it to a parent unit.
    pub fn add_source_unit(&mut self, mut source_unit: SourceUnit, parent_id: Option<usize>) {
        let source_id = self.source_units.len();

        // Add all main symbols.
        if let Some(definitions) = take(&mut source_unit.main.symbols.definitions) {
            self.add_definitions(definitions, source_id, SourceRole::Main); }
        if let Some(references) = take(&mut source_unit.main.symbols.references) {
            self.add_references(references, source_id, SourceRole::Main); }

        // Add all head symbols.
        if let Some(head) = &mut source_unit.head {
            if let Some(references) = take(&mut head.symbols.references) {
                self.add_references(references, source_id, SourceRole::Head); }
            if let Some(definitions) = take(&mut head.symbols.definitions) {
                self.add_definitions(definitions, source_id, SourceRole::Head); }
        }

        // Add all tail symbols.
        if let Some(tail) = &mut source_unit.tail {
            if let Some(references) = take(&mut tail.symbols.references) {
                self.add_references(references, source_id, SourceRole::Tail); }
            if let Some(definitions) = take(&mut tail.symbols.definitions) {
                self.add_definitions(definitions, source_id, SourceRole::Tail); }
        }

        if let Some(parent_id) = parent_id {
            if let Some(parent_unit) = self.source_units.get_mut(parent_id) {
                parent_unit.child_ids.push(source_id);
            }
        } else {
            self.root_unit_ids.push(source_id);
        }

        self.source_units.push(
            HeirarchicalSourceUnit {
                source_unit,
                child_ids: Vec::new(),
                parent_ids: Vec::new(),
            }
        );
    }

    fn add_references(&mut self, references: Vec<Symbol>, source_id: usize, source_role: SourceRole) {
        for symbol in references {
            let reference = TrackedSymbol { symbol, source_id, source_role };
            match self.definitions.contains(&reference) {
                true => self.resolved.push(reference),
                false => self.unresolved.push(reference),
            }
        }
    }

    fn add_definitions(&mut self, definitions: Vec<Symbol>, source_id: usize, source_role: SourceRole) {
        for symbol in definitions {
            let predicate = |d: &&TrackedSymbol| { &d.symbol.name == &symbol.name };
            if let Some(def) = self.definitions.iter().find(predicate) {
                let definition = TrackedSymbol { symbol, source_id, source_role };
                let redefinition = (definition, def.source_id);
                self.redefinitions.push(redefinition);
            } else {
                let predicate = |s: &mut TrackedSymbol| s.symbol.name == symbol.name;
                for symbol in self.unresolved.extract_if(predicate) {
                    self.resolved.push(symbol);
                }
                self.unresolved.retain(|s| s.symbol.name != symbol.name);
                let definition = TrackedSymbol { symbol, source_id, source_role };
                self.definitions.push(definition);
            }
        }
    }

    /// Returns the ID of the owner of a symbol resolved by this unit.
    pub fn resolves_reference(&self, source_unit: &SourceUnit) -> Option<usize> {
        if let Some(definitions) = &source_unit.main.symbols.definitions {
            if let Some(id) = self.source_id_of_unresolved(&definitions) {
                return Some(id);
            }
        }
        if let Some(head) = &source_unit.head {
            if let Some(definitions) = &head.symbols.definitions {
                if let Some(id) = self.source_id_of_unresolved(&definitions) {
                    return Some(id);
                }
            }
        }
        if let Some(tail) = &source_unit.tail {
            if let Some(definitions) = &tail.symbols.definitions {
                if let Some(id) = self.source_id_of_unresolved(&definitions) {
                    return Some(id);
                }
            }
        }
        return None;
    }

    /// Returns the ID of the owner of a reference to one of these symbols.
    fn source_id_of_unresolved(&self, symbols: &[Symbol]) -> Option<usize> {
        for symbol in symbols {
            let opt = self.unresolved.iter().find(|s| s.symbol.name == symbol.name);
            if let Some(unresolved) = opt {
                return Some(unresolved.source_id);
            }
        }
        return None;
    }

    pub fn get_source_code_for_tracked_symbol(&self, symbol: &TrackedSymbol) -> &str {
        let source_unit = &self.source_units[symbol.source_id].source_unit;
        match symbol.source_role {
            SourceRole::Main => source_unit.main.symbols.source_code.as_str(),
            SourceRole::Head => match &source_unit.head {
                Some(head) => head.symbols.source_code.as_str(),
                None => unreachable!("Failed to find source for token"),
            }
            SourceRole::Tail => match &source_unit.tail {
                Some(tail) => tail.symbols.source_code.as_str(),
                None => unreachable!("Failed to find source for token"),
            }
        }
    }

    /// Create a source file by concatenating all source units.
    /// If the source unit dependency graph contains a cycle, the IDs of the
    /// source units involved in the cycle will be returned.
    pub fn get_merged_source_code(&self) -> Result<String, Vec<usize>> {
        // The ID of a given source unit will come after the IDs of all
        // source units which define at least one symbol referenced in the
        // given source unit.
        let source_order = {
            let mut included_source_ids: Vec<usize> = Vec::new();
            let mut remaining_source_ids: Vec<usize> = Vec::new();
            // Reverse the order so that the root unit is the last to be added.
            for i in (0..self.source_units.len()).rev() {
                remaining_source_ids.push(i);
            }

            'restart: while !remaining_source_ids.is_empty() {
                'next: for (i, id) in remaining_source_ids.iter().enumerate() {
                    let unit = &self.source_units[*id];
                    for parent_id in &unit.parent_ids {
                        if !included_source_ids.contains(&parent_id) {
                            continue 'next;
                        }
                    }
                    included_source_ids.push(*id);
                    remaining_source_ids.remove(i);
                    continue 'restart;
                }
                // All remaining source units depend on at least one remaining
                // source unit, indicating a dependency cycle.
                return Err(remaining_source_ids);
            }
            included_source_ids
        };

        let mut source_code = String::new();

        // Push head source code.
        for id in &source_order {
            let source_unit = &self.source_units[*id];
            if let Some(head) = &source_unit.source_unit.head {
                push_source_code_to_string(&mut source_code, head);
            }
        }
        // Push main source code.
        for id in source_order.iter().rev() {
            let source_unit = &self.source_units[*id];
            let main = &source_unit.source_unit.main;
            push_source_code_to_string(&mut source_code, &main);
        }
        // Push tail source code.
        for id in &source_order {
            let source_unit = &self.source_units[*id];
            if let Some(tail) = &source_unit.source_unit.tail {
                push_source_code_to_string(&mut source_code, tail);
            }
        }
        return Ok(source_code);
    }
}


fn push_source_code_to_string(string: &mut String, source_file: &SourceFile) {
    // Ensure that sections are separated by two newlines.
    if !string.is_empty() {
        if !string.ends_with('\n') { string.push('\n'); }
        if !string.ends_with("\n\n") { string.push('\n'); }
    }
    // Write a path comment to the string.
    let path_str = source_file.path.as_os_str().to_string_lossy();
    let path_comment = format!("(: {path_str} )\n");
    string.push_str(&path_comment);
    string.push_str(&source_file.symbols.source_code);
}


pub struct HeirarchicalSourceUnit {
    pub source_unit: SourceUnit,
    /// IDs of units which were added to resolve symbol references this unit.
    pub child_ids: Vec<usize>,
    /// IDs of units which resolve macro references in this unit.
    pub parent_ids: Vec<usize>,
}


pub struct TrackedSymbol {
    pub symbol: Symbol,
    pub source_id: usize,
    pub source_role: SourceRole,
}


#[derive(Clone, Copy)]
pub enum SourceRole {
    Main,
    Head,
    Tail,
}


impl PartialEq for TrackedSymbol {
    fn eq(&self, other: &TrackedSymbol) -> bool {
        self.symbol.name.eq(&other.symbol.name)
    }
}