Skip to content

Commit

Permalink
MTN benchopt 1.5 API (#48)
Browse files Browse the repository at this point in the history
  • Loading branch information
tomMoral authored Sep 18, 2023
1 parent c29a8de commit 01f3fa2
Show file tree
Hide file tree
Showing 11 changed files with 14 additions and 14 deletions.
2 changes: 1 addition & 1 deletion example_config.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
objective-filter:
objective:
- L2 Logistic Regression[lmbd=0.1]
dataset:
- simulated[n_features=500,n_samples=200]
Expand Down
8 changes: 4 additions & 4 deletions objective.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ def _compute_loss(X, y, lmbd, beta):


class Objective(BaseObjective):
min_benchopt_version = "1.3"
name = "L2 Logistic Regression"
min_benchopt_version = "1.5"

parameters = {
'lmbd': [1., 0.01]
Expand All @@ -27,10 +27,10 @@ def set_data(self, X, y, X_test=None, y_test=None):
msg = "Logistic loss is implemented with y in [-1, 1]"
assert set(self.y) == {-1, 1}, msg

def get_one_solution(self):
return np.zeros((self.X.shape[1]))
def get_one_result(self):
return {'beta': np.zeros((self.X.shape[1]))}

def compute(self, beta):
def evaluate_result(self, beta):
train_loss = _compute_loss(self.X, self.y, self.lmbd, beta)
test_loss = None
if self.X_test is not None:
Expand Down
2 changes: 1 addition & 1 deletion solvers/cd.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,4 +131,4 @@ def sparse_cd(X_data, X_indices, X_indptr, y, lmbd, L, n_iter,
return w

def get_result(self):
return self.w
return dict(beta=self.w)
2 changes: 1 addition & 1 deletion solvers/chop.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,4 +172,4 @@ def run(self, n_iter):
self.run_full_batch(n_iter)

def get_result(self):
return self.beta.detach().cpu().numpy().flatten()
return dict(beta=self.beta.detach().cpu().numpy().flatten())
2 changes: 1 addition & 1 deletion solvers/copt.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,4 +107,4 @@ def step(x):
self.beta = result.x

def get_result(self):
return self.beta.flatten()
return dict(beta=self.beta.flatten())
2 changes: 1 addition & 1 deletion solvers/cuml.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,4 +52,4 @@ def run(self, n_iter):
self.clf.fit(self.X, self.y)

def get_result(self):
return self.clf.coef_.to_numpy().flatten()
return dict(beta=self.clf.coef_.to_numpy().flatten())
2 changes: 1 addition & 1 deletion solvers/glmnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,4 +82,4 @@ def get_result(self):
coefs = np.array(as_matrix(results["beta"], "matrix"))
beta = coefs.flatten()

return beta
return dict(beta=beta)
2 changes: 1 addition & 1 deletion solvers/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,4 @@ def run(self, n_iter):
self.clf.fit(self.X, self.y)

def get_result(self):
return self.clf.coef_.flatten()
return dict(beta=self.clf.coef_.flatten())
2 changes: 1 addition & 1 deletion solvers/python_gd.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def run(self, n_iter):
self.w = w

def get_result(self):
return self.w
return dict(beta=self.w)

def compute_lipschitz_constant(self):
if not sparse.issparse(self.X):
Expand Down
2 changes: 1 addition & 1 deletion solvers/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,4 +56,4 @@ def run(self, n_iter):
self.clf.fit(self.X, self.y)

def get_result(self):
return self.clf.coef_.flatten()
return dict(beta=self.clf.coef_.flatten())
2 changes: 1 addition & 1 deletion solvers/snapml.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,4 +41,4 @@ def run(self, n_iter):
self.clf.fit(self.X, self.y)

def get_result(self):
return self.clf.coef_.flatten()
return dict(beta=self.clf.coef_.flatten())

0 comments on commit 01f3fa2

Please sign in to comment.