Skip to content

Commit

Permalink
feat(tests): support defining datatype nullability for hypothesis str…
Browse files Browse the repository at this point in the history
…ategies
  • Loading branch information
kszucs authored and jcrist committed Aug 15, 2023
1 parent 94712f4 commit ff26fb8
Show file tree
Hide file tree
Showing 3 changed files with 220 additions and 126 deletions.
9 changes: 6 additions & 3 deletions ibis/formats/tests/test_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,18 +71,21 @@ def test_non_roundtripable_bytes_type(numpy_type):


@h.given(
ibst.null_dtype | ibst.variadic_dtypes | ibst.decimal_dtype() | ibst.struct_dtypes()
ibst.null_dtype
| ibst.variadic_dtypes()
| ibst.decimal_dtypes()
| ibst.struct_dtypes()
)
def test_variadic_to_numpy(ibis_type):
assert NumpyType.from_ibis(ibis_type) == np.dtype("object")


@h.given(ibst.date_dtype | ibst.timestamp_dtype)
@h.given(ibst.date_dtype() | ibst.timestamp_dtype())
def test_date_to_numpy(ibis_type):
assert NumpyType.from_ibis(ibis_type) == np.dtype("datetime64[ns]")


@h.given(ibst.time_dtype)
@h.given(ibst.time_dtype())
def test_time_to_numpy(ibis_type):
assert NumpyType.from_ibis(ibis_type) == np.dtype("timedelta64[ns]")

Expand Down
282 changes: 177 additions & 105 deletions ibis/tests/strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,73 +10,150 @@
import ibis.expr.schema as sch
from ibis.common.temporal import IntervalUnit

# pyarrow also has hypothesis strategies various pyarrow objects
# Strategies for generating ibis datatypes

# Strategies for generating datatypes

nullable = st.booleans()
_nullable = st.booleans()

null_dtype = st.just(dt.null)
boolean_dtype = st.builds(dt.Boolean, nullable=nullable)

int8_dtype = st.builds(dt.Int8, nullable=nullable)
int16_dtype = st.builds(dt.Int16, nullable=nullable)
int32_dtype = st.builds(dt.Int32, nullable=nullable)
int64_dtype = st.builds(dt.Int64, nullable=nullable)
uint8_dtype = st.builds(dt.UInt8, nullable=nullable)
uint16_dtype = st.builds(dt.UInt16, nullable=nullable)
uint32_dtype = st.builds(dt.UInt32, nullable=nullable)
uint64_dtype = st.builds(dt.UInt64, nullable=nullable)
float16_dtype = st.builds(dt.Float16, nullable=nullable)
float32_dtype = st.builds(dt.Float32, nullable=nullable)
float64_dtype = st.builds(dt.Float64, nullable=nullable)

def boolean_dtype(nullable=_nullable):
return st.builds(dt.Boolean, nullable=nullable)


def signed_integer_dtypes(nullable=_nullable):
return st.one_of(
st.builds(dt.Int8, nullable=nullable),
st.builds(dt.Int16, nullable=nullable),
st.builds(dt.Int32, nullable=nullable),
st.builds(dt.Int64, nullable=nullable),
)


def unsigned_integer_dtypes(nullable=_nullable):
return st.one_of(
st.builds(dt.UInt8, nullable=nullable),
st.builds(dt.UInt16, nullable=nullable),
st.builds(dt.UInt32, nullable=nullable),
st.builds(dt.UInt64, nullable=nullable),
)


def integer_dtypes(nullable=_nullable):
return st.one_of(
signed_integer_dtypes(nullable=nullable),
unsigned_integer_dtypes(nullable=nullable),
)


def floating_dtypes(nullable=_nullable):
return st.one_of(
st.builds(dt.Float16, nullable=nullable),
st.builds(dt.Float32, nullable=nullable),
st.builds(dt.Float64, nullable=nullable),
)


@st.composite
def decimal_dtype(draw):
def decimal_dtypes(draw, nullable=_nullable):
number = st.integers(min_value=1, max_value=38)
precision, scale = draw(number), draw(number)
h.assume(precision >= scale)
return dt.Decimal(precision, scale, nullable=draw(nullable))


signed_integer_dtypes = st.one_of(int8_dtype, int16_dtype, int32_dtype, int64_dtype)
unsigned_integer_dtypes = st.one_of(
uint8_dtype, uint16_dtype, uint32_dtype, uint64_dtype
)
integer_dtypes = st.one_of(signed_integer_dtypes, unsigned_integer_dtypes)
floating_dtypes = st.one_of(float16_dtype, float32_dtype, float64_dtype)
numeric_dtypes = st.one_of(integer_dtypes, floating_dtypes, decimal_dtype())

date_dtype = st.builds(dt.Date, nullable=nullable)
time_dtype = st.builds(dt.Time, nullable=nullable)
timestamp_dtype = st.builds(
dt.Timestamp, timezone=st.none() | tzst.timezones().map(str), nullable=nullable
)
interval_unit = st.sampled_from(list(IntervalUnit))
interval_dtype = st.builds(dt.Interval, unit=interval_unit, nullable=nullable)
temporal_dtypes = st.one_of(
date_dtype,
time_dtype,
timestamp_dtype,
# interval_dtype
)

primitive_dtypes = st.one_of(
null_dtype,
boolean_dtype,
integer_dtypes,
floating_dtypes,
date_dtype,
time_dtype,
)


def array_dtypes(item_strategy=primitive_dtypes):
def numeric_dtypes(nullable=_nullable):
return st.one_of(
integer_dtypes(nullable=nullable),
floating_dtypes(nullable=nullable),
decimal_dtypes(nullable=nullable),
)


def string_dtype(nullable=_nullable):
return st.builds(dt.String, nullable=nullable)


def binary_dtype(nullable=_nullable):
return st.builds(dt.Binary, nullable=nullable)


def json_dtype(nullable=_nullable):
return st.builds(dt.JSON, nullable=nullable)


def inet_dtype(nullable=_nullable):
return st.builds(dt.INET, nullable=nullable)


def macaddr_dtype(nullable=_nullable):
return st.builds(dt.MACADDR, nullable=nullable)


def uuid_dtype(nullable=_nullable):
return st.builds(dt.UUID, nullable=nullable)


def string_like_dtypes(nullable=_nullable):
return st.one_of(
string_dtype(nullable=nullable),
binary_dtype(nullable=nullable),
json_dtype(nullable=nullable),
inet_dtype(nullable=nullable),
macaddr_dtype(nullable=nullable),
uuid_dtype(nullable=nullable),
)


def date_dtype(nullable=_nullable):
return st.builds(dt.Date, nullable=nullable)


def time_dtype(nullable=_nullable):
return st.builds(dt.Time, nullable=nullable)


_timezone = st.none() | tzst.timezones().map(str)
_interval = st.sampled_from(list(IntervalUnit))


def timestamp_dtype(timezone=_timezone, nullable=_nullable):
return st.builds(dt.Timestamp, timezone=timezone, nullable=nullable)


def interval_dtype(interval=_interval, nullable=_nullable):
return st.builds(dt.Interval, unit=interval, nullable=nullable)


def temporal_dtypes(timezone=_timezone, interval=_interval, nullable=_nullable):
return st.one_of(
date_dtype(nullable=nullable),
time_dtype(nullable=nullable),
timestamp_dtype(timezone=timezone, nullable=nullable),
)


def primitive_dtypes(nullable=_nullable):
return st.one_of(
null_dtype,
boolean_dtype(nullable=nullable),
integer_dtypes(nullable=nullable),
floating_dtypes(nullable=nullable),
date_dtype(nullable=nullable),
time_dtype(nullable=nullable),
)


_item_strategy = primitive_dtypes()


def array_dtypes(item_strategy=_item_strategy, nullable=_nullable):
return st.builds(dt.Array, value_type=item_strategy, nullable=nullable)


def map_dtypes(key_strategy=primitive_dtypes, value_strategy=primitive_dtypes):
def map_dtypes(
key_strategy=_item_strategy, value_strategy=_item_strategy, nullable=_nullable
):
return st.builds(
dt.Map, key_type=key_strategy, value_type=value_strategy, nullable=nullable
)
Expand All @@ -85,8 +162,9 @@ def map_dtypes(key_strategy=primitive_dtypes, value_strategy=primitive_dtypes):
@st.composite
def struct_dtypes(
draw,
item_strategy=primitive_dtypes,
item_strategy=_item_strategy,
num_fields=st.integers(min_value=0, max_value=20), # noqa: B008
nullable=_nullable,
):
num_fields = draw(num_fields)
names = draw(st.lists(st.text(), min_size=num_fields, max_size=num_fields))
Expand All @@ -95,66 +173,60 @@ def struct_dtypes(
return dt.Struct(fields, nullable=draw(nullable))


point_dtype = st.builds(dt.Point, nullable=nullable)
linestring_dtype = st.builds(dt.LineString, nullable=nullable)
polygon_dtype = st.builds(dt.Polygon, nullable=nullable)
multipoint_dtype = st.builds(dt.MultiPoint, nullable=nullable)
multilinestring_dtype = st.builds(dt.MultiLineString, nullable=nullable)
multipolygon_dtype = st.builds(dt.MultiPolygon, nullable=nullable)
geometry_dtype = st.builds(
dt.GeoSpatial, geotype=st.just("geometry"), nullable=nullable
)
geography_dtype = st.builds(
dt.GeoSpatial, geotype=st.just("geography"), nullable=nullable
)
geospatial_dtypes = st.one_of(
point_dtype,
linestring_dtype,
polygon_dtype,
multipoint_dtype,
multilinestring_dtype,
multipolygon_dtype,
geometry_dtype,
geography_dtype,
)

string_dtype = st.builds(dt.String, nullable=nullable)
binary_dtype = st.builds(dt.Binary, nullable=nullable)
json_dtype = st.builds(dt.JSON, nullable=nullable)
inet_dtype = st.builds(dt.INET, nullable=nullable)
macaddr_dtype = st.builds(dt.MACADDR, nullable=nullable)
uuid_dtype = st.builds(dt.UUID, nullable=nullable)

variadic_dtypes = st.one_of(
string_dtype,
binary_dtype,
json_dtype,
inet_dtype,
macaddr_dtype,
array_dtypes(),
map_dtypes(),
)

all_dtypes = st.deferred(
lambda: (
primitive_dtypes
| interval_dtype
| uuid_dtype
| geospatial_dtypes
| variadic_dtypes
| struct_dtypes()
| array_dtypes(all_dtypes)
| map_dtypes(all_dtypes, all_dtypes)
| struct_dtypes(all_dtypes)
def geometry_dtypes(nullable=_nullable):
return st.builds(dt.GeoSpatial, geotype=st.just("geometry"), nullable=nullable)


def geography_dtypes(nullable=_nullable):
return st.builds(dt.GeoSpatial, geotype=st.just("geography"), nullable=nullable)


def geospatial_dtypes(nullable=_nullable):
return st.one_of(
st.builds(dt.Point, nullable=nullable),
st.builds(dt.LineString, nullable=nullable),
st.builds(dt.Polygon, nullable=nullable),
st.builds(dt.MultiPoint, nullable=nullable),
st.builds(dt.MultiLineString, nullable=nullable),
st.builds(dt.MultiPolygon, nullable=nullable),
geometry_dtypes(nullable=nullable),
geography_dtypes(nullable=nullable),
)


def variadic_dtypes(nullable=_nullable):
return st.one_of(
string_dtype(nullable=nullable),
binary_dtype(nullable=nullable),
json_dtype(nullable=nullable),
array_dtypes(nullable=nullable),
map_dtypes(nullable=nullable),
)


def all_dtypes(nullable=_nullable):
recursive = st.deferred(
lambda: (
primitive_dtypes(nullable=nullable)
| string_like_dtypes(nullable=nullable)
| temporal_dtypes(nullable=nullable)
| interval_dtype(nullable=nullable)
| geospatial_dtypes(nullable=nullable)
| variadic_dtypes(nullable=nullable)
| struct_dtypes(nullable=nullable)
| array_dtypes(recursive, nullable=nullable)
| map_dtypes(recursive, recursive, nullable=nullable)
| struct_dtypes(recursive, nullable=nullable)
)
)
)
return recursive


# Strategies for generating schema


@st.composite
def schema(draw, item_strategy=primitive_dtypes, max_size=20):
def schema(draw, item_strategy=_item_strategy, max_size=20):
num_fields = draw(st.integers(min_value=0, max_value=max_size))
names = draw(
st.lists(st.text(), min_size=num_fields, max_size=num_fields, unique=True)
Expand Down
Loading

0 comments on commit ff26fb8

Please sign in to comment.