Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DOCS] Shorten union of Literals #3449

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions daft/.ruff.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ extend = "../.ruff.toml"

[lint]
extend-select = [
"PYI030", # unnecessary-literal-union, derived from flake8-pyi
"TID253", # banned-module-level-imports, derived from flake8-tidy-imports
"TCH" # flake8-type-checking
]
Expand Down
2 changes: 1 addition & 1 deletion daft/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@


class _RunnerConfig:
name: ClassVar[Literal["ray"] | Literal["py"] | Literal["native"]]
name: ClassVar[Literal["ray", "py", "native"]]


@dataclasses.dataclass(frozen=True)
Expand Down
2 changes: 1 addition & 1 deletion daft/daft/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -870,7 +870,7 @@ def read_parquet_into_pyarrow(
io_config: IOConfig | None = None,
multithreaded_io: bool | None = None,
coerce_int96_timestamp_unit: PyTimeUnit | None = None,
string_encoding: Literal["utf-8"] | Literal["raw"] = "utf-8",
string_encoding: Literal["utf-8", "raw"] = "utf-8",
file_timeout_ms: int | None = None,
): ...
def read_parquet_into_pyarrow_bulk(
Expand Down
4 changes: 2 additions & 2 deletions daft/dataframe/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,7 +520,7 @@ def write_parquet(
self,
root_dir: Union[str, pathlib.Path],
compression: str = "snappy",
write_mode: Union[Literal["append"], Literal["overwrite"]] = "append",
write_mode: Literal["append", "overwrite"] = "append",
partition_cols: Optional[List[ColumnInputType]] = None,
io_config: Optional[IOConfig] = None,
) -> "DataFrame":
Expand Down Expand Up @@ -592,7 +592,7 @@ def write_parquet(
def write_csv(
self,
root_dir: Union[str, pathlib.Path],
write_mode: Union[Literal["append"], Literal["overwrite"]] = "append",
write_mode: Literal["append", "overwrite"] = "append",
partition_cols: Optional[List[ColumnInputType]] = None,
io_config: Optional[IOConfig] = None,
) -> "DataFrame":
Expand Down
12 changes: 4 additions & 8 deletions daft/expressions/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -781,9 +781,7 @@ def shift_right(self, other: Expression) -> Expression:
expr = Expression._to_expression(other)
return Expression._from_pyexpr(self._expr >> expr._expr)

def count(
self, mode: Literal["all"] | Literal["valid"] | Literal["null"] | CountMode = CountMode.Valid
) -> Expression:
def count(self, mode: Literal["all", "valid", "null"] | CountMode = CountMode.Valid) -> Expression:
"""Counts the number of values in the expression.

Args:
Expand Down Expand Up @@ -1302,7 +1300,7 @@ def _override_io_config_max_connections(max_connections: int, io_config: IOConfi
def download(
self,
max_connections: int = 32,
on_error: Literal["raise"] | Literal["null"] = "raise",
on_error: Literal["raise", "null"] = "raise",
io_config: IOConfig | None = None,
use_native_downloader: bool = True,
) -> Expression:
Expand Down Expand Up @@ -3013,9 +3011,7 @@ def value_counts(self) -> Expression:
"""
return Expression._from_pyexpr(native.list_value_counts(self._expr))

def count(
self, mode: Literal["all"] | Literal["valid"] | Literal["null"] | CountMode = CountMode.Valid
) -> Expression:
def count(self, mode: Literal["all", "valid", "null"] | CountMode = CountMode.Valid) -> Expression:
"""Counts the number of elements in each list

Args:
Expand Down Expand Up @@ -3322,7 +3318,7 @@ class ExpressionImageNamespace(ExpressionNamespace):

def decode(
self,
on_error: Literal["raise"] | Literal["null"] = "raise",
on_error: Literal["raise", "null"] = "raise",
mode: str | ImageMode | None = None,
) -> Expression:
"""
Expand Down
2 changes: 1 addition & 1 deletion daft/filesystem.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def _put_fs_in_cache(protocol: str, fs: pafs.FileSystem, io_config: IOConfig | N
class ListingInfo:
path: str
size: int
type: Literal["file"] | Literal["directory"]
type: Literal["file", "directory"]
rows: int | None = None


Expand Down
2 changes: 1 addition & 1 deletion daft/runners/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@


class Runner(Generic[PartitionT]):
name: ClassVar[Literal["ray"] | Literal["py"] | Literal["native"]]
name: ClassVar[Literal["ray", "py", "native"]]

def __init__(self) -> None:
self._part_set_cache = self.initialize_partition_set_cache()
Expand Down
2 changes: 1 addition & 1 deletion daft/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -984,7 +984,7 @@ def get(self, key: Series) -> Series:
class SeriesImageNamespace(SeriesNamespace):
def decode(
self,
on_error: Literal["raise"] | Literal["null"] = "raise",
on_error: Literal["raise", "null"] = "raise",
mode: str | ImageMode | None = None,
) -> Series:
raise_on_error = False
Expand Down
2 changes: 1 addition & 1 deletion daft/table/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,7 +562,7 @@ def read_parquet_into_pyarrow(
io_config: IOConfig | None = None,
multithreaded_io: bool | None = None,
coerce_int96_timestamp_unit: TimeUnit = TimeUnit.ns(),
string_encoding: Literal["utf-8"] | Literal["raw"] = "utf-8",
string_encoding: Literal["utf-8", "raw"] = "utf-8",
file_timeout_ms: int | None = 900_000, # 15 minutes
) -> pa.Table:
fields, metadata, columns, num_rows_read = _read_parquet_into_pyarrow(
Expand Down
4 changes: 2 additions & 2 deletions daft/udf_library/url_udfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def _worker_thread_initializer() -> None:
thread_local.filesystems_cache = {}


def _download(path: str | None, on_error: Literal["raise"] | Literal["null"]) -> bytes | None:
def _download(path: str | None, on_error: Literal["raise", "null"]) -> bytes | None:
if path is None:
return None
protocol = filesystem.get_protocol_from_path(path)
Expand Down Expand Up @@ -64,7 +64,7 @@ def _warmup_fsspec_registry(urls_pylist: list[str | None]) -> None:
def download_udf(
urls,
max_worker_threads: int = 8,
on_error: Literal["raise"] | Literal["null"] = "raise",
on_error: Literal["raise", "null"] = "raise",
):
"""Downloads the contents of the supplied URLs.

Expand Down