Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integration test for dataframe_to_mds and merge_index #465

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
62 commits
Select commit Hold shift + click to select a range
859f908
First commit
XiaohanZhangCMU Sep 26, 2023
4575174
Add a naive mds datasts
XiaohanZhangCMU Sep 26, 2023
6b31640
fix lints
XiaohanZhangCMU Sep 26, 2023
b4a0ff7
Fix
XiaohanZhangCMU Sep 26, 2023
afe0399
Merge branch 'merge_index_util' into use_merge_index_in_dataframeToMDS
XiaohanZhangCMU Sep 26, 2023
2f37037
Change dataframeToMDS API to use merge_util helper
XiaohanZhangCMU Sep 26, 2023
ed9e8d0
Fix unit tests
XiaohanZhangCMU Sep 26, 2023
af9b6dd
Fix tests
XiaohanZhangCMU Sep 26, 2023
c45ceb9
Fix lints
XiaohanZhangCMU Sep 26, 2023
8b7db39
Address a few comments
XiaohanZhangCMU Sep 28, 2023
091518c
update
XiaohanZhangCMU Sep 29, 2023
b0219c4
updates
XiaohanZhangCMU Sep 29, 2023
8109eaf
Merge retry PR. fix conflicts
XiaohanZhangCMU Sep 29, 2023
1678844
Address comments
XiaohanZhangCMU Sep 29, 2023
400050e
update unit tests
XiaohanZhangCMU Sep 29, 2023
72430be
Update tests
XiaohanZhangCMU Sep 29, 2023
69857d4
unit tests + pre-commit ok
XiaohanZhangCMU Sep 29, 2023
bf2d4ef
Add list objects for oci, gs, s3
XiaohanZhangCMU Sep 29, 2023
cf0fe95
fix tests
XiaohanZhangCMU Sep 29, 2023
5e0a1b3
Fix lints
XiaohanZhangCMU Sep 29, 2023
cbcb352
list_objects returns only basename
XiaohanZhangCMU Sep 30, 2023
81c3b88
Fix lints
XiaohanZhangCMU Sep 30, 2023
a1185d3
fix bugs in list_objects
XiaohanZhangCMU Sep 30, 2023
1972302
updates
XiaohanZhangCMU Oct 2, 2023
38f495f
Fix lints
XiaohanZhangCMU Oct 2, 2023
a5a4ffc
use new list_objects
XiaohanZhangCMU Oct 3, 2023
00e6a8c
Fix lints
XiaohanZhangCMU Oct 3, 2023
22429cf
remove
XiaohanZhangCMU Oct 3, 2023
5301563
Add merge_index
XiaohanZhangCMU Oct 4, 2023
36dff13
remove materialized test dataset
XiaohanZhangCMU Oct 4, 2023
691a588
Change do_merge_index to merge_index_from_list
XiaohanZhangCMU Oct 4, 2023
05a8f32
Fix lints
XiaohanZhangCMU Oct 4, 2023
ed48f86
Change merge_index to auto_merge_index to avoid duplicate naming
XiaohanZhangCMU Oct 5, 2023
34799a8
update pytest yaml
XiaohanZhangCMU Oct 5, 2023
85a4a6d
update
XiaohanZhangCMU Oct 5, 2023
8a4d43b
update
XiaohanZhangCMU Oct 5, 2023
5f1a63b
Fix lints
XiaohanZhangCMU Oct 5, 2023
320fa8d
Make merge_index a wrapper
XiaohanZhangCMU Oct 6, 2023
22a9cc4
add print
XiaohanZhangCMU Oct 6, 2023
f8afb66
Change fail msg for missing local file and invalid remote url
XiaohanZhangCMU Oct 6, 2023
e9d82a1
update msg
XiaohanZhangCMU Oct 6, 2023
e0e0343
remove print
XiaohanZhangCMU Oct 6, 2023
6b0e6d8
Fix lints
XiaohanZhangCMU Oct 6, 2023
a95e34b
Add warning msg for exist_ok=True
XiaohanZhangCMU Oct 7, 2023
2de66e2
Address comments
XiaohanZhangCMU Oct 7, 2023
8e6df9c
fix lints
XiaohanZhangCMU Oct 7, 2023
36b4369
Turn off manual integratin
XiaohanZhangCMU Oct 7, 2023
828e74d
Address comments
XiaohanZhangCMU Oct 9, 2023
8e616d8
Update
XiaohanZhangCMU Oct 9, 2023
282973a
updates
XiaohanZhangCMU Oct 10, 2023
ebacc87
Fix lints
XiaohanZhangCMU Oct 10, 2023
9228f34
Merge branch 'main' into merge_index_util
XiaohanZhangCMU Oct 10, 2023
90dccce
remove integration tests
XiaohanZhangCMU Oct 10, 2023
9f2dd04
Fix lints
XiaohanZhangCMU Oct 10, 2023
013a97b
Add specific exceptions to oci list_objects
XiaohanZhangCMU Oct 10, 2023
2c214c8
Fix comments
XiaohanZhangCMU Oct 10, 2023
224cba6
Add deprecated warning for dataframeToMDS
XiaohanZhangCMU Oct 10, 2023
d806cbf
Fix remote url for /Volume
XiaohanZhangCMU Oct 10, 2023
b200267
Add integration tests to dataframe_to_mds and merge_index utility
XiaohanZhangCMU Oct 10, 2023
bde4641
Test databricks-sdk 0.10.0
XiaohanZhangCMU Oct 10, 2023
24f567d
Fix lints
XiaohanZhangCMU Oct 10, 2023
946ff7c
change back to 0.8.0
XiaohanZhangCMU Oct 10, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions .github/workflows/pytest.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,11 @@ jobs:
id: tests
run: |
set -ex
pytest --splits 7 --group 1 --cov-fail-under=10
pytest --splits 7 --group 2 --cov-fail-under=10
pytest --splits 7 --group 3 --cov-fail-under=10
pytest --splits 7 --group 4 --cov-fail-under=10
pytest --splits 7 --group 5 --cov-fail-under=10
pytest --splits 7 --group 6 --cov-fail-under=10
pytest --splits 7 --group 7 --cov-fail-under=10
pytest --splits 8 --group 1 --cov-fail-under=10
pytest --splits 8 --group 2 --cov-fail-under=10
pytest --splits 8 --group 3 --cov-fail-under=10
pytest --splits 8 --group 4 --cov-fail-under=10
pytest --splits 8 --group 5 --cov-fail-under=10
pytest --splits 8 --group 6 --cov-fail-under=10
pytest --splits 8 --group 7 --cov-fail-under=10
pytest --splits 8 --group 8 --cov-fail-under=10
98 changes: 40 additions & 58 deletions streaming/base/converters/dataframe_to_mds.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,16 @@

"""A utility to convert spark dataframe to MDS."""

import json
import logging
import os
import shutil
from collections.abc import Iterable
from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Union
from typing import Any, Callable, Dict, Iterable, Optional, Tuple

import pandas as pd

from streaming.base.util import get_import_exception_message
from streaming.base.util import merge_index as do_merge_index

try:
from pyspark import TaskContext
Expand Down Expand Up @@ -119,52 +119,26 @@ def map_spark_dtype(spark_data_type: Any) -> str:
return schema_dict


def do_merge_index(partitions: Iterable, mds_path: Union[str, Tuple[str, str]]) -> None:
"""Merge index.json from partitions into one for streaming.

Args:
partitions (Iterable): partitions that contain pd.DataFrame
mds_path (Union[str, Tuple[str, str]]): (str,str)=(local,remote), str = local or remote
based on parse_uri(url) result
"""
if not partitions:
logger.warning('No partitions exist, no index merged')
return

shards = []

for row in partitions:
mds_partition_index = f'{row.mds_path}/{get_index_basename()}'
mds_partition_basename = os.path.basename(row.mds_path)
obj = json.load(open(mds_partition_index))
for i in range(len(obj['shards'])):
shard = obj['shards'][i]
for key in ('raw_data', 'zip_data'):
if shard.get(key):
basename = shard[key]['basename']
obj['shards'][i][key]['basename'] = os.path.join(mds_partition_basename,
basename)
shards += obj['shards']

obj = {
'version': 2,
'shards': shards,
}

if isinstance(mds_path, str):
mds_index = os.path.join(mds_path, get_index_basename())
else:
mds_index = os.path.join(mds_path[0], get_index_basename())

with open(mds_index, 'w') as out:
json.dump(obj, out)


def dataframeToMDS(dataframe: DataFrame,
merge_index: bool = True,
mds_kwargs: Optional[Dict[str, Any]] = None,
udf_iterable: Optional[Callable] = None,
udf_kwargs: Optional[Dict[str, Any]] = None) -> Tuple[Any, int]:
"""Deprecated API Signature.

To be replaced by dataframe_to_mds
"""
logger.warning(
'This signature is deprecated. Use dataframe_to_mds with the same arguments going forward.'
)
return dataframe_to_mds(dataframe, merge_index, mds_kwargs, udf_iterable, udf_kwargs)


def dataframe_to_mds(dataframe: DataFrame,
merge_index: bool = True,
mds_kwargs: Optional[Dict[str, Any]] = None,
udf_iterable: Optional[Callable] = None,
udf_kwargs: Optional[Dict[str, Any]] = None) -> Tuple[Any, int]:
"""Execute a spark dataframe to MDS conversion process.

This method orchestrates the conversion of a spark dataframe into MDS format by processing the
Expand Down Expand Up @@ -194,19 +168,20 @@ def dataframeToMDS(dataframe: DataFrame,
"""

def write_mds(iterator: Iterable):
"""Worker node writes iterable to MDS datasets locally."""
context = TaskContext.get()

if context is not None:
id = context.taskAttemptId()
else:
raise RuntimeError('TaskContext.get() returns None')

if isinstance(mds_path, str): # local
output = os.path.join(mds_path, f'{id}')
out_file_path = output
if mds_path[1] == '': # only local
output = os.path.join(mds_path[0], f'{id}')
partition_path = (output, '')
else:
output = (os.path.join(mds_path[0], f'{id}'), os.path.join(mds_path[1], f'{id}'))
out_file_path = output[0]
partition_path = output

if mds_kwargs:
kwargs = mds_kwargs.copy()
Expand All @@ -215,7 +190,7 @@ def write_mds(iterator: Iterable):
kwargs = {}

if merge_index:
kwargs['keep_local'] = True # need to keep local to do merge
kwargs['keep_local'] = True # need to keep workers' locals to do merge

count = 0

Expand All @@ -237,10 +212,17 @@ def write_mds(iterator: Iterable):
raise RuntimeError(f'failed to write sample: {sample}') from ex
count += 1

yield pd.concat(
[pd.Series([out_file_path], name='mds_path'),
pd.Series([count], name='fail_count')],
axis=1)
yield pd.concat([
pd.Series([os.path.join(partition_path[0], get_index_basename())],
name='mds_path_local'),
pd.Series([
os.path.join(partition_path[1], get_index_basename())
if partition_path[1] != '' else ''
],
name='mds_path_remote'),
pd.Series([count], name='fail_count')
],
axis=1)

if dataframe is None or dataframe.isEmpty():
raise ValueError(f'Input dataframe is None or Empty!')
Expand Down Expand Up @@ -275,25 +257,25 @@ def write_mds(iterator: Iterable):
keep_local = False if 'keep_local' not in mds_kwargs else mds_kwargs['keep_local']
cu = CloudUploader.get(out, keep_local=keep_local)

# Fix output format as mds_path: Tuple => remote Str => local only
# Fix output format as mds_path: Tuple(local, remote)
if cu.remote is None:
mds_path = cu.local
mds_path = (cu.local, '')
else:
mds_path = (cu.local, cu.remote)

# Prepare partition schema
result_schema = StructType([
StructField('mds_path', StringType(), False),
StructField('mds_path_local', StringType(), False),
StructField('mds_path_remote', StringType(), False),
StructField('fail_count', IntegerType(), False)
])
partitions = dataframe.mapInPandas(func=write_mds, schema=result_schema).collect()

if merge_index:
do_merge_index(partitions, mds_path)
index_files = [(row['mds_path_local'], row['mds_path_remote']) for row in partitions]
do_merge_index(index_files, out, keep_local=keep_local, download_timeout=60)

if cu.remote is not None:
if merge_index:
cu.upload_file(get_index_basename())
if 'keep_local' in mds_kwargs and mds_kwargs['keep_local'] == False:
shutil.rmtree(cu.local, ignore_errors=True)

Expand Down
Loading
Loading