Skip to content

Commit

Permalink
update test
Browse files Browse the repository at this point in the history
  • Loading branch information
lalo committed Mar 13, 2023
1 parent 7c0b212 commit 469cebe
Showing 1 changed file with 57 additions and 4 deletions.
61 changes: 57 additions & 4 deletions python/tests/test_ccb.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,32 @@ def my_ccb_simulation(n=10000, swap_after=5000, variance=0, bad_features=0, seed

envs = [[[0.8, 0.4], [0.2, 0.4]]]
offset = 0

animals = [
"cat",
"dog",
"bird",
"fish",
"horse",
"cow",
"pig",
"sheep",
"goat",
"chicken",
]
colors = [
"red",
"green",
"blue",
"yellow",
"orange",
"purple",
"black",
"white",
"brown",
"gray",
]

for i in range(1, n):
person = random.randint(0, 1)
chosen = [int(i) for i in np.random.permutation(2)]
Expand All @@ -146,7 +172,7 @@ def my_ccb_simulation(n=10000, swap_after=5000, variance=0, bad_features=0, seed
for i in range(len(rewards)):
rewards[i] += np.random.normal(0.5, variance)

yield {
temp = {
"c": {
"shared": {"name": people_ccb[person]},
"_multi": [{"a": {"topic": topics_ccb[i]}} for i in range(2)],
Expand All @@ -162,6 +188,22 @@ def my_ccb_simulation(n=10000, swap_after=5000, variance=0, bad_features=0, seed
],
}

temp["c"]["shared"][random.choice(animals)] = random.random()
temp["c"]["shared"][random.choice(animals)] = random.random()

temp["c"]["_multi"][random.choice(range(len(temp["c"]["_multi"])))][
random.choice(colors)
] = random.random()
temp["c"]["_multi"][random.choice(range(len(temp["c"]["_multi"])))][
random.choice(colors)
] = random.random()

temp["c"]["_slots"][random.choice(range(len(temp["c"]["_slots"])))][
random.choice(colors)
] = random.random()

yield temp

def save_examples(examples, path):
with open(path, "w") as f:
for ex in examples:
Expand All @@ -170,18 +212,29 @@ def save_examples(examples, path):
input_file = "ccb.json"
cache_dir = ".cache"
save_examples(
my_ccb_simulation(n=1000, variance=0.1, bad_features=1, seed=0), input_file
my_ccb_simulation(n=10, variance=0.1, bad_features=1, seed=0), input_file
)

assert os.path.exists(input_file)

vw = Vw(cache_dir, "/root/vowpal_wabbit/build/vowpalwabbit/cli/vw", handler=None)
vw = Vw(cache_dir, handler=None)
q = vw.train(
input_file, "-b 18 -q :: --ccb_explore_adf --dsjson", ["--invert_hash"]
)
automl = vw.train(
input_file,
"-b 20 --ccb_explore_adf --log_output stderr --dsjson --automl 4 --oracle_type one_diff --verbose_metrics",
["--invert_hash", "--extra_metrics"],
)

fts_names_q = set([n for n in q[0].model9("--invert_hash").weights.index])
fts_names_automl = set(
[n for n in automl[0].model9("--invert_hash").weights.index if "[" not in n]
)

assert len(fts_names_q) == 39
# current impl is broken - they should have same feature count
assert len(fts_names_q) == 330
assert len(fts_names_automl) == 200

os.remove(input_file)
shutil.rmtree(cache_dir)
Expand Down

0 comments on commit 469cebe

Please sign in to comment.