diff --git a/ibis/backends/base/sql/alchemy/__init__.py b/ibis/backends/base/sql/alchemy/__init__.py index 7d17f7ef799d..a6dec302f7c3 100644 --- a/ibis/backends/base/sql/alchemy/__init__.py +++ b/ibis/backends/base/sql/alchemy/__init__.py @@ -553,7 +553,6 @@ def _handle_failed_column_type_inference( self, table: sa.Table, nulltype_cols: Iterable[str] ) -> sa.Table: """Handle cases where SQLAlchemy cannot infer the column types of `table`.""" - self.inspector.reflect_table(table, table.columns) dialect = self.con.dialect @@ -565,15 +564,15 @@ def _handle_failed_column_type_inference( ) ) - for colname, type in self._metadata(quoted_name): + for colname, dtype in self._metadata(quoted_name): if colname in nulltype_cols: # replace null types discovered by sqlalchemy with non null # types table.append_column( sa.Column( colname, - self.compiler.translator_class.get_sqla_type(type), - nullable=type.nullable, + self.compiler.translator_class.get_sqla_type(dtype), + nullable=dtype.nullable, quote=self.compiler.translator_class._quote_column_names, ), replace_existing=True, diff --git a/ibis/backends/base/sql/alchemy/datatypes.py b/ibis/backends/base/sql/alchemy/datatypes.py index b80599b1e43f..e866f48ea695 100644 --- a/ibis/backends/base/sql/alchemy/datatypes.py +++ b/ibis/backends/base/sql/alchemy/datatypes.py @@ -284,7 +284,6 @@ def to_ibis(cls, typ: sat.TypeEngine, nullable: bool = True) -> dt.DataType: ------- Ibis type. """ - if dtype := _from_sqlalchemy_types.get(type(typ)): return dtype(nullable=nullable) elif isinstance(typ, sat.Float): diff --git a/ibis/backends/base/sql/glot/datatypes.py b/ibis/backends/base/sql/glot/datatypes.py index 4dc5e9bf65fd..a9393b7fe9d3 100644 --- a/ibis/backends/base/sql/glot/datatypes.py +++ b/ibis/backends/base/sql/glot/datatypes.py @@ -16,7 +16,7 @@ typecode.BIGDECIMAL: partial(dt.Decimal, 76, 38), typecode.BIGINT: dt.Int64, typecode.BINARY: dt.Binary, - typecode.BIT: dt.String, + # typecode.BIT: dt.String, typecode.BOOLEAN: dt.Boolean, typecode.CHAR: dt.String, typecode.DATE: dt.Date, @@ -42,6 +42,7 @@ typecode.MEDIUMTEXT: dt.String, typecode.MONEY: dt.Int64, typecode.NCHAR: dt.String, + typecode.UUID: dt.UUID, typecode.NULL: dt.Null, typecode.NVARCHAR: dt.String, typecode.OBJECT: partial(dt.Map, dt.string, dt.json), @@ -60,6 +61,7 @@ typecode.VARCHAR: dt.String, typecode.VARIANT: dt.JSON, typecode.UNIQUEIDENTIFIER: dt.UUID, + typecode.SET: partial(dt.Array, dt.string), ############################# # Unsupported sqlglot types # ############################# @@ -305,6 +307,37 @@ class PostgresType(SqlglotType): ) +class MySQLType(SqlglotType): + dialect = "mysql" + + unknown_type_strings = FrozenDict( + { + "year(4)": dt.int8, + "inet6": dt.inet, + } + ) + + @classmethod + def _from_sqlglot_BIT(cls, nbits: sge.DataTypeParam) -> dt.Integer: + nbits = int(nbits.this.this) + if nbits > 32: + return dt.Int64(nullable=cls.default_nullable) + elif nbits > 16: + return dt.Int32(nullable=cls.default_nullable) + elif nbits > 8: + return dt.Int16(nullable=cls.default_nullable) + else: + return dt.Int8(nullable=cls.default_nullable) + + @classmethod + def _from_sqlglot_DATETIME(cls) -> dt.Timestamp: + return dt.Timestamp(nullable=cls.default_nullable) + + @classmethod + def _from_sqlglot_TIMESTAMP(cls) -> dt.Timestamp: + return dt.Timestamp(timezone="UTC", nullable=cls.default_nullable) + + class DuckDBType(SqlglotType): dialect = "duckdb" default_decimal_precision = 18 diff --git a/ibis/backends/mysql/__init__.py b/ibis/backends/mysql/__init__.py index 47639cf41a97..1ed548315f14 100644 --- a/ibis/backends/mysql/__init__.py +++ b/ibis/backends/mysql/__init__.py @@ -2,17 +2,18 @@ from __future__ import annotations -import re import warnings from typing import TYPE_CHECKING, Literal import sqlalchemy as sa from sqlalchemy.dialects import mysql +import ibis.expr.schema as sch +from ibis import util from ibis.backends.base import CanCreateDatabase from ibis.backends.base.sql.alchemy import BaseAlchemyBackend from ibis.backends.mysql.compiler import MySQLCompiler -from ibis.backends.mysql.datatypes import MySQLDateTime, _type_from_cursor_info +from ibis.backends.mysql.datatypes import MySQLDateTime, MySQLType if TYPE_CHECKING: from collections.abc import Iterable @@ -146,20 +147,32 @@ def list_databases(self, like: str | None = None) -> list[str]: databases = self.inspector.get_schema_names() return self._filter_with_like(databases, like) - def _metadata(self, query: str) -> Iterable[tuple[str, dt.DataType]]: - if ( - re.search(r"^\s*SELECT\s", query, flags=re.MULTILINE | re.IGNORECASE) - is not None - ): - query = f"({query})" + def _metadata(self, table: str) -> Iterable[tuple[str, dt.DataType]]: + with self.begin() as con: + result = con.exec_driver_sql(f"DESCRIBE {table}").mappings().all() + + for field in result: + name = field["Field"] + type_string = field["Type"] + is_nullable = field["Null"] == "YES" + yield name, MySQLType.from_string(type_string, nullable=is_nullable) + + def _get_schema_using_query(self, query: str): + table = f"__ibis_mysql_metadata_{util.guid()}" with self.begin() as con: - result = con.exec_driver_sql(f"SELECT * FROM {query} _ LIMIT 0") - cursor = result.cursor - yield from ( - (field.name, _type_from_cursor_info(descr, field)) - for descr, field in zip(cursor.description, cursor._result.fields) - ) + con.exec_driver_sql(f"CREATE TEMPORARY TABLE {table} AS {query}") + result = con.exec_driver_sql(f"DESCRIBE {table}").mappings().all() + con.exec_driver_sql(f"DROP TABLE {table}") + + fields = {} + for field in result: + name = field["Field"] + type_string = field["Type"] + is_nullable = field["Null"] == "YES" + fields[name] = MySQLType.from_string(type_string, nullable=is_nullable) + + return sch.Schema(fields) def _get_temp_view_definition( self, name: str, definition: sa.sql.compiler.Compiled diff --git a/ibis/backends/mysql/datatypes.py b/ibis/backends/mysql/datatypes.py index 805d25adbdbc..9b35423a6d1b 100644 --- a/ibis/backends/mysql/datatypes.py +++ b/ibis/backends/mysql/datatypes.py @@ -1,168 +1,11 @@ from __future__ import annotations -from functools import partial - import sqlalchemy.types as sat from sqlalchemy.dialects import mysql import ibis.expr.datatypes as dt from ibis.backends.base.sql.alchemy.datatypes import UUID, AlchemyType - -# binary character set -# used to distinguish blob binary vs blob text -MY_CHARSET_BIN = 63 - - -def _type_from_cursor_info(descr, field) -> dt.DataType: - """Construct an ibis type from MySQL field descr and field result metadata. - - This method is complex because the MySQL protocol is complex. - - Types are not encoded in a self contained way, meaning you need - multiple pieces of information coming from the result set metadata to - determine the most precise type for a field. Even then, the decoding is - not high fidelity in some cases: UUIDs for example are decoded as - strings, because the protocol does not appear to preserve the logical - type, only the physical type. - """ - from pymysql.connections import TEXT_TYPES - - _, type_code, _, _, field_length, scale, _ = descr - flags = _FieldFlags(field.flags) - typename = _type_codes.get(type_code) - if typename is None: - raise NotImplementedError(f"MySQL type code {type_code:d} is not supported") - - if typename in ("DECIMAL", "NEWDECIMAL"): - precision = _decimal_length_to_precision( - length=field_length, - scale=scale, - is_unsigned=flags.is_unsigned, - ) - typ = partial(_type_mapping[typename], precision=precision, scale=scale) - elif typename == "BIT": - if field_length <= 8: - typ = dt.int8 - elif field_length <= 16: - typ = dt.int16 - elif field_length <= 32: - typ = dt.int32 - elif field_length <= 64: - typ = dt.int64 - else: - raise AssertionError("invalid field length for BIT type") - elif flags.is_set: - # sets are limited to strings - typ = dt.Array(dt.string) - elif flags.is_unsigned and flags.is_num: - typ = getattr(dt, f"U{typ.__name__}") - elif type_code in TEXT_TYPES: - # binary text - if field.charsetnr == MY_CHARSET_BIN: - typ = dt.Binary - else: - typ = dt.String - else: - typ = _type_mapping[typename] - - # projection columns are always nullable - return typ(nullable=True) - - -# ported from my_decimal.h:my_decimal_length_to_precision in mariadb -def _decimal_length_to_precision(*, length: int, scale: int, is_unsigned: bool) -> int: - return length - (scale > 0) - (not (is_unsigned or not length)) - - -_type_codes = { - 0: "DECIMAL", - 1: "TINY", - 2: "SHORT", - 3: "LONG", - 4: "FLOAT", - 5: "DOUBLE", - 6: "NULL", - 7: "TIMESTAMP", - 8: "LONGLONG", - 9: "INT24", - 10: "DATE", - 11: "TIME", - 12: "DATETIME", - 13: "YEAR", - 15: "VARCHAR", - 16: "BIT", - 245: "JSON", - 246: "NEWDECIMAL", - 247: "ENUM", - 248: "SET", - 249: "TINY_BLOB", - 250: "MEDIUM_BLOB", - 251: "LONG_BLOB", - 252: "BLOB", - 253: "VAR_STRING", - 254: "STRING", - 255: "GEOMETRY", -} - - -_type_mapping = { - "DECIMAL": dt.Decimal, - "TINY": dt.Int8, - "SHORT": dt.Int16, - "LONG": dt.Int32, - "FLOAT": dt.Float32, - "DOUBLE": dt.Float64, - "NULL": dt.Null, - "TIMESTAMP": lambda nullable: dt.Timestamp(timezone="UTC", nullable=nullable), - "LONGLONG": dt.Int64, - "INT24": dt.Int32, - "DATE": dt.Date, - "TIME": dt.Time, - "DATETIME": dt.Timestamp, - "YEAR": dt.Int8, - "VARCHAR": dt.String, - "JSON": dt.JSON, - "NEWDECIMAL": dt.Decimal, - "ENUM": dt.String, - "SET": lambda nullable: dt.Array(dt.string, nullable=nullable), - "TINY_BLOB": dt.Binary, - "MEDIUM_BLOB": dt.Binary, - "LONG_BLOB": dt.Binary, - "BLOB": dt.Binary, - "VAR_STRING": dt.String, - "STRING": dt.String, - "GEOMETRY": dt.Geometry, -} - - -class _FieldFlags: - """Flags used to disambiguate field types. - - Gaps in the flag numbers are because we do not map in flags that are - of no use in determining the field's type, such as whether the field - is a primary key or not. - """ - - UNSIGNED = 1 << 5 - SET = 1 << 11 - NUM = 1 << 15 - - __slots__ = ("value",) - - def __init__(self, value: int) -> None: - self.value = value - - @property - def is_unsigned(self) -> bool: - return (self.UNSIGNED & self.value) != 0 - - @property - def is_set(self) -> bool: - return (self.SET & self.value) != 0 - - @property - def is_num(self) -> bool: - return (self.NUM & self.value) != 0 +from ibis.backends.base.sql.glot.datatypes import MySQLType as SqlglotMySQLType class MySQLDateTime(mysql.DATETIME): @@ -214,7 +57,7 @@ def result_processor(self, *_): mysql.TIME: dt.Time, mysql.YEAR: dt.Int8, MySQLDateTime: dt.Timestamp, - UUID: dt.String, + UUID: dt.UUID, } @@ -247,8 +90,12 @@ def to_ibis(cls, typ, nullable=True): elif isinstance(typ, mysql.TIMESTAMP): return dt.Timestamp(timezone="UTC", nullable=nullable) elif isinstance(typ, mysql.SET): - return dt.Set(dt.string, nullable=nullable) + return dt.Array(dt.string, nullable=nullable) elif dtype := _from_mysql_types.get(type(typ)): return dtype(nullable=nullable) else: return super().to_ibis(typ, nullable=nullable) + + @classmethod + def from_string(cls, type_string, nullable=True): + return SqlglotMySQLType.from_string(type_string, nullable=nullable) diff --git a/ibis/backends/mysql/tests/test_client.py b/ibis/backends/mysql/tests/test_client.py index c65471c3ea44..660495680ce4 100644 --- a/ibis/backends/mysql/tests/test_client.py +++ b/ibis/backends/mysql/tests/test_client.py @@ -20,8 +20,8 @@ ("boolean", dt.int8), ("smallint", dt.int16), ("int2", dt.int16), - ("mediumint", dt.int32), - ("int3", dt.int32), + # ("mediumint", dt.int32), => https://github.com/tobymao/sqlglot/issues/2109 + # ("int3", dt.int32), => https://github.com/tobymao/sqlglot/issues/2109 ("int", dt.int32), ("int4", dt.int32), ("integer", dt.int32), @@ -52,11 +52,15 @@ # mariadb doesn't have a distinct json type ("json", dt.string), ("enum('small', 'medium', 'large')", dt.string), - ("inet6", dt.string), + # con.table(name) first parses the type correctly as ibis inet using sqlglot, + # then convert these types to sqlalchemy types then a sqlalchemy table to + # get the ibis schema again from the alchemy types, but alchemy doesn't + # support inet6 so it gets converted to string eventually + # ("inet6", dt.inet), ("set('a', 'b', 'c', 'd')", dt.Array(dt.string)), ("mediumblob", dt.binary), ("blob", dt.binary), - ("uuid", dt.string), + ("uuid", dt.uuid), ] @@ -70,16 +74,19 @@ def test_get_schema_from_query(con, mysql_type, expected_type): raw_name = ibis.util.guid() name = con._quote(raw_name) + expected_schema = ibis.schema(dict(x=expected_type)) + # temporary tables get cleaned up by the db when the session ends, so we # don't need to explicitly drop the table with con.begin() as c: c.exec_driver_sql(f"CREATE TEMPORARY TABLE {name} (x {mysql_type})") - expected_schema = ibis.schema(dict(x=expected_type)) - t = con.table(raw_name) + result_schema = con._get_schema_using_query(f"SELECT * FROM {name}") - assert t.schema() == expected_schema assert result_schema == expected_schema + t = con.table(raw_name) + assert t.schema() == expected_schema + @pytest.mark.parametrize("coltype", ["TINYBLOB", "MEDIUMBLOB", "BLOB", "LONGBLOB"]) def test_blob_type(con, coltype): diff --git a/ibis/backends/sqlite/__init__.py b/ibis/backends/sqlite/__init__.py index a7d0ac4b44d9..cb58a207d9e4 100644 --- a/ibis/backends/sqlite/__init__.py +++ b/ibis/backends/sqlite/__init__.py @@ -179,7 +179,7 @@ def _current_schema(self) -> str | None: return self.current_database def _metadata(self, query: str) -> Iterator[tuple[str, dt.DataType]]: - view = f"__ibis_sqlite_metadata{util.guid()}" + view = f"__ibis_sqlite_metadata_{util.guid()}" with self.begin() as con: if query in self.list_tables(): diff --git a/ibis/backends/trino/tests/test_datatypes.py b/ibis/backends/trino/tests/test_datatypes.py index ce2bfd07b325..85c435e117d3 100644 --- a/ibis/backends/trino/tests/test_datatypes.py +++ b/ibis/backends/trino/tests/test_datatypes.py @@ -4,7 +4,7 @@ from pytest import param import ibis.expr.datatypes as dt -from ibis.backends.trino.datatypes import parse +from ibis.backends.trino.datatypes import TrinoType dtypes = [ ("interval year to month", dt.Interval(unit="M")), @@ -65,4 +65,4 @@ [param(trino_type, ibis_type, id=trino_type) for trino_type, ibis_type in dtypes], ) def test_parse(trino_type, ibis_type): - assert parse(trino_type) == ibis_type + assert TrinoType.from_string(trino_type) == ibis_type