-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
test: [RLOS2023] add test for slate (#4629)
* test: add test for slate * test: test cleanup and slate test update * test: minor cleanup and change assert_loss function to equal instead of lower
- Loading branch information
Showing
13 changed files
with
413 additions
and
86 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
def new_action_after_threshold(**kwargs): | ||
iteration = kwargs.get("iteration", 0) | ||
threshold = kwargs.get("threshold", 0) | ||
# before iteration 500, it is sunny and after it is raining | ||
if iteration > threshold: | ||
return kwargs["after"] | ||
return kwargs["before"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
from numpy.testing import assert_allclose, assert_almost_equal | ||
from vw_executor.vw import ExecutionStatus | ||
import numpy as np | ||
|
||
|
||
def majority_close(arr1, arr2, rtol, atol, threshold): | ||
# Check if the majority of elements are close | ||
close_count = np.count_nonzero(np.isclose(arr1, arr2, rtol=rtol, atol=atol)) | ||
return close_count >= len(arr1) * threshold | ||
|
||
|
||
def assert_prediction(job, **kwargs): | ||
assert job.status == ExecutionStatus.Success, "job should be successful" | ||
atol = kwargs.get("atol", 10e-8) | ||
rtol = kwargs.get("rtol", 10e-5) | ||
threshold = kwargs.get("threshold", 0.9) | ||
expected_value = kwargs["expected_value"] | ||
predictions = job.outputs["-p"] | ||
res = [] | ||
with open(predictions[0], "r") as f: | ||
exampleRes = [] | ||
while True: | ||
line = f.readline() | ||
if not line: | ||
break | ||
if line.count(":") == 0: | ||
res.append(exampleRes) | ||
exampleRes = [] | ||
continue | ||
slotRes = [0] * line.count(":") | ||
slot = line.split(",") | ||
for i in range(len(slot)): | ||
actionInd = int(slot[i].split(":")[0]) | ||
slotRes[i] = float(slot[actionInd].split(":")[1]) | ||
exampleRes.append(slotRes) | ||
|
||
assert majority_close( | ||
res, | ||
[expected_value] * len(res), | ||
rtol=rtol, | ||
atol=atol, | ||
threshold=threshold, | ||
), f"predicted value should be {expected_value}, \n actual values are {res}" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
import random | ||
import os | ||
from test_helper import get_function_object | ||
|
||
script_directory = os.path.dirname(os.path.realpath(__file__)) | ||
random.seed(10) | ||
|
||
|
||
def generate_slate_data( | ||
num_examples, | ||
reward_function, | ||
logging_policy, | ||
action_space, | ||
context_name=["1"], | ||
): | ||
|
||
action_space_obj = get_function_object("slate.action_space", action_space["name"]) | ||
|
||
reward_function_obj = get_function_object( | ||
"slate.reward_functions", reward_function["name"] | ||
) | ||
logging_policy_obj = get_function_object( | ||
"slate.logging_policies", logging_policy["name"] | ||
) | ||
|
||
def return_cost_probability(chosen_action, chosen_slot, context): | ||
cost = -reward_function_obj( | ||
chosen_action, context, chosen_slot, **reward_function["params"] | ||
) | ||
logging_policy["params"]["num_action"] = num_actions[chosen_slot - 1] | ||
logging_policy["params"]["chosen_action"] = chosen_action | ||
probability = logging_policy_obj(**logging_policy["params"]) | ||
return cost, probability | ||
|
||
dataFile = f"slate_test_{num_examples}_{generate_slate_data.__name__}.txt" | ||
with open(os.path.join(script_directory, dataFile), "w") as f: | ||
for i in range(num_examples): | ||
action_space["params"]["iteration"] = i | ||
action_spaces = action_space_obj(**action_space["params"]) | ||
num_slots = len(action_spaces) | ||
num_actions = [len(slot) for slot in action_spaces] | ||
slot_name = [f"slot_{index}" for index in range(1, num_slots + 1)] | ||
chosen_actions = [] | ||
num_context = len(context_name) | ||
if num_context > 1: | ||
context = random.randint(1, num_context) | ||
else: | ||
context = 1 | ||
for s in range(num_slots): | ||
chosen_actions.append(random.randint(1, num_actions[s])) | ||
chosen_actions_cost_prob = [ | ||
return_cost_probability(action, slot + 1, context) | ||
for slot, action in enumerate(chosen_actions) | ||
] | ||
total_cost = sum([cost for cost, _ in chosen_actions_cost_prob]) | ||
|
||
f.write(f"slates shared {total_cost} |User {context_name[context-1]}\n") | ||
# write actions | ||
for ind, slot in enumerate(action_spaces): | ||
for a in slot: | ||
f.write( | ||
f"slates action {ind} |Action {a}\n", | ||
) | ||
|
||
for s in range(num_slots): | ||
f.write( | ||
f"slates slot {chosen_actions[s]}:{chosen_actions_cost_prob[s][1]} |Slot {slot_name[s]}\n" | ||
) | ||
f.write("\n") | ||
return os.path.join(script_directory, dataFile) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
def even_probability(chosen_action, **kwargs): | ||
num_actions = kwargs["num_action"] | ||
return round(1 / num_actions, 2) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
def fixed_reward(chosen_action, context, slot, **kwargs): | ||
reward = kwargs["reward"] | ||
return reward[slot - 1][chosen_action - 1] | ||
|
||
|
||
def reverse_reward_after_threshold(chosen_action, context, slot, **kwargs): | ||
reward = kwargs["reward"] | ||
iteration = kwargs.get("iteration", 0) | ||
threshold = kwargs.get("threshold", 0) | ||
if iteration > threshold: | ||
reward = [i[::-1] for i in reward] | ||
return reward[slot - 1][chosen_action - 1] |
Oops, something went wrong.