-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Contextual Bandit Zeroth Order Optimization
CBZO is a contextual bandit-style algorithm meant for multi-dimensional, continuous action space. It can learn different policies based on Zeroth-Order Optimization -- continuous optimization techniques which make use of gradient estimators that only require values of the function to make an estimate. The variant of CBZO currently implemented in VW works in the 1-dimensional action space setting and can learn either constant or linear policies. The algorithm has optimal bounded regret when the cost function is smooth and convex.
Let the policy to be learnt by the algorithm, which maps contexts to actions be where is the context.
As mentioned earlier, the two variants of currently implemented are:
(constant policy)
(linear policy)
where
Assuming to be the cost function, it is minimized as:
Using the 1-point gradient estimator and chain rule, we arrive at:
where is a uniform pmf over and can be computed precisely since is known.
Importantly, by using the fact that is known, we remove the dependence of sample complexity on
while suffering only dependence; the gap can be huge especially when is very large.
Implementation note: A small interval around , one for each value of , are loaded into the VW prediction to get back in the label.
The --cbzo
option enables usage of this algorithm. It is written as a new base learner and thus doesn't make use of other reductions in VW.
The label type is VW::cb_continuous::continuous_label
, same as the one cats_pdf
uses
The prediction type is VW::continuous_actions::probability_density_function
, same as the one cats_pdf
uses
Single-line format similar to CATS
-
--policy [arg]
- one ofconstant
,linear
. Defaults tolinear
-
-l [arg]
- learning rate -
--radius [arg]
- the radius of exploration . Defaults to0.1
-
--l1 [arg]
- l_1 lambda (L1 regularization) -
--l2 [arg]
- l_2 lambda (L2 regularization) -
--no_bias_regularization
- don't regularize bias
vw -d [data_file] --cbzo --policy linear -l 0.01 --radius 0.1
However, the natural way to use cbzo
is interactively and not by learning from a data file. In fact, producing a data file that can be useful in learning is not natural because the action to be used in a label is dependent on what prediction was made earlier on the unlabelled example (which had the same context).
Given below is an example to use cbzo
interactively using VW's Python API.
from vowpalwabbit import pyvw
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
def main():
vw = pyvw.vw('--cbzo --policy linear -l 0.01 --radius 0.1 --quiet')
costs_progress = []
for _ in range(500):
# Get context from environment
ctx = Environment.next_ctx()
# Determine what action to take for this context by calling predict()
# on an unlabelled example (which includes the context)
ex = vw.parse(' | c1:{} c2:{}'.format(ctx[0], ctx[1]), labelType=vw.lContinuous)
pred = vw.predict(ex)
vw.finish_example(ex)
# A list (length=2) of pdf segments are returned. Sample
# action from them
action, pdf_value = sample_action(pred)
# Get cost for the action and pass it to VW by creating
# a labelled example
cost = Environment.cost(ctx, action)
ex = vw.parse('ca {}:{}:{} | c1:{} c2:{}'.format(action, cost, pdf_value, ctx[0], ctx[1]), labelType=vw.lContinuous)
vw.learn(ex)
vw.finish_example(ex)
costs_progress.append(cost)
vw.finish()
plot(costs_progress)
"""
We create a synthetic environment where the optimal action is linear
and the cost is the absolute loss function
"""
class Environment:
wopt, bopt = [-3, 5], 0.5
rs = np.random.RandomState(0)
@staticmethod
def next_ctx():
return Environment.rs.normal(1, 1, size=(2,))
@staticmethod
def cost(ctx, action):
optimal_action = np.dot(Environment.wopt, ctx) + Environment.bopt
return abs(optimal_action - action)
"""
Samples action from the prediction.
pred is of the form [(left1, right1, pdf_value1), (left2, right2, pdf_value2)]
"""
def sample_action(pred):
# the next line is equivalent to p = [0.5, 0.5]
p = np.array([pred[i][2] * (pred[i][1] - pred[i][0]) for i in range(2)])
idx = np.random.choice(2, p=p / p.sum())
return np.random.uniform(pred[idx][0], pred[idx][1]), pred[idx][2]
def plot(costs_progress):
_, ax = plt.subplots()
costs_progress = pd.Series(costs_progress).rolling(10).mean()
ax.plot(costs_progress, marker='.', markersize=2, linewidth=1)
ax.set_xlabel('Iterations')
ax.set_ylabel('Cost (rolling mean)')
ax.grid()
plt.savefig('progress.png')
if __name__ == '__main__':
np.random.seed(0)
main()
Progress of the algorithm is seen in the figure saved:
By passing the --audit
option, it can be confirmed that the weights learnt by cbzo
are close to the optimal/target ones used by the synthetic environment.
- Home
- First Steps
- Input
- Command line arguments
- Model saving and loading
- Controlling VW's output
- Audit
- Algorithm details
- Awesome Vowpal Wabbit
- Learning algorithm
- Learning to Search subsystem
- Loss functions
- What is a learner?
- Docker image
- Model merging
- Evaluation of exploration algorithms
- Reductions
- Contextual Bandit algorithms
- Contextual Bandit Exploration with SquareCB
- Contextual Bandit Zeroth Order Optimization
- Conditional Contextual Bandit
- Slates
- CATS, CATS-pdf for Continuous Actions
- Automl
- Epsilon Decay
- Warm starting contextual bandits
- Efficient Second Order Online Learning
- Latent Dirichlet Allocation
- VW Reductions Workflows
- Interaction Grounded Learning
- CB with Large Action Spaces
- CB with Graph Feedback
- FreeGrad
- Marginal
- Active Learning
- Eigen Memory Trees (EMT)
- Element-wise interaction
- Bindings
-
Examples
- Logged Contextual Bandit example
- One Against All (oaa) multi class example
- Weighted All Pairs (wap) multi class example
- Cost Sensitive One Against All (csoaa) multi class example
- Multiclass classification
- Error Correcting Tournament (ect) multi class example
- Malicious URL example
- Daemon example
- Matrix factorization example
- Rcv1 example
- Truncated gradient descent example
- Scripts
- Implement your own joint prediction model
- Predicting probabilities
- murmur2 vs murmur3
- Weight vector
- Matching Label and Prediction Types Between Reductions
- Zhen's Presentation Slides on enhancements to vw
- EZExample Archive
- Design Documents
- Contribute: