Skip to content

Commit

Permalink
test: add basic cb test and configuration
Browse files Browse the repository at this point in the history
  • Loading branch information
michiboo committed Jun 9, 2023
1 parent bceec84 commit 6991585
Show file tree
Hide file tree
Showing 9 changed files with 219 additions and 92 deletions.
49 changes: 31 additions & 18 deletions python/tests/assert_job.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import numpy as np
from numpy.testing import assert_allclose, assert_array_almost_equal
from numpy.testing import assert_allclose, assert_almost_equal
from vw_executor.vw import ExecutionStatus


Expand All @@ -9,11 +9,13 @@ def get_from_kwargs(kwargs, key, default=None):
else:
return default


def majority_close(arr1, arr2, rtol, atol, threshold):
# Check if the majority of elements are close
close_count = np.count_nonzero(np.isclose(arr1, arr2, rtol=rtol, atol=atol))
return close_count > len(arr1) * threshold


def assert_weight(job, **kwargs):
atol = get_from_kwargs(kwargs, "atol", 10e-8)
rtol = get_from_kwargs(kwargs, "rtol", 10e-5)
Expand All @@ -23,23 +25,34 @@ def assert_weight(job, **kwargs):
with open(data[0], "r") as f:
data = f.readlines()
data = [i.strip() for i in data]
weights = job[0].model9('--readable_model').weights
weights = job[0].model9("--readable_model").weights
weights = weights["weight"].to_list()
assert_allclose(weights, expected_weights, atol=atol, rtol=rtol), f"weights should be {expected_weights}"

def assert_prediction(job, **kwargs):
assert job.status == ExecutionStatus.Success, "job should be successful"
atol = kwargs.get("atol", 10e-8)
rtol = kwargs.get("rtol", 10e-5)
threshold = kwargs.get("threshold", 0.9)
constant = kwargs["expected_value"]
predictions = job.outputs['-p']
with open(predictions[0], "r") as f:
predictions = f.readlines()
predictions = [float(i) for i in predictions[1:]]
assert majority_close(predictions, [constant]*len(predictions), rtol=rtol, atol=atol, threshold=threshold), f"predicted value should be {constant}"

assert_allclose(
weights, expected_weights, atol=atol, rtol=rtol
), f"weights should be {expected_weights}"


def assert_functions():
return
def assert_prediction(job, **kwargs):
assert job.status == ExecutionStatus.Success, "job should be successful"
atol = kwargs.get("atol", 10e-8)
rtol = kwargs.get("rtol", 10e-5)
threshold = kwargs.get("threshold", 0.9)
constant = kwargs["expected_value"]
predictions = job.outputs["-p"]
with open(predictions[0], "r") as f:
predictions = [i.strip() for i in f.readlines()]
predictions = [i for i in predictions if i != ""]
predictions = [float(i) for i in predictions[1:]]
assert majority_close(
predictions,
[constant] * len(predictions),
rtol=rtol,
atol=atol,
threshold=threshold,
), f"predicted value should be {constant}"


def assert_loss(job, **kwargs):
assert job.status == ExecutionStatus.Success, "job should be successful"
assert type(job[0].loss) == float, "loss should be an float"
assert_almost_equal(job[0].loss, kwargs["expected_loss"])
52 changes: 38 additions & 14 deletions python/tests/test_regression.py → python/tests/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,36 @@
import numpy as np
import pytest
import os
from test_helper import json_to_dict_list, dynamic_function_call, get_function_object, generate_string_combinations
from test_helper import (
json_to_dict_list,
dynamic_function_call,
get_function_object,
generate_string_combinations,
)

CURR_DICT = os.path.dirname(os.path.abspath(__file__))


def combine_list_cmds_grids(cmds, base_grid):
list_of_key_val = []
grids = []
for key, value in cmds.items():
value = [i for i in value if i != ""]
if str(value).isdigit():
list_of_key_val.append([f" {key} {format(li, '.5f').rstrip('0').rstrip('.') }" for li in value])
list_of_key_val.append(
[f" {key} {format(li, '.5f').rstrip('0').rstrip('.') }" for li in value]
)
else:
list_of_key_val.append([f" {key} {li}" for li in value])
for new_cmd in generate_string_combinations([base_grid["#base"][0]], *list_of_key_val):
for new_cmd in generate_string_combinations(
[base_grid["#base"][0]], *list_of_key_val
):
tmp_grid = base_grid.copy()
tmp_grid["#base"][0] = new_cmd
grids.append(tmp_grid)
return grids


def cleanup_data_file():
script_directory = os.path.dirname(os.path.realpath(__file__))
# List all files in the directory
Expand All @@ -33,11 +44,12 @@ def cleanup_data_file():
if file.endswith(".txt"):
file_path = os.path.join(script_directory, file)
os.remove(file_path)



@pytest.fixture
def test_description(request):
resource = request.param
yield resource #
yield resource #
cleanup_data_file()


Expand All @@ -48,23 +60,35 @@ def core_test(files, grid, outputs, job_assert, job_assert_args):
job_assert(j, **job_assert_args)


@pytest.mark.parametrize('test_description', json_to_dict_list("pytest.json"), indirect=True)
@pytest.mark.parametrize(
"test_description", json_to_dict_list("test_cb.json"), indirect=True
)
def test_all(test_description):

mutiply = test_description.get("*", None)
plus = test_description.get("+", None)

base_grid = test_description['grid']
base_grid = test_description["grid"]
grids = []
if mutiply:
grids = combine_list_cmds_grids(mutiply, base_grid)
else:
grids.append(base_grid)

for grid in grids:
options = Grid(
grid
options = Grid(grid)
data = dynamic_function_call(
"data_generation",
test_description["data_func"],
*test_description["data_func_args"].values(),
)
data = dynamic_function_call("data_generation", test_description['data_func'], *test_description["data_func_args"])
assert_job = get_function_object("assert_job", test_description['assert_func'])
core_test(data, options, test_description['output'], assert_job, test_description['assert_func_args'])
for assert_func in test_description["assert_functions"]:
assert_job = get_function_object("assert_job", assert_func["assert_func"])
script_directory = os.path.dirname(os.path.realpath(__file__))
core_test(
script_directory + data,
options,
test_description["output"],
assert_job,
assert_func["assert_func_args"],
)
36 changes: 35 additions & 1 deletion python/tests/data_generation.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,44 @@
import random
import os
from test_helper import get_function_object

script_directory = os.path.dirname(os.path.realpath(__file__))


def constant_function(no_sample, constant, lower_bound, upper_bound):
dataFile = f"constant_func_{no_sample}_{constant}_{upper_bound}_{lower_bound}.txt"
with open(dataFile, "w") as f:
with open(script_directory + "/" + dataFile, "w") as f:
random.seed(10)
for _ in range(no_sample):
x = random.uniform(lower_bound, upper_bound)
f.write(f"{constant} |f x:{x}\n")
return dataFile


def random_number_items(items):
num_items_to_select = random.randint(1, len(items))
return random.sample(items, num_items_to_select)


def generate_cb_data(
num_examples, num_features, num_actions, reward_function, probability_function
):
reward_function_obj = get_function_object(
"reward_functions", reward_function["name"]
)
probability_function_obj = get_function_object(
"probability_functions", probability_function["name"]
)
dataFile = f"cb_test_{num_examples}_{num_actions}_{num_features}.txt"
features = [f"feature{index}" for index in range(1, num_features + 1)]
with open(script_directory + "/" + dataFile, "w") as f:
for _ in range(num_examples):
chosen_action = random.randint(1, num_actions)
cost = reward_function_obj(chosen_action, **reward_function["params"])
probability = probability_function_obj(
chosen_action, **probability_function["params"]
)
f.write(
f'{chosen_action}:{cost}:{probability} | {" ".join(random_number_items(features))}\n'
)
return dataFile
2 changes: 2 additions & 0 deletions python/tests/probability_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
def constant_probability(chosen_action=None):
return 1
54 changes: 0 additions & 54 deletions python/tests/pytest.json

This file was deleted.

2 changes: 2 additions & 0 deletions python/tests/reward_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
def constant_reward(chosen_action=None):
return 1
60 changes: 60 additions & 0 deletions python/tests/test_cb.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
[
{
"data_func": "generate_cb_data",
"data_func_args": {
"num_examples": 100,
"num_features": 1,
"num_action": 1,
"reward_function": {
"name": "constant_reward",
"params": {}
},
"probability_function": {
"name": "constant_probability",
"params": {}
}
},
"assert_functions": [
{
"assert_func": "assert_loss",
"assert_func_args": {"expected_loss": 1}
},
{
"assert_func": "assert_prediction",
"assert_func_args": {
"expected_value": 0,
"threshold": 0.5
}
},
{
"assert_func": "assert_weight",
"assert_func_args": {
"expected_weights": [
5,
0
],
"atol": 100,
"rtol": 100
}
}
],
"grid": {
"#base": [
"--cb 1 -P 10000 --preserve_performance_counters --save_resume"
],
"--cb_type": [
"ips",
"mtr"
]
},
"*": {
"--cb": [
1
]
},
"output": [
"--readable_model",
"-p"
]
}
]
10 changes: 5 additions & 5 deletions python/tests/test_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@
# Get the current directory
current_dir = os.path.dirname(os.path.abspath(__file__))


def json_to_dict_list(file):
with open(current_dir + "/" + file, 'r') as file:
with open(current_dir + "/" + file, "r") as file:
# Load the JSON data
return json.load(file)


def dynamic_function_call(module_name, function_name, *args, **kwargs):
try:
Expand All @@ -36,7 +37,6 @@ def get_function_object(module_name, function_name):
print(f"Function '{function_name}' not found in module '{module_name}'.")



def generate_test_function(test_data):
@pytest.dynamic
def test_dynamic():
Expand All @@ -59,9 +59,9 @@ def generate_pytest_from_json(filepath):
for test_case in json_data:
test_function = generate_test_function(test_case)
globals()[test_function.__name__] = test_function


def generate_string_combinations(*lists):
combinations = list(itertools.product(*lists))
combinations = [''.join(combination) for combination in combinations]
combinations = ["".join(combination) for combination in combinations]
return combinations
Loading

0 comments on commit 6991585

Please sign in to comment.