Skip to content

Commit

Permalink
bug fix from refactoring and pre-commits
Browse files Browse the repository at this point in the history
Signed-off-by: miguel <miguel.brandao@ibm.com>
  • Loading branch information
miguel-brandao-ibm committed May 31, 2023
1 parent 80f88da commit 3effcfc
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 7 deletions.
4 changes: 3 additions & 1 deletion deepsearch/model/cli/main.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import typer

from deepsearch.model.model_download.artifact_management_typer import artifact_download_app
from deepsearch.model.model_download.artifact_management_typer import (
artifact_download_app,
)

app = typer.Typer(no_args_is_help=True, add_completion=False)
app.add_typer(
Expand Down
8 changes: 5 additions & 3 deletions deepsearch/model/model_download/artifact_management_typer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,18 @@

@artifact_download_app.command()
def download(
artifact: str = typer.Option(None, "--artifact", "-a"),
artifact_name: str = typer.Option(None, "--artifact", "-a"),
list_artifacts: bool = typer.Option(False, "--list", "-l"),
):

if list_artifacts:
artifacts = artifact_manager.get_index_artifact_list()
for artifact in artifacts:
typer.echo(artifact["folder_name"])
elif artifact is not None:
artifact_manager.download_artifact_to_cache(artifact, with_progess_bar=True)
elif artifact_name is not None:
artifact_manager.download_artifact_to_cache(
artifact_name, with_progess_bar=True
)


@artifact_download_app.command()
Expand Down
6 changes: 3 additions & 3 deletions deepsearch/model/model_download/artifact_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,10 @@ def get_artifact_location_in_cache(self, artifact_name: str):
if "folder_name" in artifact and artifact["folder_name"] == artifact_name:
return artifact

def delete_artifact_from_cache(self, artifact: str):
def delete_artifact_from_cache(self, artifact_name: str):
target_artifacts = []
for artifact in self.get_artifact_cache_list():
if "folder_name" in artifact and artifact["folder_name"] == artifact:
if "folder_name" in artifact and artifact["folder_name"] == artifact_name:
target_artifacts.append(artifact)

for artifact in target_artifacts:
Expand Down Expand Up @@ -119,7 +119,7 @@ def _download_file(
self, artifact_info: Dict, directory: Any, with_progress_bar: bool = False
) -> str:
# Get the filename from the URL
filename = artifact_info["artifact_filename"]
filename = artifact_info["model_filename"]
file_path = directory.name + f"/{filename}"

# Download the file
Expand Down

0 comments on commit 3effcfc

Please sign in to comment.