diff --git a/tests/python_package_test/test_dask.py b/tests/python_package_test/test_dask.py index 89c7341bcc27..da91c4ef209a 100644 --- a/tests/python_package_test/test_dask.py +++ b/tests/python_package_test/test_dask.py @@ -976,3 +976,31 @@ def test_dask_classes_and_sklearn_equivalents_have_identical_constructors_except assert dask_spec.defaults[:-1] == sklearn_spec.defaults assert dask_spec.args[-1] == 'client' assert dask_spec.defaults[-1] is None + + +@pytest.mark.parametrize( + "methods", + [ + (lgb.DaskLGBMClassifier.fit, lgb.LGBMClassifier.fit), + (lgb.DaskLGBMClassifier.predict, lgb.LGBMClassifier.predict), + (lgb.DaskLGBMClassifier.predict_proba, lgb.LGBMClassifier.predict_proba), + (lgb.DaskLGBMRegressor.fit, lgb.LGBMRegressor.fit), + (lgb.DaskLGBMRegressor.predict, lgb.LGBMRegressor.predict), + (lgb.DaskLGBMRanker.fit, lgb.LGBMRanker.fit), + (lgb.DaskLGBMRanker.predict, lgb.LGBMRanker.predict) + ] +) +def test_dask_methods_and_sklearn_equivalents_have_similar_signatures(methods): + dask_spec = inspect.getfullargspec(methods[0]) + sklearn_spec = inspect.getfullargspec(methods[1]) + dask_params = inspect.signature(methods[0]).parameters + sklearn_params = inspect.signature(methods[1]).parameters + assert dask_spec.args == sklearn_spec.args[:len(dask_spec.args)] + assert dask_spec.varargs == sklearn_spec.varargs + if sklearn_spec.varkw: + assert dask_spec.varkw == sklearn_spec.varkw[:len(dask_spec.varkw)] + assert dask_spec.kwonlyargs == sklearn_spec.kwonlyargs + assert dask_spec.kwonlydefaults == sklearn_spec.kwonlydefaults + for param in dask_spec.args: + error_msg = f"param '{param}' has different default values in the methods" + assert dask_params[param].default == sklearn_params[param].default, error_msg