Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Livy connection support for Spark SQL models #984

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 47 additions & 2 deletions dbt/adapters/spark/connections.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from contextlib import contextmanager
from threading import Lock

from dbt.adapters.contracts.connection import (
AdapterResponse,
Expand All @@ -13,6 +14,7 @@

from dbt_common.utils.encoding import DECIMALS
from dbt.adapters.spark import __version__
from dbt.adapters.spark.livysession import LivyConnectionManager, LivySessionConnectionWrapper

try:
from TCLIService.ttypes import TOperationState as ThriftState
Expand All @@ -30,7 +32,7 @@
import sqlparams
from dbt_common.dataclass_schema import StrEnum
from dataclasses import dataclass, field
from typing import Any, Dict, Optional, Union, Tuple, List, Generator, Iterable, Sequence
from typing import Any, Dict, Optional, Union, Tuple, List, Generator, Iterable, Sequence, Hashable

from abc import ABC, abstractmethod

Expand All @@ -47,7 +49,7 @@
import time

logger = AdapterLogger("Spark")

lock = Lock()
NUMBERS = DECIMALS + (int, float)


Expand All @@ -60,6 +62,7 @@ class SparkConnectionMethod(StrEnum):
HTTP = "http"
ODBC = "odbc"
SESSION = "session"
LIVY = "livy"


@dataclass
Expand All @@ -83,6 +86,8 @@ class SparkCredentials(Credentials):
use_ssl: bool = False
server_side_parameters: Dict[str, str] = field(default_factory=dict)
retry_all: bool = False
livy_session_parameters: Dict[str, Any] = field(default_factory=dict)
verify_ssl_certificate: Optional[bool] = True

@classmethod
def __pre_deserialize__(cls, data: Any) -> Any:
Expand Down Expand Up @@ -346,6 +351,7 @@ class SparkConnectionManager(SQLConnectionManager):
SPARK_CLUSTER_HTTP_PATH = "/sql/protocolv1/o/{organization}/{cluster}"
SPARK_SQL_ENDPOINT_HTTP_PATH = "/sql/1.0/endpoints/{endpoint}"
SPARK_CONNECTION_URL = "{host}:{port}" + SPARK_CLUSTER_HTTP_PATH
connection_managers: Dict[Hashable, LivyConnectionManager] = {}

@contextmanager
def exception_handler(self, sql: str) -> Generator[None, None, None]:
Expand Down Expand Up @@ -527,6 +533,45 @@ def open(cls, connection: Connection) -> Connection:
handle = SessionConnectionWrapper(
Connection(server_side_parameters=creds.server_side_parameters)
)
elif creds.method == SparkConnectionMethod.LIVY:
# connect to livy interactive session

lock.acquire()
try:
thread_id = cls.get_thread_identifier()

if thread_id not in SparkConnectionManager.connection_managers:
if len(SparkConnectionManager.connection_managers) > 0:
# Return same livy session
livyConnMgr = list(
SparkConnectionManager.connection_managers.values()
)[0]
SparkConnectionManager.connection_managers[thread_id] = livyConnMgr
else:
SparkConnectionManager.connection_managers[
thread_id
] = LivyConnectionManager()

handle = LivySessionConnectionWrapper( # type: ignore
SparkConnectionManager.connection_managers[thread_id].connect(
creds.host,
creds.user,
creds.password,
creds.auth,
creds.livy_session_parameters,
creds.verify_ssl_certificate,
)
)
connection.state = ConnectionState.OPEN

except Exception as ex:
logger.debug("Connection error: {}".format(ex))
connection.state = ConnectionState.FAIL
raise ex

finally:
lock.release()

else:
raise DbtConfigError(f"invalid credential method: {creds.method}")
break
Expand Down
13 changes: 13 additions & 0 deletions dbt/adapters/spark/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,19 @@ def debug_query(self) -> None:
"""Override for DebugTask method"""
self.execute("select 1 as id")

self.connections.get_thread_connection().handle.close()

def cleanup_connections(self) -> None:
self.connections.cleanup_all()
logger.debug("cleanup_connections")

# close all sessions
for conn_mgr in SparkConnectionManager.connection_managers:
SparkConnectionManager.connection_managers[conn_mgr].delete_session()

# reset connection_manager list
SparkConnectionManager.connection_managers = {}


# spark does something interesting with joins when both tables have the same
# static values for the join condition and complains that the join condition is
Expand Down
Loading
Loading