diff --git a/podme_api/cli.py b/podme_api/cli.py deleted file mode 100644 index 9329527..0000000 --- a/podme_api/cli.py +++ /dev/null @@ -1,158 +0,0 @@ -"""podme_api cli tool.""" - -import argparse -import asyncio -import contextlib -import logging - -from rich.console import Console -from rich.logging import RichHandler - -from podme_api.auth import PodMeDefaultAuthClient -from podme_api.auth.models import PodMeUserCredentials -from podme_api.client import PodMeClient - -console = Console() - - -def main_parser() -> argparse.ArgumentParser: - """Create the ArgumentParser with all relevant subparsers.""" - parser = argparse.ArgumentParser(description="A simple executable to use and test the library.") - _add_default_arguments(parser) - - subparsers = parser.add_subparsers(dest="cmd") - subparsers.required = True - - subscription_parser = subparsers.add_parser( - "subscription", description="Get your current PodMe subscription." - ) - subscription_parser.set_defaults(func=get_subscription) - - categories_parser = subparsers.add_parser("categories", description="Get a list of PodMe categories.") - categories_parser.set_defaults(func=get_categories) - - favourites_parser = subparsers.add_parser( - "favourites", description="Get a list of your favourite podcasts." - ) - favourites_parser.set_defaults(func=get_favourites) - - popular_parser = subparsers.add_parser( - "popular", description="Get a list of PodMe's most popular podcasts." - ) - _add_paging_arguments(popular_parser) - popular_parser.add_argument("--category", help="(optional) Limit by category", nargs="?", default=None) - popular_parser.add_argument( - "--type", help="(optional, default=2) Limit by podcast type", nargs="?", default=None - ) - popular_parser.set_defaults(func=get_popular) - - return parser - - -async def get_subscription(args) -> None: - """Retrieve PodMe subscription.""" - async with _get_client(args) as client: - subscriptions = await client.get_user_subscription() - for s in subscriptions: - console.print(s) - - -async def get_categories(args) -> None: - async with _get_client(args) as client: - categories = await client.get_categories() - for c in categories: - console.print(c) - - -async def get_favourites(args) -> None: - """Retrieve user favourite podcasts.""" - async with _get_client(args) as client: - podcasts = await client.get_user_podcasts() - for p in podcasts: - console.print(p) - - -async def get_popular(args) -> None: - """Retrieve favourite podcasts.""" - async with _get_client(args) as client: - podcasts = await client.get_popular_podcasts( - page_size=args.per_page, - pages=args.pages, - category=args.category, - podcast_type=args.type, - ) - for p in podcasts: - console.print(p) - - -def _add_default_arguments(parser: argparse.ArgumentParser): - """Add the default arguments username, password, region to the parser.""" - parser.add_argument( - "--credentials", - "-c", - type=argparse.FileType(), - default=None, - help="Path to credentials to load. Defaults to ~/.config/podme_api/credentials.json", - ) - parser.add_argument("--username", "-u", help="PodMe.com username") - parser.add_argument("--password", "-p", help="PodMe.com password") - parser.add_argument("-v", "--verbose", action="count", default=0, help="Logging verbosity level") - parser.add_argument("--debug", action="store_true", help="Enable debug mode") - - -def _add_paging_arguments(parser: argparse.ArgumentParser): - """Add paging options to the parser.""" - parser.add_argument( - "pages", - type=int, - nargs="?", - default=1, - help="(optional) Maximum number of pages to fetch (default=1)", - ) - parser.add_argument( - "per_page", type=int, nargs="?", default=25, help="(optional) Number of results per page (default=25)" - ) - - -@contextlib.asynccontextmanager -async def _get_client(args) -> PodMeClient: - """Return PodMeClient based on args.""" - if args.username and args.password: - user_creds = PodMeUserCredentials(args.username, args.password) - else: - user_creds = None - auth_client = PodMeDefaultAuthClient(user_credentials=user_creds) - client = PodMeClient(auth_client=auth_client) - try: - await client.__aenter__() - yield client - finally: - await client.__aexit__(None, None, None) - - -def main(): - """Run.""" - parser = main_parser() - args = parser.parse_args() - - if args.debug: - logging_level = logging.DEBUG - elif args.verbose: - logging_level = 50 - (args.verbose * 10) - if logging_level <= 0: - logging_level = logging.NOTSET - else: - logging_level = logging.ERROR - - logging.basicConfig( - level=logging_level, - format="%(message)s", - datefmt="[%X]", - handlers=[RichHandler(console=console)], - ) - - asyncio.run(args.func(args)) - - -if __name__ == "__main__": - main() diff --git a/podme_api/cli/__init__.py b/podme_api/cli/__init__.py new file mode 100644 index 0000000..a02f7b9 --- /dev/null +++ b/podme_api/cli/__init__.py @@ -0,0 +1,5 @@ +from .cli import main + +__all__ = [ + "main", +] diff --git a/podme_api/cli/cli.py b/podme_api/cli/cli.py new file mode 100644 index 0000000..62bf5f8 --- /dev/null +++ b/podme_api/cli/cli.py @@ -0,0 +1,442 @@ +"""podme_api cli tool.""" + +import argparse +import asyncio +import contextlib +import logging + +from rich.console import Console +from rich.live import Live +from rich.logging import RichHandler +from rich.panel import Panel +from rich.progress import ( + BarColumn, + Progress, + SpinnerColumn, + TextColumn, + TimeElapsedColumn, +) +from rich.table import Table + +from podme_api.__version__ import __version__ +from podme_api.auth import PodMeDefaultAuthClient +from podme_api.auth.models import PodMeUserCredentials +from podme_api.cli.utils import bold_star, is_valid_writable_dir, pretty_dataclass, pretty_dataclass_list +from podme_api.client import PodMeClient +from podme_api.models import PodMeDownloadProgressTask + +console = Console() + + +def main_parser() -> argparse.ArgumentParser: + """Create the ArgumentParser with all relevant subparsers.""" + parser = argparse.ArgumentParser(description="A simple executable to use and test the library.") + _add_default_arguments(parser) + + subparsers = parser.add_subparsers(dest="cmd") + subparsers.required = True + + # + # Login + # + login_parser = subparsers.add_parser("login", description="Log in") + login_parser.add_argument("username", type=str, help="Username / E-mail") + login_parser.add_argument("password", type=str, help="Password") + login_parser.set_defaults(func=login) + + # + # Subscription + # + subscription_parser = subparsers.add_parser( + "subscription", description="Get your current PodMe subscription." + ) + subscription_parser.set_defaults(func=get_subscription) + + # + # Favourites + # + favourites_parser = subparsers.add_parser( + "favourites", description="Get a list of your favourite podcasts." + ) + favourites_parser.set_defaults(func=get_favourites) + + # + # Podcasts + # + podcast_parser = subparsers.add_parser("podcast", description="Get podcast(s).") + podcast_parser.add_argument("podcast_slug", type=str, nargs="+", help="Podcast slug(s).") + podcast_parser.add_argument("--episodes", action="store_true", help="Get episodes.") + _add_paging_arguments(podcast_parser) + podcast_parser.set_defaults(func=get_podcasts) + + # + # Episodes + # + episode_parser = subparsers.add_parser("episode", description="Get episode(s).") + episode_parser.add_argument("episode_id", type=int, nargs="+", help="Episode id(s).") + episode_parser.add_argument("-d", "--download", action="store_true", help="Download episode(s).") + episode_parser.add_argument( + "-o", + dest="output_dir", + help="Directory to download episode(s) to.", + metavar="OUTPUT_DIR", + type=lambda x: is_valid_writable_dir(parser, x), + ) + episode_parser.set_defaults(func=get_episodes) + + # + # Categories + # + categories_parser = subparsers.add_parser("categories", description="Get a list of PodMe categories.") + categories_parser.set_defaults(func=get_categories) + + # + # Popular + # + popular_parser = subparsers.add_parser( + "popular", description="Get a list of PodMe's most popular podcasts." + ) + _add_paging_arguments(popular_parser) + popular_parser.add_argument("--category", help="(optional) Limit by category", nargs="?", default=None) + popular_parser.add_argument( + "--type", help="(optional, default=2) Limit by podcast type", nargs="?", default=None + ) + popular_parser.set_defaults(func=get_popular) + + # + # Search + # + search_parser = subparsers.add_parser("search", description="Search.") + search_parser.add_argument("query", type=str, help="Search query.") + search_parser.set_defaults(func=search) + _add_paging_arguments(search_parser) + + return parser + + +def _add_default_arguments(parser: argparse.ArgumentParser): + """Add default arguments to the parser.""" + parser.add_argument("-V", "--version", action="version", version=f"%(prog)s v{__version__}") + parser.add_argument("-v", "--verbose", action="count", default=0, help="Logging verbosity level") + parser.add_argument("--debug", action="store_true", help="Enable debug mode") + + +def _add_paging_arguments(parser: argparse.ArgumentParser): + """Add paging options to the parser.""" + parser.add_argument( + "--limit", + type=int, + default=20, + help="Number of results to return per page (default=20).", + ) + parser.add_argument("--pages", type=int, default=1, help="Maxium number of pages to fetch (default=1)") + + +async def login(args): + """Login.""" + async with _get_client(args) as client: + client.auth_client.invalidate_credentials() + username = await client.get_username() + console.print(f"Logged in as {username}") + + +async def get_subscription(args) -> None: + """Retrieve PodMe subscription.""" + async with _get_client(args) as client: + subscriptions = await client.get_user_subscription() + for s in subscriptions: + console.print( + *[ + pretty_dataclass( + s, + visible_fields=[ + "expiration_date", + "start_date", + "will_be_renewed", + ], + ), + pretty_dataclass( + s.subscription_plan, + visible_fields=[ + "name", + "price_decimal", + "currency", + "plan_guid", + ], + ), + ], + ) + + +async def get_favourites(args) -> None: + """Retrieve user favourite podcasts.""" + async with _get_client(args) as client: + podcasts = await client.get_user_podcasts() + console.print( + pretty_dataclass_list( + podcasts, + visible_fields=[ + "id", + "slug", + "title", + "categories", + ], + field_formatters={ + "title": lambda t, obj: f"{bold_star(obj.is_premium)}{t}", + "categories": lambda v, _: ", ".join([c.name for c in v]), + }, + field_order=[ + "title", + "slug", + "id", + "categories", + ], + ) + ) + + +async def get_podcasts(args) -> None: + async with _get_client(args) as client: + console.print(f"{args.podcast_slug}") + podcasts = await client.get_podcasts_info(args.podcast_slug) + + console.print( + pretty_dataclass_list( + podcasts, + visible_fields=[ + "id", + "slug", + "title", + "categories", + ], + field_formatters={ + "title": lambda t, obj: f"{bold_star(obj.is_premium)}{t}", + "categories": lambda v, _: ", ".join([c.name for c in v]), + }, + field_order=[ + "title", + "slug", + "id", + "categories", + ], + ) + ) + if args.episodes: + for podcast in podcasts: + episodes = await client.get_latest_episodes(podcast.slug, episodes_limit=args.limit) + console.print( + pretty_dataclass_list( + episodes, + title=f"Latest episodes of {podcast.title}", + visible_fields=[ + "id", + "title", + "date_added", + "length", + ], + field_formatters={ + "title": lambda t, obj: f"{bold_star(obj.is_premium)}{t}", + }, + field_order=[ + "id", + "title", + "length", + "date_added", + ], + ) + ) + + +async def get_episodes(args) -> None: + async with _get_client(args) as client: + episodes = await client.get_episodes_info(args.episode_id) + for episode in episodes: + console.print( + pretty_dataclass( + episode, + title=f"{episode.podcast_title} - {episode.title}", + hidden_fields=[ + "current_spot", + "current_spot_sec", + "has_completed", + ], + ) + ) + if args.download: + if not args.output_dir: + console.print("[red]Please specify an output directory[/red]") + return + output_path = args.output_dir + console.print(f"Downloading to: {output_path} ...") + ids = [e.id for e in episodes] + + job_progress = Progress( + "{task.description}", + SpinnerColumn(), + BarColumn(), + TextColumn("[progress.percentage]{task.percentage:>3.0f}%"), + TextColumn("{task.fields[step]}"), + ) + overall_progress = Progress( + TimeElapsedColumn(), + TextColumn("{task.description}"), + ) + overall_task_id = overall_progress.add_task("", total=len(ids)) + + progress_table = Table.grid() + progress_table.add_row(overall_progress) + progress_table.add_row( + Panel.fit(job_progress, title="[b]Episodes", border_style="red", padding=(1, 2)), + ) + + with Live(progress_table, console=console, refresh_per_second=10): + overall_progress.update(overall_task_id, description="Preparing download urls") + download_infos = [] + download_tasks = {} + downloads = await client.get_episode_download_url_bulk(ids) + for episode_id, download_url in downloads: + path = output_path / f"episode_{episode_id}.mp3" + download_infos.append((download_url, path)) + download_tasks[str(download_url)] = job_progress.add_task( + f"Episode {episode_id}", + url=download_url, + save_path=path, + episode_id=episode_id, + step="Starting", + ) + + def on_progress(task: PodMeDownloadProgressTask, url: str, current: int, total: int): + progress_task_id = download_tasks[url] + percentage = float((current / total) * 100) if total else 0 + task_friendly_name = { + PodMeDownloadProgressTask.INITIALIZE: "Starting", + PodMeDownloadProgressTask.DOWNLOAD_FILE: "Downloading", + PodMeDownloadProgressTask.TRANSCODE_FILE: "Transcoding", + } + job_progress.update(progress_task_id, step=task_friendly_name[task], completed=percentage) + + def on_finished(url: str, saved_filename: str): + progress_task_id = download_tasks[url] + job_progress.update(progress_task_id, step=f"Finished: {saved_filename}", completed=100) + overall_progress.update(overall_task_id, advance=1) + + overall_progress.update(overall_task_id, description="Downloading/processing files") + + await client.download_files(download_infos, on_progress, on_finished) + + overall_progress.update(overall_task_id, description="[bold green]Completed") + await asyncio.sleep(1) + + +async def get_categories(args) -> None: + async with _get_client(args) as client: + categories = await client.get_categories() + console.print( + pretty_dataclass_list( + categories, + visible_fields=[ + "id", + "key", + "name", + ], + field_order=["id", "key", "name"], + ), + ) + + +async def get_popular(args) -> None: + """Retrieve favourite podcasts.""" + async with _get_client(args) as client: + podcasts = await client.get_popular_podcasts( + page_size=args.limit, + pages=args.pages, + category=args.category, + podcast_type=args.type, + ) + console.print( + pretty_dataclass_list( + podcasts, + field_formatters={ + "title": lambda t, obj: f"{bold_star(obj.is_premium)}{t}", + }, + hidden_fields=[ + "is_premium", + "image_url", + ], + ) + ) + + +async def search(args) -> None: + async with _get_client(args) as client: + results = await client.search_podcast( + args.query, + page_size=args.limit, + pages=args.pages, + ) + console.print( + pretty_dataclass_list( + results, + visible_fields=[ + "podcast_id", + "slug", + "podcast_title", + "author_full_name", + "date_added", + ], + field_formatters={ + "podcast_title": lambda t, obj: f"{bold_star(obj.is_premium)}{t}", + }, + field_order=[ + "podcast_id", + "slug", + "podcast_title", + "author_full_name", + "date_added", + ], + ) + ) + + +@contextlib.asynccontextmanager +async def _get_client(args) -> PodMeClient: + """Return PodMeClient based on args.""" + if hasattr(args, "username") and hasattr(args, "password"): + user_creds = PodMeUserCredentials(args.username, args.password) + else: + user_creds = None + auth_client = PodMeDefaultAuthClient(user_credentials=user_creds) + client = PodMeClient(auth_client=auth_client) + try: + await client.__aenter__() + yield client + finally: + await client.__aexit__(None, None, None) + + +def main(): + """Run.""" + parser = main_parser() + args = parser.parse_args() + + if args.debug: + logging_level = logging.DEBUG + elif args.verbose: + logging_level = 50 - (args.verbose * 10) + if logging_level <= 0: + logging_level = logging.NOTSET + else: + logging_level = logging.ERROR + + logging.basicConfig( + level=logging_level, + format="%(message)s", + datefmt="[%X]", + handlers=[RichHandler(console=console)], + ) + + asyncio.run(args.func(args)) + + +if __name__ == "__main__": + main() diff --git a/podme_api/cli/utils.py b/podme_api/cli/utils.py new file mode 100644 index 0000000..5413f70 --- /dev/null +++ b/podme_api/cli/utils.py @@ -0,0 +1,186 @@ +"""podme-api cli tool.""" + +from __future__ import annotations + +from dataclasses import fields +import os +from pathlib import Path +from typing import Callable, TypeVar + +from rich.box import SIMPLE +from rich.panel import Panel +from rich.table import Table +from rich.text import Text + +from podme_api.models import BaseDataClassORJSONMixin + +T = TypeVar("T", bound=BaseDataClassORJSONMixin) + + +def pretty_dataclass( # noqa: C901 + dataclass_obj: T, + field_formatters: dict[str, Callable[[any, T], any]] | None = None, + hidden_fields: list[str] | None = None, + visible_fields: list[str] | None = None, + title: str | None = None, + hide_none: bool = True, + hide_default: bool = True, +) -> Table: + """Render a dataclass object in a pretty format using rich.""" + + field_formatters = field_formatters or {} + hidden_fields = hidden_fields or [] + visible_fields = visible_fields or [] + + table = Table(title=title, show_header=False, title_justify="left") + table.add_column("Field", style="cyan", no_wrap=True) + table.add_column("Value", style="magenta") + + if visible_fields: + # Render fields in the order specified by visible_fields + for field_name in visible_fields: + if hidden_fields and field_name in hidden_fields: + continue + + field = next((f for f in fields(dataclass_obj) if f.name == field_name), None) + if not field: + continue + + field_value = getattr(dataclass_obj, field_name) + + if hide_none and field_value is None: + continue + + if hide_none and isinstance(field_value, list) and len(field_value) == 0: + continue + + if hide_default and field_value == field.default: + continue + + if field_name in field_formatters: + field_value = field_formatters[field_name](field_value, dataclass_obj) + table.add_row(field_name, str(field_value)) + else: + # Render all fields (except hidden ones) in the default order + for field in fields(dataclass_obj): + if hidden_fields and field.name in hidden_fields: + continue + + field_value = getattr(dataclass_obj, field.name) + + if hide_none and field_value is None: + continue + + if hide_none and isinstance(field_value, list) and len(field_value) == 0: + continue + + if hide_default and field_value == field.default: + continue + + if field.name in field_formatters: + field_value = field_formatters[field.name](field_value, dataclass_obj) + table.add_row(field.name, str(field_value)) + + return table + + +def pretty_dataclass_list( # noqa: C901 + dataclass_objs: list[T], + field_formatters: dict[str, Callable[[any, T], any]] | None = None, + hidden_fields: list[str] | None = None, + visible_fields: list[str] | None = None, + field_widths: dict[str, int] | None = None, + field_order: list[str] | None = None, + title: str | None = None, + hide_none: bool = True, + hide_default: bool = True, +) -> Table | Text: + """Render a list of dataclass objects in a table format using rich.""" + + field_formatters = field_formatters or {} + hidden_fields = hidden_fields or [] + visible_fields = visible_fields or [] + field_widths = field_widths or {} + field_order = field_order or [] + + if not dataclass_objs: + if title is not None: + return Text(f"{title}: No results") + return Text("No results") + + dataclass_fields = list(fields(dataclass_objs[0])) + ordered_fields = [f for f in field_order if f in [field.name for field in dataclass_fields]] + remaining_fields = [f.name for f in dataclass_fields if f.name not in ordered_fields] + fields_to_render = ordered_fields + remaining_fields + + table = Table(title=title, expand=True) + + for field_name in fields_to_render: + if hidden_fields and field_name in hidden_fields: + continue + + if visible_fields and field_name not in visible_fields: + continue + + table.add_column( + field_name, + style="cyan", + no_wrap=not field_widths.get(field_name, None), + width=field_widths.get(field_name, None), + ) + + for obj in dataclass_objs: + row = [] + for field_name in fields_to_render: + if hidden_fields and field_name in hidden_fields: + continue + + if visible_fields and field_name not in visible_fields: + continue + + field = next((f for f in fields(obj) if f.name == field_name), None) + if not field: + continue + + field_value = getattr(obj, field_name) + + if hide_none and field_value is None: + continue + + if hide_default and field_value == field.default: + continue + + if field_name in field_formatters: + field_value = field_formatters[field_name](field_value, obj) + row.append(str(field_value)) + table.add_row(*row) + + return table + + +def header_panel(title: str, subtitle: str): + grid = Table.grid(expand=True) + grid.add_column(justify="center", ratio=1) + grid.add_column(justify="right") + grid.add_row( + title, + subtitle, + ) + return Panel( + grid, + style="white on black", + box=SIMPLE, + ) + + +def bold_star(visible: bool = True, suffix=" ", prefix=""): + return f"{prefix}[bold]*[/bold]{suffix}" if visible else "" + + +def is_valid_writable_dir(parser, x): + """Check if directory exists and is writable.""" + if not Path(x).is_dir(): + parser.error(f"{x} is not a valid directory.") + if not os.access(x, os.W_OK | os.X_OK): + parser.error(f"{x} is not writable.") + return Path(x) diff --git a/podme_api/client.py b/podme_api/client.py index a190805..8875a44 100644 --- a/podme_api/client.py +++ b/podme_api/client.py @@ -41,6 +41,7 @@ from podme_api.models import ( PodMeCategory, PodMeCategoryPage, + PodMeDownloadProgressTask, PodMeEpisode, PodMeHomeScreen, PodMeLanguage, @@ -391,7 +392,7 @@ async def download_file( self, download_url: URL | str, path: PathLike | str, - on_progress: Callable[[str, int, int], None] | None = None, + on_progress: Callable[[PodMeDownloadProgressTask, str, int, int], None] | None = None, on_finished: Callable[[str, str], None] | None = None, transcode: bool = True, ) -> None: @@ -400,9 +401,9 @@ async def download_file( Args: download_url (URL | str): The URL of the file to download. path (PathLike | str): The local path where the file will be saved. - on_progress (Callable[[str, int, int], None], optional): + on_progress (Callable[[PodMeDownloadProgressTask, str, int, int], None], optional): A callback function to report download progress. It should accept - the download URL, current size, and total size as arguments. + the download URL/path, current and total as arguments (current==total means 100%). on_finished (Callable[[str, str], None], optional): A callback function to be called when the download is complete. It should accept the download URL and save path as arguments. @@ -414,20 +415,26 @@ async def download_file( """ download_url = URL(download_url) save_path = Path(path) + if on_progress is None: + on_progress = lambda task, url, current, total: None # noqa: E731, ARG005 + if on_finished is None: + on_finished = lambda url, _path: None # noqa: E731, ARG005 self._ensure_session() try: resp = await self.session.get(download_url, raise_for_status=True) total_size = int(resp.headers.get("Content-Length", 0)) + on_progress(PodMeDownloadProgressTask.DOWNLOAD_FILE, str(download_url), 0, total_size) current_size = 0 async with aiofiles.open(save_path, mode="wb") as f: _LOGGER.debug("Starting download of <%s>", download_url) async for chunk, _ in resp.content.iter_chunks(): await f.write(chunk) current_size += len(chunk) - if on_progress: - on_progress(str(download_url), current_size, total_size) + on_progress( + PodMeDownloadProgressTask.DOWNLOAD_FILE, str(download_url), current_size, total_size + ) except (ClientPayloadError, ClientResponseError) as err: msg = f"Error while downloading {download_url}" raise PodMeApiDownloadError(msg) from err @@ -435,18 +442,19 @@ async def download_file( _LOGGER.debug("Finished download of <%s> to <%s>", download_url, save_path) if transcode: + on_progress(PodMeDownloadProgressTask.TRANSCODE_FILE, str(download_url), 0, 100) new_save_path = await self.transcode_file(save_path) if new_save_path != save_path: _LOGGER.debug("Moving transcoded file %s to %s", new_save_path, save_path) await aiofiles.os.replace(new_save_path, save_path) + on_progress(PodMeDownloadProgressTask.TRANSCODE_FILE, str(download_url), 100, 100) - if on_finished: - on_finished(str(download_url), str(save_path)) + on_finished(str(download_url), str(save_path)) async def download_files( self, download_info: list[tuple[URL | str, PathLike]], - on_progress: Callable[[str, int, int], None] | None = None, + on_progress: Callable[[PodMeDownloadProgressTask, str, int, int], None] | None = None, on_finished: Callable[[str, str], None] | None = None, ): """Download multiple files concurrently. @@ -454,9 +462,9 @@ async def download_files( Args: download_info (list[tuple[URL | str, Path | str]]): A list of tuples containing the download URL and save path for each file. - on_progress (Callable[[str, int, int], None], optional): + on_progress (Callable[[PodMeDownloadProgressTask, str, int, int], None], optional): A callback function to report download progress. It should accept - the download URL, current size, and total size as arguments. + the download URL/path, current and total as arguments (current==total means 100%). on_finished (Callable[[str, str], None], optional): A callback function to be called when the download is complete. It should accept the download URL and save path as arguments. diff --git a/podme_api/models.py b/podme_api/models.py index 2cf6cd7..a456aab 100644 --- a/podme_api/models.py +++ b/podme_api/models.py @@ -65,6 +65,16 @@ def default_language(self): # pragma: no cover return PodMeLanguage[self.name] +class PodMeDownloadProgressTask(StrEnum): + """Enumeration of PodMe download progress tasks.""" + + INITIALIZE = auto() + RESOLVE_URL = auto() + DOWNLOAD_FILE = auto() + TRANSCODE_FILE = auto() + COMPLETE = auto() + + @dataclass class PodMeCategory(BaseDataClassORJSONMixin): """Represents a PodMe category.""" diff --git a/tests/test_client.py b/tests/test_client.py index b41c182..4a929c3 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -41,6 +41,7 @@ PodMeSearchResult, PodMeSubscription, PodMeSubscriptionPlan, + PodMeDownloadProgressTask, ) from .helpers import ( @@ -678,14 +679,15 @@ def get_file_ending(url: URL) -> str: # Check progress calls assert on_progress.call_count > 0 for args in on_progress.call_args_list: - url, current, total = args[0] + task, url, current, total = args[0] + assert isinstance(task, PodMeDownloadProgressTask) assert url in [str(u) for u, _ in download_infos] assert 0 <= current <= total # Check that the last progress call for each URL has current == total last_calls = on_progress.call_args_list[-2:] for call_args in last_calls: - _, current, total = call_args[0] + _, _, current, total = call_args[0] assert current == total # Check finished calls