Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ability to read-write to SQL databases. #4928

Merged
merged 24 commits into from
Oct 3, 2022
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
f76b87c
Add ability to read-write to SQL databases.
Dref360 Sep 3, 2022
5747ad6
Fix issue where pandas<1.4.0 doesn't return the number of rows
Dref360 Sep 3, 2022
3811a5e
Fix issue where connections were not closed properly
Dref360 Sep 3, 2022
27d56b7
Apply suggestions from code review
Dref360 Sep 5, 2022
e9af3cf
Change according to reviews
Dref360 Sep 5, 2022
87eeb1a
Change according to reviews
Dref360 Sep 17, 2022
70e57c7
Merge main
Dref360 Sep 17, 2022
c3597c9
Inherit from AbstractDatasetInputStream in SqlDatasetReader
Dref360 Sep 17, 2022
61cf29a
Revert typing in SQLDatasetReader as we do not support Connexion
Dref360 Sep 18, 2022
453f2c3
Align API with Pandas/Daskk
mariosasko Sep 21, 2022
5410f51
Update tests
mariosasko Sep 21, 2022
3c128be
Update docs
mariosasko Sep 21, 2022
40268ae
Update some more tests
mariosasko Sep 21, 2022
7830d91
Merge branch 'main' of github.com:huggingface/datasets into HF-3094/i…
mariosasko Sep 21, 2022
dc005df
Missing comma
mariosasko Sep 21, 2022
a3c39d9
Small docs fix
mariosasko Sep 21, 2022
7c4999e
Style
mariosasko Sep 21, 2022
920de97
Update src/datasets/arrow_dataset.py
mariosasko Sep 23, 2022
9ecdb1f
Update src/datasets/packaged_modules/sql/sql.py
mariosasko Sep 23, 2022
27c9674
Address some comments
mariosasko Sep 23, 2022
ad20c27
Merge branch 'HF-3094/io_sql' of github.com:Dref360/datasets into HF-…
mariosasko Sep 23, 2022
81ad0e4
Address the rest
mariosasko Sep 23, 2022
3714fb0
Improve tests
mariosasko Sep 23, 2022
f3610c8
sqlalchemy required tip
mariosasko Oct 3, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions docs/source/loading.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,24 @@ To load remote Parquet files via HTTP, pass the URLs instead:
>>> wiki = load_dataset("parquet", data_files=data_files, split="train")
```

### SQL
mariosasko marked this conversation as resolved.
Show resolved Hide resolved

Read database contents with with [`Dataset.from_sql`]. Both table names and queries are supported.
Copy link
Contributor Author

@Dref360 Dref360 Oct 3, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Something like that?

Suggested change
Read database contents with with [`Dataset.from_sql`]. Both table names and queries are supported.
Read database contents with with [`Dataset.from_sql`]. Both table names and queries are supported.
Requires [`sqlalchemy`](https://www.sqlalchemy.org/).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I decided to add a tip to the from_sql docstring instead, but thanks anyways :).


For example, a table from a SQLite file can be loaded with:

```py
>>> from datasets import Dataset
>>> dataset = Dataset.from_sql("data_table", "sqlite:///sqlite_file.db")
```

Use a query for a more precise read:

```py
>>> from datasets import Dataset
>>> dataset = Dataset.from_sql("SELECT text FROM data_table WHERE length(text) > 100 LIMIT 10", "sqlite:///sqlite_file.db")
```

## In-memory data

🤗 Datasets will also allow you to create a [`Dataset`] directly from in-memory data structures like Python dictionaries and Pandas DataFrames.
Expand Down
4 changes: 4 additions & 0 deletions docs/source/package_reference/loading_methods.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,10 @@ load_dataset("csv", data_dir="path/to/data/dir", sep="\t")

[[autodoc]] datasets.packaged_modules.parquet.ParquetConfig

### SQL

[[autodoc]] datasets.packaged_modules.sql.SqlConfig

### Images

[[autodoc]] datasets.packaged_modules.imagefolder.ImageFolderConfig
Expand Down
2 changes: 2 additions & 0 deletions docs/source/package_reference/main_classes.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ The base class [`Dataset`] implements a Dataset backed by an Apache Arrow table.
- to_dict
- to_json
- to_parquet
- to_sql
- add_faiss_index
- add_faiss_index_from_external_arrays
- save_faiss_index
Expand Down Expand Up @@ -90,6 +91,7 @@ The base class [`Dataset`] implements a Dataset backed by an Apache Arrow table.
- from_json
- from_parquet
- from_text
- from_sql
- prepare_for_task
- align_labels_with_mapping

Expand Down
1 change: 1 addition & 0 deletions docs/source/process.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -609,6 +609,7 @@ Want to save your dataset to a cloud storage provider? Read our [Cloud Storage](
| CSV | [`Dataset.to_csv`] |
| JSON | [`Dataset.to_json`] |
| Parquet | [`Dataset.to_parquet`] |
| SQL | [`Dataset.to_sql`] |
| In-memory Python object | [`Dataset.to_pandas`] or [`Dataset.to_dict`] |

For example, export your dataset to a CSV file like this:
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@
"scipy",
"sentencepiece", # for bleurt
"seqeval",
"sqlalchemy",
"tldextract",
# to speed up pip backtracking
"toml>=0.10.1",
Expand Down
87 changes: 87 additions & 0 deletions src/datasets/arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,10 @@


if TYPE_CHECKING:
import sqlite3

import sqlalchemy
lhoestq marked this conversation as resolved.
Show resolved Hide resolved

from .dataset_dict import DatasetDict

logger = logging.get_logger(__name__)
Expand Down Expand Up @@ -1092,6 +1096,52 @@ def from_text(
path_or_paths, split=split, features=features, cache_dir=cache_dir, keep_in_memory=keep_in_memory, **kwargs
).read()

@staticmethod
def from_sql(
sql: Union[str, "sqlalchemy.sql.Selectable"],
con: str,
features: Optional[Features] = None,
cache_dir: str = None,
keep_in_memory: bool = False,
**kwargs,
):
"""Create Dataset from SQL query or database table.

Args:
sql (`str` or :obj:`sqlalchemy.sql.Selectable`): SQL query to be executed or a table name.
con (`str`): A connection URI string used to instantiate a database connection.
features (:class:`Features`, optional): Dataset features.
cache_dir (:obj:`str`, optional, default ``"~/.cache/huggingface/datasets"``): Directory to cache data.
keep_in_memory (:obj:`bool`, default ``False``): Whether to copy the data in-memory.
**kwargs (additional keyword arguments): Keyword arguments to be passed to :class:`SqlConfig`.

Returns:
:class:`Dataset`

Example:

```py
>>> # Fetch a database table
>>> ds = Dataset.from_sql("test_data", "postgres:///db_name")
>>> # Execute a SQL query on the table
>>> ds = Dataset.from_sql("SELECT sentence FROM test_data", "postgres:///db_name")
>>> # Use a Selectable object to specify the query
>>> from sqlalchemy import select, text
>>> stmt = select([text("sentence")]).select_from(text("test_data"))
>>> ds = Dataset.from_sql(stmt, "postgres:///db_name")
```
mariosasko marked this conversation as resolved.
Show resolved Hide resolved
"""
from .io.sql import SqlDatasetReader

return SqlDatasetReader(
sql,
con,
features=features,
cache_dir=cache_dir,
keep_in_memory=keep_in_memory,
**kwargs,
).read()

def __del__(self):
if hasattr(self, "_data"):
del self._data
Expand Down Expand Up @@ -4098,6 +4148,43 @@ def to_parquet(

return ParquetDatasetWriter(self, path_or_buf, batch_size=batch_size, **parquet_writer_kwargs).write()

def to_sql(
self,
name: str,
con: Union[str, "sqlalchemy.engine.Connection", "sqlalchemy.engine.Engine", "sqlite3.Connection"],
batch_size: Optional[int] = None,
**sql_writer_kwargs,
) -> int:
"""Exports the dataset to a SQL database.

Args:
name (`str`): Name of SQL table.
con (`str` or `sqlite3.Connection` or `sqlalchemy.engine.Connection` or `sqlalchemy.engine.Connection`):
A database connection URI string or an existing SQLite3/SQLAlchemy connection used to write to a database.
batch_size (:obj:`int`, optional): Size of the batch to load in memory and write at once.
Defaults to :obj:`datasets.config.DEFAULT_MAX_BATCH_SIZE`.
**sql_writer_kwargs (additional keyword arguments): Parameters to pass to pandas's :function:`Dataframe.to_sql`

Returns:
int: The number of records written.

Example:

```py
>>> # con provided as a connection URI string
>>> ds.to_sql("data", "sqlite:///my_own_db.sql")
mariosasko marked this conversation as resolved.
Show resolved Hide resolved
>>> # con provided as a sqlite3 connection object
>>> import sqlite3
>>> con = sqlite3.connect("my_own_db.sql")
>>> with con:
... ds.to_sql("data", con)
```
"""
# Dynamic import to avoid circular dependency
from .io.sql import SqlDatasetWriter

return SqlDatasetWriter(self, name, con, batch_size=batch_size, **sql_writer_kwargs).write()

def _push_parquet_shards_to_hub(
self,
repo_id: str,
Expand Down
3 changes: 3 additions & 0 deletions src/datasets/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,9 @@
logger.info("Disabling Apache Beam because USE_BEAM is set to False")


# Optional tools for data loading
SQLALCHEMY_AVAILABLE = importlib.util.find_spec("sqlalchemy") is not None

# Optional tools for feature decoding
PIL_AVAILABLE = importlib.util.find_spec("PIL") is not None

Expand Down
129 changes: 129 additions & 0 deletions src/datasets/io/sql.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
import multiprocessing
from typing import TYPE_CHECKING, Optional, Union

from .. import Dataset, Features, config
from ..formatting import query_table
from ..packaged_modules.sql.sql import Sql
from ..utils import logging
from .abc import AbstractDatasetInputStream


if TYPE_CHECKING:
import sqlite3

import sqlalchemy
lhoestq marked this conversation as resolved.
Show resolved Hide resolved


class SqlDatasetReader(AbstractDatasetInputStream):
def __init__(
self,
sql: Union[str, "sqlalchemy.sql.Selectable"],
con: str,
features: Optional[Features] = None,
cache_dir: str = None,
keep_in_memory: bool = False,
**kwargs,
):
super().__init__(features=features, cache_dir=cache_dir, keep_in_memory=keep_in_memory, **kwargs)
self.builder = Sql(
cache_dir=cache_dir,
features=features,
sql=sql,
con=con,
**kwargs,
)

def read(self):
download_config = None
download_mode = None
ignore_verifications = False
use_auth_token = None
base_path = None

self.builder.download_and_prepare(
download_config=download_config,
download_mode=download_mode,
ignore_verifications=ignore_verifications,
# try_from_hf_gcs=try_from_hf_gcs,
base_path=base_path,
use_auth_token=use_auth_token,
)

# Build dataset for splits
dataset = self.builder.as_dataset(
split="train", ignore_verifications=ignore_verifications, in_memory=self.keep_in_memory
mariosasko marked this conversation as resolved.
Show resolved Hide resolved
)
return dataset


class SqlDatasetWriter:
def __init__(
self,
dataset: Dataset,
name: str,
con: Union[str, "sqlalchemy.engine.Connection", "sqlalchemy.engine.Engine", "sqlite3.Connection"],
batch_size: Optional[int] = None,
num_proc: Optional[int] = None,
**to_sql_kwargs,
):

if num_proc is not None and num_proc <= 0:
raise ValueError(f"num_proc {num_proc} must be an integer > 0.")

self.dataset = dataset
self.name = name
self.con = con
self.batch_size = batch_size if batch_size else config.DEFAULT_MAX_BATCH_SIZE
self.num_proc = num_proc
self.to_sql_kwargs = to_sql_kwargs

def write(self) -> int:
_ = self.to_sql_kwargs.pop("sql", None)
_ = self.to_sql_kwargs.pop("con", None)

written = self._write(**self.to_sql_kwargs)
return written

def _batch_sql(self, args):
offset, to_sql_kwargs = args
to_sql_kwargs = {**to_sql_kwargs, "if_exists": "append"} if offset > 0 else to_sql_kwargs
batch = query_table(
table=self.dataset.data,
key=slice(offset, offset + self.batch_size),
indices=self.dataset._indices,
)
df = batch.to_pandas()
num_rows = df.to_sql(self.name, self.con, **to_sql_kwargs)
return num_rows or len(df)

def _write(self, **to_sql_kwargs) -> int:
"""Writes the pyarrow table as SQL to a database.

Caller is responsible for opening and closing the SQL connection.
"""
written = 0

if self.num_proc is None or self.num_proc == 1:
for offset in logging.tqdm(
range(0, len(self.dataset), self.batch_size),
unit="ba",
disable=not logging.is_progress_bar_enabled(),
desc="Creating SQL from Arrow format",
):
written += self._batch_sql((offset, to_sql_kwargs))
else:
num_rows, batch_size = len(self.dataset), self.batch_size
with multiprocessing.Pool(self.num_proc) as pool:
for num_rows in logging.tqdm(
pool.imap(
self._batch_sql,
[(offset, to_sql_kwargs) for offset in range(0, num_rows, batch_size)],
),
total=(num_rows // batch_size) + 1 if num_rows % batch_size else num_rows // batch_size,
unit="ba",
disable=not logging.is_progress_bar_enabled(),
desc="Creating SQL from Arrow format",
):
written += num_rows

return written
1 change: 1 addition & 0 deletions src/datasets/packaged_modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .json import json
from .pandas import pandas
from .parquet import parquet
from .sql import sql # noqa F401
from .text import text


Expand Down
18 changes: 9 additions & 9 deletions src/datasets/packaged_modules/csv/csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ def __post_init__(self):
self.names = self.column_names

@property
def read_csv_kwargs(self):
read_csv_kwargs = dict(
def pd_read_csv_kwargs(self):
pd_read_csv_kwargs = dict(
sep=self.sep,
header=self.header,
names=self.names,
Expand Down Expand Up @@ -112,16 +112,16 @@ def read_csv_kwargs(self):

# some kwargs must not be passed if they don't have a default value
# some others are deprecated and we can also not pass them if they are the default value
for read_csv_parameter in _PANDAS_READ_CSV_NO_DEFAULT_PARAMETERS + _PANDAS_READ_CSV_DEPRECATED_PARAMETERS:
if read_csv_kwargs[read_csv_parameter] == getattr(CsvConfig(), read_csv_parameter):
del read_csv_kwargs[read_csv_parameter]
for pd_read_csv_parameter in _PANDAS_READ_CSV_NO_DEFAULT_PARAMETERS + _PANDAS_READ_CSV_DEPRECATED_PARAMETERS:
if pd_read_csv_kwargs[pd_read_csv_parameter] == getattr(CsvConfig(), pd_read_csv_parameter):
del pd_read_csv_kwargs[pd_read_csv_parameter]

# Remove 1.3 new arguments
if not (datasets.config.PANDAS_VERSION.major >= 1 and datasets.config.PANDAS_VERSION.minor >= 3):
for read_csv_parameter in _PANDAS_READ_CSV_NEW_1_3_0_PARAMETERS:
del read_csv_kwargs[read_csv_parameter]
for pd_read_csv_parameter in _PANDAS_READ_CSV_NEW_1_3_0_PARAMETERS:
del pd_read_csv_kwargs[pd_read_csv_parameter]

return read_csv_kwargs
return pd_read_csv_kwargs


class Csv(datasets.ArrowBasedBuilder):
Expand Down Expand Up @@ -172,7 +172,7 @@ def _generate_tables(self, files):
else None
)
for file_idx, file in enumerate(itertools.chain.from_iterable(files)):
csv_file_reader = pd.read_csv(file, iterator=True, dtype=dtype, **self.config.read_csv_kwargs)
csv_file_reader = pd.read_csv(file, iterator=True, dtype=dtype, **self.config.pd_read_csv_kwargs)
try:
for batch_idx, df in enumerate(csv_file_reader):
pa_table = pa.Table.from_pandas(df)
Expand Down
Empty file.
Loading