diff --git a/esqa/cli.py b/esqa/cli.py index f887e4c..9269cd4 100644 --- a/esqa/cli.py +++ b/esqa/cli.py @@ -5,6 +5,7 @@ import click +from esqa.distance import load_rankings, compare_rankings from esqa.save import RankingSaver from esqa.validation_config import load from esqa.validator import Validator @@ -47,6 +48,19 @@ def check(config, index): def save(config, index): runner = RankingSaver() results = runner.run(config=load(config), index_name=index) + print(_dump(list(results.values()))) + + +@main.command() +@click.option("-r", "--ranking", type=str, help="ranking file") +@click.option("-c", "--config", type=str, help="configuration file") +@click.option("-t", "--threshold", type=float, help="threshold", default=0.7) +@click.option("--index", type=str, help="target index name", required=True) +def ranking(ranking, config, threshold, index): + runner = RankingSaver() + rankings = runner.run(config=load(config), index_name=index) + compared_rankings = load_rankings(ranking) + results = compare_rankings(rankings, compared_rankings, threshold) print(_dump(results)) diff --git a/esqa/distance.py b/esqa/distance.py new file mode 100644 index 0000000..a158458 --- /dev/null +++ b/esqa/distance.py @@ -0,0 +1,48 @@ +import dataclasses +import json +import rbo +from typing import Dict, List + +from esqa.save import Ranking + + +@dataclasses.dataclass +class FailedRanking: + name: str + similarity: float + ranking_pair: List[tuple] + + +def load_rankings(path: str) -> Dict: + with open(path) as f: + rankings = json.load(f) + results = {} + for ranking in rankings: + results[ranking["name"]] = Ranking(ranking["name"], ranking["query"], ranking["ranking"]) + return results + + +def _extract(ranking: Ranking) -> List[str]: + return [e["id"] for e in ranking.ranking] + + +def _compare(ranking_a, ranking_b): + return rbo.rbo.RankingSimilarity(ranking_a, ranking_b).rbo() + + +def _generate(ranking_a: Ranking, ranking_b: Ranking, similarity: float): + return FailedRanking( + name=ranking_a.name, + similarity=similarity, + ranking_pair=list(zip(_extract(ranking_a), _extract(ranking_b))) + ) + + +def compare_rankings(rankings_a: Dict[str, Ranking], rankings_b: Dict[str, Ranking], threshold: float) -> List[FailedRanking]: + results = [] + for ranking_name in rankings_a: + similarity = _compare(_extract(rankings_a[ranking_name]), _extract(rankings_b[ranking_name])) + if similarity > threshold: + continue + results.append(_generate(rankings_a[ranking_name], rankings_b[ranking_name], similarity)) + return results diff --git a/esqa/save.py b/esqa/save.py index 277950f..eac05a9 100644 --- a/esqa/save.py +++ b/esqa/save.py @@ -1,5 +1,5 @@ import dataclasses -from typing import List +from typing import List, Dict from elasticsearch import Elasticsearch @@ -11,7 +11,6 @@ class Ranking: name: str query: dict - asserts: List[EsAssert] ranking: List[dict] @@ -21,13 +20,14 @@ class RankingSaver: def __init__(self): self.client = Elasticsearch([ELASTICSEARCH_URL]) - def run(self, config: Configuration, index_name: str): - results = [] + def run(self, config: Configuration, index_name: str) -> Dict[str, Ranking]: + results = {} for case in config.cases: - results.append(self._get(case, index_name)) + ranking = self._get(case, index_name) + results[ranking.name] = ranking return results - def _get(self, case: Case, index_name: str): + def _get(self, case: Case, index_name: str) -> Ranking: search_results = self.client.search(body=case.query, index=index_name) return self._format(search_results, case) @@ -35,6 +35,5 @@ def _format(self, search_results: dict, case: Case) -> Ranking: return Ranking( case.name, case.query, - case.asserts, [{"id": candidate["_id"], "source": candidate["_source"]} for i, candidate in enumerate(search_results["hits"]["hits"])] ) diff --git a/poetry.lock b/poetry.lock index 25e6948..aebc3dc 100644 --- a/poetry.lock +++ b/poetry.lock @@ -113,6 +113,14 @@ category = "dev" optional = false python-versions = "*" +[[package]] +name = "numpy" +version = "1.21.1" +description = "NumPy is the fundamental package for array computing with Python." +category = "main" +optional = false +python-versions = ">=3.7" + [[package]] name = "pathspec" version = "0.9.0" @@ -149,6 +157,17 @@ category = "dev" optional = false python-versions = "*" +[[package]] +name = "rbo" +version = "0.1.2" +description = "Simple library to calculate Rank-biased Overlap between two lists" +category = "main" +optional = false +python-versions = ">=3.7,<4.0" + +[package.dependencies] +numpy = ">=1.18,<2.0" + [[package]] name = "toml" version = "0.10.2" @@ -209,7 +228,7 @@ testing = ["pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-flake8", "pytest- [metadata] lock-version = "1.1" python-versions = "^3.7" -content-hash = "3fb9de2890fd8c8c0863bd172c8000c1782a4511639353882707d92bd4f81cb2" +content-hash = "913450220e138173313894eec328c620b035d5774fdcf5555ad964df60c40f97" [metadata.files] black = [ @@ -269,6 +288,36 @@ mypy-extensions = [ {file = "mypy_extensions-0.4.3-py2.py3-none-any.whl", hash = "sha256:090fedd75945a69ae91ce1303b5824f428daf5a028d2f6ab8a299250a846f15d"}, {file = "mypy_extensions-0.4.3.tar.gz", hash = "sha256:2d82818f5bb3e369420cb3c4060a7970edba416647068eb4c5343488a6c604a8"}, ] +numpy = [ + {file = "numpy-1.21.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:38e8648f9449a549a7dfe8d8755a5979b45b3538520d1e735637ef28e8c2dc50"}, + {file = "numpy-1.21.1-cp37-cp37m-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:fd7d7409fa643a91d0a05c7554dd68aa9c9bb16e186f6ccfe40d6e003156e33a"}, + {file = "numpy-1.21.1-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:a75b4498b1e93d8b700282dc8e655b8bd559c0904b3910b144646dbbbc03e062"}, + {file = "numpy-1.21.1-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1412aa0aec3e00bc23fbb8664d76552b4efde98fb71f60737c83efbac24112f1"}, + {file = "numpy-1.21.1-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:e46ceaff65609b5399163de5893d8f2a82d3c77d5e56d976c8b5fb01faa6b671"}, + {file = "numpy-1.21.1-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:c6a2324085dd52f96498419ba95b5777e40b6bcbc20088fddb9e8cbb58885e8e"}, + {file = "numpy-1.21.1-cp37-cp37m-win32.whl", hash = "sha256:73101b2a1fef16602696d133db402a7e7586654682244344b8329cdcbbb82172"}, + {file = "numpy-1.21.1-cp37-cp37m-win_amd64.whl", hash = "sha256:7a708a79c9a9d26904d1cca8d383bf869edf6f8e7650d85dbc77b041e8c5a0f8"}, + {file = "numpy-1.21.1-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:95b995d0c413f5d0428b3f880e8fe1660ff9396dcd1f9eedbc311f37b5652e16"}, + {file = "numpy-1.21.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:635e6bd31c9fb3d475c8f44a089569070d10a9ef18ed13738b03049280281267"}, + {file = "numpy-1.21.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:4a3d5fb89bfe21be2ef47c0614b9c9c707b7362386c9a3ff1feae63e0267ccb6"}, + {file = "numpy-1.21.1-cp38-cp38-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:8a326af80e86d0e9ce92bcc1e65c8ff88297de4fa14ee936cb2293d414c9ec63"}, + {file = "numpy-1.21.1-cp38-cp38-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:791492091744b0fe390a6ce85cc1bf5149968ac7d5f0477288f78c89b385d9af"}, + {file = "numpy-1.21.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0318c465786c1f63ac05d7c4dbcecd4d2d7e13f0959b01b534ea1e92202235c5"}, + {file = "numpy-1.21.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:9a513bd9c1551894ee3d31369f9b07460ef223694098cf27d399513415855b68"}, + {file = "numpy-1.21.1-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:91c6f5fc58df1e0a3cc0c3a717bb3308ff850abdaa6d2d802573ee2b11f674a8"}, + {file = "numpy-1.21.1-cp38-cp38-win32.whl", hash = "sha256:978010b68e17150db8765355d1ccdd450f9fc916824e8c4e35ee620590e234cd"}, + {file = "numpy-1.21.1-cp38-cp38-win_amd64.whl", hash = "sha256:9749a40a5b22333467f02fe11edc98f022133ee1bfa8ab99bda5e5437b831214"}, + {file = "numpy-1.21.1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:d7a4aeac3b94af92a9373d6e77b37691b86411f9745190d2c351f410ab3a791f"}, + {file = "numpy-1.21.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:d9e7912a56108aba9b31df688a4c4f5cb0d9d3787386b87d504762b6754fbb1b"}, + {file = "numpy-1.21.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:25b40b98ebdd272bc3020935427a4530b7d60dfbe1ab9381a39147834e985eac"}, + {file = "numpy-1.21.1-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:8a92c5aea763d14ba9d6475803fc7904bda7decc2a0a68153f587ad82941fec1"}, + {file = "numpy-1.21.1-cp39-cp39-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:05a0f648eb28bae4bcb204e6fd14603de2908de982e761a2fc78efe0f19e96e1"}, + {file = "numpy-1.21.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f01f28075a92eede918b965e86e8f0ba7b7797a95aa8d35e1cc8821f5fc3ad6a"}, + {file = "numpy-1.21.1-cp39-cp39-win32.whl", hash = "sha256:88c0b89ad1cc24a5efbb99ff9ab5db0f9a86e9cc50240177a571fbe9c2860ac2"}, + {file = "numpy-1.21.1-cp39-cp39-win_amd64.whl", hash = "sha256:01721eefe70544d548425a07c80be8377096a54118070b8a62476866d5208e33"}, + {file = "numpy-1.21.1-pp37-pypy37_pp73-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:2d4d1de6e6fb3d28781c73fbde702ac97f03d79e4ffd6598b880b2d95d62ead4"}, + {file = "numpy-1.21.1.zip", hash = "sha256:dff4af63638afcc57a3dfb9e4b26d434a7a602d225b42d746ea7fe2edf1342fd"}, +] pathspec = [ {file = "pathspec-0.9.0-py2.py3-none-any.whl", hash = "sha256:7d15c4ddb0b5c802d161efc417ec1a2558ea2653c2e8ad9c19098201dc1c993a"}, {file = "pathspec-0.9.0.tar.gz", hash = "sha256:e564499435a2673d586f6b2130bb5b95f04a3ba06f81b8f895b651a3c76aabb1"}, @@ -285,6 +334,10 @@ pyflakes = [ {file = "pyflakes-1.2.3-py2.py3-none-any.whl", hash = "sha256:e87bac26c62ea5b45067cc89e4a12f56e1483f1f2cda17e7c9b375b9fd2f40da"}, {file = "pyflakes-1.2.3.tar.gz", hash = "sha256:2e4a1b636d8809d8f0a69f341acf15b2e401a3221ede11be439911d23ce2139e"}, ] +rbo = [ + {file = "rbo-0.1.2-py3-none-any.whl", hash = "sha256:588f720e928930a01e4631c33623b80f7a892b30e1ac46d7a5cb5256e27a5dbe"}, + {file = "rbo-0.1.2.tar.gz", hash = "sha256:fae72f4f59f441417c79a1628db8257e75f6a72e6d780544b46131cfd9ed1e1f"}, +] toml = [ {file = "toml-0.10.2-py2.py3-none-any.whl", hash = "sha256:806143ae5bfb6a3c6e736a764057db0e6a0e05e338b5630894a5f779cabb4f9b"}, {file = "toml-0.10.2.tar.gz", hash = "sha256:b3bda1d108d5dd99f4a20d24d9c348e91c4db7ab1b749200bded2f839ccbe68f"}, diff --git a/pyproject.toml b/pyproject.toml index 85ac0e8..8f09743 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,6 +9,7 @@ readme = "README.md" python = "^3.7" click = ">=8.0.0" elasticsearch = "7.10.1" +rbo = "^0.1.2" [tool.poetry.dev-dependencies] toml = "^0.10.0"