Skip to content

Commit

Permalink
[FEAT] [SQL] Enable SQL query to run on callers scoped variables (#2864)
Browse files Browse the repository at this point in the history
This builds a catalog based on the python globals and locals visible to
the caller at the point where the `sql` query function is called, in the
case where a catalog is not supplied. Otherwise, the catalog is final
and must contain necessary tables.

resolves #2740
  • Loading branch information
amitschang authored Sep 23, 2024
1 parent 29be743 commit fd42281
Show file tree
Hide file tree
Showing 5 changed files with 101 additions and 1 deletion.
1 change: 1 addition & 0 deletions daft/daft/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -1277,6 +1277,7 @@ class PyCatalog:
@staticmethod
def new() -> PyCatalog: ...
def register_table(self, name: str, logical_plan_builder: LogicalPlanBuilder) -> None: ...
def copy_from(self, other: PyCatalog) -> None: ...

class PySeries:
@staticmethod
Expand Down
39 changes: 38 additions & 1 deletion daft/sql/sql.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
# isort: dont-add-import: from __future__ import annotations

import inspect
from typing import Optional, overload

from daft.api_annotations import PublicAPI
from daft.context import get_context
from daft.daft import PyCatalog as _PyCatalog
from daft.daft import sql as _sql
from daft.daft import sql_expr as _sql_expr
from daft.dataframe import DataFrame
from daft.exceptions import DaftCoreException
from daft.expressions import Expression
from daft.logical.builder import LogicalPlanBuilder

Expand All @@ -28,24 +32,57 @@ def __init__(self, tables: dict) -> None:
def __str__(self) -> str:
return str(self._catalog)

def _copy_from(self, other: "SQLCatalog") -> None:
self._catalog.copy_from(other._catalog)


@PublicAPI
def sql_expr(sql: str) -> Expression:
return Expression._from_pyexpr(_sql_expr(sql))


@overload
def sql(sql: str) -> DataFrame: ...


@overload
def sql(sql: str, catalog: SQLCatalog, register_globals: bool = ...) -> DataFrame: ...


@PublicAPI
def sql(sql: str, catalog: SQLCatalog) -> DataFrame:
def sql(sql: str, catalog: Optional[SQLCatalog] = None, register_globals: bool = True) -> DataFrame:
"""Create a DataFrame from an SQL query.
EXPERIMENTAL: This features is early in development and will change.
Args:
sql (str): SQL query to execute
catalog (SQLCatalog, optional): Catalog of tables to use in the query.
Defaults to None, in which case a catalog will be built from variables
in the callers scope.
register_globals (bool, optional): Whether to incorporate global
variables into the supplied catalog, in which case a copy of the
catalog will be made and the original not modified. Defaults to True.
Returns:
DataFrame: Dataframe containing the results of the query
"""
if register_globals:
try:
# Caller is back from func, analytics, annotation
caller_frame = inspect.currentframe().f_back.f_back.f_back # type: ignore
caller_vars = {**caller_frame.f_globals, **caller_frame.f_locals} # type: ignore
except AttributeError as exc:
# some interpreters might not implement currentframe; all reasonable
# errors above should be AttributeError
raise DaftCoreException("Cannot get caller environment, please provide a catalog") from exc
catalog_ = SQLCatalog({k: v for k, v in caller_vars.items() if isinstance(v, DataFrame)})
if catalog is not None:
catalog_._copy_from(catalog)
catalog = catalog_
elif catalog is None:
raise DaftCoreException("Must supply a catalog if register_globals is False")

planning_config = get_context().daft_planning_config

_py_catalog = catalog._catalog
Expand Down
7 changes: 7 additions & 0 deletions src/daft-sql/src/catalog.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,13 @@ impl SQLCatalog {
pub fn get_table(&self, name: &str) -> Option<LogicalPlanRef> {
self.tables.get(name).cloned()
}

/// Copy from another catalog, using tables from other in case of conflict
pub fn copy_from(&mut self, other: &SQLCatalog) {
for (name, plan) in other.tables.iter() {
self.tables.insert(name.clone(), plan.clone());
}
}
}

impl Default for SQLCatalog {
Expand Down
5 changes: 5 additions & 0 deletions src/daft-sql/src/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,11 @@ impl PyCatalog {
self.catalog.register_table(name, plan);
}

/// Copy from another catalog, using tables from other in case of conflict
pub fn copy_from(&mut self, other: &PyCatalog) {
self.catalog.copy_from(&other.catalog);
}

/// __str__ to print the catalog's tables
fn __str__(&self) -> String {
format!("{:?}", self.catalog)
Expand Down
50 changes: 50 additions & 0 deletions tests/sql/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pytest

import daft
from daft.exceptions import DaftCoreException
from daft.sql.sql import SQLCatalog
from tests.assets import TPCH_QUERIES

Expand Down Expand Up @@ -149,3 +150,52 @@ def test_sql_count_star():
actual = df2.collect().to_pydict()
expected = df.agg(daft.col("b").count()).collect().to_pydict()
assert actual == expected


GLOBAL_DF = daft.from_pydict({"n": [1, 2, 3]})


def test_sql_function_sees_caller_tables():
# sees the globals
df = daft.sql("SELECT * FROM GLOBAL_DF")
assert df.collect().to_pydict() == GLOBAL_DF.collect().to_pydict()
# sees the locals
df_copy = daft.sql("SELECT * FROM df")
assert df.collect().to_pydict() == df_copy.collect().to_pydict()


def test_sql_function_locals_shadow_globals():
GLOBAL_DF = None # noqa: F841
with pytest.raises(Exception, match="Table not found"):
daft.sql("SELECT * FROM GLOBAL_DF")


def test_sql_function_globals_are_added_to_catalog():
df = daft.from_pydict({"n": [1], "x": [2]})
res = daft.sql("SELECT * FROM GLOBAL_DF g JOIN df d USING (n)", catalog=SQLCatalog({"df": df}))
joined = GLOBAL_DF.join(df, on="n")
assert res.collect().to_pydict() == joined.collect().to_pydict()


def test_sql_function_catalog_is_final():
df = daft.from_pydict({"a": [1]})
# sanity check to ensure validity of below test
assert df.collect().to_pydict() != GLOBAL_DF.collect().to_pydict()
res = daft.sql("SELECT * FROM GLOBAL_DF", catalog=SQLCatalog({"GLOBAL_DF": df}))
assert res.collect().to_pydict() == df.collect().to_pydict()


def test_sql_function_register_globals():
with pytest.raises(Exception, match="Table not found"):
daft.sql("SELECT * FROM GLOBAL_DF", SQLCatalog({}), register_globals=False)


def test_sql_function_requires_catalog_or_globals():
with pytest.raises(Exception, match="Must supply a catalog"):
daft.sql("SELECT * FROM GLOBAL_DF", register_globals=False)


def test_sql_function_raises_when_cant_get_frame(monkeypatch):
monkeypatch.setattr("inspect.currentframe", lambda: None)
with pytest.raises(DaftCoreException, match="Cannot get caller environment"):
daft.sql("SELECT * FROM df")

0 comments on commit fd42281

Please sign in to comment.