Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[tests][dask] add scikit-learn compatibility tests (fixes #3894) #3947

Merged
merged 13 commits into from
Feb 18, 2021
34 changes: 34 additions & 0 deletions tests/python_package_test/test_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import joblib
import numpy as np
import pandas as pd
import sklearn.utils.estimator_checks as sklearn_checks
from dask.array.utils import assert_eq
from dask.distributed import Client, LocalCluster, default_client, wait
from distributed.utils_test import client, cluster_fixture, gen_cluster, loop
Expand Down Expand Up @@ -1079,3 +1080,36 @@ def test_dask_methods_and_sklearn_equivalents_have_similar_signatures(methods):
for param in dask_spec.args:
error_msg = f"param '{param}' has different default values in the methods"
assert dask_params[param].default == sklearn_params[param].default, error_msg


def sklearn_checks_to_run():
check_names = [
"check_estimator_get_tags_default_keys",
"check_get_params_invariance",
"check_set_params"
]
for check_name in check_names:
check_func = getattr(sklearn_checks, check_name, None)
if check_func:
yield check_func


def _tested_estimators():
for Estimator in [lgb.DaskLGBMClassifier, lgb.DaskLGBMRegressor]:
yield Estimator()


@pytest.mark.parametrize("estimator", _tested_estimators())
@pytest.mark.parametrize("check", sklearn_checks_to_run())
def test_sklearn_integration(estimator, check, client):
estimator.set_params(local_listen_port=18000, time_out=5)
name = type(estimator).__name__
check(name, estimator)
client.close(timeout=CLIENT_CLOSE_TIMEOUT)


# this test is separate because it takes a not-yet-constructed estimator
@pytest.mark.parametrize("estimator", list(_tested_estimators()))
def test_parameters_default_constructible(estimator):
name, Estimator = estimator.__class__.__name__, estimator.__class__
sklearn_checks.check_parameters_default_constructible(name, Estimator)