From fe7ba2456a5202b28a2418fb44e145319ce6dd0c Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Sun, 13 Aug 2023 05:53:59 -0400 Subject: [PATCH] refactor(datatypes): use sqlglot for parsing backend specific types BREAKING CHANGE: The minimum version of `sqlglot` is now 17.2.0, to support much faster and more robust backend type parsing. --- ibis/backends/clickhouse/datatypes.py | 239 ++++++------------- ibis/backends/druid/datatypes.py | 30 +-- ibis/backends/duckdb/datatypes.py | 91 +------ ibis/backends/duckdb/tests/test_datatypes.py | 38 +-- ibis/backends/postgres/__init__.py | 24 +- ibis/backends/postgres/datatypes.py | 106 +++----- ibis/backends/trino/datatypes.py | 116 ++------- ibis/formats/parser.py | 206 ++++++++++++++++ poetry.lock | 8 +- pyproject.toml | 2 +- requirements-dev.txt | 2 +- 11 files changed, 370 insertions(+), 492 deletions(-) create mode 100644 ibis/formats/parser.py diff --git a/ibis/backends/clickhouse/datatypes.py b/ibis/backends/clickhouse/datatypes.py index 2f8ba0981f97..be378798316a 100644 --- a/ibis/backends/clickhouse/datatypes.py +++ b/ibis/backends/clickhouse/datatypes.py @@ -1,192 +1,101 @@ from __future__ import annotations import functools -from functools import partial -from typing import Literal +from typing import TYPE_CHECKING, Literal, Mapping -import parsy +import sqlglot as sg +from sqlglot.expressions import ColumnDef, DataType import ibis import ibis.expr.datatypes as dt -from ibis.common.parsing import ( - COMMA, - FIELD, - LPAREN, - NUMBER, - PRECISION, - RAW_NUMBER, - RAW_STRING, - RPAREN, - SCALE, - SPACES, - spaceless_string, -) +from ibis.common.collections import FrozenDict +from ibis.formats.parser import TypeParser + +if TYPE_CHECKING: + from sqlglot.expressions import DataTypeSize, Expression def _bool_type() -> Literal["Bool", "UInt8", "Int8"]: return getattr(getattr(ibis.options, "clickhouse", None), "bool_type", "Bool") -def parse(text: str) -> dt.DataType: - datetime64_args = LPAREN.then( - parsy.seq( - scale=parsy.decimal_digit.map(int).optional(), - timezone=COMMA.then(RAW_STRING).optional(), - ) - ).skip(RPAREN) - - datetime64 = spaceless_string("datetime64").then( - datetime64_args.optional(default={}).combine_dict( - partial(dt.Timestamp, nullable=False) - ) - ) - - datetime = spaceless_string("datetime").then( - parsy.seq( - timezone=LPAREN.then(RAW_STRING).skip(RPAREN).optional() - ).combine_dict(partial(dt.Timestamp, nullable=False)) - ) - - primitive = ( - datetime64 - | datetime - | spaceless_string("null", "nothing").result(dt.null) - | spaceless_string("bigint", "int64").result(dt.Int64(nullable=False)) - | spaceless_string("double", "float64").result(dt.Float64(nullable=False)) - | spaceless_string("float32", "float").result(dt.Float32(nullable=False)) - | spaceless_string("smallint", "int16", "int2").result(dt.Int16(nullable=False)) - | spaceless_string("date32", "date").result(dt.Date(nullable=False)) - | spaceless_string("time").result(dt.Time(nullable=False)) - | spaceless_string("tinyint", "int8", "int1").result(dt.Int8(nullable=False)) - | spaceless_string("boolean", "bool").result(dt.Boolean(nullable=False)) - | spaceless_string("integer", "int32", "int4", "int").result( - dt.Int32(nullable=False) - ) - | spaceless_string("uint64").result(dt.UInt64(nullable=False)) - | spaceless_string("uint32").result(dt.UInt32(nullable=False)) - | spaceless_string("uint16").result(dt.UInt16(nullable=False)) - | spaceless_string("uint8").result(dt.UInt8(nullable=False)) - | spaceless_string("uuid").result(dt.UUID(nullable=False)) - | spaceless_string( - "longtext", - "mediumtext", - "tinytext", - "text", - "longblob", - "mediumblob", - "tinyblob", - "blob", - "varchar", - "char", - "string", - ).result(dt.String(nullable=False)) - ) - - ty = parsy.forward_declaration() - - nullable = ( - spaceless_string("nullable") - .then(LPAREN) - .then(ty.map(lambda ty: ty.copy(nullable=True))) - .skip(RPAREN) - ) - - fixed_string = ( - spaceless_string("fixedstring") - .then(LPAREN) - .then(NUMBER) - .then(RPAREN) - .result(dt.String(nullable=False)) - ) - - decimal = ( - spaceless_string("decimal", "numeric") - .then(LPAREN) - .then( - parsy.seq(precision=PRECISION.skip(COMMA), scale=SCALE).combine_dict( - partial(dt.Decimal(nullable=False)) - ) - ) - .skip(RPAREN) - ) +class ClickHouseTypeParser(TypeParser): + __slots__ = () - array = spaceless_string("array").then( - LPAREN.then(ty.map(partial(dt.Array, nullable=False))).skip(RPAREN) - ) + dialect = "clickhouse" + default_decimal_precision = None + default_decimal_scale = None + default_nullable = False - map = ( - spaceless_string("map") - .then(LPAREN) - .then(parsy.seq(ty, COMMA.then(ty)).combine(partial(dt.Map, nullable=False))) - .skip(RPAREN) + short_circuit: Mapping[str, dt.DataType] = FrozenDict( + { + "IPv4": dt.INET(nullable=default_nullable), + "IPv6": dt.INET(nullable=default_nullable), + "Object('json')": dt.JSON(nullable=default_nullable), + "Array(Null)": dt.Array(dt.null, nullable=default_nullable), + "Array(Nothing)": dt.Array(dt.null, nullable=default_nullable), + } ) - at_least_one_space = parsy.regex(r"\s+") - - nested = ( - spaceless_string("nested") - .then(LPAREN) - .then( - parsy.seq(SPACES.then(FIELD.skip(at_least_one_space)), ty) - .combine(lambda field, ty: (field, dt.Array(ty, nullable=False))) - .sep_by(COMMA) - .map(partial(dt.Struct.from_tuples, nullable=False)) - ) - .skip(RPAREN) - ) - - struct = ( - spaceless_string("tuple") - .then(LPAREN) - .then( - parsy.seq( - SPACES.then(FIELD.skip(at_least_one_space).optional()), - ty, + @classmethod + def _get_DATETIME( + cls, first: DataTypeSize | None = None, second: DataTypeSize | None = None + ) -> dt.Timestamp: + if first is not None and second is not None: + scale = first + timezone = second + elif first is not None and second is None: + timezone, scale = ( + (first, second) if first.this.is_string else (second, first) ) - .sep_by(COMMA) - .map( - lambda field_names_types: dt.Struct.from_tuples( - [ - (field_name if field_name is not None else f"f{i:d}", typ) - for i, (field_name, typ) in enumerate(field_names_types) - ], - nullable=False, + else: + scale = first + timezone = second + return cls._get_TIMESTAMP(scale=scale, timezone=timezone) + + @classmethod + def _get_DATETIME64( + cls, scale: DataTypeSize | None = None, timezone: DataTypeSize | None = None + ) -> dt.Timestamp: + return cls._get_TIMESTAMP(scale=scale, timezone=timezone) + + @classmethod + def _get_NULLABLE(cls, inner_type: DataType) -> dt.DataType: + return cls._get_type(inner_type).copy(nullable=True) + + @classmethod + def _get_LOWCARDINALITY(cls, inner_type: DataType) -> dt.DataType: + return cls._get_type(inner_type) + + @classmethod + def _get_NESTED(cls, *fields: DataType) -> dt.Struct: + return dt.Struct( + { + field.name: dt.Array( + cls._get_type(field.args["kind"]), nullable=cls.default_nullable ) - ) + for field in fields + }, + nullable=cls.default_nullable, ) - .skip(RPAREN) - ) - enum_value = SPACES.then(RAW_STRING).skip(spaceless_string("=")).then(RAW_NUMBER) + @classmethod + def _get_STRUCT(cls, *fields: Expression) -> dt.Struct: + types = {} - lowcardinality = ( - spaceless_string("lowcardinality").then(LPAREN).then(ty).skip(RPAREN) - ) + for i, field in enumerate(fields): + if isinstance(field, ColumnDef): + inner_type = field.args["kind"] + name = field.name + else: + inner_type = sg.parse_one(str(field), into=DataType, read="clickhouse") + name = f"f{i:d}" - enum = ( - spaceless_string("enum") - .then(RAW_NUMBER) - .then(LPAREN) - .then(enum_value.sep_by(COMMA)) - .skip(RPAREN) - .result(dt.String(nullable=False)) - ) + types[name] = cls._get_type(inner_type) + return dt.Struct(types, nullable=cls.default_nullable) - ty.become( - nullable - | nested - | primitive - | fixed_string - | decimal - | array - | map - | struct - | enum - | lowcardinality - | spaceless_string("IPv4", "IPv6").result(dt.INET(nullable=False)) - | spaceless_string("Object('json')", "JSON").result(dt.JSON(nullable=False)) - ) - return ty.parse(text) + +parse = ClickHouseTypeParser.parse @functools.singledispatch diff --git a/ibis/backends/druid/datatypes.py b/ibis/backends/druid/datatypes.py index c760131bc0e6..9932380b1a96 100644 --- a/ibis/backends/druid/datatypes.py +++ b/ibis/backends/druid/datatypes.py @@ -1,6 +1,7 @@ from __future__ import annotations -import parsy +from typing import Mapping + import sqlalchemy as sa import sqlalchemy.types as sat from dateutil.parser import parse as timestamp_parse @@ -8,11 +9,8 @@ import ibis.expr.datatypes as dt from ibis.backends.base.sql.alchemy.datatypes import AlchemyType -from ibis.common.parsing import ( - LANGLE, - RANGLE, - spaceless_string, -) +from ibis.common.collections import FrozenDict +from ibis.formats.parser import TypeParser class DruidDateTime(sat.TypeDecorator): @@ -59,23 +57,15 @@ def _smallint(element, compiler, **kw): return "SMALLINT" -def parse(text: str) -> dt.DataType: - """Parse a Druid type into an ibis data type.""" - primitive = ( - spaceless_string("string").result(dt.string) - | spaceless_string("double").result(dt.float64) - | spaceless_string("float").result(dt.float32) - | spaceless_string("long").result(dt.int64) - | spaceless_string("json").result(dt.json) - ) +class DruidTypeParser(TypeParser): + __slots__ = () - ty = parsy.forward_declaration() + # druid doesn't have a sophisticated type system and hive is close enough + dialect = "hive" + short_circuit: Mapping[str, dt.DataType] = FrozenDict({"complex": dt.json}) - json = spaceless_string("complex").then(LANGLE).then(ty).skip(RANGLE) - array = spaceless_string("array").then(LANGLE).then(ty.map(dt.Array)).skip(RANGLE) - ty.become(primitive | array | json) - return ty.parse(text) +parse = DruidTypeParser.parse class DruidType(AlchemyType): diff --git a/ibis/backends/duckdb/datatypes.py b/ibis/backends/duckdb/datatypes.py index 5b055829eea8..7ab44095c996 100644 --- a/ibis/backends/duckdb/datatypes.py +++ b/ibis/backends/duckdb/datatypes.py @@ -1,98 +1,25 @@ from __future__ import annotations import duckdb_engine.datatypes as ducktypes -import parsy import sqlalchemy.dialects.postgresql as psql -import toolz import ibis.expr.datatypes as dt from ibis.backends.base.sql.alchemy.datatypes import AlchemyType -from ibis.common.parsing import ( - COMMA, - FIELD, - LBRACKET, - LPAREN, - PRECISION, - RAW_STRING, - RBRACKET, - RPAREN, - SCALE, - spaceless, - spaceless_string, -) +from ibis.formats.parser import TypeParser -def parse(text: str, default_decimal_parameters=(18, 3)) -> dt.DataType: - """Parse a DuckDB type into an ibis data type.""" - primitive = ( - spaceless_string("interval").result(dt.Interval("us")) - | spaceless_string("hugeint", "int128").result(dt.Decimal(38, 0)) - | spaceless_string("bigint", "int8", "long").result(dt.int64) - | spaceless_string("boolean", "bool", "logical").result(dt.boolean) - | spaceless_string("blob", "bytea", "binary", "varbinary").result(dt.binary) - | spaceless_string("double", "float8").result(dt.float64) - | spaceless_string("real", "float4", "float").result(dt.float32) - | spaceless_string("smallint", "int2", "short").result(dt.int16) - | spaceless_string( - "timestamp with time zone", "timestamp_tz", "datetime" - ).result(dt.Timestamp(timezone="UTC")) - | spaceless_string("timestamp_sec", "timestamp_s").result( - dt.Timestamp(timezone="UTC", scale=0) - ) - | spaceless_string("timestamp_ms").result(dt.Timestamp(timezone="UTC", scale=3)) - | spaceless_string("timestamp_us").result(dt.Timestamp(timezone="UTC", scale=6)) - | spaceless_string("timestamp_ns").result(dt.Timestamp(timezone="UTC", scale=9)) - | spaceless_string("timestamp").result(dt.Timestamp(timezone="UTC")) - | spaceless_string("date").result(dt.date) - | spaceless_string("time").result(dt.time) - | spaceless_string("tinyint", "int1").result(dt.int8) - | spaceless_string("integer", "int4", "int", "signed").result(dt.int32) - | spaceless_string("ubigint").result(dt.uint64) - | spaceless_string("usmallint").result(dt.uint16) - | spaceless_string("uinteger").result(dt.uint32) - | spaceless_string("utinyint").result(dt.uint8) - | spaceless_string("uuid").result(dt.uuid) - | spaceless_string("varchar", "char", "bpchar", "text", "string").result( - dt.string - ) - | spaceless_string("json").result(dt.json) - | spaceless_string("null").result(dt.null) - ) +class DuckDBTypeParser(TypeParser): + __slots__ = () - decimal = spaceless_string("decimal", "numeric").then( - parsy.seq(LPAREN.then(PRECISION), COMMA.then(SCALE).skip(RPAREN)) - .optional(default_decimal_parameters) - .combine(dt.Decimal) - ) - - brackets = spaceless(LBRACKET).then(spaceless(RBRACKET)) - - ty = parsy.forward_declaration() - non_pg_array_type = parsy.forward_declaration() - - pg_array = parsy.seq(non_pg_array_type, brackets.at_least(1).map(len)).combine( - lambda value_type, n: toolz.nth(n, toolz.iterate(dt.Array, value_type)) - ) - - map = ( - spaceless_string("map") - .then(LPAREN) - .then(parsy.seq(ty, COMMA.then(ty)).combine(dt.Map)) - .skip(RPAREN) - ) + dialect = "duckdb" + default_decimal_precision = 18 + default_decimal_scale = 3 + default_interval_precision = "us" - field = spaceless(parsy.alt(FIELD, RAW_STRING)) + fallback = {"INTERVAL": dt.Interval(default_interval_precision)} - struct = ( - spaceless_string("struct") - .then(LPAREN) - .then(parsy.seq(field, ty).sep_by(COMMA).map(dt.Struct.from_tuples)) - .skip(RPAREN) - ) - non_pg_array_type.become(primitive | decimal | map | struct) - ty.become(pg_array | non_pg_array_type) - return ty.parse(text) +parse = DuckDBTypeParser.parse _from_duckdb_types = { diff --git a/ibis/backends/duckdb/tests/test_datatypes.py b/ibis/backends/duckdb/tests/test_datatypes.py index 15c1020c9470..e11a9c8ea1e2 100644 --- a/ibis/backends/duckdb/tests/test_datatypes.py +++ b/ibis/backends/duckdb/tests/test_datatypes.py @@ -19,51 +19,30 @@ param(typ, expected, id=typ.lower()) for typ, expected in [ ("BIGINT", dt.int64), - ("INT8", dt.int64), - ("LONG", dt.int64), ("BOOLEAN", dt.boolean), - ("BOOL", dt.boolean), - ("LOGICAL", dt.boolean), ("BLOB", dt.binary), - ("BYTEA", dt.binary), - ("BINARY", dt.binary), - ("VARBINARY", dt.binary), ("DATE", dt.date), ("DOUBLE", dt.float64), - ("FLOAT8", dt.float64), - ("NUMERIC", dt.Decimal(18, 3)), - ("DECIMAL", dt.Decimal(18, 3)), ("DECIMAL(10, 3)", dt.Decimal(10, 3)), ("INTEGER", dt.int32), - ("INT4", dt.int32), - ("INT", dt.int32), - ("SIGNED", dt.int32), ("INTERVAL", dt.Interval("us")), - ("REAL", dt.float32), - ("FLOAT4", dt.float32), ("FLOAT", dt.float32), ("SMALLINT", dt.int16), - ("INT2", dt.int16), - ("SHORT", dt.int16), ("TIME", dt.time), - ("TIMESTAMP", dt.Timestamp("UTC")), - ("DATETIME", dt.Timestamp("UTC")), + ("TIME WITH TIME ZONE", dt.time), + ("TIMESTAMP", dt.timestamp), + ("TIMESTAMP WITH TIME ZONE", dt.Timestamp("UTC")), ("TINYINT", dt.int8), - ("INT1", dt.int8), ("UBIGINT", dt.uint64), ("UINTEGER", dt.uint32), ("USMALLINT", dt.uint16), ("UTINYINT", dt.uint8), ("UUID", dt.uuid), ("VARCHAR", dt.string), - ("CHAR", dt.string), - ("BPCHAR", dt.string), - ("TEXT", dt.string), - ("STRING", dt.string), ("INTEGER[]", dt.Array(dt.int32)), - ("MAP(STRING, BIGINT)", dt.Map(dt.string, dt.int64)), + ("MAP(VARCHAR, BIGINT)", dt.Map(dt.string, dt.int64)), ( - "STRUCT(a INT, b TEXT, c MAP(TEXT, FLOAT8[])[])", + "STRUCT(a INTEGER, b VARCHAR, c MAP(VARCHAR, DOUBLE[])[])", dt.Struct( dict( a=dt.int32, @@ -73,15 +52,8 @@ ), ), ("INTEGER[][]", dt.Array(dt.Array(dt.int32))), - ("TIMESTAMP_TZ", dt.Timestamp("UTC")), - ("TIMESTAMP_SEC", dt.Timestamp("UTC", scale=0)), - ("TIMESTAMP_S", dt.Timestamp("UTC", scale=0)), - ("TIMESTAMP_MS", dt.Timestamp("UTC", scale=3)), - ("TIMESTAMP_US", dt.Timestamp("UTC", scale=6)), - ("TIMESTAMP_NS", dt.Timestamp("UTC", scale=9)), ("JSON", dt.json), ("HUGEINT", dt.Decimal(38, 0)), - ("INT128", dt.Decimal(38, 0)), ] ], ) diff --git a/ibis/backends/postgres/__init__.py b/ibis/backends/postgres/__init__.py index cb57814cb7a9..cdea9ed43cda 100644 --- a/ibis/backends/postgres/__init__.py +++ b/ibis/backends/postgres/__init__.py @@ -13,7 +13,7 @@ from ibis import util from ibis.backends.base.sql.alchemy import AlchemyCanCreateSchema, BaseAlchemyBackend from ibis.backends.postgres.compiler import PostgreSQLCompiler -from ibis.backends.postgres.datatypes import _get_type +from ibis.backends.postgres.datatypes import parse from ibis.common.exceptions import InvalidDecoratorError if TYPE_CHECKING: @@ -176,20 +176,20 @@ def function(self, name: str, *, schema: str | None = None) -> Callable: def split_name_type(arg: str) -> tuple[str, dt.DataType]: name, typ = arg.split(" ", 1) - return name, _get_type(typ) + return name, parse(typ) with self.begin() as con: rows = con.execute(query).mappings().fetchall() - if not rows: - name = f"{schema}.{name}" if schema else name - raise exc.MissingUDFError(name) - elif len(rows) > 1: - raise exc.AmbiguousUDFError(name) + if not rows: + name = f"{schema}.{name}" if schema else name + raise exc.MissingUDFError(name) + elif len(rows) > 1: + raise exc.AmbiguousUDFError(name) - [row] = rows - return_type = _get_type(row["return_type"]) - signature = list(map(split_name_type, row["signature"])) + [row] = rows + return_type = parse(row["return_type"]) + signature = list(map(split_name_type, row["signature"])) # dummy callable def fake_func(*args, **kwargs): @@ -263,9 +263,7 @@ def _metadata(self, query: str) -> Iterable[tuple[str, dt.DataType]]: with self.begin() as con: con.exec_driver_sql(f"CREATE TEMPORARY VIEW {name} AS {query}") try: - yield from ( - (col, _get_type(typestr)) for col, typestr in con.execute(text) - ) + yield from ((col, parse(typestr)) for col, typestr in con.execute(text)) finally: con.exec_driver_sql(f"DROP VIEW IF EXISTS {name}") diff --git a/ibis/backends/postgres/datatypes.py b/ibis/backends/postgres/datatypes.py index 5f187606a6f4..855e34867284 100644 --- a/ibis/backends/postgres/datatypes.py +++ b/ibis/backends/postgres/datatypes.py @@ -1,93 +1,43 @@ from __future__ import annotations -import parsy +from typing import Mapping + import sqlalchemy as sa import sqlalchemy.dialects.postgresql as psql import sqlalchemy.types as sat -import toolz import ibis.expr.datatypes as dt from ibis.backends.base.sql.alchemy.datatypes import AlchemyType -from ibis.common.parsing import ( - COMMA, - LBRACKET, - LPAREN, - PRECISION, - RBRACKET, - RPAREN, - SCALE, - spaceless, - spaceless_string, -) - -_BRACKETS = "[]" - - -def _parse_numeric( - text: str, default_decimal_parameters: tuple[int | None, int | None] = (None, None) -) -> dt.DataType: - decimal = spaceless_string("decimal", "numeric").then( - parsy.seq(LPAREN.then(PRECISION.skip(COMMA)), SCALE.skip(RPAREN)) - .optional(default_decimal_parameters) - .combine(dt.Decimal) +from ibis.common.collections import FrozenDict +from ibis.formats.parser import TypeParser + + +class PostgresTypeParser(TypeParser): + __slots__ = () + + dialect = "postgres" + default_interval_precision = "s" + + short_circuit: Mapping[str, dt.DataType] = FrozenDict( + { + "vector": dt.unknown, + "tsvector": dt.unknown, + "line": dt.linestring, + "line[]": dt.Array(dt.linestring), + "polygon": dt.polygon, + "polygon[]": dt.Array(dt.polygon), + "point": dt.point, + "point[]": dt.Array(dt.point), + "macaddr": dt.macaddr, + "macaddr[]": dt.Array(dt.macaddr), + "macaddr8": dt.macaddr, + "macaddr8[]": dt.Array(dt.macaddr), + } ) - brackets = spaceless(LBRACKET).then(spaceless(RBRACKET)) - pg_array = parsy.seq(decimal, brackets.at_least(1).map(len)).combine( - lambda value_type, n: toolz.nth(n, toolz.iterate(dt.Array, value_type)) - ) +parse = PostgresTypeParser.parse - ty = pg_array | decimal - return ty.parse(text) - - -# TODO(kszucs): rename to dtype_from_postgres_typeinfo or parse_postgres_typeinfo -def _get_type(typestr: str) -> dt.DataType: - is_array = typestr.endswith(_BRACKETS) - if (typ := _type_mapping.get(typestr.replace(_BRACKETS, ""))) is not None: - return dt.Array(typ) if is_array else typ - try: - return _parse_numeric(typestr) - except parsy.ParseError: - # postgres can have arbitrary types unknown to ibis - return dt.unknown - - -_type_mapping = { - "bigint": dt.int64, - "boolean": dt.bool, - "bytea": dt.binary, - "character varying": dt.string, - "character": dt.string, - "character(1)": dt.string, - "date": dt.date, - "double precision": dt.float64, - "geography": dt.geography, - "geometry": dt.geometry, - "inet": dt.inet, - "integer": dt.int32, - "interval": dt.Interval("s"), - "json": dt.json, - "jsonb": dt.json, - "line": dt.linestring, - "macaddr": dt.macaddr, - "macaddr8": dt.macaddr, - "numeric": dt.decimal, - "point": dt.point, - "polygon": dt.polygon, - "real": dt.float32, - "smallint": dt.int16, - "text": dt.string, - # NB: this isn't correct because we're losing the "with time zone" - # information (ibis doesn't have time type that is time-zone aware), but we - # try to do _something_ here instead of failing - "time with time zone": dt.time, - "time without time zone": dt.time, - "timestamp with time zone": dt.Timestamp("UTC"), - "timestamp without time zone": dt.timestamp, - "uuid": dt.uuid, -} _from_postgres_types = { psql.DOUBLE_PRECISION: dt.Float64, diff --git a/ibis/backends/trino/datatypes.py b/ibis/backends/trino/datatypes.py index fb82cbc3222a..d3254c1bd478 100644 --- a/ibis/backends/trino/datatypes.py +++ b/ibis/backends/trino/datatypes.py @@ -1,9 +1,7 @@ from __future__ import annotations -from functools import partial -from typing import Any +from typing import Any, Mapping -import parsy import sqlalchemy.types as sat import trino.client from sqlalchemy.ext.compiler import compiles @@ -12,18 +10,8 @@ import ibis.expr.datatypes as dt from ibis.backends.base.sql.alchemy.datatypes import AlchemyType -from ibis.common.parsing import ( - COMMA, - FIELD, - LENGTH, - LPAREN, - PRECISION, - RPAREN, - SCALE, - TEMPORAL_SCALE, - spaceless, - spaceless_string, -) +from ibis.common.collections import FrozenDict +from ibis.formats.parser import TypeParser class ROW(_ROW): @@ -84,99 +72,37 @@ def _floating(element, compiler, **kw): return type(element).__name__.upper() -def parse( - text: str, - default_decimal_parameters: tuple[int, int] = (18, 3), - default_temporal_scale: int = 3, # trino defaults to millisecond scale -) -> dt.DataType: - """Parse a Trino type into an ibis data type.""" - - timestamp = ( - spaceless_string("timestamp") - .then( - parsy.seq( - scale=LPAREN.then(TEMPORAL_SCALE) - .skip(RPAREN) - .optional(default_temporal_scale), - timezone=spaceless_string("with time zone").result("UTC").optional(), - ).optional(dict(scale=default_temporal_scale, timezone=None)) - ) - .combine_dict(partial(dt.Timestamp)) - ) - - primitive = ( - spaceless_string("interval year to month").result(dt.Interval(unit="M")) - | spaceless_string("interval day to second").result(dt.Interval(unit="ms")) - | spaceless_string("bigint").result(dt.int64) - | spaceless_string("boolean").result(dt.boolean) - | spaceless_string("varbinary").result(dt.binary) - | spaceless_string("double").result(dt.float64) - | spaceless_string("real").result(dt.float32) - | spaceless_string("smallint").result(dt.int16) - | spaceless_string("date").result(dt.date) - | spaceless_string("tinyint").result(dt.int8) - | spaceless_string("integer").result(dt.int32) - | spaceless_string("uuid").result(dt.uuid) - | spaceless_string("json").result(dt.json) - | spaceless_string("ipaddress").result(dt.inet) - ) - - varchar = ( - spaceless_string("varchar", "char") - .then(LPAREN.then(LENGTH).skip(RPAREN).optional()) - .result(dt.string) - ) - - decimal = spaceless_string("decimal", "numeric").then( - parsy.seq(LPAREN.then(PRECISION).skip(COMMA), SCALE.skip(RPAREN)) - .optional(default_decimal_parameters) - .combine(dt.Decimal) - ) - - time = ( - spaceless_string("time").then( - parsy.seq( - scale=LPAREN.then(TEMPORAL_SCALE) - .skip(RPAREN) - .optional(default_temporal_scale), - timezone=spaceless_string("with time zone").result("UTC").optional(), - ).optional(dict(scale=default_temporal_scale, timezone=None)) - ) - # TODO: support time with precision - .result(dt.time) - ) +class TrinoTypeParser(TypeParser): + __slots__ = () - ty = parsy.forward_declaration() + dialect = "trino" - array = spaceless_string("array").then(LPAREN).then(ty).skip(RPAREN).map(dt.Array) - map = spaceless_string("map").then( - parsy.seq(LPAREN.then(ty).skip(COMMA), ty.skip(RPAREN)).combine(dt.Map) - ) + default_decimal_precision = 18 + default_decimal_scale = 3 + default_temporal_scale = 3 - struct = ( - spaceless_string("row") - .then(LPAREN) - .then(parsy.seq(spaceless(FIELD), ty).sep_by(COMMA).map(dt.Struct.from_tuples)) - .skip(RPAREN) + short_circuit: Mapping[str, dt.DataType] = FrozenDict( + { + "INTERVAL YEAR TO MONTH": dt.Interval("M"), + "INTERVAL DAY TO SECOND": dt.Interval("ms"), + } ) - ty.become(primitive | timestamp | time | varchar | decimal | array | map | struct) - return ty.parse(text) - -_from_trino_types = { - DOUBLE: dt.Float64, - sat.REAL: dt.Float32, - JSON: dt.JSON, -} +parse = TrinoTypeParser.parse class TrinoType(AlchemyType): dialect = "trino" + source_types = { + DOUBLE: dt.Float64, + sat.REAL: dt.Float32, + JSON: dt.JSON, + } @classmethod def to_ibis(cls, typ, nullable=True): - if dtype := _from_trino_types.get(type(typ)): + if dtype := cls.source_types.get(type(typ)): return dtype(nullable=nullable) elif isinstance(typ, sat.NUMERIC): return dt.Decimal(typ.precision or 18, typ.scale or 3, nullable=nullable) diff --git a/ibis/formats/parser.py b/ibis/formats/parser.py new file mode 100644 index 000000000000..1e0bffc6b064 --- /dev/null +++ b/ibis/formats/parser.py @@ -0,0 +1,206 @@ +from __future__ import annotations + +import abc +from typing import Mapping + +import sqlglot as sg +from sqlglot.expressions import ColumnDef, DataType, DataTypeSize + +import ibis.common.exceptions as exc +import ibis.expr.datatypes as dt +from ibis.common.collections import FrozenDict + +SQLGLOT_TYPE_TO_IBIS_TYPE = { + DataType.Type.BIGDECIMAL: dt.Decimal(76, 38), + DataType.Type.BIGINT: dt.int64, + DataType.Type.BINARY: dt.binary, + DataType.Type.BIT: dt.string, + DataType.Type.BOOLEAN: dt.boolean, + DataType.Type.CHAR: dt.string, + DataType.Type.DATE: dt.date, + DataType.Type.DOUBLE: dt.float64, + DataType.Type.ENUM: dt.string, + DataType.Type.ENUM8: dt.string, + DataType.Type.ENUM16: dt.string, + DataType.Type.FLOAT: dt.float32, + DataType.Type.FIXEDSTRING: dt.string, + DataType.Type.GEOMETRY: dt.geometry, + DataType.Type.GEOGRAPHY: dt.geography, + DataType.Type.HSTORE: dt.Map(dt.string, dt.string), + DataType.Type.INET: dt.inet, + DataType.Type.INT128: dt.Decimal(38, 0), + DataType.Type.INT256: dt.Decimal(76, 0), + DataType.Type.INT: dt.int32, + DataType.Type.IPADDRESS: dt.inet, + DataType.Type.JSON: dt.json, + DataType.Type.JSONB: dt.json, + DataType.Type.LONGBLOB: dt.binary, + DataType.Type.LONGTEXT: dt.string, + DataType.Type.MEDIUMBLOB: dt.binary, + DataType.Type.MEDIUMTEXT: dt.string, + DataType.Type.MONEY: dt.int64, + DataType.Type.NCHAR: dt.string, + DataType.Type.NULL: dt.null, + DataType.Type.NVARCHAR: dt.string, + DataType.Type.OBJECT: dt.Map(dt.string, dt.json), + DataType.Type.SMALLINT: dt.int16, + DataType.Type.SMALLMONEY: dt.int32, + DataType.Type.TEXT: dt.string, + DataType.Type.TIME: dt.time, + DataType.Type.TIMETZ: dt.time, + DataType.Type.TINYINT: dt.int8, + DataType.Type.UBIGINT: dt.uint64, + DataType.Type.UINT: dt.uint32, + DataType.Type.USMALLINT: dt.uint16, + DataType.Type.UTINYINT: dt.uint8, + DataType.Type.UUID: dt.uuid, + DataType.Type.VARBINARY: dt.binary, + DataType.Type.VARCHAR: dt.string, + DataType.Type.VARIANT: dt.json, + DataType.Type.UNIQUEIDENTIFIER: dt.uuid, + ############################# + # Unsupported sqlglot types # + ############################# + # BIGSERIAL = auto() + # DATETIME64 = auto() # clickhouse + # ENUM = auto() + # INT4RANGE = auto() + # INT4MULTIRANGE = auto() + # INT8RANGE = auto() + # INT8MULTIRANGE = auto() + # NUMRANGE = auto() + # NUMMULTIRANGE = auto() + # TSRANGE = auto() + # TSMULTIRANGE = auto() + # TSTZRANGE = auto() + # TSTZMULTIRANGE = auto() + # DATERANGE = auto() + # DATEMULTIRANGE = auto() + # HLLSKETCH = auto() + # IMAGE = auto() + # IPPREFIX = auto() + # ROWVERSION = auto() + # SERIAL = auto() + # SET = auto() + # SMALLSERIAL = auto() + # SUPER = auto() + # TIMESTAMPLTZ = auto() + # UNKNOWN = auto() # Sentinel value, useful for type annotation + # UINT128 = auto() + # UINT256 = auto() + # USERDEFINED = "USER-DEFINED" + # XML = auto() +} + + +class TypeParser(abc.ABC): + __slots__ = () + + @property + @abc.abstractmethod + def dialect(self) -> str: + """The dialect this parser is for.""" + + default_nullable = True + """Default nullability when not specified.""" + + default_decimal_precision: int | None = None + """Default decimal precision when not specified.""" + + default_decimal_scale: int | None = None + """Default decimal scale when not specified.""" + + default_temporal_scale: int | None = None + """Default temporal scale when not specified.""" + + default_interval_precision: str | None = None + """Default interval precision when not specified.""" + + short_circuit: Mapping[str, dt.DataType] = FrozenDict() + """Default short-circuit mapping of SQL string types to ibis types.""" + + @classmethod + def parse(cls, text: str) -> dt.DataType: + """Parse a type string into an ibis data type.""" + short_circuit = cls.short_circuit + if dtype := short_circuit.get(text, short_circuit.get(text.upper())): + return dtype + return cls._get_type(sg.parse_one(text, into=DataType, read=cls.dialect)) + + @classmethod + def _get_ARRAY(cls, value_type: DataType) -> dt.Array: + return dt.Array(cls._get_type(value_type), nullable=cls.default_nullable) + + @classmethod + def _get_MAP(cls, key_type: DataType, value_type: DataType) -> dt.Map: + return dt.Map( + cls._get_type(key_type), + cls._get_type(value_type), + nullable=cls.default_nullable, + ) + + @classmethod + def _get_STRUCT(cls, *fields: ColumnDef) -> dt.Struct: + return dt.Struct( + {field.name: cls._get_type(field.args["kind"]) for field in fields}, + nullable=cls.default_nullable, + ) + + @classmethod + def _get_TIMESTAMP( + cls, scale: DataTypeSize | None = None, timezone: DataTypeSize | None = None + ) -> dt.Timestamp: + return dt.Timestamp( + timezone=timezone if timezone is None else timezone.this.this, + scale=cls.default_temporal_scale if scale is None else int(scale.this.this), + nullable=cls.default_nullable, + ) + + @classmethod + def _get_TIMESTAMPTZ(cls, scale: DataTypeSize | None = None) -> dt.Timestamp: + return dt.Timestamp( + timezone="UTC", + scale=cls.default_temporal_scale if scale is None else int(scale.this.this), + nullable=cls.default_nullable, + ) + + @classmethod + def _get_DATETIME(cls, scale: DataTypeSize | None = None) -> dt.Timestamp: + return dt.Timestamp( + timezone="UTC", + scale=cls.default_temporal_scale if scale is None else int(scale.this.this), + nullable=cls.default_nullable, + ) + + @classmethod + def _get_INTERVAL(cls, precision: DataTypeSize | None = None) -> dt.Interval: + if precision is None: + precision = cls.default_interval_precision + return dt.Interval(str(precision), nullable=cls.default_nullable) + + @classmethod + def _get_DECIMAL( + cls, precision: DataTypeSize | None = None, scale: DataTypeSize | None = None + ) -> dt.Decimal: + if precision is None: + precision = cls.default_decimal_precision + else: + precision = int(precision.this.this) + + if scale is None: + scale = cls.default_decimal_scale + else: + scale = int(scale.this.this) + + return dt.Decimal(precision, scale, nullable=cls.default_nullable) + + @classmethod + def _get_type(cls, parse_result: DataType) -> dt.DataType: + typ = parse_result.this + + if (result := SQLGLOT_TYPE_TO_IBIS_TYPE.get(typ)) is not None: + return result.copy(nullable=cls.default_nullable) + elif (method := getattr(cls, f"_get_{typ.name}", None)) is not None: + return method(*parse_result.expressions) + else: + raise exc.IbisTypeError(f"Unknown type: {typ}") diff --git a/poetry.lock b/poetry.lock index a721c56b86a5..b472cd08fd3c 100644 --- a/poetry.lock +++ b/poetry.lock @@ -4910,14 +4910,14 @@ sqlalchemy = ">=1.0.0" [[package]] name = "sqlglot" -version = "17.9.1" +version = "17.12.0" description = "An easily customizable SQL parser and transpiler" category = "main" optional = false python-versions = "*" files = [ - {file = "sqlglot-17.9.1-py3-none-any.whl", hash = "sha256:eac3243fdffdf96aeff16b6c4bb070e2d2c70d1e4ac865cb8bfa96b7cf3e6611"}, - {file = "sqlglot-17.9.1.tar.gz", hash = "sha256:2a72ec4078f12debbb3e3aad9f5e1ac0591c9729f5b6becfdf45b64b48b41217"}, + {file = "sqlglot-17.12.0-py3-none-any.whl", hash = "sha256:60209561d0f66c53a8aff56362e87a955d6f4b0f7b25704cd301157583dc0230"}, + {file = "sqlglot-17.12.0.tar.gz", hash = "sha256:524ceb67a82408bb393de2ec85c7cf3a9f2e1f9c9f508fa4e4497d766ed06fa9"}, ] [package.extras] @@ -5524,4 +5524,4 @@ visualization = ["graphviz"] [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "c5d8989f669d9c243b4c440598a161a242289c672fa236d2aa64b6a5608dd45d" +content-hash = "e2ec7c3d1b7c78261b7c6b2dc7ea6f9c74da3453759ea3f461515e572f0ba891" diff --git a/pyproject.toml b/pyproject.toml index 2c05b50cdf07..552c83c38ff4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,7 +42,7 @@ pyarrow = ">=2,<13" python-dateutil = ">=2.8.2,<3" pytz = ">=2022.7" rich = ">=12.4.4,<14" -sqlglot = ">=10.4.3,<18" +sqlglot = ">=17.12.0,<18" toolz = ">=0.11,<1" typing-extensions = ">=4.3.0,<5" black = { version = ">=22.1.0,<24", optional = true } diff --git a/requirements-dev.txt b/requirements-dev.txt index a60d0ca1b7ce..06b6c4b16703 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -196,7 +196,7 @@ sortedcontainers==2.4.0 ; python_version >= "3.9" and python_version < "4.0" soupsieve==2.4.1 ; python_version >= "3.9" and python_version < "4.0" sqlalchemy-views==0.3.2 ; python_version >= "3.9" and python_version < "4.0" sqlalchemy==1.4.49 ; python_version >= "3.9" and python_version < "4.0" -sqlglot==17.9.1 ; python_version >= "3.9" and python_version < "4.0" +sqlglot==17.12.0 ; python_version >= "3.9" and python_version < "4.0" stack-data==0.6.2 ; python_version >= "3.9" and python_version < "4.0" stdlib-list==0.9.0 ; python_version >= "3.9" and python_version < "4.0" termcolor==2.3.0 ; python_version >= "3.9" and python_version < "4.0"