-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #15 from bmsuisse/executesql
Executesql
- Loading branch information
Showing
7 changed files
with
310 additions
and
42 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from .sql_connection import TdsConnection | ||
from .bulk_insert import insert_record_batch_to_sql |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
import lakeapi2sql._lowlevel as lvd | ||
from lakeapi2sql.utils import prepare_connection_string | ||
|
||
|
||
class TdsConnection: | ||
def __init__(self, connection_string: str, aad_token: str | None = None) -> None: | ||
connection_string, aad_token = await prepare_connection_string(connection_string, aad_token) | ||
self._connection_string = connection_string | ||
self._aad_token = aad_token | ||
|
||
async def __aenter__(self) -> "TdsConnection": | ||
self._connection = await lvd.connect_sql(self.connection_string, self.aad_token) | ||
return self | ||
|
||
async def __aexit__(self, exc_type, exc_value, traceback) -> None: | ||
pass | ||
|
||
async def execute_sql(self, sql: str, arguments: list[str | int | float | bool | None]) -> list[int]: | ||
return await lvd.execute_sql(self._connection, sql, arguments) | ||
|
||
async def execute_sql_with_result(self, sql: str, arguments: list[str | int | float | bool | None]) -> list[int]: | ||
return await lvd.execute_sql_with_result(self._connection, sql, arguments) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
import inspect | ||
|
||
|
||
async def prepare_connection_string(connection_string: str, aad_token: str | None) -> tuple[str, str | None]: | ||
if "authentication" in connection_string.lower(): | ||
parts = [(kv[0 : kv.index("=")], kv[kv.index("=") + 1 :]) for kv in connection_string.split(";")] | ||
auth_part = next((p for p in parts if p[0].casefold() == "Authentication".casefold())) | ||
parts.remove(auth_part) | ||
credential = None | ||
auth_method = auth_part[1].lower() | ||
if auth_method in ["ActiveDirectoryDefault".lower()]: | ||
from azure.identity.aio import DefaultAzureCredential | ||
|
||
credential = DefaultAzureCredential() | ||
elif auth_method in ["ActiveDirectoryMSI".lower(), "ActiveDirectoryManagedIdentity".lower()]: | ||
from azure.identity.aio import ManagedIdentityCredential | ||
|
||
client_part = next((p for p in parts if p[0].lower() in ["user", "msiclientid"]), None) | ||
if client_part: | ||
parts.remove(client_part) | ||
credential = ManagedIdentityCredential(client_id=client_part[1] if client_part else None) | ||
elif auth_method == "ActiveDirectoryInteractive".lower(): | ||
from azure.identity import InteractiveBrowserCredential | ||
|
||
credential = InteractiveBrowserCredential() | ||
elif auth_method == "SqlPassword": # that's kind of an no-op | ||
return ";".join((p[0] + "=" + p[1] for p in parts)), None | ||
if credential is not None: | ||
from azure.core.credentials import AccessToken | ||
|
||
res = credential.get_token("https://database.windows.net/.default") | ||
token: AccessToken = await res if inspect.isawaitable(res) else res # type: ignore | ||
aad_token = token.token | ||
return ";".join((p[0] + "=" + p[1] for p in parts)), aad_token | ||
return connection_string, aad_token |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.