Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for Spark Connect (SQL models) #899

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .changes/unreleased/Features-20231004-191452.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Features
body: Add support for Spark Connect
time: 2023-10-04T19:14:52.858895+03:00
custom:
Author: vakarisbk
Issue: "899"
1 change: 1 addition & 0 deletions .github/workflows/integration.yml
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ jobs:
- "databricks_sql_endpoint"
- "databricks_cluster"
- "databricks_http_cluster"
- "spark_connect"

env:
DBT_INVOCATION_ENV: github-actions
Expand Down
29 changes: 28 additions & 1 deletion dagger/run_dbt_spark_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,29 @@ def get_spark_container(client: dagger.Client) -> (dagger.Service, str):
return spark_ctr, "spark_db"


def get_spark_connect_container(client: dagger.Client) -> (dagger.Container, str):
spark_ctr_base = (
client.container()
.from_("spark:3.5.0-scala2.12-java17-ubuntu")
.with_exec(
[
"/opt/spark/bin/spark-submit",
"--class",
"org.apache.spark.sql.connect.service.SparkConnectServer",
"--conf",
"spark.sql.catalogImplementation=hive",
"--packages",
"org.apache.spark:spark-connect_2.12:3.5.0",
"--conf",
"spark.jars.ivy=/tmp",
]
)
.with_exposed_port(15002)
.as_service()
)
return spark_ctr_base, "localhost"


async def test_spark(test_args):
async with dagger.Connection(dagger.Config(log_output=sys.stderr)) as client:
test_profile = test_args.profile
Expand Down Expand Up @@ -133,7 +156,11 @@ async def test_spark(test_args):
)

elif test_profile == "spark_session":
tst_container = tst_container.with_exec(["pip", "install", "pyspark"])
tst_container = tst_container.with_exec(["apt-get", "install", "openjdk-17-jre", "-y"])

elif test_profile == "spark_connect":
spark_ctr, spark_host = get_spark_connect_container(client)
tst_container = tst_container.with_service_binding(alias=spark_host, service=spark_ctr)
tst_container = tst_container.with_exec(["apt-get", "install", "openjdk-17-jre", "-y"])

tst_container = tst_container.with_(env_variables(TESTING_ENV_VARS))
Expand Down
62 changes: 61 additions & 1 deletion dbt/adapters/spark/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ class SparkConnectionMethod(StrEnum):
HTTP = "http"
ODBC = "odbc"
SESSION = "session"
CONNECT = "connect"


@dataclass
Expand Down Expand Up @@ -154,6 +155,21 @@ def __post_init__(self) -> None:
f"ImportError({e.msg})"
) from e

if self.method == SparkConnectionMethod.CONNECT:
try:
import pyspark # noqa: F401 F811
import grpc # noqa: F401
import pyarrow # noqa: F401
import pandas # noqa: F401
except ImportError as e:
raise dbt.exceptions.DbtRuntimeError(
f"{self.method} connection method requires "
"additional dependencies. \n"
"Install the additional required dependencies with "
"`pip install dbt-spark[connect]`\n\n"
f"ImportError({e.msg})"
) from e

if self.method != SparkConnectionMethod.SESSION:
self.host = self.host.rstrip("/")

Expand Down Expand Up @@ -524,8 +540,52 @@ def open(cls, connection: Connection) -> Connection:
SessionConnectionWrapper,
)

# Pass session type (session or connect) into SessionConnectionWrapper
handle = SessionConnectionWrapper(
Connection(
conn_method=creds.method,
conn_url="localhost",
server_side_parameters=creds.server_side_parameters,
)
)
elif SparkConnectionMethod.CONNECT:
# Create the url

host = creds.host
port = creds.port
token = creds.token
use_ssl = creds.use_ssl
user = creds.user

# URL Format: sc://localhost:15002/;user_id=str;token=str;use_ssl=bool
if not host.startswith("sc://"):
base_url = f"sc://{host}"
base_url += f":{str(port)}"

url_extensions = []
if user:
url_extensions.append(f"user_id={user}")
if use_ssl:
url_extensions.append(f"use_ssl={use_ssl}")
if token:
url_extensions.append(f"token={token}")

conn_url = base_url + ";".join(url_extensions)

logger.debug("connection url: {}".format(conn_url))

from .session import ( # noqa: F401
Connection,
SessionConnectionWrapper,
)

# Pass session type (session or connect) into SessionConnectionWrapper
handle = SessionConnectionWrapper(
Connection(server_side_parameters=creds.server_side_parameters)
Connection(
conn_method=creds.method,
conn_url=conn_url,
server_side_parameters=creds.server_side_parameters,
)
)
else:
raise DbtConfigError(f"invalid credential method: {creds.method}")
Expand Down
35 changes: 29 additions & 6 deletions dbt/adapters/spark/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from types import TracebackType
from typing import Any, Dict, List, Optional, Tuple, Union, Sequence

from dbt.adapters.spark.connections import SparkConnectionWrapper
from dbt.adapters.spark.connections import SparkConnectionMethod, SparkConnectionWrapper
from dbt.adapters.events.logging import AdapterLogger
from dbt_common.utils.encoding import DECIMALS
from dbt_common.exceptions import DbtRuntimeError
Expand All @@ -27,9 +27,17 @@ class Cursor:
https://github.com/mkleehammer/pyodbc/wiki/Cursor
"""

def __init__(self, *, server_side_parameters: Optional[Dict[str, Any]] = None) -> None:
def __init__(
self,
*,
conn_method: SparkConnectionMethod,
conn_url: str,
server_side_parameters: Optional[Dict[str, Any]] = None,
) -> None:
self._df: Optional[DataFrame] = None
self._rows: Optional[List[Row]] = None
self.conn_method: SparkConnectionMethod = conn_method
self.conn_url: str = conn_url
self.server_side_parameters = server_side_parameters or {}

def __enter__(self) -> Cursor:
Expand Down Expand Up @@ -113,12 +121,15 @@ def execute(self, sql: str, *parameters: Any) -> None:
if len(parameters) > 0:
sql = sql % parameters

builder = SparkSession.builder.enableHiveSupport()
builder = SparkSession.builder

for parameter, value in self.server_side_parameters.items():
builder = builder.config(parameter, value)

spark_session = builder.getOrCreate()
if self.conn_method == SparkConnectionMethod.CONNECT:
spark_session = builder.remote(self.conn_url).getOrCreate()
elif self.conn_method == SparkConnectionMethod.SESSION:
spark_session = builder.enableHiveSupport().getOrCreate()

try:
self._df = spark_session.sql(sql)
Expand Down Expand Up @@ -175,7 +186,15 @@ class Connection:
https://github.com/mkleehammer/pyodbc/wiki/Connection
"""

def __init__(self, *, server_side_parameters: Optional[Dict[Any, str]] = None) -> None:
def __init__(
self,
*,
conn_method: SparkConnectionMethod,
conn_url: str,
server_side_parameters: Optional[Dict[Any, str]] = None,
) -> None:
self.conn_method = conn_method
self.conn_url = conn_url
self.server_side_parameters = server_side_parameters or {}

def cursor(self) -> Cursor:
Expand All @@ -187,7 +206,11 @@ def cursor(self) -> Cursor:
out : Cursor
The cursor.
"""
return Cursor(server_side_parameters=self.server_side_parameters)
return Cursor(
conn_method=self.conn_method,
conn_url=self.conn_url,
server_side_parameters=self.server_side_parameters,
)


class SessionConnectionWrapper(SparkConnectionWrapper):
Expand Down
1 change: 1 addition & 0 deletions dev-requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# install latest changes in dbt-core
# TODO: how to automate switching from develop to version branches?
git+https://github.com/dbt-labs/dbt-adapters.git#subdirectory=dbt-tests-adapter
git+https://github.com/dbt-labs/dbt-core.git#subdirectory=core

# if version 1.x or greater -> pin to major version
# if version 0.x -> pin to minor
Expand Down
3 changes: 3 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,8 @@ sqlparams>=3.0.0
thrift>=0.13.0
sqlparse>=0.4.2 # not directly required, pinned by Snyk to avoid a vulnerability

#spark-connect
pyspark[connect]>=3.5.0,<4

types-PyYAML
types-python-dateutil
6 changes: 5 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,10 @@ def _get_plugin_version_dict():
"thrift>=0.11.0,<0.17.0",
]
session_extras = ["pyspark>=3.0.0,<4.0.0"]
all_extras = odbc_extras + pyhive_extras + session_extras
connect_extras = [
"pyspark[connect]>=3.5.0<4.0.0",
]
all_extras = odbc_extras + pyhive_extras + session_extras + connect_extras

setup(
name=package_name,
Expand All @@ -71,6 +74,7 @@ def _get_plugin_version_dict():
"ODBC": odbc_extras,
"PyHive": pyhive_extras,
"session": session_extras,
"connect": connect_extras,
"all": all_extras,
},
zip_safe=False,
Expand Down
12 changes: 7 additions & 5 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ def dbt_profile_target(request):
target = databricks_http_cluster_target()
elif profile_type == "spark_session":
target = spark_session_target()
elif profile_type == "spark_connect":
target = spark_connect_target()
else:
raise ValueError(f"Invalid profile type '{profile_type}'")
return target
Expand Down Expand Up @@ -95,11 +97,11 @@ def databricks_http_cluster_target():


def spark_session_target():
return {
"type": "spark",
"host": "localhost",
"method": "session",
}
return {"type": "spark", "host": "localhost", "method": "session"}


def spark_connect_target():
return {"type": "spark", "host": "localhost", "port": 15002, "method": "connect"}


@pytest.fixture(autouse=True)
Expand Down
2 changes: 1 addition & 1 deletion tests/functional/adapter/dbt_clone/test_dbt_clone.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
)


@pytest.mark.skip_profile("apache_spark", "spark_session")
@pytest.mark.skip_profile("apache_spark", "spark_session", "spark_connect")
class TestSparkBigqueryClonePossible(BaseClonePossible):
@pytest.fixture(scope="class")
def models(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
)


@pytest.mark.skip_profile("spark_session", "apache_spark")
@pytest.mark.skip_profile("spark_session", "apache_spark", "spark_connect")
class TestMergeExcludeColumns(BaseMergeExcludeColumns):
@pytest.fixture(scope="class")
def project_config_update(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def project_config_update(self):
}


@pytest.mark.skip_profile("databricks_sql_endpoint", "spark_session")
@pytest.mark.skip_profile("databricks_sql_endpoint", "spark_session", "spark_connect")
class TestInsertOverwriteOnSchemaChange(IncrementalOnSchemaChangeIgnoreFail):
@pytest.fixture(scope="class")
def project_config_update(self):
Expand All @@ -45,7 +45,7 @@ def project_config_update(self):
}


@pytest.mark.skip_profile("apache_spark", "spark_session")
@pytest.mark.skip_profile("apache_spark", "spark_session", "spark_connect")
class TestDeltaOnSchemaChange(BaseIncrementalOnSchemaChangeSetup):
@pytest.fixture(scope="class")
def project_config_update(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
"""


@pytest.mark.skip_profile("spark_session", "apache_spark")
@pytest.mark.skip_profile("spark_session", "apache_spark", "spark_connect")
class TestIncrementalPredicatesMergeSpark(BaseIncrementalPredicates):
@pytest.fixture(scope="class")
def project_config_update(self):
Expand All @@ -46,7 +46,7 @@ def models(self):
}


@pytest.mark.skip_profile("spark_session", "apache_spark")
@pytest.mark.skip_profile("spark_session", "apache_spark", "spark_connect")
class TestPredicatesMergeSpark(BaseIncrementalPredicates):
@pytest.fixture(scope="class")
def project_config_update(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from dbt.tests.adapter.incremental.test_incremental_unique_id import BaseIncrementalUniqueKey


@pytest.mark.skip_profile("spark_session", "apache_spark")
@pytest.mark.skip_profile("spark_session", "apache_spark", "spark_connect")
class TestUniqueKeySpark(BaseIncrementalUniqueKey):
@pytest.fixture(scope="class")
def project_config_update(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,11 @@ def run_and_test(self, project):
check_relations_equal(project.adapter, ["merge_update_columns", "expected_partial_upsert"])

@pytest.mark.skip_profile(
"apache_spark", "databricks_http_cluster", "databricks_sql_endpoint", "spark_session"
"apache_spark",
"databricks_http_cluster",
"databricks_sql_endpoint",
"spark_session",
"spark_connect",
)
def test_delta_strategies(self, project):
self.run_and_test(project)
Expand Down
6 changes: 3 additions & 3 deletions tests/functional/adapter/persist_docs/test_persist_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
)


@pytest.mark.skip_profile("apache_spark", "spark_session")
@pytest.mark.skip_profile("apache_spark", "spark_session", "spark_connect")
class TestPersistDocsDeltaTable:
@pytest.fixture(scope="class")
def models(self):
Expand Down Expand Up @@ -78,7 +78,7 @@ def test_delta_comments(self, project):
assert result[2].startswith("Some stuff here and then a call to")


@pytest.mark.skip_profile("apache_spark", "spark_session")
@pytest.mark.skip_profile("apache_spark", "spark_session", "spark_connect")
class TestPersistDocsDeltaView:
@pytest.fixture(scope="class")
def models(self):
Expand Down Expand Up @@ -120,7 +120,7 @@ def test_delta_comments(self, project):
assert result[2] is None


@pytest.mark.skip_profile("apache_spark", "spark_session")
@pytest.mark.skip_profile("apache_spark", "spark_session", "spark_connect")
class TestPersistDocsMissingColumn:
@pytest.fixture(scope="class")
def project_config_update(self):
Expand Down
Loading
Loading