From fd0de7200f6173f6a1542b0b8bc770fa5df6115c Mon Sep 17 00:00:00 2001 From: Minura Punchihewa Date: Mon, 26 Aug 2024 18:41:05 +0530 Subject: [PATCH 01/51] added a base table implementation Signed-off-by: Minura Punchihewa --- .../databricks/base_table_dataset.py | 170 ++++++++++++++++++ 1 file changed, 170 insertions(+) create mode 100644 kedro-datasets/kedro_datasets/databricks/base_table_dataset.py diff --git a/kedro-datasets/kedro_datasets/databricks/base_table_dataset.py b/kedro-datasets/kedro_datasets/databricks/base_table_dataset.py new file mode 100644 index 000000000..8af3e50f1 --- /dev/null +++ b/kedro-datasets/kedro_datasets/databricks/base_table_dataset.py @@ -0,0 +1,170 @@ +"""``BaseTableDataset`` implementation used to add the base for +``ManagedTableDataset`` and ``ExternalTableDataset``. +""" +from __future__ import annotations + +import logging +import re +from dataclasses import dataclass +from typing import Any + +import pandas as pd +from kedro.io.core import ( + AbstractVersionedDataset, + DatasetError, + Version, + VersionNotFoundError, +) +from pyspark.sql import DataFrame +from pyspark.sql.types import StructType +from pyspark.sql.utils import AnalysisException, ParseException + +from kedro_datasets.spark.spark_dataset import _get_spark + +logger = logging.getLogger(__name__) +pd.DataFrame.iteritems = pd.DataFrame.items + + +@dataclass(frozen=True) +class BaseTable: + """Stores the definition of a base table. + + Acts as a base class for `ManagedTable` and `ExternalTable`. + """ + + # regex for tables, catalogs and schemas + _NAMING_REGEX = r"\b[0-9a-zA-Z_-]{1,}\b" + _VALID_WRITE_MODES = ["overwrite", "append"] + _VALID_DATAFRAME_TYPES = ["spark", "pandas"] + _VALID_FORMATS = ["delta", "parquet", "csv"] + format: str + database: str + catalog: str | None + table: str + write_mode: str | None + dataframe_type: str + primary_key: str | list[str] | None + owner_group: str | None + partition_columns: str | list[str] | None + json_schema: dict[str, Any] | None = None + + def __post_init__(self): + """Run validation methods if declared. + + The validation method can be a simple check + that raises DatasetError. + + The validation is performed by calling a function with the signature + `validate_(self, value) -> raises DatasetError`. + """ + for name in self.__dataclass_fields__.keys(): + method = getattr(self, f"_validate_{name}", None) + if method: + method() + + def _validate_format(self): + """Validates the format of the table. + + Raises: + DatasetError: If an invalid `format` is passed. + """ + if self.format not in self._VALID_FORMATS: + valid_formats = ", ".join(self._VALID_FORMATS) + raise DatasetError( + f"Invalid `format` provided: {self.format}. " + f"`format` must be one of: {valid_formats}" + ) + + def _validate_table(self): + """Validates table name. + + Raises: + DatasetError: If the table name does not conform to naming constraints. + """ + if not re.fullmatch(self._NAMING_REGEX, self.table): + raise DatasetError("table does not conform to naming") + + def _validate_database(self): + """Validates database name. + + Raises: + DatasetError: If the dataset name does not conform to naming constraints. + """ + if not re.fullmatch(self._NAMING_REGEX, self.database): + raise DatasetError("database does not conform to naming") + + def _validate_catalog(self): + """Validates catalog name. + + Raises: + DatasetError: If the catalog name does not conform to naming constraints. + """ + if self.catalog: + if not re.fullmatch(self._NAMING_REGEX, self.catalog): + raise DatasetError("catalog does not conform to naming") + + def _validate_write_mode(self): + """Validates the write mode. + + Raises: + DatasetError: If an invalid `write_mode` is passed. + """ + if ( + self.write_mode is not None + and self.write_mode not in self._VALID_WRITE_MODES + ): + valid_modes = ", ".join(self._VALID_WRITE_MODES) + raise DatasetError( + f"Invalid `write_mode` provided: {self.write_mode}. " + f"`write_mode` must be one of: {valid_modes}" + ) + + def _validate_dataframe_type(self): + """Validates the dataframe type. + + Raises: + DatasetError: If an invalid `dataframe_type` is passed + """ + if self.dataframe_type not in self._VALID_DATAFRAME_TYPES: + valid_types = ", ".join(self._VALID_DATAFRAME_TYPES) + raise DatasetError(f"`dataframe_type` must be one of {valid_types}") + + def _validate_primary_key(self): + """Validates the primary key of the table. + + Raises: + DatasetError: If no `primary_key` is specified. + """ + if self.primary_key is None or len(self.primary_key) == 0: + if self.write_mode == "upsert": + raise DatasetError( + f"`primary_key` must be provided for" + f"`write_mode` {self.write_mode}" + ) + + def full_table_location(self) -> str | None: + """Returns the full table location. + + Returns: + str | None : table location in the format catalog.database.table or None if database and table aren't defined + """ + full_table_location = None + if self.catalog and self.database and self.table: + full_table_location = f"`{self.catalog}`.`{self.database}`.`{self.table}`" + elif self.database and self.table: + full_table_location = f"`{self.database}`.`{self.table}`" + return full_table_location + + def schema(self) -> StructType | None: + """Returns the Spark schema of the table if it exists. + + Returns: + StructType: + """ + schema = None + try: + if self.json_schema is not None: + schema = StructType.fromJson(self.json_schema) + except (KeyError, ValueError) as exc: + raise DatasetError(exc) from exc + return schema \ No newline at end of file From 0392aba016c92d32dfc784b49c17f9a27fa3190d Mon Sep 17 00:00:00 2001 From: Minura Punchihewa Date: Mon, 26 Aug 2024 23:00:51 +0530 Subject: [PATCH 02/51] added a base table dataset implementation Signed-off-by: Minura Punchihewa --- .../databricks/base_table_dataset.py | 210 +++++++++++++++++- 1 file changed, 209 insertions(+), 1 deletion(-) diff --git a/kedro-datasets/kedro_datasets/databricks/base_table_dataset.py b/kedro-datasets/kedro_datasets/databricks/base_table_dataset.py index 8af3e50f1..0c60b742a 100644 --- a/kedro-datasets/kedro_datasets/databricks/base_table_dataset.py +++ b/kedro-datasets/kedro_datasets/databricks/base_table_dataset.py @@ -167,4 +167,212 @@ def schema(self) -> StructType | None: schema = StructType.fromJson(self.json_schema) except (KeyError, ValueError) as exc: raise DatasetError(exc) from exc - return schema \ No newline at end of file + return schema + + +class BaseTableDataset(AbstractVersionedDataset): + """``BaseTableDataset`` loads and saves data into managed delta tables or external tables on Databricks. + Load and save can be in Spark or Pandas dataframes, specified in dataframe_type. + + This dataaset is not meant to be used directly. It is a base class for ``ManagedTableDataset`` and ``ExternalTableDataset``. + """ + + # this dataset cannot be used with ``ParallelRunner``, + # therefore it has the attribute ``_SINGLE_PROCESS = True`` + # for parallelism within a Spark pipeline please consider + # using ``ThreadRunner`` instead + _SINGLE_PROCESS = True + + def __init__( # noqa: PLR0913 + self, + *, + table: str, + catalog: str | None = None, + database: str = "default", + write_mode: str | None = None, + dataframe_type: str = "spark", + # the following parameters are used by project hooks + # to create or update table properties + schema: dict[str, Any] | None = None, + partition_columns: list[str] | None = None, + owner_group: str | None = None, + metadata: dict[str, Any] | None = None, + **kwargs: Any, + ) -> None: + """Creates a new instance of ``BaseTableDataset``. + + Args: + table: the name of the table + catalog: the name of the catalog in Unity. + Defaults to None. + database: the name of the database. + (also referred to as schema). Defaults to "default". + write_mode: the mode to write the data into the table. If not + present, the data set is read-only. + Options are:["overwrite", "append", "upsert"]. + "upsert" mode requires primary_key field to be populated. + Defaults to None. + dataframe_type: "pandas" or "spark" dataframe. + Defaults to "spark". + schema: the schema of the table in JSON form. + Dataframes will be truncated to match the schema if provided. + Used by the hooks to create the table if the schema is provided + Defaults to None. + partition_columns: the columns to use for partitioning the table. + Used by the hooks. Defaults to None. + owner_group: if table access control is enabled in your workspace, + specifying owner_group will transfer ownership of the table and database to + this owner. All databases should have the same owner_group. Defaults to None. + metadata: Any arbitrary metadata. + This is ignored by Kedro, but may be consumed by users or external plugins. + Raises: + DatasetError: Invalid configuration supplied (through BaseTable validation). + """ + + self._table = self._create_table( + table=table, + catalog=catalog, + database=database, + write_mode=write_mode, + dataframe_type=dataframe_type, + schema=schema, + partition_columns=partition_columns, + owner_group=owner_group, + **kwargs, + ) + + self.metadata = metadata + self.kwargs = kwargs + + super().__init__( + filepath=None, # type: ignore[arg-type] + exists_function=self._exists, # type: ignore[arg-type] + ) + + def _create_table(self, **kwargs: Any) -> None: + """Creates a table object and assign it to the _table attribute. + + Args: + **kwargs: Arguments to pass to the table object. + """ + raise NotImplementedError + + def _load(self) -> DataFrame | pd.DataFrame: + """Loads the data from the table location defined in the init. + (spark|pandas dataframe) + + Returns: + Union[DataFrame, pd.DataFrame]: Returns a dataframe + in the format defined in the init + """ + data = _get_spark().table(self._table.full_table_location()) + if self._table.dataframe_type == "pandas": + data = data.toPandas() + return data + + def _save(self, data: DataFrame | pd.DataFrame) -> None: + """Saves the data based on the write_mode and dataframe_type in the init. + If write_mode is pandas, Spark dataframe is created first. + If schema is provided, data is matched to schema before saving + (columns will be sorted and truncated). + + Args: + data (Any): Spark or pandas dataframe to save to the table location + """ + if self._table.write_mode is None: + raise DatasetError( + "'save' can not be used in read-only mode. " + f"Change 'write_mode' value to {', '.join(self._table._VALID_WRITE_MODES)}" + ) + # filter columns specified in schema and match their ordering + schema = self._table.schema() + if schema: + cols = schema.fieldNames() + if self._table.dataframe_type == "pandas": + data = _get_spark().createDataFrame( + data.loc[:, cols], schema=self._table.schema() + ) + else: + data = data.select(*cols) + elif self._table.dataframe_type == "pandas": + data = _get_spark().createDataFrame(data) + + method = getattr(self, f"_save_{self._table.write_mode}", None) + + if method is None: + raise DatasetError( + f"Invalid `write_mode` provided: {self._table.write_mode}. " + f"`write_mode` must be one of: {self._table._VALID_WRITE_MODES}" + ) + + method(data) + + def _save_append(self, data: DataFrame) -> None: + """Saves the data to the table by appending it + to the location defined in the init. + + Args: + data (DataFrame): the Spark dataframe to append to the table. + """ + data.write.format(self._table.format).mode("append").saveAsTable( + self._table.full_table_location() or "" + ) + + def _save_overwrite(self, data: DataFrame) -> None: + """Overwrites the data in the table with the data provided. + (this is the default save mode) + + Args: + data (DataFrame): the Spark dataframe to overwrite the table with. + """ + table = data.write.format(self._table.format) + if self._table.write_mode == "overwrite": + table = table.mode("overwrite").option( + "overwriteSchema", "true" + ) + table.saveAsTable(self._table.full_table_location() or "") + + def _describe(self) -> dict[str, str | list | None]: + """Returns a description of the instance of the dataset. + + Returns: + Dict[str, str]: Dict with the details of the dataset + """ + return { + "catalog": self._table.catalog, + "database": self._table.database, + "table": self._table.table, + "write_mode": self._table.write_mode, + "dataframe_type": self._table.dataframe_type, + "owner_group": self._table.owner_group, + "partition_columns": self._table.partition_columns, + **self.kwargs + } + + def _exists(self) -> bool: + """Checks to see if the table exists. + + Returns: + bool: boolean of whether the table defined + in the dataset instance exists in the Spark session. + """ + if self._table.catalog: + try: + _get_spark().sql(f"USE CATALOG `{self._table.catalog}`") + except (ParseException, AnalysisException) as exc: + logger.warning( + "catalog %s not found or unity not enabled. Error message: %s", + self._table.catalog, + exc, + ) + try: + return ( + _get_spark() + .sql(f"SHOW TABLES IN `{self._table.database}`") + .filter(f"tableName = '{self._table.table}'") + .count() + > 0 + ) + except (ParseException, AnalysisException) as exc: + logger.warning("error occured while trying to find table: %s", exc) + return False \ No newline at end of file From 38c470f062a64f01aacfae633df386785af0971e Mon Sep 17 00:00:00 2001 From: Minura Punchihewa Date: Mon, 26 Aug 2024 23:02:49 +0530 Subject: [PATCH 03/51] renamed the module with the base classes Signed-off-by: Minura Punchihewa --- .../databricks/{base_table_dataset.py => _base_table_dataset.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename kedro-datasets/kedro_datasets/databricks/{base_table_dataset.py => _base_table_dataset.py} (100%) diff --git a/kedro-datasets/kedro_datasets/databricks/base_table_dataset.py b/kedro-datasets/kedro_datasets/databricks/_base_table_dataset.py similarity index 100% rename from kedro-datasets/kedro_datasets/databricks/base_table_dataset.py rename to kedro-datasets/kedro_datasets/databricks/_base_table_dataset.py From d424829cf63f05409b4f8ea0a1815105395f831d Mon Sep 17 00:00:00 2001 From: Minura Punchihewa Date: Mon, 26 Aug 2024 23:06:48 +0530 Subject: [PATCH 04/51] refactored ManagedTable using BaseTable Signed-off-by: Minura Punchihewa --- .../databricks/managed_table_dataset.py | 121 +----------------- 1 file changed, 2 insertions(+), 119 deletions(-) diff --git a/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py b/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py index 677db0d56..6a8a02d55 100644 --- a/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py +++ b/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py @@ -20,136 +20,19 @@ from pyspark.sql.utils import AnalysisException, ParseException from kedro_datasets.spark.spark_dataset import _get_spark +from kedro_datasets.databricks._base_table_dataset import BaseTable, BaseTableDataset logger = logging.getLogger(__name__) pd.DataFrame.iteritems = pd.DataFrame.items @dataclass(frozen=True) -class ManagedTable: +class ManagedTable(BaseTable): """Stores the definition of a managed table""" - # regex for tables, catalogs and schemas - _NAMING_REGEX = r"\b[0-9a-zA-Z_-]{1,}\b" _VALID_WRITE_MODES = ["overwrite", "upsert", "append"] - _VALID_DATAFRAME_TYPES = ["spark", "pandas"] - database: str - catalog: str | None - table: str - write_mode: str | None - dataframe_type: str primary_key: str | list[str] | None owner_group: str | None - partition_columns: str | list[str] | None - json_schema: dict[str, Any] | None = None - - def __post_init__(self): - """Run validation methods if declared. - - The validation method can be a simple check - that raises DatasetError. - - The validation is performed by calling a function with the signature - `validate_(self, value) -> raises DatasetError`. - """ - for name in self.__dataclass_fields__.keys(): - method = getattr(self, f"_validate_{name}", None) - if method: - method() - - def _validate_table(self): - """Validates table name - - Raises: - DatasetError: If the table name does not conform to naming constraints. - """ - if not re.fullmatch(self._NAMING_REGEX, self.table): - raise DatasetError("table does not conform to naming") - - def _validate_database(self): - """Validates database name - - Raises: - DatasetError: If the dataset name does not conform to naming constraints. - """ - if not re.fullmatch(self._NAMING_REGEX, self.database): - raise DatasetError("database does not conform to naming") - - def _validate_catalog(self): - """Validates catalog name - - Raises: - DatasetError: If the catalog name does not conform to naming constraints. - """ - if self.catalog: - if not re.fullmatch(self._NAMING_REGEX, self.catalog): - raise DatasetError("catalog does not conform to naming") - - def _validate_write_mode(self): - """Validates the write mode - - Raises: - DatasetError: If an invalid `write_mode` is passed. - """ - if ( - self.write_mode is not None - and self.write_mode not in self._VALID_WRITE_MODES - ): - valid_modes = ", ".join(self._VALID_WRITE_MODES) - raise DatasetError( - f"Invalid `write_mode` provided: {self.write_mode}. " - f"`write_mode` must be one of: {valid_modes}" - ) - - def _validate_dataframe_type(self): - """Validates the dataframe type - - Raises: - DatasetError: If an invalid `dataframe_type` is passed - """ - if self.dataframe_type not in self._VALID_DATAFRAME_TYPES: - valid_types = ", ".join(self._VALID_DATAFRAME_TYPES) - raise DatasetError(f"`dataframe_type` must be one of {valid_types}") - - def _validate_primary_key(self): - """Validates the primary key of the table - - Raises: - DatasetError: If no `primary_key` is specified. - """ - if self.primary_key is None or len(self.primary_key) == 0: - if self.write_mode == "upsert": - raise DatasetError( - f"`primary_key` must be provided for" - f"`write_mode` {self.write_mode}" - ) - - def full_table_location(self) -> str | None: - """Returns the full table location - - Returns: - str | None : table location in the format catalog.database.table or None if database and table aren't defined - """ - full_table_location = None - if self.catalog and self.database and self.table: - full_table_location = f"`{self.catalog}`.`{self.database}`.`{self.table}`" - elif self.database and self.table: - full_table_location = f"`{self.database}`.`{self.table}`" - return full_table_location - - def schema(self) -> StructType | None: - """Returns the Spark schema of the table if it exists - - Returns: - StructType: - """ - schema = None - try: - if self.json_schema is not None: - schema = StructType.fromJson(self.json_schema) - except (KeyError, ValueError) as exc: - raise DatasetError(exc) from exc - return schema class ManagedTableDataset(AbstractVersionedDataset): From 7091986d856ddac812851bd58fd970e475c9b948 Mon Sep 17 00:00:00 2001 From: Minura Punchihewa Date: Mon, 26 Aug 2024 23:11:29 +0530 Subject: [PATCH 05/51] refactored ManagedTableDataset using BaseTableDataset Signed-off-by: Minura Punchihewa --- .../databricks/managed_table_dataset.py | 162 +----------------- 1 file changed, 8 insertions(+), 154 deletions(-) diff --git a/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py b/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py index 6a8a02d55..da8215f35 100644 --- a/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py +++ b/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py @@ -35,7 +35,7 @@ class ManagedTable(BaseTable): owner_group: str | None -class ManagedTableDataset(AbstractVersionedDataset): +class ManagedTableDataset(BaseTableDataset): """``ManagedTableDataset`` loads and saves data into managed delta tables on Databricks. Load and save can be in Spark or Pandas dataframes, specified in dataframe_type. When saving data, you can specify one of three modes: overwrite(default), append, @@ -92,13 +92,6 @@ class ManagedTableDataset(AbstractVersionedDataset): >>> reloaded = dataset.load() >>> assert Row(name="Bob", age=12) in reloaded.take(4) """ - - # this dataset cannot be used with ``ParallelRunner``, - # therefore it has the attribute ``_SINGLE_PROCESS = True`` - # for parallelism within a Spark pipeline please consider - # using ``ThreadRunner`` instead - _SINGLE_PROCESS = True - def __init__( # noqa: PLR0913 self, *, @@ -149,81 +142,20 @@ def __init__( # noqa: PLR0913 Raises: DatasetError: Invalid configuration supplied (through ManagedTable validation) """ - - self._table = ManagedTable( - database=database, - catalog=catalog, + super().__init__( table=table, + catalog=catalog, + database=database, write_mode=write_mode, dataframe_type=dataframe_type, + version=version, + schema=schema, + partition_columns=partition_columns, + metadata=metadata, primary_key=primary_key, owner_group=owner_group, - partition_columns=partition_columns, - json_schema=schema, ) - self._version = version - self.metadata = metadata - - super().__init__( - filepath=None, # type: ignore[arg-type] - version=version, - exists_function=self._exists, # type: ignore[arg-type] - ) - - def _load(self) -> DataFrame | pd.DataFrame: - """Loads the version of data in the format defined in the init - (spark|pandas dataframe) - - Raises: - VersionNotFoundError: if the version defined in - the init doesn't exist - - Returns: - Union[DataFrame, pd.DataFrame]: Returns a dataframe - in the format defined in the init - """ - if self._version and self._version.load >= 0: - try: - data = ( - _get_spark() - .read.format("delta") - .option("versionAsOf", self._version.load) - .table(self._table.full_table_location()) - ) - except Exception as exc: - raise VersionNotFoundError(self._version.load) from exc - else: - data = _get_spark().table(self._table.full_table_location()) - if self._table.dataframe_type == "pandas": - data = data.toPandas() - return data - - def _save_append(self, data: DataFrame) -> None: - """Saves the data to the table by appending it - to the location defined in the init - - Args: - data (DataFrame): the Spark dataframe to append to the table - """ - data.write.format("delta").mode("append").saveAsTable( - self._table.full_table_location() or "" - ) - - def _save_overwrite(self, data: DataFrame) -> None: - """Overwrites the data in the table with the data provided. - (this is the default save mode) - - Args: - data (DataFrame): the Spark dataframe to overwrite the table with. - """ - delta_table = data.write.format("delta") - if self._table.write_mode == "overwrite": - delta_table = delta_table.mode("overwrite").option( - "overwriteSchema", "true" - ) - delta_table.saveAsTable(self._table.full_table_location() or "") - def _save_upsert(self, update_data: DataFrame) -> None: """Upserts the data by joining on primary_key columns or column. If table doesn't exist at save, the data is inserted to a new table. @@ -263,81 +195,3 @@ def _save_upsert(self, update_data: DataFrame) -> None: else: self._save_append(update_data) - def _save(self, data: DataFrame | pd.DataFrame) -> None: - """Saves the data based on the write_mode and dataframe_type in the init. - If write_mode is pandas, Spark dataframe is created first. - If schema is provided, data is matched to schema before saving - (columns will be sorted and truncated). - - Args: - data (Any): Spark or pandas dataframe to save to the table location - """ - if self._table.write_mode is None: - raise DatasetError( - "'save' can not be used in read-only mode. " - "Change 'write_mode' value to `overwrite`, `upsert` or `append`." - ) - # filter columns specified in schema and match their ordering - schema = self._table.schema() - if schema: - cols = schema.fieldNames() - if self._table.dataframe_type == "pandas": - data = _get_spark().createDataFrame( - data.loc[:, cols], schema=self._table.schema() - ) - else: - data = data.select(*cols) - elif self._table.dataframe_type == "pandas": - data = _get_spark().createDataFrame(data) - if self._table.write_mode == "overwrite": - self._save_overwrite(data) - elif self._table.write_mode == "upsert": - self._save_upsert(data) - elif self._table.write_mode == "append": - self._save_append(data) - - def _describe(self) -> dict[str, str | list | None]: - """Returns a description of the instance of ManagedTableDataset - - Returns: - Dict[str, str]: Dict with the details of the dataset - """ - return { - "catalog": self._table.catalog, - "database": self._table.database, - "table": self._table.table, - "write_mode": self._table.write_mode, - "dataframe_type": self._table.dataframe_type, - "primary_key": self._table.primary_key, - "version": str(self._version), - "owner_group": self._table.owner_group, - "partition_columns": self._table.partition_columns, - } - - def _exists(self) -> bool: - """Checks to see if the table exists - - Returns: - bool: boolean of whether the table defined - in the dataset instance exists in the Spark session - """ - if self._table.catalog: - try: - _get_spark().sql(f"USE CATALOG `{self._table.catalog}`") - except (ParseException, AnalysisException) as exc: - logger.warning( - "catalog %s not found or unity not enabled. Error message: %s", - self._table.catalog, - exc, - ) - try: - return ( - _get_spark() - .sql(f"SHOW TABLES IN `{self._table.database}`") - .filter(f"tableName = '{self._table.table}'") - .count() - > 0 - ) - except (ParseException, AnalysisException) as exc: - logger.warning("error occured while trying to find table: %s", exc) - return False From 081d5d1061c0fe991d326956bce6c029c56a4af1 Mon Sep 17 00:00:00 2001 From: Minura Punchihewa Date: Mon, 26 Aug 2024 23:14:20 +0530 Subject: [PATCH 06/51] removed primary_key attr from BaseTable Signed-off-by: Minura Punchihewa --- kedro-datasets/kedro_datasets/databricks/_base_table_dataset.py | 1 - 1 file changed, 1 deletion(-) diff --git a/kedro-datasets/kedro_datasets/databricks/_base_table_dataset.py b/kedro-datasets/kedro_datasets/databricks/_base_table_dataset.py index 0c60b742a..cb6774633 100644 --- a/kedro-datasets/kedro_datasets/databricks/_base_table_dataset.py +++ b/kedro-datasets/kedro_datasets/databricks/_base_table_dataset.py @@ -43,7 +43,6 @@ class BaseTable: table: str write_mode: str | None dataframe_type: str - primary_key: str | list[str] | None owner_group: str | None partition_columns: str | list[str] | None json_schema: dict[str, Any] | None = None From e2cd580905f9a89e21e0fd3cd03a2d955e5b5465 Mon Sep 17 00:00:00 2001 From: Minura Punchihewa Date: Mon, 26 Aug 2024 23:17:51 +0530 Subject: [PATCH 07/51] updated the format attrs of ManagedTable Signed-off-by: Minura Punchihewa --- .../kedro_datasets/databricks/managed_table_dataset.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py b/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py index da8215f35..2c1afb0fb 100644 --- a/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py +++ b/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py @@ -31,8 +31,9 @@ class ManagedTable(BaseTable): """Stores the definition of a managed table""" _VALID_WRITE_MODES = ["overwrite", "upsert", "append"] + _VALID_FORMATS = ["delta"] + format: str = "delta" primary_key: str | list[str] | None - owner_group: str | None class ManagedTableDataset(BaseTableDataset): @@ -194,4 +195,3 @@ def _save_upsert(self, update_data: DataFrame) -> None: _get_spark().sql(upsert_sql) else: self._save_append(update_data) - From 33814be4c300952490fa4f0707ffbcbab3bb65a3 Mon Sep 17 00:00:00 2001 From: Minura Punchihewa Date: Mon, 26 Aug 2024 23:24:16 +0530 Subject: [PATCH 08/51] implemented the _load() method of ManagedTableDataset Signed-off-by: Minura Punchihewa --- .../databricks/managed_table_dataset.py | 29 +++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py b/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py index 2c1afb0fb..afb11a0a0 100644 --- a/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py +++ b/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py @@ -157,6 +157,35 @@ def __init__( # noqa: PLR0913 owner_group=owner_group, ) + self._version = version + + def _load(self) -> DataFrame | pd.DataFrame: + """Loads the version of data in the format defined in the init + (spark|pandas dataframe) + + Raises: + VersionNotFoundError: if the version defined in + the init doesn't exist + + Returns: + Union[DataFrame, pd.DataFrame]: Returns a dataframe + in the format defined in the init + """ + if self._version and self._version.load >= 0: + try: + data = ( + _get_spark() + .read.format("delta") + .option("versionAsOf", self._version.load) + .table(self._table.full_table_location()) + ) + except Exception as exc: + raise VersionNotFoundError(self._version.load) from exc + else: + data = super()._load() + + return data + def _save_upsert(self, update_data: DataFrame) -> None: """Upserts the data by joining on primary_key columns or column. If table doesn't exist at save, the data is inserted to a new table. From 6f91dc8258fdf98662bf36eb6472c4b426ba2686 Mon Sep 17 00:00:00 2001 From: Minura Punchihewa Date: Mon, 26 Aug 2024 23:26:05 +0530 Subject: [PATCH 09/51] removed unnecessary imports Signed-off-by: Minura Punchihewa --- .../kedro_datasets/databricks/managed_table_dataset.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py b/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py index afb11a0a0..c46fdeacf 100644 --- a/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py +++ b/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py @@ -4,23 +4,19 @@ from __future__ import annotations import logging -import re from dataclasses import dataclass from typing import Any import pandas as pd from kedro.io.core import ( - AbstractVersionedDataset, DatasetError, Version, VersionNotFoundError, ) from pyspark.sql import DataFrame -from pyspark.sql.types import StructType -from pyspark.sql.utils import AnalysisException, ParseException -from kedro_datasets.spark.spark_dataset import _get_spark from kedro_datasets.databricks._base_table_dataset import BaseTable, BaseTableDataset +from kedro_datasets.spark.spark_dataset import _get_spark logger = logging.getLogger(__name__) pd.DataFrame.iteritems = pd.DataFrame.items From 410cf4a139c0c95fca11fe28c7a80a64b9169d46 Mon Sep 17 00:00:00 2001 From: Minura Punchihewa Date: Tue, 27 Aug 2024 16:32:10 +0530 Subject: [PATCH 10/51] reorganized the attrs of BaseTable Signed-off-by: Minura Punchihewa --- .../databricks/_base_table_dataset.py | 28 ++++++++++--------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/kedro-datasets/kedro_datasets/databricks/_base_table_dataset.py b/kedro-datasets/kedro_datasets/databricks/_base_table_dataset.py index cb6774633..d80729135 100644 --- a/kedro-datasets/kedro_datasets/databricks/_base_table_dataset.py +++ b/kedro-datasets/kedro_datasets/databricks/_base_table_dataset.py @@ -5,15 +5,13 @@ import logging import re -from dataclasses import dataclass -from typing import Any +from dataclasses import dataclass, field +from typing import Any, ClassVar, List import pandas as pd from kedro.io.core import ( AbstractVersionedDataset, - DatasetError, - Version, - VersionNotFoundError, + DatasetError ) from pyspark.sql import DataFrame from pyspark.sql.types import StructType @@ -25,18 +23,18 @@ pd.DataFrame.iteritems = pd.DataFrame.items -@dataclass(frozen=True) +@dataclass(frozen=True, kw_only=True) class BaseTable: """Stores the definition of a base table. Acts as a base class for `ManagedTable` and `ExternalTable`. """ - # regex for tables, catalogs and schemas - _NAMING_REGEX = r"\b[0-9a-zA-Z_-]{1,}\b" - _VALID_WRITE_MODES = ["overwrite", "append"] - _VALID_DATAFRAME_TYPES = ["spark", "pandas"] - _VALID_FORMATS = ["delta", "parquet", "csv"] + _NAMING_REGEX: ClassVar[str] = r"\b[0-9a-zA-Z_-]{1,}\b" + _VALID_WRITE_MODES: ClassVar[List[str]] = field(default=["overwrite", "append"]) + _VALID_DATAFRAME_TYPES: ClassVar[List[str]] = field(default=["spark", "pandas"]) + _VALID_FORMATS: ClassVar[List[str]] = field(default=["delta", "parquet", "csv"]) + format: str database: str catalog: str | None @@ -188,7 +186,7 @@ def __init__( # noqa: PLR0913 table: str, catalog: str | None = None, database: str = "default", - write_mode: str | None = None, + write_mode: str | None = "overwrite", dataframe_type: str = "spark", # the following parameters are used by project hooks # to create or update table properties @@ -245,14 +243,18 @@ def __init__( # noqa: PLR0913 super().__init__( filepath=None, # type: ignore[arg-type] + version=kwargs.get("version"), exists_function=self._exists, # type: ignore[arg-type] ) - def _create_table(self, **kwargs: Any) -> None: + def _create_table(self, **kwargs: Any) -> BaseTable: """Creates a table object and assign it to the _table attribute. Args: **kwargs: Arguments to pass to the table object. + + Returns: + BaseTable: the table object. """ raise NotImplementedError From 1bf0b8b21677b556468baf50866abe0661f1f955 Mon Sep 17 00:00:00 2001 From: Minura Punchihewa Date: Tue, 27 Aug 2024 16:33:20 +0530 Subject: [PATCH 11/51] implemented create_table() in ManagedTableDataset Signed-off-by: Minura Punchihewa --- .../databricks/managed_table_dataset.py | 36 +++++++++++++++---- 1 file changed, 29 insertions(+), 7 deletions(-) diff --git a/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py b/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py index c46fdeacf..d576946df 100644 --- a/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py +++ b/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py @@ -4,8 +4,8 @@ from __future__ import annotations import logging -from dataclasses import dataclass -from typing import Any +from dataclasses import dataclass, field +from typing import Any, ClassVar, List import pandas as pd from kedro.io.core import ( @@ -22,13 +22,13 @@ pd.DataFrame.iteritems = pd.DataFrame.items -@dataclass(frozen=True) +@dataclass(frozen=True, kw_only=True) class ManagedTable(BaseTable): """Stores the definition of a managed table""" - _VALID_WRITE_MODES = ["overwrite", "upsert", "append"] - _VALID_FORMATS = ["delta"] - format: str = "delta" + _VALID_WRITE_MODES: ClassVar[List[str]] = field(default=["overwrite", "upsert", "append"]) + _VALID_FORMATS: ClassVar[List[str]] = field(default=["delta"]) + primary_key: str | list[str] | None @@ -95,7 +95,7 @@ def __init__( # noqa: PLR0913 table: str, catalog: str | None = None, database: str = "default", - write_mode: str | None = None, + write_mode: str | None = "overwrite", dataframe_type: str = "spark", primary_key: str | list[str] | None = None, version: Version | None = None, @@ -155,6 +155,28 @@ def __init__( # noqa: PLR0913 self._version = version + def _create_table(self, **kwargs: Any) -> ManagedTable: + """Creates a new ManagedTable instance with the provided kwargs. + + Args: + **kwargs: the parameters to create the table with + + Returns: + ManagedTable: the new ManagedTable instance + """ + return ManagedTable( + table=kwargs["table"], + catalog=kwargs["catalog"], + database=kwargs["database"], + write_mode=kwargs["write_mode"], + dataframe_type=kwargs["dataframe_type"], + json_schema=kwargs["schema"], + partition_columns=kwargs["partition_columns"], + owner_group=kwargs["owner_group"], + primary_key=kwargs["primary_key"], + format="delta" + ) + def _load(self) -> DataFrame | pd.DataFrame: """Loads the version of data in the format defined in the init (spark|pandas dataframe) From ee7e07302b658ed2dc8d60c92efd9a8694114da5 Mon Sep 17 00:00:00 2001 From: Minura Punchihewa Date: Tue, 27 Aug 2024 23:44:31 +0530 Subject: [PATCH 12/51] added the version attr to BaseTableDataset and updated load() Signed-off-by: Minura Punchihewa --- .../databricks/_base_table_dataset.py | 27 ++++++++++++--- .../databricks/managed_table_dataset.py | 34 ++----------------- 2 files changed, 25 insertions(+), 36 deletions(-) diff --git a/kedro-datasets/kedro_datasets/databricks/_base_table_dataset.py b/kedro-datasets/kedro_datasets/databricks/_base_table_dataset.py index d80729135..d60e67e7f 100644 --- a/kedro-datasets/kedro_datasets/databricks/_base_table_dataset.py +++ b/kedro-datasets/kedro_datasets/databricks/_base_table_dataset.py @@ -11,6 +11,8 @@ import pandas as pd from kedro.io.core import ( AbstractVersionedDataset, + Version, + VersionNotFoundError, DatasetError ) from pyspark.sql import DataFrame @@ -35,10 +37,10 @@ class BaseTable: _VALID_DATAFRAME_TYPES: ClassVar[List[str]] = field(default=["spark", "pandas"]) _VALID_FORMATS: ClassVar[List[str]] = field(default=["delta", "parquet", "csv"]) - format: str database: str catalog: str | None table: str + format: str write_mode: str | None dataframe_type: str owner_group: str | None @@ -188,6 +190,7 @@ def __init__( # noqa: PLR0913 database: str = "default", write_mode: str | None = "overwrite", dataframe_type: str = "spark", + version: Version | None = None, # the following parameters are used by project hooks # to create or update table properties schema: dict[str, Any] | None = None, @@ -239,11 +242,12 @@ def __init__( # noqa: PLR0913 ) self.metadata = metadata + self._version = version self.kwargs = kwargs super().__init__( filepath=None, # type: ignore[arg-type] - version=kwargs.get("version"), + version=version, exists_function=self._exists, # type: ignore[arg-type] ) @@ -259,14 +263,29 @@ def _create_table(self, **kwargs: Any) -> BaseTable: raise NotImplementedError def _load(self) -> DataFrame | pd.DataFrame: - """Loads the data from the table location defined in the init. + """Loads the version of data in the format defined in the init (spark|pandas dataframe) + Raises: + VersionNotFoundError: if the version defined in + the init doesn't exist + Returns: Union[DataFrame, pd.DataFrame]: Returns a dataframe in the format defined in the init """ - data = _get_spark().table(self._table.full_table_location()) + if self._version and self._version.load >= 0: + try: + data = ( + _get_spark() + .read.format("delta") + .option("versionAsOf", self._version.load) + .table(self._table.full_table_location()) + ) + except Exception as exc: + raise VersionNotFoundError(self._version.load) from exc + else: + data = _get_spark().table(self._table.full_table_location()) if self._table.dataframe_type == "pandas": data = data.toPandas() return data diff --git a/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py b/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py index d576946df..bfcc68185 100644 --- a/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py +++ b/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py @@ -10,8 +10,7 @@ import pandas as pd from kedro.io.core import ( DatasetError, - Version, - VersionNotFoundError, + Version ) from pyspark.sql import DataFrame @@ -33,7 +32,7 @@ class ManagedTable(BaseTable): class ManagedTableDataset(BaseTableDataset): - """``ManagedTableDataset`` loads and saves data into managed delta tables on Databricks. + """``ManagedTableDataset`` loads and saves data into managed delta tables in Databricks. Load and save can be in Spark or Pandas dataframes, specified in dataframe_type. When saving data, you can specify one of three modes: overwrite(default), append, or upsert. Upsert requires you to specify the primary_column parameter which @@ -153,8 +152,6 @@ def __init__( # noqa: PLR0913 owner_group=owner_group, ) - self._version = version - def _create_table(self, **kwargs: Any) -> ManagedTable: """Creates a new ManagedTable instance with the provided kwargs. @@ -177,33 +174,6 @@ def _create_table(self, **kwargs: Any) -> ManagedTable: format="delta" ) - def _load(self) -> DataFrame | pd.DataFrame: - """Loads the version of data in the format defined in the init - (spark|pandas dataframe) - - Raises: - VersionNotFoundError: if the version defined in - the init doesn't exist - - Returns: - Union[DataFrame, pd.DataFrame]: Returns a dataframe - in the format defined in the init - """ - if self._version and self._version.load >= 0: - try: - data = ( - _get_spark() - .read.format("delta") - .option("versionAsOf", self._version.load) - .table(self._table.full_table_location()) - ) - except Exception as exc: - raise VersionNotFoundError(self._version.load) from exc - else: - data = super()._load() - - return data - def _save_upsert(self, update_data: DataFrame) -> None: """Upserts the data by joining on primary_key columns or column. If table doesn't exist at save, the data is inserted to a new table. From d87f4ed07aea4585c2f0e8aa77f1866c9a0e8029 Mon Sep 17 00:00:00 2001 From: Minura Punchihewa Date: Thu, 29 Aug 2024 16:28:30 +0530 Subject: [PATCH 13/51] updated the base and managed datasets with all attrs Signed-off-by: Minura Punchihewa --- .../databricks/_base_table_dataset.py | 85 +++++++++++++++-- .../databricks/managed_table_dataset.py | 91 +++++++------------ 2 files changed, 109 insertions(+), 67 deletions(-) diff --git a/kedro-datasets/kedro_datasets/databricks/_base_table_dataset.py b/kedro-datasets/kedro_datasets/databricks/_base_table_dataset.py index d60e67e7f..ffe3ea75f 100644 --- a/kedro-datasets/kedro_datasets/databricks/_base_table_dataset.py +++ b/kedro-datasets/kedro_datasets/databricks/_base_table_dataset.py @@ -33,18 +33,19 @@ class BaseTable: """ # regex for tables, catalogs and schemas _NAMING_REGEX: ClassVar[str] = r"\b[0-9a-zA-Z_-]{1,}\b" - _VALID_WRITE_MODES: ClassVar[List[str]] = field(default=["overwrite", "append"]) + _VALID_WRITE_MODES: ClassVar[List[str]] = field(default=["overwrite", "upsert", "append"]) _VALID_DATAFRAME_TYPES: ClassVar[List[str]] = field(default=["spark", "pandas"]) _VALID_FORMATS: ClassVar[List[str]] = field(default=["delta", "parquet", "csv"]) database: str catalog: str | None table: str - format: str write_mode: str | None dataframe_type: str + primary_key: str | list[str] | None owner_group: str | None partition_columns: str | list[str] | None + format: str = "delta", json_schema: dict[str, Any] | None = None def __post_init__(self): @@ -188,16 +189,17 @@ def __init__( # noqa: PLR0913 table: str, catalog: str | None = None, database: str = "default", + format: str = "delta", write_mode: str | None = "overwrite", dataframe_type: str = "spark", + primary_key: str | list[str] | None = None, version: Version | None = None, # the following parameters are used by project hooks # to create or update table properties schema: dict[str, Any] | None = None, partition_columns: list[str] | None = None, owner_group: str | None = None, - metadata: dict[str, Any] | None = None, - **kwargs: Any, + metadata: dict[str, Any] | None = None ) -> None: """Creates a new instance of ``BaseTableDataset``. @@ -214,6 +216,8 @@ def __init__( # noqa: PLR0913 Defaults to None. dataframe_type: "pandas" or "spark" dataframe. Defaults to "spark". + primary_key: the primary key of the table. + Can be in the form of a list. Defaults to None. schema: the schema of the table in JSON form. Dataframes will be truncated to match the schema if provided. Used by the hooks to create the table if the schema is provided @@ -233,17 +237,17 @@ def __init__( # noqa: PLR0913 table=table, catalog=catalog, database=database, + format=format, write_mode=write_mode, dataframe_type=dataframe_type, - schema=schema, + primary_key=primary_key, + json_schema=schema, partition_columns=partition_columns, owner_group=owner_group, - **kwargs, ) self.metadata = metadata self._version = version - self.kwargs = kwargs super().__init__( filepath=None, # type: ignore[arg-type] @@ -251,11 +255,32 @@ def __init__( # noqa: PLR0913 exists_function=self._exists, # type: ignore[arg-type] ) - def _create_table(self, **kwargs: Any) -> BaseTable: + def _create_table( + self, + table: str, + catalog: str | None, + database: str, + format: str, + write_mode: str | None, + dataframe_type: str, + primary_key: str | list[str] | None, + json_schema: dict[str, Any] | None, + partition_columns: list[str] | None, + owner_group: str | None + ) -> BaseTable: """Creates a table object and assign it to the _table attribute. Args: - **kwargs: Arguments to pass to the table object. + table: The name of the table. + catalog: The catalog of the table. + database: The database of the table. + format: The format of the table. + write_mode: The write mode for the table. + dataframe_type: The type of dataframe. + primary_key: The primary key of the table. + json_schema: The JSON schema of the table. + partition_columns: The partition columns of the table. + owner_group: The owner group of the table. Returns: BaseTable: the table object. @@ -352,6 +377,45 @@ def _save_overwrite(self, data: DataFrame) -> None: ) table.saveAsTable(self._table.full_table_location() or "") + def _save_upsert(self, update_data: DataFrame) -> None: + """Upserts the data by joining on primary_key columns or column. + If table doesn't exist at save, the data is inserted to a new table. + + Args: + update_data (DataFrame): the Spark dataframe to upsert + """ + if self._exists(): + base_data = _get_spark().table(self._table.full_table_location()) + base_columns = base_data.columns + update_columns = update_data.columns + + if set(update_columns) != set(base_columns): + raise DatasetError( + f"Upsert requires tables to have identical columns. " + f"Delta table {self._table.full_table_location()} " + f"has columns: {base_columns}, whereas " + f"dataframe has columns {update_columns}" + ) + + where_expr = "" + if isinstance(self._table.primary_key, str): + where_expr = ( + f"base.{self._table.primary_key}=update.{self._table.primary_key}" + ) + elif isinstance(self._table.primary_key, list): + where_expr = " AND ".join( + f"base.{col}=update.{col}" for col in self._table.primary_key + ) + + update_data.createOrReplaceTempView("update") + _get_spark().conf.set("fullTableAddress", self._table.full_table_location()) + _get_spark().conf.set("whereExpr", where_expr) + upsert_sql = """MERGE INTO ${fullTableAddress} base USING update ON ${whereExpr} + WHEN MATCHED THEN UPDATE SET * WHEN NOT MATCHED THEN INSERT *""" + _get_spark().sql(upsert_sql) + else: + self._save_append(update_data) + def _describe(self) -> dict[str, str | list | None]: """Returns a description of the instance of the dataset. @@ -364,9 +428,10 @@ def _describe(self) -> dict[str, str | list | None]: "table": self._table.table, "write_mode": self._table.write_mode, "dataframe_type": self._table.dataframe_type, + "primary_key": self._table.primary_key, + "version": str(self._version), "owner_group": self._table.owner_group, "partition_columns": self._table.partition_columns, - **self.kwargs } def _exists(self) -> bool: diff --git a/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py b/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py index bfcc68185..21d4fc520 100644 --- a/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py +++ b/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py @@ -9,13 +9,10 @@ import pandas as pd from kedro.io.core import ( - DatasetError, Version ) -from pyspark.sql import DataFrame from kedro_datasets.databricks._base_table_dataset import BaseTable, BaseTableDataset -from kedro_datasets.spark.spark_dataset import _get_spark logger = logging.getLogger(__name__) pd.DataFrame.iteritems = pd.DataFrame.items @@ -25,11 +22,8 @@ class ManagedTable(BaseTable): """Stores the definition of a managed table""" - _VALID_WRITE_MODES: ClassVar[List[str]] = field(default=["overwrite", "upsert", "append"]) _VALID_FORMATS: ClassVar[List[str]] = field(default=["delta"]) - primary_key: str | list[str] | None - class ManagedTableDataset(BaseTableDataset): """``ManagedTableDataset`` loads and saves data into managed delta tables in Databricks. @@ -152,63 +146,46 @@ def __init__( # noqa: PLR0913 owner_group=owner_group, ) - def _create_table(self, **kwargs: Any) -> ManagedTable: - """Creates a new ManagedTable instance with the provided kwargs. + def _create_table( + self, + table: str, + catalog: str | None, + database: str, + format: str, + write_mode: str | None, + dataframe_type: str, + primary_key: str | list[str] | None, + json_schema: dict[str, Any] | None, + partition_columns: list[str] | None, + owner_group: str | None + ) -> ManagedTable: + """Creates a new ManagedTable instance with the provided attributes. Args: - **kwargs: the parameters to create the table with + table: The name of the table. + catalog: The catalog of the table. + database: The database of the table. + format: The format of the table. + write_mode: The write mode for the table. + dataframe_type: The type of dataframe. + primary_key: The primary key of the table. + json_schema: The JSON schema of the table. + partition_columns: The partition columns of the table. + owner_group: The owner group of the table. Returns: ManagedTable: the new ManagedTable instance """ return ManagedTable( - table=kwargs["table"], - catalog=kwargs["catalog"], - database=kwargs["database"], - write_mode=kwargs["write_mode"], - dataframe_type=kwargs["dataframe_type"], - json_schema=kwargs["schema"], - partition_columns=kwargs["partition_columns"], - owner_group=kwargs["owner_group"], - primary_key=kwargs["primary_key"], - format="delta" + table=table, + catalog=catalog, + database=database, + write_mode=write_mode, + dataframe_type=dataframe_type, + json_schema=json_schema, + partition_columns=partition_columns, + owner_group=owner_group, + primary_key=primary_key, + format=format ) - def _save_upsert(self, update_data: DataFrame) -> None: - """Upserts the data by joining on primary_key columns or column. - If table doesn't exist at save, the data is inserted to a new table. - - Args: - update_data (DataFrame): the Spark dataframe to upsert - """ - if self._exists(): - base_data = _get_spark().table(self._table.full_table_location()) - base_columns = base_data.columns - update_columns = update_data.columns - - if set(update_columns) != set(base_columns): - raise DatasetError( - f"Upsert requires tables to have identical columns. " - f"Delta table {self._table.full_table_location()} " - f"has columns: {base_columns}, whereas " - f"dataframe has columns {update_columns}" - ) - - where_expr = "" - if isinstance(self._table.primary_key, str): - where_expr = ( - f"base.{self._table.primary_key}=update.{self._table.primary_key}" - ) - elif isinstance(self._table.primary_key, list): - where_expr = " AND ".join( - f"base.{col}=update.{col}" for col in self._table.primary_key - ) - - update_data.createOrReplaceTempView("update") - _get_spark().conf.set("fullTableAddress", self._table.full_table_location()) - _get_spark().conf.set("whereExpr", where_expr) - upsert_sql = """MERGE INTO ${fullTableAddress} base USING update ON ${whereExpr} - WHEN MATCHED THEN UPDATE SET * WHEN NOT MATCHED THEN INSERT *""" - _get_spark().sql(upsert_sql) - else: - self._save_append(update_data) From a4feeff0cf72be6e25fdfae4cc27319c4ed74cb1 Mon Sep 17 00:00:00 2001 From: Minura Punchihewa Date: Thu, 29 Aug 2024 16:44:17 +0530 Subject: [PATCH 14/51] updated the supported formats Signed-off-by: Minura Punchihewa --- kedro-datasets/kedro_datasets/databricks/_base_table_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/kedro-datasets/kedro_datasets/databricks/_base_table_dataset.py b/kedro-datasets/kedro_datasets/databricks/_base_table_dataset.py index ffe3ea75f..89674872c 100644 --- a/kedro-datasets/kedro_datasets/databricks/_base_table_dataset.py +++ b/kedro-datasets/kedro_datasets/databricks/_base_table_dataset.py @@ -35,7 +35,7 @@ class BaseTable: _NAMING_REGEX: ClassVar[str] = r"\b[0-9a-zA-Z_-]{1,}\b" _VALID_WRITE_MODES: ClassVar[List[str]] = field(default=["overwrite", "upsert", "append"]) _VALID_DATAFRAME_TYPES: ClassVar[List[str]] = field(default=["spark", "pandas"]) - _VALID_FORMATS: ClassVar[List[str]] = field(default=["delta", "parquet", "csv"]) + _VALID_FORMATS: ClassVar[List[str]] = field(default=["delta", "parquet", "csv", "json", "orc", "avro", "text"]) database: str catalog: str | None From 0db5639db57753ac4e55ca9b6c2ac0d8895b3e64 Mon Sep 17 00:00:00 2001 From: Minura Punchihewa Date: Thu, 29 Aug 2024 16:44:50 +0530 Subject: [PATCH 15/51] added external table and external table dataset implementations Signed-off-by: Minura Punchihewa --- .../databricks/external_table_dataset.py | 69 +++++++++++++++++++ 1 file changed, 69 insertions(+) create mode 100644 kedro-datasets/kedro_datasets/databricks/external_table_dataset.py diff --git a/kedro-datasets/kedro_datasets/databricks/external_table_dataset.py b/kedro-datasets/kedro_datasets/databricks/external_table_dataset.py new file mode 100644 index 000000000..cb29cadcd --- /dev/null +++ b/kedro-datasets/kedro_datasets/databricks/external_table_dataset.py @@ -0,0 +1,69 @@ +"""``ExternalTableDataset`` implementation to access external tables +in Databricks. +""" +from __future__ import annotations + +import logging +from dataclasses import dataclass +from typing import Any + +import pandas as pd + +from kedro_datasets.databricks._base_table_dataset import BaseTable, BaseTableDataset + +logger = logging.getLogger(__name__) +pd.DataFrame.iteritems = pd.DataFrame.items + + +@dataclass(frozen=True, kw_only=True) +class ExternalTable(BaseTable): + """Stores the definition of an external table""" + + +class ExternalTableDataset(BaseTableDataset): + """``ExternalTableDataset`` loads and saves data into external tables in Databricks. + Load and save can be in Spark or Pandas dataframes, specified in dataframe_type. + """ + + def _create_table( + self, + table: str, + catalog: str | None, + database: str, + format: str, + write_mode: str | None, + dataframe_type: str, + primary_key: str | list[str] | None, + json_schema: dict[str, Any] | None, + partition_columns: list[str] | None, + owner_group: str | None + ) -> ExternalTable: + """Creates a new ExternalTable instance with the provided attributes. + + Args: + table: The name of the table. + catalog: The catalog of the table. + database: The database of the table. + format: The format of the table. + write_mode: The write mode for the table. + dataframe_type: The type of dataframe. + primary_key: The primary key of the table. + json_schema: The JSON schema of the table. + partition_columns: The partition columns of the table. + owner_group: The owner group of the table. + + Returns: + ExternalTable: the new ExternalTable instance + """ + return ExternalTable( + table=table, + catalog=catalog, + database=database, + write_mode=write_mode, + dataframe_type=dataframe_type, + json_schema=json_schema, + partition_columns=partition_columns, + owner_group=owner_group, + primary_key=primary_key, + format=format + ) \ No newline at end of file From e19eba48c571883bddc1ec4c9e3d49e98253afb3 Mon Sep 17 00:00:00 2001 From: Minura Punchihewa Date: Thu, 29 Aug 2024 18:04:58 +0530 Subject: [PATCH 16/51] added a val func to check for format when using upsert mode Signed-off-by: Minura Punchihewa --- .../kedro_datasets/databricks/_base_table_dataset.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/kedro-datasets/kedro_datasets/databricks/_base_table_dataset.py b/kedro-datasets/kedro_datasets/databricks/_base_table_dataset.py index 89674872c..c5a5e037b 100644 --- a/kedro-datasets/kedro_datasets/databricks/_base_table_dataset.py +++ b/kedro-datasets/kedro_datasets/databricks/_base_table_dataset.py @@ -118,6 +118,18 @@ def _validate_write_mode(self): f"Invalid `write_mode` provided: {self.write_mode}. " f"`write_mode` must be one of: {valid_modes}" ) + + def _validate_format_for_upsert(self) -> None: + """Validates the format for upserts. + + Raises: + DatasetError: If the format is not supported for upserts, i.e. not 'delta'. + """ + if self.write_mode == "upsert" and self.format != "delta": + raise DatasetError( + f"Format '{self.format}' is not supported for upserts. " + f"Please use 'delta' format." + ) def _validate_dataframe_type(self): """Validates the dataframe type. From 2c46978e9ea05811e56bf73674373db0e7b01006 Mon Sep 17 00:00:00 2001 From: Minura Punchihewa Date: Wed, 4 Sep 2024 00:20:50 +0530 Subject: [PATCH 17/51] imported the ExternalTableDataset into the main pkg Signed-off-by: Minura Punchihewa --- kedro-datasets/kedro_datasets/databricks/__init__.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/kedro-datasets/kedro_datasets/databricks/__init__.py b/kedro-datasets/kedro_datasets/databricks/__init__.py index 7f7ad7235..5bd62be7c 100644 --- a/kedro-datasets/kedro_datasets/databricks/__init__.py +++ b/kedro-datasets/kedro_datasets/databricks/__init__.py @@ -6,8 +6,12 @@ # https://github.com/pylint-dev/pylint/issues/4300#issuecomment-1043601901 ManagedTableDataset: Any +ExternalTableDataset: Any __getattr__, __dir__, __all__ = lazy.attach( __name__, - submod_attrs={"managed_table_dataset": ["ManagedTableDataset"]}, + submod_attrs={ + "managed_table_dataset": ["ManagedTableDataset"], + "external_table_dataset": ["ExternalTableDataset"], + }, ) From 2777ea1847a82ec8d9582a5d6e4f0d6f54a21726 Mon Sep 17 00:00:00 2001 From: Minura Punchihewa Date: Wed, 4 Sep 2024 00:30:18 +0530 Subject: [PATCH 18/51] improved the docstrings in the code Signed-off-by: Minura Punchihewa --- .../databricks/_base_table_dataset.py | 30 +++++++++---------- .../databricks/external_table_dataset.py | 4 +-- .../databricks/managed_table_dataset.py | 4 +-- 3 files changed, 19 insertions(+), 19 deletions(-) diff --git a/kedro-datasets/kedro_datasets/databricks/_base_table_dataset.py b/kedro-datasets/kedro_datasets/databricks/_base_table_dataset.py index c5a5e037b..5cda3a44b 100644 --- a/kedro-datasets/kedro_datasets/databricks/_base_table_dataset.py +++ b/kedro-datasets/kedro_datasets/databricks/_base_table_dataset.py @@ -29,7 +29,7 @@ class BaseTable: """Stores the definition of a base table. - Acts as a base class for `ManagedTable` and `ExternalTable`. + Acts as the base class for `ManagedTable` and `ExternalTable`. """ # regex for tables, catalogs and schemas _NAMING_REGEX: ClassVar[str] = r"\b[0-9a-zA-Z_-]{1,}\b" @@ -135,7 +135,7 @@ def _validate_dataframe_type(self): """Validates the dataframe type. Raises: - DatasetError: If an invalid `dataframe_type` is passed + DatasetError: If an invalid `dataframe_type` is passed. """ if self.dataframe_type not in self._VALID_DATAFRAME_TYPES: valid_types = ", ".join(self._VALID_DATAFRAME_TYPES) @@ -158,7 +158,7 @@ def full_table_location(self) -> str | None: """Returns the full table location. Returns: - str | None : table location in the format catalog.database.table or None if database and table aren't defined + str | None : table location in the format catalog.database.table or None if database and table aren't defined. """ full_table_location = None if self.catalog and self.database and self.table: @@ -171,7 +171,7 @@ def schema(self) -> StructType | None: """Returns the Spark schema of the table if it exists. Returns: - StructType: + StructType: the schema of the table. """ schema = None try: @@ -186,13 +186,13 @@ class BaseTableDataset(AbstractVersionedDataset): """``BaseTableDataset`` loads and saves data into managed delta tables or external tables on Databricks. Load and save can be in Spark or Pandas dataframes, specified in dataframe_type. - This dataaset is not meant to be used directly. It is a base class for ``ManagedTableDataset`` and ``ExternalTableDataset``. + This dataset is not meant to be used directly. It is a base class for ``ManagedTableDataset`` and ``ExternalTableDataset``. """ - # this dataset cannot be used with ``ParallelRunner``, + # datasets that inherit from this class cannot be used with ``ParallelRunner``, # therefore it has the attribute ``_SINGLE_PROCESS = True`` # for parallelism within a Spark pipeline please consider - # using ``ThreadRunner`` instead + # using ``ThreadRunner`` instead. _SINGLE_PROCESS = True def __init__( # noqa: PLR0913 @@ -301,15 +301,15 @@ def _create_table( def _load(self) -> DataFrame | pd.DataFrame: """Loads the version of data in the format defined in the init - (spark|pandas dataframe) + (spark|pandas dataframe). Raises: VersionNotFoundError: if the version defined in - the init doesn't exist + the init doesn't exist. Returns: Union[DataFrame, pd.DataFrame]: Returns a dataframe - in the format defined in the init + in the format defined in the init. """ if self._version and self._version.load >= 0: try: @@ -334,7 +334,7 @@ def _save(self, data: DataFrame | pd.DataFrame) -> None: (columns will be sorted and truncated). Args: - data (Any): Spark or pandas dataframe to save to the table location + data (Any): Spark or pandas dataframe to save to the table location. """ if self._table.write_mode is None: raise DatasetError( @@ -376,8 +376,8 @@ def _save_append(self, data: DataFrame) -> None: ) def _save_overwrite(self, data: DataFrame) -> None: - """Overwrites the data in the table with the data provided. - (this is the default save mode) + """Overwrites the data in the table with the data provided + (this is the default save mode). Args: data (DataFrame): the Spark dataframe to overwrite the table with. @@ -394,7 +394,7 @@ def _save_upsert(self, update_data: DataFrame) -> None: If table doesn't exist at save, the data is inserted to a new table. Args: - update_data (DataFrame): the Spark dataframe to upsert + update_data (DataFrame): the Spark dataframe to upsert. """ if self._exists(): base_data = _get_spark().table(self._table.full_table_location()) @@ -432,7 +432,7 @@ def _describe(self) -> dict[str, str | list | None]: """Returns a description of the instance of the dataset. Returns: - Dict[str, str]: Dict with the details of the dataset + Dict[str, str]: Dict with the details of the dataset. """ return { "catalog": self._table.catalog, diff --git a/kedro-datasets/kedro_datasets/databricks/external_table_dataset.py b/kedro-datasets/kedro_datasets/databricks/external_table_dataset.py index cb29cadcd..f7590e3d9 100644 --- a/kedro-datasets/kedro_datasets/databricks/external_table_dataset.py +++ b/kedro-datasets/kedro_datasets/databricks/external_table_dataset.py @@ -17,7 +17,7 @@ @dataclass(frozen=True, kw_only=True) class ExternalTable(BaseTable): - """Stores the definition of an external table""" + """Stores the definition of an external table.""" class ExternalTableDataset(BaseTableDataset): @@ -53,7 +53,7 @@ def _create_table( owner_group: The owner group of the table. Returns: - ExternalTable: the new ExternalTable instance + ExternalTable: The new ExternalTable instance. """ return ExternalTable( table=table, diff --git a/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py b/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py index 21d4fc520..fcc3d8831 100644 --- a/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py +++ b/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py @@ -20,7 +20,7 @@ @dataclass(frozen=True, kw_only=True) class ManagedTable(BaseTable): - """Stores the definition of a managed table""" + """Stores the definition of a managed table.""" _VALID_FORMATS: ClassVar[List[str]] = field(default=["delta"]) @@ -174,7 +174,7 @@ def _create_table( owner_group: The owner group of the table. Returns: - ManagedTable: the new ManagedTable instance + ManagedTable: The new ManagedTable instance. """ return ManagedTable( table=table, From 6e1b17e2f65042374477c413d8e89e351f07c0ed Mon Sep 17 00:00:00 2001 From: Minura Punchihewa Date: Wed, 4 Sep 2024 00:31:05 +0530 Subject: [PATCH 19/51] added format to the _describe() Signed-off-by: Minura Punchihewa --- kedro-datasets/kedro_datasets/databricks/_base_table_dataset.py | 1 + 1 file changed, 1 insertion(+) diff --git a/kedro-datasets/kedro_datasets/databricks/_base_table_dataset.py b/kedro-datasets/kedro_datasets/databricks/_base_table_dataset.py index 5cda3a44b..cd2c6cad7 100644 --- a/kedro-datasets/kedro_datasets/databricks/_base_table_dataset.py +++ b/kedro-datasets/kedro_datasets/databricks/_base_table_dataset.py @@ -438,6 +438,7 @@ def _describe(self) -> dict[str, str | list | None]: "catalog": self._table.catalog, "database": self._table.database, "table": self._table.table, + "format": self._table.format, "write_mode": self._table.write_mode, "dataframe_type": self._table.dataframe_type, "primary_key": self._table.primary_key, From 039f9e093e9a44c12d4933a08a9017dbfa43eb17 Mon Sep 17 00:00:00 2001 From: Minura Punchihewa Date: Wed, 4 Sep 2024 14:36:41 +0530 Subject: [PATCH 20/51] updated the save methods to incorporate partition columns Signed-off-by: Minura Punchihewa --- .../databricks/_base_table_dataset.py | 23 +++++++++++++------ 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/kedro-datasets/kedro_datasets/databricks/_base_table_dataset.py b/kedro-datasets/kedro_datasets/databricks/_base_table_dataset.py index cd2c6cad7..1f2297189 100644 --- a/kedro-datasets/kedro_datasets/databricks/_base_table_dataset.py +++ b/kedro-datasets/kedro_datasets/databricks/_base_table_dataset.py @@ -371,9 +371,14 @@ def _save_append(self, data: DataFrame) -> None: Args: data (DataFrame): the Spark dataframe to append to the table. """ - data.write.format(self._table.format).mode("append").saveAsTable( - self._table.full_table_location() or "" - ) + if self._table.partition_columns: + data.write.format(self._table.format).mode("append").partitionBy( + *self._table.partition_columns + ).saveAsTable(self._table.full_table_location() or "") + else: + data.write.format(self._table.format).mode("append").saveAsTable( + self._table.full_table_location() or "" + ) def _save_overwrite(self, data: DataFrame) -> None: """Overwrites the data in the table with the data provided @@ -382,12 +387,16 @@ def _save_overwrite(self, data: DataFrame) -> None: Args: data (DataFrame): the Spark dataframe to overwrite the table with. """ - table = data.write.format(self._table.format) - if self._table.write_mode == "overwrite": - table = table.mode("overwrite").option( + if self._table.partition_columns: + data.write.format(self._table.format).mode("overwrite").partitionBy( + *self._table.partition_columns + ).option( "overwriteSchema", "true" + ).saveAsTable(self._table.full_table_location() or "") + else: + data.write.format(self._table.format).mode("overwrite").saveAsTable( + self._table.full_table_location() or "" ) - table.saveAsTable(self._table.full_table_location() or "") def _save_upsert(self, update_data: DataFrame) -> None: """Upserts the data by joining on primary_key columns or column. From fb87133d88aca2cef163e7917a76eb5c364f686c Mon Sep 17 00:00:00 2001 From: Minura Punchihewa Date: Wed, 4 Sep 2024 17:26:08 +0530 Subject: [PATCH 21/51] reverted the default write_mode back to None Signed-off-by: Minura Punchihewa --- .../kedro_datasets/databricks/_base_table_dataset.py | 2 +- .../kedro_datasets/databricks/managed_table_dataset.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/kedro-datasets/kedro_datasets/databricks/_base_table_dataset.py b/kedro-datasets/kedro_datasets/databricks/_base_table_dataset.py index 1f2297189..a7d147b99 100644 --- a/kedro-datasets/kedro_datasets/databricks/_base_table_dataset.py +++ b/kedro-datasets/kedro_datasets/databricks/_base_table_dataset.py @@ -202,7 +202,7 @@ def __init__( # noqa: PLR0913 catalog: str | None = None, database: str = "default", format: str = "delta", - write_mode: str | None = "overwrite", + write_mode: str | None = None, dataframe_type: str = "spark", primary_key: str | list[str] | None = None, version: Version | None = None, diff --git a/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py b/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py index fcc3d8831..9daf88ada 100644 --- a/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py +++ b/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py @@ -28,7 +28,7 @@ class ManagedTable(BaseTable): class ManagedTableDataset(BaseTableDataset): """``ManagedTableDataset`` loads and saves data into managed delta tables in Databricks. Load and save can be in Spark or Pandas dataframes, specified in dataframe_type. - When saving data, you can specify one of three modes: overwrite(default), append, + When saving data, you can specify one of three modes: overwrite, append, or upsert. Upsert requires you to specify the primary_column parameter which will be used as part of the join condition. This dataset works best with the databricks kedro starter. That starter comes with hooks that allow this @@ -88,7 +88,7 @@ def __init__( # noqa: PLR0913 table: str, catalog: str | None = None, database: str = "default", - write_mode: str | None = "overwrite", + write_mode: str | None = None, dataframe_type: str = "spark", primary_key: str | list[str] | None = None, version: Version | None = None, From b7a5c331e7c8de4026f7b14fe07a6c185f137fd8 Mon Sep 17 00:00:00 2001 From: Minura Punchihewa Date: Wed, 4 Sep 2024 17:34:28 +0530 Subject: [PATCH 22/51] extended the _validate_write_mode() func to include formats Signed-off-by: Minura Punchihewa --- .../kedro_datasets/databricks/_base_table_dataset.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/kedro-datasets/kedro_datasets/databricks/_base_table_dataset.py b/kedro-datasets/kedro_datasets/databricks/_base_table_dataset.py index a7d147b99..2b2eeeb51 100644 --- a/kedro-datasets/kedro_datasets/databricks/_base_table_dataset.py +++ b/kedro-datasets/kedro_datasets/databricks/_base_table_dataset.py @@ -118,13 +118,8 @@ def _validate_write_mode(self): f"Invalid `write_mode` provided: {self.write_mode}. " f"`write_mode` must be one of: {valid_modes}" ) - - def _validate_format_for_upsert(self) -> None: - """Validates the format for upserts. - Raises: - DatasetError: If the format is not supported for upserts, i.e. not 'delta'. - """ + # Upserts are only supported for delta tables. if self.write_mode == "upsert" and self.format != "delta": raise DatasetError( f"Format '{self.format}' is not supported for upserts. " From 577dd91563b5e7b745f54e0233ba070996bb2dd2 Mon Sep 17 00:00:00 2001 From: Minura Punchihewa Date: Sun, 8 Sep 2024 11:21:59 +0530 Subject: [PATCH 23/51] updated the save() logic to work with single or multiple partition cols Signed-off-by: Minura Punchihewa --- .../kedro_datasets/databricks/_base_table_dataset.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/kedro-datasets/kedro_datasets/databricks/_base_table_dataset.py b/kedro-datasets/kedro_datasets/databricks/_base_table_dataset.py index 2b2eeeb51..1fedd2bb3 100644 --- a/kedro-datasets/kedro_datasets/databricks/_base_table_dataset.py +++ b/kedro-datasets/kedro_datasets/databricks/_base_table_dataset.py @@ -211,7 +211,7 @@ def __init__( # noqa: PLR0913 """Creates a new instance of ``BaseTableDataset``. Args: - table: the name of the table + table: the name of the table. catalog: the name of the catalog in Unity. Defaults to None. database: the name of the database. @@ -227,7 +227,7 @@ def __init__( # noqa: PLR0913 Can be in the form of a list. Defaults to None. schema: the schema of the table in JSON form. Dataframes will be truncated to match the schema if provided. - Used by the hooks to create the table if the schema is provided + Used by the hooks to create the table if the schema is provided. Defaults to None. partition_columns: the columns to use for partitioning the table. Used by the hooks. Defaults to None. @@ -368,7 +368,7 @@ def _save_append(self, data: DataFrame) -> None: """ if self._table.partition_columns: data.write.format(self._table.format).mode("append").partitionBy( - *self._table.partition_columns + *self._table.partition_columns if isinstance(self._table.partition_columns, list) else self._table.partition_columns ).saveAsTable(self._table.full_table_location() or "") else: data.write.format(self._table.format).mode("append").saveAsTable( @@ -384,7 +384,7 @@ def _save_overwrite(self, data: DataFrame) -> None: """ if self._table.partition_columns: data.write.format(self._table.format).mode("overwrite").partitionBy( - *self._table.partition_columns + *self._table.partition_columns if isinstance(self._table.partition_columns, list) else self._table.partition_columns ).option( "overwriteSchema", "true" ).saveAsTable(self._table.full_table_location() or "") From 6e729e346ac430c7f1960fff003e4c86b0cd1b44 Mon Sep 17 00:00:00 2001 From: Minura Punchihewa Date: Sun, 8 Sep 2024 12:59:29 +0530 Subject: [PATCH 24/51] updated the docstrings for the datasets with missing attrs Signed-off-by: Minura Punchihewa --- .../kedro_datasets/databricks/_base_table_dataset.py | 12 ++++++++---- .../databricks/external_table_dataset.py | 4 ++-- .../databricks/managed_table_dataset.py | 10 +++++----- 3 files changed, 15 insertions(+), 11 deletions(-) diff --git a/kedro-datasets/kedro_datasets/databricks/_base_table_dataset.py b/kedro-datasets/kedro_datasets/databricks/_base_table_dataset.py index 1fedd2bb3..a8f34bfe0 100644 --- a/kedro-datasets/kedro_datasets/databricks/_base_table_dataset.py +++ b/kedro-datasets/kedro_datasets/databricks/_base_table_dataset.py @@ -216,6 +216,9 @@ def __init__( # noqa: PLR0913 Defaults to None. database: the name of the database. (also referred to as schema). Defaults to "default". + format: the format of the table. + Applicable only for external tables. + Defaults to "delta". write_mode: the mode to write the data into the table. If not present, the data set is read-only. Options are:["overwrite", "append", "upsert"]. @@ -225,6 +228,8 @@ def __init__( # noqa: PLR0913 Defaults to "spark". primary_key: the primary key of the table. Can be in the form of a list. Defaults to None. + version: kedro.io.core.Version instance to load the data. + Defaults to None. schema: the schema of the table in JSON form. Dataframes will be truncated to match the schema if provided. Used by the hooks to create the table if the schema is provided. @@ -237,9 +242,8 @@ def __init__( # noqa: PLR0913 metadata: Any arbitrary metadata. This is ignored by Kedro, but may be consumed by users or external plugins. Raises: - DatasetError: Invalid configuration supplied (through BaseTable validation). + DatasetError: Invalid configuration supplied (through ``BaseTable`` validation). """ - self._table = self._create_table( table=table, catalog=catalog, @@ -275,7 +279,7 @@ def _create_table( partition_columns: list[str] | None, owner_group: str | None ) -> BaseTable: - """Creates a table object and assign it to the _table attribute. + """Creates a ``BaseTable`` instance with the provided attributes. Args: table: The name of the table. @@ -290,7 +294,7 @@ def _create_table( owner_group: The owner group of the table. Returns: - BaseTable: the table object. + ``BaseTable``: The new ``BaseTable`` instance. """ raise NotImplementedError diff --git a/kedro-datasets/kedro_datasets/databricks/external_table_dataset.py b/kedro-datasets/kedro_datasets/databricks/external_table_dataset.py index f7590e3d9..8b6f858b9 100644 --- a/kedro-datasets/kedro_datasets/databricks/external_table_dataset.py +++ b/kedro-datasets/kedro_datasets/databricks/external_table_dataset.py @@ -38,7 +38,7 @@ def _create_table( partition_columns: list[str] | None, owner_group: str | None ) -> ExternalTable: - """Creates a new ExternalTable instance with the provided attributes. + """Creates a new ``ExternalTable`` instance with the provided attributes. Args: table: The name of the table. @@ -53,7 +53,7 @@ def _create_table( owner_group: The owner group of the table. Returns: - ExternalTable: The new ExternalTable instance. + ``ExternalTable``: The new ``ExternalTable`` instance. """ return ExternalTable( table=table, diff --git a/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py b/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py index 9daf88ada..49ad4f470 100644 --- a/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py +++ b/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py @@ -102,7 +102,7 @@ def __init__( # noqa: PLR0913 """Creates a new instance of ``ManagedTableDataset``. Args: - table: the name of the table + table: the name of the table. catalog: the name of the catalog in Unity. Defaults to None. database: the name of the database. @@ -120,7 +120,7 @@ def __init__( # noqa: PLR0913 Defaults to None. schema: the schema of the table in JSON form. Dataframes will be truncated to match the schema if provided. - Used by the hooks to create the table if the schema is provided + Used by the hooks to create the table if the schema is provided. Defaults to None. partition_columns: the columns to use for partitioning the table. Used by the hooks. Defaults to None. @@ -130,7 +130,7 @@ def __init__( # noqa: PLR0913 metadata: Any arbitrary metadata. This is ignored by Kedro, but may be consumed by users or external plugins. Raises: - DatasetError: Invalid configuration supplied (through ManagedTable validation) + DatasetError: Invalid configuration supplied (through ``ManagedTable`` validation). """ super().__init__( table=table, @@ -159,7 +159,7 @@ def _create_table( partition_columns: list[str] | None, owner_group: str | None ) -> ManagedTable: - """Creates a new ManagedTable instance with the provided attributes. + """Creates a new ``ManagedTable`` instance with the provided attributes. Args: table: The name of the table. @@ -174,7 +174,7 @@ def _create_table( owner_group: The owner group of the table. Returns: - ManagedTable: The new ManagedTable instance. + ``ManagedTable``: The new ``ManagedTable`` instance. """ return ManagedTable( table=table, From cbb85148eccbb9445801f493775985a54493539a Mon Sep 17 00:00:00 2001 From: Minura Punchihewa Date: Sun, 8 Sep 2024 13:08:33 +0530 Subject: [PATCH 25/51] introduced a location attr for creating ext tables Signed-off-by: Minura Punchihewa --- .../kedro_datasets/databricks/_base_table_dataset.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/kedro-datasets/kedro_datasets/databricks/_base_table_dataset.py b/kedro-datasets/kedro_datasets/databricks/_base_table_dataset.py index a8f34bfe0..644d8be9d 100644 --- a/kedro-datasets/kedro_datasets/databricks/_base_table_dataset.py +++ b/kedro-datasets/kedro_datasets/databricks/_base_table_dataset.py @@ -41,6 +41,7 @@ class BaseTable: catalog: str | None table: str write_mode: str | None + location: str | None dataframe_type: str primary_key: str | list[str] | None owner_group: str | None @@ -198,6 +199,7 @@ def __init__( # noqa: PLR0913 database: str = "default", format: str = "delta", write_mode: str | None = None, + location: str | None = None, dataframe_type: str = "spark", primary_key: str | list[str] | None = None, version: Version | None = None, @@ -224,6 +226,10 @@ def __init__( # noqa: PLR0913 Options are:["overwrite", "append", "upsert"]. "upsert" mode requires primary_key field to be populated. Defaults to None. + location: the location of the table. + Applicable only for external tables. + Should be a valid path in an external location that has already been created. + Defaults to None. dataframe_type: "pandas" or "spark" dataframe. Defaults to "spark". primary_key: the primary key of the table. @@ -250,6 +256,7 @@ def __init__( # noqa: PLR0913 database=database, format=format, write_mode=write_mode, + location=location, dataframe_type=dataframe_type, primary_key=primary_key, json_schema=schema, @@ -273,6 +280,7 @@ def _create_table( database: str, format: str, write_mode: str | None, + location: str | None, dataframe_type: str, primary_key: str | list[str] | None, json_schema: dict[str, Any] | None, From 4a2dc6e1740bad4aa36e239798cd16a7c4416646 Mon Sep 17 00:00:00 2001 From: Minura Punchihewa Date: Sun, 8 Sep 2024 14:15:05 +0530 Subject: [PATCH 26/51] updated the save funcs to incorporate the locations attr Signed-off-by: Minura Punchihewa --- .../databricks/_base_table_dataset.py | 52 ++++++++++++------- 1 file changed, 33 insertions(+), 19 deletions(-) diff --git a/kedro-datasets/kedro_datasets/databricks/_base_table_dataset.py b/kedro-datasets/kedro_datasets/databricks/_base_table_dataset.py index 644d8be9d..d34d10fbe 100644 --- a/kedro-datasets/kedro_datasets/databricks/_base_table_dataset.py +++ b/kedro-datasets/kedro_datasets/databricks/_base_table_dataset.py @@ -16,6 +16,7 @@ DatasetError ) from pyspark.sql import DataFrame +from pyspark.sql.readwriter import DataFrameWriter from pyspark.sql.types import StructType from pyspark.sql.utils import AnalysisException, ParseException @@ -378,14 +379,11 @@ def _save_append(self, data: DataFrame) -> None: Args: data (DataFrame): the Spark dataframe to append to the table. """ - if self._table.partition_columns: - data.write.format(self._table.format).mode("append").partitionBy( - *self._table.partition_columns if isinstance(self._table.partition_columns, list) else self._table.partition_columns - ).saveAsTable(self._table.full_table_location() or "") - else: - data.write.format(self._table.format).mode("append").saveAsTable( - self._table.full_table_location() or "" - ) + writer = data.write.format(self._table.format).mode("append") + + writer = self._add_common_options_to_writer(writer) + + writer.saveAsTable(self._table.full_table_location() or "") def _save_overwrite(self, data: DataFrame) -> None: """Overwrites the data in the table with the data provided @@ -394,16 +392,13 @@ def _save_overwrite(self, data: DataFrame) -> None: Args: data (DataFrame): the Spark dataframe to overwrite the table with. """ - if self._table.partition_columns: - data.write.format(self._table.format).mode("overwrite").partitionBy( - *self._table.partition_columns if isinstance(self._table.partition_columns, list) else self._table.partition_columns - ).option( - "overwriteSchema", "true" - ).saveAsTable(self._table.full_table_location() or "") - else: - data.write.format(self._table.format).mode("overwrite").saveAsTable( - self._table.full_table_location() or "" - ) + writer = data.write.format(self._table.format).mode("overwrite").option( + "overwriteSchema", "true" + ) + + writer = self._add_common_options_to_writer(writer) + + writer.saveAsTable(self._table.full_table_location() or "") def _save_upsert(self, update_data: DataFrame) -> None: """Upserts the data by joining on primary_key columns or column. @@ -489,4 +484,23 @@ def _exists(self) -> bool: ) except (ParseException, AnalysisException) as exc: logger.warning("error occured while trying to find table: %s", exc) - return False \ No newline at end of file + return False + + def _add_common_options_to_writer(self, writer: DataFrameWriter) -> DataFrameWriter: + """Adds options to the writer based on the table properties. + + Args: + writer (DataFrameWriter): The DataFrameWriter instance. + + Returns: + DataFrameWriter: The DataFrameWriter instance with the options added. + """ + if self._table.partition_columns: + writer.partitionBy( + *self._table.partition_columns if isinstance(self._table.partition_columns, list) else self._table.partition_columns + ) + + if self._table.location: + writer.option("path", self._table.location) + + return writer \ No newline at end of file From 1734f4487ad59f3580cb938246a68abbc2c39bec Mon Sep 17 00:00:00 2001 From: Minura Punchihewa Date: Sun, 8 Sep 2024 14:17:21 +0530 Subject: [PATCH 27/51] moved the func to check if table exists to BaseTable Signed-off-by: Minura Punchihewa --- .../databricks/_base_table_dataset.py | 50 +++++++++++-------- 1 file changed, 29 insertions(+), 21 deletions(-) diff --git a/kedro-datasets/kedro_datasets/databricks/_base_table_dataset.py b/kedro-datasets/kedro_datasets/databricks/_base_table_dataset.py index d34d10fbe..a31496775 100644 --- a/kedro-datasets/kedro_datasets/databricks/_base_table_dataset.py +++ b/kedro-datasets/kedro_datasets/databricks/_base_table_dataset.py @@ -177,7 +177,34 @@ def schema(self) -> StructType | None: except (KeyError, ValueError) as exc: raise DatasetError(exc) from exc return schema - + + def exists(self) -> bool: + """Checks to see if the table exists. + + Returns: + bool: boolean of whether the table exists in the Spark session. + """ + if self.catalog: + try: + _get_spark().sql(f"USE CATALOG `{self.catalog}`") + except (ParseException, AnalysisException) as exc: + logger.warning( + "catalog %s not found or unity not enabled. Error message: %s", + self.catalog, + exc, + ) + try: + return ( + _get_spark() + .sql(f"SHOW TABLES IN `{self.database}`") + .filter(f"tableName = '{self.table}'") + .count() + > 0 + ) + except (ParseException, AnalysisException) as exc: + logger.warning("error occured while trying to find table: %s", exc) + return False + class BaseTableDataset(AbstractVersionedDataset): """``BaseTableDataset`` loads and saves data into managed delta tables or external tables on Databricks. @@ -465,26 +492,7 @@ def _exists(self) -> bool: bool: boolean of whether the table defined in the dataset instance exists in the Spark session. """ - if self._table.catalog: - try: - _get_spark().sql(f"USE CATALOG `{self._table.catalog}`") - except (ParseException, AnalysisException) as exc: - logger.warning( - "catalog %s not found or unity not enabled. Error message: %s", - self._table.catalog, - exc, - ) - try: - return ( - _get_spark() - .sql(f"SHOW TABLES IN `{self._table.database}`") - .filter(f"tableName = '{self._table.table}'") - .count() - > 0 - ) - except (ParseException, AnalysisException) as exc: - logger.warning("error occured while trying to find table: %s", exc) - return False + return self._table.exists() def _add_common_options_to_writer(self, writer: DataFrameWriter) -> DataFrameWriter: """Adds options to the writer based on the table properties. From b03bee89ff976388a626acd8b2340d678cffa20a Mon Sep 17 00:00:00 2001 From: Minura Punchihewa Date: Sun, 8 Sep 2024 14:26:48 +0530 Subject: [PATCH 28/51] added a val func to check if location is provided if table does not exist Signed-off-by: Minura Punchihewa --- .../databricks/_base_table_dataset.py | 10 +++++----- .../databricks/external_table_dataset.py | 16 ++++++++++++++++ 2 files changed, 21 insertions(+), 5 deletions(-) diff --git a/kedro-datasets/kedro_datasets/databricks/_base_table_dataset.py b/kedro-datasets/kedro_datasets/databricks/_base_table_dataset.py index a31496775..7a850c7a4 100644 --- a/kedro-datasets/kedro_datasets/databricks/_base_table_dataset.py +++ b/kedro-datasets/kedro_datasets/databricks/_base_table_dataset.py @@ -11,9 +11,9 @@ import pandas as pd from kedro.io.core import ( AbstractVersionedDataset, + DatasetError, Version, - VersionNotFoundError, - DatasetError + VersionNotFoundError ) from pyspark.sql import DataFrame from pyspark.sql.readwriter import DataFrameWriter @@ -84,7 +84,7 @@ def _validate_table(self): DatasetError: If the table name does not conform to naming constraints. """ if not re.fullmatch(self._NAMING_REGEX, self.table): - raise DatasetError("table does not conform to naming") + raise DatasetError("Table does not conform to naming") def _validate_database(self): """Validates database name. @@ -93,7 +93,7 @@ def _validate_database(self): DatasetError: If the dataset name does not conform to naming constraints. """ if not re.fullmatch(self._NAMING_REGEX, self.database): - raise DatasetError("database does not conform to naming") + raise DatasetError("Database does not conform to naming") def _validate_catalog(self): """Validates catalog name. @@ -103,7 +103,7 @@ def _validate_catalog(self): """ if self.catalog: if not re.fullmatch(self._NAMING_REGEX, self.catalog): - raise DatasetError("catalog does not conform to naming") + raise DatasetError("Catalog does not conform to naming") def _validate_write_mode(self): """Validates the write mode. diff --git a/kedro-datasets/kedro_datasets/databricks/external_table_dataset.py b/kedro-datasets/kedro_datasets/databricks/external_table_dataset.py index 8b6f858b9..453d538d9 100644 --- a/kedro-datasets/kedro_datasets/databricks/external_table_dataset.py +++ b/kedro-datasets/kedro_datasets/databricks/external_table_dataset.py @@ -8,6 +8,10 @@ from typing import Any import pandas as pd +import pandas as pd +from kedro.io.core import ( + DatasetError +) from kedro_datasets.databricks._base_table_dataset import BaseTable, BaseTableDataset @@ -19,6 +23,18 @@ class ExternalTable(BaseTable): """Stores the definition of an external table.""" + def _validate_existence_of_table(self) -> None: + """Validates that a location is provided if the table does not exist. + + Raises: + DatasetError: If the table does not exist and no location is provided. + """ + if not self.exists() and not self.location: + raise DatasetError( + "If the external table does not exists, the `location` parameter must be provided. " + "This should be valid path in an external location that has already been created." + ) + class ExternalTableDataset(BaseTableDataset): """``ExternalTableDataset`` loads and saves data into external tables in Databricks. From dc3550e74659228161956e161e84aaf54a476962 Mon Sep 17 00:00:00 2001 From: Minura Punchihewa Date: Sun, 8 Sep 2024 14:31:31 +0530 Subject: [PATCH 29/51] moved the val func for checking if write_mode supported to ExternalTable Signed-off-by: Minura Punchihewa --- .../kedro_datasets/databricks/_base_table_dataset.py | 7 ------- .../databricks/external_table_dataset.py | 12 ++++++++++++ 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/kedro-datasets/kedro_datasets/databricks/_base_table_dataset.py b/kedro-datasets/kedro_datasets/databricks/_base_table_dataset.py index 7a850c7a4..6188ff69e 100644 --- a/kedro-datasets/kedro_datasets/databricks/_base_table_dataset.py +++ b/kedro-datasets/kedro_datasets/databricks/_base_table_dataset.py @@ -121,13 +121,6 @@ def _validate_write_mode(self): f"`write_mode` must be one of: {valid_modes}" ) - # Upserts are only supported for delta tables. - if self.write_mode == "upsert" and self.format != "delta": - raise DatasetError( - f"Format '{self.format}' is not supported for upserts. " - f"Please use 'delta' format." - ) - def _validate_dataframe_type(self): """Validates the dataframe type. diff --git a/kedro-datasets/kedro_datasets/databricks/external_table_dataset.py b/kedro-datasets/kedro_datasets/databricks/external_table_dataset.py index 453d538d9..94e285dd0 100644 --- a/kedro-datasets/kedro_datasets/databricks/external_table_dataset.py +++ b/kedro-datasets/kedro_datasets/databricks/external_table_dataset.py @@ -34,6 +34,18 @@ def _validate_existence_of_table(self) -> None: "If the external table does not exists, the `location` parameter must be provided. " "This should be valid path in an external location that has already been created." ) + + def _validate_write_mode_for_format(self) -> None: + """Validates that the write mode is compatible with the format. + + Raises: + DatasetError: If the write mode is not compatible with the format. + """ + if self.write_mode == "upsert" and self.format != "delta": + raise DatasetError( + f"Format '{self.format}' is not supported for upserts. " + f"Please use 'delta' format." + ) class ExternalTableDataset(BaseTableDataset): From 78eee4d931efe11463c205b1e510429a4fdce09e Mon Sep 17 00:00:00 2001 From: Minura Punchihewa Date: Sat, 21 Sep 2024 23:40:08 +0530 Subject: [PATCH 30/51] removed the func for adding options to writer for better readability Signed-off-by: Minura Punchihewa --- .../databricks/_base_table_dataset.py | 35 ++++++++----------- 1 file changed, 14 insertions(+), 21 deletions(-) diff --git a/kedro-datasets/kedro_datasets/databricks/_base_table_dataset.py b/kedro-datasets/kedro_datasets/databricks/_base_table_dataset.py index 6188ff69e..ad27d70a3 100644 --- a/kedro-datasets/kedro_datasets/databricks/_base_table_dataset.py +++ b/kedro-datasets/kedro_datasets/databricks/_base_table_dataset.py @@ -401,7 +401,13 @@ def _save_append(self, data: DataFrame) -> None: """ writer = data.write.format(self._table.format).mode("append") - writer = self._add_common_options_to_writer(writer) + if self._table.partition_columns: + writer.partitionBy( + *self._table.partition_columns if isinstance(self._table.partition_columns, list) else self._table.partition_columns + ) + + if self._table.location: + writer.option("path", self._table.location) writer.saveAsTable(self._table.full_table_location() or "") @@ -416,7 +422,13 @@ def _save_overwrite(self, data: DataFrame) -> None: "overwriteSchema", "true" ) - writer = self._add_common_options_to_writer(writer) + if self._table.partition_columns: + writer.partitionBy( + *self._table.partition_columns if isinstance(self._table.partition_columns, list) else self._table.partition_columns + ) + + if self._table.location: + writer.option("path", self._table.location) writer.saveAsTable(self._table.full_table_location() or "") @@ -486,22 +498,3 @@ def _exists(self) -> bool: in the dataset instance exists in the Spark session. """ return self._table.exists() - - def _add_common_options_to_writer(self, writer: DataFrameWriter) -> DataFrameWriter: - """Adds options to the writer based on the table properties. - - Args: - writer (DataFrameWriter): The DataFrameWriter instance. - - Returns: - DataFrameWriter: The DataFrameWriter instance with the options added. - """ - if self._table.partition_columns: - writer.partitionBy( - *self._table.partition_columns if isinstance(self._table.partition_columns, list) else self._table.partition_columns - ) - - if self._table.location: - writer.option("path", self._table.location) - - return writer \ No newline at end of file From 09a8cb260a1a203e70e2abb13790d97c59a148cf Mon Sep 17 00:00:00 2001 From: Minura Punchihewa Date: Sat, 21 Sep 2024 23:50:20 +0530 Subject: [PATCH 31/51] added a validation check for overwrites on ext tables Signed-off-by: Minura Punchihewa --- .../kedro_datasets/databricks/external_table_dataset.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/kedro-datasets/kedro_datasets/databricks/external_table_dataset.py b/kedro-datasets/kedro_datasets/databricks/external_table_dataset.py index 94e285dd0..750a36909 100644 --- a/kedro-datasets/kedro_datasets/databricks/external_table_dataset.py +++ b/kedro-datasets/kedro_datasets/databricks/external_table_dataset.py @@ -46,6 +46,12 @@ def _validate_write_mode_for_format(self) -> None: f"Format '{self.format}' is not supported for upserts. " f"Please use 'delta' format." ) + + if self.write_mode == "overwrite" and self.format != "delta" and not self.location: + raise DatasetError( + f"Format '{self.format}' is supported for overwrites only if the location is provided. " + f"Please provide a valid path in an external location." + ) class ExternalTableDataset(BaseTableDataset): From 7e493c26d4d76b3a982250da8a2917e934c0891f Mon Sep 17 00:00:00 2001 From: Minura Punchihewa Date: Sat, 21 Sep 2024 23:53:23 +0530 Subject: [PATCH 32/51] implemented the _save_overwrite() func for ext tables Signed-off-by: Minura Punchihewa --- .../databricks/external_table_dataset.py | 24 ++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/kedro-datasets/kedro_datasets/databricks/external_table_dataset.py b/kedro-datasets/kedro_datasets/databricks/external_table_dataset.py index 750a36909..812fd4af3 100644 --- a/kedro-datasets/kedro_datasets/databricks/external_table_dataset.py +++ b/kedro-datasets/kedro_datasets/databricks/external_table_dataset.py @@ -12,6 +12,7 @@ from kedro.io.core import ( DatasetError ) +from pyspark.sql import DataFrame from kedro_datasets.databricks._base_table_dataset import BaseTable, BaseTableDataset @@ -100,4 +101,25 @@ def _create_table( owner_group=owner_group, primary_key=primary_key, format=format - ) \ No newline at end of file + ) + + def _save_overwrite(self, data: DataFrame) -> None: + """Overwrites the data in the table with the data provided + (this is the default save mode). + + Args: + data (DataFrame): the Spark dataframe to overwrite the table with. + """ + writer = data.write.format(self._table.format).mode("overwrite").option( + "overwriteSchema", "true" + ) + + if self._table.partition_columns: + writer.partitionBy( + *self._table.partition_columns if isinstance(self._table.partition_columns, list) else self._table.partition_columns + ) + + if self._table.location: + writer.option("path", self._table.location) + + writer.save(self._table.location) \ No newline at end of file From 6343226f905eb930ad4afc1ccc9fcbeeb1a65e80 Mon Sep 17 00:00:00 2001 From: Minura Punchihewa Date: Sat, 21 Sep 2024 23:54:25 +0530 Subject: [PATCH 33/51] removed mentions of a default write mode Signed-off-by: Minura Punchihewa --- .../kedro_datasets/databricks/_base_table_dataset.py | 3 +-- .../kedro_datasets/databricks/external_table_dataset.py | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/kedro-datasets/kedro_datasets/databricks/_base_table_dataset.py b/kedro-datasets/kedro_datasets/databricks/_base_table_dataset.py index ad27d70a3..c497d73e7 100644 --- a/kedro-datasets/kedro_datasets/databricks/_base_table_dataset.py +++ b/kedro-datasets/kedro_datasets/databricks/_base_table_dataset.py @@ -412,8 +412,7 @@ def _save_append(self, data: DataFrame) -> None: writer.saveAsTable(self._table.full_table_location() or "") def _save_overwrite(self, data: DataFrame) -> None: - """Overwrites the data in the table with the data provided - (this is the default save mode). + """Overwrites the data in the table with the data provided. Args: data (DataFrame): the Spark dataframe to overwrite the table with. diff --git a/kedro-datasets/kedro_datasets/databricks/external_table_dataset.py b/kedro-datasets/kedro_datasets/databricks/external_table_dataset.py index 812fd4af3..ed99effa1 100644 --- a/kedro-datasets/kedro_datasets/databricks/external_table_dataset.py +++ b/kedro-datasets/kedro_datasets/databricks/external_table_dataset.py @@ -104,8 +104,7 @@ def _create_table( ) def _save_overwrite(self, data: DataFrame) -> None: - """Overwrites the data in the table with the data provided - (this is the default save mode). + """Overwrites the data in the table with the data provided. Args: data (DataFrame): the Spark dataframe to overwrite the table with. From ccc60c465d4480164eda9d18073243912bcdffa3 Mon Sep 17 00:00:00 2001 From: Minura Punchihewa Date: Sun, 22 Sep 2024 00:24:49 +0530 Subject: [PATCH 34/51] improved the docstrings Signed-off-by: Minura Punchihewa --- .../databricks/_base_table_dataset.py | 36 ++++++------- .../databricks/external_table_dataset.py | 54 ++++++++++++++++++- .../databricks/managed_table_dataset.py | 16 +++--- 3 files changed, 79 insertions(+), 27 deletions(-) diff --git a/kedro-datasets/kedro_datasets/databricks/_base_table_dataset.py b/kedro-datasets/kedro_datasets/databricks/_base_table_dataset.py index c497d73e7..14e906128 100644 --- a/kedro-datasets/kedro_datasets/databricks/_base_table_dataset.py +++ b/kedro-datasets/kedro_datasets/databricks/_base_table_dataset.py @@ -148,7 +148,7 @@ def full_table_location(self) -> str | None: """Returns the full table location. Returns: - str | None : table location in the format catalog.database.table or None if database and table aren't defined. + str | None : Table location in the format catalog.database.table or None if database and table aren't defined. """ full_table_location = None if self.catalog and self.database and self.table: @@ -161,7 +161,7 @@ def schema(self) -> StructType | None: """Returns the Spark schema of the table if it exists. Returns: - StructType: the schema of the table. + StructType: The schema of the table. """ schema = None try: @@ -175,7 +175,7 @@ def exists(self) -> bool: """Checks to see if the table exists. Returns: - bool: boolean of whether the table exists in the Spark session. + bool: Boolean of whether the table exists in the Spark session. """ if self.catalog: try: @@ -234,36 +234,36 @@ def __init__( # noqa: PLR0913 """Creates a new instance of ``BaseTableDataset``. Args: - table: the name of the table. - catalog: the name of the catalog in Unity. + table: The name of the table. + catalog: The name of the catalog in Unity. Defaults to None. - database: the name of the database. + database: The name of the database. (also referred to as schema). Defaults to "default". - format: the format of the table. + format: The format of the table. Applicable only for external tables. Defaults to "delta". - write_mode: the mode to write the data into the table. If not + write_mode: The mode to write the data into the table. If not present, the data set is read-only. Options are:["overwrite", "append", "upsert"]. "upsert" mode requires primary_key field to be populated. Defaults to None. - location: the location of the table. + location: The location of the table. Applicable only for external tables. Should be a valid path in an external location that has already been created. Defaults to None. dataframe_type: "pandas" or "spark" dataframe. Defaults to "spark". - primary_key: the primary key of the table. + primary_key: The primary key of the table. Can be in the form of a list. Defaults to None. version: kedro.io.core.Version instance to load the data. Defaults to None. - schema: the schema of the table in JSON form. + schema: The schema of the table in JSON form. Dataframes will be truncated to match the schema if provided. Used by the hooks to create the table if the schema is provided. Defaults to None. - partition_columns: the columns to use for partitioning the table. + partition_columns: The columns to use for partitioning the table. Used by the hooks. Defaults to None. - owner_group: if table access control is enabled in your workspace, + owner_group: If table access control is enabled in your workspace, specifying owner_group will transfer ownership of the table and database to this owner. All databases should have the same owner_group. Defaults to None. metadata: Any arbitrary metadata. @@ -332,7 +332,7 @@ def _load(self) -> DataFrame | pd.DataFrame: (spark|pandas dataframe). Raises: - VersionNotFoundError: if the version defined in + VersionNotFoundError: If the version defined in the init doesn't exist. Returns: @@ -397,7 +397,7 @@ def _save_append(self, data: DataFrame) -> None: to the location defined in the init. Args: - data (DataFrame): the Spark dataframe to append to the table. + data (DataFrame): The Spark dataframe to append to the table. """ writer = data.write.format(self._table.format).mode("append") @@ -415,7 +415,7 @@ def _save_overwrite(self, data: DataFrame) -> None: """Overwrites the data in the table with the data provided. Args: - data (DataFrame): the Spark dataframe to overwrite the table with. + data (DataFrame): The Spark dataframe to overwrite the table with. """ writer = data.write.format(self._table.format).mode("overwrite").option( "overwriteSchema", "true" @@ -436,7 +436,7 @@ def _save_upsert(self, update_data: DataFrame) -> None: If table doesn't exist at save, the data is inserted to a new table. Args: - update_data (DataFrame): the Spark dataframe to upsert. + update_data (DataFrame): The Spark dataframe to upsert. """ if self._exists(): base_data = _get_spark().table(self._table.full_table_location()) @@ -493,7 +493,7 @@ def _exists(self) -> bool: """Checks to see if the table exists. Returns: - bool: boolean of whether the table defined + bool: Boolean of whether the table defined in the dataset instance exists in the Spark session. """ return self._table.exists() diff --git a/kedro-datasets/kedro_datasets/databricks/external_table_dataset.py b/kedro-datasets/kedro_datasets/databricks/external_table_dataset.py index ed99effa1..431f24f90 100644 --- a/kedro-datasets/kedro_datasets/databricks/external_table_dataset.py +++ b/kedro-datasets/kedro_datasets/databricks/external_table_dataset.py @@ -58,6 +58,58 @@ def _validate_write_mode_for_format(self) -> None: class ExternalTableDataset(BaseTableDataset): """``ExternalTableDataset`` loads and saves data into external tables in Databricks. Load and save can be in Spark or Pandas dataframes, specified in dataframe_type. + + Example usage for the + `YAML API `_: + + .. code-block:: yaml + + names_and_ages@spark: + type: databricks.ExternalTableDataset + format: parquet + table: names_and_ages + + names_and_ages@pandas: + type: databricks.ExternalTableDataset + format: parquet + table: names_and_ages + dataframe_type: pandas + + Example usage for the + `Python API `_: + + .. code-block:: pycon + >>> from kedro_datasets.databricks import ExternalTableDataset + >>> from pyspark.sql import SparkSession + >>> from pyspark.sql.types import IntegerType, Row, StringType, StructField, StructType + >>> import importlib_metadata + >>> + >>> DELTA_VERSION = importlib_metadata.version("delta-spark") + >>> schema = StructType( + ... [StructField("name", StringType(), True), StructField("age", IntegerType(), True)] + ... ) + >>> data = [("Alex", 31), ("Bob", 12), ("Clarke", 65), ("Dave", 29)] + >>> spark_df = ( + ... SparkSession.builder.config( + ... "spark.jars.packages", f"io.delta:delta-core_2.12:{DELTA_VERSION}" + ... ) + ... .config("spark.sql.extensions", "io.delta.sql.DeltaSparkSessionExtension") + ... .config( + ... "spark.sql.catalog.spark_catalog", + ... "org.apache.spark.sql.delta.catalog.DeltaCatalog", + ... ) + ... .getOrCreate() + ... .createDataFrame(data, schema) + ... ) + >>> dataset = ExternalTableDataset( + ... table="names_and_ages", + ... write_mode="overwrite", + ... location="abfss://container@storageaccount.dfs.core.windows.net/depts/cust" + ... ) + >>> dataset.save(spark_df) + >>> reloaded = dataset.load() + >>> assert Row(name="Bob", age=12) in reloaded.take(4) """ def _create_table( @@ -107,7 +159,7 @@ def _save_overwrite(self, data: DataFrame) -> None: """Overwrites the data in the table with the data provided. Args: - data (DataFrame): the Spark dataframe to overwrite the table with. + data (DataFrame): The Spark dataframe to overwrite the table with. """ writer = data.write.format(self._table.format).mode("overwrite").option( "overwriteSchema", "true" diff --git a/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py b/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py index 49ad4f470..2f4362e58 100644 --- a/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py +++ b/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py @@ -102,29 +102,29 @@ def __init__( # noqa: PLR0913 """Creates a new instance of ``ManagedTableDataset``. Args: - table: the name of the table. - catalog: the name of the catalog in Unity. + table: The name of the table. + catalog: The name of the catalog in Unity. Defaults to None. - database: the name of the database. + database: The name of the database. (also referred to as schema). Defaults to "default". - write_mode: the mode to write the data into the table. If not + write_mode: The mode to write the data into the table. If not present, the data set is read-only. Options are:["overwrite", "append", "upsert"]. "upsert" mode requires primary_key field to be populated. Defaults to None. dataframe_type: "pandas" or "spark" dataframe. Defaults to "spark". - primary_key: the primary key of the table. + primary_key: The primary key of the table. Can be in the form of a list. Defaults to None. version: kedro.io.core.Version instance to load the data. Defaults to None. - schema: the schema of the table in JSON form. + schema: The schema of the table in JSON form. Dataframes will be truncated to match the schema if provided. Used by the hooks to create the table if the schema is provided. Defaults to None. - partition_columns: the columns to use for partitioning the table. + partition_columns: The columns to use for partitioning the table. Used by the hooks. Defaults to None. - owner_group: if table access control is enabled in your workspace, + owner_group: If table access control is enabled in your workspace, specifying owner_group will transfer ownership of the table and database to this owner. All databases should have the same owner_group. Defaults to None. metadata: Any arbitrary metadata. From 99202eeb28c03215626af770687c3526e6211004 Mon Sep 17 00:00:00 2001 From: Minura Punchihewa Date: Sun, 22 Sep 2024 00:37:28 +0530 Subject: [PATCH 35/51] fixed lint issues Signed-off-by: Minura Punchihewa --- kedro-datasets/kedro_datasets/databricks/_base_table_dataset.py | 2 +- .../kedro_datasets/databricks/external_table_dataset.py | 2 +- .../kedro_datasets/databricks/managed_table_dataset.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/kedro-datasets/kedro_datasets/databricks/_base_table_dataset.py b/kedro-datasets/kedro_datasets/databricks/_base_table_dataset.py index 14e906128..072b23d6e 100644 --- a/kedro-datasets/kedro_datasets/databricks/_base_table_dataset.py +++ b/kedro-datasets/kedro_datasets/databricks/_base_table_dataset.py @@ -294,7 +294,7 @@ def __init__( # noqa: PLR0913 exists_function=self._exists, # type: ignore[arg-type] ) - def _create_table( + def _create_table( # noqa: PLR0913 self, table: str, catalog: str | None, diff --git a/kedro-datasets/kedro_datasets/databricks/external_table_dataset.py b/kedro-datasets/kedro_datasets/databricks/external_table_dataset.py index 431f24f90..7418f4bd5 100644 --- a/kedro-datasets/kedro_datasets/databricks/external_table_dataset.py +++ b/kedro-datasets/kedro_datasets/databricks/external_table_dataset.py @@ -112,7 +112,7 @@ class ExternalTableDataset(BaseTableDataset): >>> assert Row(name="Bob", age=12) in reloaded.take(4) """ - def _create_table( + def _create_table( # noqa: PLR0913 self, table: str, catalog: str | None, diff --git a/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py b/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py index 2f4362e58..b87720675 100644 --- a/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py +++ b/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py @@ -146,7 +146,7 @@ def __init__( # noqa: PLR0913 owner_group=owner_group, ) - def _create_table( + def _create_table( # noqa: PLR0913 self, table: str, catalog: str | None, From 31f385a10007ccb673cfa678547d58f2bd2aee68 Mon Sep 17 00:00:00 2001 From: Minura Punchihewa Date: Mon, 23 Sep 2024 00:03:57 +0530 Subject: [PATCH 36/51] fixed a couple of bugs in the Table classes Signed-off-by: Minura Punchihewa --- .../kedro_datasets/databricks/_base_table_dataset.py | 2 +- .../kedro_datasets/databricks/external_table_dataset.py | 4 +++- .../kedro_datasets/databricks/managed_table_dataset.py | 4 +++- 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/kedro-datasets/kedro_datasets/databricks/_base_table_dataset.py b/kedro-datasets/kedro_datasets/databricks/_base_table_dataset.py index 072b23d6e..b70f83ba9 100644 --- a/kedro-datasets/kedro_datasets/databricks/_base_table_dataset.py +++ b/kedro-datasets/kedro_datasets/databricks/_base_table_dataset.py @@ -26,7 +26,7 @@ pd.DataFrame.iteritems = pd.DataFrame.items -@dataclass(frozen=True, kw_only=True) +@dataclass(frozen=True) class BaseTable: """Stores the definition of a base table. diff --git a/kedro-datasets/kedro_datasets/databricks/external_table_dataset.py b/kedro-datasets/kedro_datasets/databricks/external_table_dataset.py index 7418f4bd5..ca563e3de 100644 --- a/kedro-datasets/kedro_datasets/databricks/external_table_dataset.py +++ b/kedro-datasets/kedro_datasets/databricks/external_table_dataset.py @@ -20,7 +20,7 @@ pd.DataFrame.iteritems = pd.DataFrame.items -@dataclass(frozen=True, kw_only=True) +@dataclass(frozen=True) class ExternalTable(BaseTable): """Stores the definition of an external table.""" @@ -119,6 +119,7 @@ def _create_table( # noqa: PLR0913 database: str, format: str, write_mode: str | None, + location: str | None, dataframe_type: str, primary_key: str | list[str] | None, json_schema: dict[str, Any] | None, @@ -147,6 +148,7 @@ def _create_table( # noqa: PLR0913 catalog=catalog, database=database, write_mode=write_mode, + location=location, dataframe_type=dataframe_type, json_schema=json_schema, partition_columns=partition_columns, diff --git a/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py b/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py index b87720675..3df07bc99 100644 --- a/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py +++ b/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py @@ -18,7 +18,7 @@ pd.DataFrame.iteritems = pd.DataFrame.items -@dataclass(frozen=True, kw_only=True) +@dataclass(frozen=True) class ManagedTable(BaseTable): """Stores the definition of a managed table.""" @@ -153,6 +153,7 @@ def _create_table( # noqa: PLR0913 database: str, format: str, write_mode: str | None, + location: str | None, dataframe_type: str, primary_key: str | list[str] | None, json_schema: dict[str, Any] | None, @@ -181,6 +182,7 @@ def _create_table( # noqa: PLR0913 catalog=catalog, database=database, write_mode=write_mode, + location=location, dataframe_type=dataframe_type, json_schema=json_schema, partition_columns=partition_columns, From 5016dee89b5607b3151a12895f4257e65bc8f21e Mon Sep 17 00:00:00 2001 From: Minura Punchihewa Date: Mon, 23 Sep 2024 00:12:32 +0530 Subject: [PATCH 37/51] updated the _save_overwrite() logic for ext tables Signed-off-by: Minura Punchihewa --- .../kedro_datasets/databricks/external_table_dataset.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/kedro-datasets/kedro_datasets/databricks/external_table_dataset.py b/kedro-datasets/kedro_datasets/databricks/external_table_dataset.py index ca563e3de..bc6403ead 100644 --- a/kedro-datasets/kedro_datasets/databricks/external_table_dataset.py +++ b/kedro-datasets/kedro_datasets/databricks/external_table_dataset.py @@ -175,4 +175,7 @@ def _save_overwrite(self, data: DataFrame) -> None: if self._table.location: writer.option("path", self._table.location) - writer.save(self._table.location) \ No newline at end of file + if self._table.format == "delta": + writer.saveAsTable(self._table.full_table_location() or "") + else: + writer.save(self._table.location) \ No newline at end of file From 77afc501cde865fe582690061a80347ececb211c Mon Sep 17 00:00:00 2001 From: Minura Punchihewa Date: Tue, 24 Sep 2024 11:14:42 +0530 Subject: [PATCH 38/51] renamed the val funcs of the ext tables Signed-off-by: Minura Punchihewa --- .../kedro_datasets/databricks/external_table_dataset.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/kedro-datasets/kedro_datasets/databricks/external_table_dataset.py b/kedro-datasets/kedro_datasets/databricks/external_table_dataset.py index bc6403ead..fc83dcc80 100644 --- a/kedro-datasets/kedro_datasets/databricks/external_table_dataset.py +++ b/kedro-datasets/kedro_datasets/databricks/external_table_dataset.py @@ -24,7 +24,7 @@ class ExternalTable(BaseTable): """Stores the definition of an external table.""" - def _validate_existence_of_table(self) -> None: + def _validate_location(self) -> None: """Validates that a location is provided if the table does not exist. Raises: @@ -36,12 +36,14 @@ def _validate_existence_of_table(self) -> None: "This should be valid path in an external location that has already been created." ) - def _validate_write_mode_for_format(self) -> None: + def _validate_write_mode(self) -> None: """Validates that the write mode is compatible with the format. Raises: DatasetError: If the write mode is not compatible with the format. """ + super()._validate_write_mode() + if self.write_mode == "upsert" and self.format != "delta": raise DatasetError( f"Format '{self.format}' is not supported for upserts. " From e7ba0e36f5c19746869c253c43ef44938f988b77 Mon Sep 17 00:00:00 2001 From: Minura Punchihewa Date: Tue, 24 Sep 2024 11:28:34 +0530 Subject: [PATCH 39/51] updated _save_overwrite() of ext tables to handle no existing tables Signed-off-by: Minura Punchihewa --- .../kedro_datasets/databricks/external_table_dataset.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/kedro-datasets/kedro_datasets/databricks/external_table_dataset.py b/kedro-datasets/kedro_datasets/databricks/external_table_dataset.py index fc83dcc80..ca963282c 100644 --- a/kedro-datasets/kedro_datasets/databricks/external_table_dataset.py +++ b/kedro-datasets/kedro_datasets/databricks/external_table_dataset.py @@ -174,10 +174,11 @@ def _save_overwrite(self, data: DataFrame) -> None: *self._table.partition_columns if isinstance(self._table.partition_columns, list) else self._table.partition_columns ) - if self._table.location: - writer.option("path", self._table.location) + if self._table.format == "delta" or (not self._table.exists()): + if self._table.location: + writer.option("path", self._table.location) - if self._table.format == "delta": writer.saveAsTable(self._table.full_table_location() or "") + else: writer.save(self._table.location) \ No newline at end of file From d6ba42611d6d8ba0bbde75e3625ff4ec1c824c3d Mon Sep 17 00:00:00 2001 From: Minura Punchihewa Date: Wed, 25 Sep 2024 23:59:47 +0530 Subject: [PATCH 40/51] fixed bug in supporting string partition cols Signed-off-by: Minura Punchihewa --- .../kedro_datasets/databricks/_base_table_dataset.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/kedro-datasets/kedro_datasets/databricks/_base_table_dataset.py b/kedro-datasets/kedro_datasets/databricks/_base_table_dataset.py index b70f83ba9..ffd811de5 100644 --- a/kedro-datasets/kedro_datasets/databricks/_base_table_dataset.py +++ b/kedro-datasets/kedro_datasets/databricks/_base_table_dataset.py @@ -403,7 +403,7 @@ def _save_append(self, data: DataFrame) -> None: if self._table.partition_columns: writer.partitionBy( - *self._table.partition_columns if isinstance(self._table.partition_columns, list) else self._table.partition_columns + *self._table.partition_columns if isinstance(self._table.partition_columns, list) else [self._table.partition_columns] ) if self._table.location: @@ -423,7 +423,7 @@ def _save_overwrite(self, data: DataFrame) -> None: if self._table.partition_columns: writer.partitionBy( - *self._table.partition_columns if isinstance(self._table.partition_columns, list) else self._table.partition_columns + *self._table.partition_columns if isinstance(self._table.partition_columns, list) else [self._table.partition_columns] ) if self._table.location: From 73dc1ebf032c2da0ecf77ae7196c8825468c0f9f Mon Sep 17 00:00:00 2001 From: Minura Punchihewa Date: Tue, 1 Oct 2024 20:21:31 +0530 Subject: [PATCH 41/51] removed the external table dataset Signed-off-by: Minura Punchihewa --- .../kedro_datasets/databricks/__init__.py | 6 +- .../databricks/external_table_dataset.py | 184 ------------------ 2 files changed, 1 insertion(+), 189 deletions(-) delete mode 100644 kedro-datasets/kedro_datasets/databricks/external_table_dataset.py diff --git a/kedro-datasets/kedro_datasets/databricks/__init__.py b/kedro-datasets/kedro_datasets/databricks/__init__.py index 5bd62be7c..7f7ad7235 100644 --- a/kedro-datasets/kedro_datasets/databricks/__init__.py +++ b/kedro-datasets/kedro_datasets/databricks/__init__.py @@ -6,12 +6,8 @@ # https://github.com/pylint-dev/pylint/issues/4300#issuecomment-1043601901 ManagedTableDataset: Any -ExternalTableDataset: Any __getattr__, __dir__, __all__ = lazy.attach( __name__, - submod_attrs={ - "managed_table_dataset": ["ManagedTableDataset"], - "external_table_dataset": ["ExternalTableDataset"], - }, + submod_attrs={"managed_table_dataset": ["ManagedTableDataset"]}, ) diff --git a/kedro-datasets/kedro_datasets/databricks/external_table_dataset.py b/kedro-datasets/kedro_datasets/databricks/external_table_dataset.py deleted file mode 100644 index ca963282c..000000000 --- a/kedro-datasets/kedro_datasets/databricks/external_table_dataset.py +++ /dev/null @@ -1,184 +0,0 @@ -"""``ExternalTableDataset`` implementation to access external tables -in Databricks. -""" -from __future__ import annotations - -import logging -from dataclasses import dataclass -from typing import Any - -import pandas as pd -import pandas as pd -from kedro.io.core import ( - DatasetError -) -from pyspark.sql import DataFrame - -from kedro_datasets.databricks._base_table_dataset import BaseTable, BaseTableDataset - -logger = logging.getLogger(__name__) -pd.DataFrame.iteritems = pd.DataFrame.items - - -@dataclass(frozen=True) -class ExternalTable(BaseTable): - """Stores the definition of an external table.""" - - def _validate_location(self) -> None: - """Validates that a location is provided if the table does not exist. - - Raises: - DatasetError: If the table does not exist and no location is provided. - """ - if not self.exists() and not self.location: - raise DatasetError( - "If the external table does not exists, the `location` parameter must be provided. " - "This should be valid path in an external location that has already been created." - ) - - def _validate_write_mode(self) -> None: - """Validates that the write mode is compatible with the format. - - Raises: - DatasetError: If the write mode is not compatible with the format. - """ - super()._validate_write_mode() - - if self.write_mode == "upsert" and self.format != "delta": - raise DatasetError( - f"Format '{self.format}' is not supported for upserts. " - f"Please use 'delta' format." - ) - - if self.write_mode == "overwrite" and self.format != "delta" and not self.location: - raise DatasetError( - f"Format '{self.format}' is supported for overwrites only if the location is provided. " - f"Please provide a valid path in an external location." - ) - - -class ExternalTableDataset(BaseTableDataset): - """``ExternalTableDataset`` loads and saves data into external tables in Databricks. - Load and save can be in Spark or Pandas dataframes, specified in dataframe_type. - - Example usage for the - `YAML API `_: - - .. code-block:: yaml - - names_and_ages@spark: - type: databricks.ExternalTableDataset - format: parquet - table: names_and_ages - - names_and_ages@pandas: - type: databricks.ExternalTableDataset - format: parquet - table: names_and_ages - dataframe_type: pandas - - Example usage for the - `Python API `_: - - .. code-block:: pycon - >>> from kedro_datasets.databricks import ExternalTableDataset - >>> from pyspark.sql import SparkSession - >>> from pyspark.sql.types import IntegerType, Row, StringType, StructField, StructType - >>> import importlib_metadata - >>> - >>> DELTA_VERSION = importlib_metadata.version("delta-spark") - >>> schema = StructType( - ... [StructField("name", StringType(), True), StructField("age", IntegerType(), True)] - ... ) - >>> data = [("Alex", 31), ("Bob", 12), ("Clarke", 65), ("Dave", 29)] - >>> spark_df = ( - ... SparkSession.builder.config( - ... "spark.jars.packages", f"io.delta:delta-core_2.12:{DELTA_VERSION}" - ... ) - ... .config("spark.sql.extensions", "io.delta.sql.DeltaSparkSessionExtension") - ... .config( - ... "spark.sql.catalog.spark_catalog", - ... "org.apache.spark.sql.delta.catalog.DeltaCatalog", - ... ) - ... .getOrCreate() - ... .createDataFrame(data, schema) - ... ) - >>> dataset = ExternalTableDataset( - ... table="names_and_ages", - ... write_mode="overwrite", - ... location="abfss://container@storageaccount.dfs.core.windows.net/depts/cust" - ... ) - >>> dataset.save(spark_df) - >>> reloaded = dataset.load() - >>> assert Row(name="Bob", age=12) in reloaded.take(4) - """ - - def _create_table( # noqa: PLR0913 - self, - table: str, - catalog: str | None, - database: str, - format: str, - write_mode: str | None, - location: str | None, - dataframe_type: str, - primary_key: str | list[str] | None, - json_schema: dict[str, Any] | None, - partition_columns: list[str] | None, - owner_group: str | None - ) -> ExternalTable: - """Creates a new ``ExternalTable`` instance with the provided attributes. - - Args: - table: The name of the table. - catalog: The catalog of the table. - database: The database of the table. - format: The format of the table. - write_mode: The write mode for the table. - dataframe_type: The type of dataframe. - primary_key: The primary key of the table. - json_schema: The JSON schema of the table. - partition_columns: The partition columns of the table. - owner_group: The owner group of the table. - - Returns: - ``ExternalTable``: The new ``ExternalTable`` instance. - """ - return ExternalTable( - table=table, - catalog=catalog, - database=database, - write_mode=write_mode, - location=location, - dataframe_type=dataframe_type, - json_schema=json_schema, - partition_columns=partition_columns, - owner_group=owner_group, - primary_key=primary_key, - format=format - ) - - def _save_overwrite(self, data: DataFrame) -> None: - """Overwrites the data in the table with the data provided. - - Args: - data (DataFrame): The Spark dataframe to overwrite the table with. - """ - writer = data.write.format(self._table.format).mode("overwrite").option( - "overwriteSchema", "true" - ) - - if self._table.partition_columns: - writer.partitionBy( - *self._table.partition_columns if isinstance(self._table.partition_columns, list) else self._table.partition_columns - ) - - if self._table.format == "delta" or (not self._table.exists()): - if self._table.location: - writer.option("path", self._table.location) - - writer.saveAsTable(self._table.full_table_location() or "") - - else: - writer.save(self._table.location) \ No newline at end of file From fef688f3c8d75bf831c1d2ac91b763b2d3bd7231 Mon Sep 17 00:00:00 2001 From: Minura Punchihewa Date: Tue, 1 Oct 2024 20:36:42 +0530 Subject: [PATCH 42/51] removed irrelevant attrs from describe() for managed tables Signed-off-by: Minura Punchihewa --- .../kedro_datasets/databricks/_base_table_dataset.py | 1 + .../databricks/managed_table_dataset.py | 12 ++++++++++++ 2 files changed, 13 insertions(+) diff --git a/kedro-datasets/kedro_datasets/databricks/_base_table_dataset.py b/kedro-datasets/kedro_datasets/databricks/_base_table_dataset.py index ffd811de5..700e1cf53 100644 --- a/kedro-datasets/kedro_datasets/databricks/_base_table_dataset.py +++ b/kedro-datasets/kedro_datasets/databricks/_base_table_dataset.py @@ -481,6 +481,7 @@ def _describe(self) -> dict[str, str | list | None]: "database": self._table.database, "table": self._table.table, "format": self._table.format, + "location": self._table.location, "write_mode": self._table.write_mode, "dataframe_type": self._table.dataframe_type, "primary_key": self._table.primary_key, diff --git a/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py b/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py index 3df07bc99..c39f2c156 100644 --- a/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py +++ b/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py @@ -190,4 +190,16 @@ def _create_table( # noqa: PLR0913 primary_key=primary_key, format=format ) + + def _describe(self) -> dict[str, str | list | None]: + """Returns a description of the instance of the dataset. + + Returns: + Dict[str, str]: Dict with the details of the dataset. + """ + description = super()._describe() + del description["format"] + del description["location"] + + return description From 8ee9ea82c98351ae0b6281e53fed2c85ef127e22 Mon Sep 17 00:00:00 2001 From: Minura Punchihewa Date: Tue, 1 Oct 2024 21:50:41 +0530 Subject: [PATCH 43/51] preserved order of args Signed-off-by: Minura Punchihewa --- .../kedro_datasets/databricks/managed_table_dataset.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py b/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py index c39f2c156..fd4a3964e 100644 --- a/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py +++ b/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py @@ -133,9 +133,9 @@ def __init__( # noqa: PLR0913 DatasetError: Invalid configuration supplied (through ``ManagedTable`` validation). """ super().__init__( - table=table, - catalog=catalog, database=database, + catalog=catalog, + table=table, write_mode=write_mode, dataframe_type=dataframe_type, version=version, From 4cd7c0f3d3ce7bcec754efb76b4b011f9a19f309 Mon Sep 17 00:00:00 2001 From: Minura Punchihewa Date: Fri, 4 Oct 2024 19:42:52 +0530 Subject: [PATCH 44/51] updated the tests for base table dataset and managed table dataset Signed-off-by: Minura Punchihewa --- .../databricks/test_base_table_dataset.py | 493 ++++++++++++++++++ .../databricks/test_managed_table_dataset.py | 305 ----------- 2 files changed, 493 insertions(+), 305 deletions(-) create mode 100644 kedro-datasets/tests/databricks/test_base_table_dataset.py diff --git a/kedro-datasets/tests/databricks/test_base_table_dataset.py b/kedro-datasets/tests/databricks/test_base_table_dataset.py new file mode 100644 index 000000000..6f0182474 --- /dev/null +++ b/kedro-datasets/tests/databricks/test_base_table_dataset.py @@ -0,0 +1,493 @@ +import pandas as pd +import pytest +from kedro.io.core import DatasetError, Version, VersionNotFoundError +from pyspark.sql import DataFrame, SparkSession +from pyspark.sql.types import IntegerType, StringType, StructField, StructType + +from kedro_datasets.databricks._base_table_dataset import BaseTableDataset + + +@pytest.fixture +def sample_spark_df(spark_session: SparkSession): + schema = StructType( + [ + StructField("name", StringType(), True), + StructField("age", IntegerType(), True), + ] + ) + + data = [("Alex", 31), ("Bob", 12), ("Clarke", 65), ("Dave", 29)] + + return spark_session.createDataFrame(data, schema) + + +@pytest.fixture +def upsert_spark_df(spark_session: SparkSession): + schema = StructType( + [ + StructField("name", StringType(), True), + StructField("age", IntegerType(), True), + ] + ) + + data = [("Alex", 32), ("Evan", 23)] + + return spark_session.createDataFrame(data, schema) + + +@pytest.fixture +def mismatched_upsert_spark_df(spark_session: SparkSession): + schema = StructType( + [ + StructField("name", StringType(), True), + StructField("age", IntegerType(), True), + StructField("height", IntegerType(), True), + ] + ) + + data = [("Alex", 32, 174), ("Evan", 23, 166)] + + return spark_session.createDataFrame(data, schema) + + +@pytest.fixture +def subset_spark_df(spark_session: SparkSession): + schema = StructType( + [ + StructField("name", StringType(), True), + StructField("age", IntegerType(), True), + StructField("height", IntegerType(), True), + ] + ) + + data = [("Alex", 32, 174), ("Evan", 23, 166)] + + return spark_session.createDataFrame(data, schema) + + +@pytest.fixture +def subset_pandas_df(): + return pd.DataFrame( + {"name": ["Alex", "Evan"], "age": [32, 23], "height": [174, 166]} + ) + + +@pytest.fixture +def subset_expected_df(spark_session: SparkSession): + schema = StructType( + [ + StructField("name", StringType(), True), + StructField("age", IntegerType(), True), + ] + ) + + data = [("Alex", 32), ("Evan", 23)] + + return spark_session.createDataFrame(data, schema) + + +@pytest.fixture +def sample_pandas_df(): + return pd.DataFrame( + {"name": ["Alex", "Bob", "Clarke", "Dave"], "age": [31, 12, 65, 29]} + ) + + +@pytest.fixture +def append_spark_df(spark_session: SparkSession): + schema = StructType( + [ + StructField("name", StringType(), True), + StructField("age", IntegerType(), True), + ] + ) + + data = [("Evan", 23), ("Frank", 13)] + + return spark_session.createDataFrame(data, schema) + + +@pytest.fixture +def expected_append_spark_df(spark_session: SparkSession): + schema = StructType( + [ + StructField("name", StringType(), True), + StructField("age", IntegerType(), True), + ] + ) + + data = [ + ("Alex", 31), + ("Bob", 12), + ("Clarke", 65), + ("Dave", 29), + ("Evan", 23), + ("Frank", 13), + ] + + return spark_session.createDataFrame(data, schema) + + +@pytest.fixture +def expected_upsert_spark_df(spark_session: SparkSession): + schema = StructType( + [ + StructField("name", StringType(), True), + StructField("age", IntegerType(), True), + ] + ) + + data = [ + ("Alex", 32), + ("Bob", 12), + ("Clarke", 65), + ("Dave", 29), + ("Evan", 23), + ] + + return spark_session.createDataFrame(data, schema) + + +@pytest.fixture +def expected_upsert_multiple_primary_spark_df(spark_session: SparkSession): + schema = StructType( + [ + StructField("name", StringType(), True), + StructField("age", IntegerType(), True), + ] + ) + + data = [ + ("Alex", 31), + ("Alex", 32), + ("Bob", 12), + ("Clarke", 65), + ("Dave", 29), + ("Evan", 23), + ] + + return spark_session.createDataFrame(data, schema) + + +class TestBaseTableDataset: + def test_full_table(self): + unity_ds = BaseTableDataset(catalog="test", database="test", table="test") + assert unity_ds._table.full_table_location() == "`test`.`test`.`test`" + + unity_ds = BaseTableDataset( + catalog="test-test", database="test", table="test" + ) + assert unity_ds._table.full_table_location() == "`test-test`.`test`.`test`" + + unity_ds = BaseTableDataset(database="test", table="test") + assert unity_ds._table.full_table_location() == "`test`.`test`" + + unity_ds = BaseTableDataset(table="test") + assert unity_ds._table.full_table_location() == "`default`.`test`" + + with pytest.raises(TypeError): + BaseTableDataset() + + def test_describe(self): + unity_ds = BaseTableDataset(table="test") + assert unity_ds._describe() == { + "catalog": None, + "database": "default", + "table": "test", + "format": "delta", + "location": None, + "write_mode": None, + "dataframe_type": "spark", + "primary_key": None, + "version": "None", + "owner_group": None, + "partition_columns": None, + } + + def test_invalid_write_mode(self): + with pytest.raises(DatasetError): + BaseTableDataset(table="test", write_mode="invalid") + + def test_dataframe_type(self): + with pytest.raises(DatasetError): + BaseTableDataset(table="test", dataframe_type="invalid") + + def test_missing_primary_key_upsert(self): + with pytest.raises(DatasetError): + BaseTableDataset(table="test", write_mode="upsert") + + def test_invalid_table_name(self): + with pytest.raises(DatasetError): + BaseTableDataset(table="invalid!") + + def test_invalid_database(self): + with pytest.raises(DatasetError): + BaseTableDataset(table="test", database="invalid!") + + def test_invalid_catalog(self): + with pytest.raises(DatasetError): + BaseTableDataset(table="test", catalog="invalid!") + + def test_schema(self): + unity_ds = BaseTableDataset( + table="test", + schema={ + "fields": [ + { + "metadata": {}, + "name": "name", + "nullable": True, + "type": "string", + }, + { + "metadata": {}, + "name": "age", + "nullable": True, + "type": "integer", + }, + ], + "type": "struct", + }, + ) + expected_schema = StructType( + [ + StructField("name", StringType(), True), + StructField("age", IntegerType(), True), + ] + ) + assert unity_ds._table.schema() == expected_schema + + def test_invalid_schema(self): + with pytest.raises(DatasetError): + BaseTableDataset( + table="test", + schema={ + "fields": [ + { + "invalid": "schema", + } + ], + "type": "struct", + }, + )._table.schema() + + def test_catalog_exists(self): + unity_ds = BaseTableDataset( + catalog="test", database="invalid", table="test_not_there" + ) + assert not unity_ds._exists() + + def test_table_does_not_exist(self): + unity_ds = BaseTableDataset(database="invalid", table="test_not_there") + assert not unity_ds._exists() + + def test_save_default(self, sample_spark_df: DataFrame): + unity_ds = BaseTableDataset(database="test", table="test_save") + with pytest.raises(DatasetError): + unity_ds.save(sample_spark_df) + + def test_save_schema_spark( + self, subset_spark_df: DataFrame, subset_expected_df: DataFrame + ): + unity_ds = BaseTableDataset( + database="test", + table="test_save_spark_schema", + schema={ + "fields": [ + { + "metadata": {}, + "name": "name", + "nullable": True, + "type": "string", + }, + { + "metadata": {}, + "name": "age", + "nullable": True, + "type": "integer", + }, + ], + "type": "struct", + }, + write_mode="overwrite", + ) + unity_ds.save(subset_spark_df) + saved_table = unity_ds.load() + assert subset_expected_df.exceptAll(saved_table).count() == 0 + + def test_save_schema_pandas( + self, subset_pandas_df: pd.DataFrame, subset_expected_df: DataFrame + ): + unity_ds = BaseTableDataset( + database="test", + table="test_save_pd_schema", + schema={ + "fields": [ + { + "metadata": {}, + "name": "name", + "nullable": True, + "type": "string", + }, + { + "metadata": {}, + "name": "age", + "nullable": True, + "type": "integer", + }, + ], + "type": "struct", + }, + write_mode="overwrite", + dataframe_type="pandas", + ) + unity_ds.save(subset_pandas_df) + saved_ds = BaseTableDataset( + database="test", + table="test_save_pd_schema", + ) + saved_table = saved_ds.load() + assert subset_expected_df.exceptAll(saved_table).count() == 0 + + def test_save_overwrite( + self, sample_spark_df: DataFrame, append_spark_df: DataFrame + ): + unity_ds = BaseTableDataset( + database="test", table="test_save", write_mode="overwrite" + ) + unity_ds.save(sample_spark_df) + unity_ds.save(append_spark_df) + + overwritten_table = unity_ds.load() + + assert append_spark_df.exceptAll(overwritten_table).count() == 0 + + def test_save_append( + self, + sample_spark_df: DataFrame, + append_spark_df: DataFrame, + expected_append_spark_df: DataFrame, + ): + unity_ds = BaseTableDataset( + database="test", table="test_save_append", write_mode="append" + ) + unity_ds.save(sample_spark_df) + unity_ds.save(append_spark_df) + + appended_table = unity_ds.load() + + assert expected_append_spark_df.exceptAll(appended_table).count() == 0 + + def test_save_upsert( + self, + sample_spark_df: DataFrame, + upsert_spark_df: DataFrame, + expected_upsert_spark_df: DataFrame, + ): + unity_ds = BaseTableDataset( + database="test", + table="test_save_upsert", + write_mode="upsert", + primary_key="name", + ) + unity_ds.save(sample_spark_df) + unity_ds.save(upsert_spark_df) + + upserted_table = unity_ds.load() + + assert expected_upsert_spark_df.exceptAll(upserted_table).count() == 0 + + def test_save_upsert_multiple_primary( + self, + sample_spark_df: DataFrame, + upsert_spark_df: DataFrame, + expected_upsert_multiple_primary_spark_df: DataFrame, + ): + unity_ds = BaseTableDataset( + database="test", + table="test_save_upsert_multiple", + write_mode="upsert", + primary_key=["name", "age"], + ) + unity_ds.save(sample_spark_df) + unity_ds.save(upsert_spark_df) + + upserted_table = unity_ds.load() + + assert ( + expected_upsert_multiple_primary_spark_df.exceptAll(upserted_table).count() + == 0 + ) + + def test_save_upsert_mismatched_columns( + self, + sample_spark_df: DataFrame, + mismatched_upsert_spark_df: DataFrame, + ): + unity_ds = BaseTableDataset( + database="test", + table="test_save_upsert_mismatch", + write_mode="upsert", + primary_key="name", + ) + unity_ds.save(sample_spark_df) + with pytest.raises(DatasetError): + unity_ds.save(mismatched_upsert_spark_df) + + def test_load_spark(self, sample_spark_df: DataFrame): + unity_ds = BaseTableDataset( + database="test", table="test_load_spark", write_mode="overwrite" + ) + unity_ds.save(sample_spark_df) + + delta_ds = BaseTableDataset(database="test", table="test_load_spark") + delta_table = delta_ds.load() + + assert ( + isinstance(delta_table, DataFrame) + and delta_table.exceptAll(sample_spark_df).count() == 0 + ) + + def test_load_spark_no_version(self, sample_spark_df: DataFrame): + unity_ds = BaseTableDataset( + database="test", table="test_load_spark", write_mode="overwrite" + ) + unity_ds.save(sample_spark_df) + + delta_ds = BaseTableDataset( + database="test", table="test_load_spark", version=Version(2, None) + ) + with pytest.raises(VersionNotFoundError): + _ = delta_ds.load() + + def test_load_version(self, sample_spark_df: DataFrame, append_spark_df: DataFrame): + unity_ds = BaseTableDataset( + database="test", table="test_load_version", write_mode="append" + ) + unity_ds.save(sample_spark_df) + unity_ds.save(append_spark_df) + + loaded_ds = BaseTableDataset( + database="test", table="test_load_version", version=Version(0, None) + ) + loaded_df = loaded_ds.load() + + assert loaded_df.exceptAll(sample_spark_df).count() == 0 + + def test_load_pandas(self, sample_pandas_df: pd.DataFrame): + unity_ds = BaseTableDataset( + database="test", + table="test_load_pandas", + dataframe_type="pandas", + write_mode="overwrite", + ) + unity_ds.save(sample_pandas_df) + + pandas_ds = BaseTableDataset( + database="test", table="test_load_pandas", dataframe_type="pandas" + ) + pandas_df = pandas_ds.load().sort_values("name", ignore_index=True) + + assert isinstance(pandas_df, pd.DataFrame) and pandas_df.equals( + sample_pandas_df + ) diff --git a/kedro-datasets/tests/databricks/test_managed_table_dataset.py b/kedro-datasets/tests/databricks/test_managed_table_dataset.py index 03a85d27e..aac95cd7a 100644 --- a/kedro-datasets/tests/databricks/test_managed_table_dataset.py +++ b/kedro-datasets/tests/databricks/test_managed_table_dataset.py @@ -170,23 +170,6 @@ def expected_upsert_multiple_primary_spark_df(spark_session: SparkSession): class TestManagedTableDataset: - def test_full_table(self): - unity_ds = ManagedTableDataset(catalog="test", database="test", table="test") - assert unity_ds._table.full_table_location() == "`test`.`test`.`test`" - - unity_ds = ManagedTableDataset( - catalog="test-test", database="test", table="test" - ) - assert unity_ds._table.full_table_location() == "`test-test`.`test`.`test`" - - unity_ds = ManagedTableDataset(database="test", table="test") - assert unity_ds._table.full_table_location() == "`test`.`test`" - - unity_ds = ManagedTableDataset(table="test") - assert unity_ds._table.full_table_location() == "`default`.`test`" - - with pytest.raises(TypeError): - ManagedTableDataset() def test_describe(self): unity_ds = ManagedTableDataset(table="test") @@ -201,291 +184,3 @@ def test_describe(self): "owner_group": None, "partition_columns": None, } - - def test_invalid_write_mode(self): - with pytest.raises(DatasetError): - ManagedTableDataset(table="test", write_mode="invalid") - - def test_dataframe_type(self): - with pytest.raises(DatasetError): - ManagedTableDataset(table="test", dataframe_type="invalid") - - def test_missing_primary_key_upsert(self): - with pytest.raises(DatasetError): - ManagedTableDataset(table="test", write_mode="upsert") - - def test_invalid_table_name(self): - with pytest.raises(DatasetError): - ManagedTableDataset(table="invalid!") - - def test_invalid_database(self): - with pytest.raises(DatasetError): - ManagedTableDataset(table="test", database="invalid!") - - def test_invalid_catalog(self): - with pytest.raises(DatasetError): - ManagedTableDataset(table="test", catalog="invalid!") - - def test_schema(self): - unity_ds = ManagedTableDataset( - table="test", - schema={ - "fields": [ - { - "metadata": {}, - "name": "name", - "nullable": True, - "type": "string", - }, - { - "metadata": {}, - "name": "age", - "nullable": True, - "type": "integer", - }, - ], - "type": "struct", - }, - ) - expected_schema = StructType( - [ - StructField("name", StringType(), True), - StructField("age", IntegerType(), True), - ] - ) - assert unity_ds._table.schema() == expected_schema - - def test_invalid_schema(self): - with pytest.raises(DatasetError): - ManagedTableDataset( - table="test", - schema={ - "fields": [ - { - "invalid": "schema", - } - ], - "type": "struct", - }, - )._table.schema() - - def test_catalog_exists(self): - unity_ds = ManagedTableDataset( - catalog="test", database="invalid", table="test_not_there" - ) - assert not unity_ds._exists() - - def test_table_does_not_exist(self): - unity_ds = ManagedTableDataset(database="invalid", table="test_not_there") - assert not unity_ds._exists() - - def test_save_default(self, sample_spark_df: DataFrame): - unity_ds = ManagedTableDataset(database="test", table="test_save") - with pytest.raises(DatasetError): - unity_ds.save(sample_spark_df) - - def test_save_schema_spark( - self, subset_spark_df: DataFrame, subset_expected_df: DataFrame - ): - unity_ds = ManagedTableDataset( - database="test", - table="test_save_spark_schema", - schema={ - "fields": [ - { - "metadata": {}, - "name": "name", - "nullable": True, - "type": "string", - }, - { - "metadata": {}, - "name": "age", - "nullable": True, - "type": "integer", - }, - ], - "type": "struct", - }, - write_mode="overwrite", - ) - unity_ds.save(subset_spark_df) - saved_table = unity_ds.load() - assert subset_expected_df.exceptAll(saved_table).count() == 0 - - def test_save_schema_pandas( - self, subset_pandas_df: pd.DataFrame, subset_expected_df: DataFrame - ): - unity_ds = ManagedTableDataset( - database="test", - table="test_save_pd_schema", - schema={ - "fields": [ - { - "metadata": {}, - "name": "name", - "nullable": True, - "type": "string", - }, - { - "metadata": {}, - "name": "age", - "nullable": True, - "type": "integer", - }, - ], - "type": "struct", - }, - write_mode="overwrite", - dataframe_type="pandas", - ) - unity_ds.save(subset_pandas_df) - saved_ds = ManagedTableDataset( - database="test", - table="test_save_pd_schema", - ) - saved_table = saved_ds.load() - assert subset_expected_df.exceptAll(saved_table).count() == 0 - - def test_save_overwrite( - self, sample_spark_df: DataFrame, append_spark_df: DataFrame - ): - unity_ds = ManagedTableDataset( - database="test", table="test_save", write_mode="overwrite" - ) - unity_ds.save(sample_spark_df) - unity_ds.save(append_spark_df) - - overwritten_table = unity_ds.load() - - assert append_spark_df.exceptAll(overwritten_table).count() == 0 - - def test_save_append( - self, - sample_spark_df: DataFrame, - append_spark_df: DataFrame, - expected_append_spark_df: DataFrame, - ): - unity_ds = ManagedTableDataset( - database="test", table="test_save_append", write_mode="append" - ) - unity_ds.save(sample_spark_df) - unity_ds.save(append_spark_df) - - appended_table = unity_ds.load() - - assert expected_append_spark_df.exceptAll(appended_table).count() == 0 - - def test_save_upsert( - self, - sample_spark_df: DataFrame, - upsert_spark_df: DataFrame, - expected_upsert_spark_df: DataFrame, - ): - unity_ds = ManagedTableDataset( - database="test", - table="test_save_upsert", - write_mode="upsert", - primary_key="name", - ) - unity_ds.save(sample_spark_df) - unity_ds.save(upsert_spark_df) - - upserted_table = unity_ds.load() - - assert expected_upsert_spark_df.exceptAll(upserted_table).count() == 0 - - def test_save_upsert_multiple_primary( - self, - sample_spark_df: DataFrame, - upsert_spark_df: DataFrame, - expected_upsert_multiple_primary_spark_df: DataFrame, - ): - unity_ds = ManagedTableDataset( - database="test", - table="test_save_upsert_multiple", - write_mode="upsert", - primary_key=["name", "age"], - ) - unity_ds.save(sample_spark_df) - unity_ds.save(upsert_spark_df) - - upserted_table = unity_ds.load() - - assert ( - expected_upsert_multiple_primary_spark_df.exceptAll(upserted_table).count() - == 0 - ) - - def test_save_upsert_mismatched_columns( - self, - sample_spark_df: DataFrame, - mismatched_upsert_spark_df: DataFrame, - ): - unity_ds = ManagedTableDataset( - database="test", - table="test_save_upsert_mismatch", - write_mode="upsert", - primary_key="name", - ) - unity_ds.save(sample_spark_df) - with pytest.raises(DatasetError): - unity_ds.save(mismatched_upsert_spark_df) - - def test_load_spark(self, sample_spark_df: DataFrame): - unity_ds = ManagedTableDataset( - database="test", table="test_load_spark", write_mode="overwrite" - ) - unity_ds.save(sample_spark_df) - - delta_ds = ManagedTableDataset(database="test", table="test_load_spark") - delta_table = delta_ds.load() - - assert ( - isinstance(delta_table, DataFrame) - and delta_table.exceptAll(sample_spark_df).count() == 0 - ) - - def test_load_spark_no_version(self, sample_spark_df: DataFrame): - unity_ds = ManagedTableDataset( - database="test", table="test_load_spark", write_mode="overwrite" - ) - unity_ds.save(sample_spark_df) - - delta_ds = ManagedTableDataset( - database="test", table="test_load_spark", version=Version(2, None) - ) - with pytest.raises(VersionNotFoundError): - _ = delta_ds.load() - - def test_load_version(self, sample_spark_df: DataFrame, append_spark_df: DataFrame): - unity_ds = ManagedTableDataset( - database="test", table="test_load_version", write_mode="append" - ) - unity_ds.save(sample_spark_df) - unity_ds.save(append_spark_df) - - loaded_ds = ManagedTableDataset( - database="test", table="test_load_version", version=Version(0, None) - ) - loaded_df = loaded_ds.load() - - assert loaded_df.exceptAll(sample_spark_df).count() == 0 - - def test_load_pandas(self, sample_pandas_df: pd.DataFrame): - unity_ds = ManagedTableDataset( - database="test", - table="test_load_pandas", - dataframe_type="pandas", - write_mode="overwrite", - ) - unity_ds.save(sample_pandas_df) - - pandas_ds = ManagedTableDataset( - database="test", table="test_load_pandas", dataframe_type="pandas" - ) - pandas_df = pandas_ds.load().sort_values("name", ignore_index=True) - - assert isinstance(pandas_df, pd.DataFrame) and pandas_df.equals( - sample_pandas_df - ) From 99828171624281529588e0e967c4658d221ca65e Mon Sep 17 00:00:00 2001 From: Minura Punchihewa Date: Fri, 4 Oct 2024 19:48:43 +0530 Subject: [PATCH 45/51] initialized the base table in the base table dataset Signed-off-by: Minura Punchihewa --- .../databricks/_base_table_dataset.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/kedro-datasets/kedro_datasets/databricks/_base_table_dataset.py b/kedro-datasets/kedro_datasets/databricks/_base_table_dataset.py index 700e1cf53..cba1798a3 100644 --- a/kedro-datasets/kedro_datasets/databricks/_base_table_dataset.py +++ b/kedro-datasets/kedro_datasets/databricks/_base_table_dataset.py @@ -325,7 +325,19 @@ def _create_table( # noqa: PLR0913 Returns: ``BaseTable``: The new ``BaseTable`` instance. """ - raise NotImplementedError + return BaseTable( + table=table, + catalog=catalog, + database=database, + format=format, + write_mode=write_mode, + location=location, + dataframe_type=dataframe_type, + json_schema=json_schema, + partition_columns=partition_columns, + owner_group=owner_group, + primary_key=primary_key, + ) def _load(self) -> DataFrame | pd.DataFrame: """Loads the version of data in the format defined in the init From 5a05a3468962a2147ae5b754dc23f3739d2b3cda Mon Sep 17 00:00:00 2001 From: Minura Punchihewa Date: Fri, 4 Oct 2024 19:57:58 +0530 Subject: [PATCH 46/51] fixed lint issues Signed-off-by: Minura Punchihewa --- .../databricks/_base_table_dataset.py | 49 ++++++++++++------- .../databricks/managed_table_dataset.py | 16 +++--- .../databricks/test_base_table_dataset.py | 4 +- .../databricks/test_managed_table_dataset.py | 4 +- 4 files changed, 39 insertions(+), 34 deletions(-) diff --git a/kedro-datasets/kedro_datasets/databricks/_base_table_dataset.py b/kedro-datasets/kedro_datasets/databricks/_base_table_dataset.py index cba1798a3..de893f633 100644 --- a/kedro-datasets/kedro_datasets/databricks/_base_table_dataset.py +++ b/kedro-datasets/kedro_datasets/databricks/_base_table_dataset.py @@ -6,17 +6,16 @@ import logging import re from dataclasses import dataclass, field -from typing import Any, ClassVar, List +from typing import Any, ClassVar import pandas as pd from kedro.io.core import ( AbstractVersionedDataset, DatasetError, Version, - VersionNotFoundError + VersionNotFoundError, ) from pyspark.sql import DataFrame -from pyspark.sql.readwriter import DataFrameWriter from pyspark.sql.types import StructType from pyspark.sql.utils import AnalysisException, ParseException @@ -29,14 +28,19 @@ @dataclass(frozen=True) class BaseTable: """Stores the definition of a base table. - + Acts as the base class for `ManagedTable` and `ExternalTable`. """ + # regex for tables, catalogs and schemas _NAMING_REGEX: ClassVar[str] = r"\b[0-9a-zA-Z_-]{1,}\b" - _VALID_WRITE_MODES: ClassVar[List[str]] = field(default=["overwrite", "upsert", "append"]) - _VALID_DATAFRAME_TYPES: ClassVar[List[str]] = field(default=["spark", "pandas"]) - _VALID_FORMATS: ClassVar[List[str]] = field(default=["delta", "parquet", "csv", "json", "orc", "avro", "text"]) + _VALID_WRITE_MODES: ClassVar[list[str]] = field( + default=["overwrite", "upsert", "append"] + ) + _VALID_DATAFRAME_TYPES: ClassVar[list[str]] = field(default=["spark", "pandas"]) + _VALID_FORMATS: ClassVar[list[str]] = field( + default=["delta", "parquet", "csv", "json", "orc", "avro", "text"] + ) database: str catalog: str | None @@ -47,7 +51,7 @@ class BaseTable: primary_key: str | list[str] | None owner_group: str | None partition_columns: str | list[str] | None - format: str = "delta", + format: str = ("delta",) json_schema: dict[str, Any] | None = None def __post_init__(self): @@ -229,7 +233,7 @@ def __init__( # noqa: PLR0913 schema: dict[str, Any] | None = None, partition_columns: list[str] | None = None, owner_group: str | None = None, - metadata: dict[str, Any] | None = None + metadata: dict[str, Any] | None = None, ) -> None: """Creates a new instance of ``BaseTableDataset``. @@ -306,7 +310,7 @@ def _create_table( # noqa: PLR0913 primary_key: str | list[str] | None, json_schema: dict[str, Any] | None, partition_columns: list[str] | None, - owner_group: str | None + owner_group: str | None, ) -> BaseTable: """Creates a ``BaseTable`` instance with the provided attributes. @@ -315,6 +319,7 @@ def _create_table( # noqa: PLR0913 catalog: The catalog of the table. database: The database of the table. format: The format of the table. + location: The location of the table. write_mode: The write mode for the table. dataframe_type: The type of dataframe. primary_key: The primary key of the table. @@ -338,7 +343,7 @@ def _create_table( # noqa: PLR0913 owner_group=owner_group, primary_key=primary_key, ) - + def _load(self) -> DataFrame | pd.DataFrame: """Loads the version of data in the format defined in the init (spark|pandas dataframe). @@ -366,7 +371,7 @@ def _load(self) -> DataFrame | pd.DataFrame: if self._table.dataframe_type == "pandas": data = data.toPandas() return data - + def _save(self, data: DataFrame | pd.DataFrame) -> None: """Saves the data based on the write_mode and dataframe_type in the init. If write_mode is pandas, Spark dataframe is created first. @@ -401,9 +406,9 @@ def _save(self, data: DataFrame | pd.DataFrame) -> None: f"Invalid `write_mode` provided: {self._table.write_mode}. " f"`write_mode` must be one of: {self._table._VALID_WRITE_MODES}" ) - + method(data) - + def _save_append(self, data: DataFrame) -> None: """Saves the data to the table by appending it to the location defined in the init. @@ -415,7 +420,9 @@ def _save_append(self, data: DataFrame) -> None: if self._table.partition_columns: writer.partitionBy( - *self._table.partition_columns if isinstance(self._table.partition_columns, list) else [self._table.partition_columns] + *self._table.partition_columns + if isinstance(self._table.partition_columns, list) + else [self._table.partition_columns] ) if self._table.location: @@ -429,13 +436,17 @@ def _save_overwrite(self, data: DataFrame) -> None: Args: data (DataFrame): The Spark dataframe to overwrite the table with. """ - writer = data.write.format(self._table.format).mode("overwrite").option( - "overwriteSchema", "true" + writer = ( + data.write.format(self._table.format) + .mode("overwrite") + .option("overwriteSchema", "true") ) - + if self._table.partition_columns: writer.partitionBy( - *self._table.partition_columns if isinstance(self._table.partition_columns, list) else [self._table.partition_columns] + *self._table.partition_columns + if isinstance(self._table.partition_columns, list) + else [self._table.partition_columns] ) if self._table.location: diff --git a/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py b/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py index fd4a3964e..6185a498b 100644 --- a/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py +++ b/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py @@ -5,12 +5,10 @@ import logging from dataclasses import dataclass, field -from typing import Any, ClassVar, List +from typing import Any, ClassVar import pandas as pd -from kedro.io.core import ( - Version -) +from kedro.io.core import Version from kedro_datasets.databricks._base_table_dataset import BaseTable, BaseTableDataset @@ -22,7 +20,7 @@ class ManagedTable(BaseTable): """Stores the definition of a managed table.""" - _VALID_FORMATS: ClassVar[List[str]] = field(default=["delta"]) + _VALID_FORMATS: ClassVar[list[str]] = field(default=["delta"]) class ManagedTableDataset(BaseTableDataset): @@ -82,6 +80,7 @@ class ManagedTableDataset(BaseTableDataset): >>> reloaded = dataset.load() >>> assert Row(name="Bob", age=12) in reloaded.take(4) """ + def __init__( # noqa: PLR0913 self, *, @@ -158,7 +157,7 @@ def _create_table( # noqa: PLR0913 primary_key: str | list[str] | None, json_schema: dict[str, Any] | None, partition_columns: list[str] | None, - owner_group: str | None + owner_group: str | None, ) -> ManagedTable: """Creates a new ``ManagedTable`` instance with the provided attributes. @@ -188,9 +187,9 @@ def _create_table( # noqa: PLR0913 partition_columns=partition_columns, owner_group=owner_group, primary_key=primary_key, - format=format + format=format, ) - + def _describe(self) -> dict[str, str | list | None]: """Returns a description of the instance of the dataset. @@ -202,4 +201,3 @@ def _describe(self) -> dict[str, str | list | None]: del description["location"] return description - diff --git a/kedro-datasets/tests/databricks/test_base_table_dataset.py b/kedro-datasets/tests/databricks/test_base_table_dataset.py index 6f0182474..dddfe794a 100644 --- a/kedro-datasets/tests/databricks/test_base_table_dataset.py +++ b/kedro-datasets/tests/databricks/test_base_table_dataset.py @@ -174,9 +174,7 @@ def test_full_table(self): unity_ds = BaseTableDataset(catalog="test", database="test", table="test") assert unity_ds._table.full_table_location() == "`test`.`test`.`test`" - unity_ds = BaseTableDataset( - catalog="test-test", database="test", table="test" - ) + unity_ds = BaseTableDataset(catalog="test-test", database="test", table="test") assert unity_ds._table.full_table_location() == "`test-test`.`test`.`test`" unity_ds = BaseTableDataset(database="test", table="test") diff --git a/kedro-datasets/tests/databricks/test_managed_table_dataset.py b/kedro-datasets/tests/databricks/test_managed_table_dataset.py index aac95cd7a..c3cc623f4 100644 --- a/kedro-datasets/tests/databricks/test_managed_table_dataset.py +++ b/kedro-datasets/tests/databricks/test_managed_table_dataset.py @@ -1,7 +1,6 @@ import pandas as pd import pytest -from kedro.io.core import DatasetError, Version, VersionNotFoundError -from pyspark.sql import DataFrame, SparkSession +from pyspark.sql import SparkSession from pyspark.sql.types import IntegerType, StringType, StructField, StructType from kedro_datasets.databricks import ManagedTableDataset @@ -170,7 +169,6 @@ def expected_upsert_multiple_primary_spark_df(spark_session: SparkSession): class TestManagedTableDataset: - def test_describe(self): unity_ds = ManagedTableDataset(table="test") assert unity_ds._describe() == { From c5d68eaf14b50f476b3f08fc0b26de6398c76de6 Mon Sep 17 00:00:00 2001 From: Minura Punchihewa Date: Fri, 4 Oct 2024 20:17:55 +0530 Subject: [PATCH 47/51] fixed an incorrect type hint Signed-off-by: Minura Punchihewa --- kedro-datasets/kedro_datasets/databricks/_base_table_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/kedro-datasets/kedro_datasets/databricks/_base_table_dataset.py b/kedro-datasets/kedro_datasets/databricks/_base_table_dataset.py index de893f633..f7c59a72f 100644 --- a/kedro-datasets/kedro_datasets/databricks/_base_table_dataset.py +++ b/kedro-datasets/kedro_datasets/databricks/_base_table_dataset.py @@ -51,7 +51,7 @@ class BaseTable: primary_key: str | list[str] | None owner_group: str | None partition_columns: str | list[str] | None - format: str = ("delta",) + format: str = "delta" json_schema: dict[str, Any] | None = None def __post_init__(self): From b6dc5b9c45145980fb1ead1bec9f5fe19396723f Mon Sep 17 00:00:00 2001 From: Minura Punchihewa Date: Mon, 7 Oct 2024 11:00:42 +0530 Subject: [PATCH 48/51] removed redundant check on save() Signed-off-by: Minura Punchihewa --- .../kedro_datasets/databricks/_base_table_dataset.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/kedro-datasets/kedro_datasets/databricks/_base_table_dataset.py b/kedro-datasets/kedro_datasets/databricks/_base_table_dataset.py index f7c59a72f..0a9e54d99 100644 --- a/kedro-datasets/kedro_datasets/databricks/_base_table_dataset.py +++ b/kedro-datasets/kedro_datasets/databricks/_base_table_dataset.py @@ -401,12 +401,6 @@ def _save(self, data: DataFrame | pd.DataFrame) -> None: method = getattr(self, f"_save_{self._table.write_mode}", None) - if method is None: - raise DatasetError( - f"Invalid `write_mode` provided: {self._table.write_mode}. " - f"`write_mode` must be one of: {self._table._VALID_WRITE_MODES}" - ) - method(data) def _save_append(self, data: DataFrame) -> None: From c5ea9d90d5ddb7930230f2f514733b39a9c83ba7 Mon Sep 17 00:00:00 2001 From: Minura Punchihewa Date: Mon, 7 Oct 2024 11:41:15 +0530 Subject: [PATCH 49/51] added the missing unit tests Signed-off-by: Minura Punchihewa --- .../databricks/test_base_table_dataset.py | 84 +++++++++++++++++++ 1 file changed, 84 insertions(+) diff --git a/kedro-datasets/tests/databricks/test_base_table_dataset.py b/kedro-datasets/tests/databricks/test_base_table_dataset.py index dddfe794a..0ad78e451 100644 --- a/kedro-datasets/tests/databricks/test_base_table_dataset.py +++ b/kedro-datasets/tests/databricks/test_base_table_dataset.py @@ -1,3 +1,4 @@ +import os import pandas as pd import pytest from kedro.io.core import DatasetError, Version, VersionNotFoundError @@ -169,6 +170,11 @@ def expected_upsert_multiple_primary_spark_df(spark_session: SparkSession): return spark_session.createDataFrame(data, schema) +@pytest.fixture +def external_location(): + return os.environ.get("DATABRICKS_EXTERNAL_LOCATION", "s3://bucket/test_save_external") + + class TestBaseTableDataset: def test_full_table(self): unity_ds = BaseTableDataset(catalog="test", database="test", table="test") @@ -226,6 +232,10 @@ def test_invalid_catalog(self): with pytest.raises(DatasetError): BaseTableDataset(table="test", catalog="invalid!") + def test_invalid_format(self): + with pytest.raises(DatasetError): + BaseTableDataset(table="test", format="invalid") + def test_schema(self): unity_ds = BaseTableDataset( table="test", @@ -360,6 +370,41 @@ def test_save_overwrite( assert append_spark_df.exceptAll(overwritten_table).count() == 0 + def test_save_overwrite_partitioned( + self, sample_spark_df: DataFrame, append_spark_df: DataFrame + ): + unity_ds = BaseTableDataset( + database="test", + table="test_save_partitioned", + write_mode="overwrite", + partition_columns=["name"], + ) + unity_ds.save(sample_spark_df) + unity_ds.save(append_spark_df) + + overwritten_table = unity_ds.load() + + assert append_spark_df.exceptAll(overwritten_table).count() == 0 + + def test_save_overwrite_external( + self, + sample_spark_df: DataFrame, + append_spark_df: DataFrame, + external_location: str, + ): + unity_ds = BaseTableDataset( + database="test", + table="test_save_external", + write_mode="overwrite", + location=external_location, + ) + unity_ds.save(sample_spark_df) + unity_ds.save(append_spark_df) + + overwritten_table = unity_ds.load() + + assert append_spark_df.exceptAll(overwritten_table).count() == 0 + def test_save_append( self, sample_spark_df: DataFrame, @@ -376,6 +421,45 @@ def test_save_append( assert expected_append_spark_df.exceptAll(appended_table).count() == 0 + def test_save_append_partitioned( + self, + sample_spark_df: DataFrame, + append_spark_df: DataFrame, + expected_append_spark_df: DataFrame, + ): + unity_ds = BaseTableDataset( + database="test", + table="test_save_append_partitioned", + write_mode="append", + partition_columns=["name"], + ) + unity_ds.save(sample_spark_df) + unity_ds.save(append_spark_df) + + appended_table = unity_ds.load() + + assert expected_append_spark_df.exceptAll(appended_table).count() == 0 + + def test_save_append_external( + self, + sample_spark_df: DataFrame, + append_spark_df: DataFrame, + expected_append_spark_df: DataFrame, + external_location: str, + ): + unity_ds = BaseTableDataset( + database="test", + table="test_save_append_external", + write_mode="append", + location=external_location, + ) + unity_ds.save(sample_spark_df) + unity_ds.save(append_spark_df) + + appended_table = unity_ds.load() + + assert expected_append_spark_df.exceptAll(appended_table).count() == 0 + def test_save_upsert( self, sample_spark_df: DataFrame, From 0558d96441acf6f2ef04c3eb9e5900b7f6db7344 Mon Sep 17 00:00:00 2001 From: Minura Punchihewa Date: Mon, 7 Oct 2024 11:56:58 +0530 Subject: [PATCH 50/51] fixed the tests for saving external tables Signed-off-by: Minura Punchihewa --- kedro-datasets/tests/databricks/test_base_table_dataset.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/kedro-datasets/tests/databricks/test_base_table_dataset.py b/kedro-datasets/tests/databricks/test_base_table_dataset.py index 0ad78e451..e4b601a05 100644 --- a/kedro-datasets/tests/databricks/test_base_table_dataset.py +++ b/kedro-datasets/tests/databricks/test_base_table_dataset.py @@ -172,7 +172,7 @@ def expected_upsert_multiple_primary_spark_df(spark_session: SparkSession): @pytest.fixture def external_location(): - return os.environ.get("DATABRICKS_EXTERNAL_LOCATION", "s3://bucket/test_save_external") + return os.environ.get("DATABRICKS_EXTERNAL_LOCATION") class TestBaseTableDataset: @@ -396,7 +396,7 @@ def test_save_overwrite_external( database="test", table="test_save_external", write_mode="overwrite", - location=external_location, + location=f"{external_location}/test_save_external", ) unity_ds.save(sample_spark_df) unity_ds.save(append_spark_df) @@ -451,7 +451,7 @@ def test_save_append_external( database="test", table="test_save_append_external", write_mode="append", - location=external_location, + location=f"{external_location}/test_save_append_external", ) unity_ds.save(sample_spark_df) unity_ds.save(append_spark_df) From d598d79ed3edb030cb9112f7d33c4cd49adf9225 Mon Sep 17 00:00:00 2001 From: Ankita Katiyar Date: Thu, 10 Oct 2024 15:13:45 +0100 Subject: [PATCH 51/51] lint Signed-off-by: Ankita Katiyar --- kedro-datasets/tests/databricks/test_base_table_dataset.py | 1 + 1 file changed, 1 insertion(+) diff --git a/kedro-datasets/tests/databricks/test_base_table_dataset.py b/kedro-datasets/tests/databricks/test_base_table_dataset.py index e4b601a05..5cc88e8df 100644 --- a/kedro-datasets/tests/databricks/test_base_table_dataset.py +++ b/kedro-datasets/tests/databricks/test_base_table_dataset.py @@ -1,4 +1,5 @@ import os + import pandas as pd import pytest from kedro.io.core import DatasetError, Version, VersionNotFoundError