diff --git a/py-polars/polars/io/database.py b/py-polars/polars/io/database.py index 5bae271947fd..76db913aeb02 100644 --- a/py-polars/polars/io/database.py +++ b/py-polars/polars/io/database.py @@ -280,23 +280,29 @@ def _from_rows( if hasattr(self.result, "fetchall"): if self.driver_name == "sqlalchemy": - column_names = ( - list(self.result._metadata.keys) - if hasattr(self.result, "_metadata") - else [desc[0] for desc in self.result.cursor.description] - ) + if hasattr(self.result, "cursor"): + cursor_desc = {d[0]: d[1] for d in self.result.cursor.description} + elif hasattr(self.result, "_metadata"): + cursor_desc = {k: None for k in self.result._metadata.keys} + else: + msg = f"Unable to determine metadata from query result; {self.result!r}" + raise ValueError(msg) else: - column_names = [desc[0] for desc in self.result.description] + cursor_desc = {d[0]: d[1] for d in self.result.description} + + # TODO: refine types based on the cursor description's type_code, + # if/where available? (for now, we just read the column names) + result_columns = list(cursor_desc) frames = ( DataFrame( data=rows, - schema=column_names, + schema=result_columns, schema_overrides=schema_overrides, orient="row", ) for rows in ( - self._fetchmany_rows(self.result, batch_size) + list(self._fetchmany_rows(self.result, batch_size)) if iter_batches else [self._fetchall_rows(self.result)] # type: ignore[list-item] ) @@ -458,9 +464,10 @@ def read_database( # noqa: D417 be a suitable "Selectable", otherwise it is expected to be a string). connection An instantiated connection (or cursor/client object) that the query can be - executed against. Can also pass a valid ODBC connection string, starting with - "Driver=", in which case the `arrow-odbc` package will be used to establish - the connection and return Arrow-native data to Polars. + executed against. Can also pass a valid ODBC connection string, identified as + such if it contains the string "Driver=", in which case the `arrow-odbc` + package will be used to establish the connection and return Arrow-native data + to Polars. iter_batches Return an iterator of DataFrames, where each DataFrame represents a batch of data returned by the query; this can be useful for processing large resultsets @@ -560,7 +567,7 @@ def read_database( # noqa: D417 """ # noqa: W505 if isinstance(connection, str): # check for odbc connection string - if re.sub(r"\s", "", connection[:20]).lower().startswith("driver="): + if re.search(r"\bdriver\s*=\s*{[^}]+?}", connection, re.IGNORECASE): try: import arrow_odbc # noqa: F401 except ModuleNotFoundError: