Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(python): Fix read_database(…,iter_batches=True) type annotations #19832

Merged
merged 2 commits into from
Nov 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 6 additions & 15 deletions py-polars/polars/io/database/_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,26 +29,17 @@

from polars.io.database._arrow_registry import ArrowDriverProperties

if sys.version_info >= (3, 10):
from typing import TypeAlias
else:
from typing_extensions import TypeAlias

if sys.version_info >= (3, 11):
from typing import Self
else:
from typing_extensions import Self

from sqlalchemy.sql.elements import TextClause
from sqlalchemy.sql.expression import Selectable

from polars import DataFrame
from polars._typing import ConnectionOrCursor, Cursor, SchemaDict

try:
from sqlalchemy.sql.expression import Selectable
except ImportError:
Selectable: TypeAlias = Any # type: ignore[no-redef]

from sqlalchemy.sql.elements import TextClause

_INVALID_QUERY_TYPES = {
"ALTER",
"ANALYZE",
Expand Down Expand Up @@ -207,7 +198,7 @@ def _from_arrow(
iter_batches: bool,
schema_overrides: SchemaDict | None,
infer_schema_length: int | None,
) -> DataFrame | Iterable[DataFrame] | None:
) -> DataFrame | Iterator[DataFrame] | None:
"""Return resultset data in Arrow format for frame init."""
from polars import DataFrame

Expand Down Expand Up @@ -253,7 +244,7 @@ def _from_rows(
iter_batches: bool,
schema_overrides: SchemaDict | None,
infer_schema_length: int | None,
) -> DataFrame | Iterable[DataFrame] | None:
) -> DataFrame | Iterator[DataFrame] | None:
"""Return resultset data row-wise for frame init."""
from polars import DataFrame

Expand Down Expand Up @@ -529,7 +520,7 @@ def to_polars(
batch_size: int | None = None,
schema_overrides: SchemaDict | None = None,
infer_schema_length: int | None = N_INFER_DEFAULT,
) -> DataFrame | Iterable[DataFrame]:
) -> DataFrame | Iterator[DataFrame]:
"""
Convert the result set to a DataFrame.

Expand Down
11 changes: 0 additions & 11 deletions py-polars/polars/io/database/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,11 @@
from polars.dependencies import import_optional

if TYPE_CHECKING:
import sys
from collections.abc import Coroutine

if sys.version_info >= (3, 10):
from typing import TypeAlias
else:
from typing_extensions import TypeAlias

from polars import DataFrame
from polars._typing import SchemaDict

try:
from sqlalchemy.sql.expression import Selectable
except ImportError:
Selectable: TypeAlias = Any # type: ignore[no-redef]


def _run_async(co: Coroutine[Any, Any, Any]) -> Any:
"""Run asynchronous code as if it was synchronous."""
Expand Down
22 changes: 6 additions & 16 deletions py-polars/polars/io/database/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,24 +9,14 @@
from polars.io.database._executor import ConnectionExecutor

if TYPE_CHECKING:
import sys
from collections.abc import Iterable
from collections.abc import Iterator

if sys.version_info >= (3, 10):
from typing import TypeAlias
else:
from typing_extensions import TypeAlias
from sqlalchemy.sql.elements import TextClause
from sqlalchemy.sql.expression import Selectable

from polars import DataFrame
from polars._typing import ConnectionOrCursor, DbReadEngine, SchemaDict

try:
from sqlalchemy.sql.expression import Selectable
except ImportError:
Selectable: TypeAlias = Any # type: ignore[no-redef]

from sqlalchemy.sql.elements import TextClause


@overload
def read_database(
Expand All @@ -51,7 +41,7 @@ def read_database(
schema_overrides: SchemaDict | None = ...,
infer_schema_length: int | None = ...,
execute_options: dict[str, Any] | None = ...,
) -> Iterable[DataFrame]: ...
) -> Iterator[DataFrame]: ...


@overload
Expand All @@ -64,7 +54,7 @@ def read_database(
schema_overrides: SchemaDict | None = ...,
infer_schema_length: int | None = ...,
execute_options: dict[str, Any] | None = ...,
) -> DataFrame | Iterable[DataFrame]: ...
) -> DataFrame | Iterator[DataFrame]: ...


def read_database(
Expand All @@ -76,7 +66,7 @@ def read_database(
schema_overrides: SchemaDict | None = None,
infer_schema_length: int | None = N_INFER_DEFAULT,
execute_options: dict[str, Any] | None = None,
) -> DataFrame | Iterable[DataFrame]:
) -> DataFrame | Iterator[DataFrame]:
"""
Read the results of a SQL query into a DataFrame, given a connection object.

Expand Down