-
-
Notifications
You must be signed in to change notification settings - Fork 8.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[pyspark] support typing for model.py #9156
Conversation
Seems the failure |
) | ||
|
||
# xgboost types | ||
XGB_ESTIMATOR = Union[XGBClassifier, XGBRanker, XGBRegressor] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use XGBModel
class instead. It's the parent class for all referenced classes here.
SPARK_XGB_ESTIMATOR = Union[SparkXGBClassifier, SparkXGBRanker, SparkXGBRegressor] | ||
SPARK_XGB_ESTIMATOR_TYPE = Union[ | ||
Type[SparkXGBClassifier], Type[SparkXGBRanker], Type[SparkXGBRegressor] | ||
] | ||
SPARK_XGB_MODEL = Union[ | ||
SparkXGBClassifierModel, SparkXGBRegressorModel, SparkXGBRankerModel | ||
] | ||
SPARK_XGB_MODEL_TYPE = Union[ | ||
Type[SparkXGBClassifierModel], | ||
Type[SparkXGBRegressorModel], | ||
Type[SparkXGBRankerModel], | ||
] | ||
|
||
SPARK_XGB_INSTANCE = TypeVar( | ||
"SPARK_XGB_INSTANCE", SPARK_XGB_ESTIMATOR, SPARK_XGB_MODEL | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- Please consider the inheritance structure here instead of using Union.
- Please use
CamelCase
for types.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, For example, If I replace
SPARK_XGB_ESTIMATOR = Union[SparkXGBClassifier, SparkXGBRanker, SparkXGBRegressor]
with
from .core import _SparkXGBEstimator
SparkXGBEstimator = _SparkXGBEstimator
it complains with below error
xgboost/spark/model.py:41: error: Module "xgboost.spark.core" has no attribute "_SparkXGBEstimator" [attr-defined]
xgboost/spark/model.py:57: error: Variable "xgboost.spark.model.SparkXGBEstimator" is not valid as a type [valid-type]
xg
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@trivialfis any thoughts on that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's the same issue one would encounter for any typed language, the modules need to be written in a structured way. Python duck typing hides the issue, now that we want to annotate the code, some refactoring needs to be done.
from pyspark.ml.util import DefaultParamsReader, DefaultParamsWriter, MLReader, MLWriter | ||
from pyspark.sql import SparkSession | ||
|
||
from xgboost.core import Booster | ||
|
||
from .utils import get_class_name, get_logger | ||
|
||
|
||
def _get_or_create_tmp_dir(): | ||
if TYPE_CHECKING: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why type checking only? These are hard dependencies of the spark module, you can import them freely. The spark module is not part of package import:
from .training import cv, train |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If I put the imports outside of TYPE_CHECKING, It will run into circular import issue.
../../../anaconda3/envs/xgboost-dev/lib/python3.9/importlib/__init__.py:127: in import_module
return _bootstrap._gcd_import(name[level:], package, level)
tests/test_distributed/test_with_spark/test_spark_local_cluster.py:23: in <module>
from xgboost.spark import SparkXGBClassifier, SparkXGBRegressor
python-package/xgboost/spark/__init__.py:8: in <module>
from .estimator import (
python-package/xgboost/spark/estimator.py:9: in <module>
from .core import ( # type: ignore
python-package/xgboost/spark/core.py:52: in <module>
from .model import (
python-package/xgboost/spark/model.py:28: in <module>
from . import (
E ImportError: cannot import name 'SparkXGBClassifier' from partially initialized module 'xgboost.spark' (most likely due to a circular import) (/home/bobwang/work.d/ml/xgboost/python-package/xgboost/spark/__init__.py)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@trivialfis any thoughts on that?
Please ignore it. It's the mighty internet error. |
Currently, pyspark will run into circular imports issue when enabling typing for model.py. So this PR tried to refactor out pyspark a little bit to avoid this.
I will create a followup PR to support typing for core.py |
This PR adds typing for model.py of xgboost.spark, and I will create the followup PR to support typing of core.py of xgboost.spark