forked from VowpalWabbit/vowpal_wabbit
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add example C# Simulator (VowpalWabbit#1790)
* add simulator * Appveyor uses VS2017 * Bump msbuild version * Update msbuild detection * Fix image * try changing call * Add echo * Fix quoting of strings * Add back echo * Add delayed expansion * remove extra quotes * move back to hard coded path * Revert solution upgrade, remove C# explore code * Add restore * Add sampling, fix small bugs
- Loading branch information
1 parent
a6fc722
commit ae3d6a4
Showing
8 changed files
with
444 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
<?xml version="1.0" encoding="utf-8" ?> | ||
<configuration> | ||
<startup> | ||
<supportedRuntime version="v4.0" sku=".NETFramework,Version=v4.5.2" /> | ||
</startup> | ||
</configuration> |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
using System; | ||
|
||
namespace simulator | ||
{ | ||
class Program | ||
{ | ||
private static readonly string help_string = "usage: simulator initial_random tot_iter mod_iter reward_seed vw_seed exp_iter num_contexts num_actions ml_args_snips"; | ||
|
||
static void Main(string[] args) | ||
{ | ||
string ml_args = args[0] + " --quiet"; | ||
|
||
int initial_random; | ||
int tot_iter; | ||
int mod_iter; | ||
int reward_seed; | ||
ulong vw_seed; | ||
int exp_iter; | ||
int num_contexts; | ||
int num_actions; | ||
|
||
if (!int.TryParse(args[1], out initial_random)) | ||
{ | ||
Console.WriteLine(help_string); | ||
return; | ||
} | ||
|
||
if (!int.TryParse(args[2], out tot_iter)) | ||
{ | ||
Console.WriteLine(help_string); | ||
return; | ||
} | ||
|
||
if (!int.TryParse(args[3], out mod_iter)) | ||
{ | ||
Console.WriteLine(help_string); | ||
return; | ||
} | ||
|
||
if (!int.TryParse(args[4], out reward_seed)) | ||
{ | ||
Console.WriteLine(help_string); | ||
return; | ||
} | ||
|
||
if (!ulong.TryParse(args[5], out vw_seed)) | ||
{ | ||
Console.WriteLine(help_string); | ||
return; | ||
} | ||
|
||
if (!int.TryParse(args[6], out exp_iter)) | ||
{ | ||
Console.WriteLine(help_string); | ||
return; | ||
} | ||
|
||
if (!int.TryParse(args[7], out num_contexts)) | ||
{ | ||
Console.WriteLine(help_string); | ||
return; | ||
} | ||
|
||
if (!int.TryParse(args[8], out num_actions)) | ||
{ | ||
Console.WriteLine(help_string); | ||
return; | ||
} | ||
|
||
string ml_args_snips = "--cb_explore_adf --epsilon .05 --cb_type mtr -l 1e-8 --power_t 0 --quiet"; | ||
if (args.Length > 9) | ||
ml_args_snips = args[9] + " --quiet"; | ||
|
||
VowpalWabbitSimulator.Run(ml_args, initial_random, tot_iter, mod_iter, reward_seed, vw_seed, exp_iter, num_contexts, num_actions, ml_args_snips); | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
using System.Reflection; | ||
using System.Runtime.CompilerServices; | ||
using System.Runtime.InteropServices; | ||
|
||
// General Information about an assembly is controlled through the following | ||
// set of attributes. Change these attribute values to modify the information | ||
// associated with an assembly. | ||
[assembly: AssemblyTitle("simulator")] | ||
[assembly: AssemblyDescription("")] | ||
[assembly: AssemblyConfiguration("")] | ||
[assembly: AssemblyCompany("")] | ||
[assembly: AssemblyProduct("simulator")] | ||
[assembly: AssemblyCopyright("Copyright © 2019")] | ||
[assembly: AssemblyTrademark("")] | ||
[assembly: AssemblyCulture("")] | ||
|
||
// Setting ComVisible to false makes the types in this assembly not visible | ||
// to COM components. If you need to access a type in this assembly from | ||
// COM, set the ComVisible attribute to true on that type. | ||
[assembly: ComVisible(false)] | ||
|
||
// The following GUID is for the ID of the typelib if this project is exposed to COM | ||
[assembly: Guid("2a9c6717-3b6c-4db7-a626-16c04fcfeccf")] | ||
|
||
// Version information for an assembly consists of the following four values: | ||
// | ||
// Major Version | ||
// Minor Version | ||
// Build Number | ||
// Revision | ||
// | ||
// You can specify all the values or you can default the Build and Revision Numbers | ||
// by using the '*' as shown below: | ||
// [assembly: AssemblyVersion("1.0.*")] | ||
[assembly: AssemblyVersion("1.0.0.0")] | ||
[assembly: AssemblyFileVersion("1.0.0.0")] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,202 @@ | ||
//using Microsoft.Research.MultiWorldTesting.ExploreLibrary; | ||
using Newtonsoft.Json; | ||
using System; | ||
using System.IO; | ||
using System.Linq; | ||
using System.Text; | ||
using VW; | ||
using VW.Labels; | ||
|
||
namespace simulator | ||
{ | ||
public static class VowpalWabbitSimulator | ||
{ | ||
public class SimulatorExample | ||
{ | ||
private readonly int length; | ||
|
||
private readonly byte[] exampleBuffer; | ||
|
||
public float[] PDF { get; } | ||
|
||
public SimulatorExample(int numActions, int sharedContext) | ||
{ | ||
// generate distinct per user context with 2 seperate prefered actions | ||
this.PDF = Enumerable.Range(0, numActions).Select(_ => 0.005f).ToArray(); | ||
this.PDF[sharedContext] = 0.03f; | ||
|
||
this.exampleBuffer = new byte[32 * 1024]; | ||
|
||
var str = JsonConvert.SerializeObject( | ||
new | ||
{ | ||
Version = "1", | ||
EventId = "1", // can be ignored | ||
a = Enumerable.Range(1, numActions).ToArray(), | ||
c = new | ||
{ | ||
// shared user context | ||
U = new { C = sharedContext.ToString() }, | ||
_multi = Enumerable | ||
.Range(0, numActions) | ||
.Select(i => new { A = new { Constant = 1, Id = i.ToString() }, B = new { Id = i.ToString() } }) | ||
.ToArray() | ||
}, | ||
p = Enumerable.Range(0, numActions).Select(i => 0.0f).ToArray() | ||
}); | ||
|
||
Console.WriteLine(str); | ||
|
||
// allow for \0 at the end | ||
this.length = Encoding.UTF8.GetBytes(str, 0, str.Length, exampleBuffer, 0); | ||
exampleBuffer[this.length] = 0; | ||
this.length++; | ||
} | ||
|
||
public VowpalWabbitMultiLineExampleCollection CreateExample(VowpalWabbit vw) | ||
{ | ||
VowpalWabbitDecisionServiceInteractionHeader header; | ||
var examples = vw.ParseDecisionServiceJson(this.exampleBuffer, 0, this.length, true, out header); | ||
|
||
var adf = new VowpalWabbitExample[examples.Count - 1]; | ||
examples.CopyTo(1, adf, 0, examples.Count - 1); | ||
|
||
return new VowpalWabbitMultiLineExampleCollection(vw, examples[0], adf); | ||
} | ||
} | ||
|
||
private static void ExportScoringModel(VowpalWabbit learner, ref VowpalWabbit scorer) | ||
{ | ||
scorer?.Dispose(); | ||
using (var memStream = new MemoryStream()) | ||
{ | ||
learner.SaveModel(memStream); | ||
|
||
memStream.Seek(0, SeekOrigin.Begin); | ||
|
||
// Note: the learner doesn't use save-resume as done online | ||
scorer = new VowpalWabbit(new VowpalWabbitSettings { Arguments = "--quiet", ModelStream = memStream }); | ||
} | ||
} | ||
|
||
public static void Run(string ml_args, int initial_random, int tot_iter, int mod_iter, int rewardSeed, ulong vwSeed, int exp_iter, int numContexts, int numActions, string ml_args_snips) | ||
{ | ||
// byte buffer outside so one can change the example and keep the memory around | ||
var exampleBuffer = new byte[32 * 1024]; | ||
|
||
var randGen = new Random(rewardSeed); | ||
var userGen = new Random(); | ||
|
||
var simExamples = Enumerable.Range(0, numContexts) | ||
.Select(i => new SimulatorExample(numActions, i)) | ||
.ToArray(); | ||
|
||
var scorerPdf = new float[numActions]; | ||
var histPred = new int[numActions, numContexts]; | ||
var histPred2 = new int[numActions, numContexts]; | ||
var histActions = new int[numActions, numContexts]; | ||
var histCost = new int[numActions, numContexts]; | ||
var histContext = new int[numContexts]; | ||
int clicks = 0; | ||
double snips_num = 0, snips_den = 0; | ||
|
||
using (var learner = new VowpalWabbit(ml_args)) | ||
using (var learner2 = new VowpalWabbit(ml_args_snips)) | ||
{ | ||
VowpalWabbit scorer = null; | ||
|
||
scorer = new VowpalWabbit("--cb_explore_adf --epsilon 1 --quiet"); | ||
for (int i = 1; i <= tot_iter; i++) | ||
{ | ||
// sample uniform among users | ||
int userIndex = userGen.Next(simExamples.Length); | ||
var simExample = simExamples[userIndex]; | ||
var pdf = simExample.PDF; | ||
|
||
histContext[userIndex]++; | ||
|
||
using (var ex = simExample.CreateExample(learner)) | ||
{ | ||
var scores = ex.Predict(VowpalWabbitPredictionType.ActionProbabilities, scorer); | ||
|
||
var total = 0.0; | ||
|
||
foreach (var actionScore in scores) | ||
{ | ||
total += actionScore.Score; | ||
scorerPdf[actionScore.Action] = actionScore.Score; | ||
} | ||
|
||
var draw = randGen.NextDouble() * total; | ||
var sum = 0.0; | ||
uint topAction = 0; | ||
foreach (var actionScore in scores) | ||
{ | ||
sum += actionScore.Score; | ||
if(sum > draw) | ||
{ | ||
topAction = actionScore.Action; | ||
break; | ||
} | ||
} | ||
|
||
int modelAction = (int)scores[0].Action; | ||
if (i > initial_random) | ||
histPred[modelAction, userIndex] += 1; | ||
histActions[topAction, userIndex] += 1; | ||
|
||
// simulate behavior | ||
float cost = 0; | ||
if (randGen.NextDouble() < pdf[topAction]) | ||
{ | ||
cost = -1; | ||
histCost[topAction, userIndex] += 1; | ||
clicks += 1; | ||
} | ||
|
||
ex.Examples[topAction].Label = new ContextualBanditLabel((uint)topAction, cost, scorerPdf[topAction]); | ||
|
||
// simulate delay | ||
if (i >= initial_random && (i % exp_iter == 0)) | ||
{ | ||
ExportScoringModel(learner, ref scorer); | ||
} | ||
|
||
// invoke learning | ||
var oneStepAheadScores = ex.Learn(VowpalWabbitPredictionType.ActionProbabilities, learner); | ||
histPred2[oneStepAheadScores[0].Action, userIndex] += 1; | ||
|
||
var oneStepAheadScores2 = ex.Learn(VowpalWabbitPredictionType.ActionProbabilities, learner2); | ||
|
||
// SNIPS | ||
snips_num -= oneStepAheadScores2.First(f => f.Action == topAction).Score * cost / scorerPdf[topAction]; | ||
snips_den += oneStepAheadScores2.First(f => f.Action == topAction).Score / scorerPdf[topAction]; | ||
|
||
if (i % mod_iter == 0 || i == tot_iter) | ||
{ | ||
Console.WriteLine(JsonConvert.SerializeObject(new | ||
{ | ||
Iter = i, | ||
clicks, | ||
CTR = clicks / (float)i, | ||
aveLoss = learner.PerformanceStatistics.AverageLoss, | ||
CTR_snips = snips_num / snips_den, | ||
CTR_ips = snips_num / (float)i, | ||
aveLoss2 = learner2.PerformanceStatistics.AverageLoss, | ||
snips_num, | ||
snips_den, | ||
histActions, | ||
histPred, | ||
histCost, | ||
histContext, | ||
})); | ||
} | ||
} | ||
} | ||
Console.WriteLine("---------------------"); | ||
scorer?.Dispose(); | ||
} | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
<?xml version="1.0" encoding="utf-8"?> | ||
<packages> | ||
<package id="Newtonsoft.Json" version="12.0.1" targetFramework="net452" /> | ||
</packages> |
Oops, something went wrong.