Skip to content

Commit

Permalink
Added check on mgzip availability
Browse files Browse the repository at this point in the history
  • Loading branch information
iuliivasilev committed May 14, 2024
1 parent 445bab7 commit db4c420
Showing 1 changed file with 8 additions and 3 deletions.
11 changes: 8 additions & 3 deletions survivors/experiments/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import time
import os
import pickle
import mgzip # Custom

from sklearn.model_selection import StratifiedKFold, ParameterGrid, train_test_split
from sklearn.preprocessing import StandardScaler
Expand All @@ -12,6 +11,11 @@
from .. import metrics as metr
from ..external import LeafModel

try:
import mgzip # Custom
open_file = mgzip.open
except ImportError:
open_file = open

def to_str_from_dict_list(d, stratify):
if isinstance(stratify, str):
Expand Down Expand Up @@ -165,10 +169,11 @@ def f(**kwargs):
if not os.path.exists(name):
print("Fitted from scratch")
est.fit(X_train, y_train)
with mgzip.open(name, 'wb') as out: # Custom

with open_file(name, 'wb') as out: # Custom
pickle.dump(est, out, pickle.HIGHEST_PROTOCOL)

with mgzip.open(name, 'rb') as inp:
with open_file(name, 'rb') as inp:
est = pickle.load(inp)
est.aggreg_func = kwargs.get("aggreg_func", None)
est.tolerance_find_best(kwargs["ens_metric_name"])
Expand Down

0 comments on commit db4c420

Please sign in to comment.