Skip to content

Commit

Permalink
feat(bindings/python): convert query result types
Browse files Browse the repository at this point in the history
  • Loading branch information
everpcpc committed Oct 26, 2023
1 parent 0d0dd21 commit 6114964
Show file tree
Hide file tree
Showing 9 changed files with 190 additions and 120 deletions.
2 changes: 1 addition & 1 deletion bindings/python/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ license = { workspace = true }
authors = { workspace = true }

[lib]
crate-type = ["cdylib"]
crate-type = ["cdylib", "rlib"]
name = "databend_driver"
doc = false

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,5 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# flake8: noqa
from ._databend_driver import *

__all__ = _databend_driver.__all__
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

class AsyncDatabendDriver:
def __init__(self, dsn: str): ... # NOQA
async def exec(self, sql: str) -> int: ... # NOQA
# flake8: noqa
class AsyncDatabendConnection:
async def exec(self, sql: str): ...
async def query_row(self, sql: str): ...

# flake8: noqa
class AsyncDatabendClient:
def __init__(self, dsn: str): ...
async def get_conn(self) -> AsyncDatabendConnection: ...
10 changes: 5 additions & 5 deletions bindings/python/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
[build-system]
build-backend = "maturin"
requires = ["maturin>=1.0,<2.0"]

[project]
classifiers = [
"Programming Language :: Rust",
Expand All @@ -25,4 +21,8 @@ Repository = "https://github.com/datafuselabs/bendsql"
[tool.maturin]
features = ["pyo3/extension-module"]
module-name = "databend_driver._databend_driver"
python-source = "python"
python-source = "package"

[build-system]
build-backend = "maturin"
requires = ["maturin>=1.0,<2.0"]
51 changes: 0 additions & 51 deletions bindings/python/src/asyncio.rs

This file was deleted.

167 changes: 139 additions & 28 deletions bindings/python/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,49 +12,160 @@
// See the License for the specific language governing permissions and
// limitations under the License.

mod asyncio;

use crate::asyncio::*;

use databend_driver::{Client, Connection};

use pyo3::create_exception;
use pyo3::exceptions::PyException;
use pyo3::prelude::*;
use std::sync::Arc;
use pyo3::types::{PyDict, PyList, PyTuple};
use pyo3_asyncio::tokio::future_into_py;

create_exception!(
databend_client,
Error,
PyException,
"databend_client related errors"
);

#[derive(Clone)]
pub struct Connector {
pub connector: FusedConnector,
#[pymodule]
fn _databend_driver(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_class::<AsyncDatabendClient>()?;
m.add_class::<AsyncDatabendConnection>()?;
Ok(())
}

pub type FusedConnector = Arc<dyn Connection>;

// For bindings
impl Connector {
pub fn new_connector(dsn: &str) -> Result<Box<Self>, Error> {
let client = Client::new(dsn.to_string());
let conn = futures::executor::block_on(client.get_conn()).unwrap();
let r = Self {
connector: FusedConnector::from(conn),
};
Ok(Box::new(r))
#[pyclass(module = "databend_driver")]
pub struct AsyncDatabendClient(databend_driver::Client);

#[pymethods]
impl AsyncDatabendClient {
#[new]
#[pyo3(signature = (dsn))]
pub fn new(dsn: String) -> PyResult<Self> {
let client = databend_driver::Client::new(dsn);
Ok(Self(client))
}

pub fn get_conn<'p>(&'p self, py: Python<'p>) -> PyResult<&'p PyAny> {
let this = self.0.clone();
future_into_py(py, async move {
let conn = this.get_conn().await.unwrap();
Ok(AsyncDatabendConnection(conn))
})
}
}

fn build_connector(dsn: &str) -> PyResult<Connector> {
let conn = Connector::new_connector(dsn).unwrap();
Ok(*conn)
#[pyclass(module = "databend_driver")]
pub struct AsyncDatabendConnection(Box<dyn databend_driver::Connection>);

#[pymethods]
impl AsyncDatabendConnection {
pub fn exec<'p>(&'p self, py: Python<'p>, sql: String) -> PyResult<&'p PyAny> {
let this = self.0.clone();
future_into_py(py, async move {
let res = this.exec(&sql).await.unwrap();
Ok(res)
})
}

pub fn query_row<'p>(&'p self, py: Python<'p>, sql: String) -> PyResult<&'p PyAny> {
let this = self.0.clone();
future_into_py(py, async move {
let row = this.query_row(&sql).await.unwrap();
let row = row.unwrap();
Ok(Row(row))
})
}
}

#[pymodule]
fn _databend_driver(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_class::<AsyncDatabendDriver>()?;
Ok(())
#[pyclass(module = "databend_driver")]
pub struct Row(databend_driver::Row);

#[pymethods]
impl Row {
pub fn values<'p>(&'p self, py: Python<'p>) -> PyResult<PyObject> {
let res = PyTuple::new(
py,
self.0
.values()
.into_iter()
.map(|v| Value(v.clone()).into_py(py)), // FIXME: do not clone
);
Ok(res.into_py(py))
}
}

pub struct Value(databend_driver::Value);

impl IntoPy<PyObject> for Value {
fn into_py(self, py: Python<'_>) -> PyObject {
match self.0 {
databend_driver::Value::Null => py.None(),
databend_driver::Value::EmptyArray => {
let list = PyList::empty(py);
list.into_py(py)
}
databend_driver::Value::EmptyMap => {
let dict = PyDict::new(py);
dict.into_py(py)
}
databend_driver::Value::Boolean(b) => b.into_py(py),
databend_driver::Value::String(s) => s.into_py(py),
databend_driver::Value::Number(n) => {
let v = NumberValue(n);
v.into_py(py)
}
databend_driver::Value::Timestamp(_) => {
let s = self.0.to_string();
s.into_py(py)
}
databend_driver::Value::Date(_) => {
let s = self.0.to_string();
s.into_py(py)
}
databend_driver::Value::Array(inner) => {
let list = PyList::new(py, inner.into_iter().map(|v| Value(v).into_py(py)));
list.into_py(py)
}
databend_driver::Value::Map(inner) => {
let dict = PyDict::new(py);
for (k, v) in inner {
dict.set_item(Value(k).into_py(py), Value(v).into_py(py))
.unwrap();
}
dict.into_py(py)
}
databend_driver::Value::Tuple(inner) => {
let tuple = PyTuple::new(py, inner.into_iter().map(|v| Value(v).into_py(py)));
tuple.into_py(py)
}
databend_driver::Value::Bitmap(s) => s.into_py(py),
databend_driver::Value::Variant(s) => s.into_py(py),
}
}
}

pub struct NumberValue(databend_driver::NumberValue);

impl IntoPy<PyObject> for NumberValue {
fn into_py(self, py: Python<'_>) -> PyObject {
match self.0 {
databend_driver::NumberValue::Int8(i) => i.into_py(py),
databend_driver::NumberValue::Int16(i) => i.into_py(py),
databend_driver::NumberValue::Int32(i) => i.into_py(py),
databend_driver::NumberValue::Int64(i) => i.into_py(py),
databend_driver::NumberValue::UInt8(i) => i.into_py(py),
databend_driver::NumberValue::UInt16(i) => i.into_py(py),
databend_driver::NumberValue::UInt32(i) => i.into_py(py),
databend_driver::NumberValue::UInt64(i) => i.into_py(py),
databend_driver::NumberValue::Float32(i) => i.into_py(py),
databend_driver::NumberValue::Float64(i) => i.into_py(py),
databend_driver::NumberValue::Decimal128(_, _) => {
let s = self.0.to_string();
s.into_py(py)
}
databend_driver::NumberValue::Decimal256(_, _) => {
let s = self.0.to_string();
s.into_py(py)
}
}
}
}
22 changes: 0 additions & 22 deletions bindings/python/tests/binding.feature

This file was deleted.

1 change: 1 addition & 0 deletions bindings/python/tests/binding.feature
41 changes: 33 additions & 8 deletions bindings/python/tests/steps/binding.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,22 +19,47 @@
import databend_driver


@given("A new Databend-Driver Async Connector")
@given("A new Databend Driver Client")
@async_run_until_complete
async def _(context):
dsn = os.getenv(
"TEST_DATABEND_DSN", "databend+http://root:root@localhost:8000/?sslmode=disable"
)
context.ad = databend_driver.AsyncDatabendDriver(dsn)
client = databend_driver.AsyncDatabendClient(dsn)
context.conn = await client.get_conn()


@when('Async exec "{sql}"')
@when("Create a test table")
@async_run_until_complete
async def _(context, sql):
await context.ad.exec(sql)
async def _(context):
# TODO:
pass


@then("Select string {input} should be equal to {output}")
@async_run_until_complete
async def _(context, input, output):
row = await context.conn.query_row(f"SELECT '{input}'")
value = row.values()[0]
assert output == value


@then('The select "{select_sql}" should run')
@then("Select numbers should iterate all rows")
@async_run_until_complete
async def _(context, select_sql):
await context.ad.exec(select_sql)
async def _(context):
# TODO:
pass


@then("Insert and Select should be equal")
@async_run_until_complete
async def _(context):
# TODO:
pass


@then("Stream load and Select should be equal")
@async_run_until_complete
async def _(context):
# TODO:
pass
1 change: 1 addition & 0 deletions driver/src/conn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ use databend_sql::value::{NumberValue, Value};

use crate::rest_api::RestAPIConnection;

#[derive(Clone)]
pub struct Client {
dsn: String,
}
Expand Down

0 comments on commit 6114964

Please sign in to comment.