Skip to content

Commit

Permalink
Make bandit tests configurable (#1236)
Browse files Browse the repository at this point in the history
* Make epsilon configurable

* remove old field

* support multiple methodologies

* one method

* naming

* unit tests

* rename to selectVariantUsingMVT

* fix test name

* fix test

* call selectVariantUsingMVT
  • Loading branch information
tomrf1 authored Oct 30, 2024
1 parent 152fa34 commit 408d71c
Show file tree
Hide file tree
Showing 11 changed files with 175 additions and 238 deletions.
6 changes: 5 additions & 1 deletion src/server/bandit/banditData.ts
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,12 @@ async function buildBanditDataForTest(epicTest: EpicTest): Promise<BanditData> {
};
}

function hasBanditMethodology(test: EpicTest): boolean {
return !!test.methodologies?.find((method) => method.name === 'EpsilonGreedyBandit');
}

function buildBanditData(epicTestsProvider: ValueProvider<EpicTest[]>): Promise<BanditData[]> {
const banditTests = epicTestsProvider.get().filter((epicTest) => epicTest.isBanditTest);
const banditTests = epicTestsProvider.get().filter(hasBanditMethodology);
return Promise.all(
banditTests.map((epicTest) =>
buildBanditDataForTest(epicTest).catch((error) => {
Expand Down
49 changes: 16 additions & 33 deletions src/server/bandit/banditSelection.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import { EpicTest, EpicVariant } from '../../shared/types';
import { Test, Variant } from '../../shared/types';
import { BanditData } from './banditData';
import { AppliedLearningBanditTestsNames, Result } from '../tests/epics/epicSelection';
import { putMetric } from '../utils/cloudwatch';
import { logError } from '../utils/logging';

Expand All @@ -9,10 +8,10 @@ import { logError } from '../utils/logging';
* https://en.wikipedia.org/wiki/Multi-armed_bandit#Semi-uniform_strategies
*/

export function selectVariantWithHighestMean(
export function selectVariantWithHighestMean<V extends Variant, T extends Test<V>>(
testBanditData: BanditData,
test: EpicTest,
): EpicVariant | undefined {
test: T,
): V | undefined {
const variant =
testBanditData.bestVariants.length < 2
? testBanditData.bestVariants[0]
Expand All @@ -27,33 +26,24 @@ export function selectVariantWithHighestMean(
return test.variants.find((v) => v.name === variant.variantName);
}

function selectRandomVariant(test: EpicTest): Result {
function selectRandomVariant<V extends Variant, T extends Test<V>>(test: T): V | undefined {
const randomVariantIndex = Math.floor(Math.random() * test.variants.length);
const randomVariantData = test.variants[randomVariantIndex];

if (!randomVariantData) {
logError(`Failed to select random variant for bandit test: ${test.name}`);
putMetric('bandit-selection-error');
return {};
return;
}

return {
result: { test, variant: randomVariantData },
};
return randomVariantData;
}

export function epsilonValueForBanditTest(testBanditData: BanditData): number {
if (testBanditData.testName.includes(AppliedLearningBanditTestsNames.BanditTestEpsilon1)) {
return 1;
} else if (
testBanditData.testName.includes(AppliedLearningBanditTestsNames.BanditTestEpsilon2)
) {
return 0.5;
}
return 1;
}

export function selectVariantUsingEpsilonGreedy(banditData: BanditData[], test: EpicTest): Result {
export function selectVariantUsingEpsilonGreedy<V extends Variant, T extends Test<V>>(
banditData: BanditData[],
test: T,
epsilon: number,
): V | undefined {
const testBanditData = banditData.find((bandit) => bandit.testName === test.name);

if (!testBanditData) {
Expand All @@ -64,24 +54,17 @@ export function selectVariantUsingEpsilonGreedy(banditData: BanditData[], test:
// Choose at random with probability epsilon
const random = Math.random();

const EPSILON = epsilonValueForBanditTest(testBanditData);

if (EPSILON > random) {
if (epsilon > random) {
return selectRandomVariant(test);
}

const highestMeanVariantData = selectVariantWithHighestMean(testBanditData, test);
const highestMeanVariantData = selectVariantWithHighestMean<V, T>(testBanditData, test);

if (!highestMeanVariantData) {
logError(`Failed to select best variant for bandit test: ${test.name}`);
putMetric('bandit-selection-error');
return {};
return;
}

return {
result: {
test,
variant: highestMeanVariantData,
},
};
return highestMeanVariantData;
}
46 changes: 37 additions & 9 deletions src/server/lib/ab.test.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { EpicTest } from '../../shared/types';
import { selectVariant, withinRange, selectWithSeed } from './ab';
import { selectVariantUsingMVT, withinRange, selectWithSeed, selectVariant } from './ab';

const test: EpicTest = {
name: 'example-1', // note - changing this name will change the results of the tests, as it's used for the seed
Expand Down Expand Up @@ -42,34 +42,34 @@ const controlProportionSettings = {
offset: 500000,
};

describe('selectVariant', () => {
describe('selectVariantWithMVT', () => {
it('should select control (no controlProportion)', () => {
const variant = selectVariant(test, 4);
const variant = selectVariantUsingMVT(test, 4);
expect(variant.name).toBe('control');
});

it('should select variant (no controlProportion)', () => {
const variant = selectVariant(test, 2);
const variant = selectVariantUsingMVT(test, 2);
expect(variant.name).toBe('v1');
});

it('should select control (lower end of controlProportion)', () => {
const variant = selectVariant({ ...test, controlProportionSettings }, 500000);
const variant = selectVariantUsingMVT({ ...test, controlProportionSettings }, 500000);
expect(variant.name).toBe('control');
});

it('should select control (upper end of controlProportion)', () => {
const variant = selectVariant({ ...test, controlProportionSettings }, 599999);
const variant = selectVariantUsingMVT({ ...test, controlProportionSettings }, 599999);
expect(variant.name).toBe('control');
});

it('should select variant (below controlProportion)', () => {
const variant = selectVariant({ ...test, controlProportionSettings }, 499999);
const variant = selectVariantUsingMVT({ ...test, controlProportionSettings }, 499999);
expect(variant.name).toBe('v1');
});

it('should select variant (above controlProportion)', () => {
const variant = selectVariant({ ...test, controlProportionSettings }, 600000);
const variant = selectVariantUsingMVT({ ...test, controlProportionSettings }, 600000);
expect(variant.name).toBe('v1');
});

Expand All @@ -78,7 +78,10 @@ describe('selectVariant', () => {
...test,
variants: [test.variants[0]],
};
const variant = selectVariant({ ...controlOnly, controlProportionSettings }, 600000);
const variant = selectVariantUsingMVT(
{ ...controlOnly, controlProportionSettings },
600000,
);
expect(variant.name).toBe('control');
});
});
Expand Down Expand Up @@ -129,3 +132,28 @@ describe('selectWithSeed', () => {
expect(Math.abs(variantCounts.control - variantCounts.v1)).toBeLessThan(10);
});
});

describe('selectVariant', () => {
it('should return same test name if no methodology configured', () => {
const result = selectVariant(test, 1, []);
expect(result?.test.name).toEqual(test.name);
});

it('should return same test name if one methodology is configured', () => {
const testWithMethodology: EpicTest = {
...test,
methodologies: [{ name: 'ABTest' }],
};
const result = selectVariant(testWithMethodology, 1, []);
expect(result?.test.name).toEqual(test.name);
});

it('should return extended test name if one than one methodology is configured', () => {
const testWithMethodology: EpicTest = {
...test,
methodologies: [{ name: 'ABTest' }, { name: 'EpsilonGreedyBandit', epsilon: 0.5 }],
};
const result = selectVariant(testWithMethodology, 1, []);
expect(result?.test.name).toBe('example-1_EpsilonGreedyBandit-0.5');
});
});
95 changes: 90 additions & 5 deletions src/server/lib/ab.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
import { Test, Variant, AmountsTests, SelectedAmountsVariant } from '../../shared/types';
import {
Test,
Variant,
AmountsTests,
SelectedAmountsVariant,
Methodology,
} from '../../shared/types';
import { CountryGroupId } from '../../shared/lib';
import seedrandom from 'seedrandom';
import { BanditData } from '../bandit/banditData';
import { selectVariantUsingEpsilonGreedy } from '../bandit/banditSelection';

const maxMvt = 1000000;

Expand Down Expand Up @@ -33,10 +41,14 @@ export const selectWithSeed = <V extends Variant>(
};

/**
* For use in AB tests.
* If controlProportionSettings is set then we use this to define the range of mvt values for the control variant.
* Otherwise we evenly distribute all variants across maxMvt.
*/
export const selectVariant = <V extends Variant, T extends Test<V>>(test: T, mvtId: number): V => {
export const selectVariantUsingMVT = <V extends Variant, T extends Test<V>>(
test: T,
mvtId: number,
): V => {
const control = test.variants.find((v) => v.name.toLowerCase() === 'control');
const seed = test.name.split('__')[0];

Expand All @@ -60,9 +72,82 @@ export const selectVariant = <V extends Variant, T extends Test<V>>(test: T, mvt
return selectWithSeed(mvtId, seed, test.variants);
};

export const selectVariantNonSticky = <V extends Variant, T extends Test<V>>(test: T): V => {
const index = Math.floor(Math.random() * test.variants.length);
return test.variants[index];
const selectVariantWithMethodology = <V extends Variant, T extends Test<V>>(
test: T,
mvtId: number,
banditData: BanditData[],
methodology: Methodology,
): V | undefined => {
if (methodology.name === 'EpsilonGreedyBandit') {
return selectVariantUsingEpsilonGreedy(banditData, test, methodology.epsilon);
}
return selectVariantUsingMVT<V, T>(test, mvtId);
};

const addMethodologyToTestName = (testName: string, methodology: Methodology): string => {
if (methodology.name === 'EpsilonGreedyBandit') {
return `${testName}_EpsilonGreedyBandit-${methodology.epsilon}`;
} else {
return `${testName}_ABTest`;
}
};

/**
* Selects a variant from the test based on any configured methodologies.
* Defaults to an AB test.
*/
export const selectVariant = <V extends Variant, T extends Test<V>>(
test: T,
mvtId: number,
banditData: BanditData[],
): { test: T; variant: V } | undefined => {
if (test.methodologies && test.methodologies.length === 1) {
// Only one configured methodology
const variant = selectVariantWithMethodology<V, T>(
test,
mvtId,
banditData,
test.methodologies[0],
);
if (variant) {
return {
test,
variant,
};
}
} else if (test.methodologies) {
// More than one methodology, pick one of them using the mvt value
const methodology =
test.methodologies[getRandomNumber(test.name, mvtId) % test.methodologies.length];

// Add the methodology to the test name so that we can track them separately
const testWithNameExtension = {
...test,
name: addMethodologyToTestName(test.name, methodology),
};
const variant = selectVariantWithMethodology<V, T>(
testWithNameExtension,
mvtId,
banditData,
methodology,
);
if (variant) {
return {
test: testWithNameExtension,
variant,
};
}
} else {
// No configured methodology, default to AB test
const variant = selectVariantUsingMVT<V, T>(test, mvtId);
if (variant) {
return {
test,
variant,
};
}
}
return;
};

export const selectAmountsTestVariant = (
Expand Down
4 changes: 2 additions & 2 deletions src/server/lib/targetingTesting.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { TargetingAbTest, Test, Variant } from '../../shared/types';
import { selectVariant } from './ab';
import { selectVariantUsingMVT } from './ab';
import { ScheduledBannerDeploys } from '../tests/banners/bannerDeploySchedule';

type TargetingTestDecision = {
Expand Down Expand Up @@ -36,7 +36,7 @@ export const selectTargetingTest = <T>(
);

if (test) {
const variant: TargetingTestVariant<T> = selectVariant(test, mvtId);
const variant: TargetingTestVariant<T> = selectVariantUsingMVT(test, mvtId);
return {
canShow: variant.canShow(targeting),
deploySchedule: variant.deploySchedule,
Expand Down
4 changes: 2 additions & 2 deletions src/server/tests/banners/bannerSelection.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import {
UserDeviceType,
uiIsDesign,
} from '../../../shared/types';
import { selectVariant } from '../../lib/ab';
import { selectVariantUsingMVT } from '../../lib/ab';
import { historyWithinArticlesViewedSettings } from '../../lib/history';
import { TestVariant } from '../../lib/params';
import {
Expand Down Expand Up @@ -232,7 +232,7 @@ export const selectBannerTest = (
consentStatusMatches(targeting.hasConsented, test.consentStatus) &&
abandonedBasketMatches(test.bannerChannel, targeting.abandonedBasket)
) {
const variant = selectVariant<BannerVariant, BannerTest>(test, targeting.mvtId);
const variant = selectVariantUsingMVT<BannerVariant, BannerTest>(test, targeting.mvtId);

return {
test,
Expand Down
Loading

0 comments on commit 408d71c

Please sign in to comment.