Skip to content

Commit

Permalink
test: add test for cb with continous action
Browse files Browse the repository at this point in the history
  • Loading branch information
michiboo committed Aug 9, 2023
1 parent 4e8002e commit f50ca3e
Show file tree
Hide file tree
Showing 6 changed files with 279 additions and 1 deletion.
2 changes: 2 additions & 0 deletions python/tests/test_framework/assert_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ def assert_prediction(job, **kwargs):
prediction = [i for i in prediction if i != ""]
if ":" in prediction[0]:
prediction = [[j.split(":")[1] for j in i.split(",")] for i in prediction]
elif "," in prediction[0]:
prediction = [[j for j in i.split(",")] for i in prediction]
if type(prediction[0]) == list:
prediction = [[float(remove_non_digits(j)) for j in i] for i in prediction]
else:
Expand Down
62 changes: 62 additions & 0 deletions python/tests/test_framework/cb_cont/data_generation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import random
import os
from test_helper import get_function_object

script_directory = os.path.dirname(os.path.realpath(__file__))
random.seed(10)


def random_number_items(items):
num_items_to_select = random.randint(1, len(items))
return random.sample(items, num_items_to_select)


def generate_cb_data(
num_examples,
num_features,
action_range,
reward_function,
logging_policy,
context_name=["1"],
):
num_actions = int(abs(action_range[1] - action_range[0]))
dataFile = f"cb_cont_test_{num_examples}_{num_actions}_{num_features}.txt"

reward_function_obj = get_function_object(
"cb_cont.reward_functions", reward_function["name"]
)
logging_policy_obj = get_function_object(
"cb_cont.logging_policies", logging_policy["name"]
)
features = [f"feature{index}" for index in range(1, num_features + 1)]
with open(os.path.join(script_directory, dataFile), "w") as f:
for _ in range(num_examples):
no_context = len(context_name)
if no_context > 1:
context = random.randint(1, no_context)
else:
context = 1

def return_cost_probability(chosen_action, context=1):
cost = -reward_function_obj(
chosen_action, context, **reward_function["params"]
)
if "params" not in logging_policy:
logging_policy["params"] = {}
logging_policy["params"]["chosen_action"] = chosen_action
logging_policy["params"]["num_actions"] = num_actions
probability = logging_policy_obj(**logging_policy["params"])
return cost, probability

chosen_action = round(random.uniform(0, num_actions), 2)
cost, probability = return_cost_probability(chosen_action, context)
if no_context == 1:
f.write(
f'ca {chosen_action}:{cost}:{probability} | {" ".join(random_number_items(features))}\n'
)
else:
f.write(
f'ca {chosen_action}:{cost}:{probability} | {"s_" + context_name[context-1]} {" ".join(random_number_items(features))}\n'
)
f.write("\n")
return os.path.join(script_directory, dataFile)
7 changes: 7 additions & 0 deletions python/tests/test_framework/cb_cont/logging_policies.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
def constant_probability(chosen_action, **kwargs):
return 1


def even_probability(chosen_action, **kwargs):
num_actions = kwargs["num_actions"]
return round(1 / num_actions, 2)
19 changes: 19 additions & 0 deletions python/tests/test_framework/cb_cont/reward_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
def fixed_reward(chosen_action, context, **kwargs):
return 1


def constant_reward(chosen_action, context, **kwargs):
reward = kwargs["reward"]
return reward[int(chosen_action) - 1]


def fixed_reward_two_action(chosen_action, context, **kwargs):
if context == 1 and chosen_action >= 2:
return 1
elif context == 2 and chosen_action < 2 and chosen_action >= 1:
return 0
elif context == 1 and chosen_action < 1 and chosen_action >= 1:
return 0
elif context == 2 and chosen_action < 1:
return 1
return 1
189 changes: 189 additions & 0 deletions python/tests/test_framework/test_configs/cb_cont.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
[
{
"test_name": "cb_two_action",
"data_func": {
"name": "generate_cb_data",
"params": {
"num_examples": 100,
"num_features": 1,
"action_range": [
0,
2
],
"reward_function": {
"name": "constant_reward",
"params": {
"reward": [
1,
0
]
}
},
"logging_policy": {
"name": "even_probability",
"params": {}
}
}
},
"assert_functions": [
{
"name": "assert_loss",
"params": {
"expected_loss": -1,
"decimal": 1
}
},
{
"name": "assert_prediction",
"params": {
"expected_value": [
1,
0
],
"threshold": 0.8
}
}
],
"grids": {
"cb": {
"#base": [
"--cats 2 --min_value 0 --max_value 2 --bandwidth 1"
]
},
"epsilon": {
"--epsilon": [
0.1,
0.2,
0.3
]
}
},
"grids_expression": "cb * (epsilon)",
"output": [
"--readable_model",
"-p"
]
},
{
"test_name": "cb_two_action_diff_context",
"data_func": {
"name": "generate_cb_data",
"params": {
"num_examples": 100,
"num_features": 2,
"action_range": [
0,
2
],
"reward_function": {
"name": "fixed_reward_two_action",
"params": {}
},
"logging_policy": {
"name": "even_probability",
"params": {}
},
"context_name": [
"1",
"2"
]
}
},
"assert_functions": [
{
"name": "assert_loss",
"params": {
"expected_loss": -0.8,
"decimal": 1
}
},
{
"name": "assert_prediction",
"params": {
"expected_value": [
0.975,
0.025
],
"threshold": 0.1,
"atol": 0.1,
"rtol": 0.1
}
}
],
"grids": {
"cb": {
"#base": [
"--cats 2 --min_value 0 --max_value 2 --bandwidth 1"
]
},
"epsilon": {
"--epsilon": [
0.1,
0.2,
0.3
]
}
},
"grids_expression": "cb * (epsilon)",
"output": [
"--readable_model",
"-p"
]
},
{
"test_name": "cb_one_action",
"data_func": {
"name": "generate_cb_data",
"params": {
"num_examples": 10,
"num_features": 1,
"action_range": [
0,
1
],
"reward_function": {
"name": "fixed_reward",
"params": {}
},
"logging_policy": {
"name": "even_probability"
}
}
},
"assert_functions": [
{
"name": "assert_loss",
"params": {
"expected_loss": -1
}
},
{
"name": "assert_prediction",
"params": {
"expected_value": 0,
"threshold": 0.1
}
}
],
"grids": {
"g0": {
"#base": [
"--cats 1 --min_value 0 --max_value 1 --bandwidth 1"
]
},
"g1": {
"--cb_type": [
"ips",
"mtr",
"dr",
"dm"
]
}
},
"grids_expression": "g0 * g1",
"output": [
"--readable_model",
"-p"
]
}
]
1 change: 0 additions & 1 deletion python/tests/test_framework/test_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ def copy_file(source_file, destination_file):


def call_function_with_dirs(dirs, module_name, function_name, **kargs):

for dir in dirs:
try:
data = dynamic_function_call(
Expand Down

0 comments on commit f50ca3e

Please sign in to comment.