Skip to content

Commit

Permalink
perf(Ensemble): allow dumping of ensemble output into already existin…
Browse files Browse the repository at this point in the history
…g directories
  • Loading branch information
muellerdo committed Jun 13, 2022
1 parent ea04aad commit ab6c0bc
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 6 deletions.
5 changes: 3 additions & 2 deletions aucmedi/ensemble/bagging.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,11 +365,12 @@ def dump(self, directory_path):
if self.cache_dir is None:
raise FileNotFoundError("Bagging does not have a valid model cache directory!")
elif isinstance(self.cache_dir, tempfile.TemporaryDirectory):
shutil.copytree(self.cache_dir.name, directory_path)
shutil.copytree(self.cache_dir.name, directory_path,
dirs_exist_ok=True)
self.cache_dir.cleanup()
self.cache_dir = directory_path
else:
shutil.copytree(self.cache_dir, directory_path)
shutil.copytree(self.cache_dir, directory_path, dirs_exist_ok=True)
self.cache_dir = directory_path

# Load model from file
Expand Down
5 changes: 3 additions & 2 deletions aucmedi/ensemble/composite.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,11 +525,12 @@ def dump(self, directory_path):
if self.cache_dir is None:
raise FileNotFoundError("Composite does not have a valid model cache directory!")
elif isinstance(self.cache_dir, tempfile.TemporaryDirectory):
shutil.copytree(self.cache_dir.name, directory_path)
shutil.copytree(self.cache_dir.name, directory_path,
dirs_exist_ok=True)
self.cache_dir.cleanup()
self.cache_dir = directory_path
else:
shutil.copytree(self.cache_dir, directory_path)
shutil.copytree(self.cache_dir, directory_path, dirs_exist_ok=True)
self.cache_dir = directory_path

# Load model from file
Expand Down
5 changes: 3 additions & 2 deletions aucmedi/ensemble/stacking.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,11 +514,12 @@ def dump(self, directory_path):
if self.cache_dir is None:
raise FileNotFoundError("Stacking does not have a valid model cache directory!")
elif isinstance(self.cache_dir, tempfile.TemporaryDirectory):
shutil.copytree(self.cache_dir.name, directory_path)
shutil.copytree(self.cache_dir.name, directory_path,
dirs_exist_ok=True)
self.cache_dir.cleanup()
self.cache_dir = directory_path
else:
shutil.copytree(self.cache_dir, directory_path)
shutil.copytree(self.cache_dir, directory_path, dirs_exist_ok=True)
self.cache_dir = directory_path

# Load model from file
Expand Down

0 comments on commit ab6c0bc

Please sign in to comment.