Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modularize default load and save argument handling #15

Merged
merged 24 commits into from
Jul 23, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
b9bf25d
Modularize default load and save argument handling
deepyaman Jun 8, 2019
ba18548
Suppress ``super-init-not-called`` pylint messages
deepyaman Jun 9, 2019
41b40b2
Copy default args to prevent accidental mutation
deepyaman Jun 14, 2019
c10a654
Restore ``super().__init__`` given default arg fix
deepyaman Jun 14, 2019
bf2643f
Merge branch 'develop' into fix/default-args
deepyaman Jul 2, 2019
e83502c
Refactor abstract base class modification as mixin
deepyaman Jul 2, 2019
63fda57
Homogenize default load and save argument handling
deepyaman Jul 3, 2019
0505773
Demarcate load and save argument handling :dragon:
deepyaman Jul 3, 2019
a93abf2
Cover load and save argument handling :paw_prints:
deepyaman Jul 3, 2019
4226c2e
Add tests to cover load/save argument conditionals
deepyaman Jul 3, 2019
a17ae9e
Fix non-ASCII characters in legal header :pencil2:
deepyaman Jul 3, 2019
f7b2373
Remove load/save defaults from ``AbstractDataSet``
deepyaman Jul 7, 2019
124d663
Call ``super().__init__`` in mix-in implementation
deepyaman Jul 9, 2019
d3c7153
Fix MRO when subclassing ``DefaultArgumentsMixIn``
deepyaman Jul 9, 2019
da10346
Merge branch 'fix/default-args' of https://github.com/deepyaman/kedro…
deepyaman Jul 10, 2019
cac0c78
Copy default argument dicts with ``copy.deepcopy``
deepyaman Jul 10, 2019
681beb0
Merge branch 'develop' into fix/default-args
deepyaman Jul 10, 2019
0d31b7c
Merge branch 'develop' of https://github.com/quantumblacklabs/kedro i…
deepyaman Jul 10, 2019
473d725
Merge branch 'develop' into fix/default-args
deepyaman Jul 10, 2019
2a575d6
Merge branch 'develop' into fix/default-args
deepyaman Jul 10, 2019
5896daa
Annotate types for default load and save arguments
deepyaman Jul 10, 2019
3931744
Revert "Annotate types for default load and save arguments"
deepyaman Jul 10, 2019
b2e4c1c
Annotate types for default load and save arguments
deepyaman Jul 10, 2019
184d9f7
Merge branch 'develop' into fix/default-args
deepyaman Jul 18, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions kedro/contrib/io/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,5 @@
`kedro.io` module (e.g. additional ``AbstractDataSet``s and
extensions/alternative ``DataCatalog``s.
"""

from .core import DefaultArgumentsMixIn # NOQA
11 changes: 5 additions & 6 deletions kedro/contrib/io/azure/csv_blob.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,11 @@
import pandas as pd
from azure.storage.blob import BlockBlobService

from kedro.contrib.io import DefaultArgumentsMixIn
from kedro.io import AbstractDataSet


class CSVBlobDataSet(AbstractDataSet):
class CSVBlobDataSet(DefaultArgumentsMixIn, AbstractDataSet):
"""``CSVBlobDataSet`` loads and saves csv files in Microsoft's Azure
blob storage. It uses azure storage SDK to read and write in azure and
pandas to handle the csv file locally.
Expand All @@ -61,6 +62,8 @@ class CSVBlobDataSet(AbstractDataSet):
>>> assert data.equals(reloaded)
"""

DEFAULT_SAVE_ARGS = {"index": False}

def _describe(self) -> Dict[str, Any]:
return dict(
filepath=self._filepath,
Expand Down Expand Up @@ -106,16 +109,12 @@ def __init__(
All defaults are preserved, but "index", which is set to False.

"""
default_save_args = {"index": False}
self._save_args = (
{**default_save_args, **save_args} if save_args else default_save_args
)
self._load_args = load_args if load_args else {}
self._filepath = filepath
self._container_name = container_name
self._credentials = credentials if credentials else {}
self._blob_to_text_args = blob_to_text_args if blob_to_text_args else {}
self._blob_from_text_args = blob_from_text_args if blob_from_text_args else {}
super().__init__(load_args, save_args)

def _load(self) -> pd.DataFrame:
blob_service = BlockBlobService(**self._credentials)
Expand Down
16 changes: 3 additions & 13 deletions kedro/contrib/io/bioinformatics/sequence_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,11 @@

from Bio import SeqIO

from kedro.contrib.io import DefaultArgumentsMixIn
from kedro.io import AbstractDataSet


class BioSequenceLocalDataSet(AbstractDataSet):
class BioSequenceLocalDataSet(DefaultArgumentsMixIn, AbstractDataSet):
"""``BioSequenceLocalDataSet`` loads and saves data to a sequence file.

Example:
Expand Down Expand Up @@ -95,18 +96,7 @@ def __init__(

"""
self._filepath = filepath
default_load_args = {} # type: Dict[str, Any]
default_save_args = {} # type: Dict[str, Any]
self._load_args = (
{**default_load_args, **load_args}
if load_args is not None
else default_load_args
)
self._save_args = (
{**default_save_args, **save_args}
if save_args is not None
else default_save_args
)
super().__init__(load_args, save_args)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer calling super at the top of the constructor, so the subclass would overwrite stuff from the parent, as a "specialisation" of the superclass.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer calling super at the top of the constructor, so the subclass would overwrite stuff from the parent, as a "specialisation" of the superclass.

Fair argument. I just left it in the same place where default arguments were previously handled (as close to the original as I could), but that makes sense.


def _load(self) -> List:
return list(SeqIO.parse(self._filepath, **self._load_args))
Expand Down
53 changes: 53 additions & 0 deletions kedro/contrib/io/core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright 2018-2019 QuantumBlack Visual Analytics Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES
# OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, AND
# NONINFRINGEMENT. IN NO EVENT WILL THE LICENSOR OR OTHER CONTRIBUTORS
# BE LIABLE FOR ANY CLAIM, DAMAGES, OR OTHER LIABILITY, WHETHER IN AN
# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF, OR IN
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
#
# The QuantumBlack Visual Analytics Limited ("QuantumBlack") name and logo
# (either separately or in combination, "QuantumBlack Trademarks") are
# trademarks of QuantumBlack. The License does not grant you any right or
# license to the QuantumBlack Trademarks. You may not use the QuantumBlack
# Trademarks or any confusingly similar mark as a trademark for your product,
# or use the QuantumBlack Trademarks in any other manner that might cause
# confusion in the marketplace, including but not limited to in advertising,
# on websites, or on software.
#
# See the License for the specific language governing permissions and
# limitations under the License.

"""This module extends the set of classes ``kedro.io.core`` provides."""

import copy
from typing import Any, Dict, Optional


# pylint: disable=too-few-public-methods
class DefaultArgumentsMixIn:
"""Mixin class that helps handle default load and save arguments."""

DEFAULT_LOAD_ARGS = {} # type: Dict[str, Any]
DEFAULT_SAVE_ARGS = {} # type: Dict[str, Any]

def __init__(
self,
load_args: Optional[Dict[str, Any]] = None,
save_args: Optional[Dict[str, Any]] = None,
) -> None:
super().__init__()
self._load_args = copy.deepcopy(self.DEFAULT_LOAD_ARGS)
if load_args is not None:
self._load_args.update(load_args)
self._save_args = copy.deepcopy(self.DEFAULT_SAVE_ARGS)
if save_args is not None:
self._save_args.update(save_args)
6 changes: 3 additions & 3 deletions kedro/contrib/io/pyspark/spark_data_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,11 @@
from pyspark.sql import DataFrame, SparkSession
from pyspark.sql.utils import AnalysisException

from kedro.contrib.io import DefaultArgumentsMixIn
from kedro.io import AbstractDataSet


class SparkDataSet(AbstractDataSet):
class SparkDataSet(DefaultArgumentsMixIn, AbstractDataSet):
"""``SparkDataSet`` loads and saves Spark data frames.

Example:
Expand Down Expand Up @@ -106,8 +107,7 @@ def __init__(

self._filepath = filepath
self._file_format = file_format
self._load_args = load_args if load_args is not None else {}
self._save_args = save_args if save_args is not None else {}
super().__init__(load_args, save_args)

@staticmethod
def _get_spark():
Expand Down
6 changes: 3 additions & 3 deletions kedro/contrib/io/pyspark/spark_jdbc.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,13 @@

from pyspark.sql import DataFrame, SparkSession

from kedro.contrib.io import DefaultArgumentsMixIn
from kedro.io import AbstractDataSet, DataSetError

__all__ = ["SparkJDBCDataSet"]


class SparkJDBCDataSet(AbstractDataSet):
class SparkJDBCDataSet(DefaultArgumentsMixIn, AbstractDataSet):
"""``SparkJDBCDataSet`` loads data from a database table accessible
via JDBC URL url and connection properties and saves the content of
a PySpark DataFrame to an external database table via JDBC. It uses
Expand Down Expand Up @@ -140,8 +141,7 @@ def __init__(

self._url = url
self._table = table
self._load_args = load_args if load_args is not None else {}
self._save_args = save_args if save_args is not None else {}
super().__init__(load_args, save_args)
deepyaman marked this conversation as resolved.
Show resolved Hide resolved

# Update properties in load_args and save_args with credentials.
if credentials is not None:
Expand Down
24 changes: 12 additions & 12 deletions kedro/io/csv_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
underlying functionality is supported by pandas, so it supports all
allowed pandas options for loading and saving csv files.
"""
import copy
from pathlib import Path
from typing import Any, Dict

Expand Down Expand Up @@ -61,6 +62,9 @@ class CSVLocalDataSet(AbstractVersionedDataSet):

"""

DEFAULT_LOAD_ARGS = {} # type: Dict[str, Any]
DEFAULT_SAVE_ARGS = {"index": False} # type: Dict[str, Any]

def __init__(
self,
filepath: str,
Expand All @@ -87,18 +91,14 @@ def __init__(
attribute is None, save version will be autogenerated.
"""
super().__init__(Path(filepath), version)
default_save_args = {"index": False} # type: Dict[str, Any]
default_load_args = {} # type: Dict[str, Any]
self._load_args = (
{**default_load_args, **load_args}
if load_args is not None
else default_load_args
)
self._save_args = (
{**default_save_args, **save_args}
if save_args is not None
else default_save_args
)

# Handle default load and save arguments
self._load_args = copy.deepcopy(self.DEFAULT_LOAD_ARGS)
if load_args is not None:
self._load_args.update(load_args)
self._save_args = copy.deepcopy(self.DEFAULT_SAVE_ARGS)
if save_args is not None:
self._save_args.update(save_args)

def _load(self) -> pd.DataFrame:
load_path = Path(self._get_load_path())
Expand Down
21 changes: 14 additions & 7 deletions kedro/io/csv_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
"""``CSVS3DataSet`` loads and saves data to a file in S3. It uses s3fs
to read and write from S3 and pandas to handle the csv file.
"""
from copy import deepcopy
import copy
from pathlib import PurePosixPath
from typing import Any, Dict, Optional

Expand Down Expand Up @@ -62,6 +62,9 @@ class CSVS3DataSet(AbstractVersionedDataSet):
>>> assert data.equals(reloaded)
"""

DEFAULT_LOAD_ARGS = {} # type: Dict[str, Any]
DEFAULT_SAVE_ARGS = {"index": False} # type: Dict[str, Any]

# pylint: disable=too-many-arguments
def __init__(
self,
Expand Down Expand Up @@ -94,21 +97,25 @@ def __init__(
attribute is None, save version will be autogenerated.

"""
_credentials = deepcopy(credentials) or {}
_credentials = copy.deepcopy(credentials) or {}
_s3 = S3FileSystem(client_kwargs=_credentials)
super().__init__(
PurePosixPath("{}/{}".format(bucket_name, filepath)),
version,
exists_function=_s3.exists,
glob_function=_s3.glob,
)
default_save_args = {"index": False}
self._save_args = (
{**default_save_args, **save_args} if save_args else default_save_args
)
self._load_args = load_args if load_args else {}
self._bucket_name = bucket_name
self._credentials = _credentials

# Handle default load and save arguments
self._load_args = copy.deepcopy(self.DEFAULT_LOAD_ARGS)
if load_args is not None:
self._load_args.update(load_args)
self._save_args = copy.deepcopy(self.DEFAULT_SAVE_ARGS)
if save_args is not None:
self._save_args.update(save_args)

self._s3 = _s3

def _describe(self) -> Dict[str, Any]:
Expand Down
25 changes: 12 additions & 13 deletions kedro/io/excel_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
underlying functionality is supported by pandas, so it supports all
allowed pandas options for loading and saving Excel files.
"""
import copy
from pathlib import Path
from typing import Any, Dict, Union

Expand Down Expand Up @@ -61,6 +62,9 @@ class ExcelLocalDataSet(AbstractVersionedDataSet):

"""

DEFAULT_LOAD_ARGS = {"engine": "xlrd"}
DEFAULT_SAVE_ARGS = {"index": False}

def _describe(self) -> Dict[str, Any]:
return dict(
filepath=self._filepath,
Expand Down Expand Up @@ -105,21 +109,16 @@ def __init__(

"""
super().__init__(Path(filepath), version)
default_save_args = {"index": False}
default_load_args = {"engine": "xlrd"}

self._load_args = (
{**default_load_args, **load_args}
if load_args is not None
else default_load_args
)
self._save_args = (
{**default_save_args, **save_args}
if save_args is not None
else default_save_args
)
self._engine = engine

# Handle default load and save arguments
self._load_args = copy.deepcopy(self.DEFAULT_LOAD_ARGS)
if load_args is not None:
self._load_args.update(load_args)
self._save_args = copy.deepcopy(self.DEFAULT_SAVE_ARGS)
if save_args is not None:
self._save_args.update(save_args)

def _load(self) -> Union[pd.DataFrame, Dict[str, pd.DataFrame]]:
load_path = Path(self._get_load_path())
return pd.read_excel(load_path, **self._load_args)
Expand Down
24 changes: 12 additions & 12 deletions kedro/io/hdf_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
underlying functionality is supported by pandas, so it supports all
allowed pandas options for loading and saving hdf files.
"""
import copy
from pathlib import Path
from typing import Any, Dict

Expand Down Expand Up @@ -63,6 +64,9 @@ class HDFLocalDataSet(AbstractVersionedDataSet):

"""

DEFAULT_LOAD_ARGS = {} # type: Dict[str, Any]
DEFAULT_SAVE_ARGS = {} # type: Dict[str, Any]

# pylint: disable=too-many-arguments
def __init__(
self,
Expand Down Expand Up @@ -93,19 +97,15 @@ def __init__(

"""
super().__init__(Path(filepath), version)
default_load_args = {} # type: Dict[str, Any]
default_save_args = {} # type: Dict[str, Any]
self._key = key
self._load_args = (
{**default_load_args, **load_args}
if load_args is not None
else default_load_args
)
self._save_args = (
{**default_load_args, **save_args}
if save_args is not None
else default_save_args
)

# Handle default load and save arguments
self._load_args = copy.deepcopy(self.DEFAULT_LOAD_ARGS)
if load_args is not None:
self._load_args.update(load_args)
self._save_args = copy.deepcopy(self.DEFAULT_SAVE_ARGS)
if save_args is not None:
self._save_args.update(save_args)

def _load(self) -> pd.DataFrame:
load_path = Path(self._get_load_path())
Expand Down
Loading