Skip to content

Commit

Permalink
Updates all examples to use the new data loader functionality
Browse files Browse the repository at this point in the history
  • Loading branch information
elijahbenizzy committed Apr 4, 2023
1 parent dde909f commit f87311f
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 16 deletions.
11 changes: 6 additions & 5 deletions examples/data_quality/pandera/data_loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

import pandas as pd

from hamilton.function_modifiers import config, extract_columns
from hamilton.function_modifiers import config, extract_columns, load_from, source, value

data_columns = [
"id",
Expand Down Expand Up @@ -51,12 +51,13 @@ def _sanitize_columns(df_columns: List[str]) -> List[str]:

@config.when_not_in(execution=["dask", "spark"])
@extract_columns(*data_columns)
def raw_data__base(location: str) -> pd.DataFrame:
@load_from.csv(location=source("location"), sep=value(";"))
def raw_data__base(df: pd.DataFrame) -> pd.DataFrame:
"""Extracts the raw data, renames the columns to be valid python variable names, and assigns an index.
:param location: the location to load from
:return:
"""
df = pd.read_csv(location, sep=";")

# rename columns to be valid hamilton names -- and lower case it
df.columns = _sanitize_columns(df.columns)
# create proper index -- ID-Month-Day;
Expand Down Expand Up @@ -97,13 +98,13 @@ def raw_data__dask(location: str, block_size: str = "10KB") -> pd.DataFrame:

@config.when(execution="spark")
@extract_columns("index_col", *data_columns)
def raw_data__spark(location: str) -> pd.DataFrame:
@load_from.csv(location=source("location"), sep=value(";"))
def raw_data__spark(df: pd.DataFrame) -> pd.DataFrame:
"""Extracts the raw data, renames the columns to be valid python variable names, and assigns an index.
:param location: the location to load from
:param number_partitions: number of partitions to partition the data for dask.
:return: dask dataframe
"""
df = pd.read_csv(location, sep=";")
# rename columns to be valid hamilton names -- and lower case it
df.columns = _sanitize_columns(df.columns)
# create proper index -- ID-Month-Day;
Expand Down
10 changes: 5 additions & 5 deletions examples/data_quality/simple/data_loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

import pandas as pd

from hamilton.function_modifiers import config, extract_columns
from hamilton.function_modifiers import config, extract_columns, load_from, source, value

data_columns = [
"id",
Expand Down Expand Up @@ -49,12 +49,12 @@ def _sanitize_columns(df_columns: List[str]) -> List[str]:

@config.when_not_in(execution=["dask", "spark"])
@extract_columns(*data_columns)
def raw_data__base(location: str) -> pd.DataFrame:
@load_from.csv(location=source("location"), sep=value(";"))
def raw_data__base(df: pd.DataFrame) -> pd.DataFrame:
"""Extracts the raw data, renames the columns to be valid python variable names, and assigns an index.
:param location: the location to load from
:return:
"""
df = pd.read_csv(location, sep=";")
# rename columns to be valid hamilton names -- and lower case it
df.columns = _sanitize_columns(df.columns)
# create proper index -- ID-Month-Day;
Expand Down Expand Up @@ -95,13 +95,13 @@ def raw_data__dask(location: str, block_size: str = "10KB") -> pd.DataFrame:

@config.when(execution="spark")
@extract_columns("index_col", *data_columns)
def raw_data__spark(location: str) -> pd.DataFrame:
@load_from.csv(location=source("location"), sep=value(";"))
def raw_data__spark(df: pd.DataFrame) -> pd.DataFrame:
"""Extracts the raw data, renames the columns to be valid python variable names, and assigns an index.
:param location: the location to load from
:param number_partitions: number of partitions to partition the data for dask.
:return: dask dataframe
"""
df = pd.read_csv(location, sep=";")
# rename columns to be valid hamilton names -- and lower case it
df.columns = _sanitize_columns(df.columns)
# create proper index -- ID-Month-Day;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import pandas as pd

from hamilton.function_modifiers import extract_columns
from hamilton.function_modifiers import extract_columns, load_from, source, value

# full set of available columns from the data source
data_columns = [
Expand Down Expand Up @@ -46,12 +46,12 @@ def _sanitize_columns(df_columns: List[str]) -> List[str]:


@extract_columns(*data_columns)
def raw_data(location: str) -> pd.DataFrame:
@load_from.csv(location=source("location"), sep=value(";"))
def raw_data(df: pd.DataFrame) -> pd.DataFrame:
"""Extracts the raw data, renames the columns to be valid python variable names, and assigns an index.
:param location: the location to load from
:return:
"""
df = pd.read_csv(location, sep=";")
# rename columns to be valid hamilton names -- and lower case it
df.columns = _sanitize_columns(df.columns)
# create proper index -- ID-Month-Day - to be able to join features appropriately.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import pandas as pd

from hamilton.function_modifiers import extract_columns
from hamilton.function_modifiers import extract_columns, load_from, source, value

# full set of available columns from the data source
data_columns = [
Expand Down Expand Up @@ -46,6 +46,7 @@ def _sanitize_columns(df_columns: List[str]) -> List[str]:


@extract_columns(*data_columns)
@load_from.csv(location=source("location"), sep=value(";"))
def raw_data(location: str) -> pd.DataFrame:
"""Extracts the raw data, renames the columns to be valid python variable names, and assigns an index.
:param location: the location to load from
Expand Down
17 changes: 15 additions & 2 deletions hamilton/plugins/pandas_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,26 @@ class CSVDataAdapter(DataFrameDataLoader):
"""

path: str
sep: str = None

def _get_loading_kwargs(self):
kwargs = {}
if self.sep is not None:
kwargs["sep"] = self.sep
return kwargs

def _get_saving_kwargs(self):
kwargs = {"index": False}
if self.sep is not None:
kwargs["sep"] = self.sep
return kwargs

def save_data(self, data: DATAFRAME_TYPE) -> Dict[str, Any]:
data.to_csv(self.path, index=False)
data.to_csv(self.path, **self._get_saving_kwargs())
return utils.get_file_loading_metadata(self.path)

def load_data(self, type_: Type) -> Tuple[DATAFRAME_TYPE, Dict[str, Any]]:
df = pd.read_csv(self.path)
df = pd.read_csv(self.path, **self._get_loading_kwargs())
metadata = utils.get_file_loading_metadata(self.path)
return df, metadata

Expand Down

0 comments on commit f87311f

Please sign in to comment.