Skip to content

Commit

Permalink
A B C D
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 570979575
  • Loading branch information
achoum authored and copybara-github committed Oct 5, 2023
1 parent a576c23 commit 373e434
Show file tree
Hide file tree
Showing 5 changed files with 124 additions and 55 deletions.
2 changes: 1 addition & 1 deletion yggdrasil_decision_forests/port/python/ydf/metric/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,11 @@ py_test(
python_version = "PY3",
deps = [
":metric",
# absl/flags dep,
# absl/testing:absltest dep,
# numpy dep,
"@ydf_cc//yggdrasil_decision_forests/dataset:data_spec_py_proto",
"@ydf_cc//yggdrasil_decision_forests/metric:metric_py_proto",
"//ydf/utils:test_utils",
"@ydf_cc//yggdrasil_decision_forests/utils:distribution_py_proto",
],
)
65 changes: 34 additions & 31 deletions yggdrasil_decision_forests/port/python/ydf/metric/display_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from typing import Any, Optional, Tuple
from xml.dom import minidom

# TODO: Add matplotlib as a requirement, or fail.
import matplotlib.pyplot as plt

from ydf.metric import metric
Expand Down Expand Up @@ -488,42 +489,44 @@ def _object_to_html(
def _plot_roc(characteristic: metric.Characteristic):
"""Plots a ROC curve."""

fig, ax = plt.subplots(1, figsize=(4, 4))
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)
ax.set_box_aspect(1)
ax.plot([0, 1], [0, 1], linestyle="--", color="black", linewidth=0.5)
ax.plot(
characteristic.false_positive_rates,
characteristic.recalls,
color="red",
linewidth=0.5,
)
ax.set_xlabel("false positive rate")
ax.set_ylabel("true positive rate (recall)")
ax.grid()
fig.tight_layout()
return fig
with plt.ioff():
fig, ax = plt.subplots(1, figsize=(4, 4))
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)
ax.set_box_aspect(1)
ax.plot([0, 1], [0, 1], linestyle="--", color="black", linewidth=0.5)
ax.plot(
characteristic.false_positive_rates,
characteristic.recalls,
color="red",
linewidth=0.5,
)
ax.set_xlabel("false positive rate")
ax.set_ylabel("true positive rate (recall)")
ax.grid()
fig.tight_layout()
return fig


def _plot_pr(characteristic: metric.Characteristic):
"""Plots a precision-recall curve."""

fig, ax = plt.subplots(1, figsize=(4, 4))
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)
ax.set_box_aspect(1)
ax.plot(
characteristic.recalls,
characteristic.precisions,
color="red",
linewidth=0.5,
)
ax.set_xlabel("recall")
ax.set_ylabel("precision")
ax.grid()
fig.tight_layout()
return fig
with plt.ioff():
fig, ax = plt.subplots(1, figsize=(4, 4))
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)
ax.set_box_aspect(1)
ax.plot(
characteristic.recalls,
characteristic.precisions,
color="red",
linewidth=0.5,
)
ax.set_xlabel("recall")
ax.set_ylabel("precision")
ax.grid()
fig.tight_layout()
return fig


def _fig_to_dom(doc: html.Doc, fig) -> html.Elem:
Expand Down
33 changes: 10 additions & 23 deletions yggdrasil_decision_forests/port/python/ydf/metric/metric_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,32 +14,20 @@

"""Testing Metrics."""

import logging
import os
import textwrap

from absl import flags
from absl.testing import absltest
import numpy as np
from numpy import testing as npt

from yggdrasil_decision_forests.dataset import data_spec_pb2 as ds_pb
from yggdrasil_decision_forests.metric import metric_pb2
from ydf.metric import metric
from ydf.utils import test_utils
from yggdrasil_decision_forests.utils import distribution_pb2


def data_root_path() -> str:
return ""


def pydf_test_data_path() -> str:
return os.path.join(
data_root_path(),
"ydf/test_data",
)


class EvaluationTest(absltest.TestCase):

def test_no_metrics(self):
Expand Down Expand Up @@ -120,17 +108,16 @@ def test_all_metrics(self):
"""),
)

golden_path = os.path.join(
pydf_test_data_path(), "golden", "display_metric_to_html.html.expected"
test_utils.golden_check_string(
self,
e._repr_html_(),
os.path.join(
test_utils.pydf_test_data_path(),
"golden",
"display_metric_to_html.html.expected",
),
postfix=".html",
)
golden_data = open(golden_path).read()
effective_data = e._repr_html_()
if golden_data != effective_data:
effective_path = "/tmp/golden_test_value.html"
with open(effective_path, "w") as f:
f.write(effective_data)
logging.info("Saving effective data to %s", effective_path)
self.assertEqual(e._repr_html_(), golden_data)


class ConfusionTest(absltest.TestCase):
Expand Down
10 changes: 10 additions & 0 deletions yggdrasil_decision_forests/port/python/ydf/utils/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,16 @@ py_library(
srcs = ["documentation.py"],
)

py_library(
name = "test_utils",
testonly = True,
srcs = ["test_utils.py"],
deps = [
# absl/flags dep,
# absl/testing:absltest dep,
],
)

# Tests
# =====

Expand Down
69 changes: 69 additions & 0 deletions yggdrasil_decision_forests/port/python/ydf/utils/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# Copyright 2022 Google LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Utilities for unit tests."""

import logging
import os
from absl import flags
from absl.testing import absltest


def data_root_path() -> str:
"""Root directory of the repo."""
return ""


def ydf_test_data_path() -> str:
return os.path.join(
data_root_path(),
"external/ydf_cc/yggdrasil_decision_forests/test_data",
)


def pydf_test_data_path() -> str:
return os.path.join(
data_root_path(),
"ydf/test_data",
)


def golden_check_string(
test, value: str, golden_path: str, postfix: str = ""
) -> None:
"""Ensures that "value" is equal to the content of the file "golden_path".
Args:
test: A test.
value: Value to test.
golden_path: Path to golden file expressed from the root of the repo.
postfix: Optional postfix to the path of the file containing the actual
value.
"""

golden_data = open(os.path.join(data_root_path(), golden_path)).read()

if value != golden_data:
value_path = os.path.join(
absltest.TEST_TMPDIR.value, os.path.basename(golden_path) + postfix
)
logging.info("os.path.dirname(value_path): %s", os.path.dirname(value_path))
os.makedirs(os.path.dirname(value_path), exist_ok=True)
logging.info(
"Golden test failed. Save the effetive value to %s", value_path
)
with open(value_path, "w") as f:
f.write(value)

test.assertEqual(value, golden_data)

0 comments on commit 373e434

Please sign in to comment.