Skip to content

Commit

Permalink
Fix various issues (#439)
Browse files Browse the repository at this point in the history
Co-authored-by: Karan Jariwala <karankjariwala@gmail.com>
  • Loading branch information
knighton and karan6181 committed Sep 18, 2023
1 parent d1bbc34 commit 4ae10fe
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 53 deletions.
91 changes: 43 additions & 48 deletions streaming/base/converters/dataframe_to_mds.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,47 +54,41 @@
}


def is_iterable(obj: object) -> bool:
"""Check if obj is iterable.
Args:
obj: python object
Return:
bool: true if obj is iterable false otherwise
"""
return issubclass(type(obj), Iterable)


def infer_dataframe_schema(dataframe: DataFrame,
user_defined_cols: Optional[Dict[str, Any]] = None) -> Optional[Dict]:
"""Retrieve schema to construct a dictionary or do sanity check for MDSWriter.
Args:
dataframe (spark dataframe): dataframe to inspect schema
user_defined_cols (Optional[Dict[str, Any]]): user specified schema for MDSWriter
Return:
If user_defined_cols is None, return schema_dict (dict): column name and dtypes that supported by MDSWriter
Else, return None
Exceptions:
Any of the datatype found to be unsupported by MDSWriter, then raise ValueError
Returns:
If user_defined_cols is None, return schema_dict (dict): column name and dtypes that are
supported by MDSWriter, else None
Raises:
ValueError if any of the datatypes are unsupported by MDSWriter.
"""

def map_spark_dtype(spark_data_type: Any) -> str:
"""Map spark data type to mds supported types.
Args:
spark_data_type: https://spark.apache.org/docs/latest/sql-ref-datatypes.html
Return:
Returns:
str: corresponding mds datatype for input.
Exception:
Raises:
raise ValueError if no mds datatype is found for input type
"""
mds_type = MAPPING_SPARK_TO_MDS.get(type(spark_data_type), None)
if mds_type is None:
raise ValueError(f'{spark_data_type} is not supported by MDSWriter')
return mds_type

if user_defined_cols is not None: # user has provided schema, we just check if mds supports the dtype
# user has provided schema, we just check if mds supports the dtype
if user_defined_cols is not None:
mds_supported_dtypes = {
mds_type for mds_type in MAPPING_SPARK_TO_MDS.values() if mds_type is not None
}
Expand All @@ -109,8 +103,8 @@ def map_spark_dtype(spark_data_type: Any) -> str:
mapped_mds_dtype = map_spark_dtype(actual_spark_dtype)
if user_dtype != mapped_mds_dtype:
raise ValueError(
f'Mismatched types: column name `{col_name}` is `{mapped_mds_dtype}` in DataFrame but `{user_dtype}` in user_defined_cols'
)
f'Mismatched types: column name `{col_name}` is `{mapped_mds_dtype}` in ' +
f'DataFrame but `{user_dtype}` in user_defined_cols')
return None

schema = dataframe.schema
Expand All @@ -130,9 +124,8 @@ def do_merge_index(partitions: Iterable, mds_path: Union[str, Tuple[str, str]])
Args:
partitions (Iterable): partitions that contain pd.DataFrame
mds_path (Union[str, Tuple[str, str]]): (str,str)=(local,remote), str = local or remote based on parse_uri(url) result
Return:
None
mds_path (Union[str, Tuple[str, str]]): (str,str)=(local,remote), str = local or remote
based on parse_uri(url) result
"""
if not partitions:
logger.warning('No partitions exist, no index merged')
Expand Down Expand Up @@ -174,31 +167,33 @@ def dataframeToMDS(dataframe: DataFrame,
udf_kwargs: Optional[Dict[str, Any]] = None) -> Tuple[Any, int]:
"""Execute a spark dataframe to MDS conversion process.
This method orchestrates the conversion of a spark dataframe into MDS format by
processing the input data, applying a user-defined iterable function if
provided, and writing the results to MDS-compatible format. The converted data is saved to mds_path.
This method orchestrates the conversion of a spark dataframe into MDS format by processing the
input data, applying a user-defined iterable function if provided, and writing the results to
an MDS-compatible format. The converted data is saved to mds_path.
Args:
dataframe (pyspark.sql.DataFrame): A DataFrame containing Delta Lake data.
merge_index (bool): Whether to merge MDS index files. Defaults to ``True``.
mds_kwargs (dict): Refer to https://docs.mosaicml.com/projects/streaming/en/stable/api_reference/generated/streaming.MDSWriter.html
udf_iterable (Callable or None): A user-defined function that returns an iterable over the dataframe. udf_kwargs is the k-v args for the method. Defaults to ``None``.
mds_kwargs (dict): Refer to https://docs.mosaicml.com/projects/streaming/en/stable/
api_reference/generated/streaming.MDSWriter.html
udf_iterable (Callable or None): A user-defined function that returns an iterable over the
dataframe. udf_kwargs is the k-v args for the method. Defaults to ``None``.
udf_kwargs (Dict): Additional keyword arguments to pass to the pandas processing
function if provided. Defaults to an empty dictionary.
Return:
Returns:
mds_path (str or (str,str)): actual local and remote path were used
fail_count (int): number of records failed to be converted
Note:
Notes:
- The method creates a SparkSession if not already available.
- The 'udf_kwargs' dictionaries can be used to pass additional
keyword arguments to the udf_iterable.
- If udf_iterable is set, schema check will be skipped because the user defined iterable can create new columns. User must make sure they provide correct mds_kwargs[columns]
- If udf_iterable is set, schema check will be skipped because the user defined iterable
can create new columns. User must make sure they provide correct mds_kwargs[columns]
"""

def write_mds(iterator: Iterable):

context = TaskContext.get()

if context is not None:
Expand Down Expand Up @@ -230,9 +225,10 @@ def write_mds(iterator: Iterable):
records = udf_iterable(pdf, **udf_kwargs or {})
else:
records = pdf.to_dict('records')
assert is_iterable(
records
), f'pandas_processing_fn needs to return an iterable instead of a {type(records)}'
assert isinstance(
records,
Iterable), (f'pandas_processing_fn needs to return an iterable instead of a ' +
f'{type(records)}')

for sample in records:
try:
Expand Down Expand Up @@ -261,16 +257,15 @@ def write_mds(iterator: Iterable):
if udf_iterable is not None:
if 'columns' not in mds_kwargs:
raise ValueError(
f'If udf_iterable is specified, user must provide correct `columns` in the mds_kwargs'
)
logger.warning(
"With udf_iterable defined, it's up to the user's descretion to provide mds_kwargs[columns'"
)
f'If udf_iterable is specified, user must provide correct `columns` in the ' +
f'mds_kwargs')
logger.warning("With udf_iterable defined, it's up to the user's discretion to provide " +
"mds_kwargs[columns]'")
else:
if 'columns' not in mds_kwargs:
logger.warning(
"User's discretion required: columns arg is missing from mds_kwargs. Will be auto inferred"
)
"User's discretion required: columns arg is missing from mds_kwargs. Will be " +
'auto-inferred')
mds_kwargs['columns'] = infer_dataframe_schema(dataframe)
logger.warning(f"Auto inferred schema: {mds_kwargs['columns']}")
else:
Expand Down Expand Up @@ -302,11 +297,11 @@ def write_mds(iterator: Iterable):
if 'keep_local' in mds_kwargs and mds_kwargs['keep_local'] == False:
shutil.rmtree(cu.local, ignore_errors=True)

summ_fail_count = 0
sum_fail_count = 0
for row in partitions:
summ_fail_count += row['fail_count']
sum_fail_count += row['fail_count']

if summ_fail_count > 0:
if sum_fail_count > 0:
logger.warning(
f'Total failed records = {summ_fail_count}\nOverall records {dataframe.count()}')
return mds_path, summ_fail_count
f'Total failed records = {sum_fail_count}\nOverall records {dataframe.count()}')
return mds_path, sum_fail_count
10 changes: 5 additions & 5 deletions tests/base/converters/test_dataframe_to_mds.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,8 +216,8 @@ def test_patch_conversion_local_and_remote(self, dataframe: Any, scheme: str,
manual_integration_dir: Any):
if not MANUAL_INTEGRATION_TEST:
pytest.skip(
'Overlap with integration tests. But better figure out how to run this test suite with Mock.'
)
'Overlap with integration tests. But better figure out how to run this test ' +
'suite with Mock.')
mock_local, mock_remote = manual_integration_dir()
out = (mock_local, mock_remote)
mds_kwargs = {
Expand Down Expand Up @@ -295,9 +295,9 @@ def test_integration_conversion_local_and_remote(self, dataframe: Any,
assert (os.path.exists(os.path.join(mds_path[0],
'index.json'))), 'No merged index.json found'
else:
assert not (
os.path.exists(os.path.join(mds_path[0], 'index.json'))
), f'merged index is created at {mds_path[0]} when merge_index={merge_index} and keep_local={keep_local}'
assert not (os.path.exists(os.path.join(mds_path[0], 'index.json'))), (
f'merged index is created at {mds_path[0]} when merge_index={merge_index} and ' +
f'keep_local={keep_local}')

@pytest.mark.usefixtures('manual_integration_dir')
def test_integration_conversion_remote_only(self, dataframe: Any, manual_integration_dir: Any):
Expand Down

0 comments on commit 4ae10fe

Please sign in to comment.