Skip to content

Commit

Permalink
refactor(mysql): use describe temporary table to retrieve ibis schema…
Browse files Browse the repository at this point in the history
… from query
  • Loading branch information
kszucs authored and cpcloud committed Sep 4, 2023
1 parent f5490c3 commit a723637
Show file tree
Hide file tree
Showing 8 changed files with 88 additions and 190 deletions.
7 changes: 3 additions & 4 deletions ibis/backends/base/sql/alchemy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,7 +553,6 @@ def _handle_failed_column_type_inference(
self, table: sa.Table, nulltype_cols: Iterable[str]
) -> sa.Table:
"""Handle cases where SQLAlchemy cannot infer the column types of `table`."""

self.inspector.reflect_table(table, table.columns)

dialect = self.con.dialect
Expand All @@ -565,15 +564,15 @@ def _handle_failed_column_type_inference(
)
)

for colname, type in self._metadata(quoted_name):
for colname, dtype in self._metadata(quoted_name):
if colname in nulltype_cols:
# replace null types discovered by sqlalchemy with non null
# types
table.append_column(
sa.Column(
colname,
self.compiler.translator_class.get_sqla_type(type),
nullable=type.nullable,
self.compiler.translator_class.get_sqla_type(dtype),
nullable=dtype.nullable,
quote=self.compiler.translator_class._quote_column_names,
),
replace_existing=True,
Expand Down
1 change: 0 additions & 1 deletion ibis/backends/base/sql/alchemy/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,6 @@ def to_ibis(cls, typ: sat.TypeEngine, nullable: bool = True) -> dt.DataType:
-------
Ibis type.
"""

if dtype := _from_sqlalchemy_types.get(type(typ)):
return dtype(nullable=nullable)
elif isinstance(typ, sat.Float):
Expand Down
35 changes: 34 additions & 1 deletion ibis/backends/base/sql/glot/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
typecode.BIGDECIMAL: partial(dt.Decimal, 76, 38),
typecode.BIGINT: dt.Int64,
typecode.BINARY: dt.Binary,
typecode.BIT: dt.String,
# typecode.BIT: dt.String,
typecode.BOOLEAN: dt.Boolean,
typecode.CHAR: dt.String,
typecode.DATE: dt.Date,
Expand All @@ -42,6 +42,7 @@
typecode.MEDIUMTEXT: dt.String,
typecode.MONEY: dt.Int64,
typecode.NCHAR: dt.String,
typecode.UUID: dt.UUID,
typecode.NULL: dt.Null,
typecode.NVARCHAR: dt.String,
typecode.OBJECT: partial(dt.Map, dt.string, dt.json),
Expand All @@ -60,6 +61,7 @@
typecode.VARCHAR: dt.String,
typecode.VARIANT: dt.JSON,
typecode.UNIQUEIDENTIFIER: dt.UUID,
typecode.SET: partial(dt.Array, dt.string),
#############################
# Unsupported sqlglot types #
#############################
Expand Down Expand Up @@ -305,6 +307,37 @@ class PostgresType(SqlglotType):
)


class MySQLType(SqlglotType):
dialect = "mysql"

unknown_type_strings = FrozenDict(
{
"year(4)": dt.int8,
"inet6": dt.inet,
}
)

@classmethod
def _from_sqlglot_BIT(cls, nbits: sge.DataTypeParam) -> dt.Integer:
nbits = int(nbits.this.this)
if nbits > 32:
return dt.Int64(nullable=cls.default_nullable)
elif nbits > 16:
return dt.Int32(nullable=cls.default_nullable)
elif nbits > 8:
return dt.Int16(nullable=cls.default_nullable)
else:
return dt.Int8(nullable=cls.default_nullable)

@classmethod
def _from_sqlglot_DATETIME(cls) -> dt.Timestamp:
return dt.Timestamp(nullable=cls.default_nullable)

@classmethod
def _from_sqlglot_TIMESTAMP(cls) -> dt.Timestamp:
return dt.Timestamp(timezone="UTC", nullable=cls.default_nullable)


class DuckDBType(SqlglotType):
dialect = "duckdb"
default_decimal_precision = 18
Expand Down
41 changes: 27 additions & 14 deletions ibis/backends/mysql/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,18 @@

from __future__ import annotations

import re
import warnings
from typing import TYPE_CHECKING, Literal

import sqlalchemy as sa
from sqlalchemy.dialects import mysql

import ibis.expr.schema as sch
from ibis import util
from ibis.backends.base import CanCreateDatabase
from ibis.backends.base.sql.alchemy import BaseAlchemyBackend
from ibis.backends.mysql.compiler import MySQLCompiler
from ibis.backends.mysql.datatypes import MySQLDateTime, _type_from_cursor_info
from ibis.backends.mysql.datatypes import MySQLDateTime, MySQLType

if TYPE_CHECKING:
from collections.abc import Iterable
Expand Down Expand Up @@ -146,20 +147,32 @@ def list_databases(self, like: str | None = None) -> list[str]:
databases = self.inspector.get_schema_names()
return self._filter_with_like(databases, like)

def _metadata(self, query: str) -> Iterable[tuple[str, dt.DataType]]:
if (
re.search(r"^\s*SELECT\s", query, flags=re.MULTILINE | re.IGNORECASE)
is not None
):
query = f"({query})"
def _metadata(self, table: str) -> Iterable[tuple[str, dt.DataType]]:
with self.begin() as con:
result = con.exec_driver_sql(f"DESCRIBE {table}").mappings().all()

for field in result:
name = field["Field"]
type_string = field["Type"]
is_nullable = field["Null"] == "YES"
yield name, MySQLType.from_string(type_string, nullable=is_nullable)

def _get_schema_using_query(self, query: str):
table = f"__ibis_mysql_metadata_{util.guid()}"

with self.begin() as con:
result = con.exec_driver_sql(f"SELECT * FROM {query} _ LIMIT 0")
cursor = result.cursor
yield from (
(field.name, _type_from_cursor_info(descr, field))
for descr, field in zip(cursor.description, cursor._result.fields)
)
con.exec_driver_sql(f"CREATE TEMPORARY TABLE {table} AS {query}")
result = con.exec_driver_sql(f"DESCRIBE {table}").mappings().all()
con.exec_driver_sql(f"DROP TABLE {table}")

fields = {}
for field in result:
name = field["Field"]
type_string = field["Type"]
is_nullable = field["Null"] == "YES"
fields[name] = MySQLType.from_string(type_string, nullable=is_nullable)

return sch.Schema(fields)

def _get_temp_view_definition(
self, name: str, definition: sa.sql.compiler.Compiled
Expand Down
167 changes: 7 additions & 160 deletions ibis/backends/mysql/datatypes.py
Original file line number Diff line number Diff line change
@@ -1,168 +1,11 @@
from __future__ import annotations

from functools import partial

import sqlalchemy.types as sat
from sqlalchemy.dialects import mysql

import ibis.expr.datatypes as dt
from ibis.backends.base.sql.alchemy.datatypes import UUID, AlchemyType

# binary character set
# used to distinguish blob binary vs blob text
MY_CHARSET_BIN = 63


def _type_from_cursor_info(descr, field) -> dt.DataType:
"""Construct an ibis type from MySQL field descr and field result metadata.
This method is complex because the MySQL protocol is complex.
Types are not encoded in a self contained way, meaning you need
multiple pieces of information coming from the result set metadata to
determine the most precise type for a field. Even then, the decoding is
not high fidelity in some cases: UUIDs for example are decoded as
strings, because the protocol does not appear to preserve the logical
type, only the physical type.
"""
from pymysql.connections import TEXT_TYPES

_, type_code, _, _, field_length, scale, _ = descr
flags = _FieldFlags(field.flags)
typename = _type_codes.get(type_code)
if typename is None:
raise NotImplementedError(f"MySQL type code {type_code:d} is not supported")

if typename in ("DECIMAL", "NEWDECIMAL"):
precision = _decimal_length_to_precision(
length=field_length,
scale=scale,
is_unsigned=flags.is_unsigned,
)
typ = partial(_type_mapping[typename], precision=precision, scale=scale)
elif typename == "BIT":
if field_length <= 8:
typ = dt.int8
elif field_length <= 16:
typ = dt.int16
elif field_length <= 32:
typ = dt.int32
elif field_length <= 64:
typ = dt.int64
else:
raise AssertionError("invalid field length for BIT type")
elif flags.is_set:
# sets are limited to strings
typ = dt.Array(dt.string)
elif flags.is_unsigned and flags.is_num:
typ = getattr(dt, f"U{typ.__name__}")
elif type_code in TEXT_TYPES:
# binary text
if field.charsetnr == MY_CHARSET_BIN:
typ = dt.Binary
else:
typ = dt.String
else:
typ = _type_mapping[typename]

# projection columns are always nullable
return typ(nullable=True)


# ported from my_decimal.h:my_decimal_length_to_precision in mariadb
def _decimal_length_to_precision(*, length: int, scale: int, is_unsigned: bool) -> int:
return length - (scale > 0) - (not (is_unsigned or not length))


_type_codes = {
0: "DECIMAL",
1: "TINY",
2: "SHORT",
3: "LONG",
4: "FLOAT",
5: "DOUBLE",
6: "NULL",
7: "TIMESTAMP",
8: "LONGLONG",
9: "INT24",
10: "DATE",
11: "TIME",
12: "DATETIME",
13: "YEAR",
15: "VARCHAR",
16: "BIT",
245: "JSON",
246: "NEWDECIMAL",
247: "ENUM",
248: "SET",
249: "TINY_BLOB",
250: "MEDIUM_BLOB",
251: "LONG_BLOB",
252: "BLOB",
253: "VAR_STRING",
254: "STRING",
255: "GEOMETRY",
}


_type_mapping = {
"DECIMAL": dt.Decimal,
"TINY": dt.Int8,
"SHORT": dt.Int16,
"LONG": dt.Int32,
"FLOAT": dt.Float32,
"DOUBLE": dt.Float64,
"NULL": dt.Null,
"TIMESTAMP": lambda nullable: dt.Timestamp(timezone="UTC", nullable=nullable),
"LONGLONG": dt.Int64,
"INT24": dt.Int32,
"DATE": dt.Date,
"TIME": dt.Time,
"DATETIME": dt.Timestamp,
"YEAR": dt.Int8,
"VARCHAR": dt.String,
"JSON": dt.JSON,
"NEWDECIMAL": dt.Decimal,
"ENUM": dt.String,
"SET": lambda nullable: dt.Array(dt.string, nullable=nullable),
"TINY_BLOB": dt.Binary,
"MEDIUM_BLOB": dt.Binary,
"LONG_BLOB": dt.Binary,
"BLOB": dt.Binary,
"VAR_STRING": dt.String,
"STRING": dt.String,
"GEOMETRY": dt.Geometry,
}


class _FieldFlags:
"""Flags used to disambiguate field types.
Gaps in the flag numbers are because we do not map in flags that are
of no use in determining the field's type, such as whether the field
is a primary key or not.
"""

UNSIGNED = 1 << 5
SET = 1 << 11
NUM = 1 << 15

__slots__ = ("value",)

def __init__(self, value: int) -> None:
self.value = value

@property
def is_unsigned(self) -> bool:
return (self.UNSIGNED & self.value) != 0

@property
def is_set(self) -> bool:
return (self.SET & self.value) != 0

@property
def is_num(self) -> bool:
return (self.NUM & self.value) != 0
from ibis.backends.base.sql.glot.datatypes import MySQLType as SqlglotMySQLType


class MySQLDateTime(mysql.DATETIME):
Expand Down Expand Up @@ -214,7 +57,7 @@ def result_processor(self, *_):
mysql.TIME: dt.Time,
mysql.YEAR: dt.Int8,
MySQLDateTime: dt.Timestamp,
UUID: dt.String,
UUID: dt.UUID,
}


Expand Down Expand Up @@ -247,8 +90,12 @@ def to_ibis(cls, typ, nullable=True):
elif isinstance(typ, mysql.TIMESTAMP):
return dt.Timestamp(timezone="UTC", nullable=nullable)
elif isinstance(typ, mysql.SET):
return dt.Set(dt.string, nullable=nullable)
return dt.Array(dt.string, nullable=nullable)
elif dtype := _from_mysql_types.get(type(typ)):
return dtype(nullable=nullable)
else:
return super().to_ibis(typ, nullable=nullable)

@classmethod
def from_string(cls, type_string, nullable=True):
return SqlglotMySQLType.from_string(type_string, nullable=nullable)
Loading

0 comments on commit a723637

Please sign in to comment.