From b70479385c78e288c360a55d40af194677875c58 Mon Sep 17 00:00:00 2001 From: Eduardo Blancas Date: Thu, 29 Jun 2023 10:28:50 -0600 Subject: [PATCH] fix --- src/sql/connection.py | 2 +- src/sql/inspect.py | 4 ++++ src/sql/run.py | 2 +- src/tests/test_magic_cmd.py | 18 ++++++++++++------ 4 files changed, 18 insertions(+), 8 deletions(-) diff --git a/src/sql/connection.py b/src/sql/connection.py index 58352f797..f151921fe 100644 --- a/src/sql/connection.py +++ b/src/sql/connection.py @@ -653,7 +653,7 @@ def execute(self, query, with_=None): Executes SQL query on a given connection """ query = self._prepare_query(query, with_) - return self.session.execute(query) + return self.engine.execute(query) atexit.register(Connection.close_all, verbose=True) diff --git a/src/sql/inspect.py b/src/sql/inspect.py index 1d7cfc103..ada6b16b3 100644 --- a/src/sql/inspect.py +++ b/src/sql/inspect.py @@ -239,11 +239,15 @@ def __init__(self, table_name, schema=None) -> None: columns_query_result = sql.run.raw_run( Connection.current, f"SELECT * FROM {table_name} WHERE 1=0" ) + if Connection.is_custom_connection(): columns = [i[0] for i in columns_query_result.description] else: columns = columns_query_result.keys() + # TODO: abstract it internally + columns_query_result.close() + table_stats = dict({}) columns_to_include_in_report = set() columns_with_styles = [] diff --git a/src/sql/run.py b/src/sql/run.py index 229334cc7..9e492420e 100644 --- a/src/sql/run.py +++ b/src/sql/run.py @@ -662,7 +662,7 @@ def _first_word(sql): def raw_run(conn, sql): - return conn.session.execute(sqlalchemy.sql.text(sql)) + return conn.engine.execute(sqlalchemy.sql.text(sql)) class PrettyTable(prettytable.PrettyTable): diff --git a/src/tests/test_magic_cmd.py b/src/tests/test_magic_cmd.py index d9b7ef97b..ea524741b 100644 --- a/src/tests/test_magic_cmd.py +++ b/src/tests/test_magic_cmd.py @@ -1,11 +1,10 @@ +import sqlite3 import sys import math import pytest from IPython.core.error import UsageError from pathlib import Path -from sqlalchemy import create_engine -from sql.connection import Connection from sql.store import store from sql.inspect import _is_numeric @@ -112,8 +111,15 @@ def test_tables(ip): def test_tables_with_schema(ip, tmp_empty): - conn = Connection(engine=create_engine("sqlite:///my.db")) - conn.execute("CREATE TABLE numbers (some_number FLOAT)") + # TODO: why does this fail? + # ip.run_cell( + # """%%sql sqlite:///my.db + # CREATE TABLE numbers (some_number FLOAT) + # """ + # ) + + with sqlite3.connect("my.db") as conn: + conn.execute("CREATE TABLE numbers (some_number FLOAT)") ip.run_cell( """%%sql @@ -150,8 +156,8 @@ def test_columns(ip, cmd, cols): def test_columns_with_schema(ip, tmp_empty): - conn = Connection(engine=create_engine("sqlite:///my.db")) - conn.execute("CREATE TABLE numbers (some_number FLOAT)") + with sqlite3.connect("my.db") as conn: + conn.execute("CREATE TABLE numbers (some_number FLOAT)") ip.run_cell( """%%sql