Skip to content

Commit

Permalink
feat(mssql): implement inference for DATETIME2 and DATETIMEOFFSET
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud committed Jan 20, 2023
1 parent 099d1ec commit aa9f151
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 17 deletions.
14 changes: 14 additions & 0 deletions ibis/backends/base/sql/alchemy/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,20 @@ def sa_datetime(_, satype, nullable=True, default_timezone='UTC'):
return dt.Timestamp(timezone=timezone, nullable=nullable)


@dt.dtype.register(MSDialect, mssql.DATETIMEOFFSET)
def _datetimeoffset(_, sa_type, nullable=True):
if (prec := sa_type.precision) is None:
prec = 7
return dt.Timestamp(scale=prec, timezone="UTC", nullable=nullable)


@dt.dtype.register(MSDialect, mssql.DATETIME2)
def _datetime2(_, sa_type, nullable=True):
if (prec := sa_type.precision) is None:
prec = 7
return dt.Timestamp(scale=prec, nullable=nullable)


@dt.dtype.register(PGDialect, sa.ARRAY)
def sa_pg_array(dialect, satype, nullable=True):
dimensions = satype.dimensions
Expand Down
14 changes: 14 additions & 0 deletions ibis/backends/mssql/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ def _type_from_result_set_info(col: _FieldDescription) -> dt.DataType:
typ = partial(typ, precision=col["precision"], scale=col['scale'])
elif typename in ("GEOMETRY", "GEOGRAPHY"):
typ = partial(typ, geotype=typename.lower())
elif typename == 'DATETIME2':
typ = partial(typ, scale=col["scale"])
elif typename == 'DATETIMEOFFSET':
typ = partial(typ, scale=col["scale"], timezone="UTC")
elif typename == 'FLOAT':
if col['precision'] <= 24:
typ = dt.Float32
Expand Down Expand Up @@ -98,3 +102,13 @@ def _type_from_result_set_info(col: _FieldDescription) -> dt.DataType:
@to_sqla_type.register(mssql.dialect, tuple(_MSSQL_TYPE_MAP.keys()))
def _simple_types(_, itype):
return _MSSQL_TYPE_MAP[type(itype)]


@to_sqla_type.register(mssql.dialect, dt.Timestamp)
def _datetime(_, itype):
if (precision := itype.scale) is None:
precision = 7
if itype.timezone is not None:
return mssql.DATETIMEOFFSET(precision=precision)
else:
return mssql.DATETIME2(precision=precision)
20 changes: 17 additions & 3 deletions ibis/backends/mssql/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@
# Date and time
('DATE', dt.date),
('TIME', dt.time),
('DATETIME2', dt.timestamp),
('DATETIMEOFFSET', dt.timestamp),
('DATETIME2', dt.timestamp(scale=7)),
('DATETIMEOFFSET', dt.timestamp(scale=7, timezone="UTC")),
('SMALLDATETIME', dt.timestamp),
('DATETIME', dt.timestamp),
# Characters strings
Expand All @@ -54,13 +54,27 @@
not geospatial_supported, reason="geospatial dependencies not installed"
)

broken_sqlalchemy_autoload = pytest.mark.xfail(
reason="scale not inferred by sqlalchemy autoload"
)


@pytest.mark.parametrize(
("server_type", "expected_type"),
DB_TYPES
+ [
param("GEOMETRY", dt.geometry, marks=[skipif_no_geospatial_deps]),
param("GEOGRAPHY", dt.geography, marks=[skipif_no_geospatial_deps]),
]
+ [
param(
'DATETIME2(4)', dt.timestamp(scale=4), marks=[broken_sqlalchemy_autoload]
),
param(
'DATETIMEOFFSET(5)',
dt.timestamp(scale=5, timezone="UTC"),
marks=[broken_sqlalchemy_autoload],
),
],
ids=str,
)
Expand All @@ -73,9 +87,9 @@ def test_get_schema_from_query(con, server_type, expected_type):
c.execute(sa.text(f"CREATE TABLE {name} (x {server_type})"))
expected_schema = ibis.schema(dict(x=expected_type))
result_schema = con._get_schema_using_query(f"SELECT * FROM {name}")
assert result_schema == expected_schema
t = con.table(raw_name)
assert t.schema() == expected_schema
assert result_schema == expected_schema
finally:
with con.begin() as c:
c.execute(sa.text(f"DROP TABLE IF EXISTS {name}"))
19 changes: 5 additions & 14 deletions ibis/backends/tests/test_temporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,7 @@ def test_temporal_binop_pandas_timedelta(


@pytest.mark.parametrize("func_name", ["gt", "ge", "lt", "le", "eq", "ne"])
@pytest.mark.notimpl(["bigquery", "mssql"])
@pytest.mark.notimpl(["bigquery"])
def test_timestamp_comparison_filter(backend, con, alltypes, df, func_name):
ts = pd.Timestamp('20100302', tz="UTC").to_pydatetime()

Expand All @@ -490,7 +490,7 @@ def test_timestamp_comparison_filter(backend, con, alltypes, df, func_name):
"ne",
],
)
@pytest.mark.notimpl(["bigquery", "mssql"])
@pytest.mark.notimpl(["bigquery"])
def test_timestamp_comparison_filter_numpy(backend, con, alltypes, df, func_name):
ts = np.datetime64('2010-03-02 00:00:00.000123')

Expand Down Expand Up @@ -993,21 +993,12 @@ def test_large_timestamp(con):
id="ns",
marks=[
pytest.mark.broken(
[
"clickhouse",
"duckdb",
"impala",
"mssql",
"postgres",
"pyspark",
"sqlite",
"trino",
],
["clickhouse", "duckdb", "impala", "pyspark", "trino"],
reason="drivers appear to truncate nanos",
),
pytest.mark.notyet(
["bigquery"],
reason="bigquery doesn't support nanosecond timestamps",
["bigquery", "mssql", "postgres", "sqlite"],
reason="doesn't support nanoseconds",
),
],
),
Expand Down

0 comments on commit aa9f151

Please sign in to comment.