Skip to content

Commit

Permalink
Fix lints
Browse files Browse the repository at this point in the history
  • Loading branch information
XiaohanZhangCMU committed Oct 10, 2023
1 parent bde4641 commit 24f567d
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 25 deletions.
9 changes: 8 additions & 1 deletion streaming/base/converters/dataframe_to_mds.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,16 @@ def dataframeToMDS(dataframe: DataFrame,
mds_kwargs: Optional[Dict[str, Any]] = None,
udf_iterable: Optional[Callable] = None,
udf_kwargs: Optional[Dict[str, Any]] = None) -> Tuple[Any, int]:
logger.warning("This signature is deprecated. Use dataframe_to_mds with the same arguments going forward.")
"""Deprecated API Signature.
To be replaced by dataframe_to_mds
"""
logger.warning(
'This signature is deprecated. Use dataframe_to_mds with the same arguments going forward.'
)
return dataframe_to_mds(dataframe, merge_index, mds_kwargs, udf_iterable, udf_kwargs)


def dataframe_to_mds(dataframe: DataFrame,
merge_index: bool = True,
mds_kwargs: Optional[Dict[str, Any]] = None,
Expand Down
9 changes: 6 additions & 3 deletions streaming/base/storage/upload.py
Original file line number Diff line number Diff line change
Expand Up @@ -593,12 +593,15 @@ def list_objects(self, prefix: Optional[str] = None) -> Optional[List[str]]:
response_complete = True
return object_names
except Exception as e:
if isinstance(e, oci.exceptions.ServiceError):
if isinstance(e, oci.exceptions.ServiceError): # type: ignore
if e.status == 404: # type: ignore
if e.code == 'ObjectNotFound': # type: ignore
raise FileNotFoundError(f'Object {bucket_name}/{prefix} not found. {e.message}') from e # type: ignore
raise FileNotFoundError(
f'Object {bucket_name}/{prefix} not found. {e.message}' # type: ignore
) from e # type: ignore
if e.code == 'BucketNotFound': # type: ignore
raise ValueError(f'Bucket {bucket_name} not found. {e.message}') from e # type: ignore
raise ValueError(
f'Bucket {bucket_name} not found. {e.message}') from e # type: ignore
raise e
raise e
return []
Expand Down
2 changes: 1 addition & 1 deletion streaming/base/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def get_import_exception_message(package_name: str, extra_deps: str) -> str:


def merge_index(*args: Any, **kwargs: Any):
"""Redirect to one of two merge_index functions based on arguments"""
"""Redirect to one of two merge_index functions based on arguments."""
if isinstance(args[0], list) and len(args) + len(kwargs) in [2, 3, 4]:
return _merge_index_from_list(*args, **kwargs)
elif (isinstance(args[0], str) or
Expand Down
2 changes: 1 addition & 1 deletion tests/base/converters/test_integratoin_dataframe_to_mds.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def _method(cloud_prefix: str = 'gs://') -> Tuple[str, str]:
)
print(f'Deleted {len(response.data.objects)} objects with prefix: {MY_PREFIX}')

except :
except:
print('tear down oci test folder failed, continue...')


Expand Down
32 changes: 13 additions & 19 deletions tests/test_integration_util.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,17 @@
# Copyright 2023 MosaicML Streaming authors
# SPDX-License-Identifier: Apache-2.0

import json
import os
import shutil
import tempfile
import time
import urllib.parse
from multiprocessing.shared_memory import SharedMemory as BuiltinSharedMemory
from typing import Any, List, Optional, Tuple, Union
from typing import Any, Tuple

import pytest

from streaming.base.constant import RESUME
from streaming.base.shared.prefix import _get_path
from streaming.base.storage.download import download_file
from streaming.base.storage.upload import CloudUploader
from streaming.base.util import (bytes_to_int, clean_stale_shared_memory, get_list_arg,
merge_index, number_abbrev_to_int, retry)

from streaming.base.util import merge_index
from tests.test_util import integrity_check

MY_PREFIX = 'train_' + str(time.time())
Expand Down Expand Up @@ -75,7 +68,7 @@ def _method(cloud_prefix: str = 'gs://') -> Tuple[str, str]:
if objects_to_delete:
s3.delete_objects(Bucket=MY_BUCKET['s3://'],
Delete={'Objects': objects_to_delete})
except :
except:
print('tear down s3 test folder failed, continue....')

try:
Expand All @@ -97,11 +90,10 @@ def _method(cloud_prefix: str = 'gs://') -> Tuple[str, str]:
)
print(f'Deleted {len(response.data.objects)} objects with prefix: {MY_PREFIX}')

except :
except:
print('tear down oci test folder failed, continue...')



@pytest.mark.parametrize('scheme', ['oci://', 'gs://', 's3://', 'dbfs:/Volumes'])
@pytest.mark.parametrize('index_file_urls_pattern', [4, 5])
@pytest.mark.parametrize('out_format', ['remote', 'local', 'tuple'])
Expand Down Expand Up @@ -329,14 +321,16 @@ def test_merge_index_from_root_remote(manual_integration_dir: Any, out_format: s


@pytest.mark.parametrize('scheme', ['dbfs:/Volumes'])
@pytest.mark.parametrize('out_format', ['remote']) # , 'tuple'])
@pytest.mark.parametrize('n_partitions', [3]) # , 2, 3, 4])
@pytest.mark.parametrize('keep_local', [False]) # , True])
def test_uc_volume(manual_integration_dir: Any, out_format: str,
n_partitions: int, keep_local: bool, scheme: str):
@pytest.mark.parametrize('out_format', ['remote']) # , 'tuple'])
@pytest.mark.parametrize('n_partitions', [3]) # , 2, 3, 4])
@pytest.mark.parametrize('keep_local', [False]) # , True])
def test_uc_volume(manual_integration_dir: Any, out_format: str, n_partitions: int,
keep_local: bool, scheme: str):
from decimal import Decimal

from pyspark.sql import SparkSession
from pyspark.sql.types import DecimalType, IntegerType, StringType, StructField, StructType

from streaming.base.converters import dataframeToMDS

if out_format == 'remote':
Expand All @@ -362,6 +356,6 @@ def test_uc_volume(manual_integration_dir: Any, out_format: str,

mds_path, _ = dataframeToMDS(df, merge_index=True, mds_kwargs=mds_kwargs)

with pytest.raises(NotImplementedError, match=f'DatabricksUnityCatalogUploader.list_objects is not implemented.*'):
with pytest.raises(NotImplementedError,
match=f'DatabricksUnityCatalogUploader.list_objects is not implemented.*'):
merge_index(mds_path, keep_local=keep_local)

0 comments on commit 24f567d

Please sign in to comment.