Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add multiclass support for hyperopt utl #1682

Merged
merged 7 commits into from
Jan 11, 2019
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 26 additions & 16 deletions utl/vw-hyperopt.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ def __init__(self, train_set, holdout_set, command, max_evals=100,
self.max_evals = max_evals
self.searcher = searcher
self.is_regression = is_regression
self.labels_clf_count = 0

self.trials = Trials()
self.current_trial = 0
Expand Down Expand Up @@ -219,11 +220,16 @@ def get_hyperparam_string(self, **kwargs):
def compose_vw_train_command(self):
data_part = ('vw -d %s -f %s --holdout_off -c '
% (self.train_set, self.train_model))
if self.labels_clf_count > 2: # multiclass, should take probabilities
data_part += ('--oaa %s --loss_function=logistic --probabilities '
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be nice to support other multiclass methods in addition to --oaa.
e.g.: --ect, --recall_tree, --wap, (with & without ldf) etc.
They all have similar interface (expect number of classes as arg)

% (self.labels_clf_count))
self.train_command = ' '.join([data_part, self.param_suffix])

def compose_vw_validate_command(self):
data_part = 'vw -t -d %s -i %s -p %s --holdout_off -c' \
% (self.holdout_set, self.train_model, self.holdout_pred)
if self.labels_clf_count > 2: # multiclass
data_part += ' --loss_function=logistic --probabilities'
self.validate_command = data_part

def fit_vw(self):
Expand All @@ -236,34 +242,38 @@ def validate_vw(self):
self.logger.info("executing the following command (validation): %s" % self.validate_command)
subprocess.call(shlex.split(self.validate_command))

def get_y_true_train(self):
self.logger.info("loading true train class labels...")
yh = open(self.train_set, 'r')
self.y_true_train = []
for line in yh:
self.y_true_train.append(int(line.strip()[0:2]))
if not self.is_regression:
self.y_true_train = [(i + 1.) / 2 for i in self.y_true_train]
self.logger.info("train length: %d" % len(self.y_true_train))

def get_y_true_holdout(self):
self.logger.info("loading true holdout class labels...")
yh = open(self.holdout_set, 'r')
self.y_true_holdout = []
for line in yh:
self.y_true_holdout.append(float(line.split()[0]))
if not self.is_regression:
self.y_true_holdout = [int((i + 1.) / 2) for i in self.y_true_holdout]
self.labels_clf_count = len(set(self.y_true_holdout))
if self.labels_clf_count > 2 and self.outer_loss_function != 'logistic':
raise KeyError('Only logistic loss function is available for multiclass clf')
if self.labels_clf_count <= 2:
self.y_true_holdout = [int((i + 1.) / 2) for i in self.y_true_holdout]
self.logger.info("holdout length: %d" % len(self.y_true_holdout))

def validation_metric_vw(self):
v = open('%s' % self.holdout_pred, 'r')
def get_y_pred_holdout(self):
y_pred_holdout = []
for line in v:
y_pred_holdout.append(float(line.split()[0].strip()))
with open('%s' % self.holdout_pred, 'r') as v:
for line in v:
if self.labels_clf_count > 2:
y_pred_holdout.append(list(map(lambda x: float(x.split(':')[1]), line.split())))
else:
y_pred_holdout.append(float(line.split()[0].strip()))
return y_pred_holdout

def validation_metric_vw(self):
y_pred_holdout = self.get_y_pred_holdout()

if self.outer_loss_function == 'logistic':
y_pred_holdout_proba = [1. / (1 + exp(-i)) for i in y_pred_holdout]
if self.labels_clf_count > 2:
y_pred_holdout_proba = y_pred_holdout
else:
y_pred_holdout_proba = [1. / (1 + exp(-i)) for i in y_pred_holdout]
loss = log_loss(self.y_true_holdout, y_pred_holdout_proba)

elif self.outer_loss_function == 'squared':
Expand Down