use crate::*;

use std::mem::take;


/// Resolve symbol references across source units.
pub struct SymbolResolver {
    pub definitions: Vec<TrackedSymbol>,
    pub unresolved: Vec<TrackedSymbol>,
    /// Contains the ID of the owner of the original definition.
    pub redefinitions: Vec<(TrackedSymbol, usize)>,
    pub source_units: Vec<HeirarchicalSourceUnit>,
    pub root_unit_ids: Vec<usize>,
    pub unused_library_units: Vec<SourceUnit>,
}


impl SymbolResolver {
    /// Construct a resolver from a root source unit.
    pub fn from_source_unit(source_unit: SourceUnit) -> Self {
        let mut new = Self {
            definitions: Vec::new(),
            unresolved: Vec::new(),
            redefinitions: Vec::new(),
            source_units: Vec::new(),
            root_unit_ids: Vec::new(),
            unused_library_units: Vec::new(),
        };
        new.add_source_unit(source_unit, None);
        return new;
    }

    pub fn add_library_units(&mut self, mut source_units: Vec<SourceUnit>) {
        self.unused_library_units.append(&mut source_units);
    }

    pub fn resolve(&mut self) {
        // Repeatedly test if any unused source unit resolves an unresolved symbol,
        // breaking the loop when no new resolutions are found.
        'outer: loop {
            for (i, source_unit) in self.unused_library_units.iter().enumerate() {
                if let Some(id) = self.resolves_reference(&source_unit) {
                    let source_unit = self.unused_library_units.remove(i);
                    self.add_source_unit(source_unit, Some(id));
                    continue 'outer;
                }
            }
            break;
        }
    }

    /// Add a source unit to the resolver and link it to a parent unit.
    pub fn add_source_unit(&mut self, mut source_unit: SourceUnit, parent_id: Option<usize>) {
        let source_id = self.source_units.len();

        // Add all main symbols.
        if let Some(definitions) = take(&mut source_unit.main.symbols.definitions) {
            self.add_definitions(definitions, source_id, SourceRole::Main); }
        if let Some(references) = take(&mut source_unit.main.symbols.references) {
            self.add_references(references, source_id, SourceRole::Main); }

        // Add all head symbols.
        if let Some(head) = &mut source_unit.head {
            if let Some(references) = take(&mut head.symbols.references) {
                self.add_references(references, source_id, SourceRole::Head); }
            if let Some(definitions) = take(&mut head.symbols.definitions) {
                self.add_definitions(definitions, source_id, SourceRole::Head); }
        }

        // Add all tail symbols.
        if let Some(tail) = &mut source_unit.tail {
            if let Some(references) = take(&mut tail.symbols.references) {
                self.add_references(references, source_id, SourceRole::Tail); }
            if let Some(definitions) = take(&mut tail.symbols.definitions) {
                self.add_definitions(definitions, source_id, SourceRole::Tail); }
        }

        if let Some(parent_id) = parent_id {
            if let Some(parent_unit) = self.source_units.get_mut(parent_id) {
                parent_unit.child_ids.push(source_id);
            }
        } else {
            self.root_unit_ids.push(source_id);
        }

        let source_unit = HeirarchicalSourceUnit { source_unit, child_ids: Vec::new() };
        self.source_units.push(source_unit);
    }

    fn add_references(&mut self, references: Vec<Symbol>, source_id: usize, source_role: SourceRole) {
        for symbol in references {
            let reference = TrackedSymbol { symbol, source_id, source_role };
            if !self.definitions.contains(&reference) {
                self.unresolved.push(reference);
            }
        }
    }

    fn add_definitions(&mut self, definitions: Vec<Symbol>, source_id: usize, source_role: SourceRole) {
        for symbol in definitions {
            let predicate = |d: &&TrackedSymbol| { &d.symbol.name == &symbol.name };
            if let Some(def) = self.definitions.iter().find(predicate) {
                let definition = TrackedSymbol { symbol, source_id, source_role };
                let redefinition = (definition, def.source_id);
                self.redefinitions.push(redefinition);
            } else {
                self.unresolved.retain(|s| s.symbol.name != symbol.name);
                let definition = TrackedSymbol { symbol, source_id, source_role };
                self.definitions.push(definition);
            }
        }
    }

    /// Returns the ID of the owner of a symbol resolved by this unit.
    pub fn resolves_reference(&self, source_unit: &SourceUnit) -> Option<usize> {
        if let Some(definitions) = &source_unit.main.symbols.definitions {
            if let Some(id) = self.source_id_of_unresolved(&definitions) {
                return Some(id);
            }
        }
        if let Some(head) = &source_unit.head {
            if let Some(definitions) = &head.symbols.definitions {
                if let Some(id) = self.source_id_of_unresolved(&definitions) {
                    return Some(id);
                }
            }
        }
        if let Some(tail) = &source_unit.tail {
            if let Some(definitions) = &tail.symbols.definitions {
                if let Some(id) = self.source_id_of_unresolved(&definitions) {
                    return Some(id);
                }
            }
        }
        return None;
    }

    /// Returns the ID of the owner of a reference to one of these symbols.
    fn source_id_of_unresolved(&self, symbols: &[Symbol]) -> Option<usize> {
        for symbol in symbols {
            let opt = self.unresolved.iter().find(|s| s.symbol.name == symbol.name);
            if let Some(unresolved) = opt {
                return Some(unresolved.source_id);
            }
        }
        return None;
    }

    pub fn get_source_code_for_tracked_symbol(&self, symbol: &TrackedSymbol) -> &str {
        let source_unit = &self.source_units[symbol.source_id].source_unit;
        match symbol.source_role {
            SourceRole::Main => source_unit.main.symbols.source_code.as_str(),
            SourceRole::Head => match &source_unit.head {
                Some(head) => head.symbols.source_code.as_str(),
                None => unreachable!("Failed to find source for token"),
            }
            SourceRole::Tail => match &source_unit.tail {
                Some(tail) => tail.symbols.source_code.as_str(),
                None => unreachable!("Failed to find source for token"),
            }
        }
    }

    /// Create a source file by concatenating all source units.
    pub fn get_merged_source_code(&self) -> String {
        // The first source unit is guaranteed to be the root unit, so we can
        // just push source files in their current order.
        let mut source_code = String::new();

        // Push head source code.
        for source_unit in self.source_units.iter().rev() {
            if let Some(head) = &source_unit.source_unit.head {
                push_source_code_to_string(&mut source_code, head);
            }
        }
        // Push main source code.
        for source_unit in self.source_units.iter() {
            push_source_code_to_string(&mut source_code, &source_unit.source_unit.main);
        }
        // Push tail source code.
        for source_unit in self.source_units.iter().rev() {
            if let Some(tail) = &source_unit.source_unit.tail {
                push_source_code_to_string(&mut source_code, tail);
            }
        }
        return source_code;
    }
}


fn push_source_code_to_string(string: &mut String, source_file: &SourceFile) {
    // Ensure that sections are separated by two newlines.
    if !string.is_empty() {
        if !string.ends_with('\n') { string.push('\n'); }
        if !string.ends_with("\n\n") { string.push('\n'); }
    }
    // Write a path comment to the string.
    let path_str = source_file.path.as_os_str().to_string_lossy();
    let path_comment = format!("(: {path_str} )\n");
    string.push_str(&path_comment);
    string.push_str(&source_file.symbols.source_code);
}


pub struct HeirarchicalSourceUnit {
    pub source_unit: SourceUnit,
    pub child_ids: Vec<usize>,
}


pub struct TrackedSymbol {
    pub symbol: Symbol,
    pub source_id: usize,
    pub source_role: SourceRole,
}


#[derive(Clone, Copy)]
pub enum SourceRole {
    Main,
    Head,
    Tail,
}


impl PartialEq for TrackedSymbol {
    fn eq(&self, other: &TrackedSymbol) -> bool {
        self.symbol.name.eq(&other.symbol.name)
    }
}