Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow connection objects in from_sql + small doc improvement #5091

Merged
merged 7 commits into from
Oct 9, 2022
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion docs/source/loading.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -210,10 +210,19 @@ For example, a table from a SQLite file can be loaded with:
Use a query for a more precise read:

```py
>>> from sqlite3 import connect
>>> con = connect(":memory")
>>> # db writes ...
>>> from datasets import Dataset
>>> dataset = Dataset.from_sql("SELECT text FROM data_table WHERE length(text) > 100 LIMIT 10", "sqlite:///sqlite_file.db")
>>> dataset = Dataset.from_sql("SELECT text FROM table WHERE length(text) > 100 LIMIT 10", con)
```

<Tip>

To cache the read, specify [`Dataset.from_sql#con`] as a [URI string](https://docs.sqlalchemy.org/en/13/core/engines.html#database-urls).
mariosasko marked this conversation as resolved.
Show resolved Hide resolved

</Tip>

## In-memory data

🤗 Datasets will also allow you to create a [`Dataset`] directly from in-memory data structures like Python dictionaries and Pandas DataFrames.
Expand Down
11 changes: 6 additions & 5 deletions src/datasets/arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1104,7 +1104,7 @@ def from_text(
@staticmethod
def from_sql(
sql: Union[str, "sqlalchemy.sql.Selectable"],
con: str,
con: Union[str, "sqlalchemy.engine.Connection", "sqlalchemy.engine.Engine", "sqlite3.Connection"],
features: Optional[Features] = None,
cache_dir: str = None,
keep_in_memory: bool = False,
Expand All @@ -1114,7 +1114,8 @@ def from_sql(

Args:
sql (`str` or :obj:`sqlalchemy.sql.Selectable`): SQL query to be executed or a table name.
con (`str`): A connection URI string used to instantiate a database connection.
con (`str` or :obj:`sqlite3.Connection` or :obj:`sqlalchemy.engine.Connection` or :obj:`sqlalchemy.engine.Connection`):
A [URI string](https://docs.sqlalchemy.org/en/13/core/engines.html#database-urls) used to instantiate a database connection or a SQLite3/SQLAlchemy connection object.
features (:class:`Features`, optional): Dataset features.
cache_dir (:obj:`str`, optional, default ``"~/.cache/huggingface/datasets"``): Directory to cache data.
keep_in_memory (:obj:`bool`, default ``False``): Whether to copy the data in-memory.
Expand All @@ -1137,7 +1138,7 @@ def from_sql(
```

<Tip {warning=true}>
`sqlalchemy` needs to be installed to use this function.
The returned dataset can only be cached if `con` is specified as URI string.
</Tip>
"""
from .io.sql import SqlDatasetReader
Expand Down Expand Up @@ -4218,8 +4219,8 @@ def to_sql(

Args:
name (`str`): Name of SQL table.
con (`str` or `sqlite3.Connection` or `sqlalchemy.engine.Connection` or `sqlalchemy.engine.Connection`):
A database connection URI string or an existing SQLite3/SQLAlchemy connection used to write to a database.
con (`str` or :obj:`sqlite3.Connection` or :obj:`sqlalchemy.engine.Connection` or :obj:`sqlalchemy.engine.Connection`):
A [URI string](https://docs.sqlalchemy.org/en/13/core/engines.html#database-urls) or a SQLite3/SQLAlchemy connection object used to write to a database.
batch_size (:obj:`int`, optional): Size of the batch to load in memory and write at once.
Defaults to :obj:`datasets.config.DEFAULT_MAX_BATCH_SIZE`.
**sql_writer_kwargs (additional keyword arguments): Parameters to pass to pandas's :function:`Dataframe.to_sql`
Expand Down
2 changes: 1 addition & 1 deletion src/datasets/io/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class SqlDatasetReader(AbstractDatasetInputStream):
def __init__(
self,
sql: Union[str, "sqlalchemy.sql.Selectable"],
con: str,
con: Union[str, "sqlalchemy.engine.Connection", "sqlalchemy.engine.Engine", "sqlite3.Connection"],
features: Optional[Features] = None,
cache_dir: str = None,
keep_in_memory: bool = False,
Expand Down
18 changes: 14 additions & 4 deletions src/datasets/packaged_modules/sql/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,20 @@


if TYPE_CHECKING:
import sqlite3

import sqlalchemy


logger = datasets.utils.logging.get_logger(__name__)


@dataclass
class SqlConfig(datasets.BuilderConfig):
"""BuilderConfig for SQL."""

sql: Union[str, "sqlalchemy.sql.Selectable"] = None
con: str = None
con: Union[str, "sqlalchemy.engine.Connection", "sqlalchemy.engine.Engine", "sqlite3.Connection"] = None
index_col: Optional[Union[str, List[str]]] = None
coerce_float: bool = True
params: Optional[Union[List, Tuple, Dict]] = None
Expand All @@ -34,14 +39,13 @@ def __post_init__(self):
raise ValueError("sql must be specified")
if self.con is None:
raise ValueError("con must be specified")
if not isinstance(self.con, str):
raise ValueError(f"con must be a database URI string, but got {self.con} with type {type(self.con)}.")

def create_config_id(
self,
config_kwargs: dict,
custom_features: Optional[datasets.Features] = None,
) -> str:
config_kwargs = config_kwargs.copy()
# We need to stringify the Selectable object to make its hash deterministic

# The process of stringifying is explained here: http://docs.sqlalchemy.org/en/latest/faq/sqlexpressions.html
Expand All @@ -51,7 +55,6 @@ def create_config_id(
import sqlalchemy

if isinstance(sql, sqlalchemy.sql.Selectable):
config_kwargs = config_kwargs.copy()
engine = sqlalchemy.create_engine(config_kwargs["con"].split("://")[0] + "://")
sql_str = str(sql.compile(dialect=engine.dialect))
config_kwargs["sql"] = sql_str
Expand All @@ -63,6 +66,13 @@ def create_config_id(
raise TypeError(
f"Supported types for 'sql' are string and sqlalchemy.sql.Selectable but got {type(sql)}: {sql}"
)
con = config_kwargs["con"]
if not isinstance(con, str):
config_kwargs["con"] = id(con)
logger.info(
f"'con' of type {type(con)} couldn't be hashed properly. To enable hashing, specify 'con' as URI string instead."
mariosasko marked this conversation as resolved.
Show resolved Hide resolved
)

return super().create_config_id(config_kwargs, custom_features=custom_features)

@property
Expand Down
30 changes: 30 additions & 0 deletions tests/test_arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3356,6 +3356,36 @@ def _check_sql_dataset(dataset, expected_features):
assert dataset.features[feature].dtype == expected_dtype


@require_sqlalchemy
@pytest.mark.parametrize("con_type", ["string", "engine"])
def test_dataset_from_sql_con_type(con_type, sqlite_path, tmp_path):
cache_dir = tmp_path / "cache"
expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64"}
if con_type == "string":
con = "sqlite:///" + sqlite_path
elif con_type == "engine":
import sqlalchemy

con = sqlalchemy.create_engine("sqlite:///" + sqlite_path)
# # https://github.com/huggingface/datasets/issues/2832 needs to be fixed first for this to work
# with caplog.at_level(INFO):
# dataset = Dataset.from_sql(
# "dataset",
# con,
# cache_dir=cache_dir,
# )
# if con_type == "string":
# assert "couldn't be hashed properly" not in caplog.text
# elif con_type == "engine":
# assert "couldn't be hashed properly" in caplog.text
dataset = Dataset.from_sql(
"dataset",
con,
cache_dir=cache_dir,
)
_check_sql_dataset(dataset, expected_features)


@require_sqlalchemy
@pytest.mark.parametrize(
"features",
Expand Down