diff --git a/ibis/backends/datafusion/__init__.py b/ibis/backends/datafusion/__init__.py index c742e6311864..ebc860851c5e 100644 --- a/ibis/backends/datafusion/__init__.py +++ b/ibis/backends/datafusion/__init__.py @@ -125,7 +125,7 @@ def _get_schema_using_query(self, query: str) -> sch.Schema: src = sge.Create( this=table, kind="VIEW", - expression=sg.parse_one(query, read="datafusion"), + expression=sg.parse_one(query, read=self.dialect), properties=sge.Properties(expressions=[sge.TemporaryProperty()]), ) @@ -537,13 +537,13 @@ def make_gen(): # convert the renamed + casted columns into a record batch pa.RecordBatch.from_struct_array( # rename columns to match schema because datafusion lowercases things - pa.RecordBatch.from_arrays(batch.columns, names=names) + pa.RecordBatch.from_arrays(batch.to_pyarrow().columns, names=names) # cast the struct array to the desired types to work around # https://github.com/apache/arrow-datafusion-python/issues/534 .to_struct_array() .cast(struct_schema, safe=False) ) - for batch in frame.collect() + for batch in frame.execute_stream() ) return pa.ipc.RecordBatchReader.from_batches(schema.to_pyarrow(), make_gen()) @@ -628,7 +628,8 @@ def create_table( ) ) elif obj is not None: - _read_in_memory(obj, name, self, overwrite=overwrite) + table_ident = sg.table(name, db=database, quoted=quoted).sql(self.dialect) + _read_in_memory(obj, table_ident, self, overwrite=overwrite) return self.table(name, database=database) else: query = None @@ -687,7 +688,7 @@ def truncate_table( table_loc = self._warn_and_create_table_loc(database, schema) catalog, db = self._to_catalog_db_tuple(table_loc) - ident = sg.table(name, db=db, catalog=catalog).sql(self.name) + ident = sg.table(name, db=db, catalog=catalog).sql(self.dialect) with self._safe_raw_sql(sge.delete(ident)): pass diff --git a/ibis/backends/datafusion/tests/test_register.py b/ibis/backends/datafusion/tests/test_register.py index 03bdd620641a..16a82973c7fa 100644 --- a/ibis/backends/datafusion/tests/test_register.py +++ b/ibis/backends/datafusion/tests/test_register.py @@ -50,3 +50,9 @@ def test_register_dataset(conn): with pytest.warns(FutureWarning, match="v9.1"): conn.register(dataset, "my_table") assert conn.table("my_table").x.sum().execute() == 6 + + +def test_create_table_with_uppercase_name(conn): + tab = pa.table({"x": [1, 2, 3]}) + conn.create_table("MY_TABLE", tab) + assert conn.table("MY_TABLE").x.sum().execute() == 6