diff options
Diffstat (limited to 'src/source_unit.rs')
-rw-r--r-- | src/source_unit.rs | 33 |
1 files changed, 17 insertions, 16 deletions
diff --git a/src/source_unit.rs b/src/source_unit.rs index 3e674be..69eb436 100644 --- a/src/source_unit.rs +++ b/src/source_unit.rs @@ -6,8 +6,8 @@ use vagabond::*; type ParseFn = fn(&str, Option<&Path>) -> Vec<Symbol>; -/// Gather all source units from a PATH-style environment variable. -pub fn gather_from_path_variable(variable: &str, extension: &str, parse: ParseFn) -> Vec<SourceUnit> { +/// Gather all source units with a given extension using a PATH-style environment variable. +pub fn gather_from_path_variable(variable: &str, extension: Option<&str>, parse: ParseFn) -> Vec<SourceUnit> { let mut source_units = Vec::new(); if let Ok(string) = std::env::var(variable) { for path in string.split(":").map(PathBuf::from) { @@ -17,8 +17,8 @@ pub fn gather_from_path_variable(variable: &str, extension: &str, parse: ParseFn return source_units; } -/// Gather source units at or descending from a path. -pub fn gather_from_path(path: &Path, extension: &str, parse: ParseFn) -> Vec<SourceUnit> { +/// Gather source units with a given extension at or descending from a path. +pub fn gather_from_path(path: &Path, extension: Option<&str>, parse: ParseFn) -> Vec<SourceUnit> { let mut source_units = Vec::new(); if let Ok(entry) = Entry::from_path(path) { if EntryType::File == entry.entry_type { @@ -47,22 +47,22 @@ pub struct SourceUnit { impl SourceUnit { /// Load source from a main file and an associated head and tail file. - pub fn from_path<P: AsRef<Path>>(path: P, extension: &str, parse: ParseFn) -> Result<Self, FileError> { + pub fn from_path<P: AsRef<Path>>(path: P, extension: Option<&str>, parse: ParseFn) -> Result<Self, FileError> { let main_path = { path.as_ref().canonicalize().unwrap_or_else(|_| path.as_ref().to_path_buf()) }; let main_path_str = main_path.as_os_str().to_string_lossy().to_string(); - let head_extension = format!("head.{extension}"); - let tail_extension = format!("tail.{extension}"); - let is_head = main_path_str.ends_with(&head_extension); - let is_tail = main_path_str.ends_with(&tail_extension); - let is_not_main = !main_path_str.ends_with(extension); + // Attempt to extract an extension from main path if no extension was provided. + let extension = extension.or_else(|| main_path.extension().and_then(|ext| ext.to_str())); + + let head_extension = extension.map(|ext| format!("head.{ext}")); + let tail_extension = extension.map(|ext| format!("tail.{ext}")); + let is_head = head_extension.as_ref().map_or(false, |ext| main_path_str.ends_with(ext.as_str())); + let is_tail = tail_extension.as_ref().map_or(false, |ext| main_path_str.ends_with(ext.as_str())); + let is_main = extension.map_or(true, |ext| main_path_str.ends_with(ext)); // Head and tail files will be picked up later along with the main file. - if is_not_main || is_head || is_tail { return Err(FileError::InvalidExtension); } + if !is_main || is_head || is_tail { return Err(FileError::InvalidExtension); } let source_code = read_file(path.as_ref())?; let symbols = parse(&source_code, Some(path.as_ref())); - let head_path = main_path.with_extension(head_extension); - let tail_path = main_path.with_extension(tail_extension); - macro_rules! parse_file { ($path:expr) => { read_file(&$path).ok().map(|source_code| { @@ -72,9 +72,10 @@ impl SourceUnit { }) }; } + + let head = head_extension.map_or(None, |ext| parse_file!(main_path.with_extension(&ext))); + let tail = tail_extension.map_or(None, |ext| parse_file!(main_path.with_extension(&ext))); let main = SourceFile { path: main_path, source_code, symbols }; - let head = parse_file!(head_path); - let tail = parse_file!(tail_path); Ok( SourceUnit { main, head, tail } ) } |