Skip to content

Commit

Permalink
Add example C# Simulator (VowpalWabbit#1790)
Browse files Browse the repository at this point in the history
* add simulator

* Appveyor uses VS2017

* Bump msbuild version

* Update msbuild detection

* Fix image

* try changing call

* Add echo

* Fix quoting of strings

* Add back echo

* Add delayed expansion

* remove extra quotes

* move back to hard coded path

* Revert solution upgrade, remove C# explore code

* Add restore

* Add sampling, fix small bugs
  • Loading branch information
jackgerrits committed May 15, 2019
1 parent a6fc722 commit ae3d6a4
Show file tree
Hide file tree
Showing 8 changed files with 444 additions and 1 deletion.
4 changes: 4 additions & 0 deletions .scripts/restore.cmd
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ ECHO Restoring "%vwRoot%\vowpalwabbit\packages.config"
"%nugetPath%" restore -o "%vwRoot%\vowpalwabbit\packages" "%vwRoot%\vowpalwabbit\packages.config"
ECHO.

ECHO Restoring "%vwRoot%\cs\examples\simulator\packages.config"
"%nugetPath%" restore -o "%vwRoot%\vowpalwabbit\packages" "%vwRoot%\cs\examples\simulator\packages.config"
ECHO.

POPD

ENDLOCAL
6 changes: 6 additions & 0 deletions cs/examples/simulator/App.config
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
<?xml version="1.0" encoding="utf-8" ?>
<configuration>
<startup>
<supportedRuntime version="v4.0" sku=".NETFramework,Version=v4.5.2" />
</startup>
</configuration>
77 changes: 77 additions & 0 deletions cs/examples/simulator/Program.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
using System;

namespace simulator
{
class Program
{
private static readonly string help_string = "usage: simulator initial_random tot_iter mod_iter reward_seed vw_seed exp_iter num_contexts num_actions ml_args_snips";

static void Main(string[] args)
{
string ml_args = args[0] + " --quiet";

int initial_random;
int tot_iter;
int mod_iter;
int reward_seed;
ulong vw_seed;
int exp_iter;
int num_contexts;
int num_actions;

if (!int.TryParse(args[1], out initial_random))
{
Console.WriteLine(help_string);
return;
}

if (!int.TryParse(args[2], out tot_iter))
{
Console.WriteLine(help_string);
return;
}

if (!int.TryParse(args[3], out mod_iter))
{
Console.WriteLine(help_string);
return;
}

if (!int.TryParse(args[4], out reward_seed))
{
Console.WriteLine(help_string);
return;
}

if (!ulong.TryParse(args[5], out vw_seed))
{
Console.WriteLine(help_string);
return;
}

if (!int.TryParse(args[6], out exp_iter))
{
Console.WriteLine(help_string);
return;
}

if (!int.TryParse(args[7], out num_contexts))
{
Console.WriteLine(help_string);
return;
}

if (!int.TryParse(args[8], out num_actions))
{
Console.WriteLine(help_string);
return;
}

string ml_args_snips = "--cb_explore_adf --epsilon .05 --cb_type mtr -l 1e-8 --power_t 0 --quiet";
if (args.Length > 9)
ml_args_snips = args[9] + " --quiet";

VowpalWabbitSimulator.Run(ml_args, initial_random, tot_iter, mod_iter, reward_seed, vw_seed, exp_iter, num_contexts, num_actions, ml_args_snips);
}
}
}
36 changes: 36 additions & 0 deletions cs/examples/simulator/Properties/AssemblyInfo.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
using System.Reflection;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;

// General Information about an assembly is controlled through the following
// set of attributes. Change these attribute values to modify the information
// associated with an assembly.
[assembly: AssemblyTitle("simulator")]
[assembly: AssemblyDescription("")]
[assembly: AssemblyConfiguration("")]
[assembly: AssemblyCompany("")]
[assembly: AssemblyProduct("simulator")]
[assembly: AssemblyCopyright("Copyright © 2019")]
[assembly: AssemblyTrademark("")]
[assembly: AssemblyCulture("")]

// Setting ComVisible to false makes the types in this assembly not visible
// to COM components. If you need to access a type in this assembly from
// COM, set the ComVisible attribute to true on that type.
[assembly: ComVisible(false)]

// The following GUID is for the ID of the typelib if this project is exposed to COM
[assembly: Guid("2a9c6717-3b6c-4db7-a626-16c04fcfeccf")]

// Version information for an assembly consists of the following four values:
//
// Major Version
// Minor Version
// Build Number
// Revision
//
// You can specify all the values or you can default the Build and Revision Numbers
// by using the '*' as shown below:
// [assembly: AssemblyVersion("1.0.*")]
[assembly: AssemblyVersion("1.0.0.0")]
[assembly: AssemblyFileVersion("1.0.0.0")]
202 changes: 202 additions & 0 deletions cs/examples/simulator/VowpalWabbitSimulator.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
//using Microsoft.Research.MultiWorldTesting.ExploreLibrary;
using Newtonsoft.Json;
using System;
using System.IO;
using System.Linq;
using System.Text;
using VW;
using VW.Labels;

namespace simulator
{
public static class VowpalWabbitSimulator
{
public class SimulatorExample
{
private readonly int length;

private readonly byte[] exampleBuffer;

public float[] PDF { get; }

public SimulatorExample(int numActions, int sharedContext)
{
// generate distinct per user context with 2 seperate prefered actions
this.PDF = Enumerable.Range(0, numActions).Select(_ => 0.005f).ToArray();
this.PDF[sharedContext] = 0.03f;

this.exampleBuffer = new byte[32 * 1024];

var str = JsonConvert.SerializeObject(
new
{
Version = "1",
EventId = "1", // can be ignored
a = Enumerable.Range(1, numActions).ToArray(),
c = new
{
// shared user context
U = new { C = sharedContext.ToString() },
_multi = Enumerable
.Range(0, numActions)
.Select(i => new { A = new { Constant = 1, Id = i.ToString() }, B = new { Id = i.ToString() } })
.ToArray()
},
p = Enumerable.Range(0, numActions).Select(i => 0.0f).ToArray()
});

Console.WriteLine(str);

// allow for \0 at the end
this.length = Encoding.UTF8.GetBytes(str, 0, str.Length, exampleBuffer, 0);
exampleBuffer[this.length] = 0;
this.length++;
}

public VowpalWabbitMultiLineExampleCollection CreateExample(VowpalWabbit vw)
{
VowpalWabbitDecisionServiceInteractionHeader header;
var examples = vw.ParseDecisionServiceJson(this.exampleBuffer, 0, this.length, true, out header);

var adf = new VowpalWabbitExample[examples.Count - 1];
examples.CopyTo(1, adf, 0, examples.Count - 1);

return new VowpalWabbitMultiLineExampleCollection(vw, examples[0], adf);
}
}

private static void ExportScoringModel(VowpalWabbit learner, ref VowpalWabbit scorer)
{
scorer?.Dispose();
using (var memStream = new MemoryStream())
{
learner.SaveModel(memStream);

memStream.Seek(0, SeekOrigin.Begin);

// Note: the learner doesn't use save-resume as done online
scorer = new VowpalWabbit(new VowpalWabbitSettings { Arguments = "--quiet", ModelStream = memStream });
}
}

public static void Run(string ml_args, int initial_random, int tot_iter, int mod_iter, int rewardSeed, ulong vwSeed, int exp_iter, int numContexts, int numActions, string ml_args_snips)
{
// byte buffer outside so one can change the example and keep the memory around
var exampleBuffer = new byte[32 * 1024];

var randGen = new Random(rewardSeed);
var userGen = new Random();

var simExamples = Enumerable.Range(0, numContexts)
.Select(i => new SimulatorExample(numActions, i))
.ToArray();

var scorerPdf = new float[numActions];
var histPred = new int[numActions, numContexts];
var histPred2 = new int[numActions, numContexts];
var histActions = new int[numActions, numContexts];
var histCost = new int[numActions, numContexts];
var histContext = new int[numContexts];
int clicks = 0;
double snips_num = 0, snips_den = 0;

using (var learner = new VowpalWabbit(ml_args))
using (var learner2 = new VowpalWabbit(ml_args_snips))
{
VowpalWabbit scorer = null;

scorer = new VowpalWabbit("--cb_explore_adf --epsilon 1 --quiet");
for (int i = 1; i <= tot_iter; i++)
{
// sample uniform among users
int userIndex = userGen.Next(simExamples.Length);
var simExample = simExamples[userIndex];
var pdf = simExample.PDF;

histContext[userIndex]++;

using (var ex = simExample.CreateExample(learner))
{
var scores = ex.Predict(VowpalWabbitPredictionType.ActionProbabilities, scorer);

var total = 0.0;

foreach (var actionScore in scores)
{
total += actionScore.Score;
scorerPdf[actionScore.Action] = actionScore.Score;
}

var draw = randGen.NextDouble() * total;
var sum = 0.0;
uint topAction = 0;
foreach (var actionScore in scores)
{
sum += actionScore.Score;
if(sum > draw)
{
topAction = actionScore.Action;
break;
}
}

int modelAction = (int)scores[0].Action;
if (i > initial_random)
histPred[modelAction, userIndex] += 1;
histActions[topAction, userIndex] += 1;

// simulate behavior
float cost = 0;
if (randGen.NextDouble() < pdf[topAction])
{
cost = -1;
histCost[topAction, userIndex] += 1;
clicks += 1;
}

ex.Examples[topAction].Label = new ContextualBanditLabel((uint)topAction, cost, scorerPdf[topAction]);

// simulate delay
if (i >= initial_random && (i % exp_iter == 0))
{
ExportScoringModel(learner, ref scorer);
}

// invoke learning
var oneStepAheadScores = ex.Learn(VowpalWabbitPredictionType.ActionProbabilities, learner);
histPred2[oneStepAheadScores[0].Action, userIndex] += 1;

var oneStepAheadScores2 = ex.Learn(VowpalWabbitPredictionType.ActionProbabilities, learner2);

// SNIPS
snips_num -= oneStepAheadScores2.First(f => f.Action == topAction).Score * cost / scorerPdf[topAction];
snips_den += oneStepAheadScores2.First(f => f.Action == topAction).Score / scorerPdf[topAction];

if (i % mod_iter == 0 || i == tot_iter)
{
Console.WriteLine(JsonConvert.SerializeObject(new
{
Iter = i,
clicks,
CTR = clicks / (float)i,
aveLoss = learner.PerformanceStatistics.AverageLoss,
CTR_snips = snips_num / snips_den,
CTR_ips = snips_num / (float)i,
aveLoss2 = learner2.PerformanceStatistics.AverageLoss,
snips_num,
snips_den,
histActions,
histPred,
histCost,
histContext,
pdf
}));
}
}
}
Console.WriteLine("---------------------");
scorer?.Dispose();
}
}
}
}
4 changes: 4 additions & 0 deletions cs/examples/simulator/packages.config
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
<?xml version="1.0" encoding="utf-8"?>
<packages>
<package id="Newtonsoft.Json" version="12.0.1" targetFramework="net452" />
</packages>
Loading

0 comments on commit ae3d6a4

Please sign in to comment.