Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add npe2 list command to discover/display all currently installed plugins #192

Merged
merged 13 commits into from
Jun 25, 2022
195 changes: 175 additions & 20 deletions npe2/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@
import warnings
from enum import Enum
from pathlib import Path
from textwrap import indent
from typing import List, Optional
from typing import TYPE_CHECKING, Iterator, List, Optional, Sequence

import typer

from npe2 import PluginManifest
from npe2 import PluginManager, PluginManifest

if TYPE_CHECKING:
from rich.console import RenderableType

app = typer.Typer()

Expand All @@ -20,29 +22,61 @@ class Format(str, Enum):
toml = "toml"


class ListFormat(str, Enum):
"""Valid out formats for `npe2 list`."""

table = "table"
json = "json" # alias for json in pandas "records" format
yaml = "yaml"
compact = "compact"


def _pprint_formatted(string, format: Format = Format.yaml): # pragma: no cover
"""Print yaml nicely, depending on available modules."""
try:
from rich.console import Console
from rich.syntax import Syntax
from rich.console import Console
from rich.syntax import Syntax

Console().print(Syntax(string, format.value, theme="fruity"))
except ImportError:
typer.echo(string)
Console().print(Syntax(string, format.value, theme="fruity"))


def _pprint_exception(err: Exception):
from rich.console import Console
from rich.traceback import Traceback

e_info = (type(err), err, err.__traceback__)
try:
from rich.console import Console
from rich.traceback import Traceback

trace = Traceback.extract(*e_info, show_locals=True)
Console().print(Traceback(trace))
except ImportError:
import traceback
trace = Traceback.extract(*e_info, show_locals=True)
Console().print(Traceback(trace))

typer.echo("\n" + "".join(traceback.format_exception(*e_info)))

def _pprint_table(
headers: Sequence["RenderableType"], rows: Sequence[Sequence["RenderableType"]]
):
from itertools import cycle

from rich.console import Console
from rich.table import Table

COLORS = ["cyan", "magenta", "green", "yellow"]
EMOJI_TRUE = ":white_check_mark:"
EMOJI_FALSE = ""

table = Table()
for head, color in zip(headers, cycle(COLORS)):
table.add_column(head, style=color)
for row in rows:
strings = []
for r in row:
val = ""
if isinstance(r, dict):
val = ", ".join(f"{k} ({v})" for k, v in r.items())
elif r:
val = str(r).replace("True", EMOJI_TRUE).replace("False", EMOJI_FALSE)
strings.append(val)
table.add_row(*strings)

console = Console()
console.print(table)


@app.command()
Expand All @@ -60,7 +94,6 @@ def validate(
),
):
"""Validate manifest for a distribution name or manifest filepath."""

err: Optional[Exception] = None
try:
pm = PluginManifest._from_package_or_name(name)
Expand Down Expand Up @@ -126,6 +159,127 @@ def parse(
_pprint_formatted(manifest_string, fmt)


def _make_rows(pm_dict: dict, normed_fields: Sequence[str]) -> Iterator[List]:
"""Cleanup output from pm.dict() into rows for table.

outside of just extracting the fields we care about, this also:

- handles nested fields expressed as dotted strings: `packge_metadata.version`
- negates fields that are prefixed with `!`
- simplifies contributions to a {name: count} dict.

"""
for info in pm_dict["plugins"].values():
row = []
for field in normed_fields:
val = info.get(field.lstrip("!"))

# extact nested fields
if not val and "." in field:
parts = field.split(".")
val = info
while parts:
val = val[parts.pop(0)]

# negate fields starting with !
if field.startswith("!"):
val = not val

# simplify contributions to just the number of contributions
if field == "contributions":
val = {k: len(v) for k, v in val.items() if v}

row.append(val)
yield row


@app.command()
def list(
fields: str = typer.Option(
"name,version,npe2,contributions",
help="Comma seperated list of fields to include in the output."
"Names may contain dots, indicating nested manifest fields "
"(`contributions.readers`). Fields names prefixed with `!` will be "
"negated in the output. Fields will appear in the table in the order in "
"which they are provided.",
metavar="FIELDS",
),
sort: str = typer.Option(
"0",
"-s",
"--sort",
help="Field name or (int) index on which to sort.",
metavar="KEY",
),
format: ListFormat = typer.Option(
"table",
"-f",
"--format",
help="Out format to use. When using 'compact', `--fields` is ignored ",
),
):
"""List currently installed plugins."""

if format == ListFormat.compact:
fields = "name,version,contributions"

requested_fields = [f.lower() for f in fields.split(",")]

# check for sort values that will not work
bad_sort_param_msg = (
f"Invalid sort value {sort!r}. "
f"Must be column index (<{len(requested_fields)}) or one of: "
+ ", ".join(requested_fields)
)
try:
if (sort_index := int(sort)) >= len(requested_fields):
raise typer.BadParameter(bad_sort_param_msg)
except ValueError:
try:
sort_index = requested_fields.index(sort.lower())
except ValueError as e:
raise typer.BadParameter(bad_sort_param_msg) from e

# some convenience aliases
ALIASES = {
"version": "package_metadata.version",
"summary": "package_metadata.summary",
"license": "package_metadata.license",
"author": "package_metadata.author",
"npe2": "!npe1_shim",
"npe1": "npe1_shim",
}
normed_fields = [ALIASES.get(f, f) for f in requested_fields]

pm = PluginManager.instance()
pm.discover(include_npe1=True)
pm_dict = pm.dict(include={f.lstrip("!") for f in normed_fields})
rows = sorted(_make_rows(pm_dict, normed_fields), key=lambda r: r[sort_index])

if format == ListFormat.table:
heads = [f.split(".")[-1].replace("_", " ").title() for f in requested_fields]
_pprint_table(headers=heads, rows=rows)
return

# standard records format used for the other formats
# [{column -> value}, ... , {column -> value}]
data: List[dict] = [dict(zip(requested_fields, row)) for row in rows]

if format == ListFormat.json:
import json

_pprint_formatted(json.dumps(data, indent=1), Format.json)
elif format in (ListFormat.yaml):
import yaml

_pprint_formatted(yaml.safe_dump(data, sort_keys=False), Format.yaml)
elif format in (ListFormat.compact):
template = " - {name}: {version} ({ncontrib} contributions)"
for r in data:
ncontrib = sum(r.get("contributions", {}).values())
typer.echo(template.format(**r, ncontrib=ncontrib))


@app.command()
def fetch(
name: str,
Expand Down Expand Up @@ -202,6 +356,8 @@ def convert(
else:
pm = manifest_from_npe1(str(path))
if w:
from textwrap import indent

typer.secho("Some issues occured:", fg=typer.colors.RED, bold=False)
for r in w:
typer.secho(
Expand Down Expand Up @@ -252,8 +408,7 @@ def cache(
from npe2.manifest._npe1_adapter import ADAPTER_CACHE, clear_cache

if clear:
_cleared = clear_cache(names)
if _cleared:
if _cleared := clear_cache(names):
nf = "\n".join(f" - {i.name}" for i in _cleared)
typer.secho("Cleared these files from cache:")
typer.secho(nf, fg=typer.colors.RED)
Expand Down
3 changes: 1 addition & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ install_requires =
psygnal>=0.3.0
pydantic
pytomlpp
rich
typer
python_requires = >=3.8
include_package_data = True
Expand Down Expand Up @@ -58,7 +59,6 @@ dev =
pydocstyle
pytest
pytest-cov
rich
typer
docs =
Jinja2
Expand All @@ -69,7 +69,6 @@ testing =
numpy
pytest
pytest-cov
rich

[bdist_wheel]
universal = 1
Expand Down
25 changes: 25 additions & 0 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,3 +175,28 @@ def test_cli_cache_clear_named(mock_cache):
result = runner.invoke(app, ["cache", "--clear", "not-a-plugin"])
assert result.stdout == "Nothing to clear for plugins: not-a-plugin\n"
assert result.exit_code == 0


@pytest.mark.parametrize("format", ["table", "compact", "yaml", "json"])
@pytest.mark.parametrize("fields", [None, "name,version,author"])
def test_cli_list(format, fields, uses_npe1_plugin):
result = runner.invoke(app, ["list", "-f", format, "--fields", fields])
assert result.exit_code == 0
assert "npe1-plugin" in result.output
if fields and "author" in fields and format != "compact":
assert "author" in result.output.lower()
else:
assert "author" not in result.output.lower()


def test_cli_list_sort(uses_npe1_plugin):
result = runner.invoke(app, ["list", "--sort", "version"])
assert result.exit_code == 0

result = runner.invoke(app, ["list", "--sort", "7"])
assert result.exit_code
assert "Invalid sort value '7'" in result.output

result = runner.invoke(app, ["list", "--sort", "notaname"])
assert result.exit_code
assert "Invalid sort value 'notaname'" in result.output