Skip to content

Commit

Permalink
refactor schema to be more lenient with a trie
Browse files Browse the repository at this point in the history
  • Loading branch information
tobymao committed Nov 3, 2022
1 parent fd0efba commit 55bae2e
Show file tree
Hide file tree
Showing 5 changed files with 202 additions and 305 deletions.
2 changes: 1 addition & 1 deletion sqlglot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
subquery,
)
from sqlglot.expressions import table_ as table
from sqlglot.expressions import union
from sqlglot.expressions import to_column, to_table, union
from sqlglot.generator import Generator
from sqlglot.parser import Parser
from sqlglot.schema import MappingSchema
Expand Down
4 changes: 4 additions & 0 deletions sqlglot/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ class OptimizeError(SqlglotError):
pass


class SchemaError(SqlglotError):
pass


def concat_errors(errors: t.Sequence[t.Any], maximum: int) -> str:
msg = [str(e) for e in errors[:maximum]]
remaining = len(errors) - maximum
Expand Down
177 changes: 105 additions & 72 deletions sqlglot/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,17 @@
import typing as t

from sqlglot import expressions as exp
from sqlglot.errors import OptimizeError
from sqlglot.errors import SchemaError
from sqlglot.helper import csv_reader
from sqlglot.trie import in_trie, new_trie

if t.TYPE_CHECKING:
from sqlglot.dataframe.sql.types import StructType

ColumnMapping = t.Union[t.Dict, str, StructType, t.List]

TABLE_ARGS = ("this", "db", "catalog")


class Schema(abc.ABC):
"""Abstract base class for database schemas"""
Expand Down Expand Up @@ -80,14 +83,12 @@ def __init__(
dialect: t.Optional[str] = None,
) -> None:
self.schema = schema or {}
self.visible = visible
self.visible = visible or {}
self.schema_trie = self._build_trie(self.schema)
self.visible_trie = self._build_trie(self.visible)
self.dialect = dialect
self._type_mapping_cache: t.Dict[str, exp.DataType.Type] = {}
self.supported_table_args: t.List | t.Tuple[str, ...] = []
self.forbidden_table_args: t.Set[str] = set()

if self.schema:
self._initialize_supported_args()
self._supported_table_args: t.Tuple[str, ...] = tuple()

@classmethod
def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema:
Expand All @@ -98,7 +99,28 @@ def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema:
)

def copy(self, **kwargs) -> MappingSchema:
return MappingSchema(**{"schema": self.schema.copy(), **kwargs}) # type: ignore
return MappingSchema(
**{
"schema": self.schema.copy(),
"visible": self.visible.copy(),
"dialect": self.dialect,
**kwargs,
} # type: ignore
)

@property
def supported_table_args(self):
if not self._supported_table_args and self.schema:
depth = _dict_depth(self.schema)

if not depth or depth == 1: # {}
self._supported_table_args = tuple()
elif 2 <= depth <= 4:
self._supported_table_args = TABLE_ARGS[: depth - 1]
else:
raise SchemaError(f"Invalid schema shape. Depth: {depth}")

return self._supported_table_args

def add_table(
self, table: exp.Table | str, column_mapping: t.Optional[ColumnMapping] = None
Expand All @@ -110,81 +132,81 @@ def add_table(
table: the `Table` expression instance or string representing the table.
column_mapping: a column mapping that describes the structure of the table.
"""
table = exp.to_table(table) # type: ignore
self._validate_table(table) # type: ignore
table_ = self._ensure_table(table)
column_mapping = ensure_column_mapping(column_mapping)
table_args = [table.text(p) for p in self.supported_table_args or self._get_table_args_from_table(table)] # type: ignore
existing_column_mapping = _nested_get(
self.schema, *zip(self.supported_table_args, table_args), raise_on_missing=False
)
schema = self.find_schema(table_, raise_on_missing=False)

if existing_column_mapping and not column_mapping:
if schema and not column_mapping:
return

_nested_set(
self.schema,
[table.text(p) for p in self.supported_table_args or self._get_table_args_from_table(table)], # type: ignore
list(reversed(self.table_parts(table_))),
column_mapping,
)
self._initialize_supported_args()
self.schema_trie = self._build_trie(self.schema)

def _get_table_args_from_table(self, table: exp.Table) -> t.Tuple[str, ...]:
if table.args.get("catalog") is not None:
return "catalog", "db", "this"
if table.args.get("db") is not None:
return "db", "this"
return ("this",)
def _ensure_table(self, table: exp.Table | str) -> exp.Table:
table_ = exp.to_table(table)

def _validate_table(self, table: exp.Table) -> None:
if not self.supported_table_args and isinstance(table, exp.Table):
return
for forbidden in self.forbidden_table_args:
if table.text(forbidden):
raise ValueError(f"Schema doesn't support {forbidden}. Received: {table.sql()}")
for expected in self.supported_table_args:
if not table.text(expected):
raise ValueError(f"Table is expected to have {expected}. Received: {table.sql()} ")
if not table_:
raise SchemaError(f"Not a valid table '{table}'")

return table_

def table_parts(self, table: exp.Table) -> t.List[str]:
return [table.text(part) for part in TABLE_ARGS if table.text(part)]

def column_names(self, table: exp.Table | str, only_visible: bool = False) -> t.List[str]:
table = exp.to_table(table) # type: ignore
if not isinstance(table.this, exp.Identifier): # type: ignore
table_ = self._ensure_table(table)

if not isinstance(table_.this, exp.Identifier):
return fs_get(table) # type: ignore

args = tuple(table.text(p) for p in self.supported_table_args) # type: ignore
schema = self.find_schema(table_)

for forbidden in self.forbidden_table_args:
if table.text(forbidden): # type: ignore
raise ValueError(f"Schema doesn't support {forbidden}. Received: {table.sql()}") # type: ignore
if schema is None:
raise SchemaError(f"Could not find table schema {table}")

columns = list(_nested_get(self.schema, *zip(self.supported_table_args, args))) # type: ignore
if not only_visible or not self.visible:
return columns
return list(schema)

visible_schema = self.find_schema(table_, trie=self.visible_trie)
return [col for col in schema if col in visible_schema] # type: ignore

visible = _nested_get(self.visible, *zip(self.supported_table_args, args)) # type: ignore
return [col for col in columns if col in visible] # type: ignore
def find_schema(
self, table: exp.Table, trie=None, raise_on_missing=True
) -> t.Optional[t.Dict[str, str]]:
parts = self.table_parts(table)[0 : len(self.supported_table_args)]
value, trie = in_trie(trie or self.schema_trie, parts)

if value == 0:
if raise_on_missing:
raise SchemaError(f"Cannot find schema for {table}.")
else:
return None
elif value == 1:
possibilities = flatten_schema(trie)
if len(possibilities) == 1:
parts.extend(possibilities[0])
else:
message = ", ".join(".".join(parts) for parts in possibilities)
if raise_on_missing:
raise SchemaError(f"Ambiguous schema for {table}: {message}.")
return None

return self._nested_get(parts, raise_on_missing=raise_on_missing)

def get_column_type(
self, table: exp.Table | str, column: exp.Column | str
) -> exp.DataType.Type:
try:
column_name = column if isinstance(column, str) else column.name
if isinstance(table, exp.Table):
supported_table_args = self.supported_table_args or self._get_table_args_from_table(
table
)
table_args = [table.text(p) for p in supported_table_args]
table_schema = _nested_get(
self.schema,
*zip(supported_table_args, table_args),
raise_on_missing=False,
)
else:
table_schema = self.schema.get(table, {})

column_name = column if isinstance(column, str) else column.name
table_ = exp.to_table(table)
if table_:
table_schema = self.find_schema(table_)
schema_type = table_schema.get(column_name).upper() # type: ignore
return self._convert_type(schema_type)
except:
raise OptimizeError(f"Failed to get type for column {column_name}")
raise SchemaError(f"Could not convert table '{table}'")

def _convert_type(self, schema_type: str) -> exp.DataType.Type:
"""
Expand All @@ -202,23 +224,21 @@ def _convert_type(self, schema_type: str) -> exp.DataType.Type:
schema_type, into=exp.DataType, dialect=self.dialect
).this
except AttributeError:
raise OptimizeError(f"Failed to convert type {schema_type}")
raise SchemaError(f"Failed to convert type {schema_type}")

return self._type_mapping_cache[schema_type]

def _initialize_supported_args(self) -> None:
if not self.supported_table_args:
depth = _dict_depth(self.schema)

all_args = ["this", "db", "catalog"]
if not depth or depth == 1: # {}
self.supported_table_args = []
elif 2 <= depth <= 4:
self.supported_table_args = tuple(reversed(all_args[: depth - 1]))
else:
raise OptimizeError(f"Invalid schema shape. Depth: {depth}")
def _build_trie(self, schema):
return new_trie(reversed(t) for t in flatten_schema(schema))

self.forbidden_table_args = {"catalog", "db", "this"} - set(self.supported_table_args)
def _nested_get(
self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True
) -> t.Optional[t.Any]:
return _nested_get(
d or self.schema,
*zip(self.supported_table_args, reversed(parts)),
raise_on_missing=raise_on_missing,
)


def ensure_schema(schema: t.Any) -> Schema:
Expand Down Expand Up @@ -247,6 +267,19 @@ def ensure_column_mapping(mapping: t.Optional[ColumnMapping]):
raise ValueError(f"Invalid mapping provided: {type(mapping)}")


def flatten_schema(schema, keys=None):
tables = []
keys = keys or []
depth = _dict_depth(schema)

for k, v in schema.items():
if depth >= 3:
tables.extend(flatten_schema(v, keys + [k]))
elif depth == 2:
tables.append(keys + [k])
return tables


def fs_get(table: exp.Table) -> t.List[str]:
name = table.this.name

Expand Down
6 changes: 4 additions & 2 deletions sqlglot/trie.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import typing as t

key = t.Sequence[t.Hashable]

def new_trie(keywords: t.Iterable[str]) -> t.Dict:

def new_trie(keywords: t.Iterable[key]) -> t.Dict:
"""
Creates a new trie out of a collection of keywords.
Expand Down Expand Up @@ -30,7 +32,7 @@ def new_trie(keywords: t.Iterable[str]) -> t.Dict:
return trie


def in_trie(trie: t.Dict, key: str) -> t.Tuple[int, t.Dict]:
def in_trie(trie: t.Dict, key: key) -> t.Tuple[int, t.Dict]:
"""
Checks whether a key is in a trie.
Expand Down
Loading

0 comments on commit 55bae2e

Please sign in to comment.