Skip to content

Commit

Permalink
test: minor cleanup and change assert_loss function to equal instead …
Browse files Browse the repository at this point in the history
…of lower
  • Loading branch information
michiboo committed Aug 4, 2023
1 parent 51c7045 commit 4e8002e
Show file tree
Hide file tree
Showing 6 changed files with 25 additions and 20 deletions.
10 changes: 8 additions & 2 deletions python/tests/test_framework/assert_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,17 @@ def assert_loss(job, **kwargs):
assert job.status == ExecutionStatus.Success, "job should be successful"
assert type(job[0].loss) == float, "loss should be an float"
decimal = kwargs.get("decimal", 2)
if job[0].loss < kwargs["expected_loss"]:
return
assert_almost_equal(job[0].loss, kwargs["expected_loss"], decimal=decimal)


def assert_loss_below(job, **kwargs):
assert job.status == ExecutionStatus.Success, "job should be successful"
assert type(job[0].loss) == float, "loss should be an float"
assert (
job[0].loss <= kwargs["expected_loss"]
), f"loss should be below {kwargs['expected_loss']}"


def assert_prediction_with_generated_data(job, **kwargs):
assert job.status == ExecutionStatus.Success, "job should be successful"
expected_class = []
Expand Down
10 changes: 5 additions & 5 deletions python/tests/test_framework/cb/data_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@ def generate_cb_data(
num_actions,
reward_function,
logging_policy,
no_context=1,
context_name=None,
context_name=["1"],
):

dataFile = f"cb_test_{num_examples}_{num_actions}_{num_features}.txt"
Expand All @@ -32,13 +31,14 @@ def generate_cb_data(
features = [f"feature{index}" for index in range(1, num_features + 1)]
with open(os.path.join(script_directory, dataFile), "w") as f:
for _ in range(num_examples):
no_context = len(context_name)
if no_context > 1:
context = random.randint(1, no_context)
if not context_name:
context_name = [f"{index}" for index in range(1, no_context + 1)]
else:
context = 1

def return_cost_probability(chosen_action, context=1):
cost = reward_function_obj(
cost = -reward_function_obj(
chosen_action, context, **reward_function["params"]
)
if "params" not in logging_policy:
Expand Down
8 changes: 3 additions & 5 deletions python/tests/test_framework/slate/data_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@ def generate_slate_data(
reward_function,
logging_policy,
action_space,
num_context=1,
context_name=None,
context_name=["1"],
):

action_space_obj = get_function_object("slate.action_space", action_space["name"])
Expand All @@ -25,7 +24,7 @@ def generate_slate_data(
)

def return_cost_probability(chosen_action, chosen_slot, context):
cost = reward_function_obj(
cost = -reward_function_obj(
chosen_action, context, chosen_slot, **reward_function["params"]
)
logging_policy["params"]["num_action"] = num_actions[chosen_slot - 1]
Expand All @@ -42,12 +41,11 @@ def return_cost_probability(chosen_action, chosen_slot, context):
num_actions = [len(slot) for slot in action_spaces]
slot_name = [f"slot_{index}" for index in range(1, num_slots + 1)]
chosen_actions = []
num_context = len(context_name)
if num_context > 1:
context = random.randint(1, num_context)
else:
context = 1
if not context_name:
context_name = [f"{index}" for index in range(1, num_context + 1)]
for s in range(num_slots):
chosen_actions.append(random.randint(1, num_actions[s]))
chosen_actions_cost_prob = [
Expand Down
2 changes: 1 addition & 1 deletion python/tests/test_framework/slate/reward_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ def fixed_reward(chosen_action, context, slot, **kwargs):
return reward[slot - 1][chosen_action - 1]


def reverse_reward_after_iteration(chosen_action, context, slot, **kwargs):
def reverse_reward_after_threshold(chosen_action, context, slot, **kwargs):
reward = kwargs["reward"]
iteration = kwargs.get("iteration", 0)
threshold = kwargs.get("threshold", 0)
Expand Down
11 changes: 6 additions & 5 deletions python/tests/test_framework/test_configs/cb.json
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,15 @@
"params": {
}
},
"no_context": 2
"context_name": ["1", "2"]
}
},
"assert_functions": [
{
"name": "assert_loss",
"params": {
"expected_loss": 0.1
"expected_loss": -1,
"decimal": 1
}
},
{
Expand Down Expand Up @@ -124,7 +125,7 @@
{
"name": "assert_loss",
"params": {
"expected_loss": 1
"expected_loss": -1
}
},
{
Expand Down Expand Up @@ -173,14 +174,14 @@
"params": {
}
},
"no_context": 2
"context_name": ["1", "2"]
}
},
"assert_functions": [
{
"name": "assert_loss",
"params": {
"expected_loss": 0.6,
"expected_loss": -0.4,
"decimal": 1
}
},
Expand Down
4 changes: 2 additions & 2 deletions python/tests/test_framework/test_configs/slate.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"params": {
"num_examples": 1000,
"reward_function": {
"name": "reverse_reward_after_iteration",
"name": "reverse_reward_after_threshold",
"params": {
"reward": [
[
Expand Down Expand Up @@ -57,7 +57,7 @@
{
"name": "assert_loss",
"params": {
"expected_loss": 0.8,
"expected_loss": -1.9,
"decimal": 0.1
}
},
Expand Down

0 comments on commit 4e8002e

Please sign in to comment.