diff --git a/docs/source/loading.mdx b/docs/source/loading.mdx
index d3b0b8fbf31..16b44259efd 100644
--- a/docs/source/loading.mdx
+++ b/docs/source/loading.mdx
@@ -210,10 +210,19 @@ For example, a table from a SQLite file can be loaded with:
Use a query for a more precise read:
```py
+>>> from sqlite3 import connect
+>>> con = connect(":memory")
+>>> # db writes ...
>>> from datasets import Dataset
->>> dataset = Dataset.from_sql("SELECT text FROM data_table WHERE length(text) > 100 LIMIT 10", "sqlite:///sqlite_file.db")
+>>> dataset = Dataset.from_sql("SELECT text FROM table WHERE length(text) > 100 LIMIT 10", con)
```
+
+
+You can specify [`Dataset.from_sql#con`] as a [URI string](https://docs.sqlalchemy.org/en/13/core/engines.html#database-urls) for the 🤗 Datasets caching to work across sessions.
+
+
+
## In-memory data
🤗 Datasets will also allow you to create a [`Dataset`] directly from in-memory data structures like Python dictionaries and Pandas DataFrames.
diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py
index eda0b318128..17a4412693d 100644
--- a/src/datasets/arrow_dataset.py
+++ b/src/datasets/arrow_dataset.py
@@ -1104,7 +1104,7 @@ def from_text(
@staticmethod
def from_sql(
sql: Union[str, "sqlalchemy.sql.Selectable"],
- con: str,
+ con: Union[str, "sqlalchemy.engine.Connection", "sqlalchemy.engine.Engine", "sqlite3.Connection"],
features: Optional[Features] = None,
cache_dir: str = None,
keep_in_memory: bool = False,
@@ -1114,7 +1114,8 @@ def from_sql(
Args:
sql (`str` or :obj:`sqlalchemy.sql.Selectable`): SQL query to be executed or a table name.
- con (`str`): A connection URI string used to instantiate a database connection.
+ con (`str` or :obj:`sqlite3.Connection` or :obj:`sqlalchemy.engine.Connection` or :obj:`sqlalchemy.engine.Connection`):
+ A [URI string](https://docs.sqlalchemy.org/en/13/core/engines.html#database-urls) used to instantiate a database connection or a SQLite3/SQLAlchemy connection object.
features (:class:`Features`, optional): Dataset features.
cache_dir (:obj:`str`, optional, default ``"~/.cache/huggingface/datasets"``): Directory to cache data.
keep_in_memory (:obj:`bool`, default ``False``): Whether to copy the data in-memory.
@@ -1137,7 +1138,7 @@ def from_sql(
```
- `sqlalchemy` needs to be installed to use this function.
+ The returned dataset can only be cached if `con` is specified as URI string.
"""
from .io.sql import SqlDatasetReader
@@ -4218,8 +4219,8 @@ def to_sql(
Args:
name (`str`): Name of SQL table.
- con (`str` or `sqlite3.Connection` or `sqlalchemy.engine.Connection` or `sqlalchemy.engine.Connection`):
- A database connection URI string or an existing SQLite3/SQLAlchemy connection used to write to a database.
+ con (`str` or :obj:`sqlite3.Connection` or :obj:`sqlalchemy.engine.Connection` or :obj:`sqlalchemy.engine.Connection`):
+ A [URI string](https://docs.sqlalchemy.org/en/13/core/engines.html#database-urls) or a SQLite3/SQLAlchemy connection object used to write to a database.
batch_size (:obj:`int`, optional): Size of the batch to load in memory and write at once.
Defaults to :obj:`datasets.config.DEFAULT_MAX_BATCH_SIZE`.
**sql_writer_kwargs (additional keyword arguments): Parameters to pass to pandas's :function:`Dataframe.to_sql`
diff --git a/src/datasets/io/sql.py b/src/datasets/io/sql.py
index 0301908f50c..c88cad49398 100644
--- a/src/datasets/io/sql.py
+++ b/src/datasets/io/sql.py
@@ -18,7 +18,7 @@ class SqlDatasetReader(AbstractDatasetInputStream):
def __init__(
self,
sql: Union[str, "sqlalchemy.sql.Selectable"],
- con: str,
+ con: Union[str, "sqlalchemy.engine.Connection", "sqlalchemy.engine.Engine", "sqlite3.Connection"],
features: Optional[Features] = None,
cache_dir: str = None,
keep_in_memory: bool = False,
diff --git a/src/datasets/packaged_modules/sql/sql.py b/src/datasets/packaged_modules/sql/sql.py
index 25c0178e264..5fe9d74acf1 100644
--- a/src/datasets/packaged_modules/sql/sql.py
+++ b/src/datasets/packaged_modules/sql/sql.py
@@ -12,15 +12,20 @@
if TYPE_CHECKING:
+ import sqlite3
+
import sqlalchemy
+logger = datasets.utils.logging.get_logger(__name__)
+
+
@dataclass
class SqlConfig(datasets.BuilderConfig):
"""BuilderConfig for SQL."""
sql: Union[str, "sqlalchemy.sql.Selectable"] = None
- con: str = None
+ con: Union[str, "sqlalchemy.engine.Connection", "sqlalchemy.engine.Engine", "sqlite3.Connection"] = None
index_col: Optional[Union[str, List[str]]] = None
coerce_float: bool = True
params: Optional[Union[List, Tuple, Dict]] = None
@@ -34,14 +39,13 @@ def __post_init__(self):
raise ValueError("sql must be specified")
if self.con is None:
raise ValueError("con must be specified")
- if not isinstance(self.con, str):
- raise ValueError(f"con must be a database URI string, but got {self.con} with type {type(self.con)}.")
def create_config_id(
self,
config_kwargs: dict,
custom_features: Optional[datasets.Features] = None,
) -> str:
+ config_kwargs = config_kwargs.copy()
# We need to stringify the Selectable object to make its hash deterministic
# The process of stringifying is explained here: http://docs.sqlalchemy.org/en/latest/faq/sqlexpressions.html
@@ -51,7 +55,6 @@ def create_config_id(
import sqlalchemy
if isinstance(sql, sqlalchemy.sql.Selectable):
- config_kwargs = config_kwargs.copy()
engine = sqlalchemy.create_engine(config_kwargs["con"].split("://")[0] + "://")
sql_str = str(sql.compile(dialect=engine.dialect))
config_kwargs["sql"] = sql_str
@@ -63,6 +66,13 @@ def create_config_id(
raise TypeError(
f"Supported types for 'sql' are string and sqlalchemy.sql.Selectable but got {type(sql)}: {sql}"
)
+ con = config_kwargs["con"]
+ if not isinstance(con, str):
+ config_kwargs["con"] = id(con)
+ logger.info(
+ f"SQL connection 'con' of type {type(con)} couldn't be hashed properly. To enable hashing, specify 'con' as URI string instead."
+ )
+
return super().create_config_id(config_kwargs, custom_features=custom_features)
@property
diff --git a/tests/test_arrow_dataset.py b/tests/test_arrow_dataset.py
index 4fe9ea1ea2b..54e1c62806a 100644
--- a/tests/test_arrow_dataset.py
+++ b/tests/test_arrow_dataset.py
@@ -3356,6 +3356,36 @@ def _check_sql_dataset(dataset, expected_features):
assert dataset.features[feature].dtype == expected_dtype
+@require_sqlalchemy
+@pytest.mark.parametrize("con_type", ["string", "engine"])
+def test_dataset_from_sql_con_type(con_type, sqlite_path, tmp_path):
+ cache_dir = tmp_path / "cache"
+ expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64"}
+ if con_type == "string":
+ con = "sqlite:///" + sqlite_path
+ elif con_type == "engine":
+ import sqlalchemy
+
+ con = sqlalchemy.create_engine("sqlite:///" + sqlite_path)
+ # # https://github.com/huggingface/datasets/issues/2832 needs to be fixed first for this to work
+ # with caplog.at_level(INFO):
+ # dataset = Dataset.from_sql(
+ # "dataset",
+ # con,
+ # cache_dir=cache_dir,
+ # )
+ # if con_type == "string":
+ # assert "couldn't be hashed properly" not in caplog.text
+ # elif con_type == "engine":
+ # assert "couldn't be hashed properly" in caplog.text
+ dataset = Dataset.from_sql(
+ "dataset",
+ con,
+ cache_dir=cache_dir,
+ )
+ _check_sql_dataset(dataset, expected_features)
+
+
@require_sqlalchemy
@pytest.mark.parametrize(
"features",