From 56d41dac918e0e02bf38b3b0d90dbb980563fecf Mon Sep 17 00:00:00 2001 From: Jack Gerrits Date: Mon, 27 Nov 2023 14:13:52 -0500 Subject: [PATCH 1/6] fix: fix #4669 by handling empty decision scores elements --- test/core.vwtest.json | 28 ++++++++++ test/train-sets/issue4669.txt | 50 +++++++++++++++++ test/train-sets/ref/issue4669_test.stderr | 23 ++++++++ test/train-sets/ref/issue4669_test.stdout | 1 + test/train-sets/ref/issue4669_train.stderr | 23 ++++++++ test/train-sets/ref/issue4669_train.stdout | 0 vowpalwabbit/core/src/decision_scores.cc | 3 +- .../conditional_contextual_bandit.cc | 13 ++++- vowpalwabbit/core/tests/ccb_test.cc | 56 +++++++++++++++++++ 9 files changed, 194 insertions(+), 3 deletions(-) create mode 100644 test/train-sets/issue4669.txt create mode 100644 test/train-sets/ref/issue4669_test.stderr create mode 100644 test/train-sets/ref/issue4669_test.stdout create mode 100644 test/train-sets/ref/issue4669_train.stderr create mode 100644 test/train-sets/ref/issue4669_train.stdout diff --git a/test/core.vwtest.json b/test/core.vwtest.json index ef6857518f3..6c5fa549831 100644 --- a/test/core.vwtest.json +++ b/test/core.vwtest.json @@ -6073,5 +6073,33 @@ "depends_on": [ 467 ] + }, + { + "id": 469, + "desc": "https://github.com/VowpalWabbit/vowpal_wabbit/issues/4669", + "vw_command": "--ccb_explore_adf -q UA --all_slots_loss -f issue4669.model -d train-sets/issue4669.txt", + "diff_files": { + "stderr": "train-sets/ref/issue4669_train.stderr", + "stdout": "train-sets/ref/issue4669_train.stdout" + }, + "input_files": [ + "train-sets/issue4669.txt" + ] + }, + { + "id": 470, + "desc": "https://github.com/VowpalWabbit/vowpal_wabbit/issues/4669", + "vw_command": "--ccb_explore_adf -q UA --all_slots_loss -i issue4669.model -t -d train-sets/issue4669.txt", + "diff_files": { + "stderr": "train-sets/ref/issue4669_test.stderr", + "stdout": "train-sets/ref/issue4669_test.stdout" + }, + "input_files": [ + "train-sets/issue4669.txt", + "issue4669.model" + ], + "depends_on": [ + 469 + ] } ] \ No newline at end of file diff --git a/test/train-sets/issue4669.txt b/test/train-sets/issue4669.txt new file mode 100644 index 00000000000..2051716f9d2 --- /dev/null +++ b/test/train-sets/issue4669.txt @@ -0,0 +1,50 @@ +ccb shared |User userID='aUser' +ccb action |Action contentId='a' +ccb action |Action contentId='b' +ccb action |Action contentId='c' +ccb action |Action contentId='d' +ccb action |Action contentId='e' +ccb action |Action contentId='f' +ccb action |Action contentId='g' +ccb action |Action contentId='h' +ccb action |Action contentId='i' +ccb action |Action contentId='j' +ccb action |Action contentId='k' +ccb action |Action contentId='l' +ccb action |Action contentId='m' +ccb action |Action contentId='n' +ccb action |Action contentId='o' +ccb action |Action contentId='p' +ccb action |Action contentId='q' +ccb action |Action contentId='r' +ccb action |Action contentId='s' +ccb action |Action contentId='t' +ccb action |Action contentId='x' +ccb action |Action contentId='y' +ccb action |Action contentId='z' +ccb action |Action contentId='aa' +ccb action |Action contentId='ab' +ccb action |Action contentId='ac' +ccb action |Action contentId='ad' +ccb action |Action contentId='ae' +ccb action |Action contentId='af' +ccb action |Action contentId='ag' +ccb slot 7:0:0.2814157009124756,18:0.4986087679862976,0:0.09111946821212769,13:0.034919969737529755,17:0.022231325507164,24:0.021717887371778488,12:0.02091880328953266,22:0.008943566121160984,21:0.0035167494788765907,4:0.003486328525468707,8:0.00306233623996377,14:0.002975673880428076,25:0.002446186961606145,11:0.0013868052046746016,6:0.0007644615834578872,23:0.0007121390081010759,10:0.0006081887986510992,20:0.00046663329703733325,29:0.00024663639487698674,5:0.00020450385636650026,19:6.42634549876675e-05,15:6.378938996931538e-05,9:4.367829387774691e-05,27:2.616703204694204e-05,1:1.8396624000160955e-05,2:1.1368916602805257e-05,26:9.699509064375889e-06,16:5.553145911108004e-06,28:2.4986536573123885e-06,3:2.245711812065565e-06 7,18,0,13,17,24,12,22,21,4,8,14,25,11,6,23,10,20,29,5,19,15,9,27,1,2,26,16,28,3 | +ccb slot 28:0:0.38746780157089233,12:0.4452095031738281,17:0.11129308491945267,13:0.021261485293507576,2:0.010781776160001755,5:0.006623556837439537,14:0.006618801970034838,0:0.00422333087772131,3:0.0029810182750225067,25:0.0011069459142163396,22:0.0010891001438722014,21:0.00043250489397905767,20:0.00032753165578469634,29:0.0002478094829712063,24:0.0001928548444993794,18:3.78785262000747e-05,1:2.8776834369637072e-05,8:2.2285950763034634e-05,10:1.975510713236872e-05,19:1.1510705917316955e-05,15:1.0838572052307427e-05,4:4.2137894524785224e-06,11:2.6999800866178703e-06,16:1.461011038372817e-06,6:1.387319798595854e-06,27:1.0769275604616269e-06,23:9.243367458111607e-07,9:1.2605019605871348e-07,26:6.900691684741389e-10 28,12,17,13,2,5,14,0,3,25,22,21,20,29,24,18,1,8,10,19,15,4,11,16,6,27,23,9,26 | +ccb slot 21:0:0.009457200765609741,17:0.13830232620239258,20:0.06671705096960068,18:0.060711901634931564,13:0.04832380264997482,12:0.02896728552877903,15:0.027813201770186424,1:0.01238782238215208,5:0.5844042301177979,24:0.007916729897260666,22:0.005313423927873373,3:0.0030796348582953215,8:0.0026463039685040712,2:0.001610504579730332,4:0.0004394367279019207,14:0.0003743874840438366,11:0.0003052169340662658,9:0.0002891358162742108,10:0.0002627648937050253,0:0.0002489395847078413,25:0.00015549163799732924,29:0.00012897276610601693,23:7.323167665163055e-05,6:4.6114979340927675e-05,16:1.574967973283492e-05,27:4.545572664937936e-06,26:4.3639242903736886e-06,19:9.780344356613568e-08 21,17,20,18,13,12,15,1,5,24,22,3,8,2,4,14,11,9,10,0,25,29,23,6,16,27,26,19 | +ccb slot 25:0:0.04852084442973137,4:0.18842479586601257,0:0.11689016968011856,11:0.09394704550504684,15:0.07111208140850067,14:0.06413374841213226,1:0.2144840508699417,17:0.04394223168492317,22:0.04030191898345947,6:0.025946179404854774,20:0.02485107257962227,23:0.02306993119418621,8:0.01103772409260273,18:0.009649633429944515,24:0.0051767947152256966,3:0.004639864899218082,5:0.0036481451243162155,13:0.003012213623151183,27:0.0018589160172268748,2:0.0014349337434396148,12:0.0013534734025597572,16:0.0009505416383035481,26:0.0006534871645271778,19:0.0006136561860330403,29:0.00025192913017235696,9:4.8225138016277924e-05,10:4.6369124902412295e-05 25,4,0,11,15,14,1,17,22,6,20,23,8,18,24,3,5,13,27,2,12,16,26,19,29,9,10 | +ccb slot 22:-1:0.2308475524187088,29:0.7056121826171875,24:0.024420326575636864,13:0.017714975401759148,5:0.006036168430000544,15:0.004493136424571276,17:0.002671352354809642,14:0.0016715804813429713,18:0.001555807190015912,12:0.0014371434226632118,19:0.0010492069413885474,10:0.0006786655867472291,0:0.0003868465428240597,2:0.0003725361602846533,3:0.0003565707884263247,4:0.00022809540678281337,6:0.00016483885701745749,11:0.00013846883666701615,1:3.974044739152305e-05,16:3.603336153901182e-05,8:3.055339402635582e-05,27:1.7396419934812002e-05,20:1.5432331565534696e-05,23:1.4172852388583124e-05,9:1.0329640645068139e-05,26:8.735177061680588e-07 22,29,24,13,5,15,17,14,18,12,19,10,0,2,3,4,6,11,1,16,8,27,20,23,9,26 | +ccb slot 12:0:0.9093628525733948,0:0.04226174205541611,17:0.018896808847784996,13:0.009150444529950619,14:0.005060152616351843,24:0.0038528849836438894,11:0.0037376414984464645,1:0.0027163256891071796,20:0.0019985607359558344,5:0.0015763206174597144,6:0.0005092258215881884,18:0.0003184932575095445,2:0.0002330887655261904,8:0.00017497778753750026,23:4.659605838241987e-05,10:4.272087971912697e-05,29:2.479964132362511e-05,3:8.60487580212066e-06,27:6.880749879201176e-06,16:5.5458899623772595e-06,15:4.874308160651708e-06,9:3.708174290295574e-06,26:3.1738302368466975e-06,19:2.801355094561586e-06,4:8.652203291603655e-07 12,0,17,13,14,24,11,1,20,5,6,18,2,8,23,10,29,3,27,16,15,9,26,19,4 | +ccb slot 5:0:0.28719252347946167,17:0.23356057703495026,29:0.17160779237747192,20:0.1328369677066803,0:0.07422535866498947,23:0.050119683146476746,1:0.021847186610102654,13:0.012630262412130833,2:0.004507853649556637,8:0.0024876149836927652,3:0.0022544104140251875,9:0.0015914351679384708,18:0.0013845351058989763,10:0.001123852445743978,4:0.0008636291604489088,15:0.0004619551182258874,16:0.0004587841685861349,6:0.0004078721103724092,11:0.00021510760416276753,14:0.00014530900807585567,24:3.345161894685589e-05,26:2.9229657229734585e-05,27:1.3438097084872425e-05,19:1.2410271210683277e-06 5,17,29,20,0,23,1,13,2,8,3,9,18,10,4,15,16,6,11,14,24,26,27,19 | +ccb slot 20:0:0.028376348316669464,0:0.1662101000547409,1:0.046988535672426224,17:0.6560971140861511,3:0.02463809959590435,2:0.022943779826164246,8:0.02021339163184166,4:0.011480750516057014,6:0.006486969999969006,27:0.005383896175771952,18:0.005046996288001537,10:0.0017316940939053893,16:0.0012900123838335276,9:0.0011327711399644613,13:0.000544899667147547,23:0.000498446635901928,11:0.00048043631250038743,14:0.00018104366608895361,19:0.00012088919174857438,24:6.899452273501083e-05,29:5.0261642172699794e-05,15:3.0958537536207587e-05,26:3.547019559846376e-06 20,0,1,17,3,2,8,4,6,27,18,10,16,9,13,23,11,14,19,24,29,15,26 | +ccb slot 29:0:0.14386186003684998,4:0.18966703116893768,11:0.2729467749595642,1:0.10815022885799408,17:0.10176316648721695,24:0.048249997198581696,3:0.034515734761953354,2:0.0317024290561676,6:0.029643042013049126,14:0.01419538538902998,10:0.008138692937791348,0:0.00682934420183301,13:0.003933712374418974,9:0.0023981353733688593,8:0.0013491454301401973,23:0.000978886615484953,15:0.0006223535747267306,27:0.00042564209434203804,18:0.0003683240502141416,19:0.00018579690367914736,26:5.759065970778465e-05,16:1.6690515622030944e-05 29,4,11,1,17,24,3,2,6,14,10,0,13,9,8,23,15,27,18,19,26,16 | +ccb slot 16:0:0.8358462452888489,18:0.14180654287338257,8:0.006622648332268,0:0.0048446557484567165,13:0.003651235019788146,17:0.002547427313402295,2:0.0023207622580230236,14:0.0007652127533219755,9:0.0003962431219406426,3:0.00039094872772693634,6:0.0003010313375853002,1:0.00015582605556119233,4:0.0001078778732335195,27:8.33657686598599e-05,11:6.325534195639193e-05,26:5.2438019338296726e-05,24:2.5939621991710737e-05,10:1.2438776138878893e-05,15:5.886784038011683e-06,23:1.0229091351732222e-08,19:7.580823080388654e-09 16,18,8,0,13,17,2,14,9,3,6,1,4,27,11,26,24,10,15,23,19 | +ccb slot 17:0:0.5127939581871033,2:0.4746686518192291,19:0.006616346072405577,1:0.001669980469159782,3:0.0010488936677575111,24:0.0009213163866661489,6:0.0007942019728943706,14:0.0006952053518034518,23:0.0005310842534527183,27:0.00010622834088280797,0:5.494915967574343e-05,8:2.4259865313069895e-05,26:2.4085793484118767e-05,10:1.763248656061478e-05,13:1.2499597687565256e-05,4:9.418990885023959e-06,11:6.62406728224596e-06,15:4.250239726388827e-06,9:3.168663909036695e-07 17,2,19,1,3,24,6,14,23,27,0,8,26,10,13,4,11,15,9 | +ccb slot 2:0:0.3009277284145355,11:0.2405196875333786,0:0.12718695402145386,1:0.10248230397701263,4:0.10110407322645187,27:0.0488734170794487,14:0.031138606369495392,23:0.027005093172192574,15:0.01463699247688055,13:0.0021626290399581194,9:0.001511908951215446,24:0.0014617646811529994,26:0.000723238626960665,8:0.00015963710029609501,10:6.141114135971293e-05,3:4.065587563673034e-05,6:2.3371310362563236e-06,19:1.6080829254860873e-06 2,11,0,1,4,27,14,23,15,13,9,24,26,8,10,3,6,19 | +ccb slot 3:0:0.057185981422662735,27:0.14134737849235535,24:0.11761204898357391,11:0.09371144324541092,0:0.5287960171699524,6:0.03242357447743416,8:0.011251688934862614,14:0.007142344955354929,1:0.0070684002712368965,23:0.0012641034554690123,9:0.0005852219182997942,13:0.0005638026632368565,19:0.0004846873343922198,15:0.0002985078317578882,26:0.00022339983843266964,10:3.1262432457879186e-05,4:1.0203940291830804e-05 3,27,24,11,0,6,8,14,1,23,9,13,19,15,26,10,4 | +ccb slot 1:0:0.1249222531914711,24:0.25583717226982117,14:0.12867441773414612,23:0.26537543535232544,13:0.08874479681253433,15:0.04395920783281326,11:0.04107809066772461,26:0.026856984943151474,0:0.01694386824965477,6:0.005332316737622023,9:0.0009168571559712291,8:0.0005701580666936934,10:0.00035485764965415,4:0.00030593250994570553,27:9.87778403214179e-05,19:2.8945323720108718e-05 1,24,14,23,13,15,11,26,0,6,9,8,10,4,27,19 | +ccb slot 0:0:0.7005100846290588,15:0.15935854613780975,26:0.10210220515727997,6:0.01650484837591648,27:0.012556466273963451,14:0.003727053524926305,10:0.0025744284503161907,23:0.0013873651623725891,11:0.0004973539034835994,13:0.00025413508410565555,24:0.0002187014470109716,8:0.00016339073772542179,19:8.715951116755605e-05,4:4.481669020606205e-05,9:1.3443200259644073e-05 0,15,26,6,27,14,10,23,11,13,24,8,19,4,9 | +ccb slot 8:0:0.9874762892723083,6:0.00831417366862297,24:0.0015088269719853997,19:0.0014199139550328255,13:0.0005207555368542671,4:0.0003495477430988103,23:0.00022233030176721513,11:0.0001492148294346407,10:1.3149796359357424e-05,27:9.886987754725851e-06,15:9.333793059340678e-06,9:2.8232302611286286e-06,14:2.0402019345056033e-06,26:1.7364958466714597e-06 8,6,24,19,13,4,23,11,10,27,15,9,14,26 | +ccb slot 14:0:0.12711836397647858,9:0.23751074075698853,10:0.3793737292289734,27:0.08924281597137451,24:0.07040458917617798,15:0.06249556690454483,11:0.02001812309026718,4:0.009080302901566029,26:0.0017807148396968842,23:0.0015691084554418921,19:0.0009121194598264992,13:0.0004805738863069564,6:1.3344042599783279e-05 14,9,10,27,24,15,11,4,26,23,19,13,6 | +ccb slot 24:0:0.9453869462013245,19:0.029324114322662354,23:0.017052331939339638,13:0.004684425424784422,11:0.0012589030666276813,10:0.0009999927133321762,26:0.0009100232855416834,15:0.00019236477965023369,6:0.00014357283362187445,27:4.12082408729475e-05,9:3.6369422105053673e-06,4:2.426173068670323e-06 24,19,23,13,11,10,26,15,6,27,9,4 | +ccb slot 19:0:0.07647673040628433,13:0.26458504796028137,27:0.08122498542070389,26:0.548874020576477,11:0.025309467688202858,15:0.0020130074117332697,6:0.00121505802962929,23:0.00012354919454082847,9:9.067041537491605e-05,4:4.7242716391338035e-05,10:4.006700328318402e-05 19,13,27,26,11,15,6,23,9,4,10 | \ No newline at end of file diff --git a/test/train-sets/ref/issue4669_test.stderr b/test/train-sets/ref/issue4669_test.stderr new file mode 100644 index 00000000000..dea94504ebf --- /dev/null +++ b/test/train-sets/ref/issue4669_test.stderr @@ -0,0 +1,23 @@ +creating quadratic features for pairs: UA +only testing +using no cache +Reading datafile = train-sets/issue4669.txt +num sources = 1 +Num weight bits = 18 +learning rate = 0.5 +initial_t = 1 +power_t = 0.5 +cb_type = mtr +Enabled learners: gd, generate_interactions, scorer-identity, csoaa_ldf-rank, cb_adf, cb_explore_adf_greedy, cb_sample, shared_feature_merger, ccb_explore_adf +Input label = CCB +Output pred = DECISION_PROBS +average since example example current current current +loss last counter weight label predict features +0.000000 0.000000 1 1.0 7:0,28:0,... 22,4,6,25,2... 2059 + +finished run +number of examples = 1 +weighted example sum = 1.000000 +weighted label sum = 0.000000 +average loss = 0.000000 +total feature number = 2059 diff --git a/test/train-sets/ref/issue4669_test.stdout b/test/train-sets/ref/issue4669_test.stdout new file mode 100644 index 00000000000..c310247cee5 --- /dev/null +++ b/test/train-sets/ref/issue4669_test.stdout @@ -0,0 +1 @@ +[warning] model file has set of {-q, --cubic, --interactions} settings stored, but they'll be OVERRIDDEN by set of {-q, --cubic, --interactions} settings from command line. diff --git a/test/train-sets/ref/issue4669_train.stderr b/test/train-sets/ref/issue4669_train.stderr new file mode 100644 index 00000000000..0b14fdf24f2 --- /dev/null +++ b/test/train-sets/ref/issue4669_train.stderr @@ -0,0 +1,23 @@ +creating quadratic features for pairs: UA +final_regressor = issue4669.model +using no cache +Reading datafile = train-sets/issue4669.txt +num sources = 1 +Num weight bits = 18 +learning rate = 0.5 +initial_t = 0 +power_t = 0.5 +cb_type = mtr +Enabled learners: gd, generate_interactions, scorer-identity, csoaa_ldf-rank, cb_adf, cb_explore_adf_greedy, cb_sample, shared_feature_merger, ccb_explore_adf +Input label = CCB +Output pred = DECISION_PROBS +average since example example current current current +loss last counter weight label predict features +-0.16661 -0.16661 1 1.0 7:0,28:0,... 7,28,21,25,... 3120 + +finished run +number of examples = 1 +weighted example sum = 1.000000 +weighted label sum = 0.000000 +average loss = -0.166610 +total feature number = 3120 diff --git a/test/train-sets/ref/issue4669_train.stdout b/test/train-sets/ref/issue4669_train.stdout new file mode 100644 index 00000000000..e69de29bb2d diff --git a/vowpalwabbit/core/src/decision_scores.cc b/vowpalwabbit/core/src/decision_scores.cc index 4bc0810c7c9..02529c5b42c 100644 --- a/vowpalwabbit/core/src/decision_scores.cc +++ b/vowpalwabbit/core/src/decision_scores.cc @@ -26,7 +26,8 @@ void print_update(VW::workspace& all, const VW::multi_ex& slots, const VW::decis std::string delim; for (const auto& slot : decision_scores) { - pred_ss << delim << slot[0].action; + if (slot.empty()) { pred_ss << delim << "None"; } + else { pred_ss << delim << slot[0].action; } delim = ","; } all.sd->print_update(*all.output_runtime.trace_message, all.passes_config.holdout_set_off, diff --git a/vowpalwabbit/core/src/reductions/conditional_contextual_bandit.cc b/vowpalwabbit/core/src/reductions/conditional_contextual_bandit.cc index b332793ded8..893b5daac55 100644 --- a/vowpalwabbit/core/src/reductions/conditional_contextual_bandit.cc +++ b/vowpalwabbit/core/src/reductions/conditional_contextual_bandit.cc @@ -5,6 +5,7 @@ #include "vw/core/reductions/conditional_contextual_bandit.h" #include "vw/config/options.h" +#include "vw/core/cb.h" #include "vw/core/ccb_label.h" #include "vw/core/ccb_reduction_features.h" #include "vw/core/constant.h" @@ -213,8 +214,12 @@ void clear_pred_and_label(ccb_data& data) data.actions[data.action_with_label]->l.cb.costs.clear(); } -// true if there exists at least 1 action in the cb multi-example -bool has_action(VW::multi_ex& cb_ex) { return !cb_ex.empty(); } +// true if there exists at least 2 examples (since there can only be up to 1 +// shared example), or the 0th example is not shared. +bool has_action(VW::multi_ex& cb_ex) +{ + return cb_ex.size() > 1 || (!cb_ex.empty() && !VW::ec_is_example_header_cb(*cb_ex[0])); +} // This function intentionally does not handle increasing the num_features of the example because // the output_example function has special logic to ensure the number of features is correctly calculated. @@ -547,6 +552,10 @@ void update_stats_ccb(const VW::workspace& /* all */, shared_data& sd, const ccb num_labeled++; if (i == 0 || data.all_slots_loss_report) { + // It is possible for the prediction to be empty if there were no actions available at the time of taking the + // slot decision. In this case it does not contribute to loss. + if (preds[i].empty()) { continue; } + const float l = VW::get_cost_estimate(outcome->probabilities[VW::details::TOP_ACTION_INDEX], outcome->cost, preds[i][VW::details::TOP_ACTION_INDEX].action); loss += l * preds[i][VW::details::TOP_ACTION_INDEX].score * ec_seq[VW::details::SHARED_EX_INDEX]->weight; diff --git a/vowpalwabbit/core/tests/ccb_test.cc b/vowpalwabbit/core/tests/ccb_test.cc index d9ba62525bc..86962f51df0 100644 --- a/vowpalwabbit/core/tests/ccb_test.cc +++ b/vowpalwabbit/core/tests/ccb_test.cc @@ -145,3 +145,59 @@ TEST(Ccb, InsertInteractionsImplTest) EXPECT_THAT(result, testing::ContainerEq(expected_after)); } + +TEST(Ccb, ExplicitIncludedActionsNonExistentAction) +{ + auto vw = VW::initialize(vwtest::make_args("--ccb_explore_adf", "--quiet")); + VW::multi_ex examples; + examples.push_back(VW::read_example(*vw, "ccb shared |")); + examples.push_back(VW::read_example(*vw, "ccb action |")); + examples.push_back(VW::read_example(*vw, "ccb slot 0:10:10 10 |")); + + vw->learn(examples); + + auto& decision_scores = examples[0]->pred.decision_scores; + EXPECT_EQ(decision_scores.size(), 1); + EXPECT_EQ(decision_scores[0].size(), 0); + vw->finish_example(examples); +} + +TEST(Ccb, NoAvailableActions) +{ + auto vw = VW::initialize(vwtest::make_args("--ccb_explore_adf", "--quiet", "--all_slots_loss")); + { + VW::multi_ex examples; + examples.push_back(VW::read_example(*vw, "ccb shared |")); + examples.push_back(VW::read_example(*vw, "ccb action | a")); + examples.push_back(VW::read_example(*vw, "ccb action | b")); + examples.push_back(VW::read_example(*vw, "ccb slot 0:-1:0.5 0,1 |")); + examples.push_back(VW::read_example(*vw, "ccb slot |")); + + vw->learn(examples); + + auto& decision_scores = examples[0]->pred.decision_scores; + EXPECT_EQ(decision_scores.size(), 2); + vw->finish_example(examples); + } + + { + VW::multi_ex examples; + examples.push_back(VW::read_example(*vw, "ccb shared |")); + examples.push_back(VW::read_example(*vw, "ccb action | a")); + examples.push_back(VW::read_example(*vw, "ccb action | b")); + examples.push_back(VW::read_example(*vw, "ccb slot 0:-1:0.5 0,1 |")); + // This time restrict slot 1 to only have action 0 available + examples.push_back(VW::read_example(*vw, "ccb slot 0:-1:0.5 0 |")); + + vw->predict(examples); + + auto& decision_scores = examples[0]->pred.decision_scores; + EXPECT_EQ(decision_scores.size(), 2); + EXPECT_EQ(decision_scores[0].size(), 2); + EXPECT_EQ(decision_scores[0][0].action, 0); + EXPECT_EQ(decision_scores[0][1].action, 1); + EXPECT_EQ(decision_scores[1].size(), 0); + + vw->finish_example(examples); + } +} \ No newline at end of file From d711d2045ef8bfdfe755528af8035916a2d80fa5 Mon Sep 17 00:00:00 2001 From: Jack Gerrits Date: Mon, 27 Nov 2023 19:25:42 -0500 Subject: [PATCH 2/6] simplify test --- test/core.vwtest.json | 11 ++--- test/train-sets/issue4669.dsjson | 1 + test/train-sets/issue4669.txt | 50 --------------------- test/train-sets/ref/issue4669_test.stderr | 8 ++-- test/train-sets/ref/issue4669_test.stdout | 1 - test/train-sets/ref/issue4669_test_pred.txt | 3 ++ test/train-sets/ref/issue4669_train.stderr | 9 ++-- 7 files changed, 18 insertions(+), 65 deletions(-) create mode 100644 test/train-sets/issue4669.dsjson delete mode 100644 test/train-sets/issue4669.txt create mode 100644 test/train-sets/ref/issue4669_test_pred.txt diff --git a/test/core.vwtest.json b/test/core.vwtest.json index 6c5fa549831..f569d3d812d 100644 --- a/test/core.vwtest.json +++ b/test/core.vwtest.json @@ -6077,25 +6077,26 @@ { "id": 469, "desc": "https://github.com/VowpalWabbit/vowpal_wabbit/issues/4669", - "vw_command": "--ccb_explore_adf -q UA --all_slots_loss -f issue4669.model -d train-sets/issue4669.txt", + "vw_command": "--ccb_explore_adf --dsjson -d train-sets/issue4669.dsjson -f issue4669.model", "diff_files": { "stderr": "train-sets/ref/issue4669_train.stderr", "stdout": "train-sets/ref/issue4669_train.stdout" }, "input_files": [ - "train-sets/issue4669.txt" + "train-sets/issue4669.dsjson" ] }, { "id": 470, "desc": "https://github.com/VowpalWabbit/vowpal_wabbit/issues/4669", - "vw_command": "--ccb_explore_adf -q UA --all_slots_loss -i issue4669.model -t -d train-sets/issue4669.txt", + "vw_command": "--ccb_explore_adf --dsjson --all_slots_loss --epsilon 0 -t -i issue4669.model -t -d train-sets/issue4669.dsjson -p issue4669_test_pred.txt", "diff_files": { "stderr": "train-sets/ref/issue4669_test.stderr", - "stdout": "train-sets/ref/issue4669_test.stdout" + "stdout": "train-sets/ref/issue4669_test.stdout", + "issue4669_test_pred.txt": "train-sets/ref/issue4669_test_pred.txt" }, "input_files": [ - "train-sets/issue4669.txt", + "train-sets/issue4669.dsjson", "issue4669.model" ], "depends_on": [ diff --git a/test/train-sets/issue4669.dsjson b/test/train-sets/issue4669.dsjson new file mode 100644 index 00000000000..bbd36e32773 --- /dev/null +++ b/test/train-sets/issue4669.dsjson @@ -0,0 +1 @@ +{"c":{"_multi":[{"f":"1"},{"f":"2"}],"_slots":[{"_inc":[0,1]},{"_inc":[1]}]},"_outcomes":[{"_label_cost":1.0,"_a":[0,1],"_p":[0.5,0.5]},{"_label_cost":0.0,"_a":[1],"_p":[1]}]} \ No newline at end of file diff --git a/test/train-sets/issue4669.txt b/test/train-sets/issue4669.txt deleted file mode 100644 index 2051716f9d2..00000000000 --- a/test/train-sets/issue4669.txt +++ /dev/null @@ -1,50 +0,0 @@ -ccb shared |User userID='aUser' -ccb action |Action contentId='a' -ccb action |Action contentId='b' -ccb action |Action contentId='c' -ccb action |Action contentId='d' -ccb action |Action contentId='e' -ccb action |Action contentId='f' -ccb action |Action contentId='g' -ccb action |Action contentId='h' -ccb action |Action contentId='i' -ccb action |Action contentId='j' -ccb action |Action contentId='k' -ccb action |Action contentId='l' -ccb action |Action contentId='m' -ccb action |Action contentId='n' -ccb action |Action contentId='o' -ccb action |Action contentId='p' -ccb action |Action contentId='q' -ccb action |Action contentId='r' -ccb action |Action contentId='s' -ccb action |Action contentId='t' -ccb action |Action contentId='x' -ccb action |Action contentId='y' -ccb action |Action contentId='z' -ccb action |Action contentId='aa' -ccb action |Action contentId='ab' -ccb action |Action contentId='ac' -ccb action |Action contentId='ad' -ccb action |Action contentId='ae' -ccb action |Action contentId='af' -ccb action |Action contentId='ag' -ccb slot 7:0:0.2814157009124756,18:0.4986087679862976,0:0.09111946821212769,13:0.034919969737529755,17:0.022231325507164,24:0.021717887371778488,12:0.02091880328953266,22:0.008943566121160984,21:0.0035167494788765907,4:0.003486328525468707,8:0.00306233623996377,14:0.002975673880428076,25:0.002446186961606145,11:0.0013868052046746016,6:0.0007644615834578872,23:0.0007121390081010759,10:0.0006081887986510992,20:0.00046663329703733325,29:0.00024663639487698674,5:0.00020450385636650026,19:6.42634549876675e-05,15:6.378938996931538e-05,9:4.367829387774691e-05,27:2.616703204694204e-05,1:1.8396624000160955e-05,2:1.1368916602805257e-05,26:9.699509064375889e-06,16:5.553145911108004e-06,28:2.4986536573123885e-06,3:2.245711812065565e-06 7,18,0,13,17,24,12,22,21,4,8,14,25,11,6,23,10,20,29,5,19,15,9,27,1,2,26,16,28,3 | -ccb slot 28:0:0.38746780157089233,12:0.4452095031738281,17:0.11129308491945267,13:0.021261485293507576,2:0.010781776160001755,5:0.006623556837439537,14:0.006618801970034838,0:0.00422333087772131,3:0.0029810182750225067,25:0.0011069459142163396,22:0.0010891001438722014,21:0.00043250489397905767,20:0.00032753165578469634,29:0.0002478094829712063,24:0.0001928548444993794,18:3.78785262000747e-05,1:2.8776834369637072e-05,8:2.2285950763034634e-05,10:1.975510713236872e-05,19:1.1510705917316955e-05,15:1.0838572052307427e-05,4:4.2137894524785224e-06,11:2.6999800866178703e-06,16:1.461011038372817e-06,6:1.387319798595854e-06,27:1.0769275604616269e-06,23:9.243367458111607e-07,9:1.2605019605871348e-07,26:6.900691684741389e-10 28,12,17,13,2,5,14,0,3,25,22,21,20,29,24,18,1,8,10,19,15,4,11,16,6,27,23,9,26 | -ccb slot 21:0:0.009457200765609741,17:0.13830232620239258,20:0.06671705096960068,18:0.060711901634931564,13:0.04832380264997482,12:0.02896728552877903,15:0.027813201770186424,1:0.01238782238215208,5:0.5844042301177979,24:0.007916729897260666,22:0.005313423927873373,3:0.0030796348582953215,8:0.0026463039685040712,2:0.001610504579730332,4:0.0004394367279019207,14:0.0003743874840438366,11:0.0003052169340662658,9:0.0002891358162742108,10:0.0002627648937050253,0:0.0002489395847078413,25:0.00015549163799732924,29:0.00012897276610601693,23:7.323167665163055e-05,6:4.6114979340927675e-05,16:1.574967973283492e-05,27:4.545572664937936e-06,26:4.3639242903736886e-06,19:9.780344356613568e-08 21,17,20,18,13,12,15,1,5,24,22,3,8,2,4,14,11,9,10,0,25,29,23,6,16,27,26,19 | -ccb slot 25:0:0.04852084442973137,4:0.18842479586601257,0:0.11689016968011856,11:0.09394704550504684,15:0.07111208140850067,14:0.06413374841213226,1:0.2144840508699417,17:0.04394223168492317,22:0.04030191898345947,6:0.025946179404854774,20:0.02485107257962227,23:0.02306993119418621,8:0.01103772409260273,18:0.009649633429944515,24:0.0051767947152256966,3:0.004639864899218082,5:0.0036481451243162155,13:0.003012213623151183,27:0.0018589160172268748,2:0.0014349337434396148,12:0.0013534734025597572,16:0.0009505416383035481,26:0.0006534871645271778,19:0.0006136561860330403,29:0.00025192913017235696,9:4.8225138016277924e-05,10:4.6369124902412295e-05 25,4,0,11,15,14,1,17,22,6,20,23,8,18,24,3,5,13,27,2,12,16,26,19,29,9,10 | -ccb slot 22:-1:0.2308475524187088,29:0.7056121826171875,24:0.024420326575636864,13:0.017714975401759148,5:0.006036168430000544,15:0.004493136424571276,17:0.002671352354809642,14:0.0016715804813429713,18:0.001555807190015912,12:0.0014371434226632118,19:0.0010492069413885474,10:0.0006786655867472291,0:0.0003868465428240597,2:0.0003725361602846533,3:0.0003565707884263247,4:0.00022809540678281337,6:0.00016483885701745749,11:0.00013846883666701615,1:3.974044739152305e-05,16:3.603336153901182e-05,8:3.055339402635582e-05,27:1.7396419934812002e-05,20:1.5432331565534696e-05,23:1.4172852388583124e-05,9:1.0329640645068139e-05,26:8.735177061680588e-07 22,29,24,13,5,15,17,14,18,12,19,10,0,2,3,4,6,11,1,16,8,27,20,23,9,26 | -ccb slot 12:0:0.9093628525733948,0:0.04226174205541611,17:0.018896808847784996,13:0.009150444529950619,14:0.005060152616351843,24:0.0038528849836438894,11:0.0037376414984464645,1:0.0027163256891071796,20:0.0019985607359558344,5:0.0015763206174597144,6:0.0005092258215881884,18:0.0003184932575095445,2:0.0002330887655261904,8:0.00017497778753750026,23:4.659605838241987e-05,10:4.272087971912697e-05,29:2.479964132362511e-05,3:8.60487580212066e-06,27:6.880749879201176e-06,16:5.5458899623772595e-06,15:4.874308160651708e-06,9:3.708174290295574e-06,26:3.1738302368466975e-06,19:2.801355094561586e-06,4:8.652203291603655e-07 12,0,17,13,14,24,11,1,20,5,6,18,2,8,23,10,29,3,27,16,15,9,26,19,4 | -ccb slot 5:0:0.28719252347946167,17:0.23356057703495026,29:0.17160779237747192,20:0.1328369677066803,0:0.07422535866498947,23:0.050119683146476746,1:0.021847186610102654,13:0.012630262412130833,2:0.004507853649556637,8:0.0024876149836927652,3:0.0022544104140251875,9:0.0015914351679384708,18:0.0013845351058989763,10:0.001123852445743978,4:0.0008636291604489088,15:0.0004619551182258874,16:0.0004587841685861349,6:0.0004078721103724092,11:0.00021510760416276753,14:0.00014530900807585567,24:3.345161894685589e-05,26:2.9229657229734585e-05,27:1.3438097084872425e-05,19:1.2410271210683277e-06 5,17,29,20,0,23,1,13,2,8,3,9,18,10,4,15,16,6,11,14,24,26,27,19 | -ccb slot 20:0:0.028376348316669464,0:0.1662101000547409,1:0.046988535672426224,17:0.6560971140861511,3:0.02463809959590435,2:0.022943779826164246,8:0.02021339163184166,4:0.011480750516057014,6:0.006486969999969006,27:0.005383896175771952,18:0.005046996288001537,10:0.0017316940939053893,16:0.0012900123838335276,9:0.0011327711399644613,13:0.000544899667147547,23:0.000498446635901928,11:0.00048043631250038743,14:0.00018104366608895361,19:0.00012088919174857438,24:6.899452273501083e-05,29:5.0261642172699794e-05,15:3.0958537536207587e-05,26:3.547019559846376e-06 20,0,1,17,3,2,8,4,6,27,18,10,16,9,13,23,11,14,19,24,29,15,26 | -ccb slot 29:0:0.14386186003684998,4:0.18966703116893768,11:0.2729467749595642,1:0.10815022885799408,17:0.10176316648721695,24:0.048249997198581696,3:0.034515734761953354,2:0.0317024290561676,6:0.029643042013049126,14:0.01419538538902998,10:0.008138692937791348,0:0.00682934420183301,13:0.003933712374418974,9:0.0023981353733688593,8:0.0013491454301401973,23:0.000978886615484953,15:0.0006223535747267306,27:0.00042564209434203804,18:0.0003683240502141416,19:0.00018579690367914736,26:5.759065970778465e-05,16:1.6690515622030944e-05 29,4,11,1,17,24,3,2,6,14,10,0,13,9,8,23,15,27,18,19,26,16 | -ccb slot 16:0:0.8358462452888489,18:0.14180654287338257,8:0.006622648332268,0:0.0048446557484567165,13:0.003651235019788146,17:0.002547427313402295,2:0.0023207622580230236,14:0.0007652127533219755,9:0.0003962431219406426,3:0.00039094872772693634,6:0.0003010313375853002,1:0.00015582605556119233,4:0.0001078778732335195,27:8.33657686598599e-05,11:6.325534195639193e-05,26:5.2438019338296726e-05,24:2.5939621991710737e-05,10:1.2438776138878893e-05,15:5.886784038011683e-06,23:1.0229091351732222e-08,19:7.580823080388654e-09 16,18,8,0,13,17,2,14,9,3,6,1,4,27,11,26,24,10,15,23,19 | -ccb slot 17:0:0.5127939581871033,2:0.4746686518192291,19:0.006616346072405577,1:0.001669980469159782,3:0.0010488936677575111,24:0.0009213163866661489,6:0.0007942019728943706,14:0.0006952053518034518,23:0.0005310842534527183,27:0.00010622834088280797,0:5.494915967574343e-05,8:2.4259865313069895e-05,26:2.4085793484118767e-05,10:1.763248656061478e-05,13:1.2499597687565256e-05,4:9.418990885023959e-06,11:6.62406728224596e-06,15:4.250239726388827e-06,9:3.168663909036695e-07 17,2,19,1,3,24,6,14,23,27,0,8,26,10,13,4,11,15,9 | -ccb slot 2:0:0.3009277284145355,11:0.2405196875333786,0:0.12718695402145386,1:0.10248230397701263,4:0.10110407322645187,27:0.0488734170794487,14:0.031138606369495392,23:0.027005093172192574,15:0.01463699247688055,13:0.0021626290399581194,9:0.001511908951215446,24:0.0014617646811529994,26:0.000723238626960665,8:0.00015963710029609501,10:6.141114135971293e-05,3:4.065587563673034e-05,6:2.3371310362563236e-06,19:1.6080829254860873e-06 2,11,0,1,4,27,14,23,15,13,9,24,26,8,10,3,6,19 | -ccb slot 3:0:0.057185981422662735,27:0.14134737849235535,24:0.11761204898357391,11:0.09371144324541092,0:0.5287960171699524,6:0.03242357447743416,8:0.011251688934862614,14:0.007142344955354929,1:0.0070684002712368965,23:0.0012641034554690123,9:0.0005852219182997942,13:0.0005638026632368565,19:0.0004846873343922198,15:0.0002985078317578882,26:0.00022339983843266964,10:3.1262432457879186e-05,4:1.0203940291830804e-05 3,27,24,11,0,6,8,14,1,23,9,13,19,15,26,10,4 | -ccb slot 1:0:0.1249222531914711,24:0.25583717226982117,14:0.12867441773414612,23:0.26537543535232544,13:0.08874479681253433,15:0.04395920783281326,11:0.04107809066772461,26:0.026856984943151474,0:0.01694386824965477,6:0.005332316737622023,9:0.0009168571559712291,8:0.0005701580666936934,10:0.00035485764965415,4:0.00030593250994570553,27:9.87778403214179e-05,19:2.8945323720108718e-05 1,24,14,23,13,15,11,26,0,6,9,8,10,4,27,19 | -ccb slot 0:0:0.7005100846290588,15:0.15935854613780975,26:0.10210220515727997,6:0.01650484837591648,27:0.012556466273963451,14:0.003727053524926305,10:0.0025744284503161907,23:0.0013873651623725891,11:0.0004973539034835994,13:0.00025413508410565555,24:0.0002187014470109716,8:0.00016339073772542179,19:8.715951116755605e-05,4:4.481669020606205e-05,9:1.3443200259644073e-05 0,15,26,6,27,14,10,23,11,13,24,8,19,4,9 | -ccb slot 8:0:0.9874762892723083,6:0.00831417366862297,24:0.0015088269719853997,19:0.0014199139550328255,13:0.0005207555368542671,4:0.0003495477430988103,23:0.00022233030176721513,11:0.0001492148294346407,10:1.3149796359357424e-05,27:9.886987754725851e-06,15:9.333793059340678e-06,9:2.8232302611286286e-06,14:2.0402019345056033e-06,26:1.7364958466714597e-06 8,6,24,19,13,4,23,11,10,27,15,9,14,26 | -ccb slot 14:0:0.12711836397647858,9:0.23751074075698853,10:0.3793737292289734,27:0.08924281597137451,24:0.07040458917617798,15:0.06249556690454483,11:0.02001812309026718,4:0.009080302901566029,26:0.0017807148396968842,23:0.0015691084554418921,19:0.0009121194598264992,13:0.0004805738863069564,6:1.3344042599783279e-05 14,9,10,27,24,15,11,4,26,23,19,13,6 | -ccb slot 24:0:0.9453869462013245,19:0.029324114322662354,23:0.017052331939339638,13:0.004684425424784422,11:0.0012589030666276813,10:0.0009999927133321762,26:0.0009100232855416834,15:0.00019236477965023369,6:0.00014357283362187445,27:4.12082408729475e-05,9:3.6369422105053673e-06,4:2.426173068670323e-06 24,19,23,13,11,10,26,15,6,27,9,4 | -ccb slot 19:0:0.07647673040628433,13:0.26458504796028137,27:0.08122498542070389,26:0.548874020576477,11:0.025309467688202858,15:0.0020130074117332697,6:0.00121505802962929,23:0.00012354919454082847,9:9.067041537491605e-05,4:4.7242716391338035e-05,10:4.006700328318402e-05 19,13,27,26,11,15,6,23,9,4,10 | \ No newline at end of file diff --git a/test/train-sets/ref/issue4669_test.stderr b/test/train-sets/ref/issue4669_test.stderr index dea94504ebf..9b3fb9ce7cf 100644 --- a/test/train-sets/ref/issue4669_test.stderr +++ b/test/train-sets/ref/issue4669_test.stderr @@ -1,7 +1,7 @@ -creating quadratic features for pairs: UA only testing +predictions = issue4669_test_pred.txt using no cache -Reading datafile = train-sets/issue4669.txt +Reading datafile = train-sets/issue4669.dsjson num sources = 1 Num weight bits = 18 learning rate = 0.5 @@ -13,11 +13,11 @@ Input label = CCB Output pred = DECISION_PROBS average since example example current current current loss last counter weight label predict features -0.000000 0.000000 1 1.0 7:0,28:0,... 22,4,6,25,2... 2059 +0.000000 0.000000 1 1.0 0:1,1:0 1,None 9 finished run number of examples = 1 weighted example sum = 1.000000 weighted label sum = 0.000000 average loss = 0.000000 -total feature number = 2059 +total feature number = 9 diff --git a/test/train-sets/ref/issue4669_test.stdout b/test/train-sets/ref/issue4669_test.stdout index c310247cee5..e69de29bb2d 100644 --- a/test/train-sets/ref/issue4669_test.stdout +++ b/test/train-sets/ref/issue4669_test.stdout @@ -1 +0,0 @@ -[warning] model file has set of {-q, --cubic, --interactions} settings stored, but they'll be OVERRIDDEN by set of {-q, --cubic, --interactions} settings from command line. diff --git a/test/train-sets/ref/issue4669_test_pred.txt b/test/train-sets/ref/issue4669_test_pred.txt new file mode 100644 index 00000000000..ba6b9ca942b --- /dev/null +++ b/test/train-sets/ref/issue4669_test_pred.txt @@ -0,0 +1,3 @@ +1:1,0:0 + + diff --git a/test/train-sets/ref/issue4669_train.stderr b/test/train-sets/ref/issue4669_train.stderr index 0b14fdf24f2..48505ae87ae 100644 --- a/test/train-sets/ref/issue4669_train.stderr +++ b/test/train-sets/ref/issue4669_train.stderr @@ -1,7 +1,6 @@ -creating quadratic features for pairs: UA final_regressor = issue4669.model using no cache -Reading datafile = train-sets/issue4669.txt +Reading datafile = train-sets/issue4669.dsjson num sources = 1 Num weight bits = 18 learning rate = 0.5 @@ -13,11 +12,11 @@ Input label = CCB Output pred = DECISION_PROBS average since example example current current current loss last counter weight label predict features --0.16661 -0.16661 1 1.0 7:0,28:0,... 7,28,21,25,... 3120 +1.000000 1.000000 1 1.0 0:1,1:0 0,1 12 finished run number of examples = 1 weighted example sum = 1.000000 weighted label sum = 0.000000 -average loss = -0.166610 -total feature number = 3120 +average loss = 1.000000 +total feature number = 12 From acb36e3447efd666b2adaa9bc85cbfd19cac3148 Mon Sep 17 00:00:00 2001 From: Jack Gerrits Date: Fri, 8 Dec 2023 16:01:33 -0500 Subject: [PATCH 3/6] ensure empty predictions do not affect num_labeled as well as loss --- .../core/src/reductions/conditional_contextual_bandit.cc | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/vowpalwabbit/core/src/reductions/conditional_contextual_bandit.cc b/vowpalwabbit/core/src/reductions/conditional_contextual_bandit.cc index 893b5daac55..ee3db743b11 100644 --- a/vowpalwabbit/core/src/reductions/conditional_contextual_bandit.cc +++ b/vowpalwabbit/core/src/reductions/conditional_contextual_bandit.cc @@ -549,13 +549,13 @@ void update_stats_ccb(const VW::workspace& /* all */, shared_data& sd, const ccb auto* outcome = data.slots[i]->l.conditional_contextual_bandit.outcome; if (outcome != nullptr) { + // It is possible for the prediction to be empty if there were no actions available at the time of taking the + // slot decision. In this case it does not contribute to loss. + if (preds[i].empty()) { continue; } + num_labeled++; if (i == 0 || data.all_slots_loss_report) { - // It is possible for the prediction to be empty if there were no actions available at the time of taking the - // slot decision. In this case it does not contribute to loss. - if (preds[i].empty()) { continue; } - const float l = VW::get_cost_estimate(outcome->probabilities[VW::details::TOP_ACTION_INDEX], outcome->cost, preds[i][VW::details::TOP_ACTION_INDEX].action); loss += l * preds[i][VW::details::TOP_ACTION_INDEX].score * ec_seq[VW::details::SHARED_EX_INDEX]->weight; From 89254ad947b55fce736c8a347f82d938b94b43c4 Mon Sep 17 00:00:00 2001 From: Jack Gerrits Date: Thu, 1 Feb 2024 11:42:15 -0500 Subject: [PATCH 4/6] Update conditional_contextual_bandit.cc --- .../core/src/reductions/conditional_contextual_bandit.cc | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vowpalwabbit/core/src/reductions/conditional_contextual_bandit.cc b/vowpalwabbit/core/src/reductions/conditional_contextual_bandit.cc index ee3db743b11..75762e8ce30 100644 --- a/vowpalwabbit/core/src/reductions/conditional_contextual_bandit.cc +++ b/vowpalwabbit/core/src/reductions/conditional_contextual_bandit.cc @@ -549,11 +549,10 @@ void update_stats_ccb(const VW::workspace& /* all */, shared_data& sd, const ccb auto* outcome = data.slots[i]->l.conditional_contextual_bandit.outcome; if (outcome != nullptr) { + num_labeled++; // It is possible for the prediction to be empty if there were no actions available at the time of taking the // slot decision. In this case it does not contribute to loss. if (preds[i].empty()) { continue; } - - num_labeled++; if (i == 0 || data.all_slots_loss_report) { const float l = VW::get_cost_estimate(outcome->probabilities[VW::details::TOP_ACTION_INDEX], outcome->cost, From 0d8798ec007871adc5bcd674d1f03eb0b19df435 Mon Sep 17 00:00:00 2001 From: Jack Gerrits Date: Tue, 27 Feb 2024 10:43:18 -0500 Subject: [PATCH 5/6] Bounds check for explicit inclusion --- .../core/src/reductions/conditional_contextual_bandit.cc | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/vowpalwabbit/core/src/reductions/conditional_contextual_bandit.cc b/vowpalwabbit/core/src/reductions/conditional_contextual_bandit.cc index 75762e8ce30..d86eb1a272a 100644 --- a/vowpalwabbit/core/src/reductions/conditional_contextual_bandit.cc +++ b/vowpalwabbit/core/src/reductions/conditional_contextual_bandit.cc @@ -314,7 +314,14 @@ void build_cb_example(VW::multi_ex& cb_ex, VW::example* slot, const VW::ccb_labe // First time seeing this, initialize the vector with falses so we can start setting each included action. if (data.include_list.empty()) { data.include_list.assign(data.actions.size(), false); } - for (uint32_t included_action_id : explicit_includes) { data.include_list[included_action_id] = true; } + for (uint32_t included_action_id : explicit_includes) + { + // The action may be included but not actually exist in the list of possible actions. + if (included_action_id < data.actions.size()) + { + data.include_list[included_action_id] = true; + } + } } // set the available actions in the cb multi-example From 56b57575877953c027565c78b26a0f4ff03e0da0 Mon Sep 17 00:00:00 2001 From: Jack Gerrits Date: Tue, 27 Feb 2024 10:45:07 -0500 Subject: [PATCH 6/6] Formatting --- .../core/src/reductions/conditional_contextual_bandit.cc | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/vowpalwabbit/core/src/reductions/conditional_contextual_bandit.cc b/vowpalwabbit/core/src/reductions/conditional_contextual_bandit.cc index d86eb1a272a..9d15dd4a1ce 100644 --- a/vowpalwabbit/core/src/reductions/conditional_contextual_bandit.cc +++ b/vowpalwabbit/core/src/reductions/conditional_contextual_bandit.cc @@ -317,10 +317,7 @@ void build_cb_example(VW::multi_ex& cb_ex, VW::example* slot, const VW::ccb_labe for (uint32_t included_action_id : explicit_includes) { // The action may be included but not actually exist in the list of possible actions. - if (included_action_id < data.actions.size()) - { - data.include_list[included_action_id] = true; - } + if (included_action_id < data.actions.size()) { data.include_list[included_action_id] = true; } } }