diff --git a/pandas/_libs/hashtable.pyx b/pandas/_libs/hashtable.pyx index eee287b2c157b3..a4e5bee9a8746a 100644 --- a/pandas/_libs/hashtable.pyx +++ b/pandas/_libs/hashtable.pyx @@ -39,9 +39,6 @@ PyDateTime_IMPORT cdef extern from "Python.h": int PySlice_Check(object) -cdef size_t _INIT_VEC_CAP = 128 - - include "hashtable_class_helper.pxi" include "hashtable_func_helper.pxi" diff --git a/pandas/_libs/lib.pyx b/pandas/_libs/lib.pyx index b4724bc3dd59b7..d1a2a418469760 100644 --- a/pandas/_libs/lib.pyx +++ b/pandas/_libs/lib.pyx @@ -314,6 +314,7 @@ cpdef bint isscalar(object val): - instances of datetime.timedelta - Period - instances of decimal.Decimal + - Interval """ @@ -327,7 +328,8 @@ cpdef bint isscalar(object val): or PyDelta_Check(val) or PyTime_Check(val) or util.is_period_object(val) - or is_decimal(val)) + or is_decimal(val), + or is_interval(val)) def item_from_zerodim(object val): @@ -1965,4 +1967,6 @@ cdef class BlockPlacement: include "reduce.pyx" include "properties.pyx" +include "interval.pyx" +include "intervaltree.pyx" include "inference.pyx" diff --git a/pandas/_libs/src/inference.pyx b/pandas/_libs/src/inference.pyx index 933fc8fb1cc9bb..858a7a29ad8685 100644 --- a/pandas/_libs/src/inference.pyx +++ b/pandas/_libs/src/inference.pyx @@ -347,6 +347,10 @@ def infer_dtype(object _values): if is_period_array(values): return 'period' + elif is_interval(val): + if is_interval_array_fixed_closed(values): + return 'interval' + for i in range(n): val = util.get_value_1d(values, i) if (util.is_integer_object(val) and @@ -742,6 +746,23 @@ cpdef bint is_period_array(ndarray[object] values): return False return null_count != n +cdef inline bint is_interval(object o): + return isinstance(o, Interval) + +def is_interval_array_fixed_closed(ndarray[object] values): + cdef Py_ssize_t i, n = len(values) + cdef str closed + if n == 0: + return False + for i in range(n): + if not is_interval(values[i]): + return False + if i == 0: + closed = values[0].closed + elif closed != values[i].closed: + return False + return True + cdef extern from "parse_helper.h": inline int floatify(object, double *result, int *maybe_int) except -1 diff --git a/pandas/core/algorithms.py b/pandas/core/algorithms.py index 6937675603c109..a9b6be864c9d45 100644 --- a/pandas/core/algorithms.py +++ b/pandas/core/algorithms.py @@ -404,7 +404,6 @@ def value_counts(values, sort=True, ascending=False, normalize=False, cat, bins = cut(values, bins, retbins=True) except TypeError: raise TypeError("bins argument only works with numeric data.") - values = cat.codes if is_extension_type(values) and not is_datetimetz(values): # handle Categorical and sparse, diff --git a/pandas/core/api.py b/pandas/core/api.py index 65253dedb8b539..dbb5e22358c189 100644 --- a/pandas/core/api.py +++ b/pandas/core/api.py @@ -12,6 +12,7 @@ from pandas.core.index import (Index, CategoricalIndex, Int64Index, UInt64Index, RangeIndex, Float64Index, MultiIndex) +from pandas.core.interval import Interval, IntervalIndex from pandas.core.series import Series from pandas.core.frame import DataFrame diff --git a/pandas/core/groupby.py b/pandas/core/groupby.py index 4095a14aa59701..7b26428d0daefe 100644 --- a/pandas/core/groupby.py +++ b/pandas/core/groupby.py @@ -43,8 +43,7 @@ from pandas.core.categorical import Categorical from pandas.core.frame import DataFrame from pandas.core.generic import NDFrame -from pandas.core.index import (Index, MultiIndex, CategoricalIndex, - _ensure_index) +from pandas.core.interval import IntervalIndex from pandas.core.internals import BlockManager, make_block from pandas.core.series import Series from pandas.core.panel import Panel @@ -3059,12 +3058,20 @@ def value_counts(self, normalize=False, sort=True, ascending=False, if bins is None: lab, lev = algorithms.factorize(val, sort=True) else: - cat, bins = cut(val, bins, retbins=True) + raise NotImplementedError('this is broken') + lab, bins = cut(val, bins, retbins=True) # bins[:-1] for backward compat; # o.w. cat.categories could be better - lab, lev, dropna = cat.codes, bins[:-1], False - - sorter = np.lexsort((lab, ids)) + # cat = Categorical(cat) + # lab, lev, dropna = cat.codes, bins[:-1], False + + if (lab.dtype == object + and lib.is_interval_array_fixed_closed(lab[notnull(lab)])): + lab_index = Index(lab) + assert isinstance(lab, IntervalIndex) + sorter = np.lexsort((lab_index.left, lab_index.right, ids)) + else: + sorter = np.lexsort((lab, ids)) ids, lab = ids[sorter], lab[sorter] # group boundaries are where group ids change @@ -3105,12 +3112,13 @@ def value_counts(self, normalize=False, sort=True, ascending=False, acc = rep(d) out /= acc - if sort and bins is None: + if sort: # and bins is None: cat = ids[inc][mask] if dropna else ids[inc] sorter = np.lexsort((out if ascending else -out, cat)) out, labels[-1] = out[sorter], labels[-1][sorter] - if bins is None: + # if bins is None: + if True: mi = MultiIndex(levels=levels, labels=labels, names=names, verify_integrity=False) diff --git a/pandas/core/interval.py b/pandas/core/interval.py new file mode 100644 index 00000000000000..68e07f21367a02 --- /dev/null +++ b/pandas/core/interval.py @@ -0,0 +1,521 @@ +import operator + +import numpy as np +import pandas as pd + +from pandas.core.base import PandasObject, IndexOpsMixin +from pandas.core.common import (_values_from_object, _ensure_platform_int, + notnull, is_datetime_or_timedelta_dtype, + is_integer_dtype, is_float_dtype) +from pandas.core.index import (Index, _ensure_index, default_pprint, + InvalidIndexError, MultiIndex) +from pandas.lib import (Interval, IntervalMixin, IntervalTree, + interval_bounds_to_intervals, + intervals_to_interval_bounds) +from pandas.util.decorators import cache_readonly +import pandas.core.common as com + + +_VALID_CLOSED = set(['left', 'right', 'both', 'neither']) + + +def _get_next_label(label): + dtype = getattr(label, 'dtype', type(label)) + if isinstance(label, (pd.Timestamp, pd.Timedelta)): + dtype = 'datetime64' + if is_datetime_or_timedelta_dtype(dtype): + return label + np.timedelta64(1, 'ns') + elif is_integer_dtype(dtype): + return label + 1 + elif is_float_dtype(dtype): + return np.nextafter(label, np.infty) + else: + raise TypeError('cannot determine next label for type %r' + % type(label)) + + +def _get_prev_label(label): + dtype = getattr(label, 'dtype', type(label)) + if isinstance(label, (pd.Timestamp, pd.Timedelta)): + dtype = 'datetime64' + if is_datetime_or_timedelta_dtype(dtype): + return label - np.timedelta64(1, 'ns') + elif is_integer_dtype(dtype): + return label - 1 + elif is_float_dtype(dtype): + return np.nextafter(label, -np.infty) + else: + raise TypeError('cannot determine next label for type %r' + % type(label)) + + +def _get_interval_closed_bounds(interval): + """ + Given an Interval or IntervalIndex, return the corresponding interval with + closed bounds. + """ + left, right = interval.left, interval.right + if interval.open_left: + left = _get_next_label(left) + if interval.open_right: + right = _get_prev_label(right) + return left, right + + +class IntervalIndex(IntervalMixin, Index): + """ + Immutable Index implementing an ordered, sliceable set. IntervalIndex + represents an Index of intervals that are all closed on the same side. + + .. versionadded:: 0.18 + + Properties + ---------- + left, right : array-like (1-dimensional) + Left and right bounds for each interval. + closed : {'left', 'right', 'both', 'neither'}, optional + Whether the intervals are closed on the left-side, right-side, both or + neither. Defaults to 'right'. + name : object, optional + Name to be stored in the index. + """ + _typ = 'intervalindex' + _comparables = ['name'] + _attributes = ['name', 'closed'] + _allow_index_ops = True + _engine = None # disable it + + def __new__(cls, left, right, closed='right', name=None, fastpath=False): + # TODO: validation + result = IntervalMixin.__new__(cls) + result._left = _ensure_index(left) + result._right = _ensure_index(right) + result._closed = closed + result.name = name + if not fastpath: + result._validate() + result._reset_identity() + return result + + def _validate(self): + """Verify that the IntervalIndex is valid. + """ + # TODO: exclude periods? + if self.closed not in _VALID_CLOSED: + raise ValueError("invalid options for 'closed': %s" % self.closed) + if len(self.left) != len(self.right): + raise ValueError('left and right must have the same length') + left_valid = notnull(self.left) + right_valid = notnull(self.right) + if not (left_valid == right_valid).all(): + raise ValueError('missing values must be missing in the same ' + 'location both left and right sides') + if not (self.left[left_valid] <= self.right[left_valid]).all(): + raise ValueError('left side of interval must be <= right side') + + def _simple_new(cls, values, name=None, **kwargs): + # ensure we don't end up here (this is a superclass method) + raise NotImplementedError + + def _cleanup(self): + pass + + @property + def _engine(self): + raise NotImplementedError + + @cache_readonly + def _tree(self): + return IntervalTree(self.left, self.right, closed=self.closed) + + @property + def _constructor(self): + return type(self).from_intervals + + @classmethod + def from_breaks(cls, breaks, closed='right', name=None): + """ + Construct an IntervalIndex from an array of splits + + Parameters + ---------- + breaks : array-like (1-dimensional) + Left and right bounds for each interval. + closed : {'left', 'right', 'both', 'neither'}, optional + Whether the intervals are closed on the left-side, right-side, both + or neither. Defaults to 'right'. + name : object, optional + Name to be stored in the index. + + Examples + -------- + + >>> IntervalIndex.from_breaks([0, 1, 2, 3]) + IntervalIndex(left=[0, 1, 2], + right=[1, 2, 3], + closed='right') + """ + return cls(breaks[:-1], breaks[1:], closed, name) + + @classmethod + def from_intervals(cls, data, name=None): + """ + Construct an IntervalIndex from a 1d array of Interval objects + + Parameters + ---------- + data : array-like (1-dimensional) + Array of Interval objects. All intervals must be closed on the same + sides. + name : object, optional + Name to be stored in the index. + + Examples + -------- + + >>> IntervalIndex.from_intervals([Interval(0, 1), Interval(1, 2)]) + IntervalIndex(left=[0, 1], + right=[1, 2], + closed='right') + + The generic Index constructor work identically when it infers an array + of all intervals: + + >>> Index([Interval(0, 1), Interval(1, 2)]) + IntervalIndex(left=[0, 1], + right=[1, 2], + closed='right') + """ + data = np.asarray(data) + left, right, closed = intervals_to_interval_bounds(data) + return cls(left, right, closed, name) + + @classmethod + def from_tuples(cls, data, closed='right', name=None): + left = [] + right = [] + for l, r in data: + left.append(l) + right.append(r) + return cls(np.array(left), np.array(right), closed, name) + + def to_tuples(self): + return Index(com._asarray_tuplesafe(zip(self.left, self.right))) + + @cache_readonly + def _multiindex(self): + return MultiIndex.from_arrays([self.left, self.right], + names=['left', 'right']) + + @property + def left(self): + return self._left + + @property + def right(self): + return self._right + + @property + def closed(self): + return self._closed + + def __len__(self): + return len(self.left) + + @cache_readonly + def values(self): + """Returns the IntervalIndex's data as a numpy array of Interval + objects (with dtype='object') + """ + left = np.asarray(self.left) + right = np.asarray(self.right) + return interval_bounds_to_intervals(left, right, self.closed) + + def __array__(self, result=None): + """ the array interface, return my values """ + return self.values + + def __array_wrap__(self, result, context=None): + # we don't want the superclass implementation + return result + + def _array_values(self): + return self.values + + def __reduce__(self): + return self.__class__, (self.left, self.right, self.closed, self.name) + + def _shallow_copy(self, values=None, name=None): + name = name if name is not None else self.name + if values is not None: + return type(self).from_intervals(values, name=name) + else: + return self.copy(name=name) + + def copy(self, deep=False, name=None): + left = self.left.copy(deep=True) if deep else self.left + right = self.right.copy(deep=True) if deep else self.right + name = name if name is not None else self.name + return type(self)(left, right, closed=self.closed, name=name, + fastpath=True) + + @cache_readonly + def dtype(self): + return np.dtype('O') + + @cache_readonly + def mid(self): + """Returns the mid-point of each interval in the index as an array + """ + try: + return Index(0.5 * (self.left.values + self.right.values)) + except TypeError: + # datetime safe version + delta = self.right.values - self.left.values + return Index(self.left.values + 0.5 * delta) + + @cache_readonly + def is_monotonic_increasing(self): + return self._multiindex.is_monotonic_increasing + + @cache_readonly + def is_monotonic_decreasing(self): + return self._multiindex.is_monotonic_decreasing + + @cache_readonly + def is_unique(self): + return self._multiindex.is_unique + + @cache_readonly + def is_non_overlapping_monotonic(self): + # must be increasing (e.g., [0, 1), [1, 2), [2, 3), ... ) + # or decreasing (e.g., [-1, 0), [-2, -1), [-3, -2), ...) + # we already require left <= right + return ((self.right[:-1] <= self.left[1:]).all() or + (self.left[:-1] >= self.right[1:]).all()) + + def _convert_scalar_indexer(self, key, kind=None): + return key + + def _maybe_cast_slice_bound(self, label, side, kind): + return getattr(self, side)._maybe_cast_slice_bound(label, side, kind) + + def _convert_list_indexer(self, keyarr, kind=None): + """ + we are passed a list indexer. + Return our indexer or raise if all of the values are not included in the categories + """ + locs = self.get_indexer(keyarr) + # TODO: handle keyarr if it includes intervals + if (locs == -1).any(): + raise KeyError("a list-indexer must only include existing intervals") + + return locs + + def _check_method(self, method): + if method is not None: + raise NotImplementedError( + 'method %r not yet implemented for IntervalIndex' % method) + + def _searchsorted_monotonic(self, label, side, exclude_label=False): + if not self.is_non_overlapping_monotonic: + raise KeyError('can only get slices from an IntervalIndex if ' + 'bounds are non-overlapping and all monotonic ' + 'increasing or decreasing') + + if isinstance(label, IntervalMixin): + raise NotImplementedError + + if ((side == 'left' and self.left.is_monotonic_increasing) or + (side == 'right' and self.left.is_monotonic_decreasing)): + sub_idx = self.right + if self.open_right or exclude_label: + label = _get_next_label(label) + else: + sub_idx = self.left + if self.open_left or exclude_label: + label = _get_prev_label(label) + + return sub_idx._searchsorted_monotonic(label, side) + + def _get_loc_only_exact_matches(self, key): + return self._multiindex._tuple_index.get_loc(key) + + def _find_non_overlapping_monotonic_bounds(self, key): + if isinstance(key, IntervalMixin): + start = self._searchsorted_monotonic( + key.left, 'left', exclude_label=key.open_left) + stop = self._searchsorted_monotonic( + key.right, 'right', exclude_label=key.open_right) + else: + # scalar + start = self._searchsorted_monotonic(key, 'left') + stop = self._searchsorted_monotonic(key, 'right') + return start, stop + + def get_loc(self, key, method=None): + self._check_method(method) + + original_key = key + + if self.is_non_overlapping_monotonic: + if isinstance(key, Interval): + left = self._maybe_cast_slice_bound(key.left, 'left', None) + right = self._maybe_cast_slice_bound(key.right, 'right', None) + key = Interval(left, right, key.closed) + else: + key = self._maybe_cast_slice_bound(key, 'left', None) + + start, stop = self._find_non_overlapping_monotonic_bounds(key) + + if start + 1 == stop: + return start + elif start < stop: + return slice(start, stop) + else: + raise KeyError(original_key) + + else: + # use the interval tree + if isinstance(key, Interval): + left, right = _get_interval_closed_bounds(key) + return self._tree.get_loc_interval(left, right) + else: + return self._tree.get_loc(key) + + def get_value(self, series, key): + # this method seems necessary for Series.__getitem__ but I have no idea + # what it should actually do here... + loc = self.get_loc(key) # nb. this can't handle slice objects + return series.iloc[loc] + + def get_indexer(self, target, method=None, limit=None, tolerance=None): + self._check_method(method) + target = _ensure_index(target) + + if self.is_non_overlapping_monotonic: + start, stop = self._find_non_overlapping_monotonic_bounds(target) + + start_plus_one = start + 1 + if (start_plus_one < stop).any(): + raise ValueError('indexer corresponds to non-unique elements') + return np.where(start_plus_one == stop, start, -1) + + else: + if isinstance(target, IntervalIndex): + raise NotImplementedError( + 'have not yet implemented get_indexer ' + 'for IntervalIndex indexers') + else: + return self._tree.get_indexer(target) + + def delete(self, loc): + new_left = self.left.delete(loc) + new_right = self.right.delete(loc) + return type(self)(new_left, new_right, self.closed, self.name, + fastpath=True) + + def insert(self, loc, item): + if not isinstance(item, Interval): + raise ValueError('can only insert Interval objects into an ' + 'IntervalIndex') + if not item.closed == self.closed: + raise ValueError('inserted item must be closed on the same side ' + 'as the index') + new_left = self.left.insert(loc, item.left) + new_right = self.right.insert(loc, item.right) + return type(self)(new_left, new_right, self.closed, self.name, + fastpath=True) + + def _as_like_interval_index(self, other, error_msg): + self._assert_can_do_setop(other) + other = _ensure_index(other) + if (not isinstance(other, IntervalIndex) or + self.closed != other.closed): + raise ValueError(error_msg) + return other + + def append(self, other): + msg = ('can only append two IntervalIndex objects that are closed on ' + 'the same side') + other = self._as_like_interval_index(other, msg) + new_left = self.left.append(other.left) + new_right = self.right.append(other.right) + if other.name is not None and other.name != self.name: + name = None + else: + name = self.name + return type(self)(new_left, new_right, self.closed, name, + fastpath=True) + + def take(self, indexer, axis=0): + indexer = com._ensure_platform_int(indexer) + new_left = self.left.take(indexer) + new_right = self.right.take(indexer) + return type(self)(new_left, new_right, self.closed, self.name, + fastpath=True) + + def __contains__(self, key): + try: + self.get_loc(key) + return True + except KeyError: + return False + + def __getitem__(self, value): + left = self.left[value] + right = self.right[value] + if not isinstance(left, Index): + return Interval(left, right, self.closed) + else: + return type(self)(left, right, self.closed, self.name) + + # __repr__ associated methods are based on MultiIndex + + def _format_attrs(self): + attrs = [('left', default_pprint(self.left)), + ('right', default_pprint(self.right)), + ('closed', repr(self.closed))] + if self.name is not None: + attrs.append(('name', default_pprint(self.name))) + return attrs + + def _format_space(self): + return "\n%s" % (' ' * (len(self.__class__.__name__) + 1)) + + def _format_data(self): + return None + + def argsort(self, *args, **kwargs): + return np.lexsort((self.right, self.left)) + + def equals(self, other): + if self.is_(other): + return True + try: + return (self.left.equals(other.left) + and self.right.equals(other.right) + and self.closed == other.closed) + except AttributeError: + return False + + def _setop(op_name): + def func(self, other): + msg = ('can only do set operations between two IntervalIndex ' + 'objects that are closed on the same side') + other = self._as_like_interval_index(other, msg) + result = getattr(self._multiindex, op_name)(other._multiindex) + result_name = self.name if self.name == other.name else None + return type(self).from_tuples(result.values, closed=self.closed, + name=result_name) + return func + + union = _setop('union') + intersection = _setop('intersection') + difference = _setop('difference') + sym_diff = _setop('sym_diff') + + # TODO: arithmetic operations + + +IntervalIndex._add_logical_methods_disabled() diff --git a/pandas/src/generate_intervaltree.py b/pandas/src/generate_intervaltree.py new file mode 100644 index 00000000000000..275a0d40e2433c --- /dev/null +++ b/pandas/src/generate_intervaltree.py @@ -0,0 +1,395 @@ +""" +This file generates `intervaltree.pyx` which is then included in `../lib.pyx` +during building. To regenerate `intervaltree.pyx`, just run: + + `python generate_intervaltree.py`. +""" +from __future__ import print_function +import os +from pandas.compat import StringIO +import numpy as np + + +warning_to_new_contributors = """ +# DO NOT EDIT THIS FILE: This file was autogenerated from +# generate_intervaltree.py, so please edit that file and then run +# `python2 generate_intervaltree.py` to re-generate this file. +""" + +header = r''' +from numpy cimport int64_t, float64_t +from numpy cimport ndarray, PyArray_ArgSort, NPY_QUICKSORT, PyArray_Take +import numpy as np + +cimport cython +cimport numpy as cnp +cnp.import_array() + +from hashtable cimport Int64Vector, Int64VectorData + + +ctypedef fused scalar64_t: + float64_t + int64_t + + +NODE_CLASSES = {} + + +cdef class IntervalTree(IntervalMixin): + """A centered interval tree + + Based off the algorithm described on Wikipedia: + http://en.wikipedia.org/wiki/Interval_tree + """ + cdef: + readonly object left, right, root + readonly str closed + object _left_sorter, _right_sorter + + def __init__(self, left, right, closed='right', leaf_size=100): + """ + Parameters + ---------- + left, right : np.ndarray[ndim=1] + Left and right bounds for each interval. Assumed to contain no + NaNs. + closed : {'left', 'right', 'both', 'neither'}, optional + Whether the intervals are closed on the left-side, right-side, both + or neither. Defaults to 'right'. + leaf_size : int, optional + Parameter that controls when the tree switches from creating nodes + to brute-force search. Tune this parameter to optimize query + performance. + """ + if closed not in ['left', 'right', 'both', 'neither']: + raise ValueError("invalid option for 'closed': %s" % closed) + + left = np.asarray(left) + right = np.asarray(right) + dtype = np.result_type(left, right) + self.left = np.asarray(left, dtype=dtype) + self.right = np.asarray(right, dtype=dtype) + + indices = np.arange(len(left), dtype='int64') + + self.closed = closed + + node_cls = NODE_CLASSES[str(dtype), closed] + self.root = node_cls(self.left, self.right, indices, leaf_size) + + @property + def left_sorter(self): + """How to sort the left labels; this is used for binary search + """ + if self._left_sorter is None: + self._left_sorter = np.argsort(self.left) + return self._left_sorter + + @property + def right_sorter(self): + """How to sort the right labels + """ + if self._right_sorter is None: + self._right_sorter = np.argsort(self.right) + return self._right_sorter + + def get_loc(self, scalar64_t key): + """Return all positions corresponding to intervals that overlap with + the given scalar key + """ + result = Int64Vector() + self.root.query(result, key) + if not result.data.n: + raise KeyError(key) + return result.to_array() + + def _get_partial_overlap(self, key_left, key_right, side): + """Return all positions corresponding to intervals with the given side + falling between the left and right bounds of an interval query + """ + if side == 'left': + values = self.left + sorter = self.left_sorter + else: + values = self.right + sorter = self.right_sorter + key = [key_left, key_right] + i, j = values.searchsorted(key, sorter=sorter) + return sorter[i:j] + + def get_loc_interval(self, key_left, key_right): + """Lookup the intervals enclosed in the given interval bounds + + The given interval is presumed to have closed bounds. + """ + import pandas as pd + left_overlap = self._get_partial_overlap(key_left, key_right, 'left') + right_overlap = self._get_partial_overlap(key_left, key_right, 'right') + enclosing = self.get_loc(0.5 * (key_left + key_right)) + combined = np.concatenate([left_overlap, right_overlap, enclosing]) + uniques = pd.unique(combined) + return uniques + + def get_indexer(self, scalar64_t[:] target): + """Return the positions corresponding to unique intervals that overlap + with the given array of scalar targets. + """ + # TODO: write get_indexer_intervals + cdef: + int64_t old_len, i + Int64Vector result + + result = Int64Vector() + old_len = 0 + for i in range(len(target)): + self.root.query(result, target[i]) + if result.data.n == old_len: + result.append(-1) + elif result.data.n > old_len + 1: + raise KeyError( + 'indexer does not intersect a unique set of intervals') + old_len = result.data.n + return result.to_array() + + def get_indexer_non_unique(self, scalar64_t[:] target): + """Return the positions corresponding to intervals that overlap with + the given array of scalar targets. Non-unique positions are repeated. + """ + cdef: + int64_t old_len, i + Int64Vector result, missing + + result = Int64Vector() + missing = Int64Vector() + old_len = 0 + for i in range(len(target)): + self.root.query(result, target[i]) + if result.data.n == old_len: + result.append(-1) + missing.append(i) + old_len = result.data.n + return result.to_array(), missing.to_array() + + def __repr__(self): + return ('' + % self.root.n_elements) + + +cdef take(ndarray source, ndarray indices): + """Take the given positions from a 1D ndarray + """ + return PyArray_Take(source, indices, 0) + + +cdef sort_values_and_indices(all_values, all_indices, subset): + indices = take(all_indices, subset) + values = take(all_values, subset) + sorter = PyArray_ArgSort(values, 0, NPY_QUICKSORT) + sorted_values = take(values, sorter) + sorted_indices = take(indices, sorter) + return sorted_values, sorted_indices +''' + +# we need specialized nodes and leaves to optimize for different dtype and +# closed values +# unfortunately, fused dtypes can't parameterize attributes on extension types, +# so we're stuck using template generation. + +node_template = r''' +cdef class {dtype_title}Closed{closed_title}IntervalNode: + """Non-terminal node for an IntervalTree + + Categorizes intervals by those that fall to the left, those that fall to + the right, and those that overlap with the pivot. + """ + cdef: + {dtype_title}Closed{closed_title}IntervalNode left_node, right_node + {dtype}_t[:] center_left_values, center_right_values, left, right + int64_t[:] center_left_indices, center_right_indices, indices + {dtype}_t min_left, max_right + readonly {dtype}_t pivot + readonly int64_t n_elements, n_center, leaf_size + readonly bint is_leaf_node + + def __init__(self, + ndarray[{dtype}_t, ndim=1] left, + ndarray[{dtype}_t, ndim=1] right, + ndarray[int64_t, ndim=1] indices, + int64_t leaf_size): + + self.n_elements = len(left) + self.leaf_size = leaf_size + if left.size > 0: + self.min_left = left.min() + self.max_right = right.max() + else: + self.min_left = 0 + self.max_right = 0 + + if self.n_elements <= leaf_size: + # make this a terminal (leaf) node + self.is_leaf_node = True + self.left = left + self.right = right + self.indices = indices + self.n_center + else: + # calculate a pivot so we can create child nodes + self.is_leaf_node = False + self.pivot = np.median(left + right) / 2 + left_set, right_set, center_set = self.classify_intervals(left, right) + + self.left_node = self.new_child_node(left, right, indices, left_set) + self.right_node = self.new_child_node(left, right, indices, right_set) + + self.center_left_values, self.center_left_indices = \ + sort_values_and_indices(left, indices, center_set) + self.center_right_values, self.center_right_indices = \ + sort_values_and_indices(right, indices, center_set) + self.n_center = len(self.center_left_indices) + + @cython.wraparound(False) + @cython.boundscheck(False) + cdef classify_intervals(self, {dtype}_t[:] left, {dtype}_t[:] right): + """Classify the given intervals based upon whether they fall to the + left, right, or overlap with this node's pivot. + """ + cdef: + Int64Vector left_ind, right_ind, overlapping_ind + Py_ssize_t i + + left_ind = Int64Vector() + right_ind = Int64Vector() + overlapping_ind = Int64Vector() + + for i in range(self.n_elements): + if right[i] {cmp_right_converse} self.pivot: + left_ind.append(i) + elif self.pivot {cmp_left_converse} left[i]: + right_ind.append(i) + else: + overlapping_ind.append(i) + + return (left_ind.to_array(), + right_ind.to_array(), + overlapping_ind.to_array()) + + cdef new_child_node(self, + ndarray[{dtype}_t, ndim=1] left, + ndarray[{dtype}_t, ndim=1] right, + ndarray[int64_t, ndim=1] indices, + ndarray[int64_t, ndim=1] subset): + """Create a new child node. + """ + left = take(left, subset) + right = take(right, subset) + indices = take(indices, subset) + return {dtype_title}Closed{closed_title}IntervalNode( + left, right, indices, self.leaf_size) + + @cython.wraparound(False) + @cython.boundscheck(False) + @cython.initializedcheck(False) + cpdef query(self, Int64Vector result, scalar64_t point): + """Recursively query this node and its sub-nodes for intervals that + overlap with the query point. + """ + cdef: + int64_t[:] indices + {dtype}_t[:] values + Py_ssize_t i + + if self.is_leaf_node: + # Once we get down to a certain size, it doesn't make sense to + # continue the binary tree structure. Instead, we use linear + # search. + for i in range(self.n_elements): + if self.left[i] {cmp_left} point {cmp_right} self.right[i]: + result.append(self.indices[i]) + else: + # There are child nodes. Based on comparing our query to the pivot, + # look at the center values, then go to the relevant child. + if point < self.pivot: + values = self.center_left_values + indices = self.center_left_indices + for i in range(self.n_center): + if not values[i] {cmp_left} point: + break + result.append(indices[i]) + if point {cmp_right} self.left_node.max_right: + self.left_node.query(result, point) + elif point > self.pivot: + values = self.center_right_values + indices = self.center_right_indices + for i in range(self.n_center - 1, -1, -1): + if not point {cmp_right} values[i]: + break + result.append(indices[i]) + if self.right_node.min_left {cmp_left} point: + self.right_node.query(result, point) + else: + result.extend(self.center_left_indices) + + def __repr__(self): + if self.is_leaf_node: + return ('<{dtype_title}Closed{closed_title}IntervalNode: ' + '%s elements (terminal)>' % self.n_elements) + else: + n_left = self.left_node.n_elements + n_right = self.right_node.n_elements + n_center = self.n_elements - n_left - n_right + return ('<{dtype_title}Closed{closed_title}IntervalNode: pivot %s, ' + '%s elements (%s left, %s right, %s overlapping)>' % + (self.pivot, self.n_elements, n_left, n_right, n_center)) + + def counts(self): + if self.is_leaf_node: + return self.n_elements + else: + m = len(self.center_left_values) + l = self.left_node.counts() + r = self.right_node.counts() + return (m, (l, r)) + +NODE_CLASSES['{dtype}', '{closed}'] = {dtype_title}Closed{closed_title}IntervalNode +''' + + +def generate_node_template(): + output = StringIO() + for dtype in ['float64', 'int64']: + for closed, cmp_left, cmp_right in [ + ('left', '<=', '<'), + ('right', '<', '<='), + ('both', '<=', '<='), + ('neither', '<', '<')]: + cmp_left_converse = '<' if cmp_left == '<=' else '<=' + cmp_right_converse = '<' if cmp_right == '<=' else '<=' + classes = node_template.format(dtype=dtype, + dtype_title=dtype.title(), + closed=closed, + closed_title=closed.title(), + cmp_left=cmp_left, + cmp_right=cmp_right, + cmp_left_converse=cmp_left_converse, + cmp_right_converse=cmp_right_converse) + output.write(classes) + output.write("\n") + return output.getvalue() + + +def generate_cython_file(): + # Put `intervaltree.pyx` in the same directory as this file + directory = os.path.dirname(os.path.realpath(__file__)) + filename = 'intervaltree.pyx' + path = os.path.join(directory, filename) + + with open(path, 'w') as f: + print(warning_to_new_contributors, file=f) + print(header, file=f) + print(generate_node_template(), file=f) + + +if __name__ == '__main__': + generate_cython_file() diff --git a/pandas/src/interval.pyx b/pandas/src/interval.pyx new file mode 100644 index 00000000000000..495730e0fd6a1c --- /dev/null +++ b/pandas/src/interval.pyx @@ -0,0 +1,171 @@ +cimport numpy as np +import numpy as np +import pandas as pd + +cimport cython +import cython + +from cpython.object cimport (Py_EQ, Py_NE, Py_GT, Py_LT, Py_GE, Py_LE, + PyObject_RichCompare) + +import numbers +_VALID_CLOSED = frozenset(['left', 'right', 'both', 'neither']) + +cdef class IntervalMixin: + property closed_left: + def __get__(self): + return self.closed == 'left' or self.closed == 'both' + + property closed_right: + def __get__(self): + return self.closed == 'right' or self.closed == 'both' + + property open_left: + def __get__(self): + return not self.closed_left + + property open_right: + def __get__(self): + return not self.closed_right + + property mid: + def __get__(self): + try: + return 0.5 * (self.left + self.right) + except TypeError: + # datetime safe version + return self.left + 0.5 * (self.right - self.left) + + +cdef _interval_like(other): + return (hasattr(other, 'left') + and hasattr(other, 'right') + and hasattr(other, 'closed')) + + +cdef class Interval(IntervalMixin): + cdef readonly object left, right + cdef readonly str closed + + def __init__(self, left, right, str closed='right'): + # note: it is faster to just do these checks than to use a special + # constructor (__cinit__/__new__) to avoid them + if closed not in _VALID_CLOSED: + raise ValueError("invalid option for 'closed': %s" % closed) + if not left <= right: + raise ValueError('left side of interval must be <= right side') + self.left = left + self.right = right + self.closed = closed + + def __hash__(self): + return hash((self.left, self.right, self.closed)) + + def __contains__(self, key): + if _interval_like(key): + raise TypeError('__contains__ not defined for two intervals') + return ((self.left < key if self.open_left else self.left <= key) and + (key < self.right if self.open_right else key <= self.right)) + + def __richcmp__(self, other, int op): + if hasattr(other, 'ndim'): + # let numpy (or IntervalIndex) handle vectorization + return NotImplemented + + if _interval_like(other): + self_tuple = (self.left, self.right, self.closed) + other_tuple = (other.left, other.right, other.closed) + return PyObject_RichCompare(self_tuple, other_tuple, op) + + # nb. could just return NotImplemented now, but handling this + # explicitly allows us to opt into the Python 3 behavior, even on + # Python 2. + if op == Py_EQ or op == Py_NE: + return NotImplemented + else: + op_str = {Py_LT: '<', Py_LE: '<=', Py_GT: '>', Py_GE: '>='}[op] + raise TypeError('unorderable types: %s() %s %s()' % + (type(self).__name__, op_str, type(other).__name__)) + + def __reduce__(self): + args = (self.left, self.right, self.closed) + return (type(self), args) + + def __repr__(self): + return ('%s(%r, %r, closed=%r)' % + (type(self).__name__, self.left, self.right, self.closed)) + + def __str__(self): + start_symbol = '[' if self.closed_left else '(' + end_symbol = ']' if self.closed_right else ')' + return '%s%s, %s%s' % (start_symbol, self.left, self.right, end_symbol) + + def __add__(self, y): + if isinstance(y, numbers.Number): + return Interval(self.left + y, self.right + y) + elif isinstance(y, Interval) and isinstance(self, numbers.Number): + return Interval(y.left + self, y.right + self) + else: + raise NotImplemented + + def __sub__(self, y): + if isinstance(y, numbers.Number): + return Interval(self.left - y, self.right - y) + else: + raise NotImplemented + + def __mul__(self, y): + if isinstance(y, numbers.Number): + return Interval(self.left * y, self.right * y) + elif isinstance(y, Interval) and isinstance(self, numbers.Number): + return Interval(y.left * self, y.right * self) + else: + return NotImplemented + + def __div__(self, y): + if isinstance(y, numbers.Number): + return Interval(self.left / y, self.right / y) + else: + return NotImplemented + + def __truediv__(self, y): + if isinstance(y, numbers.Number): + return Interval(self.left / y, self.right / y) + else: + return NotImplemented + + def __floordiv__(self, y): + if isinstance(y, numbers.Number): + return Interval(self.left // y, self.right // y) + else: + return NotImplemented + + +@cython.wraparound(False) +@cython.boundscheck(False) +cpdef interval_bounds_to_intervals(np.ndarray left, np.ndarray right, + str closed): + result = np.empty(len(left), dtype=object) + nulls = pd.isnull(left) | pd.isnull(right) + result[nulls] = np.nan + for i in np.flatnonzero(~nulls): + result[i] = Interval(left[i], right[i], closed) + return result + + +@cython.wraparound(False) +@cython.boundscheck(False) +cpdef intervals_to_interval_bounds(np.ndarray intervals): + left = np.empty(len(intervals), dtype=object) + right = np.empty(len(intervals), dtype=object) + cdef str closed = None + for i in range(len(intervals)): + interval = intervals[i] + left[i] = interval.left + right[i] = interval.right + if closed is None: + closed = interval.closed + elif closed != interval.closed: + raise ValueError('intervals must all be closed on the same side') + return left, right, closed + diff --git a/pandas/src/intervaltree.pyx b/pandas/src/intervaltree.pyx new file mode 100644 index 00000000000000..55782c930d4f8c --- /dev/null +++ b/pandas/src/intervaltree.pyx @@ -0,0 +1,1444 @@ + +# DO NOT EDIT THIS FILE: This file was autogenerated from +# generate_intervaltree.py, so please edit that file and then run +# `python2 generate_intervaltree.py` to re-generate this file. + + +from numpy cimport int64_t, float64_t +from numpy cimport ndarray, PyArray_ArgSort, NPY_QUICKSORT, PyArray_Take +import numpy as np + +cimport cython +cimport numpy as cnp +cnp.import_array() + +from hashtable cimport Int64Vector, Int64VectorData + + +ctypedef fused scalar64_t: + float64_t + int64_t + + +NODE_CLASSES = {} + + +cdef class IntervalTree(IntervalMixin): + """A centered interval tree + + Based off the algorithm described on Wikipedia: + http://en.wikipedia.org/wiki/Interval_tree + """ + cdef: + readonly object left, right, root + readonly str closed + object _left_sorter, _right_sorter + + def __init__(self, left, right, closed='right', leaf_size=100): + """ + Parameters + ---------- + left, right : np.ndarray[ndim=1] + Left and right bounds for each interval. Assumed to contain no + NaNs. + closed : {'left', 'right', 'both', 'neither'}, optional + Whether the intervals are closed on the left-side, right-side, both + or neither. Defaults to 'right'. + leaf_size : int, optional + Parameter that controls when the tree switches from creating nodes + to brute-force search. Tune this parameter to optimize query + performance. + """ + if closed not in ['left', 'right', 'both', 'neither']: + raise ValueError("invalid option for 'closed': %s" % closed) + + left = np.asarray(left) + right = np.asarray(right) + dtype = np.result_type(left, right) + self.left = np.asarray(left, dtype=dtype) + self.right = np.asarray(right, dtype=dtype) + + indices = np.arange(len(left), dtype='int64') + + self.closed = closed + + node_cls = NODE_CLASSES[str(dtype), closed] + self.root = node_cls(self.left, self.right, indices, leaf_size) + + @property + def left_sorter(self): + """How to sort the left labels; this is used for binary search + """ + if self._left_sorter is None: + self._left_sorter = np.argsort(self.left) + return self._left_sorter + + @property + def right_sorter(self): + """How to sort the right labels + """ + if self._right_sorter is None: + self._right_sorter = np.argsort(self.right) + return self._right_sorter + + def get_loc(self, scalar64_t key): + """Return all positions corresponding to intervals that overlap with + the given scalar key + """ + result = Int64Vector() + self.root.query(result, key) + if not result.data.n: + raise KeyError(key) + return result.to_array() + + def _get_partial_overlap(self, key_left, key_right, side): + """Return all positions corresponding to intervals with the given side + falling between the left and right bounds of an interval query + """ + if side == 'left': + values = self.left + sorter = self.left_sorter + else: + values = self.right + sorter = self.right_sorter + key = [key_left, key_right] + i, j = values.searchsorted(key, sorter=sorter) + return sorter[i:j] + + def get_loc_interval(self, key_left, key_right): + """Lookup the intervals enclosed in the given interval bounds + + The given interval is presumed to have closed bounds. + """ + import pandas as pd + left_overlap = self._get_partial_overlap(key_left, key_right, 'left') + right_overlap = self._get_partial_overlap(key_left, key_right, 'right') + enclosing = self.get_loc(0.5 * (key_left + key_right)) + combined = np.concatenate([left_overlap, right_overlap, enclosing]) + uniques = pd.unique(combined) + return uniques + + def get_indexer(self, scalar64_t[:] target): + """Return the positions corresponding to unique intervals that overlap + with the given array of scalar targets. + """ + # TODO: write get_indexer_intervals + cdef: + int64_t old_len, i + Int64Vector result + + result = Int64Vector() + old_len = 0 + for i in range(len(target)): + self.root.query(result, target[i]) + if result.data.n == old_len: + result.append(-1) + elif result.data.n > old_len + 1: + raise KeyError( + 'indexer does not intersect a unique set of intervals') + old_len = result.data.n + return result.to_array() + + def get_indexer_non_unique(self, scalar64_t[:] target): + """Return the positions corresponding to intervals that overlap with + the given array of scalar targets. Non-unique positions are repeated. + """ + cdef: + int64_t old_len, i + Int64Vector result, missing + + result = Int64Vector() + missing = Int64Vector() + old_len = 0 + for i in range(len(target)): + self.root.query(result, target[i]) + if result.data.n == old_len: + result.append(-1) + missing.append(i) + old_len = result.data.n + return result.to_array(), missing.to_array() + + def __repr__(self): + return ('' + % self.root.n_elements) + + +cdef take(ndarray source, ndarray indices): + """Take the given positions from a 1D ndarray + """ + return PyArray_Take(source, indices, 0) + + +cdef sort_values_and_indices(all_values, all_indices, subset): + indices = take(all_indices, subset) + values = take(all_values, subset) + sorter = PyArray_ArgSort(values, 0, NPY_QUICKSORT) + sorted_values = take(values, sorter) + sorted_indices = take(indices, sorter) + return sorted_values, sorted_indices + + +cdef class Float64ClosedLeftIntervalNode: + """Non-terminal node for an IntervalTree + + Categorizes intervals by those that fall to the left, those that fall to + the right, and those that overlap with the pivot. + """ + cdef: + Float64ClosedLeftIntervalNode left_node, right_node + float64_t[:] center_left_values, center_right_values, left, right + int64_t[:] center_left_indices, center_right_indices, indices + float64_t min_left, max_right + readonly float64_t pivot + readonly int64_t n_elements, n_center, leaf_size + readonly bint is_leaf_node + + def __init__(self, + ndarray[float64_t, ndim=1] left, + ndarray[float64_t, ndim=1] right, + ndarray[int64_t, ndim=1] indices, + int64_t leaf_size): + + self.n_elements = len(left) + self.leaf_size = leaf_size + if left.size > 0: + self.min_left = left.min() + self.max_right = right.max() + else: + self.min_left = 0 + self.max_right = 0 + + if self.n_elements <= leaf_size: + # make this a terminal (leaf) node + self.is_leaf_node = True + self.left = left + self.right = right + self.indices = indices + self.n_center + else: + # calculate a pivot so we can create child nodes + self.is_leaf_node = False + self.pivot = np.median(left + right) / 2 + left_set, right_set, center_set = self.classify_intervals(left, right) + + self.left_node = self.new_child_node(left, right, indices, left_set) + self.right_node = self.new_child_node(left, right, indices, right_set) + + self.center_left_values, self.center_left_indices = \ + sort_values_and_indices(left, indices, center_set) + self.center_right_values, self.center_right_indices = \ + sort_values_and_indices(right, indices, center_set) + self.n_center = len(self.center_left_indices) + + @cython.wraparound(False) + @cython.boundscheck(False) + cdef classify_intervals(self, float64_t[:] left, float64_t[:] right): + """Classify the given intervals based upon whether they fall to the + left, right, or overlap with this node's pivot. + """ + cdef: + Int64Vector left_ind, right_ind, overlapping_ind + Py_ssize_t i + + left_ind = Int64Vector() + right_ind = Int64Vector() + overlapping_ind = Int64Vector() + + for i in range(self.n_elements): + if right[i] <= self.pivot: + left_ind.append(i) + elif self.pivot < left[i]: + right_ind.append(i) + else: + overlapping_ind.append(i) + + return (left_ind.to_array(), + right_ind.to_array(), + overlapping_ind.to_array()) + + cdef new_child_node(self, + ndarray[float64_t, ndim=1] left, + ndarray[float64_t, ndim=1] right, + ndarray[int64_t, ndim=1] indices, + ndarray[int64_t, ndim=1] subset): + """Create a new child node. + """ + left = take(left, subset) + right = take(right, subset) + indices = take(indices, subset) + return Float64ClosedLeftIntervalNode( + left, right, indices, self.leaf_size) + + @cython.wraparound(False) + @cython.boundscheck(False) + @cython.initializedcheck(False) + cpdef query(self, Int64Vector result, scalar64_t point): + """Recursively query this node and its sub-nodes for intervals that + overlap with the query point. + """ + cdef: + int64_t[:] indices + float64_t[:] values + Py_ssize_t i + + if self.is_leaf_node: + # Once we get down to a certain size, it doesn't make sense to + # continue the binary tree structure. Instead, we use linear + # search. + for i in range(self.n_elements): + if self.left[i] <= point < self.right[i]: + result.append(self.indices[i]) + else: + # There are child nodes. Based on comparing our query to the pivot, + # look at the center values, then go to the relevant child. + if point < self.pivot: + values = self.center_left_values + indices = self.center_left_indices + for i in range(self.n_center): + if not values[i] <= point: + break + result.append(indices[i]) + if point < self.left_node.max_right: + self.left_node.query(result, point) + elif point > self.pivot: + values = self.center_right_values + indices = self.center_right_indices + for i in range(self.n_center - 1, -1, -1): + if not point < values[i]: + break + result.append(indices[i]) + if self.right_node.min_left <= point: + self.right_node.query(result, point) + else: + result.extend(self.center_left_indices) + + def __repr__(self): + if self.is_leaf_node: + return ('' % self.n_elements) + else: + n_left = self.left_node.n_elements + n_right = self.right_node.n_elements + n_center = self.n_elements - n_left - n_right + return ('' % + (self.pivot, self.n_elements, n_left, n_right, n_center)) + + def counts(self): + if self.is_leaf_node: + return self.n_elements + else: + m = len(self.center_left_values) + l = self.left_node.counts() + r = self.right_node.counts() + return (m, (l, r)) + +NODE_CLASSES['float64', 'left'] = Float64ClosedLeftIntervalNode + + +cdef class Float64ClosedRightIntervalNode: + """Non-terminal node for an IntervalTree + + Categorizes intervals by those that fall to the left, those that fall to + the right, and those that overlap with the pivot. + """ + cdef: + Float64ClosedRightIntervalNode left_node, right_node + float64_t[:] center_left_values, center_right_values, left, right + int64_t[:] center_left_indices, center_right_indices, indices + float64_t min_left, max_right + readonly float64_t pivot + readonly int64_t n_elements, n_center, leaf_size + readonly bint is_leaf_node + + def __init__(self, + ndarray[float64_t, ndim=1] left, + ndarray[float64_t, ndim=1] right, + ndarray[int64_t, ndim=1] indices, + int64_t leaf_size): + + self.n_elements = len(left) + self.leaf_size = leaf_size + if left.size > 0: + self.min_left = left.min() + self.max_right = right.max() + else: + self.min_left = 0 + self.max_right = 0 + + if self.n_elements <= leaf_size: + # make this a terminal (leaf) node + self.is_leaf_node = True + self.left = left + self.right = right + self.indices = indices + self.n_center + else: + # calculate a pivot so we can create child nodes + self.is_leaf_node = False + self.pivot = np.median(left + right) / 2 + left_set, right_set, center_set = self.classify_intervals(left, right) + + self.left_node = self.new_child_node(left, right, indices, left_set) + self.right_node = self.new_child_node(left, right, indices, right_set) + + self.center_left_values, self.center_left_indices = \ + sort_values_and_indices(left, indices, center_set) + self.center_right_values, self.center_right_indices = \ + sort_values_and_indices(right, indices, center_set) + self.n_center = len(self.center_left_indices) + + @cython.wraparound(False) + @cython.boundscheck(False) + cdef classify_intervals(self, float64_t[:] left, float64_t[:] right): + """Classify the given intervals based upon whether they fall to the + left, right, or overlap with this node's pivot. + """ + cdef: + Int64Vector left_ind, right_ind, overlapping_ind + Py_ssize_t i + + left_ind = Int64Vector() + right_ind = Int64Vector() + overlapping_ind = Int64Vector() + + for i in range(self.n_elements): + if right[i] < self.pivot: + left_ind.append(i) + elif self.pivot <= left[i]: + right_ind.append(i) + else: + overlapping_ind.append(i) + + return (left_ind.to_array(), + right_ind.to_array(), + overlapping_ind.to_array()) + + cdef new_child_node(self, + ndarray[float64_t, ndim=1] left, + ndarray[float64_t, ndim=1] right, + ndarray[int64_t, ndim=1] indices, + ndarray[int64_t, ndim=1] subset): + """Create a new child node. + """ + left = take(left, subset) + right = take(right, subset) + indices = take(indices, subset) + return Float64ClosedRightIntervalNode( + left, right, indices, self.leaf_size) + + @cython.wraparound(False) + @cython.boundscheck(False) + @cython.initializedcheck(False) + cpdef query(self, Int64Vector result, scalar64_t point): + """Recursively query this node and its sub-nodes for intervals that + overlap with the query point. + """ + cdef: + int64_t[:] indices + float64_t[:] values + Py_ssize_t i + + if self.is_leaf_node: + # Once we get down to a certain size, it doesn't make sense to + # continue the binary tree structure. Instead, we use linear + # search. + for i in range(self.n_elements): + if self.left[i] < point <= self.right[i]: + result.append(self.indices[i]) + else: + # There are child nodes. Based on comparing our query to the pivot, + # look at the center values, then go to the relevant child. + if point < self.pivot: + values = self.center_left_values + indices = self.center_left_indices + for i in range(self.n_center): + if not values[i] < point: + break + result.append(indices[i]) + if point <= self.left_node.max_right: + self.left_node.query(result, point) + elif point > self.pivot: + values = self.center_right_values + indices = self.center_right_indices + for i in range(self.n_center - 1, -1, -1): + if not point <= values[i]: + break + result.append(indices[i]) + if self.right_node.min_left < point: + self.right_node.query(result, point) + else: + result.extend(self.center_left_indices) + + def __repr__(self): + if self.is_leaf_node: + return ('' % self.n_elements) + else: + n_left = self.left_node.n_elements + n_right = self.right_node.n_elements + n_center = self.n_elements - n_left - n_right + return ('' % + (self.pivot, self.n_elements, n_left, n_right, n_center)) + + def counts(self): + if self.is_leaf_node: + return self.n_elements + else: + m = len(self.center_left_values) + l = self.left_node.counts() + r = self.right_node.counts() + return (m, (l, r)) + +NODE_CLASSES['float64', 'right'] = Float64ClosedRightIntervalNode + + +cdef class Float64ClosedBothIntervalNode: + """Non-terminal node for an IntervalTree + + Categorizes intervals by those that fall to the left, those that fall to + the right, and those that overlap with the pivot. + """ + cdef: + Float64ClosedBothIntervalNode left_node, right_node + float64_t[:] center_left_values, center_right_values, left, right + int64_t[:] center_left_indices, center_right_indices, indices + float64_t min_left, max_right + readonly float64_t pivot + readonly int64_t n_elements, n_center, leaf_size + readonly bint is_leaf_node + + def __init__(self, + ndarray[float64_t, ndim=1] left, + ndarray[float64_t, ndim=1] right, + ndarray[int64_t, ndim=1] indices, + int64_t leaf_size): + + self.n_elements = len(left) + self.leaf_size = leaf_size + if left.size > 0: + self.min_left = left.min() + self.max_right = right.max() + else: + self.min_left = 0 + self.max_right = 0 + + if self.n_elements <= leaf_size: + # make this a terminal (leaf) node + self.is_leaf_node = True + self.left = left + self.right = right + self.indices = indices + self.n_center + else: + # calculate a pivot so we can create child nodes + self.is_leaf_node = False + self.pivot = np.median(left + right) / 2 + left_set, right_set, center_set = self.classify_intervals(left, right) + + self.left_node = self.new_child_node(left, right, indices, left_set) + self.right_node = self.new_child_node(left, right, indices, right_set) + + self.center_left_values, self.center_left_indices = \ + sort_values_and_indices(left, indices, center_set) + self.center_right_values, self.center_right_indices = \ + sort_values_and_indices(right, indices, center_set) + self.n_center = len(self.center_left_indices) + + @cython.wraparound(False) + @cython.boundscheck(False) + cdef classify_intervals(self, float64_t[:] left, float64_t[:] right): + """Classify the given intervals based upon whether they fall to the + left, right, or overlap with this node's pivot. + """ + cdef: + Int64Vector left_ind, right_ind, overlapping_ind + Py_ssize_t i + + left_ind = Int64Vector() + right_ind = Int64Vector() + overlapping_ind = Int64Vector() + + for i in range(self.n_elements): + if right[i] < self.pivot: + left_ind.append(i) + elif self.pivot < left[i]: + right_ind.append(i) + else: + overlapping_ind.append(i) + + return (left_ind.to_array(), + right_ind.to_array(), + overlapping_ind.to_array()) + + cdef new_child_node(self, + ndarray[float64_t, ndim=1] left, + ndarray[float64_t, ndim=1] right, + ndarray[int64_t, ndim=1] indices, + ndarray[int64_t, ndim=1] subset): + """Create a new child node. + """ + left = take(left, subset) + right = take(right, subset) + indices = take(indices, subset) + return Float64ClosedBothIntervalNode( + left, right, indices, self.leaf_size) + + @cython.wraparound(False) + @cython.boundscheck(False) + @cython.initializedcheck(False) + cpdef query(self, Int64Vector result, scalar64_t point): + """Recursively query this node and its sub-nodes for intervals that + overlap with the query point. + """ + cdef: + int64_t[:] indices + float64_t[:] values + Py_ssize_t i + + if self.is_leaf_node: + # Once we get down to a certain size, it doesn't make sense to + # continue the binary tree structure. Instead, we use linear + # search. + for i in range(self.n_elements): + if self.left[i] <= point <= self.right[i]: + result.append(self.indices[i]) + else: + # There are child nodes. Based on comparing our query to the pivot, + # look at the center values, then go to the relevant child. + if point < self.pivot: + values = self.center_left_values + indices = self.center_left_indices + for i in range(self.n_center): + if not values[i] <= point: + break + result.append(indices[i]) + if point <= self.left_node.max_right: + self.left_node.query(result, point) + elif point > self.pivot: + values = self.center_right_values + indices = self.center_right_indices + for i in range(self.n_center - 1, -1, -1): + if not point <= values[i]: + break + result.append(indices[i]) + if self.right_node.min_left <= point: + self.right_node.query(result, point) + else: + result.extend(self.center_left_indices) + + def __repr__(self): + if self.is_leaf_node: + return ('' % self.n_elements) + else: + n_left = self.left_node.n_elements + n_right = self.right_node.n_elements + n_center = self.n_elements - n_left - n_right + return ('' % + (self.pivot, self.n_elements, n_left, n_right, n_center)) + + def counts(self): + if self.is_leaf_node: + return self.n_elements + else: + m = len(self.center_left_values) + l = self.left_node.counts() + r = self.right_node.counts() + return (m, (l, r)) + +NODE_CLASSES['float64', 'both'] = Float64ClosedBothIntervalNode + + +cdef class Float64ClosedNeitherIntervalNode: + """Non-terminal node for an IntervalTree + + Categorizes intervals by those that fall to the left, those that fall to + the right, and those that overlap with the pivot. + """ + cdef: + Float64ClosedNeitherIntervalNode left_node, right_node + float64_t[:] center_left_values, center_right_values, left, right + int64_t[:] center_left_indices, center_right_indices, indices + float64_t min_left, max_right + readonly float64_t pivot + readonly int64_t n_elements, n_center, leaf_size + readonly bint is_leaf_node + + def __init__(self, + ndarray[float64_t, ndim=1] left, + ndarray[float64_t, ndim=1] right, + ndarray[int64_t, ndim=1] indices, + int64_t leaf_size): + + self.n_elements = len(left) + self.leaf_size = leaf_size + if left.size > 0: + self.min_left = left.min() + self.max_right = right.max() + else: + self.min_left = 0 + self.max_right = 0 + + if self.n_elements <= leaf_size: + # make this a terminal (leaf) node + self.is_leaf_node = True + self.left = left + self.right = right + self.indices = indices + self.n_center + else: + # calculate a pivot so we can create child nodes + self.is_leaf_node = False + self.pivot = np.median(left + right) / 2 + left_set, right_set, center_set = self.classify_intervals(left, right) + + self.left_node = self.new_child_node(left, right, indices, left_set) + self.right_node = self.new_child_node(left, right, indices, right_set) + + self.center_left_values, self.center_left_indices = \ + sort_values_and_indices(left, indices, center_set) + self.center_right_values, self.center_right_indices = \ + sort_values_and_indices(right, indices, center_set) + self.n_center = len(self.center_left_indices) + + @cython.wraparound(False) + @cython.boundscheck(False) + cdef classify_intervals(self, float64_t[:] left, float64_t[:] right): + """Classify the given intervals based upon whether they fall to the + left, right, or overlap with this node's pivot. + """ + cdef: + Int64Vector left_ind, right_ind, overlapping_ind + Py_ssize_t i + + left_ind = Int64Vector() + right_ind = Int64Vector() + overlapping_ind = Int64Vector() + + for i in range(self.n_elements): + if right[i] <= self.pivot: + left_ind.append(i) + elif self.pivot <= left[i]: + right_ind.append(i) + else: + overlapping_ind.append(i) + + return (left_ind.to_array(), + right_ind.to_array(), + overlapping_ind.to_array()) + + cdef new_child_node(self, + ndarray[float64_t, ndim=1] left, + ndarray[float64_t, ndim=1] right, + ndarray[int64_t, ndim=1] indices, + ndarray[int64_t, ndim=1] subset): + """Create a new child node. + """ + left = take(left, subset) + right = take(right, subset) + indices = take(indices, subset) + return Float64ClosedNeitherIntervalNode( + left, right, indices, self.leaf_size) + + @cython.wraparound(False) + @cython.boundscheck(False) + @cython.initializedcheck(False) + cpdef query(self, Int64Vector result, scalar64_t point): + """Recursively query this node and its sub-nodes for intervals that + overlap with the query point. + """ + cdef: + int64_t[:] indices + float64_t[:] values + Py_ssize_t i + + if self.is_leaf_node: + # Once we get down to a certain size, it doesn't make sense to + # continue the binary tree structure. Instead, we use linear + # search. + for i in range(self.n_elements): + if self.left[i] < point < self.right[i]: + result.append(self.indices[i]) + else: + # There are child nodes. Based on comparing our query to the pivot, + # look at the center values, then go to the relevant child. + if point < self.pivot: + values = self.center_left_values + indices = self.center_left_indices + for i in range(self.n_center): + if not values[i] < point: + break + result.append(indices[i]) + if point < self.left_node.max_right: + self.left_node.query(result, point) + elif point > self.pivot: + values = self.center_right_values + indices = self.center_right_indices + for i in range(self.n_center - 1, -1, -1): + if not point < values[i]: + break + result.append(indices[i]) + if self.right_node.min_left < point: + self.right_node.query(result, point) + else: + result.extend(self.center_left_indices) + + def __repr__(self): + if self.is_leaf_node: + return ('' % self.n_elements) + else: + n_left = self.left_node.n_elements + n_right = self.right_node.n_elements + n_center = self.n_elements - n_left - n_right + return ('' % + (self.pivot, self.n_elements, n_left, n_right, n_center)) + + def counts(self): + if self.is_leaf_node: + return self.n_elements + else: + m = len(self.center_left_values) + l = self.left_node.counts() + r = self.right_node.counts() + return (m, (l, r)) + +NODE_CLASSES['float64', 'neither'] = Float64ClosedNeitherIntervalNode + + +cdef class Int64ClosedLeftIntervalNode: + """Non-terminal node for an IntervalTree + + Categorizes intervals by those that fall to the left, those that fall to + the right, and those that overlap with the pivot. + """ + cdef: + Int64ClosedLeftIntervalNode left_node, right_node + int64_t[:] center_left_values, center_right_values, left, right + int64_t[:] center_left_indices, center_right_indices, indices + int64_t min_left, max_right + readonly int64_t pivot + readonly int64_t n_elements, n_center, leaf_size + readonly bint is_leaf_node + + def __init__(self, + ndarray[int64_t, ndim=1] left, + ndarray[int64_t, ndim=1] right, + ndarray[int64_t, ndim=1] indices, + int64_t leaf_size): + + self.n_elements = len(left) + self.leaf_size = leaf_size + if left.size > 0: + self.min_left = left.min() + self.max_right = right.max() + else: + self.min_left = 0 + self.max_right = 0 + + if self.n_elements <= leaf_size: + # make this a terminal (leaf) node + self.is_leaf_node = True + self.left = left + self.right = right + self.indices = indices + self.n_center + else: + # calculate a pivot so we can create child nodes + self.is_leaf_node = False + self.pivot = np.median(left + right) / 2 + left_set, right_set, center_set = self.classify_intervals(left, right) + + self.left_node = self.new_child_node(left, right, indices, left_set) + self.right_node = self.new_child_node(left, right, indices, right_set) + + self.center_left_values, self.center_left_indices = \ + sort_values_and_indices(left, indices, center_set) + self.center_right_values, self.center_right_indices = \ + sort_values_and_indices(right, indices, center_set) + self.n_center = len(self.center_left_indices) + + @cython.wraparound(False) + @cython.boundscheck(False) + cdef classify_intervals(self, int64_t[:] left, int64_t[:] right): + """Classify the given intervals based upon whether they fall to the + left, right, or overlap with this node's pivot. + """ + cdef: + Int64Vector left_ind, right_ind, overlapping_ind + Py_ssize_t i + + left_ind = Int64Vector() + right_ind = Int64Vector() + overlapping_ind = Int64Vector() + + for i in range(self.n_elements): + if right[i] <= self.pivot: + left_ind.append(i) + elif self.pivot < left[i]: + right_ind.append(i) + else: + overlapping_ind.append(i) + + return (left_ind.to_array(), + right_ind.to_array(), + overlapping_ind.to_array()) + + cdef new_child_node(self, + ndarray[int64_t, ndim=1] left, + ndarray[int64_t, ndim=1] right, + ndarray[int64_t, ndim=1] indices, + ndarray[int64_t, ndim=1] subset): + """Create a new child node. + """ + left = take(left, subset) + right = take(right, subset) + indices = take(indices, subset) + return Int64ClosedLeftIntervalNode( + left, right, indices, self.leaf_size) + + @cython.wraparound(False) + @cython.boundscheck(False) + @cython.initializedcheck(False) + cpdef query(self, Int64Vector result, scalar64_t point): + """Recursively query this node and its sub-nodes for intervals that + overlap with the query point. + """ + cdef: + int64_t[:] indices + int64_t[:] values + Py_ssize_t i + + if self.is_leaf_node: + # Once we get down to a certain size, it doesn't make sense to + # continue the binary tree structure. Instead, we use linear + # search. + for i in range(self.n_elements): + if self.left[i] <= point < self.right[i]: + result.append(self.indices[i]) + else: + # There are child nodes. Based on comparing our query to the pivot, + # look at the center values, then go to the relevant child. + if point < self.pivot: + values = self.center_left_values + indices = self.center_left_indices + for i in range(self.n_center): + if not values[i] <= point: + break + result.append(indices[i]) + if point < self.left_node.max_right: + self.left_node.query(result, point) + elif point > self.pivot: + values = self.center_right_values + indices = self.center_right_indices + for i in range(self.n_center - 1, -1, -1): + if not point < values[i]: + break + result.append(indices[i]) + if self.right_node.min_left <= point: + self.right_node.query(result, point) + else: + result.extend(self.center_left_indices) + + def __repr__(self): + if self.is_leaf_node: + return ('' % self.n_elements) + else: + n_left = self.left_node.n_elements + n_right = self.right_node.n_elements + n_center = self.n_elements - n_left - n_right + return ('' % + (self.pivot, self.n_elements, n_left, n_right, n_center)) + + def counts(self): + if self.is_leaf_node: + return self.n_elements + else: + m = len(self.center_left_values) + l = self.left_node.counts() + r = self.right_node.counts() + return (m, (l, r)) + +NODE_CLASSES['int64', 'left'] = Int64ClosedLeftIntervalNode + + +cdef class Int64ClosedRightIntervalNode: + """Non-terminal node for an IntervalTree + + Categorizes intervals by those that fall to the left, those that fall to + the right, and those that overlap with the pivot. + """ + cdef: + Int64ClosedRightIntervalNode left_node, right_node + int64_t[:] center_left_values, center_right_values, left, right + int64_t[:] center_left_indices, center_right_indices, indices + int64_t min_left, max_right + readonly int64_t pivot + readonly int64_t n_elements, n_center, leaf_size + readonly bint is_leaf_node + + def __init__(self, + ndarray[int64_t, ndim=1] left, + ndarray[int64_t, ndim=1] right, + ndarray[int64_t, ndim=1] indices, + int64_t leaf_size): + + self.n_elements = len(left) + self.leaf_size = leaf_size + if left.size > 0: + self.min_left = left.min() + self.max_right = right.max() + else: + self.min_left = 0 + self.max_right = 0 + + if self.n_elements <= leaf_size: + # make this a terminal (leaf) node + self.is_leaf_node = True + self.left = left + self.right = right + self.indices = indices + self.n_center + else: + # calculate a pivot so we can create child nodes + self.is_leaf_node = False + self.pivot = np.median(left + right) / 2 + left_set, right_set, center_set = self.classify_intervals(left, right) + + self.left_node = self.new_child_node(left, right, indices, left_set) + self.right_node = self.new_child_node(left, right, indices, right_set) + + self.center_left_values, self.center_left_indices = \ + sort_values_and_indices(left, indices, center_set) + self.center_right_values, self.center_right_indices = \ + sort_values_and_indices(right, indices, center_set) + self.n_center = len(self.center_left_indices) + + @cython.wraparound(False) + @cython.boundscheck(False) + cdef classify_intervals(self, int64_t[:] left, int64_t[:] right): + """Classify the given intervals based upon whether they fall to the + left, right, or overlap with this node's pivot. + """ + cdef: + Int64Vector left_ind, right_ind, overlapping_ind + Py_ssize_t i + + left_ind = Int64Vector() + right_ind = Int64Vector() + overlapping_ind = Int64Vector() + + for i in range(self.n_elements): + if right[i] < self.pivot: + left_ind.append(i) + elif self.pivot <= left[i]: + right_ind.append(i) + else: + overlapping_ind.append(i) + + return (left_ind.to_array(), + right_ind.to_array(), + overlapping_ind.to_array()) + + cdef new_child_node(self, + ndarray[int64_t, ndim=1] left, + ndarray[int64_t, ndim=1] right, + ndarray[int64_t, ndim=1] indices, + ndarray[int64_t, ndim=1] subset): + """Create a new child node. + """ + left = take(left, subset) + right = take(right, subset) + indices = take(indices, subset) + return Int64ClosedRightIntervalNode( + left, right, indices, self.leaf_size) + + @cython.wraparound(False) + @cython.boundscheck(False) + @cython.initializedcheck(False) + cpdef query(self, Int64Vector result, scalar64_t point): + """Recursively query this node and its sub-nodes for intervals that + overlap with the query point. + """ + cdef: + int64_t[:] indices + int64_t[:] values + Py_ssize_t i + + if self.is_leaf_node: + # Once we get down to a certain size, it doesn't make sense to + # continue the binary tree structure. Instead, we use linear + # search. + for i in range(self.n_elements): + if self.left[i] < point <= self.right[i]: + result.append(self.indices[i]) + else: + # There are child nodes. Based on comparing our query to the pivot, + # look at the center values, then go to the relevant child. + if point < self.pivot: + values = self.center_left_values + indices = self.center_left_indices + for i in range(self.n_center): + if not values[i] < point: + break + result.append(indices[i]) + if point <= self.left_node.max_right: + self.left_node.query(result, point) + elif point > self.pivot: + values = self.center_right_values + indices = self.center_right_indices + for i in range(self.n_center - 1, -1, -1): + if not point <= values[i]: + break + result.append(indices[i]) + if self.right_node.min_left < point: + self.right_node.query(result, point) + else: + result.extend(self.center_left_indices) + + def __repr__(self): + if self.is_leaf_node: + return ('' % self.n_elements) + else: + n_left = self.left_node.n_elements + n_right = self.right_node.n_elements + n_center = self.n_elements - n_left - n_right + return ('' % + (self.pivot, self.n_elements, n_left, n_right, n_center)) + + def counts(self): + if self.is_leaf_node: + return self.n_elements + else: + m = len(self.center_left_values) + l = self.left_node.counts() + r = self.right_node.counts() + return (m, (l, r)) + +NODE_CLASSES['int64', 'right'] = Int64ClosedRightIntervalNode + + +cdef class Int64ClosedBothIntervalNode: + """Non-terminal node for an IntervalTree + + Categorizes intervals by those that fall to the left, those that fall to + the right, and those that overlap with the pivot. + """ + cdef: + Int64ClosedBothIntervalNode left_node, right_node + int64_t[:] center_left_values, center_right_values, left, right + int64_t[:] center_left_indices, center_right_indices, indices + int64_t min_left, max_right + readonly int64_t pivot + readonly int64_t n_elements, n_center, leaf_size + readonly bint is_leaf_node + + def __init__(self, + ndarray[int64_t, ndim=1] left, + ndarray[int64_t, ndim=1] right, + ndarray[int64_t, ndim=1] indices, + int64_t leaf_size): + + self.n_elements = len(left) + self.leaf_size = leaf_size + if left.size > 0: + self.min_left = left.min() + self.max_right = right.max() + else: + self.min_left = 0 + self.max_right = 0 + + if self.n_elements <= leaf_size: + # make this a terminal (leaf) node + self.is_leaf_node = True + self.left = left + self.right = right + self.indices = indices + self.n_center + else: + # calculate a pivot so we can create child nodes + self.is_leaf_node = False + self.pivot = np.median(left + right) / 2 + left_set, right_set, center_set = self.classify_intervals(left, right) + + self.left_node = self.new_child_node(left, right, indices, left_set) + self.right_node = self.new_child_node(left, right, indices, right_set) + + self.center_left_values, self.center_left_indices = \ + sort_values_and_indices(left, indices, center_set) + self.center_right_values, self.center_right_indices = \ + sort_values_and_indices(right, indices, center_set) + self.n_center = len(self.center_left_indices) + + @cython.wraparound(False) + @cython.boundscheck(False) + cdef classify_intervals(self, int64_t[:] left, int64_t[:] right): + """Classify the given intervals based upon whether they fall to the + left, right, or overlap with this node's pivot. + """ + cdef: + Int64Vector left_ind, right_ind, overlapping_ind + Py_ssize_t i + + left_ind = Int64Vector() + right_ind = Int64Vector() + overlapping_ind = Int64Vector() + + for i in range(self.n_elements): + if right[i] < self.pivot: + left_ind.append(i) + elif self.pivot < left[i]: + right_ind.append(i) + else: + overlapping_ind.append(i) + + return (left_ind.to_array(), + right_ind.to_array(), + overlapping_ind.to_array()) + + cdef new_child_node(self, + ndarray[int64_t, ndim=1] left, + ndarray[int64_t, ndim=1] right, + ndarray[int64_t, ndim=1] indices, + ndarray[int64_t, ndim=1] subset): + """Create a new child node. + """ + left = take(left, subset) + right = take(right, subset) + indices = take(indices, subset) + return Int64ClosedBothIntervalNode( + left, right, indices, self.leaf_size) + + @cython.wraparound(False) + @cython.boundscheck(False) + @cython.initializedcheck(False) + cpdef query(self, Int64Vector result, scalar64_t point): + """Recursively query this node and its sub-nodes for intervals that + overlap with the query point. + """ + cdef: + int64_t[:] indices + int64_t[:] values + Py_ssize_t i + + if self.is_leaf_node: + # Once we get down to a certain size, it doesn't make sense to + # continue the binary tree structure. Instead, we use linear + # search. + for i in range(self.n_elements): + if self.left[i] <= point <= self.right[i]: + result.append(self.indices[i]) + else: + # There are child nodes. Based on comparing our query to the pivot, + # look at the center values, then go to the relevant child. + if point < self.pivot: + values = self.center_left_values + indices = self.center_left_indices + for i in range(self.n_center): + if not values[i] <= point: + break + result.append(indices[i]) + if point <= self.left_node.max_right: + self.left_node.query(result, point) + elif point > self.pivot: + values = self.center_right_values + indices = self.center_right_indices + for i in range(self.n_center - 1, -1, -1): + if not point <= values[i]: + break + result.append(indices[i]) + if self.right_node.min_left <= point: + self.right_node.query(result, point) + else: + result.extend(self.center_left_indices) + + def __repr__(self): + if self.is_leaf_node: + return ('' % self.n_elements) + else: + n_left = self.left_node.n_elements + n_right = self.right_node.n_elements + n_center = self.n_elements - n_left - n_right + return ('' % + (self.pivot, self.n_elements, n_left, n_right, n_center)) + + def counts(self): + if self.is_leaf_node: + return self.n_elements + else: + m = len(self.center_left_values) + l = self.left_node.counts() + r = self.right_node.counts() + return (m, (l, r)) + +NODE_CLASSES['int64', 'both'] = Int64ClosedBothIntervalNode + + +cdef class Int64ClosedNeitherIntervalNode: + """Non-terminal node for an IntervalTree + + Categorizes intervals by those that fall to the left, those that fall to + the right, and those that overlap with the pivot. + """ + cdef: + Int64ClosedNeitherIntervalNode left_node, right_node + int64_t[:] center_left_values, center_right_values, left, right + int64_t[:] center_left_indices, center_right_indices, indices + int64_t min_left, max_right + readonly int64_t pivot + readonly int64_t n_elements, n_center, leaf_size + readonly bint is_leaf_node + + def __init__(self, + ndarray[int64_t, ndim=1] left, + ndarray[int64_t, ndim=1] right, + ndarray[int64_t, ndim=1] indices, + int64_t leaf_size): + + self.n_elements = len(left) + self.leaf_size = leaf_size + if left.size > 0: + self.min_left = left.min() + self.max_right = right.max() + else: + self.min_left = 0 + self.max_right = 0 + + if self.n_elements <= leaf_size: + # make this a terminal (leaf) node + self.is_leaf_node = True + self.left = left + self.right = right + self.indices = indices + self.n_center + else: + # calculate a pivot so we can create child nodes + self.is_leaf_node = False + self.pivot = np.median(left + right) / 2 + left_set, right_set, center_set = self.classify_intervals(left, right) + + self.left_node = self.new_child_node(left, right, indices, left_set) + self.right_node = self.new_child_node(left, right, indices, right_set) + + self.center_left_values, self.center_left_indices = \ + sort_values_and_indices(left, indices, center_set) + self.center_right_values, self.center_right_indices = \ + sort_values_and_indices(right, indices, center_set) + self.n_center = len(self.center_left_indices) + + @cython.wraparound(False) + @cython.boundscheck(False) + cdef classify_intervals(self, int64_t[:] left, int64_t[:] right): + """Classify the given intervals based upon whether they fall to the + left, right, or overlap with this node's pivot. + """ + cdef: + Int64Vector left_ind, right_ind, overlapping_ind + Py_ssize_t i + + left_ind = Int64Vector() + right_ind = Int64Vector() + overlapping_ind = Int64Vector() + + for i in range(self.n_elements): + if right[i] <= self.pivot: + left_ind.append(i) + elif self.pivot <= left[i]: + right_ind.append(i) + else: + overlapping_ind.append(i) + + return (left_ind.to_array(), + right_ind.to_array(), + overlapping_ind.to_array()) + + cdef new_child_node(self, + ndarray[int64_t, ndim=1] left, + ndarray[int64_t, ndim=1] right, + ndarray[int64_t, ndim=1] indices, + ndarray[int64_t, ndim=1] subset): + """Create a new child node. + """ + left = take(left, subset) + right = take(right, subset) + indices = take(indices, subset) + return Int64ClosedNeitherIntervalNode( + left, right, indices, self.leaf_size) + + @cython.wraparound(False) + @cython.boundscheck(False) + @cython.initializedcheck(False) + cpdef query(self, Int64Vector result, scalar64_t point): + """Recursively query this node and its sub-nodes for intervals that + overlap with the query point. + """ + cdef: + int64_t[:] indices + int64_t[:] values + Py_ssize_t i + + if self.is_leaf_node: + # Once we get down to a certain size, it doesn't make sense to + # continue the binary tree structure. Instead, we use linear + # search. + for i in range(self.n_elements): + if self.left[i] < point < self.right[i]: + result.append(self.indices[i]) + else: + # There are child nodes. Based on comparing our query to the pivot, + # look at the center values, then go to the relevant child. + if point < self.pivot: + values = self.center_left_values + indices = self.center_left_indices + for i in range(self.n_center): + if not values[i] < point: + break + result.append(indices[i]) + if point < self.left_node.max_right: + self.left_node.query(result, point) + elif point > self.pivot: + values = self.center_right_values + indices = self.center_right_indices + for i in range(self.n_center - 1, -1, -1): + if not point < values[i]: + break + result.append(indices[i]) + if self.right_node.min_left < point: + self.right_node.query(result, point) + else: + result.extend(self.center_left_indices) + + def __repr__(self): + if self.is_leaf_node: + return ('' % self.n_elements) + else: + n_left = self.left_node.n_elements + n_right = self.right_node.n_elements + n_center = self.n_elements - n_left - n_right + return ('' % + (self.pivot, self.n_elements, n_left, n_right, n_center)) + + def counts(self): + if self.is_leaf_node: + return self.n_elements + else: + m = len(self.center_left_values) + l = self.left_node.counts() + r = self.right_node.counts() + return (m, (l, r)) + +NODE_CLASSES['int64', 'neither'] = Int64ClosedNeitherIntervalNode + + diff --git a/pandas/tests/groupby/test_groupby.py b/pandas/tests/groupby/test_groupby.py index a355dca3029c7a..ba9bd8b1705256 100644 --- a/pandas/tests/groupby/test_groupby.py +++ b/pandas/tests/groupby/test_groupby.py @@ -4058,6 +4058,49 @@ def test_transform_doesnt_clobber_ints(self): expected = gb2.transform('mean') tm.assert_frame_equal(result, expected) + def test_groupby_categorical_two_columns(self): + + # https://github.com/pydata/pandas/issues/8138 + d = {'cat': pd.Categorical(["a","b","a","b"], categories=["a", "b", "c"], ordered=True), + 'ints': [1, 1, 2, 2],'val': [10, 20, 30, 40]} + test = pd.DataFrame(d) + + # Grouping on a single column + groups_single_key = test.groupby("cat") + res = groups_single_key.agg('mean') + exp = DataFrame({"ints":[1.5,1.5,np.nan], "val":[20,30,np.nan]}, + index=pd.CategoricalIndex(["a", "b", "c"], name="cat")) + tm.assert_frame_equal(res, exp) + + # Grouping on two columns + groups_double_key = test.groupby(["cat","ints"]) + res = groups_double_key.agg('mean') + exp = DataFrame({"val":[10,30,20,40,np.nan,np.nan], + "cat": ["a","a","b","b","c","c"], + "ints": [1,2,1,2,1,2]}).set_index(["cat","ints"]) + tm.assert_frame_equal(res, exp) + + # GH 10132 + for key in [('a', 1), ('b', 2), ('b', 1), ('a', 2)]: + c, i = key + result = groups_double_key.get_group(key) + expected = test[(test.cat == c) & (test.ints == i)] + assert_frame_equal(result, expected) + + d = {'C1': [3, 3, 4, 5], 'C2': [1, 2, 3, 4], 'C3': [10, 100, 200, 34]} + test = pd.DataFrame(d) + values = pd.cut(test['C1'], [1, 2, 3, 6], labels=pd.Categorical(['a', 'b', 'c'])) + values.name = "cat" + groups_double_key = test.groupby([values,'C2']) + + res = groups_double_key.agg('mean') + nan = np.nan + idx = MultiIndex.from_product([['a', 'b', 'c'], [1, 2, 3, 4]], + names=["cat", "C2"]) + exp = DataFrame({"C1":[nan,nan,nan,nan, 3, 3,nan,nan, nan,nan, 4, 5], + "C3":[nan,nan,nan,nan, 10,100,nan,nan, nan,nan,200,34]}, index=idx) + tm.assert_frame_equal(res, exp) + def test_groupby_apply_all_none(self): # Tests to make sure no errors if apply function returns all None # values. Issue 9684. diff --git a/pandas/tests/test_algos.py b/pandas/tests/test_algos.py index 7a3cc3e2c3cd7b..3999e605fe2d23 100644 --- a/pandas/tests/test_algos.py +++ b/pandas/tests/test_algos.py @@ -475,24 +475,24 @@ def test_value_counts(self): arr = np.random.randn(4) factor = cut(arr, 4) - tm.assertIsInstance(factor, Categorical) + # tm.assertIsInstance(factor, n) result = algos.value_counts(factor) - cats = ['(-1.194, -0.535]', '(-0.535, 0.121]', '(0.121, 0.777]', - '(0.777, 1.433]'] - expected_index = CategoricalIndex(cats, cats, ordered=True) - expected = Series([1, 1, 1, 1], index=expected_index) + breaks = [-1.192, -0.535, 0.121, 0.777, 1.433] + expected_index = pd.IntervalIndex.from_breaks(breaks) + expected = Series([1, 1, 1, 1], + index=expected_index) tm.assert_series_equal(result.sort_index(), expected.sort_index()) def test_value_counts_bins(self): s = [1, 2, 3, 4] result = algos.value_counts(s, bins=1) self.assertEqual(result.tolist(), [4]) - self.assertEqual(result.index[0], 0.997) + self.assertEqual(result.index[0], pd.Interval(0.999, 4.0)) result = algos.value_counts(s, bins=2, sort=False) self.assertEqual(result.tolist(), [2, 2]) - self.assertEqual(result.index[0], 0.997) - self.assertEqual(result.index[1], 2.5) + self.assertEqual(result.index.min(), pd.Interval(0.999, 2.5)) + self.assertEqual(result.index.max(), pd.Interval(2.5, 4.0)) def test_value_counts_dtypes(self): result = algos.value_counts([1, 1.]) diff --git a/pandas/tests/test_base.py b/pandas/tests/test_base.py index 68db0d19344b99..02afd12d64f346 100644 --- a/pandas/tests/test_base.py +++ b/pandas/tests/test_base.py @@ -13,7 +13,7 @@ needs_i8_conversion) import pandas.util.testing as tm from pandas import (Series, Index, DatetimeIndex, TimedeltaIndex, PeriodIndex, - Timedelta) + Timedelta, IntervalIndex, Interval) from pandas.compat import StringIO from pandas.compat.numpy import np_array_datetime64_compat from pandas.core.base import PandasDelegate, NoNewAttributesMixin @@ -31,6 +31,7 @@ def test_string_methods_dont_fail(self): unicode(self.container) # noqa def test_tricky_container(self): + import nose if not hasattr(self, 'unicode_container'): pytest.skip('Need unicode_container to test with this') repr(self.unicode_container) @@ -575,10 +576,10 @@ def test_value_counts_bins(self): s1 = Series([1, 1, 2, 3]) res1 = s1.value_counts(bins=1) - exp1 = Series({0.998: 4}) + exp1 = Series({Interval(0.999, 3.0): 4}) tm.assert_series_equal(res1, exp1) res1n = s1.value_counts(bins=1, normalize=True) - exp1n = Series({0.998: 1.0}) + exp1n = Series({Interval(0.999, 3.0): 1.0}) tm.assert_series_equal(res1n, exp1n) if isinstance(s1, Index): @@ -590,17 +591,11 @@ def test_value_counts_bins(self): self.assertEqual(s1.nunique(), 3) res4 = s1.value_counts(bins=4) - exp4 = Series({0.998: 2, - 1.5: 1, - 2.0: 0, - 2.5: 1}, index=[0.998, 2.5, 1.5, 2.0]) + intervals = IntervalIndex.from_breaks([0.999, 1.5, 2.0, 2.5, 3.0]) + exp4 = Series([2, 1, 1], index=intervals.take([0, 3, 1])) tm.assert_series_equal(res4, exp4) res4n = s1.value_counts(bins=4, normalize=True) - exp4n = Series( - {0.998: 0.5, - 1.5: 0.25, - 2.0: 0.0, - 2.5: 0.25}, index=[0.998, 2.5, 1.5, 2.0]) + exp4n = Series([0.5, 0.25, 0.25], index=intervals.take([0, 3, 1])) tm.assert_series_equal(res4n, exp4n) # handle NA's properly diff --git a/pandas/tests/test_categorical.py b/pandas/tests/test_categorical.py index 6c8aeba704c7be..b3c349876c624c 100644 --- a/pandas/tests/test_categorical.py +++ b/pandas/tests/test_categorical.py @@ -17,7 +17,7 @@ import pandas.compat as compat import pandas.util.testing as tm from pandas import (Categorical, Index, Series, DataFrame, PeriodIndex, - Timestamp, CategoricalIndex, isnull) + Timestamp, CategoricalIndex, Interval, isnull) from pandas.compat import range, lrange, u, PY3 from pandas.core.config import option_context @@ -1730,11 +1730,11 @@ def setUp(self): self.factor = Categorical(['a', 'b', 'b', 'a', 'a', 'c', 'c', 'c']) df = DataFrame({'value': np.random.randint(0, 10000, 100)}) - labels = ["{0} - {1}".format(i, i + 499) for i in range(0, 10000, 500)] + labels = [ "{0} - {1}".format(i, i + 499) for i in range(0, 10000, 500) ] + cat_labels = Categorical(labels, labels) df = df.sort_values(by=['value'], ascending=True) - df['value_group'] = pd.cut(df.value, range(0, 10500, 500), right=False, - labels=labels) + df['value_group'] = pd.cut(df.value, range(0, 10500, 500), right=False, labels=cat_labels) self.cat = df def test_dtypes(self): @@ -2160,9 +2160,8 @@ def test_series_functions_no_warnings(self): def test_assignment_to_dataframe(self): # assignment - df = DataFrame({'value': np.array(np.random.randint(0, 10000, 100), - dtype='int32')}) - labels = ["{0} - {1}".format(i, i + 499) for i in range(0, 10000, 500)] + df = DataFrame({'value': np.array(np.random.randint(0, 10000, 100),dtype='int32')}) + labels = Categorical(["{0} - {1}".format(i, i + 499) for i in range(0, 10000, 500)]) df = df.sort_values(by=['value'], ascending=True) s = pd.cut(df.value, range(0, 10500, 500), right=False, labels=labels) @@ -3156,7 +3155,7 @@ def f(x): # GH 9603 df = pd.DataFrame({'a': [1, 0, 0, 0]}) - c = pd.cut(df.a, [0, 1, 2, 3, 4]) + c = pd.cut(df.a, [0, 1, 2, 3, 4], labels=pd.Categorical(list('abcd'))) result = df.groupby(c).apply(len) exp_index = pd.CategoricalIndex(c.values.categories, @@ -3273,7 +3272,7 @@ def test_slicing(self): df = DataFrame({'value': (np.arange(100) + 1).astype('int64')}) df['D'] = pd.cut(df.value, bins=[0, 25, 50, 75, 100]) - expected = Series([11, '(0, 25]'], index=['value', 'D'], name=10) + expected = Series([11, Interval(0, 25)], index=['value','D'], name=10) result = df.iloc[10] tm.assert_series_equal(result, expected) @@ -3283,7 +3282,7 @@ def test_slicing(self): result = df.iloc[10:20] tm.assert_frame_equal(result, expected) - expected = Series([9, '(0, 25]'], index=['value', 'D'], name=8) + expected = Series([9, Interval(0, 25)],index=['value', 'D'], name=8) result = df.loc[8] tm.assert_series_equal(result, expected) diff --git a/pandas/tests/test_interval.py b/pandas/tests/test_interval.py new file mode 100644 index 00000000000000..1b52e2629b38cb --- /dev/null +++ b/pandas/tests/test_interval.py @@ -0,0 +1,591 @@ +from __future__ import division +import numpy as np + +from pandas.core.interval import Interval, IntervalIndex +from pandas.core.index import Index +from pandas.lib import IntervalTree + +import pandas.util.testing as tm +import pandas as pd + + +class TestInterval(tm.TestCase): + def setUp(self): + self.interval = Interval(0, 1) + + def test_properties(self): + self.assertEqual(self.interval.closed, 'right') + self.assertEqual(self.interval.left, 0) + self.assertEqual(self.interval.right, 1) + self.assertEqual(self.interval.mid, 0.5) + + def test_repr(self): + self.assertEqual(repr(self.interval), + "Interval(0, 1, closed='right')") + self.assertEqual(str(self.interval), "(0, 1]") + + interval_left = Interval(0, 1, closed='left') + self.assertEqual(repr(interval_left), + "Interval(0, 1, closed='left')") + self.assertEqual(str(interval_left), "[0, 1)") + + def test_contains(self): + self.assertIn(0.5, self.interval) + self.assertIn(1, self.interval) + self.assertNotIn(0, self.interval) + self.assertRaises(TypeError, lambda: self.interval in self.interval) + + interval = Interval(0, 1, closed='both') + self.assertIn(0, interval) + self.assertIn(1, interval) + + interval = Interval(0, 1, closed='neither') + self.assertNotIn(0, interval) + self.assertIn(0.5, interval) + self.assertNotIn(1, interval) + + def test_equal(self): + self.assertEqual(Interval(0, 1), Interval(0, 1, closed='right')) + self.assertNotEqual(Interval(0, 1), Interval(0, 1, closed='left')) + self.assertNotEqual(Interval(0, 1), 0) + + def test_comparison(self): + with self.assertRaisesRegexp(TypeError, 'unorderable types'): + Interval(0, 1) < 2 + + self.assertTrue(Interval(0, 1) < Interval(1, 2)) + self.assertTrue(Interval(0, 1) < Interval(0, 2)) + self.assertTrue(Interval(0, 1) < Interval(0.5, 1.5)) + self.assertTrue(Interval(0, 1) <= Interval(0, 1)) + self.assertTrue(Interval(0, 1) > Interval(-1, 2)) + self.assertTrue(Interval(0, 1) >= Interval(0, 1)) + + def test_hash(self): + # should not raise + hash(self.interval) + + def test_math_add(self): + expected = Interval(1, 2) + actual = self.interval + 1 + self.assertEqual(expected, actual) + + expected = Interval(1, 2) + actual = 1 + self.interval + self.assertEqual(expected, actual) + + actual = self.interval + actual += 1 + self.assertEqual(expected, actual) + + with self.assertRaises(TypeError): + self.interval + Interval(1, 2) + + def test_math_sub(self): + expected = Interval(-1, 0) + actual = self.interval - 1 + self.assertEqual(expected, actual) + + actual = self.interval + actual -= 1 + self.assertEqual(expected, actual) + + with self.assertRaises(TypeError): + self.interval - Interval(1, 2) + + def test_math_mult(self): + expected = Interval(0, 2) + actual = self.interval * 2 + self.assertEqual(expected, actual) + + expected = Interval(0, 2) + actual = 2 * self.interval + self.assertEqual(expected, actual) + + actual = self.interval + actual *= 2 + self.assertEqual(expected, actual) + + with self.assertRaises(TypeError): + self.interval * Interval(1, 2) + + def test_math_div(self): + expected = Interval(0, 0.5) + actual = self.interval / 2.0 + self.assertEqual(expected, actual) + + actual = self.interval + actual /= 2.0 + self.assertEqual(expected, actual) + + with self.assertRaises(TypeError): + self.interval / Interval(1, 2) + + +class TestIntervalTree(tm.TestCase): + def setUp(self): + self.tree = IntervalTree(np.arange(5), np.arange(5) + 2) + + def test_get_loc(self): + self.assert_numpy_array_equal(self.tree.get_loc(1), [0]) + self.assert_numpy_array_equal(np.sort(self.tree.get_loc(2)), [0, 1]) + with self.assertRaises(KeyError): + self.tree.get_loc(-1) + + def test_get_indexer(self): + self.assert_numpy_array_equal( + self.tree.get_indexer(np.array([1.0, 5.5, 6.5])), [0, 4, -1]) + with self.assertRaises(KeyError): + self.tree.get_indexer(np.array([3.0])) + + def test_get_indexer_non_unique(self): + indexer, missing = self.tree.get_indexer_non_unique( + np.array([1.0, 2.0, 6.5])) + self.assert_numpy_array_equal(indexer[:1], [0]) + self.assert_numpy_array_equal(np.sort(indexer[1:3]), [0, 1]) + self.assert_numpy_array_equal(np.sort(indexer[3:]), [-1]) + self.assert_numpy_array_equal(missing, [2]) + + def test_duplicates(self): + tree = IntervalTree([0, 0, 0], [1, 1, 1]) + self.assert_numpy_array_equal(np.sort(tree.get_loc(0.5)), [0, 1, 2]) + + with self.assertRaises(KeyError): + tree.get_indexer(np.array([0.5])) + + indexer, missing = tree.get_indexer_non_unique(np.array([0.5])) + self.assert_numpy_array_equal(np.sort(indexer), [0, 1, 2]) + self.assert_numpy_array_equal(missing, []) + + def test_get_loc_closed(self): + for closed in ['left', 'right', 'both', 'neither']: + tree = IntervalTree([0], [1], closed=closed) + for p, errors in [(0, tree.open_left), + (1, tree.open_right)]: + if errors: + with self.assertRaises(KeyError): + tree.get_loc(p) + else: + self.assert_numpy_array_equal(tree.get_loc(p), + np.array([0])) + + def test_get_indexer_closed(self): + x = np.arange(1000) + found = x + not_found = -np.ones(1000) + for leaf_size in [1, 10, 100, 10000]: + for closed in ['left', 'right', 'both', 'neither']: + tree = IntervalTree(x, x + 0.5, closed=closed, + leaf_size=leaf_size) + self.assert_numpy_array_equal(found, tree.get_indexer(x + 0.25)) + + expected = found if tree.closed_left else not_found + self.assert_numpy_array_equal(expected, tree.get_indexer(x + 0.0)) + + expected = found if tree.closed_right else not_found + self.assert_numpy_array_equal(expected, tree.get_indexer(x + 0.5)) + + +class TestIntervalIndex(tm.TestCase): + def setUp(self): + self.index = IntervalIndex([0, 1], [1, 2]) + + def test_constructors(self): + expected = self.index + actual = IntervalIndex.from_breaks(np.arange(3), closed='right') + self.assertTrue(expected.equals(actual)) + + alternate = IntervalIndex.from_breaks(np.arange(3), closed='left') + self.assertFalse(expected.equals(alternate)) + + actual = IntervalIndex.from_intervals([Interval(0, 1), Interval(1, 2)]) + self.assertTrue(expected.equals(actual)) + + self.assertRaises(ValueError, IntervalIndex, [0], [1], closed='invalid') + + # TODO: fix all these commented out tests (here and below) + + intervals = [Interval(0, 1), Interval(1, 2, closed='left')] + with self.assertRaises(ValueError): + IntervalIndex.from_intervals(intervals) + + with self.assertRaises(ValueError): + IntervalIndex([0, 10], [3, 5]) + + actual = Index([Interval(0, 1), Interval(1, 2)]) + self.assertIsInstance(actual, IntervalIndex) + self.assertTrue(expected.equals(actual)) + + actual = Index(expected) + self.assertIsInstance(actual, IntervalIndex) + self.assertTrue(expected.equals(actual)) + + # no point in nesting periods in an IntervalIndex + # self.assertRaises(ValueError, IntervalIndex.from_breaks, + # pd.period_range('2000-01-01', periods=3)) + + def test_properties(self): + self.assertEqual(len(self.index), 2) + self.assertEqual(self.index.size, 2) + + self.assert_numpy_array_equal(self.index.left, [0, 1]) + self.assertIsInstance(self.index.left, Index) + + self.assert_numpy_array_equal(self.index.right, [1, 2]) + self.assertIsInstance(self.index.right, Index) + + self.assert_numpy_array_equal(self.index.mid, [0.5, 1.5]) + self.assertIsInstance(self.index.mid, Index) + + self.assertEqual(self.index.closed, 'right') + + expected = np.array([Interval(0, 1), Interval(1, 2)], dtype=object) + self.assert_numpy_array_equal(np.asarray(self.index), expected) + self.assert_numpy_array_equal(self.index.values, expected) + + def test_copy(self): + actual = self.index.copy() + self.assertTrue(actual.equals(self.index)) + + actual = self.index.copy(deep=True) + self.assertTrue(actual.equals(self.index)) + self.assertIsNot(actual.left, self.index.left) + + def test_delete(self): + expected = IntervalIndex.from_breaks([1, 2]) + actual = self.index.delete(0) + self.assertTrue(expected.equals(actual)) + + def test_insert(self): + expected = IntervalIndex.from_breaks(range(4)) + actual = self.index.insert(2, Interval(2, 3)) + self.assertTrue(expected.equals(actual)) + + self.assertRaises(ValueError, self.index.insert, 0, 1) + self.assertRaises(ValueError, self.index.insert, 0, + Interval(2, 3, closed='left')) + + def test_take(self): + actual = self.index.take([0, 1]) + self.assertTrue(self.index.equals(actual)) + + expected = IntervalIndex([0, 0, 1], [1, 1, 2]) + actual = self.index.take([0, 0, 1]) + self.assertTrue(expected.equals(actual)) + + def test_monotonic_and_unique(self): + self.assertTrue(self.index.is_monotonic) + self.assertTrue(self.index.is_unique) + + idx = IntervalIndex.from_tuples([(0, 1), (0.5, 1.5)]) + self.assertTrue(idx.is_monotonic) + self.assertTrue(idx.is_unique) + + idx = IntervalIndex.from_tuples([(0, 1), (2, 3), (1, 2)]) + self.assertFalse(idx.is_monotonic) + self.assertTrue(idx.is_unique) + + idx = IntervalIndex.from_tuples([(0, 2), (0, 2)]) + self.assertFalse(idx.is_unique) + self.assertTrue(idx.is_monotonic) + + def test_repr(self): + expected = ("IntervalIndex(left=[0, 1],\n right=[1, 2]," + "\n closed='right')") + IntervalIndex((0, 1), (1, 2), closed='right') + self.assertEqual(repr(self.index), expected) + + def test_get_loc_value(self): + self.assertRaises(KeyError, self.index.get_loc, 0) + self.assertEqual(self.index.get_loc(0.5), 0) + self.assertEqual(self.index.get_loc(1), 0) + self.assertEqual(self.index.get_loc(1.5), 1) + self.assertEqual(self.index.get_loc(2), 1) + self.assertRaises(KeyError, self.index.get_loc, -1) + self.assertRaises(KeyError, self.index.get_loc, 3) + + idx = IntervalIndex.from_tuples([(0, 2), (1, 3)]) + self.assertEqual(idx.get_loc(0.5), 0) + self.assertEqual(idx.get_loc(1), 0) + self.assert_numpy_array_equal(idx.get_loc(1.5), [0, 1]) + self.assert_numpy_array_equal(np.sort(idx.get_loc(2)), [0, 1]) + self.assertEqual(idx.get_loc(3), 1) + self.assertRaises(KeyError, idx.get_loc, 3.5) + + idx = IntervalIndex([0, 2], [1, 3]) + self.assertRaises(KeyError, idx.get_loc, 1.5) + + def slice_locs_cases(self, breaks): + # TODO: same tests for more index types + index = IntervalIndex.from_breaks([0, 1, 2], closed='right') + self.assertEqual(index.slice_locs(), (0, 2)) + self.assertEqual(index.slice_locs(0, 1), (0, 1)) + self.assertEqual(index.slice_locs(1, 1), (0, 1)) + self.assertEqual(index.slice_locs(0, 2), (0, 2)) + self.assertEqual(index.slice_locs(0.5, 1.5), (0, 2)) + self.assertEqual(index.slice_locs(0, 0.5), (0, 1)) + self.assertEqual(index.slice_locs(start=1), (0, 2)) + self.assertEqual(index.slice_locs(start=1.2), (1, 2)) + self.assertEqual(index.slice_locs(end=1), (0, 1)) + self.assertEqual(index.slice_locs(end=1.1), (0, 2)) + self.assertEqual(index.slice_locs(end=1.0), (0, 1)) + self.assertEqual(*index.slice_locs(-1, -1)) + + index = IntervalIndex.from_breaks([0, 1, 2], closed='neither') + self.assertEqual(index.slice_locs(0, 1), (0, 1)) + self.assertEqual(index.slice_locs(0, 2), (0, 2)) + self.assertEqual(index.slice_locs(0.5, 1.5), (0, 2)) + self.assertEqual(index.slice_locs(1, 1), (1, 1)) + self.assertEqual(index.slice_locs(1, 2), (1, 2)) + + index = IntervalIndex.from_breaks([0, 1, 2], closed='both') + self.assertEqual(index.slice_locs(1, 1), (0, 2)) + self.assertEqual(index.slice_locs(1, 2), (0, 2)) + + def test_slice_locs_int64(self): + self.slice_locs_cases([0, 1, 2]) + + def test_slice_locs_float64(self): + self.slice_locs_cases([0.0, 1.0, 2.0]) + + def slice_locs_decreasing_cases(self, tuples): + index = IntervalIndex.from_tuples(tuples) + self.assertEqual(index.slice_locs(1.5, 0.5), (1, 3)) + self.assertEqual(index.slice_locs(2, 0), (1, 3)) + self.assertEqual(index.slice_locs(2, 1), (1, 3)) + self.assertEqual(index.slice_locs(3, 1.1), (0, 3)) + self.assertEqual(index.slice_locs(3, 3), (0, 2)) + self.assertEqual(index.slice_locs(3.5, 3.3), (0, 1)) + self.assertEqual(index.slice_locs(1, -3), (2, 3)) + self.assertEqual(*index.slice_locs(-1, -1)) + + def test_slice_locs_decreasing_int64(self): + self.slice_locs_cases([(2, 4), (1, 3), (0, 2)]) + + def test_slice_locs_decreasing_float64(self): + self.slice_locs_cases([(2., 4.), (1., 3.), (0., 2.)]) + + def test_slice_locs_fails(self): + index = IntervalIndex.from_tuples([(1, 2), (0, 1), (2, 3)]) + with self.assertRaises(KeyError): + index.slice_locs(1, 2) + + def test_get_loc_interval(self): + self.assertEqual(self.index.get_loc(Interval(0, 1)), 0) + self.assertEqual(self.index.get_loc(Interval(0, 0.5)), 0) + self.assertEqual(self.index.get_loc(Interval(0, 1, 'left')), 0) + self.assertRaises(KeyError, self.index.get_loc, Interval(2, 3)) + self.assertRaises(KeyError, self.index.get_loc, Interval(-1, 0, 'left')) + + def test_get_indexer(self): + actual = self.index.get_indexer([-1, 0, 0.5, 1, 1.5, 2, 3]) + expected = [-1, -1, 0, 0, 1, 1, -1] + self.assert_numpy_array_equal(actual, expected) + + actual = self.index.get_indexer(self.index) + expected = [0, 1] + self.assert_numpy_array_equal(actual, expected) + + index = IntervalIndex.from_breaks([0, 1, 2], closed='left') + actual = index.get_indexer([-1, 0, 0.5, 1, 1.5, 2, 3]) + expected = [-1, 0, 0, 1, 1, -1, -1] + self.assert_numpy_array_equal(actual, expected) + + actual = self.index.get_indexer(index[:1]) + expected = [0] + self.assert_numpy_array_equal(actual, expected) + + self.assertRaises(ValueError, self.index.get_indexer, index) + + def test_get_indexer_subintervals(self): + # return indexers for wholly contained subintervals + target = IntervalIndex.from_breaks(np.linspace(0, 2, 5)) + actual = self.index.get_indexer(target) + expected = [0, 0, 1, 1] + self.assert_numpy_array_equal(actual, expected) + + target = IntervalIndex.from_breaks([0, 0.67, 1.33, 2]) + self.assertRaises(ValueError, self.index.get_indexer, target) + + actual = self.index.get_indexer(target[[0, -1]]) + expected = [0, 1] + self.assert_numpy_array_equal(actual, expected) + + target = IntervalIndex.from_breaks([0, 0.33, 0.67, 1], closed='left') + actual = self.index.get_indexer(target) + expected = [0, 0, 0] + self.assert_numpy_array_equal(actual, expected) + + def test_contains(self): + self.assertNotIn(0, self.index) + self.assertIn(0.5, self.index) + self.assertIn(2, self.index) + + self.assertIn(Interval(0, 1), self.index) + self.assertIn(Interval(0, 2), self.index) + self.assertIn(Interval(0, 0.5), self.index) + self.assertNotIn(Interval(3, 5), self.index) + self.assertNotIn(Interval(-1, 0, closed='left'), self.index) + + def test_non_contiguous(self): + index = IntervalIndex.from_tuples([(0, 1), (2, 3)]) + target = [0.5, 1.5, 2.5] + actual = index.get_indexer(target) + expected = [0, -1, 1] + self.assert_numpy_array_equal(actual, expected) + + self.assertNotIn(1.5, index) + + def test_union(self): + other = IntervalIndex([2], [3]) + expected = IntervalIndex(range(3), range(1, 4)) + actual = self.index.union(other) + self.assertTrue(expected.equals(actual)) + + actual = other.union(self.index) + self.assertTrue(expected.equals(actual)) + + self.assert_numpy_array_equal(self.index.union(self.index), self.index) + self.assert_numpy_array_equal(self.index.union(self.index[:1]), + self.index) + + def test_intersection(self): + other = IntervalIndex.from_breaks([1, 2, 3]) + expected = IntervalIndex.from_breaks([1, 2]) + actual = self.index.intersection(other) + self.assertTrue(expected.equals(actual)) + + self.assert_numpy_array_equal(self.index.intersection(self.index), + self.index) + + def test_difference(self): + self.assert_numpy_array_equal(self.index.difference(self.index[:1]), + self.index[1:]) + + def test_sym_diff(self): + self.assert_numpy_array_equal(self.index[:1].sym_diff(self.index[1:]), + self.index) + + def test_set_operation_errors(self): + self.assertRaises(ValueError, self.index.union, self.index.left) + + other = IntervalIndex.from_breaks([0, 1, 2], closed='neither') + self.assertRaises(ValueError, self.index.union, other) + + def test_isin(self): + actual = self.index.isin(self.index) + self.assert_numpy_array_equal([True, True], actual) + + actual = self.index.isin(self.index[:1]) + self.assert_numpy_array_equal([True, False], actual) + + def test_comparison(self): + actual = Interval(0, 1) < self.index + expected = [False, True] + self.assert_numpy_array_equal(actual, expected) + + actual = Interval(0.5, 1.5) < self.index + expected = [False, True] + self.assert_numpy_array_equal(actual, expected) + actual = self.index > Interval(0.5, 1.5) + self.assert_numpy_array_equal(actual, expected) + + actual = self.index == self.index + expected = [True, True] + self.assert_numpy_array_equal(actual, expected) + actual = self.index <= self.index + self.assert_numpy_array_equal(actual, expected) + actual = self.index >= self.index + self.assert_numpy_array_equal(actual, expected) + + actual = self.index < self.index + expected = [False, False] + self.assert_numpy_array_equal(actual, expected) + actual = self.index > self.index + self.assert_numpy_array_equal(actual, expected) + + actual = self.index == IntervalIndex.from_breaks([0, 1, 2], 'left') + self.assert_numpy_array_equal(actual, expected) + + actual = self.index == self.index.values + self.assert_numpy_array_equal(actual, [True, True]) + actual = self.index.values == self.index + self.assert_numpy_array_equal(actual, [True, True]) + actual = self.index <= self.index.values + self.assert_numpy_array_equal(actual, [True, True]) + actual = self.index != self.index.values + self.assert_numpy_array_equal(actual, [False, False]) + actual = self.index > self.index.values + self.assert_numpy_array_equal(actual, [False, False]) + actual = self.index.values > self.index + self.assert_numpy_array_equal(actual, [False, False]) + + # invalid comparisons + actual = self.index == 0 + self.assert_numpy_array_equal(actual, [False, False]) + actual = self.index == self.index.left + self.assert_numpy_array_equal(actual, [False, False]) + + with self.assertRaisesRegexp(TypeError, 'unorderable types'): + self.index > 0 + with self.assertRaisesRegexp(TypeError, 'unorderable types'): + self.index <= 0 + with self.assertRaises(TypeError): + self.index > np.arange(2) + with self.assertRaises(ValueError): + self.index > np.arange(3) + + def test_missing_values(self): + idx = pd.Index([np.nan, pd.Interval(0, 1), pd.Interval(1, 2)]) + idx2 = pd.IntervalIndex([np.nan, 0, 1], [np.nan, 1, 2]) + assert idx.equals(idx2) + + with tm.assertRaisesRegexp(ValueError, 'both left and right sides'): + pd.IntervalIndex([np.nan, 0, 1], [0, 1, 2]) + + self.assert_numpy_array_equal(pd.isnull(idx), [True, False, False]) + + def test_order(self): + expected = IntervalIndex.from_breaks([1, 2, 3, 4]) + actual = IntervalIndex.from_tuples([(3, 4), (1, 2), (2, 3)]).order() + self.assert_numpy_array_equal(expected, actual) + + def test_datetime(self): + dates = pd.date_range('2000', periods=3) + idx = IntervalIndex.from_breaks(dates) + + self.assert_numpy_array_equal(idx.left, dates[:2]) + self.assert_numpy_array_equal(idx.right, dates[-2:]) + + expected = pd.date_range('2000-01-01T12:00', periods=2) + self.assert_numpy_array_equal(idx.mid, expected) + + self.assertIn('2000-01-01T12', idx) + + target = pd.date_range('1999-12-31T12:00', periods=7, freq='12H') + actual = idx.get_indexer(target) + expected = [-1, -1, 0, 0, 1, 1, -1] + self.assert_numpy_array_equal(actual, expected) + + # def test_math(self): + # # add, subtract, multiply, divide with scalars should be OK + # actual = 2 * self.index + 1 + # expected = IntervalIndex.from_breaks((2 * np.arange(3) + 1)) + # self.assertTrue(expected.equals(actual)) + + # actual = self.index / 2.0 - 1 + # expected = IntervalIndex.from_breaks((np.arange(3) / 2.0 - 1)) + # self.assertTrue(expected.equals(actual)) + + # with self.assertRaises(TypeError): + # # doesn't make sense to add two IntervalIndex objects + # self.index + self.index + + # def test_datetime_math(self): + + # expected = IntervalIndex(pd.date_range('2000-01-02', periods=3)) + # actual = idx + pd.to_timedelta(1, unit='D') + # self.assertTrue(expected.equals(actual)) + + # TODO: other set operations (left join, right join, intersection), + # set operations with conflicting IntervalIndex objects or other dtypes, + # groupby, cut, reset_index... diff --git a/pandas/tests/tools/test_tile.py b/pandas/tests/tools/test_tile.py index 11b242bc06e15b..ca8e58dd853228 100644 --- a/pandas/tests/tools/test_tile.py +++ b/pandas/tests/tools/test_tile.py @@ -3,12 +3,14 @@ import numpy as np from pandas.compat import zip -from pandas import Series, Index, Categorical +from pandas import DataFrame, Series, Index, unique, isnull, Categorical import pandas.util.testing as tm from pandas.util.testing import assertRaisesRegexp import pandas.core.common as com from pandas.core.algorithms import quantile +from pandas.core.categorical import Categorical +from pandas.core.interval import Interval, IntervalIndex from pandas.tools.tile import cut, qcut import pandas.tools.tile as tmod from pandas import to_datetime, DatetimeIndex, Timestamp @@ -27,34 +29,30 @@ def test_bins(self): data = np.array([.2, 1.4, 2.5, 6.2, 9.7, 2.1]) result, bins = cut(data, 3, retbins=True) - exp_codes = np.array([0, 0, 0, 1, 2, 0], dtype=np.int8) - tm.assert_numpy_array_equal(result.codes, exp_codes) - exp = np.array([0.1905, 3.36666667, 6.53333333, 9.7]) - tm.assert_almost_equal(bins, exp) + intervals = IntervalIndex.from_breaks(bins.round(3)) + tm.assert_numpy_array_equal(result, intervals.take([0, 0, 0, 1, 2, 0])) + tm.assert_almost_equal(bins, [0.1905, 3.36666667, 6.53333333, 9.7]) def test_right(self): data = np.array([.2, 1.4, 2.5, 6.2, 9.7, 2.1, 2.575]) result, bins = cut(data, 4, right=True, retbins=True) - exp_codes = np.array([0, 0, 0, 2, 3, 0, 0], dtype=np.int8) - tm.assert_numpy_array_equal(result.codes, exp_codes) - exp = np.array([0.1905, 2.575, 4.95, 7.325, 9.7]) - tm.assert_numpy_array_equal(bins, exp) + intervals = IntervalIndex.from_breaks(bins.round(3)) + tm.assert_numpy_array_equal(result, intervals.take([0, 0, 0, 2, 3, 0, 0])) + tm.assert_almost_equal(bins, [0.1905, 2.575, 4.95, 7.325, 9.7]) def test_noright(self): data = np.array([.2, 1.4, 2.5, 6.2, 9.7, 2.1, 2.575]) result, bins = cut(data, 4, right=False, retbins=True) - exp_codes = np.array([0, 0, 0, 2, 3, 0, 1], dtype=np.int8) - tm.assert_numpy_array_equal(result.codes, exp_codes) - exp = np.array([0.2, 2.575, 4.95, 7.325, 9.7095]) - tm.assert_almost_equal(bins, exp) + intervals = IntervalIndex.from_breaks(bins.round(3), closed='left') + tm.assert_numpy_array_equal(result, intervals.take([0, 0, 0, 2, 3, 0, 1])) + tm.assert_almost_equal(bins, [0.2, 2.575, 4.95, 7.325, 9.7095]) def test_arraylike(self): data = [.2, 1.4, 2.5, 6.2, 9.7, 2.1] result, bins = cut(data, 3, retbins=True) - exp_codes = np.array([0, 0, 0, 1, 2, 0], dtype=np.int8) - tm.assert_numpy_array_equal(result.codes, exp_codes) - exp = np.array([0.1905, 3.36666667, 6.53333333, 9.7]) - tm.assert_almost_equal(bins, exp) + intervals = IntervalIndex.from_breaks(bins.round(3)) + tm.assert_numpy_array_equal(result, intervals.take([0, 0, 0, 1, 2, 0])) + tm.assert_almost_equal(bins, [0.1905, 3.36666667, 6.53333333, 9.7]) def test_bins_not_monotonic(self): data = [.2, 1.4, 2.5, 6.2, 9.7, 2.1] @@ -82,14 +80,13 @@ def test_labels(self): arr = np.tile(np.arange(0, 1.01, 0.1), 4) result, bins = cut(arr, 4, retbins=True) - ex_levels = Index(['(-0.001, 0.25]', '(0.25, 0.5]', '(0.5, 0.75]', - '(0.75, 1]']) - self.assert_index_equal(result.categories, ex_levels) + ex_levels = IntervalIndex.from_breaks([-1e-3, 0.25, 0.5, 0.75, 1]) + self.assert_numpy_array_equal(unique(result), ex_levels) result, bins = cut(arr, 4, retbins=True, right=False) - ex_levels = Index(['[0, 0.25)', '[0.25, 0.5)', '[0.5, 0.75)', - '[0.75, 1.001)']) - self.assert_index_equal(result.categories, ex_levels) + ex_levels = IntervalIndex.from_breaks([0, 0.25, 0.5, 0.75, 1 + 1e-3], + closed='left') + self.assert_numpy_array_equal(unique(result), ex_levels) def test_cut_pass_series_name_to_factor(self): s = Series(np.random.randn(100), name='foo') @@ -101,9 +98,8 @@ def test_label_precision(self): arr = np.arange(0, 0.73, 0.01) result = cut(arr, 4, precision=2) - ex_levels = Index(['(-0.00072, 0.18]', '(0.18, 0.36]', - '(0.36, 0.54]', '(0.54, 0.72]']) - self.assert_index_equal(result.categories, ex_levels) + ex_levels = IntervalIndex.from_breaks([-0.00072, 0.18, 0.36, 0.54, 0.72]) + self.assert_numpy_array_equal(unique(result), ex_levels) def test_na_handling(self): arr = np.arange(0, 0.75, 0.01) @@ -125,17 +121,16 @@ def test_inf_handling(self): data = np.arange(6) data_ser = Series(data, dtype='int64') - result = cut(data, [-np.inf, 2, 4, np.inf]) - result_ser = cut(data_ser, [-np.inf, 2, 4, np.inf]) + bins = [-np.inf, 2, 4, np.inf] + result = cut(data, bins) + result_ser = cut(data_ser, bins) - ex_categories = Index(['(-inf, 2]', '(2, 4]', '(4, inf]']) - - tm.assert_index_equal(result.categories, ex_categories) - tm.assert_index_equal(result_ser.cat.categories, ex_categories) - self.assertEqual(result[5], '(4, inf]') - self.assertEqual(result[0], '(-inf, 2]') - self.assertEqual(result_ser[5], '(4, inf]') - self.assertEqual(result_ser[0], '(-inf, 2]') + ex_uniques = IntervalIndex.from_breaks(bins).values + tm.assert_numpy_array_equal(unique(result), ex_uniques) + self.assertEqual(result[5], Interval(4, np.inf)) + self.assertEqual(result[0], Interval(-np.inf, 2)) + self.assertEqual(result_ser[5], Interval(4, np.inf)) + self.assertEqual(result_ser[0], Interval(-np.inf, 2)) def test_qcut(self): arr = np.random.randn(1000) @@ -158,7 +153,7 @@ def test_qcut_specify_quantiles(self): factor = qcut(arr, [0, .25, .5, .75, 1.]) expected = qcut(arr, 4) - tm.assert_categorical_equal(factor, expected) + self.assert_numpy_array_equal(factor, expected) def test_qcut_all_bins_same(self): assertRaisesRegexp(ValueError, "edges.*unique", qcut, @@ -169,7 +164,7 @@ def test_cut_out_of_bounds(self): result = cut(arr, [-1, 0, 1]) - mask = result.codes == -1 + mask = isnull(result) ex_mask = (arr < -1) | (arr > 1) self.assert_numpy_array_equal(mask, ex_mask) @@ -179,19 +174,21 @@ def test_cut_pass_labels(self): labels = ['Small', 'Medium', 'Large'] result = cut(arr, bins, labels=labels) + exp = ['Medium'] + 4 * ['Small'] + ['Medium', 'Large'] + self.assert_numpy_array_equal(result, exp) - exp = cut(arr, bins) - exp.categories = labels - - tm.assert_categorical_equal(result, exp) + result = cut(arr, bins, labels=Categorical.from_codes([0, 1, 2], labels)) + exp = Categorical.from_codes([1] + 4 * [0] + [1, 2], labels) + self.assertTrue(result.equals(exp)) def test_qcut_include_lowest(self): values = np.arange(10) cats = qcut(values, 4) - ex_levels = ['[0, 2.25]', '(2.25, 4.5]', '(4.5, 6.75]', '(6.75, 9]'] - self.assertTrue((cats.categories == ex_levels).all()) + ex_levels = [Interval(0, 2.25, closed='both'), Interval(2.25, 4.5), + Interval(4.5, 6.75), Interval(6.75, 9)] + self.assert_numpy_array_equal(unique(cats), ex_levels) def test_qcut_nas(self): arr = np.random.randn(100) @@ -200,9 +197,15 @@ def test_qcut_nas(self): result = qcut(arr, 4) self.assertTrue(com.isnull(result[:20]).all()) - def test_label_formatting(self): - self.assertEqual(tmod._trim_zeros('1.000'), '1') + def test_qcut_index(self): + # the result is closed on a different side for the first interval, but + # we should still be able to make an index + result = qcut([0, 2], 2) + index = Index(result) + expected = Index([Interval(0, 1, closed='both'), Interval(1, 2)]) + self.assert_numpy_array_equal(index, expected) + def test_round_frac(self): # it works result = cut(np.arange(11.), 2) @@ -210,10 +213,15 @@ def test_label_formatting(self): # #1979, negative numbers - result = tmod._format_label(-117.9998, precision=3) - self.assertEqual(result, '-118') - result = tmod._format_label(117.9998, precision=3) - self.assertEqual(result, '118') + result = tmod._round_frac(-117.9998, precision=3) + self.assertEqual(result, -118) + result = tmod._round_frac(117.9998, precision=3) + self.assertEqual(result, 118) + + result = tmod._round_frac(117.9998, precision=2) + self.assertEqual(result, 118) + result = tmod._round_frac(0.000123456, precision=2) + self.assertEqual(result, 0.00012) def test_qcut_binning_issues(self): # #1978, 1979 @@ -224,9 +232,9 @@ def test_qcut_binning_issues(self): starts = [] ends = [] - for lev in result.categories: - s, e = lev[1:-1].split(',') - + for lev in np.unique(result): + s = lev.left + e = lev.right self.assertTrue(s != e) starts.append(float(s)) @@ -238,22 +246,20 @@ def test_qcut_binning_issues(self): self.assertTrue(ep < en) self.assertTrue(ep <= sn) - def test_cut_return_categorical(self): - s = Series([0, 1, 2, 3, 4, 5, 6, 7, 8]) - res = cut(s, 3) - exp = Series(Categorical.from_codes([0, 0, 0, 1, 1, 1, 2, 2, 2], - ["(-0.008, 2.667]", - "(2.667, 5.333]", "(5.333, 8]"], - ordered=True)) + def test_cut_return_intervals(self): + s = Series([0,1,2,3,4,5,6,7,8]) + res = cut(s,3) + exp_bins = np.linspace(0, 8, num=4).round(3) + exp_bins[0] -= 0.008 + exp = Series(IntervalIndex.from_breaks(exp_bins).take([0,0,0,1,1,1,2,2,2])) tm.assert_series_equal(res, exp) - def test_qcut_return_categorical(self): - s = Series([0, 1, 2, 3, 4, 5, 6, 7, 8]) - res = qcut(s, [0, 0.333, 0.666, 1]) - exp = Series(Categorical.from_codes([0, 0, 0, 1, 1, 1, 2, 2, 2], - ["[0, 2.664]", - "(2.664, 5.328]", "(5.328, 8]"], - ordered=True)) + def test_qcut_return_intervals(self): + s = Series([0,1,2,3,4,5,6,7,8]) + res = qcut(s,[0,0.333,0.666,1]) + exp_levels = np.array([Interval(0, 2.664, closed='both'), + Interval(2.664, 5.328), Interval(5.328, 8)]) + exp = Series(exp_levels.take([0,0,0,1,1,1,2,2,2])) tm.assert_series_equal(res, exp) def test_series_retbins(self): @@ -421,6 +427,14 @@ def test_datetime_bin(self): result = cut(data, bins=bin_pydatetime) tm.assert_series_equal(Series(result), expected) + result, bins = cut(s, 2, retbins=True, labels=[0, 1]) + tm.assert_numpy_array_equal(result, [0, 0, 1, 1]) + tm.assert_almost_equal(bins, [-0.003, 1.5, 3]) + + result, bins = qcut(s, 2, retbins=True, labels=[0, 1]) + tm.assert_numpy_array_equal(result, [0, 0, 1, 1]) + tm.assert_almost_equal(bins, [0, 1.5, 3]) + def curpath(): pth, _ = os.path.split(os.path.abspath(__file__)) diff --git a/pandas/tools/tile.py b/pandas/tools/tile.py index ccd8c2478e8a51..a48cf060d09a61 100644 --- a/pandas/tools/tile.py +++ b/pandas/tools/tile.py @@ -8,6 +8,8 @@ from pandas.core.api import Series from pandas.core.categorical import Categorical +from pandas.core.index import _ensure_index +from pandas.core.interval import IntervalIndex, Interval import pandas.core.algorithms as algos import pandas.core.nanops as nanops from pandas.compat import zip @@ -17,6 +19,8 @@ import numpy as np +import warnings + def cut(x, bins, right=True, labels=None, retbins=False, precision=3, include_lowest=False): @@ -45,9 +49,9 @@ def cut(x, bins, right=True, labels=None, retbins=False, precision=3, retbins : bool, optional Whether to return the bins or not. Can be useful if bins is given as a scalar. - precision : int + precision : int, optional The precision at which to store and display the bins labels - include_lowest : bool + include_lowest : bool, optional Whether the first interval should be left-inclusive or not. Returns @@ -93,14 +97,17 @@ def cut(x, bins, right=True, labels=None, retbins=False, precision=3, if is_scalar(bins) and bins < 1: raise ValueError("`bins` should be a positive integer.") - sz = x.size + # TODO: IntervalIndex + try: # for array-like + sz = x.size + except AttributeError: + x = np.asarray(x) + sz = x.size if sz == 0: raise ValueError('Cannot cut empty array') - # handle empty arrays. Can't determine range, so use 0-1. - # rng = (0, 1) - else: - rng = (nanops.nanmin(x), nanops.nanmax(x)) + + rng = (nanops.nanmin(x), nanops.nanmax(x)) mn, mx = [mi + 0.0 for mi in rng] if mn == mx: # adjust end points before binning @@ -149,7 +156,7 @@ def qcut(x, q, labels=None, retbins=False, precision=3, duplicates='raise'): retbins : bool, optional Whether to return the bins or not. Can be useful if bins is given as a scalar. - precision : int + precision : int, optional The precision at which to store and display the bins labels duplicates : {default 'raise', 'drop'}, optional If bin edges are not unique, raise ValueError or drop non-uniques. @@ -225,6 +232,8 @@ def _bins_to_cuts(x, bins, right=True, labels=None, if labels is not False: if labels is None: + + # TODO: IntervalIndex increases = 0 while True: try: @@ -239,23 +248,34 @@ def _bins_to_cuts(x, bins, right=True, labels=None, else: break + # + #closed = 'right' if right else 'left' + #precision = _infer_precision(precision, bins) + #breaks = [_round_frac(b, precision) for b in bins] + #labels = IntervalIndex.from_breaks(breaks, closed=closed).values + + #if right and include_lowest: + # labels[0] = Interval(labels[0].left, labels[0].right, + # closed='both') + else: if len(labels) != len(bins) - 1: raise ValueError('Bin labels must be one fewer than ' 'the number of bin edges') - levels = labels - levels = np.asarray(levels, dtype=object) + if not com.is_categorical(labels): + labels = np.asarray(labels) + np.putmask(ids, na_mask, 0) - fac = Categorical(ids - 1, levels, ordered=True, fastpath=True) + result = com.take_nd(labels, ids - 1) + else: - fac = ids - 1 + result = ids - 1 if has_nas: - fac = fac.astype(np.float64) - np.putmask(fac, na_mask, np.nan) - - return fac, bins + result = result.astype(np.float64) + np.putmask(result, na_mask, np.nan) + return result, bins def _format_levels(bins, prec, right=True, include_lowest=False, dtype=None): @@ -265,15 +285,11 @@ def _format_levels(bins, prec, right=True, for a, b in zip(bins, bins[1:]): fa, fb = fmt(a), fmt(b) - if a != b and fa == fb: - raise ValueError('precision too low') - - formatted = '(%s, %s]' % (fa, fb) - - levels.append(formatted) - - if include_lowest: - levels[0] = '[' + levels[0][1:] +def _round_frac(x, precision): + """Round the fractional part of the given number + """ + if not np.isfinite(x) or x == 0: + return x else: levels = ['[%s, %s)' % (fmt(a), fmt(b)) for a, b in zip(bins, bins[1:])] @@ -291,30 +307,11 @@ def _format_label(x, precision=3, dtype=None): return str(x) elif is_float(x): frac, whole = np.modf(x) - sgn = '-' if x < 0 else '' - whole = abs(whole) - if frac != 0.0: - val = fmt_str % frac - - # rounded up or down - if '.' not in val: - if x < 0: - return '%d' % (-whole - 1) - else: - return '%d' % (whole + 1) - - if 'e' in val: - return _trim_zeros(fmt_str % x) - else: - val = _trim_zeros(val) - if '.' in val: - return sgn + '.'.join(('%d' % whole, val.split('.')[1])) - else: # pragma: no cover - return sgn + '.'.join(('%d' % whole, val)) + if whole == 0: + digits = -int(np.floor(np.log10(abs(frac)))) - 1 + precision else: - return sgn + '%0.f' % whole - else: - return str(x) + digits = precision + return np.around(x, digits) def _trim_zeros(x): @@ -388,3 +385,12 @@ def _postprocess_for_cut(fac, bins, retbins, x_is_series, series_index, name): return fac return fac, bins + +def _infer_precision(base_precision, bins): + """Infer an appropriate precision for _round_frac + """ + for precision in range(base_precision, 20): + levels = [_round_frac(b, precision) for b in bins] + if algos.unique(levels).size == bins.size: + return precision + return base_precision # default diff --git a/pandas/util/testing.py b/pandas/util/testing.py index cf76f4ead77e34..a6aeba9221768a 100644 --- a/pandas/util/testing.py +++ b/pandas/util/testing.py @@ -43,9 +43,11 @@ from pandas.computation import expressions as expr -from pandas import (bdate_range, CategoricalIndex, Categorical, DatetimeIndex, - TimedeltaIndex, PeriodIndex, RangeIndex, Index, MultiIndex, +from pandas import (bdate_range, CategoricalIndex, Categorical, IntervalIndex, + DatetimeIndex, TimedeltaIndex, PeriodIndex, RangeIndex, + Index, MultiIndex, Series, DataFrame, Panel, Panel4D) + from pandas.util.decorators import deprecate from pandas.util import libtesting from pandas.io.common import urlopen @@ -1588,6 +1590,11 @@ def makeCategoricalIndex(k=10, n=3, name=None): return CategoricalIndex(np.random.choice(x, k), name=name) +def makeIntervalIndex(k=10, name=None): + """ make a length k IntervalIndex """ + x = np.linspace(0, 100, num=(k + 1)) + return IntervalIndex.from_breaks(x, name=name) + def makeBoolIndex(k=10, name=None): if k == 1: return Index([True], name=name)