Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEAT] Allow user provided schema and schema inference length for read_sql #2676

Merged
merged 2 commits into from
Aug 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions daft/daft.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -1035,6 +1035,7 @@ class PySchema:
def _repr_html_(self) -> str: ...
def _truncated_table_html(self) -> str: ...
def _truncated_table_string(self) -> str: ...
def apply_hints(self, hints: PySchema) -> PySchema: ...

class PyExpr:
def alias(self, name: str) -> PyExpr: ...
Expand Down
18 changes: 17 additions & 1 deletion daft/io/_sql.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
# isort: dont-add-import: from __future__ import annotations


from typing import TYPE_CHECKING, Callable, Optional, Union
from typing import TYPE_CHECKING, Callable, Dict, Optional, Union

from daft import context, from_pydict
from daft.api_annotations import PublicAPI
from daft.daft import PythonStorageConfig, ScanOperatorHandle, StorageConfig
from daft.dataframe import DataFrame
from daft.datatype import DataType
from daft.logical.builder import LogicalPlanBuilder
from daft.sql.sql_connection import SQLConnection
from daft.sql.sql_scan import SQLScanOperator
Expand All @@ -22,6 +23,9 @@ def read_sql(
partition_col: Optional[str] = None,
num_partitions: Optional[int] = None,
disable_pushdowns_to_sql: bool = False,
infer_schema: bool = True,
infer_schema_length: int = 10,
schema: Optional[Dict[str, DataType]] = None,
) -> DataFrame:
"""Create a DataFrame from the results of a SQL query.

Expand All @@ -32,6 +36,10 @@ def read_sql(
num_partitions (Optional[int]): Number of partitions to read the data into,
defaults to None, which will lets Daft determine the number of partitions.
disable_pushdowns_to_sql (bool): Whether to disable pushdowns to the SQL query, defaults to False
infer_schema (bool): Whether to turn on schema inference, defaults to True. If set to False, the schema parameter must be provided.
infer_schema_length (int): The number of rows to scan when inferring the schema, defaults to 10. If infer_schema is False, this parameter is ignored. Note that if Daft is able to use ConnectorX to infer the schema, this parameter is ignored as ConnectorX is an Arrow backed driver.
schema (Optional[Dict[str, DataType]]): A mapping of column names to datatypes. If infer_schema is False, this schema is used as the definitive schema for the data, otherwise it is used as a schema hint that is applied after the schema is inferred.
This can be useful if the types can be more precisely determined than what the inference can provide (e.g., if a column can be declared as a fixed-sized list rather than a list).

Returns:
DataFrame: Dataframe containing the results of the query
Expand Down Expand Up @@ -86,6 +94,11 @@ def read_sql(
if num_partitions is not None and partition_col is None:
raise ValueError("Failed to execute sql: partition_col must be specified when num_partitions is specified")

if not infer_schema and schema is None:
raise ValueError(
"Cannot read DataFrame with infer_schema=False and schema=None, please provide a schema or set infer_schema=True"
)

io_config = context.get_context().daft_planning_config.default_io_config
storage_config = StorageConfig.python(PythonStorageConfig(io_config))

Expand All @@ -95,6 +108,9 @@ def read_sql(
sql_conn,
storage_config,
disable_pushdowns_to_sql,
infer_schema,
infer_schema_length,
schema,
partition_col=partition_col,
num_partitions=num_partitions,
)
Expand Down
3 changes: 3 additions & 0 deletions daft/logical/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,9 @@ def _truncated_table_html(self) -> str:
def _truncated_table_string(self) -> str:
return self._schema._truncated_table_string()

def apply_hints(self, hints: Schema) -> Schema:
return Schema._from_pyschema(self._schema.apply_hints(hints._schema))

def union(self, other: Schema) -> Schema:
if not isinstance(other, Schema):
raise ValueError(f"Expected Schema, got other: {type(other)}")
Expand Down
73 changes: 68 additions & 5 deletions daft/sql/sql_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

import pyarrow as pa

from daft.logical.schema import Schema

if TYPE_CHECKING:
from sqlalchemy.engine import Connection

Expand Down Expand Up @@ -48,10 +50,68 @@ def from_connection_factory(cls, conn_factory: Callable[[], Connection]) -> SQLC
except Exception as e:
raise ValueError(f"Unexpected error while calling the connection factory: {e}") from e

def read(self, sql: str) -> pa.Table:
def read_schema(self, sql: str, infer_schema_length: int) -> Schema:
if self._should_use_connectorx():
sql = self.construct_sql_query(sql, limit=0)
else:
sql = self.construct_sql_query(sql, limit=infer_schema_length)
table = self._execute_sql_query(sql)
schema = Schema.from_pyarrow_schema(table.schema)
return schema

def read(
self,
sql: str,
projection: list[str] | None = None,
limit: int | None = None,
predicate: str | None = None,
partition_bounds: tuple[str, str] | None = None,
) -> pa.Table:
sql = self.construct_sql_query(sql, projection, predicate, limit, partition_bounds)
return self._execute_sql_query(sql)

def _execute_sql_query(self, sql: str) -> pa.Table:
def construct_sql_query(
self,
sql: str,
projection: list[str] | None = None,
predicate: str | None = None,
limit: int | None = None,
partition_bounds: tuple[str, str] | None = None,
) -> str:
import sqlglot

target_dialect = self.dialect
# sqlglot does not support "postgresql" dialect, it only supports "postgres"
if target_dialect == "postgresql":
target_dialect = "postgres"
# sqlglot does not recognize "mssql" as a dialect, it instead recognizes "tsql", which is the SQL dialect for Microsoft SQL Server
elif target_dialect == "mssql":
target_dialect = "tsql"

if not any(target_dialect == supported_dialect.value for supported_dialect in sqlglot.Dialects):
raise ValueError(
f"Unsupported dialect: {target_dialect}, please refer to the documentation for supported dialects."
)

query = sqlglot.subquery(sql, "subquery")

if projection is not None:
query = query.select(*projection)
else:
query = query.select("*")

if predicate is not None:
query = query.where(predicate)

if partition_bounds is not None:
query = query.where(partition_bounds[0]).where(partition_bounds[1])

if limit is not None:
query = query.limit(limit)

return query.sql(dialect=target_dialect)

def _should_use_connectorx(self) -> bool:
# Supported DBs extracted from here https://github.com/sfu-db/connector-x/tree/7b3147436b7e20b96691348143d605e2249d6119?tab=readme-ov-file#sources
connectorx_supported_dbs = {
"postgres",
Expand All @@ -67,9 +127,12 @@ def _execute_sql_query(self, sql: str) -> pa.Table:

if isinstance(self.conn, str):
if self.dialect in connectorx_supported_dbs and self.driver == "":
return self._execute_sql_query_with_connectorx(sql)
else:
return self._execute_sql_query_with_sqlalchemy(sql)
return True
return False

def _execute_sql_query(self, sql: str) -> pa.Table:
if self._should_use_connectorx():
return self._execute_sql_query_with_connectorx(sql)
else:
return self._execute_sql_query_with_sqlalchemy(sql)

Expand Down
87 changes: 31 additions & 56 deletions daft/sql/sql_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
ScanTask,
StorageConfig,
)
from daft.datatype import DataType
from daft.expressions.expressions import lit
from daft.io.common import _get_schema_from_dict
from daft.io.scan import PartitionField, ScanOperator
from daft.logical.schema import Schema
from daft.sql.sql_connection import SQLConnection
Expand All @@ -37,6 +39,9 @@ def __init__(
conn: SQLConnection,
storage_config: StorageConfig,
disable_pushdowns_to_sql: bool,
infer_schema: bool,
infer_schema_length: int,
schema: dict[str, DataType] | None,
partition_col: str | None = None,
num_partitions: int | None = None,
) -> None:
Expand All @@ -47,7 +52,7 @@ def __init__(
self._disable_pushdowns_to_sql = disable_pushdowns_to_sql
self._partition_col = partition_col
self._num_partitions = num_partitions
self._schema = self._attempt_schema_read()
self._schema = self._attempt_schema_read(infer_schema, infer_schema_length, schema)

def schema(self) -> Schema:
return self._schema
Expand Down Expand Up @@ -106,11 +111,21 @@ def can_absorb_limit(self) -> bool:
def can_absorb_select(self) -> bool:
return False

def _attempt_schema_read(self) -> Schema:
sql = self._construct_sql_query(limit=1)
pa_table = self.conn.read(sql)
schema = Schema.from_pyarrow_schema(pa_table.schema)
return schema
def _attempt_schema_read(
self,
infer_schema: bool,
infer_schema_length: int,
schema: dict[str, DataType] | None,
) -> Schema:
# If schema is provided and user turned off schema inference, use the provided schema
if schema is not None and not infer_schema:
return _get_schema_from_dict(schema)

# Else, attempt schema inference then apply the schema hint if provided
inferred_schema = self.conn.read_schema(self.sql, infer_schema_length)
if schema is not None:
return inferred_schema.apply_hints(_get_schema_from_dict(schema))
return inferred_schema

def _get_size_estimates(self) -> tuple[int, float, int]:
total_rows = self._get_num_rows()
Expand All @@ -124,8 +139,7 @@ def _get_size_estimates(self) -> tuple[int, float, int]:
return total_rows, total_size, num_scan_tasks

def _get_num_rows(self) -> int:
sql = self._construct_sql_query(projection=["COUNT(*)"])
pa_table = self.conn.read(sql)
pa_table = self.conn.read(self.sql, projection=["COUNT(*)"])

if pa_table.num_rows != 1:
raise RuntimeError(
Expand All @@ -142,13 +156,13 @@ def _attempt_partition_bounds_read(self, num_scan_tasks: int) -> tuple[Any, Part
try:
# Try to get percentiles using percentile_cont
percentiles = [i / num_scan_tasks for i in range(num_scan_tasks + 1)]
sql = self._construct_sql_query(
pa_table = self.conn.read(
self.sql,
projection=[
f"percentile_disc({percentile}) WITHIN GROUP (ORDER BY {self._partition_col}) AS bound_{i}"
for i, percentile in enumerate(percentiles)
]
],
)
pa_table = self.conn.read(sql)
return pa_table, PartitionBoundStrategy.PERCENTILE

except RuntimeError as e:
Expand All @@ -158,13 +172,13 @@ def _attempt_partition_bounds_read(self, num_scan_tasks: int) -> tuple[Any, Part
e,
)

sql = self._construct_sql_query(
pa_table = self.conn.read(
self.sql,
projection=[
f"MIN({self._partition_col}) AS min",
f"MAX({self._partition_col}) AS max",
]
],
)
pa_table = self.conn.read(sql)
return pa_table, PartitionBoundStrategy.MIN_MAX

def _get_partition_bounds_and_strategy(self, num_scan_tasks: int) -> tuple[list[Any], PartitionBoundStrategy]:
Expand Down Expand Up @@ -226,14 +240,15 @@ def _construct_scan_task(
)

if apply_pushdowns_to_sql:
sql = self._construct_sql_query(
sql = self.conn.construct_sql_query(
self.sql,
projection=pushdowns.columns,
predicate=predicate_sql,
limit=pushdowns.limit,
partition_bounds=partition_bounds,
)
else:
sql = self._construct_sql_query(partition_bounds=partition_bounds)
sql = self.conn.construct_sql_query(self.sql, partition_bounds=partition_bounds)

file_format_config = FileFormatConfig.from_database_config(DatabaseSourceConfig(sql, self.conn))
return ScanTask.sql_scan_task(
Expand All @@ -246,43 +261,3 @@ def _construct_scan_task(
pushdowns=pushdowns if not apply_pushdowns_to_sql else None,
stats=stats,
)

def _construct_sql_query(
self,
projection: list[str] | None = None,
predicate: str | None = None,
limit: int | None = None,
partition_bounds: tuple[str, str] | None = None,
) -> str:
import sqlglot

target_dialect = self.conn.dialect
# sqlglot does not support "postgresql" dialect, it only supports "postgres"
if target_dialect == "postgresql":
target_dialect = "postgres"
# sqlglot does not recognize "mssql" as a dialect, it instead recognizes "tsql", which is the SQL dialect for Microsoft SQL Server
elif target_dialect == "mssql":
target_dialect = "tsql"

if not any(target_dialect == supported_dialect.value for supported_dialect in sqlglot.Dialects):
raise ValueError(
f"Unsupported dialect: {target_dialect}, please refer to the documentation for supported dialects."
)

query = sqlglot.subquery(self.sql, "subquery")

if projection is not None:
query = query.select(*projection)
else:
query = query.select("*")

if predicate is not None:
query = query.where(predicate)

if partition_bounds is not None:
query = query.where(partition_bounds[0]).where(partition_bounds[1])

if limit is not None:
query = query.limit(limit)

return query.sql(dialect=target_dialect)
5 changes: 5 additions & 0 deletions src/daft-core/src/python/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,11 @@ impl PySchema {
pub fn _truncated_table_string(&self) -> PyResult<String> {
Ok(self.schema.truncated_table_string())
}

pub fn apply_hints(&self, hints: &PySchema) -> PyResult<PySchema> {
let new_schema = Arc::new(self.schema.apply_hints(&hints.schema)?);
Ok(new_schema.into())
}
}

impl_bincode_py_state_serialization!(PySchema);
Expand Down
2 changes: 1 addition & 1 deletion tests/integration/sql/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def generated_data(request: pytest.FixtureRequest) -> pd.DataFrame:
"time_col": [
(datetime.combine(datetime.today(), time(0, 0)) + timedelta(minutes=x)).time() for x in range(200)
],
"null_col": [None if i % 2 == 1 else "not_null" for i in range(num_rows)],
"null_col": [None if i % 2 == 0 else "not_null" for i in range(num_rows)],
"non_uniformly_distributed_col": [1 for _ in range(num_rows)],
}
return pd.DataFrame(data)
Expand Down
Loading
Loading