-
Notifications
You must be signed in to change notification settings - Fork 3.9k
/
Copy pathcompat.py
171 lines (141 loc) · 4.7 KB
/
compat.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
# coding: utf-8
"""Compatibility library."""
from __future__ import absolute_import
import inspect
import sys
import numpy as np
is_py3 = (sys.version_info[0] == 3)
"""Compatibility between Python2 and Python3"""
if is_py3:
zip_ = zip
string_type = str
numeric_types = (int, float, bool)
integer_types = (int, )
range_ = range
def argc_(func):
"""Count the number of arguments of a function."""
return len(inspect.signature(func).parameters)
def decode_string(bytestring):
"""Decode C bytestring to ordinary string."""
return bytestring.decode('utf-8')
else:
from itertools import izip as zip_
string_type = basestring
numeric_types = (int, long, float, bool)
integer_types = (int, long)
range_ = xrange
def argc_(func):
"""Count the number of arguments of a function."""
return len(inspect.getargspec(func).args)
def decode_string(bytestring):
"""Decode C bytestring to ordinary string."""
return bytestring
"""json"""
try:
import simplejson as json
except (ImportError, SyntaxError):
# simplejson does not support Python 3.2, it throws a SyntaxError
# because of u'...' Unicode literals.
import json
def json_default_with_numpy(obj):
"""Convert numpy classes to JSON serializable objects."""
if isinstance(obj, (np.integer, np.floating, np.bool_)):
return obj.item()
elif isinstance(obj, np.ndarray):
return obj.tolist()
else:
return obj
"""pandas"""
try:
from pandas import Series, DataFrame
from pandas.api.types import is_sparse as is_dtype_sparse
PANDAS_INSTALLED = True
except ImportError:
PANDAS_INSTALLED = False
class Series(object):
"""Dummy class for pandas.Series."""
pass
class DataFrame(object):
"""Dummy class for pandas.DataFrame."""
pass
is_dtype_sparse = None
"""matplotlib"""
try:
import matplotlib
MATPLOTLIB_INSTALLED = True
except ImportError:
MATPLOTLIB_INSTALLED = False
"""graphviz"""
try:
import graphviz
GRAPHVIZ_INSTALLED = True
except ImportError:
GRAPHVIZ_INSTALLED = False
"""datatable"""
try:
import datatable
if hasattr(datatable, "Frame"):
DataTable = datatable.Frame
else:
DataTable = datatable.DataTable
DATATABLE_INSTALLED = True
except ImportError:
DATATABLE_INSTALLED = False
class DataTable(object):
"""Dummy class for DataTable."""
pass
"""sklearn"""
try:
from sklearn.base import BaseEstimator
from sklearn.base import RegressorMixin, ClassifierMixin
from sklearn.preprocessing import LabelEncoder
from sklearn.utils.class_weight import compute_sample_weight
from sklearn.utils.multiclass import check_classification_targets
from sklearn.utils.validation import assert_all_finite, check_X_y, check_array
try:
from sklearn.model_selection import StratifiedKFold, GroupKFold
from sklearn.exceptions import NotFittedError
except ImportError:
from sklearn.cross_validation import StratifiedKFold, GroupKFold
from sklearn.utils.validation import NotFittedError
try:
from sklearn.utils.validation import _check_sample_weight
except ImportError:
from sklearn.utils.validation import check_consistent_length
# dummy function to support older version of scikit-learn
def _check_sample_weight(sample_weight, X, dtype=None):
check_consistent_length(sample_weight, X)
return sample_weight
SKLEARN_INSTALLED = True
_LGBMModelBase = BaseEstimator
_LGBMRegressorBase = RegressorMixin
_LGBMClassifierBase = ClassifierMixin
_LGBMLabelEncoder = LabelEncoder
LGBMNotFittedError = NotFittedError
_LGBMStratifiedKFold = StratifiedKFold
_LGBMGroupKFold = GroupKFold
_LGBMCheckXY = check_X_y
_LGBMCheckArray = check_array
_LGBMCheckSampleWeight = _check_sample_weight
_LGBMAssertAllFinite = assert_all_finite
_LGBMCheckClassificationTargets = check_classification_targets
_LGBMComputeSampleWeight = compute_sample_weight
except ImportError:
SKLEARN_INSTALLED = False
_LGBMModelBase = object
_LGBMClassifierBase = object
_LGBMRegressorBase = object
_LGBMLabelEncoder = None
LGBMNotFittedError = ValueError
_LGBMStratifiedKFold = None
_LGBMGroupKFold = None
_LGBMCheckXY = None
_LGBMCheckArray = None
_LGBMCheckSampleWeight = None
_LGBMAssertAllFinite = None
_LGBMCheckClassificationTargets = None
_LGBMComputeSampleWeight = None
# DeprecationWarning is not shown by default, so let's create our own with higher level
class LGBMDeprecationWarning(UserWarning):
"""Custom deprecation warning."""
pass