From df71921b21b43f4d8c1be14243d1560c1832e4e4 Mon Sep 17 00:00:00 2001 From: desmondcheongzx Date: Thu, 28 Nov 2024 02:02:34 -0800 Subject: [PATCH 1/2] Shorten union of Literals --- daft/context.py | 2 +- daft/daft/__init__.pyi | 2 +- daft/dataframe/dataframe.py | 4 ++-- daft/expressions/expressions.py | 12 ++++-------- daft/filesystem.py | 2 +- daft/runners/runner.py | 2 +- daft/series.py | 2 +- daft/table/table.py | 2 +- daft/udf_library/url_udfs.py | 4 ++-- 9 files changed, 14 insertions(+), 18 deletions(-) diff --git a/daft/context.py b/daft/context.py index fa1a61a637..7f9b8b1ae6 100644 --- a/daft/context.py +++ b/daft/context.py @@ -19,7 +19,7 @@ class _RunnerConfig: - name: ClassVar[Literal["ray"] | Literal["py"] | Literal["native"]] + name: ClassVar[Literal["ray", "py", "native"]] @dataclasses.dataclass(frozen=True) diff --git a/daft/daft/__init__.pyi b/daft/daft/__init__.pyi index 6da06bb7f6..6e0e90b48f 100644 --- a/daft/daft/__init__.pyi +++ b/daft/daft/__init__.pyi @@ -870,7 +870,7 @@ def read_parquet_into_pyarrow( io_config: IOConfig | None = None, multithreaded_io: bool | None = None, coerce_int96_timestamp_unit: PyTimeUnit | None = None, - string_encoding: Literal["utf-8"] | Literal["raw"] = "utf-8", + string_encoding: Literal["utf-8", "raw"] = "utf-8", file_timeout_ms: int | None = None, ): ... def read_parquet_into_pyarrow_bulk( diff --git a/daft/dataframe/dataframe.py b/daft/dataframe/dataframe.py index 77528da220..07b20c6d98 100644 --- a/daft/dataframe/dataframe.py +++ b/daft/dataframe/dataframe.py @@ -520,7 +520,7 @@ def write_parquet( self, root_dir: Union[str, pathlib.Path], compression: str = "snappy", - write_mode: Union[Literal["append"], Literal["overwrite"]] = "append", + write_mode: Literal["append", "overwrite"] = "append", partition_cols: Optional[List[ColumnInputType]] = None, io_config: Optional[IOConfig] = None, ) -> "DataFrame": @@ -592,7 +592,7 @@ def write_parquet( def write_csv( self, root_dir: Union[str, pathlib.Path], - write_mode: Union[Literal["append"], Literal["overwrite"]] = "append", + write_mode: Literal["append", "overwrite"] = "append", partition_cols: Optional[List[ColumnInputType]] = None, io_config: Optional[IOConfig] = None, ) -> "DataFrame": diff --git a/daft/expressions/expressions.py b/daft/expressions/expressions.py index 80506f4ca5..4f186aaa83 100644 --- a/daft/expressions/expressions.py +++ b/daft/expressions/expressions.py @@ -781,9 +781,7 @@ def shift_right(self, other: Expression) -> Expression: expr = Expression._to_expression(other) return Expression._from_pyexpr(self._expr >> expr._expr) - def count( - self, mode: Literal["all"] | Literal["valid"] | Literal["null"] | CountMode = CountMode.Valid - ) -> Expression: + def count(self, mode: Literal["all", "valid", "null"] | CountMode = CountMode.Valid) -> Expression: """Counts the number of values in the expression. Args: @@ -1302,7 +1300,7 @@ def _override_io_config_max_connections(max_connections: int, io_config: IOConfi def download( self, max_connections: int = 32, - on_error: Literal["raise"] | Literal["null"] = "raise", + on_error: Literal["raise", "null"] = "raise", io_config: IOConfig | None = None, use_native_downloader: bool = True, ) -> Expression: @@ -3013,9 +3011,7 @@ def value_counts(self) -> Expression: """ return Expression._from_pyexpr(native.list_value_counts(self._expr)) - def count( - self, mode: Literal["all"] | Literal["valid"] | Literal["null"] | CountMode = CountMode.Valid - ) -> Expression: + def count(self, mode: Literal["all", "valid", "null"] | CountMode = CountMode.Valid) -> Expression: """Counts the number of elements in each list Args: @@ -3322,7 +3318,7 @@ class ExpressionImageNamespace(ExpressionNamespace): def decode( self, - on_error: Literal["raise"] | Literal["null"] = "raise", + on_error: Literal["raise", "null"] = "raise", mode: str | ImageMode | None = None, ) -> Expression: """ diff --git a/daft/filesystem.py b/daft/filesystem.py index 564c56ee9b..22438d6365 100644 --- a/daft/filesystem.py +++ b/daft/filesystem.py @@ -44,7 +44,7 @@ def _put_fs_in_cache(protocol: str, fs: pafs.FileSystem, io_config: IOConfig | N class ListingInfo: path: str size: int - type: Literal["file"] | Literal["directory"] + type: Literal["file", "directory"] rows: int | None = None diff --git a/daft/runners/runner.py b/daft/runners/runner.py index 6ffc7f136b..34a81fda6a 100644 --- a/daft/runners/runner.py +++ b/daft/runners/runner.py @@ -20,7 +20,7 @@ class Runner(Generic[PartitionT]): - name: ClassVar[Literal["ray"] | Literal["py"] | Literal["native"]] + name: ClassVar[Literal["ray", "py", "native"]] def __init__(self) -> None: self._part_set_cache = self.initialize_partition_set_cache() diff --git a/daft/series.py b/daft/series.py index 95788d5f7f..0e63c6d6c7 100644 --- a/daft/series.py +++ b/daft/series.py @@ -984,7 +984,7 @@ def get(self, key: Series) -> Series: class SeriesImageNamespace(SeriesNamespace): def decode( self, - on_error: Literal["raise"] | Literal["null"] = "raise", + on_error: Literal["raise", "null"] = "raise", mode: str | ImageMode | None = None, ) -> Series: raise_on_error = False diff --git a/daft/table/table.py b/daft/table/table.py index 66757cce2a..ba0868fdaf 100644 --- a/daft/table/table.py +++ b/daft/table/table.py @@ -562,7 +562,7 @@ def read_parquet_into_pyarrow( io_config: IOConfig | None = None, multithreaded_io: bool | None = None, coerce_int96_timestamp_unit: TimeUnit = TimeUnit.ns(), - string_encoding: Literal["utf-8"] | Literal["raw"] = "utf-8", + string_encoding: Literal["utf-8", "raw"] = "utf-8", file_timeout_ms: int | None = 900_000, # 15 minutes ) -> pa.Table: fields, metadata, columns, num_rows_read = _read_parquet_into_pyarrow( diff --git a/daft/udf_library/url_udfs.py b/daft/udf_library/url_udfs.py index 2a87cdbc59..aa97dc670f 100644 --- a/daft/udf_library/url_udfs.py +++ b/daft/udf_library/url_udfs.py @@ -21,7 +21,7 @@ def _worker_thread_initializer() -> None: thread_local.filesystems_cache = {} -def _download(path: str | None, on_error: Literal["raise"] | Literal["null"]) -> bytes | None: +def _download(path: str | None, on_error: Literal["raise", "null"]) -> bytes | None: if path is None: return None protocol = filesystem.get_protocol_from_path(path) @@ -64,7 +64,7 @@ def _warmup_fsspec_registry(urls_pylist: list[str | None]) -> None: def download_udf( urls, max_worker_threads: int = 8, - on_error: Literal["raise"] | Literal["null"] = "raise", + on_error: Literal["raise", "null"] = "raise", ): """Downloads the contents of the supplied URLs. From 0ce9a5dff03137c38b18032ee490c9e744112a66 Mon Sep 17 00:00:00 2001 From: desmondcheongzx Date: Thu, 28 Nov 2024 02:10:58 -0800 Subject: [PATCH 2/2] Add Ruff rule --- daft/.ruff.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/daft/.ruff.toml b/daft/.ruff.toml index 1acbd4af4b..ff97e9122b 100644 --- a/daft/.ruff.toml +++ b/daft/.ruff.toml @@ -2,6 +2,7 @@ extend = "../.ruff.toml" [lint] extend-select = [ + "PYI030", # unnecessary-literal-union, derived from flake8-pyi "TID253", # banned-module-level-imports, derived from flake8-tidy-imports "TCH" # flake8-type-checking ]