From 456cd44879c32bac7f8a798cb8e7e5851e94b4ec Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Sat, 26 Mar 2022 17:05:29 -0400 Subject: [PATCH] feat(mysql): implement _get_schema_from_query --- ibis/backends/mysql/__init__.py | 12 ++ ibis/backends/mysql/datatypes.py | 173 +++++++++++++++++++++++ ibis/backends/mysql/tests/test_client.py | 94 ++++++------ 3 files changed, 234 insertions(+), 45 deletions(-) create mode 100644 ibis/backends/mysql/datatypes.py diff --git a/ibis/backends/mysql/__init__.py b/ibis/backends/mysql/__init__.py index 910b6a75d855..fb47b11b8492 100644 --- a/ibis/backends/mysql/__init__.py +++ b/ibis/backends/mysql/__init__.py @@ -10,9 +10,11 @@ import sqlalchemy.dialects.mysql as mysql import ibis.expr.datatypes as dt +import ibis.expr.schema as sch from ibis.backends.base.sql.alchemy import BaseAlchemyBackend from .compiler import MySQLCompiler +from .datatypes import _type_from_cursor_info class Backend(BaseAlchemyBackend): @@ -121,6 +123,16 @@ def begin(self): query = "SET @@session.time_zone = '{}'" bind.execute(query.format(previous_timezone)) + def _get_schema_using_query(self, query: str) -> sch.Schema: + """Infer the schema of `query`.""" + result = self.con.execute(f"SELECT * FROM ({query}) _ LIMIT 0") + cursor = result.cursor + fields = [ + (field.name, _type_from_cursor_info(descr, field)) + for descr, field in zip(cursor.description, cursor._result.fields) + ] + return sch.Schema.from_tuples(fields) + # TODO(kszucs): unsigned integers diff --git a/ibis/backends/mysql/datatypes.py b/ibis/backends/mysql/datatypes.py new file mode 100644 index 000000000000..367647327638 --- /dev/null +++ b/ibis/backends/mysql/datatypes.py @@ -0,0 +1,173 @@ +from __future__ import annotations + +from functools import partial + +import ibis.expr.datatypes as dt + +# binary character set +# used to distinguish blob binary vs blob text +MY_CHARSET_BIN = 63 + + +def _type_from_cursor_info(descr, field) -> dt.DataType: + """Construct an ibis type from MySQL field descr and field result metadata. + + This method is complex because the MySQL protocol is complex. + + Types are not encoded in a self contained way, meaning you need + multiple pieces of information coming from the result set metadata to + determine the most precise type for a field. Even then, the decoding is + not high fidelity in some cases: UUIDs for example are decoded as + strings, because the protocol does not appear to preserve the logical + type, only the physical type. + """ + from pymysql.connections import TEXT_TYPES + + _, type_code, _, _, field_length, scale, _ = descr + flags = _FieldFlags(field.flags) + typename = _type_codes.get(type_code) + if typename is None: + raise NotImplementedError( + f"MySQL type code {type_code:d} is not supported" + ) + + typ = _type_mapping[typename] + + if typename in ("DECIMAL", "NEWDECIMAL"): + precision = _decimal_length_to_precision( + length=field_length, + scale=scale, + is_unsigned=flags.is_unsigned, + ) + typ = partial(typ, precision=precision, scale=scale) + elif typename == "BIT": + if field_length <= 8: + typ = dt.int8 + elif field_length <= 16: + typ = dt.int16 + elif field_length <= 32: + typ = dt.int32 + elif field_length <= 64: + typ = dt.int64 + else: + assert False, "invalid field length for BIT type" + else: + if flags.is_set: + # sets are limited to strings + typ = dt.Set(dt.string) + elif flags.is_unsigned and flags.is_num: + typ = getattr(dt, f"U{typ.__name__}") + elif type_code in TEXT_TYPES: + # binary text + if field.charsetnr == MY_CHARSET_BIN: + typ = dt.Binary + else: + typ = dt.String + + # projection columns are always nullable + return typ(nullable=True) + + +# ported from my_decimal.h:my_decimal_length_to_precision in mariadb +def _decimal_length_to_precision( + *, + length: int, + scale: int, + is_unsigned: bool, +) -> int: + return length - (scale > 0) - (not (is_unsigned or not length)) + + +_type_codes = { + 0: "DECIMAL", + 1: "TINY", + 2: "SHORT", + 3: "LONG", + 4: "FLOAT", + 5: "DOUBLE", + 6: "NULL", + 7: "TIMESTAMP", + 8: "LONGLONG", + 9: "INT24", + 10: "DATE", + 11: "TIME", + 12: "DATETIME", + 13: "YEAR", + 15: "VARCHAR", + 16: "BIT", + 245: "JSON", + 246: "NEWDECIMAL", + 247: "ENUM", + 248: "SET", + 249: "TINY_BLOB", + 250: "MEDIUM_BLOB", + 251: "LONG_BLOB", + 252: "BLOB", + 253: "VAR_STRING", + 254: "STRING", + 255: "GEOMETRY", +} + + +_type_mapping = { + "DECIMAL": dt.Decimal, + "TINY": dt.Int8, + "SHORT": dt.Int16, + "LONG": dt.Int32, + "FLOAT": dt.Float32, + "DOUBLE": dt.Float64, + "NULL": dt.Null, + "TIMESTAMP": lambda nullable: dt.Timestamp( + timezone="UTC", + nullable=nullable, + ), + "LONGLONG": dt.Int64, + "INT24": dt.Int32, + "DATE": dt.Date, + "TIME": dt.Time, + "DATETIME": dt.Timestamp, + "YEAR": dt.Int16, + "VARCHAR": dt.String, + "BIT": dt.Int8, + "JSON": dt.JSON, + "NEWDECIMAL": dt.Decimal, + "ENUM": dt.String, + "SET": lambda nullable: dt.Set(dt.string, nullable=nullable), + "TINY_BLOB": dt.Binary, + "MEDIUM_BLOB": dt.Binary, + "LONG_BLOB": dt.Binary, + "BLOB": dt.Binary, + "VAR_STRING": dt.String, + "STRING": dt.String, + "GEOMETRY": dt.Geometry, +} + + +class _FieldFlags: + """Flags used to disambiguate field types. + + Gaps in the flag numbers are because we do not map in flags that are of no + use in determining the field's type, such as whether the field is a primary + key or not. + """ + + UNSIGNED = 1 << 5 + SET = 1 << 11 + NUM = 1 << 15 + + __slots__ = ("value",) + + def __init__(self, value: int) -> None: + self.value = value + + @property + def is_unsigned(self) -> bool: + return (self.UNSIGNED & self.value) != 0 + + @property + def is_set(self) -> bool: + return (self.SET & self.value) != 0 + + @property + def is_num(self) -> bool: + return (self.NUM & self.value) != 0 diff --git a/ibis/backends/mysql/tests/test_client.py b/ibis/backends/mysql/tests/test_client.py index 703a3cf28c4e..c6bcd357e319 100644 --- a/ibis/backends/mysql/tests/test_client.py +++ b/ibis/backends/mysql/tests/test_client.py @@ -4,60 +4,64 @@ import ibis import ibis.expr.datatypes as dt +MYSQL_TYPES = [ + ("tinyint", dt.int8), + ("int1", dt.int8), + ("boolean", dt.int8), + ("smallint", dt.int16), + ("int2", dt.int16), + ("mediumint", dt.int32), + ("int3", dt.int32), + ("int", dt.int32), + ("int4", dt.int32), + ("integer", dt.int32), + ("bigint", dt.int64), + ("decimal", dt.Decimal(10, 0)), + ("decimal(5, 2)", dt.Decimal(5, 2)), + ("dec", dt.Decimal(10, 0)), + ("numeric", dt.Decimal(10, 0)), + ("fixed", dt.Decimal(10, 0)), + ("float", dt.float32), + ("double", dt.float64), + ("timestamp", dt.Timestamp("UTC")), + ("date", dt.date), + ("time", dt.time), + ("datetime", dt.timestamp), + ("year", dt.int16), + ("char(32)", dt.string), + ("char byte", dt.binary), + ("varchar(42)", dt.string), + ("mediumtext", dt.string), + ("text", dt.string), + ("binary(42)", dt.binary), + ("varbinary(42)", dt.binary), + ("bit(1)", dt.int8), + ("bit(9)", dt.int16), + ("bit(17)", dt.int32), + ("bit(33)", dt.int64), + # mariadb doesn't have a distinct json type + ("json", dt.string), + ("enum('small', 'medium', 'large')", dt.string), + ("inet6", dt.string), + ("set('a', 'b', 'c', 'd')", dt.Set(dt.string)), + ("mediumblob", dt.binary), + ("blob", dt.binary), + ("uuid", dt.string), +] + @pytest.mark.parametrize( ("mysql_type", "expected_type"), [ - param(mysql_type, ibis_type, id=mysql_type.lower()) - for mysql_type, ibis_type in [ - ("tinyint", dt.int8), - ("int1", dt.int8), - ("boolean", dt.int8), - ("smallint", dt.int16), - ("int2", dt.int16), - ("mediumint", dt.int32), - ("int3", dt.int32), - ("int", dt.int32), - ("int4", dt.int32), - ("integer", dt.int32), - ("bigint", dt.int64), - ("decimal", dt.Decimal(10, 0)), - ("decimal(5, 2)", dt.Decimal(5, 2)), - ("dec", dt.Decimal(10, 0)), - ("numeric", dt.Decimal(10, 0)), - ("fixed", dt.Decimal(10, 0)), - ("float", dt.float32), - ("double", dt.float64), - ("timestamp", dt.Timestamp("UTC")), - ("date", dt.date), - ("time", dt.time), - ("datetime", dt.timestamp), - ("year", dt.int16), - ("char(32)", dt.string), - ("char byte", dt.binary), - ("varchar(42)", dt.string), - ("mediumtext", dt.string), - ("text", dt.string), - ("binary(42)", dt.binary), - ("varbinary(42)", dt.binary), - ("bit(1)", dt.int8), - ("bit(9)", dt.int16), - ("bit(17)", dt.int32), - ("bit(33)", dt.int64), - # mariadb doesn't have a distinct json type - ("json", dt.string), - ("enum('small', 'medium', 'large')", dt.string), - ("inet6", dt.string), - ("set('a', 'b', 'c', 'd')", dt.Set(dt.string)), - ("mediumblob", dt.binary), - ("blob", dt.binary), - ("uuid", dt.string), - ] + param(mysql_type, ibis_type, id=mysql_type) + for mysql_type, ibis_type in MYSQL_TYPES ], ) def test_get_schema_from_query(con, mysql_type, expected_type): raw_name = ibis.util.guid() name = con.con.dialect.identifier_preparer.quote_identifier(raw_name) + # temporary tables get cleaned up by the db when the session ends, so we + # don't need to explicitly drop the table con.raw_sql( f"CREATE TEMPORARY TABLE {name} (x {mysql_type}, y {mysql_type})" )