Skip to content

Commit

Permalink
Merge pull request #16 from bmsuisse/executesql
Browse files Browse the repository at this point in the history
Bug Fixes and Testing
  • Loading branch information
aersam authored May 3, 2024
2 parents f16fc8f + f3c4de9 commit 36e8a1d
Show file tree
Hide file tree
Showing 22 changed files with 455 additions and 101 deletions.
41 changes: 41 additions & 0 deletions .github/workflows/python-test.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
name: Python Test

on:
push:
branches: ["main"]
paths-ignore: ["README.md", "docs", ".github"]
pull_request:
branches: ["main"]
paths-ignore: ["README.md", "docs", ".github"]

jobs:
build:
runs-on: ubuntu-22.04
strategy:
fail-fast: false
matrix:
python-version: ["3.11"]

steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v3
with:
python-version: ${{ matrix.python-version }}

- name: Setup Rust
uses: actions-rust-lang/setup-rust-toolchain@v1
- name: Install tooling dependencies
run: |
python -m pip install --upgrade pip
pip install maturin
- name: Install Dependencies
run: |
pip install pytest polars pyarrow pytest-asyncio pyright python-dotenv docker pyright cffi
- name: Install Project
run: maturin develop
- name: pytest
shell: bash
run: pytest
- name: Pyright
run: poetry run pyright .
28 changes: 21 additions & 7 deletions lakeapi2sql/sql_connection.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,36 @@
import lakeapi2sql._lowlevel as lvd
from lakeapi2sql.utils import prepare_connection_string
from typing import TypedDict


class TdsColumn(TypedDict):
name: str
column_type: str


class TdsResult(TypedDict):
columns: list[TdsColumn]
rows: list[dict]


class TdsConnection:
def __init__(self, connection_string: str, aad_token: str | None = None) -> None:
connection_string, aad_token = await prepare_connection_string(connection_string, aad_token)
self._connection_string = connection_string
self._aad_token = aad_token

async def __aenter__(self) -> "TdsConnection":
self._connection = await lvd.connect_sql(self.connection_string, self.aad_token)
connection_string, aad_token = await prepare_connection_string(self._connection_string, self._aad_token)

self._connection = await lvd.connect_sql(connection_string, aad_token)
return self

async def __aexit__(self, exc_type, exc_value, traceback) -> None:
async def __aexit__(self, *args, **kwargs) -> None:
pass

async def execute_sql(self, sql: str, arguments: list[str | int | float | bool | None]) -> list[int]:
return await lvd.execute_sql(self._connection, sql, arguments)
async def execute_sql(self, sql: str, arguments: list[str | int | float | bool | None] = None) -> list[int]:
return await lvd.execute_sql(self._connection, sql, arguments or [])

async def execute_sql_with_result(self, sql: str, arguments: list[str | int | float | bool | None]) -> list[int]:
return await lvd.execute_sql_with_result(self._connection, sql, arguments)
async def execute_sql_with_result(
self, sql: str, arguments: list[str | int | float | bool | None] = None
) -> TdsResult:
return await lvd.execute_sql_with_result(self._connection, sql, arguments or [])
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ build-backend = "maturin"
[project]
name = "lakeapi2sql"
requires-python = ">=3.10"
version = "0.9.0"
version = "0.9.1"
classifiers = [
"Programming Language :: Rust",
"Programming Language :: Python :: Implementation :: CPython",
Expand Down
57 changes: 40 additions & 17 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use error::LakeApi2SqlError;
use futures::{StreamExt, TryStreamExt};
use pyo3::exceptions::{PyConnectionError, PyIOError, PyTypeError};
use pyo3::prelude::*;
use pyo3::types::{PyDict, PyInt, PyList, PyString};
use pyo3::types::{PyDict, PyInt, PyList, PyString, PyTuple};
mod arrow_convert;
pub mod bulk_insert;
pub mod connect;
Expand Down Expand Up @@ -63,7 +63,7 @@ fn into_dict_result<'a>(py: Python<'a>, meta: Option<ResultMetadata>, rows: Vec<
let mut py_rows = PyList::new(
py,
rows.iter().map(|row| {
PyList::new(
PyTuple::new(
py,
row.cells()
.map(|(c, val)| match val {
Expand Down Expand Up @@ -244,7 +244,7 @@ impl ToSql for ValueWrap {

fn to_exec_args(args: Vec<&PyAny>) -> Result<Vec<ValueWrap>, PyErr> {
let mut res: Vec<ValueWrap> = Vec::new();
for i in 0..args.len() - 1 {
for i in 0..args.len() {
let x = args[i];
res.push(ValueWrap(if x.is_none() {
Box::new(Option::<i64>::None) as Box<dyn ToSql>
Expand Down Expand Up @@ -280,27 +280,50 @@ fn execute_sql<'a>(
list2
});
}
let nr_args = args.len();
let tds_args = to_exec_args(args)?;

let mutex = conn.0.clone();
pyo3_asyncio::tokio::future_into_py(py, async move {
let res = mutex
.clone()
.lock()
.await
.execute(
query,
tds_args
.iter()
.map(|x| x.0.borrow() as &dyn ToSql)
.collect::<Vec<&dyn ToSql>>()
.as_slice(),
)
.await;
let res = if nr_args > 0 {
mutex
.clone()
.lock()
.await
.execute(
query,
tds_args
.iter()
.map(|x| x.0.borrow() as &dyn ToSql)
.collect::<Vec<&dyn ToSql>>()
.as_slice(),
)
.await
.map(|x| x.rows_affected().to_owned())
} else {
let arc = mutex.clone();
let lock = arc.lock();
let mut conn = lock.await;
let res = conn.simple_query(query).await;
match res {
Ok(mut stream) => {
let mut row_count: u64 = 0;
while let Some(item) = stream.try_next().await.map_err(|er| {
PyErr::new::<PyIOError, _>(format!("Error executing: {er}"))
})? {
if let QueryItem::Row(_) = item {
row_count += 1;
}
}
Ok(vec![row_count])
}
Err(a) => Err(a),
}
};

match res {
Ok(re) => {
return Ok(into_list(re.rows_affected()));
return Ok(into_list(&re));
}
Err(er) => Err(PyErr::new::<PyIOError, _>(format!("Error executing: {er}"))),
}
Expand Down
28 changes: 0 additions & 28 deletions test/test_insert.py

This file was deleted.

48 changes: 0 additions & 48 deletions test/test_insert_reader.py

This file was deleted.

50 changes: 50 additions & 0 deletions test_server/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from pathlib import Path
import docker
from docker.models.containers import Container
from time import sleep
from typing import cast
import docker.errors
import os


def _getenvs():
envs = dict()
with open("test_server/sql_docker.env", "r") as f:
lines = f.readlines()
envs = {
item[0].strip(): item[1].strip()
for item in [line.split("=") for line in lines if len(line.strip()) > 0 and not line.startswith("#")]
}
return envs


def start_mssql_server() -> Container:
client = docker.from_env() # code taken from https://github.com/fsspec/adlfs/blob/main/adlfs/tests/conftest.py#L72
sql_server: Container | None = None
try:
m = cast(Container, client.containers.get("test4sql_lakeapi2sql"))
if m.status == "running":
return m
else:
sql_server = m
except docker.errors.NotFound:
pass

envs = _getenvs()

if sql_server is None:
# using podman: podman run --env-file=TESTS/SQL_DOCKER.ENV --publish=1439:1433 --name=mssql1 chriseaton/adventureworks:light
# podman kill mssql1
sql_server = client.containers.run(
"mcr.microsoft.com/mssql/server:2022-latest",
environment=envs,
detach=True,
name="test4sql_lakeapi2sql",
ports={"1433/tcp": "1444"},
) # type: ignore
assert sql_server is not None
sql_server.start()
print(sql_server.status)
sleep(15)
print("Successfully created sql container...")
return sql_server
4 changes: 4 additions & 0 deletions test_server/sql_docker.env
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
SA_PASSWORD=MyPass@word4tests
ACCEPT_EULA=Y
MSSQL_PID=Express
MSSQL_SA_PASSWORD=MyPass@word4tests
File renamed without changes.
Loading

0 comments on commit 36e8a1d

Please sign in to comment.