Skip to content

Commit

Permalink
feat(spin_sdk::pg): Improve pg type's conversion in Rust SDK
Browse files Browse the repository at this point in the history
- Add conversion support for: int16, int32, int64, floating32,
floating64, binary, boolean;
- Add support for nullable variants;

Signed-off-by: Konstantin Shabanov <mail@etehtsea.me>
  • Loading branch information
etehtsea committed Oct 29, 2022
1 parent 6dc2f5f commit a1d6939
Show file tree
Hide file tree
Showing 6 changed files with 396 additions and 57 deletions.
16 changes: 4 additions & 12 deletions crates/outbound-pg/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,7 @@ fn convert_data_type(pg_type: &Type) -> DbDataType {
Type::INT2 => DbDataType::Int16,
Type::INT4 => DbDataType::Int32,
Type::INT8 => DbDataType::Int64,
Type::TEXT => DbDataType::Str,
Type::VARCHAR => DbDataType::Str,
Type::TEXT | Type::VARCHAR | Type::BPCHAR => DbDataType::Str,
_ => {
tracing::debug!("Couldn't convert Postgres type {} to WIT", pg_type.name(),);
DbDataType::Other
Expand Down Expand Up @@ -208,17 +207,10 @@ fn convert_entry(row: &Row, index: usize) -> Result<DbValue, tokio_postgres::Err
None => DbValue::DbNull,
}
}
&Type::TEXT => {
let value: Option<&str> = row.try_get(index)?;
&Type::TEXT | &Type::VARCHAR | &Type::BPCHAR => {
let value: Option<String> = row.try_get(index)?;
match value {
Some(v) => DbValue::Str(v.to_owned()),
None => DbValue::DbNull,
}
}
&Type::VARCHAR => {
let value: Option<&str> = row.try_get(index)?;
match value {
Some(v) => DbValue::Str(v.to_owned()),
Some(v) => DbValue::Str(v),
None => DbValue::DbNull,
}
}
Expand Down
3 changes: 2 additions & 1 deletion examples/rust-outbound-pg/db/testdata.sql
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@ CREATE TABLE articletest (
id integer GENERATED BY DEFAULT AS IDENTITY PRIMARY KEY,
title varchar(40) NOT NULL,
content text NOT NULL,
authorname varchar(40) NOT NULL
authorname varchar(40) NOT NULL,
coauthor text
);

INSERT INTO articletest (title, content, authorname) VALUES
Expand Down
21 changes: 12 additions & 9 deletions examples/rust-outbound-pg/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
#![allow(dead_code)]
use anyhow::{anyhow, Result};
use spin_sdk::pg::Decode;
use spin_sdk::{
http::{Request, Response},
http_component,
pg,
http_component, pg,
};

// The environment variable set in `spin.toml` that points to the
Expand All @@ -16,22 +16,25 @@ struct Article {
title: String,
content: String,
authorname: String,
coauthor: Option<String>,
}

impl TryFrom<&pg::Row> for Article {
type Error = anyhow::Error;

fn try_from(row: &pg::Row) -> Result<Self, Self::Error> {
let id: i32 = (&row[0]).try_into()?;
let title: String = (&row[1]).try_into()?;
let content: String = (&row[2]).try_into()?;
let authorname: String = (&row[3]).try_into()?;
let id = i32::decode(&row[0])?;
let title = String::decode(&row[1])?;
let content = String::decode(&row[2])?;
let authorname = String::decode(&row[3])?;
let coauthor = Option::<String>::decode(&row[4])?;

Ok(Self {
id,
title,
content,
authorname,
coauthor,
})
}
}
Expand All @@ -51,7 +54,7 @@ fn process(req: Request) -> Result<Response> {
fn read(_req: Request) -> Result<Response> {
let address = std::env::var(DB_URL_ENV)?;

let sql = "SELECT id, title, content, authorname FROM articletest";
let sql = "SELECT id, title, content, authorname, coauthor FROM articletest";
let rowset = pg::query(&address, sql, &[])
.map_err(|e| anyhow!("Error executing Postgres query: {:?}", e))?;

Expand Down Expand Up @@ -98,7 +101,7 @@ fn write(_req: Request) -> Result<Response> {
let rowset = pg::query(&address, sql, &[])
.map_err(|e| anyhow!("Error executing Postgres query: {:?}", e))?;
let row = &rowset.rows[0];
let count: i64 = (&row[0]).try_into()?;
let count = i64::decode(&row[0])?;
let response = format!("Count: {}\n", count);

Ok(http::Response::builder()
Expand All @@ -116,7 +119,7 @@ fn pg_backend_pid(_req: Request) -> Result<Response> {

let row = &rowset.rows[0];

i32::try_from(&row[0])
i32::decode(&row[0])
};

assert_eq!(get_pid()?, get_pid()?);
Expand Down
168 changes: 153 additions & 15 deletions sdk/rust/src/pg.rs
Original file line number Diff line number Diff line change
@@ -1,47 +1,185 @@

wit_bindgen_rust::import!("../../wit/ephemeral/outbound-pg.wit");

/// Exports the generated outbound Pg items.
pub use outbound_pg::*;

impl TryFrom<&DbValue> for i32 {
type Error = anyhow::Error;
pub trait Decode: Sized {
fn decode(value: &DbValue) -> Result<Self, anyhow::Error>;
}

impl<T> Decode for Option<T>
where
T: Decode,
{
fn decode(value: &DbValue) -> Result<Self, anyhow::Error> {
match value {
DbValue::DbNull => Ok(None),
v => Ok(Some(T::decode(v)?)),
}
}
}

impl Decode for bool {
fn decode(value: &DbValue) -> Result<Self, anyhow::Error> {
match value {
DbValue::Boolean(boolean) => Ok(*boolean),
_ => Err(anyhow::anyhow!(
"Expected BOOL from the DB but got {:?}",
value
)),
}
}
}

impl Decode for i16 {
fn decode(value: &DbValue) -> Result<Self, anyhow::Error> {
match value {
DbValue::Int16(n) => Ok(*n),
_ => Err(anyhow::anyhow!(
"Expected SMALLINT from the DB but got {:?}",
value
)),
}
}
}

fn try_from(value: &DbValue) -> Result<Self, Self::Error> {
impl Decode for i32 {
fn decode(value: &DbValue) -> Result<Self, anyhow::Error> {
match value {
DbValue::Int32(n) => Ok(*n),
_ => Err(anyhow::anyhow!(
"Expected integer from database but got {:?}",
"Expected INT from the DB but got {:?}",
value
)),
}
}
}

impl TryFrom<&DbValue> for String {
type Error = anyhow::Error;
impl Decode for i64 {
fn decode(value: &DbValue) -> Result<Self, anyhow::Error> {
match value {
DbValue::Int64(n) => Ok(*n),
_ => Err(anyhow::anyhow!(
"Expected BIGINT from the DB but got {:?}",
value
)),
}
}
}

fn try_from(value: &DbValue) -> Result<Self, Self::Error> {
impl Decode for f32 {
fn decode(value: &DbValue) -> Result<Self, anyhow::Error> {
match value {
DbValue::Str(s) => Ok(s.to_owned()),
DbValue::Floating32(n) => Ok(*n),
_ => Err(anyhow::anyhow!(
"Expected REAL from the DB but got {:?}",
value
)),
}
}
}

impl Decode for f64 {
fn decode(value: &DbValue) -> Result<Self, anyhow::Error> {
match value {
DbValue::Floating64(n) => Ok(*n),
_ => Err(anyhow::anyhow!(
"Expected string from the DB but got {:?}",
"Expected DOUBLE PRECISION from the DB but got {:?}",
value
)),
}
}
}

impl TryFrom<&DbValue> for i64 {
type Error = anyhow::Error;
impl Decode for Vec<u8> {
fn decode(value: &DbValue) -> Result<Self, anyhow::Error> {
match value {
DbValue::Binary(n) => Ok(n.to_owned()),
_ => Err(anyhow::anyhow!(
"Expected BYTEA from the DB but got {:?}",
value
)),
}
}
}

fn try_from(value: &DbValue) -> Result<Self, Self::Error> {
impl Decode for String {
fn decode(value: &DbValue) -> Result<Self, anyhow::Error> {
match value {
DbValue::Int64(n) => Ok(*n),
DbValue::Str(s) => Ok(s.to_owned()),
_ => Err(anyhow::anyhow!(
"Expected integer from the DB but got {:?}",
"Expected CHAR, VARCHAR, TEXT from the DB but got {:?}",
value
)),
}
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn boolean() {
assert!(bool::decode(&DbValue::Boolean(true)).unwrap());
assert!(bool::decode(&DbValue::Int32(0)).is_err());
assert!(Option::<bool>::decode(&DbValue::DbNull).unwrap().is_none());
}

#[test]
fn int16() {
assert_eq!(i16::decode(&DbValue::Int16(0)).unwrap(), 0);
assert!(i16::decode(&DbValue::Int32(0)).is_err());
assert!(Option::<i16>::decode(&DbValue::DbNull).unwrap().is_none());
}

#[test]
fn int32() {
assert_eq!(i32::decode(&DbValue::Int32(0)).unwrap(), 0);
assert!(i32::decode(&DbValue::Boolean(false)).is_err());
assert!(Option::<i32>::decode(&DbValue::DbNull).unwrap().is_none());
}

#[test]
fn int64() {
assert_eq!(i64::decode(&DbValue::Int64(0)).unwrap(), 0);
assert!(i64::decode(&DbValue::Boolean(false)).is_err());
assert!(Option::<i64>::decode(&DbValue::DbNull).unwrap().is_none());
}

#[test]
fn floating32() {
assert!(f32::decode(&DbValue::Floating32(0.0)).is_ok());
assert!(f32::decode(&DbValue::Boolean(false)).is_err());
assert!(Option::<f32>::decode(&DbValue::DbNull).unwrap().is_none());
}

#[test]
fn floating64() {
assert!(f64::decode(&DbValue::Floating64(0.0)).is_ok());
assert!(f64::decode(&DbValue::Boolean(false)).is_err());
assert!(Option::<f64>::decode(&DbValue::DbNull).unwrap().is_none());
}

#[test]
fn str() {
assert_eq!(
String::decode(&DbValue::Str(String::from("foo"))).unwrap(),
String::from("foo")
);

assert!(String::decode(&DbValue::Int32(0)).is_err());
assert!(Option::<String>::decode(&DbValue::DbNull)
.unwrap()
.is_none());
}

#[test]
fn binary() {
assert!(Vec::<u8>::decode(&DbValue::Binary(vec![0, 0])).is_ok());
assert!(Vec::<u8>::decode(&DbValue::Boolean(false)).is_err());
assert!(Option::<Vec<u8>>::decode(&DbValue::DbNull)
.unwrap()
.is_none());
}
}
4 changes: 3 additions & 1 deletion tests/integration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -664,7 +664,9 @@ mod integration_tests {
)
.await?;

assert_status(&s, "/test_read_types", 200).await?;
assert_status(&s, "/test_numeric_types", 200).await?;
assert_status(&s, "/test_character_types", 200).await?;
assert_status(&s, "/test_general_types", 200).await?;
assert_status(&s, "/pg_backend_pid", 200).await?;

Ok(())
Expand Down
Loading

0 comments on commit a1d6939

Please sign in to comment.