From 6da5cc5b47dad363215bdc42378b5a2c43c7d72e Mon Sep 17 00:00:00 2001 From: richard gilmore Date: Mon, 18 Dec 2023 01:09:04 +0000 Subject: [PATCH 01/29] add basic spark support to library --- src/sql/command.py | 4 +- src/sql/connection/__init__.py | 2 + src/sql/connection/connection.py | 68 ++++++++++++++++++++++++++++++++ src/sql/run/run.py | 6 +++ 4 files changed, 78 insertions(+), 2 deletions(-) diff --git a/src/sql/command.py b/src/sql/command.py index 7b7bd168b..894ad10e3 100644 --- a/src/sql/command.py +++ b/src/sql/command.py @@ -5,7 +5,7 @@ from sql import parse, exceptions from sql.store import store -from sql.connection import ConnectionManager, is_pep249_compliant +from sql.connection import ConnectionManager, is_pep249_compliant, is_spark from sql.util import validate_nonidentifier_connection @@ -49,7 +49,7 @@ def __init__(self, magic, user_ns, line, cell) -> None: if ( one_arg and self.args.line[0] in user_ns - and (isinstance(user_ns[self.args.line[0]], Engine) or is_dbapi_connection_) + and (isinstance(user_ns[self.args.line[0]], Engine) or is_dbapi_connection_ or is_spark) ): line_for_command = [] add_conn = True diff --git a/src/sql/connection/__init__.py b/src/sql/connection/__init__.py index 7c48e624b..259b549d4 100644 --- a/src/sql/connection/__init__.py +++ b/src/sql/connection/__init__.py @@ -2,7 +2,9 @@ ConnectionManager, SQLAlchemyConnection, DBAPIConnection, + SparkConnectConnection, is_pep249_compliant, + is_spark, PLOOMBER_DOCS_LINK_STR, default_alias_for_engine, ResultSetCollection, diff --git a/src/sql/connection/connection.py b/src/sql/connection/connection.py index 678ca30de..203462082 100644 --- a/src/sql/connection/connection.py +++ b/src/sql/connection/connection.py @@ -5,6 +5,7 @@ from difflib import get_close_matches import atexit from functools import partial +import pyspark import sqlalchemy from sqlalchemy.engine import Engine @@ -21,6 +22,7 @@ import sqlparse from ploomber_core.exceptions import modify_exceptions +from pyspark.sql import SparkSession from sql.store import store from sql.telemetry import telemetry @@ -257,6 +259,8 @@ def set( ) elif is_pep249_compliant(descriptor): cls.current = DBAPIConnection(descriptor, config=config, alias=alias) + elif is_spark(descriptor): + cls.current = SparkConnectConnection(descriptor) else: existing = rough_dict_get(cls.connections, descriptor) if existing and existing.alias == alias: @@ -1060,6 +1064,68 @@ def to_table(self, table_name, data_frame, if_exists, index, schema=None): ) +class SparkConnectConnection(AbstractConnection): + + @telemetry.log_call("SparkConnectConnection", payload=True) + def __init__(self, payload, connection: SparkSession, alias=None, config=None): + self._driver = None + + # TODO: implement the dialect blacklist and add unit tests + self._requires_manual_commit = True if config is None else config.autocommit + + self._connection = connection + self._connection_class_name = type(connection).__name__ + + # calling init from AbstractConnection must be the last thing we do as it + # register the connection + super().__init__(self._connection_class_name) + + # TODO: delete this + self.name = self._connection_class_name + + @property + def dialect(self): + """Returns a string with the SQL dialect name""" + return "spark" + + + def raw_execute(self, query, parameters=None): + """Run the query without any pre-processing""" + return self._connection.sql(query) + + + def _get_database_information(self): + """ + Get the dialect, driver, and database server version info of current + connection + """ + return { + "dialect": self.dialect, + "driver": self._connection_class_name, + "server_version_info": None, + } + + @property + def url(self): + """Returns None since DBAPI connections don't have a url""" + return None + + @property + def connection_sqlalchemy(self): + """ + Raises NotImplementedError since DBAPI connections don't have a SQLAlchemy + connection object + """ + raise NotImplementedError( + "This feature is only available for SQLAlchemy connections" + ) + + def to_table(self, table_name, data_frame, if_exists, index, schema=None): + raise exceptions.NotImplementedError( + "--persist/--persist-replace is not available for DBAPI connections" + " (only available for SQLAlchemy connections)" + ) + def _check_if_duckdb_dbapi_connection(conn): """Check if the connection is a native duckdb connection""" # NOTE: duckdb defines df and pl to efficiently convert results to @@ -1153,6 +1219,8 @@ def is_pep249_compliant(conn): return True +def is_spark(ins): + return isinstance(ins,pyspark.sql.connect.session.SparkSession) or isinstance(ins,pyspark.sql.SparkSession) def default_alias_for_engine(engine): if not engine.url.username: diff --git a/src/sql/run/run.py b/src/sql/run/run.py index 11312c450..a55c2c96d 100644 --- a/src/sql/run/run.py +++ b/src/sql/run/run.py @@ -49,6 +49,9 @@ def run_statements(conn, sql, config, parameters=None): if first_word.startswith("\\") and is_postgres_or_redshift(conn.dialect): result = handle_postgres_special(conn, statement) + if is_spark(conn.dialect): + return conn.raw_execute(statement, parameters=parameters) + # regular query else: result = conn.raw_execute(statement, parameters=parameters) @@ -68,6 +71,9 @@ def is_postgres_or_redshift(dialect): """Checks if dialect is postgres or redshift""" return "postgres" in str(dialect) or "redshift" in str(dialect) +def is_spark(dialect): + return "spark" in str(dialect) + def select_df_type(resultset, config): """ From 280b646023037a9cab4313a853b8042b6120f10a Mon Sep 17 00:00:00 2001 From: richard gilmore Date: Mon, 18 Dec 2023 22:00:02 +0000 Subject: [PATCH 02/29] adding tests --- setup.py | 1 + src/sql/connection/connection.py | 34 +++++++++++++++++++-------- src/tests/integration/conftest.py | 8 +++++++ src/tests/test_connection.py | 38 +++++++++++++++++++++++++++++++ 4 files changed, 72 insertions(+), 9 deletions(-) diff --git a/setup.py b/setup.py index 2e1b87ad3..7d65357ae 100644 --- a/setup.py +++ b/setup.py @@ -35,6 +35,7 @@ "pandas", # previously pinned to 2.0.3 "polars==0.17.2", # 04/18/23 this breaks our CI "pyarrow", + "pyspark" "invoke", "pkgmt", "twine", diff --git a/src/sql/connection/connection.py b/src/sql/connection/connection.py index 203462082..3f3e58f42 100644 --- a/src/sql/connection/connection.py +++ b/src/sql/connection/connection.py @@ -5,7 +5,6 @@ from difflib import get_close_matches import atexit from functools import partial -import pyspark import sqlalchemy from sqlalchemy.engine import Engine @@ -17,12 +16,18 @@ InternalError, ProgrammingError, ) + +try: + from pyspark.sql.connect.session import SparkSession as CSparkSession + from pyspark.sql import SparkSession +except ModuleNotFoundError: + CSparkSession = None + SparkSession = None from IPython.core.error import UsageError import sqlglot import sqlparse from ploomber_core.exceptions import modify_exceptions -from pyspark.sql import SparkSession from sql.store import store from sql.telemetry import telemetry @@ -1065,9 +1070,15 @@ def to_table(self, table_name, data_frame, if_exists, index, schema=None): class SparkConnectConnection(AbstractConnection): + + is_dbapi_connection = False @telemetry.log_call("SparkConnectConnection", payload=True) - def __init__(self, payload, connection: SparkSession, alias=None, config=None): + def __init__(self, payload, connection, alias=None, config=None): + try: + payload["engine"] = type(connection) + except Exception as e: + payload["engine_parsing_error"] = str(e) self._driver = None # TODO: implement the dialect blacklist and add unit tests @@ -1080,7 +1091,7 @@ def __init__(self, payload, connection: SparkSession, alias=None, config=None): # register the connection super().__init__(self._connection_class_name) - # TODO: delete this + self.name = self._connection_class_name @property @@ -1102,18 +1113,18 @@ def _get_database_information(self): return { "dialect": self.dialect, "driver": self._connection_class_name, - "server_version_info": None, + "server_version_info": self._connection.version(), } @property def url(self): - """Returns None since DBAPI connections don't have a url""" + """Returns None since Spark connections don't have a url""" return None @property def connection_sqlalchemy(self): """ - Raises NotImplementedError since DBAPI connections don't have a SQLAlchemy + Raises NotImplementedError since Spark connections don't have a SQLAlchemy connection object """ raise NotImplementedError( @@ -1122,9 +1133,14 @@ def connection_sqlalchemy(self): def to_table(self, table_name, data_frame, if_exists, index, schema=None): raise exceptions.NotImplementedError( - "--persist/--persist-replace is not available for DBAPI connections" + "--persist/--persist-replace is not available for Spark connections" " (only available for SQLAlchemy connections)" ) + + def close(self): + """Close the connection""" + # NOTE: spark is often shared outside sql, allow user to manage closure + pass def _check_if_duckdb_dbapi_connection(conn): """Check if the connection is a native duckdb connection""" @@ -1220,7 +1236,7 @@ def is_pep249_compliant(conn): return True def is_spark(ins): - return isinstance(ins,pyspark.sql.connect.session.SparkSession) or isinstance(ins,pyspark.sql.SparkSession) + return (CSparkSession is not None and isinstance(ins,CSparkSession)) or (SparkSession is not None and isinstance(ins,SparkSession)) def default_alias_for_engine(engine): if not engine.url.username: diff --git a/src/tests/integration/conftest.py b/src/tests/integration/conftest.py index e0a0b5e7f..0a13810f1 100644 --- a/src/tests/integration/conftest.py +++ b/src/tests/integration/conftest.py @@ -2,6 +2,7 @@ from pathlib import Path import shutil import pandas as pd +from pyspark.sql import SparkSession import pytest from sqlalchemy import MetaData, Table, create_engine, text import uuid @@ -287,6 +288,13 @@ def setup_duckDB_native(test_table_name_dict): yield conn conn.close() +@pytest.fixture(scope="session") +def setup_spark(): + spark = SparkSession.Builder.master("local").getOrCreate() + yield spark + spark.stop() + + def load_generic_testing_data_duckdb_native(ip, test_table_name_dict): ip.run_cell("import pandas as pd") diff --git a/src/tests/test_connection.py b/src/tests/test_connection.py index 1329eebca..396ec168e 100644 --- a/src/tests/test_connection.py +++ b/src/tests/test_connection.py @@ -13,6 +13,8 @@ from sqlalchemy.engine import Engine from sqlalchemy import exc +import pyspark + from sql.connection import connection as connection_module import sql.connection from sql.connection import ( @@ -21,6 +23,7 @@ ConnectionManager, is_pep249_compliant, default_alias_for_engine, + is_spark, ResultSetCollection, detect_duckdb_summarize_or_select, ) @@ -40,6 +43,10 @@ def mock_database(monkeypatch, cleanup): monkeypatch.setattr(Engine, "connect", Mock()) monkeypatch.setattr(sqlalchemy, "create_engine", Mock()) +@pytest.fixture +def mock_spark(monkeypatch,cleanup): + monkeypatch.setitem(sys.modules, "pyspark.sql.SparkSession", Mock()) + @pytest.fixture def mock_postgres(monkeypatch, cleanup): @@ -456,6 +463,22 @@ def test_properties(mock_postgres): def test_is_pep249_compliant(conn, expected): assert is_pep249_compliant(conn) is expected +@pytest.mark.parametrize( + "descriptor, expected", + [ + [sqlite3.connect(""), False], + [duckdb.connect(""), False], + [create_engine("sqlite://"), False], + [Mock(spec=pyspark.sql.SparkSession), True], + [Mock(spec=pyspark.sql.connect.session.SparkSession), True], + [None, False], + [object(), False], + ["not_a_valid_connection", False], + [0, False], + ] +) +def test_is_spark(descriptor, expected): + assert is_spark(descriptor) is expected def test_close_all(ip_empty, monkeypatch): connections = {} @@ -589,6 +612,21 @@ def test_set_dbapi(monkeypatch, callable_, key): assert connections == {key: conn} assert ConnectionManager.current == conn +@pytest.mark.parametrize( + "spark, key", + [ + [Mock(name="SparkSession",spec=pyspark.sql.SparkSession), "Mock"], + [Mock(name="SparkSession",spec=pyspark.sql.connect.session.SparkSession), "Mock"], + ], +) +def test_set_spark(monkeypatch, spark, key): + connections = {} + monkeypatch.setattr(ConnectionManager, "connections", connections) + + conn = ConnectionManager.set(spark, displaycon=False) + + assert connections == {key: conn} + assert ConnectionManager.current == conn def test_set_with_alias(monkeypatch): connections = {} From 145ae6ab02aad58370cdfd632732ccfc630a197c Mon Sep 17 00:00:00 2001 From: richard gilmore Date: Mon, 18 Dec 2023 22:06:33 +0000 Subject: [PATCH 03/29] formatting --- setup.py | 3 +-- src/sql/command.py | 6 +++++- src/sql/connection/connection.py | 17 +++++++++-------- src/sql/run/run.py | 1 + src/tests/integration/conftest.py | 2 +- src/tests/test_connection.py | 16 ++++++++++++---- 6 files changed, 29 insertions(+), 16 deletions(-) diff --git a/setup.py b/setup.py index 7d65357ae..f7f213405 100644 --- a/setup.py +++ b/setup.py @@ -35,8 +35,7 @@ "pandas", # previously pinned to 2.0.3 "polars==0.17.2", # 04/18/23 this breaks our CI "pyarrow", - "pyspark" - "invoke", + "pyspark" "invoke", "pkgmt", "twine", # tests diff --git a/src/sql/command.py b/src/sql/command.py index 894ad10e3..c5a664c73 100644 --- a/src/sql/command.py +++ b/src/sql/command.py @@ -49,7 +49,11 @@ def __init__(self, magic, user_ns, line, cell) -> None: if ( one_arg and self.args.line[0] in user_ns - and (isinstance(user_ns[self.args.line[0]], Engine) or is_dbapi_connection_ or is_spark) + and ( + isinstance(user_ns[self.args.line[0]], Engine) + or is_dbapi_connection_ + or is_spark + ) ): line_for_command = [] add_conn = True diff --git a/src/sql/connection/connection.py b/src/sql/connection/connection.py index 3f3e58f42..54b0b4ff1 100644 --- a/src/sql/connection/connection.py +++ b/src/sql/connection/connection.py @@ -1070,9 +1070,8 @@ def to_table(self, table_name, data_frame, if_exists, index, schema=None): class SparkConnectConnection(AbstractConnection): - is_dbapi_connection = False - + @telemetry.log_call("SparkConnectConnection", payload=True) def __init__(self, payload, connection, alias=None, config=None): try: @@ -1091,20 +1090,17 @@ def __init__(self, payload, connection, alias=None, config=None): # register the connection super().__init__(self._connection_class_name) - self.name = self._connection_class_name - + @property def dialect(self): """Returns a string with the SQL dialect name""" return "spark" - def raw_execute(self, query, parameters=None): """Run the query without any pre-processing""" return self._connection.sql(query) - def _get_database_information(self): """ Get the dialect, driver, and database server version info of current @@ -1136,12 +1132,13 @@ def to_table(self, table_name, data_frame, if_exists, index, schema=None): "--persist/--persist-replace is not available for Spark connections" " (only available for SQLAlchemy connections)" ) - + def close(self): """Close the connection""" # NOTE: spark is often shared outside sql, allow user to manage closure pass + def _check_if_duckdb_dbapi_connection(conn): """Check if the connection is a native duckdb connection""" # NOTE: duckdb defines df and pl to efficiently convert results to @@ -1235,8 +1232,12 @@ def is_pep249_compliant(conn): return True + def is_spark(ins): - return (CSparkSession is not None and isinstance(ins,CSparkSession)) or (SparkSession is not None and isinstance(ins,SparkSession)) + return (CSparkSession is not None and isinstance(ins, CSparkSession)) or ( + SparkSession is not None and isinstance(ins, SparkSession) + ) + def default_alias_for_engine(engine): if not engine.url.username: diff --git a/src/sql/run/run.py b/src/sql/run/run.py index a55c2c96d..f59ef96d2 100644 --- a/src/sql/run/run.py +++ b/src/sql/run/run.py @@ -71,6 +71,7 @@ def is_postgres_or_redshift(dialect): """Checks if dialect is postgres or redshift""" return "postgres" in str(dialect) or "redshift" in str(dialect) + def is_spark(dialect): return "spark" in str(dialect) diff --git a/src/tests/integration/conftest.py b/src/tests/integration/conftest.py index 0a13810f1..d92fe4d70 100644 --- a/src/tests/integration/conftest.py +++ b/src/tests/integration/conftest.py @@ -288,12 +288,12 @@ def setup_duckDB_native(test_table_name_dict): yield conn conn.close() + @pytest.fixture(scope="session") def setup_spark(): spark = SparkSession.Builder.master("local").getOrCreate() yield spark spark.stop() - def load_generic_testing_data_duckdb_native(ip, test_table_name_dict): diff --git a/src/tests/test_connection.py b/src/tests/test_connection.py index 396ec168e..b22b03c32 100644 --- a/src/tests/test_connection.py +++ b/src/tests/test_connection.py @@ -43,8 +43,9 @@ def mock_database(monkeypatch, cleanup): monkeypatch.setattr(Engine, "connect", Mock()) monkeypatch.setattr(sqlalchemy, "create_engine", Mock()) + @pytest.fixture -def mock_spark(monkeypatch,cleanup): +def mock_spark(monkeypatch, cleanup): monkeypatch.setitem(sys.modules, "pyspark.sql.SparkSession", Mock()) @@ -463,6 +464,7 @@ def test_properties(mock_postgres): def test_is_pep249_compliant(conn, expected): assert is_pep249_compliant(conn) is expected + @pytest.mark.parametrize( "descriptor, expected", [ @@ -475,11 +477,12 @@ def test_is_pep249_compliant(conn, expected): [object(), False], ["not_a_valid_connection", False], [0, False], - ] + ], ) def test_is_spark(descriptor, expected): assert is_spark(descriptor) is expected + def test_close_all(ip_empty, monkeypatch): connections = {} monkeypatch.setattr(ConnectionManager, "connections", connections) @@ -612,11 +615,15 @@ def test_set_dbapi(monkeypatch, callable_, key): assert connections == {key: conn} assert ConnectionManager.current == conn + @pytest.mark.parametrize( "spark, key", [ - [Mock(name="SparkSession",spec=pyspark.sql.SparkSession), "Mock"], - [Mock(name="SparkSession",spec=pyspark.sql.connect.session.SparkSession), "Mock"], + [Mock(name="SparkSession", spec=pyspark.sql.SparkSession), "Mock"], + [ + Mock(name="SparkSession", spec=pyspark.sql.connect.session.SparkSession), + "Mock", + ], ], ) def test_set_spark(monkeypatch, spark, key): @@ -628,6 +635,7 @@ def test_set_spark(monkeypatch, spark, key): assert connections == {key: conn} assert ConnectionManager.current == conn + def test_set_with_alias(monkeypatch): connections = {} monkeypatch.setattr(ConnectionManager, "connections", connections) From 7cd9e6118427e5f9e000630b6163194db47789eb Mon Sep 17 00:00:00 2001 From: richard gilmore Date: Mon, 18 Dec 2023 22:28:16 +0000 Subject: [PATCH 04/29] add spark connection --- setup.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/setup.py b/setup.py index f7f213405..28092dd88 100644 --- a/setup.py +++ b/setup.py @@ -35,7 +35,8 @@ "pandas", # previously pinned to 2.0.3 "polars==0.17.2", # 04/18/23 this breaks our CI "pyarrow", - "pyspark" "invoke", + "pyspark", + "invoke", "pkgmt", "twine", # tests From 0f6a3284858c5371b1ceed700905fde828227186 Mon Sep 17 00:00:00 2001 From: richard gilmore Date: Mon, 18 Dec 2023 22:28:26 +0000 Subject: [PATCH 05/29] add spark connection --- src/tests/integration/test_connection.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/tests/integration/test_connection.py b/src/tests/integration/test_connection.py index 5d867b6e6..aa4c4a8bf 100644 --- a/src/tests/integration/test_connection.py +++ b/src/tests/integration/test_connection.py @@ -8,7 +8,7 @@ import pytest -from sql.connection import SQLAlchemyConnection, DBAPIConnection, ConnectionManager +from sql.connection import SQLAlchemyConnection, DBAPIConnection, ConnectionManager, SparkConnectConnection from sql import _testing from sql.connection import connection @@ -92,6 +92,11 @@ def test_connection_properties(dynamic_db, request, Constructor, alias, dialect) partial(DBAPIConnection, alias="another-alias"), "another-alias", ], + [ + "setup_spark", + SparkConnectConnection, + "SparkSession" + ] ], ) def test_connection_identifiers( From 0b7e3bfe6cb54d411dea3e3b0296f0409e2830de Mon Sep 17 00:00:00 2001 From: richard gilmore Date: Mon, 18 Dec 2023 23:36:21 +0000 Subject: [PATCH 06/29] fixed test and formating --- setup.py | 2 +- src/sql/command.py | 2 +- src/tests/integration/test_connection.py | 13 +++++++------ src/tests/test_magic.py | 1 + 4 files changed, 10 insertions(+), 8 deletions(-) diff --git a/setup.py b/setup.py index 28092dd88..50c5a19ce 100644 --- a/setup.py +++ b/setup.py @@ -35,7 +35,7 @@ "pandas", # previously pinned to 2.0.3 "polars==0.17.2", # 04/18/23 this breaks our CI "pyarrow", - "pyspark", + "pyspark", "invoke", "pkgmt", "twine", diff --git a/src/sql/command.py b/src/sql/command.py index c5a664c73..675420d74 100644 --- a/src/sql/command.py +++ b/src/sql/command.py @@ -52,7 +52,7 @@ def __init__(self, magic, user_ns, line, cell) -> None: and ( isinstance(user_ns[self.args.line[0]], Engine) or is_dbapi_connection_ - or is_spark + or is_spark(user_ns[self.args.line[0]]) ) ): line_for_command = [] diff --git a/src/tests/integration/test_connection.py b/src/tests/integration/test_connection.py index aa4c4a8bf..63dc71ac9 100644 --- a/src/tests/integration/test_connection.py +++ b/src/tests/integration/test_connection.py @@ -8,7 +8,12 @@ import pytest -from sql.connection import SQLAlchemyConnection, DBAPIConnection, ConnectionManager, SparkConnectConnection +from sql.connection import ( + SQLAlchemyConnection, + DBAPIConnection, + ConnectionManager, + SparkConnectConnection, +) from sql import _testing from sql.connection import connection @@ -92,11 +97,7 @@ def test_connection_properties(dynamic_db, request, Constructor, alias, dialect) partial(DBAPIConnection, alias="another-alias"), "another-alias", ], - [ - "setup_spark", - SparkConnectConnection, - "SparkSession" - ] + ["setup_spark", SparkConnectConnection, "SparkSession"], ], ) def test_connection_identifiers( diff --git a/src/tests/test_magic.py b/src/tests/test_magic.py index c2dd8f155..0502bfb9a 100644 --- a/src/tests/test_magic.py +++ b/src/tests/test_magic.py @@ -1702,6 +1702,7 @@ def test_persist_uses_error_handling_method(ip, monkeypatch, cell): ip.push({"df": df}) conn = ConnectionManager.current + print(conn) execute_with_error_handling_mock = Mock(wraps=conn._execute_with_error_handling) monkeypatch.setattr( conn, "_execute_with_error_handling", execute_with_error_handling_mock From e60f0ea718f7f18b8c35126c87a62c4a2aad5e15 Mon Sep 17 00:00:00 2001 From: richard gilmore Date: Tue, 19 Dec 2023 00:56:51 +0000 Subject: [PATCH 07/29] added docs --- doc/_toc.yml | 1 + doc/integrations/spark.ipynb | 867 +++++++++++++++++++++++++++++++ src/sql/connection/connection.py | 2 +- 3 files changed, 869 insertions(+), 1 deletion(-) create mode 100644 doc/integrations/spark.ipynb diff --git a/doc/_toc.yml b/doc/_toc.yml index 2d9850b10..f667e807b 100644 --- a/doc/_toc.yml +++ b/doc/_toc.yml @@ -43,6 +43,7 @@ parts: - file: integrations/duckdb-native - file: integrations/compatibility - file: integrations/chdb + - file: integrations/spark - caption: API Reference chapters: diff --git a/doc/integrations/spark.ipynb b/doc/integrations/spark.ipynb new file mode 100644 index 000000000..2e518be11 --- /dev/null +++ b/doc/integrations/spark.ipynb @@ -0,0 +1,867 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Spark\n", + "\n", + "This tutorial will show you how to get a Spark instance up and running locally to test JupySQL. You can run this in a Jupyter notebook. We'll use Spark Connect which is the new thin client for Spark" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Pre-requisites\n", + "\n", + "To run this tutorial, you need to install following Python packages:" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Note: you may need to restart the kernel to use updated packages.\n" + ] + } + ], + "source": [ + "%pip install jupysql pyspark==3.4.1 arrow pyarrow==12.0.1 pandas --quiet" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Start Spark instance\n", + "\n", + "We fetch the official image, create a new database, and user (this will take a few seconds)." + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "4a7133e0b70b9435d575ef2eda12644a589ab78ef5800f4d95775377c13ef18c\n" + ] + } + ], + "source": [ + "%%bash\n", + "docker run -p 15002:15002 -p 4040:4040 -d --name spark wh1isper/sparglim-server" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Our database is running, let's load some data!" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Load sample data\n", + "\n", + "Now, let's fetch some sample data. We'll be using the [NYC taxi dataset](https://www.nyc.gov/site/tlc/about/tlc-trip-record-data.page):" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "import pandas as pd\n", + "from pyspark.sql.connect.session import SparkSession\n", + "\n", + "spark = SparkSession.builder.remote(\"sc://localhost\").getOrCreate()\n", + "\n", + "df = pd.read_parquet(\n", + " \"https://d37ci6vzurychx.cloudfront.net/trip-data/yellow_tripdata_2021-01.parquet\"\n", + ")\n", + "sparkDf = spark.createDataFrame(df.head(10000))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Set eagerEval on to print dataframes" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "def __pretty_(self, p, cycle):\n", + " self.show(truncate=False)\n", + "\n", + "\n", + "from pyspark.sql.connect.dataframe import DataFrame\n", + "\n", + "DataFrame._repr_pretty_ = __pretty_\n", + "spark.conf.set(\"spark.sql.repl.eagerEval.enabled\", True)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Add dataset to temporary view to allow querying" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "sparkDf.createOrReplaceTempView(\"taxi\")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Query\n", + "\n", + "Now, let's start JupySQL, authenticate, and query the data!" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The sql extension is already loaded. To reload it, use:\n", + " %reload_ext sql\n" + ] + } + ], + "source": [ + "%load_ext sql" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "%sql spark" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "```{important}\n", + "If the cell above fails, you might have some missing packages. Message us on [Slack](https://ploomber.io/community) and we'll help you!\n", + "```" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "List the tables in the database:" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "Running query in 'SparkSession'" + ], + "text/plain": [ + "Running query in 'SparkSession'" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "+---------+--------+-----------+\n", + "|namespace|viewName|isTemporary|\n", + "+---------+--------+-----------+\n", + "| |taxi |true |\n", + "+---------+--------+-----------+\n", + "\n" + ] + }, + { + "data": { + "text/plain": [] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "%sql show views in default" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "List columns in the taxi table:" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "Running query in 'SparkSession'" + ], + "text/plain": [ + "Running query in 'SparkSession'" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "root\n", + " |-- VendorID: long (nullable = true)\n", + " |-- tpep_pickup_datetime: timestamp (nullable = true)\n", + " |-- tpep_dropoff_datetime: timestamp (nullable = true)\n", + " |-- passenger_count: double (nullable = true)\n", + " |-- trip_distance: double (nullable = true)\n", + " |-- RatecodeID: double (nullable = true)\n", + " |-- store_and_fwd_flag: string (nullable = true)\n", + " |-- PULocationID: long (nullable = true)\n", + " |-- DOLocationID: long (nullable = true)\n", + " |-- payment_type: long (nullable = true)\n", + " |-- fare_amount: double (nullable = true)\n", + " |-- extra: double (nullable = true)\n", + " |-- mta_tax: double (nullable = true)\n", + " |-- tip_amount: double (nullable = true)\n", + " |-- tolls_amount: double (nullable = true)\n", + " |-- improvement_surcharge: double (nullable = true)\n", + " |-- total_amount: double (nullable = true)\n", + " |-- congestion_surcharge: double (nullable = true)\n", + " |-- airport_fee: double (nullable = true)\n", + "\n" + ] + } + ], + "source": [ + "df = %sql select * from taxi\n", + "df.printSchema()" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Query our data:" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "Running query in 'SparkSession'" + ], + "text/plain": [ + "Running query in 'SparkSession'" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "+--------+\n", + "|count(1)|\n", + "+--------+\n", + "|10000 |\n", + "+--------+\n", + "\n" + ] + }, + { + "data": { + "text/plain": [] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "%%sql\n", + "SELECT COUNT(*) FROM taxi" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Parameterize queries" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "threshold = 10" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "Running query in 'SparkSession'" + ], + "text/plain": [ + "Running query in 'SparkSession'" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "+--------+\n", + "|count(1)|\n", + "+--------+\n", + "|9476 |\n", + "+--------+\n", + "\n" + ] + }, + { + "data": { + "text/plain": [] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "%%sql\n", + "SELECT COUNT(*) FROM taxi\n", + "WHERE trip_distance < {{threshold}}" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "threshold = 0.5" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "Running query in 'SparkSession'" + ], + "text/plain": [ + "Running query in 'SparkSession'" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "+--------+\n", + "|count(1)|\n", + "+--------+\n", + "|642 |\n", + "+--------+\n", + "\n" + ] + }, + { + "data": { + "text/plain": [] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "%%sql\n", + "SELECT COUNT(*) FROM taxi\n", + "WHERE trip_distance < {{threshold}}" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## CTEs" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "Running query in 'SparkSession'" + ], + "text/plain": [ + "Running query in 'SparkSession'" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Skipping execution..." + ], + "text/plain": [ + "Skipping execution..." + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "%%sql --save many_passengers --no-execute\n", + "SELECT *\n", + "FROM taxi\n", + "WHERE passenger_count > 3\n", + "-- remove top 1% outliers for better visualization\n", + "AND trip_distance < 18.93" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "Running query in 'SparkSession'" + ], + "text/plain": [ + "Running query in 'SparkSession'" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "+------------------+------------------+------------------+\n", + "|min(trip_distance)|avg(trip_distance)|max(trip_distance)|\n", + "+------------------+------------------+------------------+\n", + "|0.0 |3.1091381872213963|18.46 |\n", + "+------------------+------------------+------------------+\n", + "\n" + ] + }, + { + "data": { + "text/plain": [] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "%%sql --save trip_stats --with many_passengers\n", + "SELECT MIN(trip_distance), AVG(trip_distance), MAX(trip_distance)\n", + "FROM many_passengers" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This is what JupySQL executes:" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "WITH `many_passengers` AS (\n", + "SELECT *\n", + "FROM taxi\n", + "WHERE passenger_count > 3\n", + "\n", + "AND trip_distance < 18.93)\n", + "SELECT MIN(trip_distance), AVG(trip_distance), MAX(trip_distance)\n", + "FROM many_passengers\n" + ] + } + ], + "source": [ + "query = %sqlcmd snippets trip_stats\n", + "print(query)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Plotting" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The %sqlplot magic command currently does not directly support the `--schema` option for specifying the schema name. To work around this, you can specify the schema in the SQL query itself." + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "Running query in 'SparkSession'" + ], + "text/plain": [ + "Running query in 'SparkSession'" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "result = %sql SELECT trip_distance FROM taxi\n", + "\n", + "import matplotlib.pyplot as plt\n", + "\n", + "data = result.toPandas()\n", + "\n", + "plt.hist(data[\"trip_distance\"])\n", + "plt.xlabel(\"Trip Distance\")\n", + "plt.ylabel(\"Frequency\")\n", + "plt.title(\"Histogram of Trip Distance\")\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "Running query in 'SparkSession'" + ], + "text/plain": [ + "Running query in 'SparkSession'" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "result = %sql SELECT trip_distance FROM taxi\n", + "\n", + "import matplotlib.pyplot as plt\n", + "\n", + "data = result.toPandas()\n", + "\n", + "plt.boxplot(data[\"trip_distance\"])\n", + "plt.xlabel(\"Trip Distance\")\n", + "plt.ylabel(\"Value\")\n", + "plt.title(\"Boxplot of Trip Distance\")\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Persist - Not Supported" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Clean up\n", + "\n", + "To stop and remove the container:" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CONTAINER ID IMAGE COMMAND CREATED STATUS PORTS NAMES\n", + "4a7133e0b70b wh1isper/sparglim-server \"tini -- sparglim-se…\" 4 minutes ago Up 4 minutes 0.0.0.0:4040->4040/tcp, 0.0.0.0:15002->15002/tcp spark\n", + "f019407c6426 docker.dev.slicelife.com/onelogin-aws-assume-role:stable \"onelogin-aws-assume…\" 2 weeks ago Up 2 weeks heuristic_tu\n" + ] + } + ], + "source": [ + "! docker container ls" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "%%capture out\n", + "! docker container ls --filter ancestor=wh1isper/sparglim-server --quiet" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Container id: 4a7133e0b70b\n" + ] + } + ], + "source": [ + "container_id = out.stdout.strip()\n", + "print(f\"Container id: {container_id}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "4a7133e0b70b\n" + ] + } + ], + "source": [ + "! docker container stop {container_id}" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "4a7133e0b70b\n" + ] + } + ], + "source": [ + "! docker container rm {container_id}" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.13" + }, + "myst": { + "html_meta": { + "description lang=en": "Query a PostgreSQL database from Jupyter via JupySQL", + "keywords": "jupyter, sql, jupysql, postgres", + "property=og:locale": "en_US" + } + }, + "widgets": { + "application/vnd.jupyter.widget-state+json": { + "state": {}, + "version_major": 2, + "version_minor": 0 + } + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/src/sql/connection/connection.py b/src/sql/connection/connection.py index 54b0b4ff1..7262065c7 100644 --- a/src/sql/connection/connection.py +++ b/src/sql/connection/connection.py @@ -1109,7 +1109,7 @@ def _get_database_information(self): return { "dialect": self.dialect, "driver": self._connection_class_name, - "server_version_info": self._connection.version(), + "server_version_info": self._connection.version, } @property From c4acca13ef17395b5f16a091dc3ca0188821b32d Mon Sep 17 00:00:00 2001 From: richard gilmore Date: Tue, 19 Dec 2023 01:09:52 +0000 Subject: [PATCH 08/29] exclude execution --- doc/conf.py | 1 + doc/integrations/spark.ipynb | 7 ++++++- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/doc/conf.py b/doc/conf.py index 5e1792154..39080d77d 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -27,6 +27,7 @@ "integrations/oracle.ipynb", "integrations/snowflake.ipynb", "integrations/redshift.ipynb", + "integrations/spark.ipynb", ] nb_execution_in_temp = True nb_execution_show_tb = True diff --git a/doc/integrations/spark.ipynb b/doc/integrations/spark.ipynb index 2e518be11..481a8da6d 100644 --- a/doc/integrations/spark.ipynb +++ b/doc/integrations/spark.ipynb @@ -831,7 +831,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": "jupysql", "language": "python", "name": "python3" }, @@ -854,6 +854,11 @@ "property=og:locale": "en_US" } }, + "vscode": { + "interpreter": { + "hash": "8de7291ac4f217ed756f77e1d71d41823fff9c4ffb13df0a183e9309929ad9aa" + } + }, "widgets": { "application/vnd.jupyter.widget-state+json": { "state": {}, From 96846d4ae46eeffe2b8e735bb1c4111e273ccc6b Mon Sep 17 00:00:00 2001 From: richard gilmore Date: Tue, 19 Dec 2023 09:57:14 +0000 Subject: [PATCH 09/29] documentation updates --- doc/integrations/compatibility.md | 15 ++++ doc/integrations/spark.ipynb | 111 +++++++++++++----------------- src/tests/test_magic.py | 1 - 3 files changed, 61 insertions(+), 66 deletions(-) diff --git a/doc/integrations/compatibility.md b/doc/integrations/compatibility.md index 4e6b36432..d26199150 100644 --- a/doc/integrations/compatibility.md +++ b/doc/integrations/compatibility.md @@ -114,4 +114,19 @@ These table reflects the compatibility status of JupySQL `>=0.7` - Listing tables with `%sqlcmd tables` ✅ - Listing columns with `%sqlcmd columns` ✅ - Parametrized SQL queries via `{{parameter}}` ✅ +- Interactive SQL queries via `--interact` ✅ + +## Spark + +- Running queries with `%%sql` ✅ +- CTEs with `%%sql --save NAME` ✅ +- Plotting with `%%sqlplot boxplot` ❌ +- Plotting with `%%sqlplot bar` ❌ +- Plotting with `%%sqlplot pie` ❌ +- Plotting with `%%sqlplot histogram` ❌ +- Plotting with `ggplot` API ❌ +- Profiling tables with `%sqlcmd profile` ❌ +- Listing tables with `%sqlcmd tables` ❌ +- Listing columns with `%sqlcmd columns` ❌ +- Parametrized SQL queries via `{{parameter}}` ✅ - Interactive SQL queries via `--interact` ✅ \ No newline at end of file diff --git a/doc/integrations/spark.ipynb b/doc/integrations/spark.ipynb index 481a8da6d..ee2451124 100644 --- a/doc/integrations/spark.ipynb +++ b/doc/integrations/spark.ipynb @@ -7,7 +7,7 @@ "source": [ "# Spark\n", "\n", - "This tutorial will show you how to get a Spark instance up and running locally to test JupySQL. You can run this in a Jupyter notebook. We'll use Spark Connect which is the new thin client for Spark" + "This tutorial will show you how to get a Spark instance up and running locally to integrate with JupySQL. You can run this in a Jupyter notebook. We'll use [Spark Connect](https://spark.apache.org/docs/latest/api/python/getting_started/quickstart_connect.html) which is the new thin client for Spark" ] }, { @@ -36,7 +36,7 @@ } ], "source": [ - "%pip install jupysql pyspark==3.4.1 arrow pyarrow==12.0.1 pandas --quiet" + "%pip install jupysql pyspark==3.4.1 arrow pyarrow==12.0.1 pandas grpc-status --quiet" ] }, { @@ -51,7 +51,7 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": 26, "metadata": { "tags": [] }, @@ -60,7 +60,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "4a7133e0b70b9435d575ef2eda12644a589ab78ef5800f4d95775377c13ef18c\n" + "8e3831298eb9dce9d26b2f2d3df5a1a64634a19976f4eaeabf92c45afd478a39\n" ] } ], @@ -89,7 +89,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 27, "metadata": { "tags": [] }, @@ -110,12 +110,12 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Set eagerEval on to print dataframes" + "Set [eagerEval](https://spark.apache.org/docs/latest/api/python/getting_started/quickstart_df.html#Viewing-Data) on to print dataframes, This makes Spark print dataframes eagerly in notebook environments, rather than it's default lazy execution which requires .show() to see the data. In Spark 3.4.1 we need to override, as below, but in 3.5.0 it will print in html. " ] }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 28, "metadata": {}, "outputs": [], "source": [ @@ -134,12 +134,12 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Add dataset to temporary view to allow querying" + "Add dataset to temporary view to allow querying:" ] }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 29, "metadata": { "tags": [] }, @@ -160,7 +160,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 30, "metadata": { "tags": [] }, @@ -180,7 +180,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 31, "metadata": { "tags": [] }, @@ -209,46 +209,34 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 46, "metadata": { "tags": [] }, "outputs": [ { - "data": { - "text/html": [ - "Running query in 'SparkSession'" - ], - "text/plain": [ - "Running query in 'SparkSession'" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "+---------+--------+-----------+\n", - "|namespace|viewName|isTemporary|\n", - "+---------+--------+-----------+\n", - "| |taxi |true |\n", - "+---------+--------+-----------+\n", - "\n" + "ename": "NotImplementedError", + "evalue": "This feature is only available for SQLAlchemy connections", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mNotImplementedError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[46], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mget_ipython\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun_line_magic\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43msqlcmd\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mtables\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m/opt/homebrew/anaconda3/envs/jupysql/lib/python3.10/site-packages/IPython/core/interactiveshell.py:2456\u001b[0m, in \u001b[0;36mInteractiveShell.run_line_magic\u001b[0;34m(self, magic_name, line, _stack_depth)\u001b[0m\n\u001b[1;32m 2454\u001b[0m kwargs[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mlocal_ns\u001b[39m\u001b[38;5;124m'\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mget_local_scope(stack_depth)\n\u001b[1;32m 2455\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbuiltin_trap:\n\u001b[0;32m-> 2456\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 2458\u001b[0m \u001b[38;5;66;03m# The code below prevents the output from being displayed\u001b[39;00m\n\u001b[1;32m 2459\u001b[0m \u001b[38;5;66;03m# when using magics with decorator @output_can_be_silenced\u001b[39;00m\n\u001b[1;32m 2460\u001b[0m \u001b[38;5;66;03m# when the last Python token in the expression is a ';'.\u001b[39;00m\n\u001b[1;32m 2461\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mgetattr\u001b[39m(fn, magic\u001b[38;5;241m.\u001b[39mMAGIC_OUTPUT_CAN_BE_SILENCED, \u001b[38;5;28;01mFalse\u001b[39;00m):\n", + "File \u001b[0;32m~/scripts/jupysql/src/sql/magic_cmd.py:104\u001b[0m, in \u001b[0;36mSqlCmdMagic._validate_execute_inputs\u001b[0;34m(self, line)\u001b[0m\n\u001b[1;32m 101\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m command \u001b[38;5;129;01min\u001b[39;00m COMMANDS_SQLALCHEMY_ONLY:\n\u001b[1;32m 102\u001b[0m support_only_sql_alchemy_connection(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m%sqlcmd \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mcommand\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m--> 104\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mexecute\u001b[49m\u001b[43m(\u001b[49m\u001b[43mcommand\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mothers\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 105\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 106\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m exceptions\u001b[38;5;241m.\u001b[39mUsageError(\n\u001b[1;32m 107\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m%sqlcmd has no command: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mcommand\u001b[38;5;132;01m!r}\u001b[39;00m\u001b[38;5;124m. \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 108\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mValid commands are: \u001b[39m\u001b[38;5;132;01m{}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;241m.\u001b[39mformat(\n\u001b[1;32m 109\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m, \u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;241m.\u001b[39mjoin(AVAILABLE_SQLCMD_COMMANDS)\n\u001b[1;32m 110\u001b[0m )\n\u001b[1;32m 111\u001b[0m )\n", + "File \u001b[0;32m~/scripts/jupysql/src/sql/magic_cmd.py:132\u001b[0m, in \u001b[0;36mSqlCmdMagic.execute\u001b[0;34m(self, cmd_name, others, cell, local_ns)\u001b[0m\n\u001b[1;32m 130\u001b[0m cmd \u001b[38;5;241m=\u001b[39m router\u001b[38;5;241m.\u001b[39mget(cmd_name)\n\u001b[1;32m 131\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m cmd:\n\u001b[0;32m--> 132\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mcmd\u001b[49m\u001b[43m(\u001b[49m\u001b[43mothers\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/scripts/jupysql/src/sql/cmd/tables.py:30\u001b[0m, in \u001b[0;36mtables\u001b[0;34m(others)\u001b[0m\n\u001b[1;32m 26\u001b[0m parser\u001b[38;5;241m.\u001b[39madd_argument(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m-s\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m--schema\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;28mtype\u001b[39m\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mstr\u001b[39m, help\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mSchema name\u001b[39m\u001b[38;5;124m\"\u001b[39m, required\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m)\n\u001b[1;32m 28\u001b[0m args \u001b[38;5;241m=\u001b[39m parser\u001b[38;5;241m.\u001b[39mparse_args(others)\n\u001b[0;32m---> 30\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43minspect\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget_table_names\u001b[49m\u001b[43m(\u001b[49m\u001b[43mschema\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43margs\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mschema\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m/opt/homebrew/anaconda3/envs/jupysql/lib/python3.10/site-packages/ploomber_core/telemetry/telemetry.py:679\u001b[0m, in \u001b[0;36mTelemetry.log_call.._log_call..wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 677\u001b[0m result \u001b[38;5;241m=\u001b[39m func(_payload, \u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[1;32m 678\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 679\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 680\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mException\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[1;32m 681\u001b[0m metadata_error \u001b[38;5;241m=\u001b[39m {\n\u001b[1;32m 682\u001b[0m \u001b[38;5;66;03m# can we log None to posthog?\u001b[39;00m\n\u001b[1;32m 683\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtype\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;28mgetattr\u001b[39m(e, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtype_\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;28;01mNone\u001b[39;00m),\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 686\u001b[0m \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39m_payload,\n\u001b[1;32m 687\u001b[0m }\n", + "File \u001b[0;32m~/scripts/jupysql/src/sql/inspect.py:483\u001b[0m, in \u001b[0;36mget_table_names\u001b[0;34m(schema)\u001b[0m\n\u001b[1;32m 480\u001b[0m \u001b[38;5;129m@telemetry\u001b[39m\u001b[38;5;241m.\u001b[39mlog_call()\n\u001b[1;32m 481\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mget_table_names\u001b[39m(schema\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m):\n\u001b[1;32m 482\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"Get table names for a given connection\"\"\"\u001b[39;00m\n\u001b[0;32m--> 483\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mTables\u001b[49m\u001b[43m(\u001b[49m\u001b[43mschema\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/scripts/jupysql/src/sql/inspect.py:38\u001b[0m, in \u001b[0;36mTables.__init__\u001b[0;34m(self, schema, conn)\u001b[0m\n\u001b[1;32m 37\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__init__\u001b[39m(\u001b[38;5;28mself\u001b[39m, schema\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m, conn\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m---> 38\u001b[0m inspector \u001b[38;5;241m=\u001b[39m \u001b[43m_get_inspector\u001b[49m\u001b[43m(\u001b[49m\u001b[43mconn\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 40\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_table \u001b[38;5;241m=\u001b[39m PrettyTable()\n\u001b[1;32m 41\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_table\u001b[38;5;241m.\u001b[39mfield_names \u001b[38;5;241m=\u001b[39m [\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mName\u001b[39m\u001b[38;5;124m\"\u001b[39m]\n", + "File \u001b[0;32m~/scripts/jupysql/src/sql/inspect.py:21\u001b[0m, in \u001b[0;36m_get_inspector\u001b[0;34m(conn)\u001b[0m\n\u001b[1;32m 19\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m exceptions\u001b[38;5;241m.\u001b[39mRuntimeError(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mNo active connection\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 20\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m---> 21\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m inspect(\u001b[43mConnectionManager\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcurrent\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mconnection_sqlalchemy\u001b[49m)\n", + "File \u001b[0;32m~/scripts/jupysql/src/sql/connection/connection.py:1126\u001b[0m, in \u001b[0;36mSparkConnectConnection.connection_sqlalchemy\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1120\u001b[0m \u001b[38;5;129m@property\u001b[39m\n\u001b[1;32m 1121\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mconnection_sqlalchemy\u001b[39m(\u001b[38;5;28mself\u001b[39m):\n\u001b[1;32m 1122\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 1123\u001b[0m \u001b[38;5;124;03m Raises NotImplementedError since Spark connections don't have a SQLAlchemy\u001b[39;00m\n\u001b[1;32m 1124\u001b[0m \u001b[38;5;124;03m connection object\u001b[39;00m\n\u001b[1;32m 1125\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m-> 1126\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mNotImplementedError\u001b[39;00m(\n\u001b[1;32m 1127\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mThis feature is only available for SQLAlchemy connections\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 1128\u001b[0m )\n", + "\u001b[0;31mNotImplementedError\u001b[0m: This feature is only available for SQLAlchemy connections" ] - }, - { - "data": { - "text/plain": [] - }, - "execution_count": 7, - "metadata": {}, - "output_type": "execute_result" } ], "source": [ - "%sql show views in default" + "%sqlcmd tables" ] }, { @@ -261,7 +249,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 33, "metadata": { "tags": [] }, @@ -321,7 +309,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 34, "metadata": { "tags": [] }, @@ -354,7 +342,7 @@ "data": { "text/plain": [] }, - "execution_count": 9, + "execution_count": 34, "metadata": {}, "output_type": "execute_result" } @@ -374,7 +362,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 35, "metadata": { "tags": [] }, @@ -385,7 +373,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 36, "metadata": { "tags": [] }, @@ -418,7 +406,7 @@ "data": { "text/plain": [] }, - "execution_count": 11, + "execution_count": 36, "metadata": {}, "output_type": "execute_result" } @@ -431,7 +419,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 37, "metadata": { "tags": [] }, @@ -442,7 +430,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 38, "metadata": { "tags": [] }, @@ -475,7 +463,7 @@ "data": { "text/plain": [] }, - "execution_count": 14, + "execution_count": 38, "metadata": {}, "output_type": "execute_result" } @@ -496,7 +484,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 39, "metadata": { "tags": [] }, @@ -537,7 +525,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 40, "metadata": { "tags": [] }, @@ -570,7 +558,7 @@ "data": { "text/plain": [] }, - "execution_count": 17, + "execution_count": 40, "metadata": {}, "output_type": "execute_result" } @@ -591,7 +579,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 41, "metadata": { "tags": [] }, @@ -629,12 +617,12 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "The %sqlplot magic command currently does not directly support the `--schema` option for specifying the schema name. To work around this, you can specify the schema in the SQL query itself." + "The %sqlplot magic command currently supported by the spark connection but you can still plot as like this:" ] }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 45, "metadata": { "tags": [] }, @@ -720,13 +708,6 @@ "plt.show()" ] }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Persist - Not Supported" - ] - }, { "attachments": {}, "cell_type": "markdown", @@ -831,7 +812,7 @@ ], "metadata": { "kernelspec": { - "display_name": "jupysql", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, diff --git a/src/tests/test_magic.py b/src/tests/test_magic.py index 0502bfb9a..c2dd8f155 100644 --- a/src/tests/test_magic.py +++ b/src/tests/test_magic.py @@ -1702,7 +1702,6 @@ def test_persist_uses_error_handling_method(ip, monkeypatch, cell): ip.push({"df": df}) conn = ConnectionManager.current - print(conn) execute_with_error_handling_mock = Mock(wraps=conn._execute_with_error_handling) monkeypatch.setattr( conn, "_execute_with_error_handling", execute_with_error_handling_mock From c0723d749a5368a99d740f0b7171f7c82659e11d Mon Sep 17 00:00:00 2001 From: richard gilmore Date: Tue, 19 Dec 2023 10:13:56 +0000 Subject: [PATCH 10/29] adjust doc string for close --- doc/integrations/spark.ipynb | 50 ++++++++++++++++++++------------ src/sql/connection/connection.py | 3 +- 2 files changed, 32 insertions(+), 21 deletions(-) diff --git a/doc/integrations/spark.ipynb b/doc/integrations/spark.ipynb index ee2451124..55ea62da5 100644 --- a/doc/integrations/spark.ipynb +++ b/doc/integrations/spark.ipynb @@ -209,34 +209,46 @@ }, { "cell_type": "code", - "execution_count": 46, + "execution_count": 47, "metadata": { "tags": [] }, "outputs": [ { - "ename": "NotImplementedError", - "evalue": "This feature is only available for SQLAlchemy connections", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mNotImplementedError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[46], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mget_ipython\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun_line_magic\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43msqlcmd\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mtables\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m/opt/homebrew/anaconda3/envs/jupysql/lib/python3.10/site-packages/IPython/core/interactiveshell.py:2456\u001b[0m, in \u001b[0;36mInteractiveShell.run_line_magic\u001b[0;34m(self, magic_name, line, _stack_depth)\u001b[0m\n\u001b[1;32m 2454\u001b[0m kwargs[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mlocal_ns\u001b[39m\u001b[38;5;124m'\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mget_local_scope(stack_depth)\n\u001b[1;32m 2455\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbuiltin_trap:\n\u001b[0;32m-> 2456\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 2458\u001b[0m \u001b[38;5;66;03m# The code below prevents the output from being displayed\u001b[39;00m\n\u001b[1;32m 2459\u001b[0m \u001b[38;5;66;03m# when using magics with decorator @output_can_be_silenced\u001b[39;00m\n\u001b[1;32m 2460\u001b[0m \u001b[38;5;66;03m# when the last Python token in the expression is a ';'.\u001b[39;00m\n\u001b[1;32m 2461\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mgetattr\u001b[39m(fn, magic\u001b[38;5;241m.\u001b[39mMAGIC_OUTPUT_CAN_BE_SILENCED, \u001b[38;5;28;01mFalse\u001b[39;00m):\n", - "File \u001b[0;32m~/scripts/jupysql/src/sql/magic_cmd.py:104\u001b[0m, in \u001b[0;36mSqlCmdMagic._validate_execute_inputs\u001b[0;34m(self, line)\u001b[0m\n\u001b[1;32m 101\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m command \u001b[38;5;129;01min\u001b[39;00m COMMANDS_SQLALCHEMY_ONLY:\n\u001b[1;32m 102\u001b[0m support_only_sql_alchemy_connection(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m%sqlcmd \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mcommand\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m--> 104\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mexecute\u001b[49m\u001b[43m(\u001b[49m\u001b[43mcommand\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mothers\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 105\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 106\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m exceptions\u001b[38;5;241m.\u001b[39mUsageError(\n\u001b[1;32m 107\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m%sqlcmd has no command: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mcommand\u001b[38;5;132;01m!r}\u001b[39;00m\u001b[38;5;124m. \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 108\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mValid commands are: \u001b[39m\u001b[38;5;132;01m{}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;241m.\u001b[39mformat(\n\u001b[1;32m 109\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m, \u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;241m.\u001b[39mjoin(AVAILABLE_SQLCMD_COMMANDS)\n\u001b[1;32m 110\u001b[0m )\n\u001b[1;32m 111\u001b[0m )\n", - "File \u001b[0;32m~/scripts/jupysql/src/sql/magic_cmd.py:132\u001b[0m, in \u001b[0;36mSqlCmdMagic.execute\u001b[0;34m(self, cmd_name, others, cell, local_ns)\u001b[0m\n\u001b[1;32m 130\u001b[0m cmd \u001b[38;5;241m=\u001b[39m router\u001b[38;5;241m.\u001b[39mget(cmd_name)\n\u001b[1;32m 131\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m cmd:\n\u001b[0;32m--> 132\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mcmd\u001b[49m\u001b[43m(\u001b[49m\u001b[43mothers\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/scripts/jupysql/src/sql/cmd/tables.py:30\u001b[0m, in \u001b[0;36mtables\u001b[0;34m(others)\u001b[0m\n\u001b[1;32m 26\u001b[0m parser\u001b[38;5;241m.\u001b[39madd_argument(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m-s\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m--schema\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;28mtype\u001b[39m\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mstr\u001b[39m, help\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mSchema name\u001b[39m\u001b[38;5;124m\"\u001b[39m, required\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m)\n\u001b[1;32m 28\u001b[0m args \u001b[38;5;241m=\u001b[39m parser\u001b[38;5;241m.\u001b[39mparse_args(others)\n\u001b[0;32m---> 30\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43minspect\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget_table_names\u001b[49m\u001b[43m(\u001b[49m\u001b[43mschema\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43margs\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mschema\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m/opt/homebrew/anaconda3/envs/jupysql/lib/python3.10/site-packages/ploomber_core/telemetry/telemetry.py:679\u001b[0m, in \u001b[0;36mTelemetry.log_call.._log_call..wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 677\u001b[0m result \u001b[38;5;241m=\u001b[39m func(_payload, \u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[1;32m 678\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 679\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 680\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mException\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[1;32m 681\u001b[0m metadata_error \u001b[38;5;241m=\u001b[39m {\n\u001b[1;32m 682\u001b[0m \u001b[38;5;66;03m# can we log None to posthog?\u001b[39;00m\n\u001b[1;32m 683\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtype\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;28mgetattr\u001b[39m(e, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtype_\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;28;01mNone\u001b[39;00m),\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 686\u001b[0m \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39m_payload,\n\u001b[1;32m 687\u001b[0m }\n", - "File \u001b[0;32m~/scripts/jupysql/src/sql/inspect.py:483\u001b[0m, in \u001b[0;36mget_table_names\u001b[0;34m(schema)\u001b[0m\n\u001b[1;32m 480\u001b[0m \u001b[38;5;129m@telemetry\u001b[39m\u001b[38;5;241m.\u001b[39mlog_call()\n\u001b[1;32m 481\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mget_table_names\u001b[39m(schema\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m):\n\u001b[1;32m 482\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"Get table names for a given connection\"\"\"\u001b[39;00m\n\u001b[0;32m--> 483\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mTables\u001b[49m\u001b[43m(\u001b[49m\u001b[43mschema\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/scripts/jupysql/src/sql/inspect.py:38\u001b[0m, in \u001b[0;36mTables.__init__\u001b[0;34m(self, schema, conn)\u001b[0m\n\u001b[1;32m 37\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__init__\u001b[39m(\u001b[38;5;28mself\u001b[39m, schema\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m, conn\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m---> 38\u001b[0m inspector \u001b[38;5;241m=\u001b[39m \u001b[43m_get_inspector\u001b[49m\u001b[43m(\u001b[49m\u001b[43mconn\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 40\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_table \u001b[38;5;241m=\u001b[39m PrettyTable()\n\u001b[1;32m 41\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_table\u001b[38;5;241m.\u001b[39mfield_names \u001b[38;5;241m=\u001b[39m [\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mName\u001b[39m\u001b[38;5;124m\"\u001b[39m]\n", - "File \u001b[0;32m~/scripts/jupysql/src/sql/inspect.py:21\u001b[0m, in \u001b[0;36m_get_inspector\u001b[0;34m(conn)\u001b[0m\n\u001b[1;32m 19\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m exceptions\u001b[38;5;241m.\u001b[39mRuntimeError(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mNo active connection\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 20\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m---> 21\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m inspect(\u001b[43mConnectionManager\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcurrent\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mconnection_sqlalchemy\u001b[49m)\n", - "File \u001b[0;32m~/scripts/jupysql/src/sql/connection/connection.py:1126\u001b[0m, in \u001b[0;36mSparkConnectConnection.connection_sqlalchemy\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1120\u001b[0m \u001b[38;5;129m@property\u001b[39m\n\u001b[1;32m 1121\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mconnection_sqlalchemy\u001b[39m(\u001b[38;5;28mself\u001b[39m):\n\u001b[1;32m 1122\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 1123\u001b[0m \u001b[38;5;124;03m Raises NotImplementedError since Spark connections don't have a SQLAlchemy\u001b[39;00m\n\u001b[1;32m 1124\u001b[0m \u001b[38;5;124;03m connection object\u001b[39;00m\n\u001b[1;32m 1125\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m-> 1126\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mNotImplementedError\u001b[39;00m(\n\u001b[1;32m 1127\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mThis feature is only available for SQLAlchemy connections\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 1128\u001b[0m )\n", - "\u001b[0;31mNotImplementedError\u001b[0m: This feature is only available for SQLAlchemy connections" + "data": { + "text/html": [ + "Running query in 'SparkSession'" + ], + "text/plain": [ + "Running query in 'SparkSession'" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "+---------+--------+-----------+\n", + "|namespace|viewName|isTemporary|\n", + "+---------+--------+-----------+\n", + "| |taxi |true |\n", + "+---------+--------+-----------+\n", + "\n" ] + }, + { + "data": { + "text/plain": [] + }, + "execution_count": 47, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ - "%sqlcmd tables" + "%sql show views in default" ] }, { diff --git a/src/sql/connection/connection.py b/src/sql/connection/connection.py index 7262065c7..1c04ecc44 100644 --- a/src/sql/connection/connection.py +++ b/src/sql/connection/connection.py @@ -1134,8 +1134,7 @@ def to_table(self, table_name, data_frame, if_exists, index, schema=None): ) def close(self): - """Close the connection""" - # NOTE: spark is often shared outside sql, allow user to manage closure + """Override of the abstract close as SparkSession is usually shared with pyspark""" pass From 3de5f615a6f1af942a96f78fd67d70e98bca644f Mon Sep 17 00:00:00 2001 From: richard gilmore Date: Tue, 19 Dec 2023 10:19:38 +0000 Subject: [PATCH 11/29] add generic --- src/sql/connection/connection.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sql/connection/connection.py b/src/sql/connection/connection.py index 1c04ecc44..bf7bc44e0 100644 --- a/src/sql/connection/connection.py +++ b/src/sql/connection/connection.py @@ -1088,7 +1088,7 @@ def __init__(self, payload, connection, alias=None, config=None): # calling init from AbstractConnection must be the last thing we do as it # register the connection - super().__init__(self._connection_class_name) + super().__init__(alias=alias or self._connection_class_name) self.name = self._connection_class_name From 0b968dc37b6f6925836d43938ac5d5e3ecf37a1a Mon Sep 17 00:00:00 2001 From: richard gilmore Date: Tue, 19 Dec 2023 14:00:12 +0000 Subject: [PATCH 12/29] integrated better with existing functionality --- doc/integrations/compatibility.md | 15 +- doc/integrations/spark.ipynb | 776 +++++++++++++++++++++++++----- src/sql/connection/connection.py | 4 +- src/sql/magic.py | 8 + src/sql/plot.py | 2 +- src/sql/run/run.py | 5 +- src/sql/run/sparkdataframe.py | 42 ++ 7 files changed, 721 insertions(+), 131 deletions(-) create mode 100644 src/sql/run/sparkdataframe.py diff --git a/doc/integrations/compatibility.md b/doc/integrations/compatibility.md index d26199150..1771823f9 100644 --- a/doc/integrations/compatibility.md +++ b/doc/integrations/compatibility.md @@ -120,13 +120,14 @@ These table reflects the compatibility status of JupySQL `>=0.7` - Running queries with `%%sql` ✅ - CTEs with `%%sql --save NAME` ✅ -- Plotting with `%%sqlplot boxplot` ❌ -- Plotting with `%%sqlplot bar` ❌ -- Plotting with `%%sqlplot pie` ❌ -- Plotting with `%%sqlplot histogram` ❌ -- Plotting with `ggplot` API ❌ -- Profiling tables with `%sqlcmd profile` ❌ +- Plotting with `%%sqlplot boxplot` ✅ +- Plotting with `%%sqlplot bar` ✅ +- Plotting with `%%sqlplot pie` ✅ +- Plotting with `%%sqlplot histogram` ✅ +- Plotting with `ggplot` ✅ +- Profiling tables with `%sqlcmd profile` ✅ - Listing tables with `%sqlcmd tables` ❌ - Listing columns with `%sqlcmd columns` ❌ - Parametrized SQL queries via `{{parameter}}` ✅ -- Interactive SQL queries via `--interact` ✅ \ No newline at end of file +- Interactive SQL queries via `--interact` ✅ +- Persiting Dataframes via `--persist` ❌ \ No newline at end of file diff --git a/doc/integrations/spark.ipynb b/doc/integrations/spark.ipynb index 55ea62da5..90e296902 100644 --- a/doc/integrations/spark.ipynb +++ b/doc/integrations/spark.ipynb @@ -22,7 +22,7 @@ }, { "cell_type": "code", - "execution_count": 26, + "execution_count": 14, "metadata": { "tags": [] }, @@ -36,7 +36,7 @@ } ], "source": [ - "%pip install jupysql pyspark==3.4.1 arrow pyarrow==12.0.1 pandas grpc-status --quiet" + "%pip install jupysql pyspark==3.4.1 arrow pyarrow==12.0.1 pandas grpcio-status --quiet" ] }, { @@ -51,7 +51,7 @@ }, { "cell_type": "code", - "execution_count": 26, + "execution_count": 15, "metadata": { "tags": [] }, @@ -60,7 +60,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "8e3831298eb9dce9d26b2f2d3df5a1a64634a19976f4eaeabf92c45afd478a39\n" + "12f699ee8e8e35ab10186f3c39024a7e443691bb4213e56ca3c2e90cd80daf1b\n" ] } ], @@ -89,7 +89,7 @@ }, { "cell_type": "code", - "execution_count": 27, + "execution_count": 16, "metadata": { "tags": [] }, @@ -115,7 +115,7 @@ }, { "cell_type": "code", - "execution_count": 28, + "execution_count": 17, "metadata": {}, "outputs": [], "source": [ @@ -139,7 +139,7 @@ }, { "cell_type": "code", - "execution_count": 29, + "execution_count": 18, "metadata": { "tags": [] }, @@ -160,7 +160,7 @@ }, { "cell_type": "code", - "execution_count": 30, + "execution_count": 19, "metadata": { "tags": [] }, @@ -180,7 +180,7 @@ }, { "cell_type": "code", - "execution_count": 31, + "execution_count": 20, "metadata": { "tags": [] }, @@ -209,10 +209,80 @@ }, { "cell_type": "code", - "execution_count": 47, + "execution_count": 22, "metadata": { "tags": [] }, + "outputs": [ + { + "data": { + "text/html": [ + "Running query in 'SparkSession'" + ], + "text/plain": [ + "Running query in 'SparkSession'" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
namespaceviewNameisTemporary
taxiTrue
" + ], + "text/plain": [ + "+-----------+----------+-------------+\n", + "| namespace | viewName | isTemporary |\n", + "+-----------+----------+-------------+\n", + "| | taxi | True |\n", + "+-----------+----------+-------------+" + ] + }, + "execution_count": 22, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "%sql show views in default" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can turn on `lazy_spark` to avoid executing spark plan and return a Spark Dataframe" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [], + "source": [ + "%config SqlMagic.lazy_spark = True" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, "outputs": [ { "data": { @@ -242,7 +312,7 @@ "data": { "text/plain": [] }, - "execution_count": 47, + "execution_count": 8, "metadata": {}, "output_type": "execute_result" } @@ -251,6 +321,15 @@ "%sql show views in default" ] }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [], + "source": [ + "%config SqlMagic.lazy_spark = False" + ] + }, { "attachments": {}, "cell_type": "markdown", @@ -261,7 +340,7 @@ }, { "cell_type": "code", - "execution_count": 33, + "execution_count": 25, "metadata": { "tags": [] }, @@ -308,7 +387,7 @@ ], "source": [ "df = %sql select * from taxi\n", - "df.printSchema()" + "df.sqlaproxy.dataframe.printSchema()" ] }, { @@ -321,7 +400,7 @@ }, { "cell_type": "code", - "execution_count": 34, + "execution_count": 26, "metadata": { "tags": [] }, @@ -338,23 +417,31 @@ "metadata": {}, "output_type": "display_data" }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "+--------+\n", - "|count(1)|\n", - "+--------+\n", - "|10000 |\n", - "+--------+\n", - "\n" - ] - }, { "data": { - "text/plain": [] + "text/html": [ + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
count(1)
10000
" + ], + "text/plain": [ + "+----------+\n", + "| count(1) |\n", + "+----------+\n", + "| 10000 |\n", + "+----------+" + ] }, - "execution_count": 34, + "execution_count": 26, "metadata": {}, "output_type": "execute_result" } @@ -374,7 +461,7 @@ }, { "cell_type": "code", - "execution_count": 35, + "execution_count": 27, "metadata": { "tags": [] }, @@ -385,7 +472,7 @@ }, { "cell_type": "code", - "execution_count": 36, + "execution_count": 28, "metadata": { "tags": [] }, @@ -402,23 +489,31 @@ "metadata": {}, "output_type": "display_data" }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "+--------+\n", - "|count(1)|\n", - "+--------+\n", - "|9476 |\n", - "+--------+\n", - "\n" - ] - }, { "data": { - "text/plain": [] + "text/html": [ + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
count(1)
9476
" + ], + "text/plain": [ + "+----------+\n", + "| count(1) |\n", + "+----------+\n", + "| 9476 |\n", + "+----------+" + ] }, - "execution_count": 36, + "execution_count": 28, "metadata": {}, "output_type": "execute_result" } @@ -431,7 +526,7 @@ }, { "cell_type": "code", - "execution_count": 37, + "execution_count": 29, "metadata": { "tags": [] }, @@ -442,7 +537,7 @@ }, { "cell_type": "code", - "execution_count": 38, + "execution_count": 30, "metadata": { "tags": [] }, @@ -459,23 +554,31 @@ "metadata": {}, "output_type": "display_data" }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "+--------+\n", - "|count(1)|\n", - "+--------+\n", - "|642 |\n", - "+--------+\n", - "\n" - ] - }, { "data": { - "text/plain": [] + "text/html": [ + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
count(1)
642
" + ], + "text/plain": [ + "+----------+\n", + "| count(1) |\n", + "+----------+\n", + "| 642 |\n", + "+----------+" + ] }, - "execution_count": 38, + "execution_count": 30, "metadata": {}, "output_type": "execute_result" } @@ -496,7 +599,7 @@ }, { "cell_type": "code", - "execution_count": 39, + "execution_count": 31, "metadata": { "tags": [] }, @@ -537,7 +640,7 @@ }, { "cell_type": "code", - "execution_count": 40, + "execution_count": 32, "metadata": { "tags": [] }, @@ -554,23 +657,35 @@ "metadata": {}, "output_type": "display_data" }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "+------------------+------------------+------------------+\n", - "|min(trip_distance)|avg(trip_distance)|max(trip_distance)|\n", - "+------------------+------------------+------------------+\n", - "|0.0 |3.1091381872213963|18.46 |\n", - "+------------------+------------------+------------------+\n", - "\n" - ] - }, { "data": { - "text/plain": [] + "text/html": [ + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
min(trip_distance)avg(trip_distance)max(trip_distance)
0.03.109138187221396318.46
" + ], + "text/plain": [ + "+--------------------+--------------------+--------------------+\n", + "| min(trip_distance) | avg(trip_distance) | max(trip_distance) |\n", + "+--------------------+--------------------+--------------------+\n", + "| 0.0 | 3.1091381872213963 | 18.46 |\n", + "+--------------------+--------------------+--------------------+" + ] }, - "execution_count": 40, + "execution_count": 32, "metadata": {}, "output_type": "execute_result" } @@ -591,7 +706,7 @@ }, { "cell_type": "code", - "execution_count": 41, + "execution_count": 33, "metadata": { "tags": [] }, @@ -616,6 +731,326 @@ "print(query)" ] }, + { + "cell_type": "code", + "execution_count": 34, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
Following statistics are not available in\n", + " SparkSession: STD, 25%, 50%, 75%
\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
VendorIDtpep_pickup_datetimetpep_dropoff_datetimepassenger_counttrip_distanceRatecodeIDstore_and_fwd_flagPULocationIDDOLocationIDpayment_typefare_amountextramta_taxtip_amounttolls_amountimprovement_surchargetotal_amountcongestion_surchargeairport_fee
count1000010000100001000010000100001000010000100001000010000100001000010000100001000010000100000
unique287668745712436217323042288350418395930
topnan2021-01-01 00:41:192021-01-02 00:00:00nannannanNnannannannannannannannannannannanNone
freqnan47nannannan9808nannannannannannannannannannannan0
mean1.6901nannan1.50803.10021.0712nan158.5551154.72961.381911.88220.82590.48641.78460.22460.294516.96962.1063nan
std0.4625nannan1.13543.59701.0755nan70.928875.25040.555210.84201.11670.10412.43511.27300.057012.50230.9562nan
min1nannan0.00.01.0nan111-100.0-0.5-0.5-1.07-6.12-0.3-100.3-2.5nan
25%1.0000nannan1.00001.04001.0000nan100.000083.00001.00006.00000.00000.50000.00000.00000.300010.30002.5000nan
50%2.0000nannan1.00001.93001.0000nan152.0000151.00001.00008.50000.50000.50001.54000.00000.300013.55002.5000nan
75%2.0000nannan2.00003.60001.0000nan234.0000234.00002.000013.50002.50000.50002.65000.00000.300019.30002.5000nan
max2nannan6.045.9299.0nan2652654121.03.50.580.025.50.3137.762.5nan
" + ], + "text/plain": [ + "+--------+----------+----------------------+-----------------------+-----------------+---------------+------------+--------------------+--------------+--------------+--------------+-------------+--------+---------+------------+--------------+-----------------------+--------------+----------------------+-------------+\n", + "| | VendorID | tpep_pickup_datetime | tpep_dropoff_datetime | passenger_count | trip_distance | RatecodeID | store_and_fwd_flag | PULocationID | DOLocationID | payment_type | fare_amount | extra | mta_tax | tip_amount | tolls_amount | improvement_surcharge | total_amount | congestion_surcharge | airport_fee |\n", + "+--------+----------+----------------------+-----------------------+-----------------+---------------+------------+--------------------+--------------+--------------+--------------+-------------+--------+---------+------------+--------------+-----------------------+--------------+----------------------+-------------+\n", + "| count | 10000 | 10000 | 10000 | 10000 | 10000 | 10000 | 10000 | 10000 | 10000 | 10000 | 10000 | 10000 | 10000 | 10000 | 10000 | 10000 | 10000 | 10000 | 0 |\n", + "| unique | 2 | 8766 | 8745 | 7 | 1243 | 6 | 2 | 173 | 230 | 4 | 228 | 8 | 3 | 504 | 18 | 3 | 959 | 3 | 0 |\n", + "| top | nan | 2021-01-01 00:41:19 | 2021-01-02 00:00:00 | nan | nan | nan | N | nan | nan | nan | nan | nan | nan | nan | nan | nan | nan | nan | None |\n", + "| freq | nan | 4 | 7 | nan | nan | nan | 9808 | nan | nan | nan | nan | nan | nan | nan | nan | nan | nan | nan | 0 |\n", + "| mean | 1.6901 | nan | nan | 1.5080 | 3.1002 | 1.0712 | nan | 158.5551 | 154.7296 | 1.3819 | 11.8822 | 0.8259 | 0.4864 | 1.7846 | 0.2246 | 0.2945 | 16.9696 | 2.1063 | nan |\n", + "| std | 0.4625 | nan | nan | 1.1354 | 3.5970 | 1.0755 | nan | 70.9288 | 75.2504 | 0.5552 | 10.8420 | 1.1167 | 0.1041 | 2.4351 | 1.2730 | 0.0570 | 12.5023 | 0.9562 | nan |\n", + "| min | 1 | nan | nan | 0.0 | 0.0 | 1.0 | nan | 1 | 1 | 1 | -100.0 | -0.5 | -0.5 | -1.07 | -6.12 | -0.3 | -100.3 | -2.5 | nan |\n", + "| 25% | 1.0000 | nan | nan | 1.0000 | 1.0400 | 1.0000 | nan | 100.0000 | 83.0000 | 1.0000 | 6.0000 | 0.0000 | 0.5000 | 0.0000 | 0.0000 | 0.3000 | 10.3000 | 2.5000 | nan |\n", + "| 50% | 2.0000 | nan | nan | 1.0000 | 1.9300 | 1.0000 | nan | 152.0000 | 151.0000 | 1.0000 | 8.5000 | 0.5000 | 0.5000 | 1.5400 | 0.0000 | 0.3000 | 13.5500 | 2.5000 | nan |\n", + "| 75% | 2.0000 | nan | nan | 2.0000 | 3.6000 | 1.0000 | nan | 234.0000 | 234.0000 | 2.0000 | 13.5000 | 2.5000 | 0.5000 | 2.6500 | 0.0000 | 0.3000 | 19.3000 | 2.5000 | nan |\n", + "| max | 2 | nan | nan | 6.0 | 45.92 | 99.0 | nan | 265 | 265 | 4 | 121.0 | 3.5 | 0.5 | 80.0 | 25.5 | 0.3 | 137.76 | 2.5 | nan |\n", + "+--------+----------+----------------------+-----------------------+-----------------+---------------+------------+--------------------+--------------+--------------+--------------+-------------+--------+---------+------------+--------------+-----------------------+--------------+----------------------+-------------+" + ] + }, + "execution_count": 34, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "%sqlcmd profile -t taxi" + ] + }, { "attachments": {}, "cell_type": "markdown", @@ -634,26 +1069,24 @@ }, { "cell_type": "code", - "execution_count": 45, + "execution_count": 35, "metadata": { "tags": [] }, "outputs": [ { "data": { - "text/html": [ - "Running query in 'SparkSession'" - ], "text/plain": [ - "Running query in 'SparkSession'" + "" ] }, + "execution_count": 35, "metadata": {}, - "output_type": "display_data" + "output_type": "execute_result" }, { "data": { - "image/png": "", + "image/png": "", "text/plain": [ "
" ] @@ -663,33 +1096,53 @@ } ], "source": [ - "result = %sql SELECT trip_distance FROM taxi\n", - "\n", - "import matplotlib.pyplot as plt\n", - "\n", - "data = result.toPandas()\n", - "\n", - "plt.hist(data[\"trip_distance\"])\n", - "plt.xlabel(\"Trip Distance\")\n", - "plt.ylabel(\"Frequency\")\n", - "plt.title(\"Histogram of Trip Distance\")\n", - "plt.show()" + "%sqlplot histogram --table taxi --column trip_distance --bins 10" ] }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 36, "metadata": { "tags": [] }, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 36, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "%sqlplot boxplot --table taxi --column trip_distance" + ] + }, + { + "cell_type": "code", + "execution_count": 37, + "metadata": {}, "outputs": [ { "data": { "text/html": [ - "Running query in 'SparkSession'" + "Removing NULLs, if there exists any from payment_type" ], "text/plain": [ - "Running query in 'SparkSession'" + "Removing NULLs, if there exists any from payment_type" ] }, "metadata": {}, @@ -697,7 +1150,17 @@ }, { "data": { - "image/png": "", + "text/plain": [ + "" + ] + }, + "execution_count": 37, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "", "text/plain": [ "
" ] @@ -707,17 +1170,88 @@ } ], "source": [ - "result = %sql SELECT trip_distance FROM taxi\n", - "\n", - "import matplotlib.pyplot as plt\n", - "\n", - "data = result.toPandas()\n", - "\n", - "plt.boxplot(data[\"trip_distance\"])\n", - "plt.xlabel(\"Trip Distance\")\n", - "plt.ylabel(\"Value\")\n", - "plt.title(\"Boxplot of Trip Distance\")\n", - "plt.show()" + "%sqlplot bar --table taxi --column payment_type" + ] + }, + { + "cell_type": "code", + "execution_count": 38, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "Removing NULLs, if there exists any from payment_type" + ], + "text/plain": [ + "Removing NULLs, if there exists any from payment_type" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 38, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "%sqlplot pie --table taxi --column payment_type" + ] + }, + { + "cell_type": "code", + "execution_count": 39, + "metadata": {}, + "outputs": [], + "source": [ + "from sql.ggplot import ggplot, aes, geom_boxplot, geom_histogram, facet_wrap" + ] + }, + { + "cell_type": "code", + "execution_count": 40, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 40, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "(ggplot(table=\"taxi\", mapping=aes(x=\"trip_distance\")) + geom_histogram(bins=10))" ] }, { @@ -732,7 +1266,7 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 41, "metadata": { "tags": [] }, @@ -741,9 +1275,9 @@ "name": "stdout", "output_type": "stream", "text": [ - "CONTAINER ID IMAGE COMMAND CREATED STATUS PORTS NAMES\n", - "4a7133e0b70b wh1isper/sparglim-server \"tini -- sparglim-se…\" 4 minutes ago Up 4 minutes 0.0.0.0:4040->4040/tcp, 0.0.0.0:15002->15002/tcp spark\n", - "f019407c6426 docker.dev.slicelife.com/onelogin-aws-assume-role:stable \"onelogin-aws-assume…\" 2 weeks ago Up 2 weeks heuristic_tu\n" + "CONTAINER ID IMAGE COMMAND CREATED STATUS PORTS NAMES\n", + "12f699ee8e8e wh1isper/sparglim-server \"tini -- sparglim-se…\" About a minute ago Up About a minute 0.0.0.0:4040->4040/tcp, 0.0.0.0:15002->15002/tcp spark\n", + "f019407c6426 docker.dev.slicelife.com/onelogin-aws-assume-role:stable \"onelogin-aws-assume…\" 2 weeks ago Up 2 weeks heuristic_tu\n" ] } ], @@ -753,7 +1287,7 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 42, "metadata": { "tags": [] }, @@ -765,7 +1299,7 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 43, "metadata": { "tags": [] }, @@ -774,7 +1308,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "Container id: 4a7133e0b70b\n" + "Container id: 12f699ee8e8e\n" ] } ], @@ -785,7 +1319,7 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": 44, "metadata": { "tags": [] }, @@ -794,7 +1328,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "4a7133e0b70b\n" + "12f699ee8e8e\n" ] } ], @@ -804,7 +1338,7 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": 45, "metadata": { "tags": [] }, @@ -813,7 +1347,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "4a7133e0b70b\n" + "12f699ee8e8e\n" ] } ], diff --git a/src/sql/connection/connection.py b/src/sql/connection/connection.py index bf7bc44e0..fb2a0159d 100644 --- a/src/sql/connection/connection.py +++ b/src/sql/connection/connection.py @@ -17,6 +17,8 @@ ProgrammingError, ) +from sql.run.sparkdataframe import handle_spark_dataframe + try: from pyspark.sql.connect.session import SparkSession as CSparkSession from pyspark.sql import SparkSession @@ -1099,7 +1101,7 @@ def dialect(self): def raw_execute(self, query, parameters=None): """Run the query without any pre-processing""" - return self._connection.sql(query) + return handle_spark_dataframe(self._connection.sql(query)) def _get_database_information(self): """ diff --git a/src/sql/magic.py b/src/sql/magic.py index d34d32e2c..32d231d6b 100644 --- a/src/sql/magic.py +++ b/src/sql/magic.py @@ -147,6 +147,14 @@ class SqlMagic(Magics, Configurable): config=True, help="Verbosity level. 0=minimal, 1=normal, 2=all", ) + lazy_spark = Bool( + default_value=False, + config=True, + help="Whether to evalute using ResultSet which will " + "cause the plan to execute or just return a Spark " + "DataFrame plan allowing lazy execution but still " + "validating schemas.", + ) named_parameters = Bool( default_value=False, config=True, diff --git a/src/sql/plot.py b/src/sql/plot.py index 10e7e7895..d61f3d2f6 100644 --- a/src/sql/plot.py +++ b/src/sql/plot.py @@ -271,7 +271,7 @@ def _min_max(conn, table, column, with_=None, use_backticks=False): template = Template(template_) query = template.render(table=table, column=column) - + y = conn.execute(query, with_) min_, max_ = conn.execute(query, with_).fetchone() return min_, max_ diff --git a/src/sql/run/run.py b/src/sql/run/run.py index f59ef96d2..3b3c2c719 100644 --- a/src/sql/run/run.py +++ b/src/sql/run/run.py @@ -3,6 +3,7 @@ from sql import exceptions, display from sql.run.resultset import ResultSet from sql.run.pgspecial import handle_postgres_special +from sql.run.sparkdataframe import handle_spark_dataframe # TODO: conn also has access to config, we should clean this up to provide a clean @@ -50,7 +51,9 @@ def run_statements(conn, sql, config, parameters=None): result = handle_postgres_special(conn, statement) if is_spark(conn.dialect): - return conn.raw_execute(statement, parameters=parameters) + result = conn.raw_execute(statement, parameters=parameters) + if config.lazy_spark: + return conn.raw_execute(statement, parameters=parameters).dataframe # regular query else: diff --git a/src/sql/run/sparkdataframe.py b/src/sql/run/sparkdataframe.py new file mode 100644 index 000000000..4ac86b35e --- /dev/null +++ b/src/sql/run/sparkdataframe.py @@ -0,0 +1,42 @@ +try: + from pyspark.sql import DataFrame + from pyspark.sql.connect.dataframe import DataFrame as CDataFrame +except ModuleNotFoundError: + DataFrame = None + CDataFrame = None + +from sql import exceptions + + +def handle_spark_dataframe(dataframe, should_cache=False): + """Execute a ResultSet sqlaproxy using pysark module.""" + if not DataFrame and not CDataFrame: + raise exceptions.MissingPackageError("pysark not installed") + + return FakeResultProxy(dataframe, dataframe.columns, should_cache) + + +class FakeResultProxy(object): + """A fake class that pretends to behave like the ResultProxy from + SqlAlchemy. + """ + + dataframe = None + + def __init__(self, dataframe: DataFrame, headers, should_cache): + self.dataframe = dataframe + self.fetchall = dataframe.collect + self.rowcount = dataframe.count + self.keys = lambda: headers + self.returns_rows = True + if should_cache: + self.dataframe.cache() + + def fetchmany(self, size): + return self.dataframe.take(size) + + def fetchone(self): + return self.dataframe.head() + + def close(self): + self.dataframe.unpersist() From e6a25a780baed0ce617482882a958ab448ee800e Mon Sep 17 00:00:00 2001 From: richard gilmore Date: Wed, 20 Dec 2023 00:23:38 +0000 Subject: [PATCH 13/29] finishing integration tests --- CHANGELOG.md | 1 + doc/integrations/compatibility.md | 4 +- doc/integrations/spark.ipynb | 16 ++-- setup.py | 1 + src/sql/_testing.py | 4 + src/sql/connection/connection.py | 12 ++- src/sql/error_handler.py | 4 + src/sql/run/resultset.py | 11 ++- src/sql/run/sparkdataframe.py | 2 +- src/sql/stats.py | 1 - src/sql/util.py | 2 + src/tests/integration/conftest.py | 40 ++++++++- .../integration/test_generic_db_operations.py | 83 ++++++++++++++++++- src/tests/integration/test_stats.py | 22 ++++- 14 files changed, 182 insertions(+), 21 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 69dfdc271..0682b5139 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,7 @@ # CHANGELOG ## 0.10.6dev +* [Feature] Add Spark Connection as a dialect for Jupysql (#965) * [Fix] Fix error when `%sql` includes a query with negative numbers (#958) ## 0.10.5 (2023-12-11) diff --git a/doc/integrations/compatibility.md b/doc/integrations/compatibility.md index 1771823f9..67d1e9258 100644 --- a/doc/integrations/compatibility.md +++ b/doc/integrations/compatibility.md @@ -120,7 +120,7 @@ These table reflects the compatibility status of JupySQL `>=0.7` - Running queries with `%%sql` ✅ - CTEs with `%%sql --save NAME` ✅ -- Plotting with `%%sqlplot boxplot` ✅ +- Plotting with `%%sqlplot boxplot` ❓ - Plotting with `%%sqlplot bar` ✅ - Plotting with `%%sqlplot pie` ✅ - Plotting with `%%sqlplot histogram` ✅ @@ -130,4 +130,4 @@ These table reflects the compatibility status of JupySQL `>=0.7` - Listing columns with `%sqlcmd columns` ❌ - Parametrized SQL queries via `{{parameter}}` ✅ - Interactive SQL queries via `--interact` ✅ -- Persiting Dataframes via `--persist` ❌ \ No newline at end of file +- Persiting Dataframes via `--persist` ✅ \ No newline at end of file diff --git a/doc/integrations/spark.ipynb b/doc/integrations/spark.ipynb index 90e296902..fa9f14c46 100644 --- a/doc/integrations/spark.ipynb +++ b/doc/integrations/spark.ipynb @@ -708,6 +708,7 @@ "cell_type": "code", "execution_count": 33, "metadata": { + "scrolled": true, "tags": [] }, "outputs": [ @@ -731,6 +732,13 @@ "print(query)" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Profiling" + ] + }, { "cell_type": "code", "execution_count": 34, @@ -1059,14 +1067,6 @@ "## Plotting" ] }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "The %sqlplot magic command currently supported by the spark connection but you can still plot as like this:" - ] - }, { "cell_type": "code", "execution_count": 35, diff --git a/setup.py b/setup.py index 50c5a19ce..e1ceb01d9 100644 --- a/setup.py +++ b/setup.py @@ -36,6 +36,7 @@ "polars==0.17.2", # 04/18/23 this breaks our CI "pyarrow", "pyspark", + "grpcio-status", "invoke", "pkgmt", "twine", diff --git a/src/sql/_testing.py b/src/sql/_testing.py index 14a5e9675..041994e28 100644 --- a/src/sql/_testing.py +++ b/src/sql/_testing.py @@ -210,6 +210,10 @@ def get_tmp_dir(): "docker_ct": None, "query": {}, }, + "spark": { + "alias": "SparkSession", + "drivername": "SparkSession", + }, "clickhouse": { "drivername": "clickhouse+native", "username": "username", diff --git a/src/sql/connection/connection.py b/src/sql/connection/connection.py index fb2a0159d..8490fe42f 100644 --- a/src/sql/connection/connection.py +++ b/src/sql/connection/connection.py @@ -1130,9 +1130,15 @@ def connection_sqlalchemy(self): ) def to_table(self, table_name, data_frame, if_exists, index, schema=None): - raise exceptions.NotImplementedError( - "--persist/--persist-replace is not available for Spark connections" - " (only available for SQLAlchemy connections)" + mode = ( + "overwrite" + if if_exists == "replace" + else "append" + if if_exists == "append" + else "error" + ) + self._connection.createDataFrame(data_frame).write.mode(mode).saveAsTable( + f"{schema}.{table_name}" if schema else table_name ) def close(self): diff --git a/src/sql/error_handler.py b/src/sql/error_handler.py index 0eac9292a..ef1def3f5 100644 --- a/src/sql/error_handler.py +++ b/src/sql/error_handler.py @@ -107,8 +107,12 @@ def handle_exception(error, query=None, short_error=True): if util.is_sqlalchemy_error(error) or util.is_non_sqlalchemy_error(error): detailed_message, error_type = _detailed_message_with_error_type(error, query) if short_error: + print("YEY") _raise_error(error, detailed_message, error_type) else: + print("NEH") _display_error_msg_with_trace(error, detailed_message) else: + print("sadge") + print(str(error)) raise error diff --git a/src/sql/run/resultset.py b/src/sql/run/resultset.py index 8451150aa..15a77c775 100644 --- a/src/sql/run/resultset.py +++ b/src/sql/run/resultset.py @@ -434,7 +434,11 @@ def fetchmany(self, size): raise RuntimeError(f"Error running the query: {str(e)}") from e self.mark_fetching_as_done() return - + # spark doesn't support curser + if hasattr(self._sqlaproxy, "dataframe"): + self._results = [] + self._pretty_table.clear() + print(self._conn) self._extend_results(returned) if len(returned) < size: @@ -458,6 +462,9 @@ def fetch_for_repr_if_needed(self): def fetchall(self): if not self._done_fetching(): + if hasattr(self._sqlaproxy, "dataframe"): + self._results = [] + self._pretty_table.clear() self._extend_results(self.sqlaproxy.fetchall()) self.mark_fetching_as_done() @@ -500,6 +507,8 @@ def _convert_to_data_frame( # maybe create accessors in the connection objects? if result_set._conn.is_dbapi_connection: native_connection = result_set.sqlaproxy + elif hasattr(result_set.sqlaproxy, "dataframe"): + return result_set.sqlaproxy.dataframe.toPandas() else: native_connection = result_set._conn._connection.connection diff --git a/src/sql/run/sparkdataframe.py b/src/sql/run/sparkdataframe.py index 4ac86b35e..9f8d96bfa 100644 --- a/src/sql/run/sparkdataframe.py +++ b/src/sql/run/sparkdataframe.py @@ -23,7 +23,7 @@ class FakeResultProxy(object): dataframe = None - def __init__(self, dataframe: DataFrame, headers, should_cache): + def __init__(self, dataframe, headers, should_cache): self.dataframe = dataframe self.fetchall = dataframe.collect self.rowcount = dataframe.count diff --git a/src/sql/stats.py b/src/sql/stats.py index 0fc12e87f..b03252154 100644 --- a/src/sql/stats.py +++ b/src/sql/stats.py @@ -45,7 +45,6 @@ def _summary_stats_one_by_one(conn, table, column, with_=None): other = list(conn.execute(query, with_).fetchone()) keys = ["q1", "med", "q3", "mean", "N"] - return {k: float(v) for k, v in zip(keys, percentiles + other)} diff --git a/src/sql/util.py b/src/sql/util.py index b0d9c0c3b..0956ac68c 100644 --- a/src/sql/util.py +++ b/src/sql/util.py @@ -543,6 +543,8 @@ def is_non_sqlalchemy_error(error): "pyodbc.ProgrammingError", # Clickhouse errors "DB::Exception:", + # Pyspark + "UNRESOLVED_ROUTINE", ] return any(msg in str(error) for msg in specific_db_errors) diff --git a/src/tests/integration/conftest.py b/src/tests/integration/conftest.py index d92fe4d70..a07241971 100644 --- a/src/tests/integration/conftest.py +++ b/src/tests/integration/conftest.py @@ -290,10 +290,46 @@ def setup_duckDB_native(test_table_name_dict): @pytest.fixture(scope="session") -def setup_spark(): - spark = SparkSession.Builder.master("local").getOrCreate() +def setup_spark(test_table_name_dict): + import os + import shutil + + os.environ["PYSPARK_PYTHON"] = os.environ.get("CONDA_PYTHON_EXE") + os.environ["PYSPARK_DRIVER_PYTHON"] = os.environ.get("CONDA_PYTHON_EXE") + spark = SparkSession.builder.master("local[1]").enableHiveSupport().getOrCreate() + load_generic_testing_data_spark(spark, test_table_name_dict) yield spark spark.stop() + shutil.rmtree("metastore_db") + shutil.rmtree("spark-warehouse") + os.remove("derby.log") + + +def load_generic_testing_data_spark(spark: SparkSession, test_table_name_dict): + spark.createDataFrame( + pd.DataFrame( + {"taxi_driver_name": ["Eric Ken", "John Smith", "Kevin Kelly"] * 15} + ) + ).createOrReplaceTempView(test_table_name_dict["taxi"]) + spark.createDataFrame( + pd.DataFrame({"x": range(0, 5), "y": range(5, 10)}) + ).createOrReplaceTempView(test_table_name_dict["plot_something"]) + spark.createDataFrame( + pd.DataFrame({"numbers_elements": [1, 2, 3] * 20}) + ).createOrReplaceTempView(test_table_name_dict["numbers"]) + + +@pytest.fixture +def ip_with_spark(ip_empty, setup_spark): + configKey = "spark" + alias = "SparkSession" + + ip_empty.push({"conn": setup_spark}) + # Select database engine, use different sqlite database endpoint + ip_empty.run_cell("%sql " + "conn" + " --alias " + alias) + yield ip_empty + # Disconnect database + ip_empty.run_cell("%sql -x " + alias) def load_generic_testing_data_duckdb_native(ip, test_table_name_dict): diff --git a/src/tests/integration/test_generic_db_operations.py b/src/tests/integration/test_generic_db_operations.py index 47325ed16..2255c4485 100644 --- a/src/tests/integration/test_generic_db_operations.py +++ b/src/tests/integration/test_generic_db_operations.py @@ -21,6 +21,7 @@ "ip_with_Snowflake", "ip_with_oracle", "ip_with_clickhouse", + "ip_with_spark", ] @@ -54,6 +55,7 @@ def mock_log_api(monkeypatch): ("ip_with_clickhouse", "", "LIMIT 3"), ("ip_with_oracle", "", "FETCH FIRST 3 ROWS ONLY"), ("ip_with_MSSQL", "TOP 3", ""), + ("ip_with_spark", "", "LIMIT 3"), ], ) def test_run_query( @@ -93,6 +95,7 @@ def test_run_query( "ip_with_Snowflake", "ip_with_redshift", "ip_with_clickhouse", + "ip_with_spark", ], ) def test_handle_multiple_open_result_sets( @@ -151,6 +154,7 @@ def test_handle_multiple_open_result_sets( "No engine for table " ), ), + ("ip_with_spark", "--no-index"), ], ) def test_create_table_with_indexed_df( @@ -218,6 +222,7 @@ def get_connection_count(ip_with_dynamic_db): ("ip_with_MSSQL", 1), ("ip_with_Snowflake", 1), ("ip_with_clickhouse", 1), + ("ip_with_spark", 1), ], ) def test_active_connection_number(ip_with_dynamic_db, expected, request): @@ -273,6 +278,7 @@ def test_close_and_connect( ("ip_with_Snowflake", "snowflake", "snowflake"), ("ip_with_oracle", "oracle", "oracledb"), ("ip_with_clickhouse", "clickhouse", "native"), + ("ip_with_spark", "spark", "SparkSession"), ], ) def test_telemetry_execute_command_has_connection_info( @@ -337,6 +343,7 @@ def test_telemetry_execute_command_has_connection_info( ("ip_with_Snowflake"), ("ip_with_duckDB_native"), ("ip_with_redshift"), + ("ip_with_spark"), pytest.param( "ip_with_MSSQL", marks=pytest.mark.xfail(reason="sqlglot does not support SQL server"), @@ -419,6 +426,9 @@ def test_sqlplot_histogram(ip_with_dynamic_db, cell, request, test_table_name_di reason="Plotting from snippet not working in clickhouse" ), ), + pytest.param( + "ip_with_spark", marks=pytest.mark.xfail(reason=BOX_PLOT_FAIL_REASON) + ), ], ) def test_sqlplot_boxplot(ip_with_dynamic_db, cell, request, test_table_name_dict): @@ -442,6 +452,7 @@ def test_sqlplot_boxplot(ip_with_dynamic_db, cell, request, test_table_name_dict "ip_with_duckDB", "ip_with_redshift", "ip_with_MSSQL", + "ip_with_spark", ], ) def test_sqlplot_bar(ip_with_dynamic_db, request, test_table_name_dict): @@ -464,7 +475,13 @@ def test_sqlplot_bar(ip_with_dynamic_db, request, test_table_name_dict): @pytest.mark.parametrize( "ip_with_dynamic_db", - ["ip_with_postgreSQL", "ip_with_duckDB", "ip_with_redshift", "ip_with_MSSQL"], + [ + "ip_with_postgreSQL", + "ip_with_duckDB", + "ip_with_redshift", + "ip_with_MSSQL", + "ip_with_spark", + ], ) def test_sqlplot_pie(ip_with_dynamic_db, request, test_table_name_dict): plt.cla() @@ -517,6 +534,10 @@ def test_sqlplot_pie(ip_with_dynamic_db, request, test_table_name_dict): reason="Plotting from snippet not working in clickhouse" ), ), + pytest.param( + "ip_with_spark", + marks=pytest.mark.xfail(reason="Schema not implemented"), + ), ], ) def test_sqlplot_using_schema(ip_with_dynamic_db, request): @@ -569,6 +590,10 @@ def test_sqlplot_using_schema(ip_with_dynamic_db, request): ("ip_with_Snowflake"), ("ip_with_oracle"), ("ip_with_clickhouse"), + pytest.param( + "ip_with_spark", + marks=pytest.mark.xfail(reason="not supported yet for sparkconnections"), + ), ], ) def test_sqlcmd_test(ip_with_dynamic_db, request, test_table_name_dict): @@ -604,6 +629,7 @@ def test_sqlcmd_test(ip_with_dynamic_db, request, test_table_name_dict): ), ("ip_with_oracle"), ("ip_with_clickhouse"), + ("ip_with_spark"), ], ) def test_profile_data_mismatch(ip_with_dynamic_db, request, capsys): @@ -786,6 +812,25 @@ def test_profile_data_mismatch(ip_with_dynamic_db, request, capsys): }, "Following statistics are not available in", ), + ( + "ip_with_spark", + "taxi", + ["taxi_driver_name"], + { + "count": [45], + "mean": [math.nan], + "min": ["Eric Ken"], + "max": ["Kevin Kelly"], + "unique": [3], + "freq": [15], + "top": ["Eric Ken"], + "std": [math.nan], + "25%": [math.nan], + "50%": [math.nan], + "75%": [math.nan], + }, + None, + ), ], ) def test_sqlcmd_profile( @@ -847,6 +892,10 @@ def test_sqlcmd_profile( ("ip_with_MSSQL"), ("ip_with_Snowflake"), ("ip_with_clickhouse"), + pytest.param( + "ip_with_spark", + marks=pytest.mark.xfail(reason="Not Implemented"), + ), ], ) def test_sqlcmd_columns(ip_with_dynamic_db, table, request, test_table_name_dict): @@ -873,6 +922,10 @@ def test_sqlcmd_columns(ip_with_dynamic_db, table, request, test_table_name_dict ("ip_with_MSSQL"), ("ip_with_Snowflake"), ("ip_with_clickhouse"), + pytest.param( + "ip_with_spark", + marks=pytest.mark.xfail(reason="Not Implemented"), + ), ], ) def test_sqlcmd_tables(ip_with_dynamic_db, request): @@ -927,6 +980,7 @@ def test_sql_query(ip_with_dynamic_db, cell, request, test_table_name_dict): "ip_with_Snowflake", "ip_with_oracle", "ip_with_clickhouse", + "ip_with_spark", ], ) def test_sql_query_cte(ip_with_dynamic_db, request, test_table_name_dict, cell): @@ -957,6 +1011,10 @@ def test_sql_query_cte(ip_with_dynamic_db, request, test_table_name_dict, cell): "ip_with_clickhouse", marks=pytest.mark.xfail(reason="Not yet implemented"), ), + pytest.param( + "ip_with_spark", + marks=pytest.mark.xfail(reason="Not yet implemented"), + ), ], ) def test_sql_error_suggests_using_cte(ip_with_dynamic_db, request): @@ -987,6 +1045,7 @@ def test_sql_error_suggests_using_cte(ip_with_dynamic_db, request): "ip_with_MSSQL", "ip_with_oracle", "ip_with_clickhouse", + "ip_with_spark", ], ) def test_results_sets_are_closed(ip_with_dynamic_db, request, test_table_name_dict): @@ -1024,6 +1083,7 @@ def test_results_sets_are_closed(ip_with_dynamic_db, request, test_table_name_di "ip_with_MSSQL", "ip_with_oracle", "ip_with_clickhouse", + "ip_with_spark", ], ) @pytest.mark.parametrize( @@ -1150,6 +1210,7 @@ def test_autocommit_retrieve_existing_resultssets_duckdb_from( CREATE_TABLE, marks=pytest.mark.xfail(reason="Not working yet"), ), + ("ip_with_spark", CREATE_TABLE), ], ) def test_autocommit_create_table_single_cell( @@ -1222,6 +1283,7 @@ def test_autocommit_create_table_single_cell( CREATE_TABLE, marks=pytest.mark.xfail(reason="Not working yet"), ), + ("ip_with_spark", CREATE_TABLE), ], ) def test_autocommit_create_table_multiple_cells( @@ -1408,6 +1470,20 @@ def test_autocommit_create_table_multiple_cells( ["Table with name mysnip does not exist!"], "RuntimeError", ), + ( + "ip_with_spark", + "mysnippet", + [ + "Cannot resolve function `not_a_function` on search path", + ], + "RuntimeError", + ), + ( + "ip_with_spark", + "mysnip", + ["Cannot resolve function `not_a_function` on search path"], + "RuntimeError", + ), ], ids=[ "no-typo-postgreSQL", @@ -1428,6 +1504,8 @@ def test_autocommit_create_table_multiple_cells( "with-typo-redshift", "no-typo-duckDB-native", "with-typo-duckDB-native", + "no-typo-spark", + "with-typo-spark", ], ) def test_query_snippet_invalid_function_error_message( @@ -1456,7 +1534,7 @@ def test_query_snippet_invalid_function_error_message( # Save result and test error message result_error = excinfo.value.error_type result_msg = str(excinfo.value) - + print(result_msg) assert error_type == result_error assert all(msg in result_msg for msg in error_msgs) @@ -1502,6 +1580,7 @@ def test_query_snippet_invalid_function_error_message( "No engine for table " ), ), + ("ip_with_spark", "--no-index"), ], ) def test_persist_in_schema(ip_with_dynamic_db, args, request, test_table_name_dict): diff --git a/src/tests/integration/test_stats.py b/src/tests/integration/test_stats.py index d8f93439d..01593f8ce 100644 --- a/src/tests/integration/test_stats.py +++ b/src/tests/integration/test_stats.py @@ -1,7 +1,7 @@ import pytest from sql.stats import _summary_stats -from sql.connection import SQLAlchemyConnection +from sql.connection import SQLAlchemyConnection, SparkConnectConnection @pytest.mark.parametrize( @@ -26,3 +26,23 @@ def test_summary_stats(fixture_name, request, test_table_name_dict): "mean": 2.0, "N": 5.0, } + + +@pytest.mark.parametrize( + "fixture_name", + [ + "setup_spark", + ], +) +def test_summary_stats(fixture_name, request, test_table_name_dict): + conn = SparkConnectConnection(request.getfixturevalue(fixture_name)) + table = test_table_name_dict["plot_something"] + column = "x" + + assert _summary_stats(conn, table, column) == { + "q1": 1.0, + "med": 2.0, + "q3": 3.0, + "mean": 2.0, + "N": 5.0, + } From a9716008146949d894d8978929386a5eaaf71674 Mon Sep 17 00:00:00 2001 From: richard gilmore Date: Wed, 20 Dec 2023 00:32:44 +0000 Subject: [PATCH 14/29] pass config and alias correctly --- src/sql/connection/connection.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/sql/connection/connection.py b/src/sql/connection/connection.py index 8490fe42f..c68c5acea 100644 --- a/src/sql/connection/connection.py +++ b/src/sql/connection/connection.py @@ -267,7 +267,9 @@ def set( elif is_pep249_compliant(descriptor): cls.current = DBAPIConnection(descriptor, config=config, alias=alias) elif is_spark(descriptor): - cls.current = SparkConnectConnection(descriptor) + cls.current = SparkConnectConnection( + descriptor, config=config, alias=alias + ) else: existing = rough_dict_get(cls.connections, descriptor) if existing and existing.alias == alias: From a64b2a314626f0f307195c1d93a3793602188993 Mon Sep 17 00:00:00 2001 From: richard gilmore Date: Wed, 20 Dec 2023 01:42:21 +0000 Subject: [PATCH 15/29] fixed issue with backticks and also implemented fake cursor --- src/sql/connection/connection.py | 2 +- src/sql/error_handler.py | 1 + src/sql/plot.py | 9 ++++++++- src/sql/run/sparkdataframe.py | 7 +++++++ src/sql/util.py | 1 + .../integration/test_generic_db_operations.py | 17 ++++------------- 6 files changed, 22 insertions(+), 15 deletions(-) diff --git a/src/sql/connection/connection.py b/src/sql/connection/connection.py index c68c5acea..cfdb85c8d 100644 --- a/src/sql/connection/connection.py +++ b/src/sql/connection/connection.py @@ -1099,7 +1099,7 @@ def __init__(self, payload, connection, alias=None, config=None): @property def dialect(self): """Returns a string with the SQL dialect name""" - return "spark" + return "spark2" def raw_execute(self, query, parameters=None): """Run the query without any pre-processing""" diff --git a/src/sql/error_handler.py b/src/sql/error_handler.py index ef1def3f5..7792ecbd9 100644 --- a/src/sql/error_handler.py +++ b/src/sql/error_handler.py @@ -50,6 +50,7 @@ def _detailed_message_with_error_type(error, query): "error in your sql syntax", "incorrect syntax", "invalid sql", + "syntax_error", ] not_found_substrings = [ r"(\btable with name\b).+(\bdoes not exist\b)", diff --git a/src/sql/plot.py b/src/sql/plot.py index d61f3d2f6..b6e8f9ca4 100644 --- a/src/sql/plot.py +++ b/src/sql/plot.py @@ -268,7 +268,7 @@ def _min_max(conn, table, column, with_=None, use_backticks=False): """ if use_backticks: template_ = template_.replace('"', "`") - + table = table.replace('"', "`") template = Template(template_) query = template.render(table=table, column=column) y = conn.execute(query, with_) @@ -628,6 +628,7 @@ def _histogram( if use_backticks: template_ = template_.replace('"', "`") + table = table.replace('"', "`") template = Template(template_) @@ -663,6 +664,7 @@ def _histogram( if use_backticks: template_ = template_.replace('"', "`") + table = table.replace('"', "`") template = Template(template_) @@ -681,6 +683,7 @@ def _histogram( if use_backticks: template_ = template_.replace('"', "`") + table = table.replace('"', "`") template = Template(template_) @@ -835,6 +838,7 @@ def _bar(table, column, with_=None, conn=None): if use_backticks: template_ = template_.replace('"', "`") + table = table.replace('"', "`") template = Template(template_) query = template.render(table=table, x_=x_, height_=height_) @@ -854,6 +858,7 @@ def _bar(table, column, with_=None, conn=None): if use_backticks: template_ = template_.replace('"', "`") + table = table.replace('"', "`") template = Template(template_) query = template.render(table=table, column=column) @@ -1022,6 +1027,7 @@ def _pie(table, column, with_=None, conn=None): """ if use_backticks: template_ = template_.replace('"', "`") + table = table.replace('"', "`") template = Template(template_) query = template.render(table=table, labels_=labels_, size_=size_) @@ -1037,6 +1043,7 @@ def _pie(table, column, with_=None, conn=None): """ if use_backticks: template_ = template_.replace('"', "`") + table = table.replace('"', "`") template = Template(template_) query = template.render(table=table, column=column) diff --git a/src/sql/run/sparkdataframe.py b/src/sql/run/sparkdataframe.py index 9f8d96bfa..8d4df5696 100644 --- a/src/sql/run/sparkdataframe.py +++ b/src/sql/run/sparkdataframe.py @@ -28,6 +28,7 @@ def __init__(self, dataframe, headers, should_cache): self.fetchall = dataframe.collect self.rowcount = dataframe.count self.keys = lambda: headers + self.cursor = FakeCursor(headers) self.returns_rows = True if should_cache: self.dataframe.cache() @@ -40,3 +41,9 @@ def fetchone(self): def close(self): self.dataframe.unpersist() + +class FakeCursor(object): + description = None + + def __init__(self,headers) -> None: + self.description = headers \ No newline at end of file diff --git a/src/sql/util.py b/src/sql/util.py index 0956ac68c..4f7fd6d4f 100644 --- a/src/sql/util.py +++ b/src/sql/util.py @@ -545,6 +545,7 @@ def is_non_sqlalchemy_error(error): "DB::Exception:", # Pyspark "UNRESOLVED_ROUTINE", + "PARSE_SYNTAX_ERROR", ] return any(msg in str(error) for msg in specific_db_errors) diff --git a/src/tests/integration/test_generic_db_operations.py b/src/tests/integration/test_generic_db_operations.py index 2255c4485..3bebaecd1 100644 --- a/src/tests/integration/test_generic_db_operations.py +++ b/src/tests/integration/test_generic_db_operations.py @@ -278,7 +278,7 @@ def test_close_and_connect( ("ip_with_Snowflake", "snowflake", "snowflake"), ("ip_with_oracle", "oracle", "oracledb"), ("ip_with_clickhouse", "clickhouse", "native"), - ("ip_with_spark", "spark", "SparkSession"), + ("ip_with_spark", "spark2", "SparkSession"), ], ) def test_telemetry_execute_command_has_connection_info( @@ -534,10 +534,7 @@ def test_sqlplot_pie(ip_with_dynamic_db, request, test_table_name_dict): reason="Plotting from snippet not working in clickhouse" ), ), - pytest.param( - "ip_with_spark", - marks=pytest.mark.xfail(reason="Schema not implemented"), - ), + "ip_with_spark" ], ) def test_sqlplot_using_schema(ip_with_dynamic_db, request): @@ -590,10 +587,7 @@ def test_sqlplot_using_schema(ip_with_dynamic_db, request): ("ip_with_Snowflake"), ("ip_with_oracle"), ("ip_with_clickhouse"), - pytest.param( - "ip_with_spark", - marks=pytest.mark.xfail(reason="not supported yet for sparkconnections"), - ), + ("ip_with_spark"), ], ) def test_sqlcmd_test(ip_with_dynamic_db, request, test_table_name_dict): @@ -1011,10 +1005,7 @@ def test_sql_query_cte(ip_with_dynamic_db, request, test_table_name_dict, cell): "ip_with_clickhouse", marks=pytest.mark.xfail(reason="Not yet implemented"), ), - pytest.param( - "ip_with_spark", - marks=pytest.mark.xfail(reason="Not yet implemented"), - ), + "ip_with_spark", ], ) def test_sql_error_suggests_using_cte(ip_with_dynamic_db, request): From 0a9b4de3b019cc55f884823b542fd7a9c831e052 Mon Sep 17 00:00:00 2001 From: richard gilmore Date: Wed, 20 Dec 2023 09:17:35 +0000 Subject: [PATCH 16/29] change configuration name --- doc/integrations/spark.ipynb | 4 ++-- src/sql/magic.py | 8 ++++---- src/sql/run/run.py | 2 +- src/sql/run/sparkdataframe.py | 7 ++++--- src/tests/integration/test_generic_db_operations.py | 2 +- 5 files changed, 12 insertions(+), 11 deletions(-) diff --git a/doc/integrations/spark.ipynb b/doc/integrations/spark.ipynb index fa9f14c46..99e121c18 100644 --- a/doc/integrations/spark.ipynb +++ b/doc/integrations/spark.ipynb @@ -276,7 +276,7 @@ "metadata": {}, "outputs": [], "source": [ - "%config SqlMagic.lazy_spark = True" + "%config SqlMagic.lazy_execution = True" ] }, { @@ -327,7 +327,7 @@ "metadata": {}, "outputs": [], "source": [ - "%config SqlMagic.lazy_spark = False" + "%config SqlMagic.lazy_execution = False" ] }, { diff --git a/src/sql/magic.py b/src/sql/magic.py index 32d231d6b..eddb344c6 100644 --- a/src/sql/magic.py +++ b/src/sql/magic.py @@ -147,13 +147,13 @@ class SqlMagic(Magics, Configurable): config=True, help="Verbosity level. 0=minimal, 1=normal, 2=all", ) - lazy_spark = Bool( + lazy_execution = Bool( default_value=False, config=True, help="Whether to evalute using ResultSet which will " - "cause the plan to execute or just return a Spark " - "DataFrame plan allowing lazy execution but still " - "validating schemas.", + "cause the plan to execute or just return a lazily " + "executed plan allowing validating schemas, " + "without expensive compute.", ) named_parameters = Bool( default_value=False, diff --git a/src/sql/run/run.py b/src/sql/run/run.py index 3b3c2c719..f9b1fd964 100644 --- a/src/sql/run/run.py +++ b/src/sql/run/run.py @@ -52,7 +52,7 @@ def run_statements(conn, sql, config, parameters=None): if is_spark(conn.dialect): result = conn.raw_execute(statement, parameters=parameters) - if config.lazy_spark: + if config.lazy_execution: return conn.raw_execute(statement, parameters=parameters).dataframe # regular query diff --git a/src/sql/run/sparkdataframe.py b/src/sql/run/sparkdataframe.py index 8d4df5696..77ef31854 100644 --- a/src/sql/run/sparkdataframe.py +++ b/src/sql/run/sparkdataframe.py @@ -42,8 +42,9 @@ def fetchone(self): def close(self): self.dataframe.unpersist() + class FakeCursor(object): - description = None + description = None - def __init__(self,headers) -> None: - self.description = headers \ No newline at end of file + def __init__(self, headers) -> None: + self.description = headers diff --git a/src/tests/integration/test_generic_db_operations.py b/src/tests/integration/test_generic_db_operations.py index 3bebaecd1..95722758a 100644 --- a/src/tests/integration/test_generic_db_operations.py +++ b/src/tests/integration/test_generic_db_operations.py @@ -534,7 +534,7 @@ def test_sqlplot_pie(ip_with_dynamic_db, request, test_table_name_dict): reason="Plotting from snippet not working in clickhouse" ), ), - "ip_with_spark" + "ip_with_spark", ], ) def test_sqlplot_using_schema(ip_with_dynamic_db, request): From 189e4a0f6e1c9f06733344d9ad526cadff16d5f9 Mon Sep 17 00:00:00 2001 From: richard gilmore Date: Wed, 20 Dec 2023 10:06:34 +0000 Subject: [PATCH 17/29] fix env variable error integration tests CI --- doc/api/configuration.md | 16 ++++++++++++++++ doc/integrations/spark.ipynb | 2 +- src/tests/integration/conftest.py | 5 +++-- 3 files changed, 20 insertions(+), 3 deletions(-) diff --git a/doc/api/configuration.md b/doc/api/configuration.md index e2bb114a5..b7945683e 100644 --- a/doc/api/configuration.md +++ b/doc/api/configuration.md @@ -234,6 +234,22 @@ value enables the ones from previous values plus new ones: - `2`: All feedback - Footer to distinguish pandas/polars data frames from JupySQL's result sets +## `lazy_execution` + +Default: `False` + +Return lazy relation to dataset rather than executing through JupySql. + +```{code-cell} ipython3 +%config SqlMagic.lazy_execution = True +df = %sql SELECT * FROM languages +``` + +```{code-cell} ipython3 +%config SqlMagic.lazy_execution = False +res = %sql SELECT * FROM languages +``` + ## `named_parameters` ```{versionadded} 0.9 diff --git a/doc/integrations/spark.ipynb b/doc/integrations/spark.ipynb index 99e121c18..01940eb0f 100644 --- a/doc/integrations/spark.ipynb +++ b/doc/integrations/spark.ipynb @@ -1221,7 +1221,7 @@ "metadata": {}, "outputs": [], "source": [ - "from sql.ggplot import ggplot, aes, geom_boxplot, geom_histogram, facet_wrap" + "from sql.ggplot import ggplot, aes, geom_histogram" ] }, { diff --git a/src/tests/integration/conftest.py b/src/tests/integration/conftest.py index a07241971..0eb81f2b3 100644 --- a/src/tests/integration/conftest.py +++ b/src/tests/integration/conftest.py @@ -293,9 +293,10 @@ def setup_duckDB_native(test_table_name_dict): def setup_spark(test_table_name_dict): import os import shutil + import sys - os.environ["PYSPARK_PYTHON"] = os.environ.get("CONDA_PYTHON_EXE") - os.environ["PYSPARK_DRIVER_PYTHON"] = os.environ.get("CONDA_PYTHON_EXE") + os.environ["PYSPARK_PYTHON"] = sys.executable + os.environ["PYSPARK_DRIVER_PYTHON"] = sys.executable spark = SparkSession.builder.master("local[1]").enableHiveSupport().getOrCreate() load_generic_testing_data_spark(spark, test_table_name_dict) yield spark From dc5e208c51b7496f429c2ca1b8612add66bf0dd8 Mon Sep 17 00:00:00 2001 From: richard gilmore Date: Wed, 20 Dec 2023 10:18:17 +0000 Subject: [PATCH 18/29] fixing lint errors --- src/sql/connection/__init__.py | 2 ++ src/sql/connection/connection.py | 3 ++- src/sql/plot.py | 1 - src/sql/run/run.py | 1 - src/tests/integration/conftest.py | 5 ++--- src/tests/integration/test_stats.py | 2 +- 6 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/sql/connection/__init__.py b/src/sql/connection/__init__.py index 259b549d4..4d9dfb10a 100644 --- a/src/sql/connection/__init__.py +++ b/src/sql/connection/__init__.py @@ -16,7 +16,9 @@ "ConnectionManager", "SQLAlchemyConnection", "DBAPIConnection", + "SparkConnectConnection", "is_pep249_compliant", + "is_spark", "PLOOMBER_DOCS_LINK_STR", "default_alias_for_engine", "ResultSetCollection", diff --git a/src/sql/connection/connection.py b/src/sql/connection/connection.py index cfdb85c8d..9f4495408 100644 --- a/src/sql/connection/connection.py +++ b/src/sql/connection/connection.py @@ -1144,7 +1144,8 @@ def to_table(self, table_name, data_frame, if_exists, index, schema=None): ) def close(self): - """Override of the abstract close as SparkSession is usually shared with pyspark""" + """Override of the abstract close as SparkSession is usually + shared with pyspark""" pass diff --git a/src/sql/plot.py b/src/sql/plot.py index b6e8f9ca4..d3cecb759 100644 --- a/src/sql/plot.py +++ b/src/sql/plot.py @@ -271,7 +271,6 @@ def _min_max(conn, table, column, with_=None, use_backticks=False): table = table.replace('"', "`") template = Template(template_) query = template.render(table=table, column=column) - y = conn.execute(query, with_) min_, max_ = conn.execute(query, with_).fetchone() return min_, max_ diff --git a/src/sql/run/run.py b/src/sql/run/run.py index f9b1fd964..7dcd5049e 100644 --- a/src/sql/run/run.py +++ b/src/sql/run/run.py @@ -3,7 +3,6 @@ from sql import exceptions, display from sql.run.resultset import ResultSet from sql.run.pgspecial import handle_postgres_special -from sql.run.sparkdataframe import handle_spark_dataframe # TODO: conn also has access to config, we should clean this up to provide a clean diff --git a/src/tests/integration/conftest.py b/src/tests/integration/conftest.py index 0eb81f2b3..cad899583 100644 --- a/src/tests/integration/conftest.py +++ b/src/tests/integration/conftest.py @@ -301,8 +301,8 @@ def setup_spark(test_table_name_dict): load_generic_testing_data_spark(spark, test_table_name_dict) yield spark spark.stop() - shutil.rmtree("metastore_db") - shutil.rmtree("spark-warehouse") + shutil.rmtree("metastore_db", ignore_errors=True) + shutil.rmtree("spark-warehouse", ignore_errors=True) os.remove("derby.log") @@ -322,7 +322,6 @@ def load_generic_testing_data_spark(spark: SparkSession, test_table_name_dict): @pytest.fixture def ip_with_spark(ip_empty, setup_spark): - configKey = "spark" alias = "SparkSession" ip_empty.push({"conn": setup_spark}) diff --git a/src/tests/integration/test_stats.py b/src/tests/integration/test_stats.py index 01593f8ce..fe33b18f5 100644 --- a/src/tests/integration/test_stats.py +++ b/src/tests/integration/test_stats.py @@ -34,7 +34,7 @@ def test_summary_stats(fixture_name, request, test_table_name_dict): "setup_spark", ], ) -def test_summary_stats(fixture_name, request, test_table_name_dict): +def test_summary_stats_spark(fixture_name, request, test_table_name_dict): conn = SparkConnectConnection(request.getfixturevalue(fixture_name)) table = test_table_name_dict["plot_something"] column = "x" From a4ef89b98c09c58e46ad9e9e3e3fda8c12513879 Mon Sep 17 00:00:00 2001 From: richard gilmore Date: Wed, 20 Dec 2023 10:29:28 +0000 Subject: [PATCH 19/29] change log formating --- CHANGELOG.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0682b5139..a2113753f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,7 +1,8 @@ # CHANGELOG ## 0.10.6dev -* [Feature] Add Spark Connection as a dialect for Jupysql (#965) +* [Feature] Add Spark Connection as a dialect for Jupysql [#965](https://github.com/ploomber/jupysql/issues/965) +(by [@gilandose](https://github.com/gilandose)) * [Fix] Fix error when `%sql` includes a query with negative numbers (#958) ## 0.10.5 (2023-12-11) From 45e7b6c363c051b2b9c8ed9a5ae3309a9b304e66 Mon Sep 17 00:00:00 2001 From: richard gilmore Date: Wed, 20 Dec 2023 11:01:41 +0000 Subject: [PATCH 20/29] metadata ipynb --- doc/integrations/spark.ipynb | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/doc/integrations/spark.ipynb b/doc/integrations/spark.ipynb index 01940eb0f..4f150500d 100644 --- a/doc/integrations/spark.ipynb +++ b/doc/integrations/spark.ipynb @@ -1376,8 +1376,8 @@ }, "myst": { "html_meta": { - "description lang=en": "Query a PostgreSQL database from Jupyter via JupySQL", - "keywords": "jupyter, sql, jupysql, postgres", + "description lang=en": "Query using Spark SQL from Jupyter via JupySQL", + "keywords": "jupyter, sql, jupysql, spark", "property=og:locale": "en_US" } }, From 96173de1ae3f3535c328ab262861679192f6f127 Mon Sep 17 00:00:00 2001 From: richard gilmore Date: Thu, 21 Dec 2023 17:13:34 +0000 Subject: [PATCH 21/29] addressing comments --- doc/integrations/compatibility.md | 2 +- src/sql/error_handler.py | 3 --- src/sql/magic.py | 2 +- src/sql/run/resultset.py | 2 +- src/sql/run/run.py | 7 ++----- src/sql/run/sparkdataframe.py | 10 ++++++---- 6 files changed, 11 insertions(+), 15 deletions(-) diff --git a/doc/integrations/compatibility.md b/doc/integrations/compatibility.md index 67d1e9258..d59760a98 100644 --- a/doc/integrations/compatibility.md +++ b/doc/integrations/compatibility.md @@ -130,4 +130,4 @@ These table reflects the compatibility status of JupySQL `>=0.7` - Listing columns with `%sqlcmd columns` ❌ - Parametrized SQL queries via `{{parameter}}` ✅ - Interactive SQL queries via `--interact` ✅ -- Persiting Dataframes via `--persist` ✅ \ No newline at end of file +- Persisting Dataframes via `--persist` ✅ \ No newline at end of file diff --git a/src/sql/error_handler.py b/src/sql/error_handler.py index 7792ecbd9..cedc7ade7 100644 --- a/src/sql/error_handler.py +++ b/src/sql/error_handler.py @@ -108,12 +108,9 @@ def handle_exception(error, query=None, short_error=True): if util.is_sqlalchemy_error(error) or util.is_non_sqlalchemy_error(error): detailed_message, error_type = _detailed_message_with_error_type(error, query) if short_error: - print("YEY") _raise_error(error, detailed_message, error_type) else: - print("NEH") _display_error_msg_with_trace(error, detailed_message) else: - print("sadge") print(str(error)) raise error diff --git a/src/sql/magic.py b/src/sql/magic.py index eddb344c6..ac49bbeba 100644 --- a/src/sql/magic.py +++ b/src/sql/magic.py @@ -150,7 +150,7 @@ class SqlMagic(Magics, Configurable): lazy_execution = Bool( default_value=False, config=True, - help="Whether to evalute using ResultSet which will " + help="Whether to evaluate using ResultSet which will " "cause the plan to execute or just return a lazily " "executed plan allowing validating schemas, " "without expensive compute.", diff --git a/src/sql/run/resultset.py b/src/sql/run/resultset.py index 15a77c775..e1e55deeb 100644 --- a/src/sql/run/resultset.py +++ b/src/sql/run/resultset.py @@ -434,7 +434,7 @@ def fetchmany(self, size): raise RuntimeError(f"Error running the query: {str(e)}") from e self.mark_fetching_as_done() return - # spark doesn't support curser + # spark doesn't support cursor if hasattr(self._sqlaproxy, "dataframe"): self._results = [] self._pretty_table.clear() diff --git a/src/sql/run/run.py b/src/sql/run/run.py index 7dcd5049e..a1e34aa7d 100644 --- a/src/sql/run/run.py +++ b/src/sql/run/run.py @@ -49,14 +49,11 @@ def run_statements(conn, sql, config, parameters=None): if first_word.startswith("\\") and is_postgres_or_redshift(conn.dialect): result = handle_postgres_special(conn, statement) - if is_spark(conn.dialect): - result = conn.raw_execute(statement, parameters=parameters) - if config.lazy_execution: - return conn.raw_execute(statement, parameters=parameters).dataframe - # regular query else: result = conn.raw_execute(statement, parameters=parameters) + if is_spark(conn.dialect) and config.lazy_execution: + return result.dataframe if ( config.feedback >= 1 diff --git a/src/sql/run/sparkdataframe.py b/src/sql/run/sparkdataframe.py index 77ef31854..ba8765141 100644 --- a/src/sql/run/sparkdataframe.py +++ b/src/sql/run/sparkdataframe.py @@ -13,10 +13,10 @@ def handle_spark_dataframe(dataframe, should_cache=False): if not DataFrame and not CDataFrame: raise exceptions.MissingPackageError("pysark not installed") - return FakeResultProxy(dataframe, dataframe.columns, should_cache) + return SparkResultProxy(dataframe, dataframe.columns, should_cache) -class FakeResultProxy(object): +class SparkResultProxy(object): """A fake class that pretends to behave like the ResultProxy from SqlAlchemy. """ @@ -28,7 +28,7 @@ def __init__(self, dataframe, headers, should_cache): self.fetchall = dataframe.collect self.rowcount = dataframe.count self.keys = lambda: headers - self.cursor = FakeCursor(headers) + self.cursor = SparkCuror(headers) self.returns_rows = True if should_cache: self.dataframe.cache() @@ -43,7 +43,9 @@ def close(self): self.dataframe.unpersist() -class FakeCursor(object): +class SparkCuror(object): + """Clas to extend to give SqlAlchemy Cursor like behaviour""" + description = None def __init__(self, headers) -> None: From df5925e43845e9139218006a1637edbd13f0790c Mon Sep 17 00:00:00 2001 From: richard gilmore Date: Thu, 21 Dec 2023 17:21:32 +0000 Subject: [PATCH 22/29] update changelog --- CHANGELOG.md | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a2113753f..6574bcd82 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,8 +1,12 @@ # CHANGELOG -## 0.10.6dev +## 0.10.7dev + * [Feature] Add Spark Connection as a dialect for Jupysql [#965](https://github.com/ploomber/jupysql/issues/965) (by [@gilandose](https://github.com/gilandose)) + +## 0.10.6 (2023-12-21) + * [Fix] Fix error when `%sql` includes a query with negative numbers (#958) ## 0.10.5 (2023-12-11) From 27dca2b9018e84c42d12ff7c3a1090133cc09540 Mon Sep 17 00:00:00 2001 From: richard gilmore Date: Thu, 21 Dec 2023 17:25:49 +0000 Subject: [PATCH 23/29] changelog --- CHANGELOG.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6574bcd82..f533cb816 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,12 +2,12 @@ ## 0.10.7dev -* [Feature] Add Spark Connection as a dialect for Jupysql [#965](https://github.com/ploomber/jupysql/issues/965) +* [Feature] Add Spark Connection as a dialect for Jupysql ([#965](https://github.com/ploomber/jupysql/issues/965)) (by [@gilandose](https://github.com/gilandose)) ## 0.10.6 (2023-12-21) -* [Fix] Fix error when `%sql` includes a query with negative numbers (#958) +* [Fix] Fix error when `%sql` includes a query with negative numbers ([#958](https://github.com/ploomber/jupysql/issues/958)) ## 0.10.5 (2023-12-11) From 743dee85f77a799742381809473eba98f7a09444 Mon Sep 17 00:00:00 2001 From: richard gilmore Date: Thu, 21 Dec 2023 20:39:12 +0000 Subject: [PATCH 24/29] fix row count --- src/sql/run/sparkdataframe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sql/run/sparkdataframe.py b/src/sql/run/sparkdataframe.py index ba8765141..5dc497744 100644 --- a/src/sql/run/sparkdataframe.py +++ b/src/sql/run/sparkdataframe.py @@ -26,7 +26,7 @@ class SparkResultProxy(object): def __init__(self, dataframe, headers, should_cache): self.dataframe = dataframe self.fetchall = dataframe.collect - self.rowcount = dataframe.count + self.rowcount = dataframe.count() self.keys = lambda: headers self.cursor = SparkCuror(headers) self.returns_rows = True From ac20efd7a893b8a6c97d83cebe6fd3d6f2ec6a92 Mon Sep 17 00:00:00 2001 From: richard gilmore Date: Thu, 21 Dec 2023 20:43:21 +0000 Subject: [PATCH 25/29] spelling --- src/sql/run/sparkdataframe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sql/run/sparkdataframe.py b/src/sql/run/sparkdataframe.py index 5dc497744..d8f2297ba 100644 --- a/src/sql/run/sparkdataframe.py +++ b/src/sql/run/sparkdataframe.py @@ -44,7 +44,7 @@ def close(self): class SparkCuror(object): - """Clas to extend to give SqlAlchemy Cursor like behaviour""" + """Class to extend to give SqlAlchemy Cursor like behaviour""" description = None From 7bf2b5b2b9a4e4427764ab3a4d4a3dae839cb6e8 Mon Sep 17 00:00:00 2001 From: richard gilmore Date: Thu, 21 Dec 2023 20:43:44 +0000 Subject: [PATCH 26/29] spelling --- src/sql/run/sparkdataframe.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/sql/run/sparkdataframe.py b/src/sql/run/sparkdataframe.py index d8f2297ba..81644b1e2 100644 --- a/src/sql/run/sparkdataframe.py +++ b/src/sql/run/sparkdataframe.py @@ -28,7 +28,7 @@ def __init__(self, dataframe, headers, should_cache): self.fetchall = dataframe.collect self.rowcount = dataframe.count() self.keys = lambda: headers - self.cursor = SparkCuror(headers) + self.cursor = SparkCursor(headers) self.returns_rows = True if should_cache: self.dataframe.cache() @@ -43,7 +43,7 @@ def close(self): self.dataframe.unpersist() -class SparkCuror(object): +class SparkCursor(object): """Class to extend to give SqlAlchemy Cursor like behaviour""" description = None From 36aaa1003f50c75188b6f0cf169d3a7157d53645 Mon Sep 17 00:00:00 2001 From: richard gilmore Date: Thu, 21 Dec 2023 22:15:16 +0000 Subject: [PATCH 27/29] remove pypark dev dependency --- setup.py | 4 ++-- src/sql/connection/connection.py | 29 ++++++++++++++-------- src/tests/test_connection.py | 41 +++++++++++++++++++++++--------- 3 files changed, 51 insertions(+), 23 deletions(-) diff --git a/setup.py b/setup.py index 7207b88d8..51dfc5cbc 100644 --- a/setup.py +++ b/setup.py @@ -35,8 +35,6 @@ "pandas", # previously pinned to 2.0.3 "polars==0.17.2", # 04/18/23 this breaks our CI "pyarrow", - "pyspark", - "grpcio-status", "invoke", "pkgmt", "twine", @@ -73,6 +71,8 @@ "redshift-connector", "sqlalchemy-redshift", "clickhouse-sqlalchemy", + "pyspark", + "grpcio-status", ] setup( diff --git a/src/sql/connection/connection.py b/src/sql/connection/connection.py index 9f4495408..8f7d4758e 100644 --- a/src/sql/connection/connection.py +++ b/src/sql/connection/connection.py @@ -19,12 +19,6 @@ from sql.run.sparkdataframe import handle_spark_dataframe -try: - from pyspark.sql.connect.session import SparkSession as CSparkSession - from pyspark.sql import SparkSession -except ModuleNotFoundError: - CSparkSession = None - SparkSession = None from IPython.core.error import UsageError import sqlglot import sqlparse @@ -1243,10 +1237,25 @@ def is_pep249_compliant(conn): return True -def is_spark(ins): - return (CSparkSession is not None and isinstance(ins, CSparkSession)) or ( - SparkSession is not None and isinstance(ins, SparkSession) - ) +def is_spark(conn): + """Check if it is a SparkSession by checking for available methods""" + + sparksession_methods = [ + "table", + "read", + "readStream", + "createDataFrame", + "sql", + "stop", + "catalog", + "version", + ] + for method_name in sparksession_methods: + # Checking whether the connection object has the method + if not hasattr(conn, method_name): + return False + + return True def default_alias_for_engine(engine): diff --git a/src/tests/test_connection.py b/src/tests/test_connection.py index bdcb0fa86..4da0c60f5 100644 --- a/src/tests/test_connection.py +++ b/src/tests/test_connection.py @@ -13,7 +13,6 @@ from sqlalchemy.engine import Engine from sqlalchemy import exc -import pyspark from sql.connection import connection as connection_module import sql.connection @@ -44,9 +43,33 @@ def mock_database(monkeypatch, cleanup): monkeypatch.setattr(sqlalchemy, "create_engine", Mock()) -@pytest.fixture -def mock_spark(monkeypatch, cleanup): - monkeypatch.setitem(sys.modules, "pyspark.sql.SparkSession", Mock()) +def mock_sparksession(): + mock = Mock( + spec=[ + "table", + "read", + "readStream", + "createDataFrame", + "sql", + "stop", + "catalog", + "version", + ] + ) + return mock + + +def mock_not_sparksession(): + mock = Mock( + spec=[ + "read", + "readStream", + "createDataFrame", + "sql", + "verison", + ] + ) + return mock @pytest.fixture @@ -471,8 +494,8 @@ def test_is_pep249_compliant(conn, expected): [sqlite3.connect(""), False], [duckdb.connect(""), False], [create_engine("sqlite://"), False], - [Mock(spec=pyspark.sql.SparkSession), True], - [Mock(spec=pyspark.sql.connect.session.SparkSession), True], + [mock_sparksession(), True], + [mock_not_sparksession(), False], [None, False], [object(), False], ["not_a_valid_connection", False], @@ -619,11 +642,7 @@ def test_set_dbapi(monkeypatch, callable_, key): @pytest.mark.parametrize( "spark, key", [ - [Mock(name="SparkSession", spec=pyspark.sql.SparkSession), "Mock"], - [ - Mock(name="SparkSession", spec=pyspark.sql.connect.session.SparkSession), - "Mock", - ], + [mock_sparksession(), "Mock"], ], ) def test_set_spark(monkeypatch, spark, key): From db024fb2149b3b050f6ea0507c4cc36a3da77cfc Mon Sep 17 00:00:00 2001 From: richard gilmore Date: Sat, 23 Dec 2023 20:05:40 +0000 Subject: [PATCH 28/29] review comments --- CHANGELOG.md | 3 +-- doc/api/configuration.md | 4 ++++ noxfile.py | 2 ++ setup.py | 1 + src/sql/error_handler.py | 1 - src/sql/magic.py | 3 ++- src/sql/run/resultset.py | 1 - src/tests/test_connection.py | 3 +-- 8 files changed, 11 insertions(+), 7 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index f533cb816..a15174f7f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,8 +2,7 @@ ## 0.10.7dev -* [Feature] Add Spark Connection as a dialect for Jupysql ([#965](https://github.com/ploomber/jupysql/issues/965)) -(by [@gilandose](https://github.com/gilandose)) +* [Feature] Add Spark Connection as a dialect for Jupysql ([#965](https://github.com/ploomber/jupysql/issues/965)) (by [@gilandose](https://github.com/gilandose)) ## 0.10.6 (2023-12-21) diff --git a/doc/api/configuration.md b/doc/api/configuration.md index b7945683e..254ea712a 100644 --- a/doc/api/configuration.md +++ b/doc/api/configuration.md @@ -236,6 +236,10 @@ value enables the ones from previous values plus new ones: ## `lazy_execution` +```{versionadded} 0.10.7 +This option only works when connecting to Spark +``` + Default: `False` Return lazy relation to dataset rather than executing through JupySql. diff --git a/noxfile.py b/noxfile.py index 92c2b8fb6..5a62c0f26 100644 --- a/noxfile.py +++ b/noxfile.py @@ -35,6 +35,8 @@ "pyodbc==4.0.34", "sqlalchemy-pytds", "python-tds", + "pyspark>=3.4.1", + "grpcio-status", ] diff --git a/setup.py b/setup.py index 51dfc5cbc..1c03419f8 100644 --- a/setup.py +++ b/setup.py @@ -71,6 +71,7 @@ "redshift-connector", "sqlalchemy-redshift", "clickhouse-sqlalchemy", + # following two dependencies required for spark "pyspark", "grpcio-status", ] diff --git a/src/sql/error_handler.py b/src/sql/error_handler.py index cedc7ade7..cd1dfb3cd 100644 --- a/src/sql/error_handler.py +++ b/src/sql/error_handler.py @@ -112,5 +112,4 @@ def handle_exception(error, query=None, short_error=True): else: _display_error_msg_with_trace(error, detailed_message) else: - print(str(error)) raise error diff --git a/src/sql/magic.py b/src/sql/magic.py index ac49bbeba..17a2a2a49 100644 --- a/src/sql/magic.py +++ b/src/sql/magic.py @@ -153,7 +153,8 @@ class SqlMagic(Magics, Configurable): help="Whether to evaluate using ResultSet which will " "cause the plan to execute or just return a lazily " "executed plan allowing validating schemas, " - "without expensive compute.", + "without expensive compute." + "Currently only supported for Spark Connection.", ) named_parameters = Bool( default_value=False, diff --git a/src/sql/run/resultset.py b/src/sql/run/resultset.py index e1e55deeb..4b977a8e2 100644 --- a/src/sql/run/resultset.py +++ b/src/sql/run/resultset.py @@ -438,7 +438,6 @@ def fetchmany(self, size): if hasattr(self._sqlaproxy, "dataframe"): self._results = [] self._pretty_table.clear() - print(self._conn) self._extend_results(returned) if len(returned) < size: diff --git a/src/tests/test_connection.py b/src/tests/test_connection.py index 4da0c60f5..8369ff690 100644 --- a/src/tests/test_connection.py +++ b/src/tests/test_connection.py @@ -48,7 +48,6 @@ def mock_sparksession(): spec=[ "table", "read", - "readStream", "createDataFrame", "sql", "stop", @@ -66,7 +65,7 @@ def mock_not_sparksession(): "readStream", "createDataFrame", "sql", - "verison", + "version", ] ) return mock From cfb5431a909f13d8c946c8ef6161aaf35036ac6a Mon Sep 17 00:00:00 2001 From: richard gilmore Date: Sun, 24 Dec 2023 00:40:56 +0000 Subject: [PATCH 29/29] missed readStream in connection.py --- src/sql/connection/connection.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/sql/connection/connection.py b/src/sql/connection/connection.py index 8f7d4758e..fc4552aaa 100644 --- a/src/sql/connection/connection.py +++ b/src/sql/connection/connection.py @@ -1243,7 +1243,6 @@ def is_spark(conn): sparksession_methods = [ "table", "read", - "readStream", "createDataFrame", "sql", "stop",