Skip to content

Commit

Permalink
Update benchmarks to use argparse #586
Browse files Browse the repository at this point in the history
  • Loading branch information
davidmezzetti committed Oct 30, 2023
1 parent 7774fe7 commit a688966
Showing 1 changed file with 73 additions and 51 deletions.
124 changes: 73 additions & 51 deletions examples/benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
pip install txtai pytrec_eval rank-bm25 elasticsearch psutil
"""

import argparse
import csv
import json
import os
import pickle
import sqlite3
import sys
import time

import psutil
Expand All @@ -34,16 +34,18 @@ class Index:
Base index definition. Defines methods to index and search a dataset.
"""

def __init__(self, path, refresh=True):
def __init__(self, path, config, refresh):
"""
Creates a new index.
Args:
path: path to dataset
config: path to config file
refresh: overwrites existing index if True, otherwise existing index is loaded
"""

self.path = path
self.config = config
self.refresh = refresh

# Build and save index
Expand Down Expand Up @@ -139,10 +141,10 @@ def batch(self, data, size):

return [data[x : x + size] for x in range(0, len(data), size)]

def config(self, key, default):
def readconfig(self, key, default):
"""
Reads configuration from a config.yml file. Returns default configuration
if config.yml file is not found or config key isn't present.
Reads configuration from a config file. Returns default configuration
if config file is not found or config key isn't present.
Args:
key: configuration key to lookup
Expand All @@ -152,9 +154,9 @@ def config(self, key, default):
config if found, otherwise returns default config
"""

if os.path.exists("config.yml"):
if self.config and os.path.exists(self.config):
# Read configuration
with open("config.yml", "r", encoding="utf-8") as f:
with open(self.config, "r", encoding="utf-8") as f:
# Check for config
config = yaml.safe_load(f)
if key in config:
Expand All @@ -170,7 +172,7 @@ class Score(Index):

def index(self):
# Read configuration
config = self.config("scoring", {"method": "bm25", "terms": True})
config = self.readconfig("scoring", {"method": "bm25", "terms": True})

# Create scoring instance
scoring = ScoringFactory.create(config)
Expand All @@ -197,7 +199,7 @@ def index(self):
embeddings.load(path)
else:
# Read configuration
config = self.config("embed", {"batch": 8192, "encodebatch": 128, "faiss": {"quantize": True, "sample": 0.05}})
config = self.readconfig("embed", {"batch": 8192, "encodebatch": 128, "faiss": {"quantize": True, "sample": 0.05}})

# Build index
embeddings = Embeddings(config)
Expand All @@ -219,7 +221,7 @@ def index(self):
embeddings.load(path)
else:
# Read configuration
config = self.config(
config = self.readconfig(
"hybrid",
{
"batch": 8192,
Expand Down Expand Up @@ -362,33 +364,6 @@ def index(self):
return es


def create(name, path, refresh):
"""
Creates a new index.
Args:
path: dataset path
refresh: overwrites existing index if True, otherwise existing index is loaded
Returns:
Index
"""

if name == "embed":
return Embed(path, refresh)
if name == "es":
return Elastic(path, refresh)
if name == "hybrid":
return Hybrid(path, refresh)
if name == "sqlite":
return SQLiteFTS(path, refresh)
if name == "rank":
return RankBM25(path, refresh)

# Default
return Score(path, refresh)


def relevance(path):
"""
Loads relevance data for evaluation.
Expand All @@ -415,6 +390,35 @@ def relevance(path):
return rel


def create(method, path, config, refresh):
"""
Creates a new index.
Args:
method: indexing method
path: dataset path
config: path to config file
refresh: overwrites existing index if True, otherwise existing index is loaded
Returns:
Index
"""

if method == "es":
return Elastic(path, config, refresh)
if method == "hybrid":
return Hybrid(path, config, refresh)
if method == "scoring":
return Score(path, config, refresh)
if method == "sqlite":
return SQLiteFTS(path, config, refresh)
if method == "rank":
return RankBM25(path, config, refresh)

# Default
return Embed(path, config, refresh)


def compute(results):
"""
Computes metrics using the results from an evaluation run.
Expand All @@ -437,13 +441,14 @@ def compute(results):
return {metric: round(np.mean(values), 5) for metric, values in metrics.items()}


def evaluate(path, methods):
def evaluate(methods, path, args):
"""
Runs an evaluation.
Args:
path: path to dataset
methods: list of indexing methods to test
path: path to dataset
args: command line arguments
Returns:
{calculated performance metrics}
Expand All @@ -455,7 +460,7 @@ def evaluate(path, methods):
performance = {}

# Calculate stats for each model type
topk, refresh = 10, True
topk = args.topk
evaluator = RelevanceEvaluator(relevance(path), {f"ndcg_cut.{topk}", f"map_cut.{topk}", f"recall.{topk}", f"P.{topk}"})
for method in methods:
# Stats for this source
Expand All @@ -464,21 +469,22 @@ def evaluate(path, methods):

# Create index and get results
start = time.time()
index = create(method, path, refresh)
index = create(method, path, args.config, args.refresh)

# Add indexing metrics
stats["index"] = round(time.time() - start, 2)
stats["memory"] = int(psutil.Process().memory_info().rss / (1024 * 1024))
stats["disk"] = int(sum(d.stat().st_size for d in os.scandir(f"{path}/{method}") if d.is_file()) / 1024)

print("INDEX TIME =", time.time() - start)
print(f"MEMORY USAGE = {psutil.Process().memory_info().rss / (1024 * 1024)} MB")
print(f"MEMORY USAGE = {stats['memory']} MB")
print(f"DISK USAGE = {stats['disk']} KB")

start = time.time()
results = index(topk)

# Add search metrics
stats["search"] = round(time.time() - start, 2)

print("SEARCH TIME =", time.time() - start)

# Calculate stats
Expand All @@ -499,17 +505,19 @@ def evaluate(path, methods):
return performance


def benchmarks():
def benchmarks(args):
"""
Main benchmark execution method.
Args:
args: command line arguments
"""

# Directory where BEIR datasets are stored
directory = sys.argv[1] if len(sys.argv) > 1 else "beir"
directory = args.directory if args.directory else "beir"

if len(sys.argv) > 3:
sources = [sys.argv[2]]
methods = [sys.argv[3]]
if args.sources and args.methods:
sources, methods = args.sources.split(","), args.methods.split(",")
mode = "a"
else:
# Default sources and methods
Expand All @@ -535,16 +543,30 @@ def benchmarks():
with open("benchmarks.json", mode, encoding="utf-8") as f:
for source in sources:
# Run evaluations
results = evaluate(f"{directory}/{source}", methods)
results = evaluate(methods, f"{directory}/{source}", args)

# Save as JSON lines output
for method, stats in results.items():
stats["source"] = source
stats["method"] = method
stats["name"] = args.name if args.name else method

json.dump(stats, f)
f.write("\n")


# Calculate benchmarks
benchmarks()
if __name__ == "__main__":
# Command line parser
parser = argparse.ArgumentParser(description="Benchmarks")
parser.add_argument("-c", "--config", help="path to config file", metavar="CONFIG")
parser.add_argument("-d", "--directory", help="root directory path with datasets", metavar="DIRECTORY")
parser.add_argument("-m", "--methods", help="comma separated list of methods", metavar="METHODS")
parser.add_argument("-n", "--name", help="name to assign to this run, defaults to method name", metavar="NAME")
parser.add_argument(
"-r", "--refresh", help="refreshes index if set, otherwise uses existing index if available", action="store_true", default=True
)
parser.add_argument("-s", "--sources", help="comma separated list of data sources", metavar="SOURCES")
parser.add_argument("-t", "--topk", help="top k results to use for the evaluation", metavar="TOPK", default=10)

# Calculate benchmarks
benchmarks(parser.parse_args())

0 comments on commit a688966

Please sign in to comment.