Skip to content

Commit

Permalink
ENH add snapml solver (#25)
Browse files Browse the repository at this point in the history
  • Loading branch information
tanglef authored Mar 29, 2022
1 parent 26eb702 commit 120e6a7
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 3 deletions.
8 changes: 8 additions & 0 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,11 @@ jobs:

- run: conda info

- name: Install OpenMP
if: matrix.os == 'macos-latest'
run: |
brew install libomp
- name: Install benchopt and its dependencies
run: |
conda info
Expand All @@ -58,6 +63,9 @@ jobs:
- name: Test
run: |
export OMP_NUM_THREADS=1 # see issue 20
benchopt test . --env-name bench_test_env -vl
benchopt test . --env-name bench_test_env -vl --skip-install
Expand Down
6 changes: 3 additions & 3 deletions datasets/news20.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,16 @@ class Dataset(BaseDataset):
name = "news20"
is_sparse = True

install_cmd = 'conda'
requirements = ['pip:libsvmdata']
install_cmd = "conda"
requirements = ["pip:libsvmdata"]

def __init__(self):
self.X, self.y = None, None

def get_data(self):

if self.X is None:
self.X, self.y = fetch_libsvm('news20')
self.X, self.y = fetch_libsvm("news20.binary")

data = dict(X=self.X, y=self.y)

Expand Down
33 changes: 33 additions & 0 deletions solvers/snapml.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from benchopt import BaseSolver, safe_import_context

with safe_import_context() as import_ctx:
from snapml import LogisticRegression
import numpy as np


class Solver(BaseSolver):
name = "snapml"

install_cmd = "conda"
requirements = ["pip:snapml"]

def set_objective(self, X, y, lmbd):
self.X, self.y, self.lmbd = X, y, lmbd

self.clf = LogisticRegression(
fit_intercept=False,
regularizer=self.lmbd,
penalty="l2",
tol=1e-12,
)

def run(self, n_iter):
if n_iter == 0:
self.clf.coef_ = np.zeros(self.X.shape[1])
return

self.clf.max_iter = n_iter
self.clf.fit(self.X, self.y)

def get_result(self):
return self.clf.coef_.flatten()

0 comments on commit 120e6a7

Please sign in to comment.