diff --git a/assembler.py b/assembler.py index 727fe76..5d337a3 100644 --- a/assembler.py +++ b/assembler.py @@ -14,10 +14,10 @@ class FromAssemblerCallGraphBuilder(CallGraphBuilder): def __init__(self, baseDirs, specialModuleFiles = {}): assertType(specialModuleFiles, 'specialModuleFiles', dict) + if isinstance(baseDirs, str): baseDirs = [baseDirs] assertTypeAll(baseDirs, 'baseDirs', str) - for baseDir in baseDirs: if not os.path.isdir(baseDir): raise IOError("Not a directory: " + baseDir); diff --git a/source.py b/source.py index 0bc85ea..7431354 100644 --- a/source.py +++ b/source.py @@ -1533,12 +1533,17 @@ def __removeStringsFromStatement(statement): return cleanStatement class SourceFiles(object): - def __init__(self, baseDir, specialModuleFiles={}): + def __init__(self, baseDirs, specialModuleFiles={}): assertType(specialModuleFiles, 'specialModuleFiles', dict) - if not os.path.isdir(baseDir): - raise IOError("Not a directory: " + baseDir) - self.__baseDir = baseDir + if isinstance(baseDirs, str): + baseDirs = [baseDirs] + assertTypeAll(baseDirs, 'baseDirs', str) + for baseDir in baseDirs: + if not os.path.isdir(baseDir): + raise IOError("Not a directory: " + baseDir); + + self.__baseDirs = baseDirs self.__filesByPath = dict() self.__filesByModules = dict() self.setSpecialModuleFiles(specialModuleFiles) @@ -1555,9 +1560,6 @@ def setSpecialModuleFiles(self, specialModuleFiles): self.__filesByModules = dict() # Clear Module Cache - def getBaseDir(self): - return self.__baseDir - def existsSubroutine(self, subroutineName): assertType(subroutineName, 'subroutineName', SubroutineName) @@ -1621,8 +1623,8 @@ def getRelativePath(self, sourceFile): assertType(sourceFile, 'sourceFile', SourceFile) path = sourceFile.getPath() - if path.startswith(self.__baseDir): - path = path[len(self.__baseDir):].lstrip('/') + if path.startswith(self.__baseDirs): + path = path[len(self.__baseDirs):].lstrip('/') return path def __getModuleFileName(self, moduleName): @@ -1632,10 +1634,11 @@ def __getModuleFileName(self, moduleName): return moduleName + '.f90' def __findFile(self, fileName): - for root, _, files in os.walk(self.__baseDir): - for name in files: - if name.replace('.F90', '.f90') == fileName.replace('.F90', '.f90'): - return os.path.join(root, name) + for baseDir in self.__baseDirs: + for root, _, files in os.walk(baseDir): + for name in files: + if name.replace('.F90', '.f90') == fileName.replace('.F90', '.f90'): + return os.path.join(root, name) return None def clearCache(self):