Skip to content

Commit

Permalink
Merge pull request #15 from bmsuisse/executesql
Browse files Browse the repository at this point in the history
Executesql
  • Loading branch information
aersam authored May 2, 2024
2 parents 5bfc904 + 262639f commit f16fc8f
Show file tree
Hide file tree
Showing 7 changed files with 310 additions and 42 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "lake2sql"
version = "0.8.3"
version = "0.9.0"
edition = "2021"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
Expand Down
2 changes: 2 additions & 0 deletions lakeapi2sql/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .sql_connection import TdsConnection
from .bulk_insert import insert_record_batch_to_sql
40 changes: 4 additions & 36 deletions lakeapi2sql/bulk_insert.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import pyarrow as pa
from pyarrow.cffi import ffi as arrow_ffi

from lakeapi2sql.utils import prepare_connection_string


class BulkInfoField(TypedDict):
name: str
Expand All @@ -14,48 +16,14 @@ class BulkInfo(TypedDict):
fields: list[BulkInfoField]


async def _prepare_connection_string(connection_string: str, aad_token: str | None) -> tuple[str, str | None]:
if "authentication" in connection_string.lower():
parts = [(kv[0 : kv.index("=")], kv[kv.index("=") + 1 :]) for kv in connection_string.split(";")]
auth_part = next((p for p in parts if p[0].casefold() == "Authentication".casefold()))
parts.remove(auth_part)
credential = None
auth_method = auth_part[1].lower()
if auth_method in ["ActiveDirectoryDefault".lower()]:
from azure.identity.aio import DefaultAzureCredential

credential = DefaultAzureCredential()
elif auth_method in ["ActiveDirectoryMSI".lower(), "ActiveDirectoryManagedIdentity".lower()]:
from azure.identity.aio import ManagedIdentityCredential

client_part = next((p for p in parts if p[0].lower() in ["user", "msiclientid"]), None)
if client_part:
parts.remove(client_part)
credential = ManagedIdentityCredential(client_id=client_part[1] if client_part else None)
elif auth_method == "ActiveDirectoryInteractive".lower():
from azure.identity import InteractiveBrowserCredential

credential = InteractiveBrowserCredential()
elif auth_method == "SqlPassword": # that's kind of an no-op
return ";".join((p[0] + "=" + p[1] for p in parts)), None
if credential is not None:
from azure.core.credentials import AccessToken

res = credential.get_token("https://database.windows.net/.default")
token: AccessToken = await res if inspect.isawaitable(res) else res # type: ignore
aad_token = token.token
return ";".join((p[0] + "=" + p[1] for p in parts)), aad_token
return connection_string, aad_token


async def insert_record_batch_to_sql(
connection_string: str,
table_name: str,
reader: pa.RecordBatchReader,
col_names: list[str] | None = None,
aad_token: str | None = None,
):
connection_string, aad_token = await _prepare_connection_string(connection_string, aad_token)
connection_string, aad_token = await prepare_connection_string(connection_string, aad_token)

return await lvd.insert_arrow_reader_to_sql(connection_string, reader, table_name, col_names or [], aad_token)

Expand All @@ -68,7 +36,7 @@ async def insert_http_arrow_stream_to_sql(
aad_token: str | None = None,
col_names: list[str] | None = None,
) -> BulkInfo:
connection_string, aad_token = await _prepare_connection_string(connection_string, aad_token)
connection_string, aad_token = await prepare_connection_string(connection_string, aad_token)

return await lvd.insert_arrow_stream_to_sql(
connection_string, table_name, col_names or [], url, basic_auth[0], basic_auth[1], aad_token
Expand Down
22 changes: 22 additions & 0 deletions lakeapi2sql/sql_connection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import lakeapi2sql._lowlevel as lvd
from lakeapi2sql.utils import prepare_connection_string


class TdsConnection:
def __init__(self, connection_string: str, aad_token: str | None = None) -> None:
connection_string, aad_token = await prepare_connection_string(connection_string, aad_token)
self._connection_string = connection_string
self._aad_token = aad_token

async def __aenter__(self) -> "TdsConnection":
self._connection = await lvd.connect_sql(self.connection_string, self.aad_token)
return self

async def __aexit__(self, exc_type, exc_value, traceback) -> None:
pass

async def execute_sql(self, sql: str, arguments: list[str | int | float | bool | None]) -> list[int]:
return await lvd.execute_sql(self._connection, sql, arguments)

async def execute_sql_with_result(self, sql: str, arguments: list[str | int | float | bool | None]) -> list[int]:
return await lvd.execute_sql_with_result(self._connection, sql, arguments)
35 changes: 35 additions & 0 deletions lakeapi2sql/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import inspect


async def prepare_connection_string(connection_string: str, aad_token: str | None) -> tuple[str, str | None]:
if "authentication" in connection_string.lower():
parts = [(kv[0 : kv.index("=")], kv[kv.index("=") + 1 :]) for kv in connection_string.split(";")]
auth_part = next((p for p in parts if p[0].casefold() == "Authentication".casefold()))
parts.remove(auth_part)
credential = None
auth_method = auth_part[1].lower()
if auth_method in ["ActiveDirectoryDefault".lower()]:
from azure.identity.aio import DefaultAzureCredential

credential = DefaultAzureCredential()
elif auth_method in ["ActiveDirectoryMSI".lower(), "ActiveDirectoryManagedIdentity".lower()]:
from azure.identity.aio import ManagedIdentityCredential

client_part = next((p for p in parts if p[0].lower() in ["user", "msiclientid"]), None)
if client_part:
parts.remove(client_part)
credential = ManagedIdentityCredential(client_id=client_part[1] if client_part else None)
elif auth_method == "ActiveDirectoryInteractive".lower():
from azure.identity import InteractiveBrowserCredential

credential = InteractiveBrowserCredential()
elif auth_method == "SqlPassword": # that's kind of an no-op
return ";".join((p[0] + "=" + p[1] for p in parts)), None
if credential is not None:
from azure.core.credentials import AccessToken

res = credential.get_token("https://database.windows.net/.default")
token: AccessToken = await res if inspect.isawaitable(res) else res # type: ignore
aad_token = token.token
return ";".join((p[0] + "=" + p[1] for p in parts)), aad_token
return connection_string, aad_token
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ build-backend = "maturin"
[project]
name = "lakeapi2sql"
requires-python = ">=3.10"
version = "0.8.4"
version = "0.9.0"
classifiers = [
"Programming Language :: Rust",
"Programming Language :: Python :: Implementation :: CPython",
Expand Down
Loading

0 comments on commit f16fc8f

Please sign in to comment.