Skip to content

Commit

Permalink
Cleanup pandas support
Browse files Browse the repository at this point in the history
  • Loading branch information
sinhrks committed Nov 12, 2015
1 parent 4fb6153 commit 25c4fbd
Show file tree
Hide file tree
Showing 5 changed files with 132 additions and 57 deletions.
2 changes: 2 additions & 0 deletions python-package/conv_rst.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# pylint: disable=invalid-name, exec-used
"""Convert README.md to README.rst for PyPI"""

from pypandoc import convert

read_md = convert('python-package/README.md', 'rst')
with open('python-package/README.rst', 'w') as rst_file:
rst_file.write(read_md)
47 changes: 47 additions & 0 deletions python-package/xgboost/compat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# coding: utf-8
# pylint: disable=unused-import, invalid-name
"""For compatibility"""

from __future__ import absolute_import

import sys


PY3 = (sys.version_info[0] == 3)

if PY3:
# pylint: disable=invalid-name, redefined-builtin
STRING_TYPES = str,
else:
# pylint: disable=invalid-name
STRING_TYPES = basestring,

# pandas
try:
from pandas import DataFrame
PANDAS_INSTALLED = True
except ImportError:

class DataFrame(object):
""" dummy for pandas.DataFrame """
pass

PANDAS_INSTALLED = False

# sklearn
try:
from sklearn.base import BaseEstimator
from sklearn.base import RegressorMixin, ClassifierMixin
from sklearn.preprocessing import LabelEncoder
SKLEARN_INSTALLED = True

XGBModelBase = BaseEstimator
XGBRegressorBase = RegressorMixin
XGBClassifierBase = ClassifierMixin
except ImportError:
SKLEARN_INSTALLED = False

# used for compatiblity without sklearn
XGBModelBase = object
XGBClassifierBase = object
XGBRegressorBase = object
75 changes: 35 additions & 40 deletions python-package/xgboost/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from __future__ import absolute_import

import os
import sys
import ctypes
import collections

Expand All @@ -13,20 +12,12 @@

from .libpath import find_lib_path

from .compat import STRING_TYPES, PY3, DataFrame

class XGBoostError(Exception):
"""Error throwed by xgboost trainer."""
pass

PY3 = (sys.version_info[0] == 3)

if PY3:
# pylint: disable=invalid-name, redefined-builtin
STRING_TYPES = str,
else:
# pylint: disable=invalid-name
STRING_TYPES = basestring,


def from_pystr_to_cstr(data):
"""Convert a list of Python str to C pointer
Expand Down Expand Up @@ -138,42 +129,49 @@ def c_array(ctype, values):
return (ctype * len(values))(*values)


def _maybe_from_pandas(data, label, feature_names, feature_types):
""" Extract internal data from pd.DataFrame """
try:
import pandas as pd
except ImportError:
return data, label, feature_names, feature_types

if not isinstance(data, pd.DataFrame):
return data, label, feature_names, feature_types
PANDAS_DTYPE_MAPPER = {'int8': 'int', 'int16': 'int', 'int32': 'int', 'int64': 'int',
'uint8': 'int', 'uint16': 'int', 'uint32': 'int', 'uint64': 'int',
'float16': 'float', 'float32': 'float', 'float64': 'float',
'bool': 'i'}


mapper = {'int8': 'int', 'int16': 'int', 'int32': 'int', 'int64': 'int',
'uint8': 'int', 'uint16': 'int', 'uint32': 'int', 'uint64': 'int',
'float16': 'float', 'float32': 'float', 'float64': 'float',
'bool': 'i'}
def _maybe_pandas_data(data, feature_names, feature_types):
""" Extract internal data from pd.DataFrame for DMatrix data """

if not isinstance(data, DataFrame):
return data, feature_names, feature_types

data_dtypes = data.dtypes
if not all(dtype.name in (mapper.keys()) for dtype in data_dtypes):
if not all(dtype.name in PANDAS_DTYPE_MAPPER for dtype in data_dtypes):
raise ValueError('DataFrame.dtypes for data must be int, float or bool')

if label is not None:
if isinstance(label, pd.DataFrame):
label_dtypes = label.dtypes
if not all(dtype.name in (mapper.keys()) for dtype in label_dtypes):
raise ValueError('DataFrame.dtypes for label must be int, float or bool')
else:
label = label.values.astype('float')

if feature_names is None:
feature_names = data.columns.format()

if feature_types is None:
feature_types = [mapper[dtype.name] for dtype in data_dtypes]
feature_types = [PANDAS_DTYPE_MAPPER[dtype.name] for dtype in data_dtypes]

data = data.values.astype('float')

return data, label, feature_names, feature_types
return data, feature_names, feature_types


def _maybe_pandas_label(label):
""" Extract internal data from pd.DataFrame for DMatrix label """

if isinstance(label, DataFrame):
if len(label.columns) > 1:
raise ValueError('DataFrame for label cannot have multiple columns')

label_dtypes = label.dtypes
if not all(dtype.name in PANDAS_DTYPE_MAPPER for dtype in label_dtypes):
raise ValueError('DataFrame.dtypes for label must be int, float or bool')
else:
label = label.values.astype('float')
# pd.Series can be passed to xgb as it is

return label

class DMatrix(object):
"""Data Matrix used in XGBoost.
Expand Down Expand Up @@ -216,13 +214,10 @@ def __init__(self, data, label=None, missing=0.0,
self.handle = None
return

klass = getattr(getattr(data, '__class__', None), '__name__', None)
if klass == 'DataFrame':
# once check class name to avoid unnecessary pandas import
data, label, feature_names, feature_types = _maybe_from_pandas(data,
label,
feature_names,
feature_types)
data, feature_names, feature_types = _maybe_pandas_data(data,
feature_names,
feature_types)
label = _maybe_pandas_label(label)

if isinstance(data, STRING_TYPES):
self.handle = ctypes.c_void_p()
Expand Down
20 changes: 3 additions & 17 deletions python-package/xgboost/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,23 +7,9 @@
from .core import Booster, DMatrix, XGBoostError
from .training import train

try:
from sklearn.base import BaseEstimator
from sklearn.base import RegressorMixin, ClassifierMixin
from sklearn.preprocessing import LabelEncoder
SKLEARN_INSTALLED = True
except ImportError:
SKLEARN_INSTALLED = False

# used for compatiblity without sklearn
XGBModelBase = object
XGBClassifierBase = object
XGBRegressorBase = object

if SKLEARN_INSTALLED:
XGBModelBase = BaseEstimator
XGBRegressorBase = RegressorMixin
XGBClassifierBase = ClassifierMixin
from .compat import (SKLEARN_INSTALLED, XGBModelBase,
XGBClassifierBase, XGBRegressorBase, LabelEncoder)


class XGBModel(XGBModelBase):
# pylint: disable=too-many-arguments, too-many-instance-attributes, invalid-name
Expand Down
45 changes: 45 additions & 0 deletions tests/python/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,51 @@ def test_pandas(self):
assert dm.num_row() == 2
assert dm.num_col() == 3

df = pd.DataFrame({'A': ['X', 'Y', 'Z'], 'B': [1, 2, 3]})
dummies = pd.get_dummies(df)
# B A_X A_Y A_Z
# 0 1 1 0 0
# 1 2 0 1 0
# 2 3 0 0 1
result, _, _ = xgb.core._maybe_pandas_data(dummies, None, None)
exp = np.array([[ 1., 1., 0., 0.],
[ 2., 0., 1., 0.],
[ 3., 0., 0., 1.]])
np.testing.assert_array_equal(result, exp)

dm = xgb.DMatrix(dummies)
assert dm.feature_names == ['B', 'A_X', 'A_Y', 'A_Z']
assert dm.feature_types == ['int', 'float', 'float', 'float']
assert dm.num_row() == 3
assert dm.num_col() == 4

df = pd.DataFrame({'A=1': [1, 2, 3], 'A=2': [4, 5, 6]})
dm = xgb.DMatrix(df)
assert dm.feature_names == ['A=1', 'A=2']
assert dm.feature_types == ['int', 'int']
assert dm.num_row() == 3
assert dm.num_col() == 2

def test_pandas_label(self):
import pandas as pd

# label must be a single column
df = pd.DataFrame({'A': ['X', 'Y', 'Z'], 'B': [1, 2, 3]})
self.assertRaises(ValueError, xgb.core._maybe_pandas_label, df)

# label must be supported dtype
df = pd.DataFrame({'A': np.array(['a', 'b', 'c'], dtype=object)})
self.assertRaises(ValueError, xgb.core._maybe_pandas_label, df)

df = pd.DataFrame({'A': np.array([1, 2, 3], dtype=int)})
result = xgb.core._maybe_pandas_label(df)
np.testing.assert_array_equal(result, np.array([[1.], [2.], [3.]], dtype=float))

dm = xgb.DMatrix(np.random.randn(3, 2), label=df)
assert dm.num_row() == 3
assert dm.num_col() == 2


def test_load_file_invalid(self):

self.assertRaises(ValueError, xgb.Booster,
Expand Down

0 comments on commit 25c4fbd

Please sign in to comment.