Skip to content

Commit

Permalink
[tests] [python] add test for non-serializable callback (#4741)
Browse files Browse the repository at this point in the history
  • Loading branch information
StrikerRUS authored Oct 29, 2021
1 parent e10cbd2 commit 798dc1d
Showing 1 changed file with 20 additions and 0 deletions.
20 changes: 20 additions & 0 deletions tests/python_package_test/test_sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,14 @@
decreasing_generator = itertools.count(0, -1)


class UnpicklableCallback:
def __reduce__(self):
raise Exception("This class in not picklable")

def __call__(self, env):
env.model.set_attr(attr_set_inside_callback=str(env.iteration * 10))


def custom_asymmetric_obj(y_true, y_pred):
residual = (y_true - y_pred).astype(np.float64)
grad = np.where(residual < 0, -2 * 10.0 * residual, -2 * residual)
Expand Down Expand Up @@ -427,6 +435,18 @@ def test_joblib():
np.testing.assert_allclose(pred_origin, pred_pickle)


def test_non_serializable_objects_in_callbacks(tmp_path):
unpicklable_callback = UnpicklableCallback()

with pytest.raises(Exception, match="This class in not picklable"):
joblib.dump(unpicklable_callback, tmp_path / 'tmp.joblib')

X, y = load_boston(return_X_y=True)
gbm = lgb.LGBMRegressor(n_estimators=5)
gbm.fit(X, y, callbacks=[unpicklable_callback])
assert gbm.booster_.attr('attr_set_inside_callback') == '40'


def test_random_state_object():
X, y = load_iris(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42)
Expand Down

0 comments on commit 798dc1d

Please sign in to comment.