Skip to content

Commit

Permalink
test: [RLOS2023] add test for slate (#4629)
Browse files Browse the repository at this point in the history
* test: add test for slate

* test: test cleanup and slate test update

* test: minor cleanup and change assert_loss function to equal instead of lower
  • Loading branch information
michiboo authored Aug 7, 2023
1 parent 144912d commit 702604f
Show file tree
Hide file tree
Showing 13 changed files with 413 additions and 86 deletions.
10 changes: 8 additions & 2 deletions python/tests/test_framework/assert_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,17 @@ def assert_loss(job, **kwargs):
assert job.status == ExecutionStatus.Success, "job should be successful"
assert type(job[0].loss) == float, "loss should be an float"
decimal = kwargs.get("decimal", 2)
if job[0].loss < kwargs["expected_loss"]:
return
assert_almost_equal(job[0].loss, kwargs["expected_loss"], decimal=decimal)


def assert_loss_below(job, **kwargs):
assert job.status == ExecutionStatus.Success, "job should be successful"
assert type(job[0].loss) == float, "loss should be an float"
assert (
job[0].loss <= kwargs["expected_loss"]
), f"loss should be below {kwargs['expected_loss']}"


def assert_prediction_with_generated_data(job, **kwargs):
assert job.status == ExecutionStatus.Success, "job should be successful"
expected_class = []
Expand Down
13 changes: 8 additions & 5 deletions python/tests/test_framework/cb/data_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@ def generate_cb_data(
num_actions,
reward_function,
logging_policy,
no_context=1,
context_name=None,
context_name=["1"],
):

dataFile = f"cb_test_{num_examples}_{num_actions}_{num_features}.txt"
Expand All @@ -32,16 +31,20 @@ def generate_cb_data(
features = [f"feature{index}" for index in range(1, num_features + 1)]
with open(os.path.join(script_directory, dataFile), "w") as f:
for _ in range(num_examples):
no_context = len(context_name)
if no_context > 1:
context = random.randint(1, no_context)
if not context_name:
context_name = [f"{index}" for index in range(1, no_context + 1)]
else:
context = 1

def return_cost_probability(chosen_action, context=1):
cost = reward_function_obj(
cost = -reward_function_obj(
chosen_action, context, **reward_function["params"]
)
if "params" not in logging_policy:
logging_policy["params"] = {}
logging_policy["params"]["chosen_action"] = chosen_action
logging_policy["params"]["num_actions"] = num_actions
probability = logging_policy_obj(**logging_policy["params"])
return cost, probability

Expand Down
6 changes: 3 additions & 3 deletions python/tests/test_framework/classification/data_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,20 @@


def generate_classification_data(
num_sample,
num_example,
num_classes,
num_features,
classify_func,
bounds=None,
):
dataFile = f"classification_{num_classes}_{num_features}_{num_sample}.txt"
dataFile = f"classification_{num_classes}_{num_features}_{num_example}.txt"
classify_func_obj = get_function_object(
"classification.classification_functions", classify_func["name"]
)
if not bounds:
bounds = [[0, 1] for _ in range(num_features)]
with open(os.path.join(script_directory, dataFile), "w") as f:
for _ in range(num_sample):
for _ in range(num_example):
x = [
random.uniform(bounds[index][0], bounds[index][1])
for index in range(num_features)
Expand Down
7 changes: 7 additions & 0 deletions python/tests/test_framework/slate/action_space.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
def new_action_after_threshold(**kwargs):
iteration = kwargs.get("iteration", 0)
threshold = kwargs.get("threshold", 0)
# before iteration 500, it is sunny and after it is raining
if iteration > threshold:
return kwargs["after"]
return kwargs["before"]
43 changes: 43 additions & 0 deletions python/tests/test_framework/slate/assert_job.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from numpy.testing import assert_allclose, assert_almost_equal
from vw_executor.vw import ExecutionStatus
import numpy as np


def majority_close(arr1, arr2, rtol, atol, threshold):
# Check if the majority of elements are close
close_count = np.count_nonzero(np.isclose(arr1, arr2, rtol=rtol, atol=atol))
return close_count >= len(arr1) * threshold


def assert_prediction(job, **kwargs):
assert job.status == ExecutionStatus.Success, "job should be successful"
atol = kwargs.get("atol", 10e-8)
rtol = kwargs.get("rtol", 10e-5)
threshold = kwargs.get("threshold", 0.9)
expected_value = kwargs["expected_value"]
predictions = job.outputs["-p"]
res = []
with open(predictions[0], "r") as f:
exampleRes = []
while True:
line = f.readline()
if not line:
break
if line.count(":") == 0:
res.append(exampleRes)
exampleRes = []
continue
slotRes = [0] * line.count(":")
slot = line.split(",")
for i in range(len(slot)):
actionInd = int(slot[i].split(":")[0])
slotRes[i] = float(slot[actionInd].split(":")[1])
exampleRes.append(slotRes)

assert majority_close(
res,
[expected_value] * len(res),
rtol=rtol,
atol=atol,
threshold=threshold,
), f"predicted value should be {expected_value}, \n actual values are {res}"
70 changes: 70 additions & 0 deletions python/tests/test_framework/slate/data_generation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import random
import os
from test_helper import get_function_object

script_directory = os.path.dirname(os.path.realpath(__file__))
random.seed(10)


def generate_slate_data(
num_examples,
reward_function,
logging_policy,
action_space,
context_name=["1"],
):

action_space_obj = get_function_object("slate.action_space", action_space["name"])

reward_function_obj = get_function_object(
"slate.reward_functions", reward_function["name"]
)
logging_policy_obj = get_function_object(
"slate.logging_policies", logging_policy["name"]
)

def return_cost_probability(chosen_action, chosen_slot, context):
cost = -reward_function_obj(
chosen_action, context, chosen_slot, **reward_function["params"]
)
logging_policy["params"]["num_action"] = num_actions[chosen_slot - 1]
logging_policy["params"]["chosen_action"] = chosen_action
probability = logging_policy_obj(**logging_policy["params"])
return cost, probability

dataFile = f"slate_test_{num_examples}_{generate_slate_data.__name__}.txt"
with open(os.path.join(script_directory, dataFile), "w") as f:
for i in range(num_examples):
action_space["params"]["iteration"] = i
action_spaces = action_space_obj(**action_space["params"])
num_slots = len(action_spaces)
num_actions = [len(slot) for slot in action_spaces]
slot_name = [f"slot_{index}" for index in range(1, num_slots + 1)]
chosen_actions = []
num_context = len(context_name)
if num_context > 1:
context = random.randint(1, num_context)
else:
context = 1
for s in range(num_slots):
chosen_actions.append(random.randint(1, num_actions[s]))
chosen_actions_cost_prob = [
return_cost_probability(action, slot + 1, context)
for slot, action in enumerate(chosen_actions)
]
total_cost = sum([cost for cost, _ in chosen_actions_cost_prob])

f.write(f"slates shared {total_cost} |User {context_name[context-1]}\n")
# write actions
for ind, slot in enumerate(action_spaces):
for a in slot:
f.write(
f"slates action {ind} |Action {a}\n",
)

for s in range(num_slots):
f.write(
f"slates slot {chosen_actions[s]}:{chosen_actions_cost_prob[s][1]} |Slot {slot_name[s]}\n"
)
f.write("\n")
return os.path.join(script_directory, dataFile)
3 changes: 3 additions & 0 deletions python/tests/test_framework/slate/logging_policies.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
def even_probability(chosen_action, **kwargs):
num_actions = kwargs["num_action"]
return round(1 / num_actions, 2)
12 changes: 12 additions & 0 deletions python/tests/test_framework/slate/reward_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
def fixed_reward(chosen_action, context, slot, **kwargs):
reward = kwargs["reward"]
return reward[slot - 1][chosen_action - 1]


def reverse_reward_after_threshold(chosen_action, context, slot, **kwargs):
reward = kwargs["reward"]
iteration = kwargs.get("iteration", 0)
threshold = kwargs.get("threshold", 0)
if iteration > threshold:
reward = [i[::-1] for i in reward]
return reward[slot - 1][chosen_action - 1]
Loading

0 comments on commit 702604f

Please sign in to comment.