Skip to content

Commit

Permalink
fix standalone print bug when cols too many
Browse files Browse the repository at this point in the history
  • Loading branch information
goldenxinxing committed Nov 18, 2022
1 parent 4c9e8ba commit 0c7ce8c
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 14 deletions.
1 change: 1 addition & 0 deletions client/starwhale/consts/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ class SWDSSubFileType:

DEFAULT_PAGE_IDX = 1
DEFAULT_PAGE_SIZE = 20
DEFAULT_REPORT_COLS = 20

RECOVER_DIRNAME = ".recover"
OBJECT_STORE_DIRNAME = ".objectstore"
Expand Down
13 changes: 11 additions & 2 deletions client/starwhale/core/eval/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from starwhale.consts.env import SWEnv

from .view import JobTermView, get_term_view, DEFAULT_PAGE_IDX, DEFAULT_PAGE_SIZE
from ...consts import DEFAULT_REPORT_COLS


@click.group(
Expand Down Expand Up @@ -175,9 +176,17 @@ def _cancel(job: str, force: bool) -> None:
@click.option(
"--size", type=int, default=DEFAULT_PAGE_SIZE, help="Page size for tasks list"
)
@click.option(
"--max-report-cols",
type=int,
default=DEFAULT_REPORT_COLS,
help="Max table column size for print",
)
@click.pass_obj
def _info(view: t.Type[JobTermView], job: str, page: int, size: int) -> None:
view(job).info(page, size)
def _info(
view: t.Type[JobTermView], job: str, page: int, size: int, max_report_cols: int
) -> None:
view(job).info(page, size, max_report_cols)


@eval_job_cmd.command("compare", aliases=["cmp"])
Expand Down
53 changes: 41 additions & 12 deletions client/starwhale/core/eval/view.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
DEFAULT_PAGE_IDX,
DEFAULT_PAGE_SIZE,
SHORT_VERSION_CNT,
DEFAULT_REPORT_COLS,
DEFAULT_MANIFEST_NAME,
)
from starwhale.base.uri import URI
Expand Down Expand Up @@ -112,7 +113,12 @@ def compare(self, job_uris: t.List[str]) -> None:
console.print(table)

@BaseTermView._header
def info(self, page: int = DEFAULT_PAGE_IDX, size: int = DEFAULT_PAGE_SIZE) -> None:
def info(
self,
page: int = DEFAULT_PAGE_IDX,
size: int = DEFAULT_PAGE_SIZE,
max_report_cols: int = DEFAULT_REPORT_COLS,
) -> None:
_rt = self.job.info(page, size)
if not _rt:
console.print(":tea: not found info")
Expand Down Expand Up @@ -140,7 +146,9 @@ def info(self, page: int = DEFAULT_PAGE_IDX, size: int = DEFAULT_PAGE_SIZE) -> N
self._render_summary_report(_report["summary"], _kind)

if _kind == MetricKind.MultiClassification.value:
self._render_multi_classification_job_report(_rt["report"])
self._render_multi_classification_job_report(
_rt["report"], max_report_cols
)

def _render_summary_report(self, summary: t.Dict[str, t.Any], kind: str) -> None:
console.rule(f"[bold green]{kind.upper()} Summary")
Expand Down Expand Up @@ -175,7 +183,7 @@ def _print_tasks(self, tasks: t.List[t.Dict[str, t.Any]]) -> None:
console.print(table)

def _render_multi_classification_job_report(
self, report: t.Dict[str, t.Any]
self, report: t.Dict[str, t.Any], max_report_cols: int
) -> None:
if not report:
console.print(":turtle: no report")
Expand All @@ -191,14 +199,23 @@ def _print_labels() -> None:
table = Table(box=box.SIMPLE)
table.add_column("Label", style="cyan")
keys = labels[sort_label_names[0]]
for _k in keys:
for idx, _k in enumerate(keys):
if _k == "id":
continue
table.add_column(_k.capitalize())
if idx < max_report_cols:
table.add_column(_k.capitalize())
else:
table.add_column("...")
break

for _k, _v in labels.items():
table.add_row(
_k, *(f"{float(_v[_k2]):.4f}" for _k2 in keys if _k2 != "id")
_k,
*(
f"{float(_v[_k2]):.4f}"
for idx, _k2 in enumerate(keys)
if _k2 != "id" and idx < max_report_cols + 1
),
)

console.rule(f"[bold green]{report['kind'].upper()} Label Metrics Report")
Expand All @@ -209,16 +226,23 @@ def _print_confusion_matrix() -> None:
if not cm:
return

btable = Table(box=box.SIMPLE)
btable = Table(box=box.ROUNDED)
btable.add_column("", style="cyan")
for n in sort_label_names:
btable.add_column(n)
for idx, n in enumerate(sort_label_names):
if idx < max_report_cols:
btable.add_column(n)
else:
btable.add_column("...")
break
for idx, bl in enumerate(cm.get("binarylabel", [])):
btable.add_row(
sort_label_names[idx],
*[f"{float(bl[i]):.4f}" for i in bl if i != "id"],
*[
f"{float(bl[i]):.4f}"
for idx, i in enumerate(bl)
if i != "id" and idx < max_report_cols + 1
],
)

console.rule(f"[bold green]{report['kind'].upper()} Confusion Matrix")
console.print(btable)

Expand Down Expand Up @@ -385,7 +409,12 @@ def list( # type: ignore
_data, _ = super().list(project_uri, fullname, show_removed, page, size)
cls.pretty_json(_data)

def info(self, page: int = DEFAULT_PAGE_IDX, size: int = DEFAULT_PAGE_SIZE) -> None:
def info(
self,
page: int = DEFAULT_PAGE_IDX,
size: int = DEFAULT_PAGE_SIZE,
max_report_cols: int = DEFAULT_REPORT_COLS,
) -> None:
_rt = self.job.info(page, size)
if not _rt:
console.print(":tea: not found info")
Expand Down

0 comments on commit 0c7ce8c

Please sign in to comment.