Skip to content

Commit

Permalink
feat(mysql): implement _get_schema_from_query
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud authored and kszucs committed Mar 30, 2022
1 parent 1c8d484 commit 456cd44
Show file tree
Hide file tree
Showing 3 changed files with 234 additions and 45 deletions.
12 changes: 12 additions & 0 deletions ibis/backends/mysql/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@
import sqlalchemy.dialects.mysql as mysql

import ibis.expr.datatypes as dt
import ibis.expr.schema as sch
from ibis.backends.base.sql.alchemy import BaseAlchemyBackend

from .compiler import MySQLCompiler
from .datatypes import _type_from_cursor_info


class Backend(BaseAlchemyBackend):
Expand Down Expand Up @@ -121,6 +123,16 @@ def begin(self):
query = "SET @@session.time_zone = '{}'"
bind.execute(query.format(previous_timezone))

def _get_schema_using_query(self, query: str) -> sch.Schema:
"""Infer the schema of `query`."""
result = self.con.execute(f"SELECT * FROM ({query}) _ LIMIT 0")
cursor = result.cursor
fields = [
(field.name, _type_from_cursor_info(descr, field))
for descr, field in zip(cursor.description, cursor._result.fields)
]
return sch.Schema.from_tuples(fields)


# TODO(kszucs): unsigned integers

Expand Down
173 changes: 173 additions & 0 deletions ibis/backends/mysql/datatypes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
from __future__ import annotations

from functools import partial

import ibis.expr.datatypes as dt

# binary character set
# used to distinguish blob binary vs blob text
MY_CHARSET_BIN = 63


def _type_from_cursor_info(descr, field) -> dt.DataType:
"""Construct an ibis type from MySQL field descr and field result metadata.
This method is complex because the MySQL protocol is complex.
Types are not encoded in a self contained way, meaning you need
multiple pieces of information coming from the result set metadata to
determine the most precise type for a field. Even then, the decoding is
not high fidelity in some cases: UUIDs for example are decoded as
strings, because the protocol does not appear to preserve the logical
type, only the physical type.
"""
from pymysql.connections import TEXT_TYPES

_, type_code, _, _, field_length, scale, _ = descr
flags = _FieldFlags(field.flags)
typename = _type_codes.get(type_code)
if typename is None:
raise NotImplementedError(
f"MySQL type code {type_code:d} is not supported"
)

typ = _type_mapping[typename]

if typename in ("DECIMAL", "NEWDECIMAL"):
precision = _decimal_length_to_precision(
length=field_length,
scale=scale,
is_unsigned=flags.is_unsigned,
)
typ = partial(typ, precision=precision, scale=scale)
elif typename == "BIT":
if field_length <= 8:
typ = dt.int8
elif field_length <= 16:
typ = dt.int16
elif field_length <= 32:
typ = dt.int32
elif field_length <= 64:
typ = dt.int64
else:
assert False, "invalid field length for BIT type"
else:
if flags.is_set:
# sets are limited to strings
typ = dt.Set(dt.string)
elif flags.is_unsigned and flags.is_num:
typ = getattr(dt, f"U{typ.__name__}")
elif type_code in TEXT_TYPES:
# binary text
if field.charsetnr == MY_CHARSET_BIN:
typ = dt.Binary
else:
typ = dt.String

# projection columns are always nullable
return typ(nullable=True)


# ported from my_decimal.h:my_decimal_length_to_precision in mariadb
def _decimal_length_to_precision(
*,
length: int,
scale: int,
is_unsigned: bool,
) -> int:
return length - (scale > 0) - (not (is_unsigned or not length))


_type_codes = {
0: "DECIMAL",
1: "TINY",
2: "SHORT",
3: "LONG",
4: "FLOAT",
5: "DOUBLE",
6: "NULL",
7: "TIMESTAMP",
8: "LONGLONG",
9: "INT24",
10: "DATE",
11: "TIME",
12: "DATETIME",
13: "YEAR",
15: "VARCHAR",
16: "BIT",
245: "JSON",
246: "NEWDECIMAL",
247: "ENUM",
248: "SET",
249: "TINY_BLOB",
250: "MEDIUM_BLOB",
251: "LONG_BLOB",
252: "BLOB",
253: "VAR_STRING",
254: "STRING",
255: "GEOMETRY",
}


_type_mapping = {
"DECIMAL": dt.Decimal,
"TINY": dt.Int8,
"SHORT": dt.Int16,
"LONG": dt.Int32,
"FLOAT": dt.Float32,
"DOUBLE": dt.Float64,
"NULL": dt.Null,
"TIMESTAMP": lambda nullable: dt.Timestamp(
timezone="UTC",
nullable=nullable,
),
"LONGLONG": dt.Int64,
"INT24": dt.Int32,
"DATE": dt.Date,
"TIME": dt.Time,
"DATETIME": dt.Timestamp,
"YEAR": dt.Int16,
"VARCHAR": dt.String,
"BIT": dt.Int8,
"JSON": dt.JSON,
"NEWDECIMAL": dt.Decimal,
"ENUM": dt.String,
"SET": lambda nullable: dt.Set(dt.string, nullable=nullable),
"TINY_BLOB": dt.Binary,
"MEDIUM_BLOB": dt.Binary,
"LONG_BLOB": dt.Binary,
"BLOB": dt.Binary,
"VAR_STRING": dt.String,
"STRING": dt.String,
"GEOMETRY": dt.Geometry,
}


class _FieldFlags:
"""Flags used to disambiguate field types.
Gaps in the flag numbers are because we do not map in flags that are of no
use in determining the field's type, such as whether the field is a primary
key or not.
"""

UNSIGNED = 1 << 5
SET = 1 << 11
NUM = 1 << 15

__slots__ = ("value",)

def __init__(self, value: int) -> None:
self.value = value

@property
def is_unsigned(self) -> bool:
return (self.UNSIGNED & self.value) != 0

@property
def is_set(self) -> bool:
return (self.SET & self.value) != 0

@property
def is_num(self) -> bool:
return (self.NUM & self.value) != 0
94 changes: 49 additions & 45 deletions ibis/backends/mysql/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,60 +4,64 @@
import ibis
import ibis.expr.datatypes as dt

MYSQL_TYPES = [
("tinyint", dt.int8),
("int1", dt.int8),
("boolean", dt.int8),
("smallint", dt.int16),
("int2", dt.int16),
("mediumint", dt.int32),
("int3", dt.int32),
("int", dt.int32),
("int4", dt.int32),
("integer", dt.int32),
("bigint", dt.int64),
("decimal", dt.Decimal(10, 0)),
("decimal(5, 2)", dt.Decimal(5, 2)),
("dec", dt.Decimal(10, 0)),
("numeric", dt.Decimal(10, 0)),
("fixed", dt.Decimal(10, 0)),
("float", dt.float32),
("double", dt.float64),
("timestamp", dt.Timestamp("UTC")),
("date", dt.date),
("time", dt.time),
("datetime", dt.timestamp),
("year", dt.int16),
("char(32)", dt.string),
("char byte", dt.binary),
("varchar(42)", dt.string),
("mediumtext", dt.string),
("text", dt.string),
("binary(42)", dt.binary),
("varbinary(42)", dt.binary),
("bit(1)", dt.int8),
("bit(9)", dt.int16),
("bit(17)", dt.int32),
("bit(33)", dt.int64),
# mariadb doesn't have a distinct json type
("json", dt.string),
("enum('small', 'medium', 'large')", dt.string),
("inet6", dt.string),
("set('a', 'b', 'c', 'd')", dt.Set(dt.string)),
("mediumblob", dt.binary),
("blob", dt.binary),
("uuid", dt.string),
]


@pytest.mark.parametrize(
("mysql_type", "expected_type"),
[
param(mysql_type, ibis_type, id=mysql_type.lower())
for mysql_type, ibis_type in [
("tinyint", dt.int8),
("int1", dt.int8),
("boolean", dt.int8),
("smallint", dt.int16),
("int2", dt.int16),
("mediumint", dt.int32),
("int3", dt.int32),
("int", dt.int32),
("int4", dt.int32),
("integer", dt.int32),
("bigint", dt.int64),
("decimal", dt.Decimal(10, 0)),
("decimal(5, 2)", dt.Decimal(5, 2)),
("dec", dt.Decimal(10, 0)),
("numeric", dt.Decimal(10, 0)),
("fixed", dt.Decimal(10, 0)),
("float", dt.float32),
("double", dt.float64),
("timestamp", dt.Timestamp("UTC")),
("date", dt.date),
("time", dt.time),
("datetime", dt.timestamp),
("year", dt.int16),
("char(32)", dt.string),
("char byte", dt.binary),
("varchar(42)", dt.string),
("mediumtext", dt.string),
("text", dt.string),
("binary(42)", dt.binary),
("varbinary(42)", dt.binary),
("bit(1)", dt.int8),
("bit(9)", dt.int16),
("bit(17)", dt.int32),
("bit(33)", dt.int64),
# mariadb doesn't have a distinct json type
("json", dt.string),
("enum('small', 'medium', 'large')", dt.string),
("inet6", dt.string),
("set('a', 'b', 'c', 'd')", dt.Set(dt.string)),
("mediumblob", dt.binary),
("blob", dt.binary),
("uuid", dt.string),
]
param(mysql_type, ibis_type, id=mysql_type)
for mysql_type, ibis_type in MYSQL_TYPES
],
)
def test_get_schema_from_query(con, mysql_type, expected_type):
raw_name = ibis.util.guid()
name = con.con.dialect.identifier_preparer.quote_identifier(raw_name)
# temporary tables get cleaned up by the db when the session ends, so we
# don't need to explicitly drop the table
con.raw_sql(
f"CREATE TEMPORARY TABLE {name} (x {mysql_type}, y {mysql_type})"
)
Expand Down

0 comments on commit 456cd44

Please sign in to comment.