From ff26fb8209929e56a846329a258d697a0516b523 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kriszti=C3=A1n=20Sz=C5=B1cs?= Date: Mon, 14 Aug 2023 20:15:36 +0200 Subject: [PATCH] feat(tests): support defining datatype nullability for hypothesis strategies --- ibis/formats/tests/test_numpy.py | 9 +- ibis/tests/strategies.py | 282 +++++++++++++++++++------------ ibis/tests/test_strategies.py | 55 ++++-- 3 files changed, 220 insertions(+), 126 deletions(-) diff --git a/ibis/formats/tests/test_numpy.py b/ibis/formats/tests/test_numpy.py index 8f6242c09217..af994b273094 100644 --- a/ibis/formats/tests/test_numpy.py +++ b/ibis/formats/tests/test_numpy.py @@ -71,18 +71,21 @@ def test_non_roundtripable_bytes_type(numpy_type): @h.given( - ibst.null_dtype | ibst.variadic_dtypes | ibst.decimal_dtype() | ibst.struct_dtypes() + ibst.null_dtype + | ibst.variadic_dtypes() + | ibst.decimal_dtypes() + | ibst.struct_dtypes() ) def test_variadic_to_numpy(ibis_type): assert NumpyType.from_ibis(ibis_type) == np.dtype("object") -@h.given(ibst.date_dtype | ibst.timestamp_dtype) +@h.given(ibst.date_dtype() | ibst.timestamp_dtype()) def test_date_to_numpy(ibis_type): assert NumpyType.from_ibis(ibis_type) == np.dtype("datetime64[ns]") -@h.given(ibst.time_dtype) +@h.given(ibst.time_dtype()) def test_time_to_numpy(ibis_type): assert NumpyType.from_ibis(ibis_type) == np.dtype("timedelta64[ns]") diff --git a/ibis/tests/strategies.py b/ibis/tests/strategies.py index a7145cc9fb7e..d01bfc2beb4d 100644 --- a/ibis/tests/strategies.py +++ b/ibis/tests/strategies.py @@ -10,73 +10,150 @@ import ibis.expr.schema as sch from ibis.common.temporal import IntervalUnit -# pyarrow also has hypothesis strategies various pyarrow objects +# Strategies for generating ibis datatypes -# Strategies for generating datatypes - -nullable = st.booleans() +_nullable = st.booleans() null_dtype = st.just(dt.null) -boolean_dtype = st.builds(dt.Boolean, nullable=nullable) -int8_dtype = st.builds(dt.Int8, nullable=nullable) -int16_dtype = st.builds(dt.Int16, nullable=nullable) -int32_dtype = st.builds(dt.Int32, nullable=nullable) -int64_dtype = st.builds(dt.Int64, nullable=nullable) -uint8_dtype = st.builds(dt.UInt8, nullable=nullable) -uint16_dtype = st.builds(dt.UInt16, nullable=nullable) -uint32_dtype = st.builds(dt.UInt32, nullable=nullable) -uint64_dtype = st.builds(dt.UInt64, nullable=nullable) -float16_dtype = st.builds(dt.Float16, nullable=nullable) -float32_dtype = st.builds(dt.Float32, nullable=nullable) -float64_dtype = st.builds(dt.Float64, nullable=nullable) + +def boolean_dtype(nullable=_nullable): + return st.builds(dt.Boolean, nullable=nullable) + + +def signed_integer_dtypes(nullable=_nullable): + return st.one_of( + st.builds(dt.Int8, nullable=nullable), + st.builds(dt.Int16, nullable=nullable), + st.builds(dt.Int32, nullable=nullable), + st.builds(dt.Int64, nullable=nullable), + ) + + +def unsigned_integer_dtypes(nullable=_nullable): + return st.one_of( + st.builds(dt.UInt8, nullable=nullable), + st.builds(dt.UInt16, nullable=nullable), + st.builds(dt.UInt32, nullable=nullable), + st.builds(dt.UInt64, nullable=nullable), + ) + + +def integer_dtypes(nullable=_nullable): + return st.one_of( + signed_integer_dtypes(nullable=nullable), + unsigned_integer_dtypes(nullable=nullable), + ) + + +def floating_dtypes(nullable=_nullable): + return st.one_of( + st.builds(dt.Float16, nullable=nullable), + st.builds(dt.Float32, nullable=nullable), + st.builds(dt.Float64, nullable=nullable), + ) @st.composite -def decimal_dtype(draw): +def decimal_dtypes(draw, nullable=_nullable): number = st.integers(min_value=1, max_value=38) precision, scale = draw(number), draw(number) h.assume(precision >= scale) return dt.Decimal(precision, scale, nullable=draw(nullable)) -signed_integer_dtypes = st.one_of(int8_dtype, int16_dtype, int32_dtype, int64_dtype) -unsigned_integer_dtypes = st.one_of( - uint8_dtype, uint16_dtype, uint32_dtype, uint64_dtype -) -integer_dtypes = st.one_of(signed_integer_dtypes, unsigned_integer_dtypes) -floating_dtypes = st.one_of(float16_dtype, float32_dtype, float64_dtype) -numeric_dtypes = st.one_of(integer_dtypes, floating_dtypes, decimal_dtype()) - -date_dtype = st.builds(dt.Date, nullable=nullable) -time_dtype = st.builds(dt.Time, nullable=nullable) -timestamp_dtype = st.builds( - dt.Timestamp, timezone=st.none() | tzst.timezones().map(str), nullable=nullable -) -interval_unit = st.sampled_from(list(IntervalUnit)) -interval_dtype = st.builds(dt.Interval, unit=interval_unit, nullable=nullable) -temporal_dtypes = st.one_of( - date_dtype, - time_dtype, - timestamp_dtype, - # interval_dtype -) - -primitive_dtypes = st.one_of( - null_dtype, - boolean_dtype, - integer_dtypes, - floating_dtypes, - date_dtype, - time_dtype, -) - - -def array_dtypes(item_strategy=primitive_dtypes): +def numeric_dtypes(nullable=_nullable): + return st.one_of( + integer_dtypes(nullable=nullable), + floating_dtypes(nullable=nullable), + decimal_dtypes(nullable=nullable), + ) + + +def string_dtype(nullable=_nullable): + return st.builds(dt.String, nullable=nullable) + + +def binary_dtype(nullable=_nullable): + return st.builds(dt.Binary, nullable=nullable) + + +def json_dtype(nullable=_nullable): + return st.builds(dt.JSON, nullable=nullable) + + +def inet_dtype(nullable=_nullable): + return st.builds(dt.INET, nullable=nullable) + + +def macaddr_dtype(nullable=_nullable): + return st.builds(dt.MACADDR, nullable=nullable) + + +def uuid_dtype(nullable=_nullable): + return st.builds(dt.UUID, nullable=nullable) + + +def string_like_dtypes(nullable=_nullable): + return st.one_of( + string_dtype(nullable=nullable), + binary_dtype(nullable=nullable), + json_dtype(nullable=nullable), + inet_dtype(nullable=nullable), + macaddr_dtype(nullable=nullable), + uuid_dtype(nullable=nullable), + ) + + +def date_dtype(nullable=_nullable): + return st.builds(dt.Date, nullable=nullable) + + +def time_dtype(nullable=_nullable): + return st.builds(dt.Time, nullable=nullable) + + +_timezone = st.none() | tzst.timezones().map(str) +_interval = st.sampled_from(list(IntervalUnit)) + + +def timestamp_dtype(timezone=_timezone, nullable=_nullable): + return st.builds(dt.Timestamp, timezone=timezone, nullable=nullable) + + +def interval_dtype(interval=_interval, nullable=_nullable): + return st.builds(dt.Interval, unit=interval, nullable=nullable) + + +def temporal_dtypes(timezone=_timezone, interval=_interval, nullable=_nullable): + return st.one_of( + date_dtype(nullable=nullable), + time_dtype(nullable=nullable), + timestamp_dtype(timezone=timezone, nullable=nullable), + ) + + +def primitive_dtypes(nullable=_nullable): + return st.one_of( + null_dtype, + boolean_dtype(nullable=nullable), + integer_dtypes(nullable=nullable), + floating_dtypes(nullable=nullable), + date_dtype(nullable=nullable), + time_dtype(nullable=nullable), + ) + + +_item_strategy = primitive_dtypes() + + +def array_dtypes(item_strategy=_item_strategy, nullable=_nullable): return st.builds(dt.Array, value_type=item_strategy, nullable=nullable) -def map_dtypes(key_strategy=primitive_dtypes, value_strategy=primitive_dtypes): +def map_dtypes( + key_strategy=_item_strategy, value_strategy=_item_strategy, nullable=_nullable +): return st.builds( dt.Map, key_type=key_strategy, value_type=value_strategy, nullable=nullable ) @@ -85,8 +162,9 @@ def map_dtypes(key_strategy=primitive_dtypes, value_strategy=primitive_dtypes): @st.composite def struct_dtypes( draw, - item_strategy=primitive_dtypes, + item_strategy=_item_strategy, num_fields=st.integers(min_value=0, max_value=20), # noqa: B008 + nullable=_nullable, ): num_fields = draw(num_fields) names = draw(st.lists(st.text(), min_size=num_fields, max_size=num_fields)) @@ -95,66 +173,60 @@ def struct_dtypes( return dt.Struct(fields, nullable=draw(nullable)) -point_dtype = st.builds(dt.Point, nullable=nullable) -linestring_dtype = st.builds(dt.LineString, nullable=nullable) -polygon_dtype = st.builds(dt.Polygon, nullable=nullable) -multipoint_dtype = st.builds(dt.MultiPoint, nullable=nullable) -multilinestring_dtype = st.builds(dt.MultiLineString, nullable=nullable) -multipolygon_dtype = st.builds(dt.MultiPolygon, nullable=nullable) -geometry_dtype = st.builds( - dt.GeoSpatial, geotype=st.just("geometry"), nullable=nullable -) -geography_dtype = st.builds( - dt.GeoSpatial, geotype=st.just("geography"), nullable=nullable -) -geospatial_dtypes = st.one_of( - point_dtype, - linestring_dtype, - polygon_dtype, - multipoint_dtype, - multilinestring_dtype, - multipolygon_dtype, - geometry_dtype, - geography_dtype, -) - -string_dtype = st.builds(dt.String, nullable=nullable) -binary_dtype = st.builds(dt.Binary, nullable=nullable) -json_dtype = st.builds(dt.JSON, nullable=nullable) -inet_dtype = st.builds(dt.INET, nullable=nullable) -macaddr_dtype = st.builds(dt.MACADDR, nullable=nullable) -uuid_dtype = st.builds(dt.UUID, nullable=nullable) - -variadic_dtypes = st.one_of( - string_dtype, - binary_dtype, - json_dtype, - inet_dtype, - macaddr_dtype, - array_dtypes(), - map_dtypes(), -) - -all_dtypes = st.deferred( - lambda: ( - primitive_dtypes - | interval_dtype - | uuid_dtype - | geospatial_dtypes - | variadic_dtypes - | struct_dtypes() - | array_dtypes(all_dtypes) - | map_dtypes(all_dtypes, all_dtypes) - | struct_dtypes(all_dtypes) +def geometry_dtypes(nullable=_nullable): + return st.builds(dt.GeoSpatial, geotype=st.just("geometry"), nullable=nullable) + + +def geography_dtypes(nullable=_nullable): + return st.builds(dt.GeoSpatial, geotype=st.just("geography"), nullable=nullable) + + +def geospatial_dtypes(nullable=_nullable): + return st.one_of( + st.builds(dt.Point, nullable=nullable), + st.builds(dt.LineString, nullable=nullable), + st.builds(dt.Polygon, nullable=nullable), + st.builds(dt.MultiPoint, nullable=nullable), + st.builds(dt.MultiLineString, nullable=nullable), + st.builds(dt.MultiPolygon, nullable=nullable), + geometry_dtypes(nullable=nullable), + geography_dtypes(nullable=nullable), + ) + + +def variadic_dtypes(nullable=_nullable): + return st.one_of( + string_dtype(nullable=nullable), + binary_dtype(nullable=nullable), + json_dtype(nullable=nullable), + array_dtypes(nullable=nullable), + map_dtypes(nullable=nullable), + ) + + +def all_dtypes(nullable=_nullable): + recursive = st.deferred( + lambda: ( + primitive_dtypes(nullable=nullable) + | string_like_dtypes(nullable=nullable) + | temporal_dtypes(nullable=nullable) + | interval_dtype(nullable=nullable) + | geospatial_dtypes(nullable=nullable) + | variadic_dtypes(nullable=nullable) + | struct_dtypes(nullable=nullable) + | array_dtypes(recursive, nullable=nullable) + | map_dtypes(recursive, recursive, nullable=nullable) + | struct_dtypes(recursive, nullable=nullable) + ) ) -) + return recursive # Strategies for generating schema @st.composite -def schema(draw, item_strategy=primitive_dtypes, max_size=20): +def schema(draw, item_strategy=_item_strategy, max_size=20): num_fields = draw(st.integers(min_value=0, max_value=max_size)) names = draw( st.lists(st.text(), min_size=num_fields, max_size=num_fields, unique=True) diff --git a/ibis/tests/test_strategies.py b/ibis/tests/test_strategies.py index 167faa6534e3..40eeaa5c8c64 100644 --- a/ibis/tests/test_strategies.py +++ b/ibis/tests/test_strategies.py @@ -1,6 +1,7 @@ from __future__ import annotations import hypothesis as h +import hypothesis.strategies as st import numpy as np import ibis.expr.datatypes as dt @@ -16,76 +17,94 @@ def test_null_dtype(dtype): assert dtype.nullable is True -@h.given(its.boolean_dtype) +@h.given(its.boolean_dtype()) def test_boolean_dtype(dtype): assert isinstance(dtype, dt.Boolean) assert dtype.is_boolean() is True -@h.given(its.signed_integer_dtypes) +@h.given(its.signed_integer_dtypes()) def test_signed_integer_dtype(dtype): assert isinstance(dtype, dt.SignedInteger) assert dtype.is_integer() is True -@h.given(its.unsigned_integer_dtypes) +@h.given(its.unsigned_integer_dtypes()) def test_unsigned_integer_dtype(dtype): assert isinstance(dtype, dt.UnsignedInteger) assert dtype.is_integer() is True -@h.given(its.floating_dtypes) +@h.given(its.floating_dtypes()) def test_floating_dtype(dtype): assert isinstance(dtype, dt.Floating) assert dtype.is_floating() is True -@h.given(its.numeric_dtypes) +@h.given(its.numeric_dtypes()) def test_numeric_dtype(dtype): assert isinstance(dtype, dt.Numeric) assert dtype.is_numeric() is True -@h.given(its.timestamp_dtype) +@h.given(its.numeric_dtypes(nullable=st.just(True))) +def test_numeric_dtypes_nullable(dtype): + assert dtype.nullable is True + assert dtype.is_numeric() is True + + +@h.given(its.numeric_dtypes(nullable=st.just(False))) +def test_numeric_dtypes_non_nullable(dtype): + assert dtype.nullable is False + assert dtype.is_numeric() is True + + +@h.given(its.timestamp_dtype()) def test_timestamp_dtype(dtype): assert isinstance(dtype, dt.Timestamp) assert isinstance(dtype.timezone, (type(None), str)) assert dtype.is_timestamp() is True -@h.given(its.interval_dtype) +@h.given(its.interval_dtype()) def test_interval_dtype(dtype): assert isinstance(dtype, dt.Interval) assert dtype.is_interval() is True -@h.given(its.temporal_dtypes) +@h.given(its.temporal_dtypes()) def test_temporal_dtype(dtype): assert isinstance(dtype, dt.Temporal) assert dtype.is_temporal() is True -@h.given(its.primitive_dtypes) +@h.given(its.primitive_dtypes()) def test_primitive_dtype(dtype): assert isinstance(dtype, dt.Primitive) assert dtype.is_primitive() is True -@h.given(its.array_dtypes(its.primitive_dtypes)) +@h.given(its.geospatial_dtypes()) +def test_geospatial_dtype(dtype): + assert isinstance(dtype, dt.GeoSpatial) + assert dtype.is_geospatial() is True + + +@h.given(its.array_dtypes(its.primitive_dtypes())) def test_array_dtype(dtype): assert isinstance(dtype, dt.Array) assert isinstance(dtype.value_type, dt.Primitive) assert dtype.is_array() is True -@h.given(its.array_dtypes(its.array_dtypes(its.primitive_dtypes))) +@h.given(its.array_dtypes(its.array_dtypes(its.primitive_dtypes()))) def test_array_array_dtype(dtype): assert isinstance(dtype, dt.Array) assert isinstance(dtype.value_type, dt.Array) assert isinstance(dtype.value_type.value_type, dt.Primitive) -@h.given(its.map_dtypes(its.primitive_dtypes, its.boolean_dtype)) +@h.given(its.map_dtypes(its.primitive_dtypes(), its.boolean_dtype())) def test_map_dtype(dtype): assert isinstance(dtype, dt.Map) assert isinstance(dtype.key_type, dt.Primitive) @@ -100,20 +119,20 @@ def test_struct_dtype(dtype): assert dtype.is_struct() is True -@h.given(its.struct_dtypes(its.variadic_dtypes)) +@h.given(its.struct_dtypes(its.variadic_dtypes())) def test_struct_variadic_dtype(dtype): assert isinstance(dtype, dt.Struct) assert all(t.is_variadic() for t in dtype.types) assert dtype.is_struct() is True -@h.given(its.variadic_dtypes) +@h.given(its.variadic_dtypes()) def test_variadic_dtype(dtype): assert isinstance(dtype, dt.Variadic) assert dtype.is_variadic() is True -@h.given(its.all_dtypes) +@h.given(its.all_dtypes()) def test_all_dtypes(dtype): assert isinstance(dtype, dt.DataType) @@ -126,14 +145,14 @@ def test_schema(schema): assert len(set(schema.names)) == len(schema.names) -@h.given(its.schema(its.array_dtypes(its.numeric_dtypes))) +@h.given(its.schema(its.array_dtypes(its.numeric_dtypes()))) def test_schema_array_dtype(schema): assert isinstance(schema, sch.Schema) assert all(t.is_array() for t in schema.types) assert all(isinstance(n, str) for n in schema.names) -@h.given(its.primitive_dtypes) +@h.given(its.primitive_dtypes()) def test_primitive_dtypes_to_pandas(dtype): assert isinstance(dtype.to_pandas(), np.dtype) @@ -144,7 +163,7 @@ def test_schema_to_pandas(schema): assert len(pandas_schema) == len(schema) -@h.given(its.memtable(its.schema(its.integer_dtypes, max_size=5))) +@h.given(its.memtable(its.schema(its.integer_dtypes(), max_size=5))) def test_memtable(memtable): assert isinstance(memtable, ir.TableExpr) assert isinstance(memtable.schema(), sch.Schema)