Skip to content

Commit

Permalink
Better errors for illegal PyO3 FieldValue values. (#676)
Browse files Browse the repository at this point in the history
  • Loading branch information
obi1kenobi authored Sep 19, 2024
1 parent b31a090 commit d10ef03
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 26 deletions.
5 changes: 4 additions & 1 deletion pytrustfall/src/shim.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
use std::{collections::BTreeMap, sync::Arc};

use pyo3::{
exceptions::PyStopIteration, prelude::*, types::PyIterator, types::PyTuple, wrap_pyfunction,
exceptions::PyStopIteration,
prelude::*,
types::{PyIterator, PyTuple},
wrap_pyfunction,
};
use trustfall_core::{
frontend::{error::FrontendError, parse},
Expand Down
95 changes: 71 additions & 24 deletions pytrustfall/src/value.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
use std::sync::Arc;
use std::{fmt::Display, sync::Arc};

use pyo3::{exceptions::PyTypeError, prelude::*, types::PyList};
use pyo3::{exceptions::PyValueError, prelude::*, types::PyList};

use crate::errors::QueryArgumentsError;

// TODO: apply https://pyo3.rs/v0.22.3/conversions/traits#deriving-frompyobject-for-enums
#[derive(Debug, Clone)]
pub(crate) enum FieldValue {
Null,
Expand All @@ -18,6 +15,38 @@ pub(crate) enum FieldValue {
List(Vec<FieldValue>),
}

impl FieldValue {
#[inline]
pub(crate) fn is_null(&self) -> bool {
matches!(self, FieldValue::Null)
}
}

impl Display for FieldValue {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
FieldValue::Null => write!(f, "null"),
FieldValue::Int64(v) => write!(f, "{v}"),
FieldValue::Uint64(v) => write!(f, "{v}"),
FieldValue::Float64(v) => write!(f, "{v}"),
FieldValue::String(v) => write!(f, "\"{v}\""),
FieldValue::Boolean(v) => write!(f, "{v}"),
FieldValue::Enum(v) => write!(f, "{v}"),
FieldValue::List(v) => {
write!(f, "[")?;
let mut iter = v.iter();
if let Some(next) = iter.next() {
write!(f, "{next}")?;
}
for elem in iter {
write!(f, ", {elem}")?;
}
write!(f, "]")
}
}
}
}

impl IntoPy<Py<PyAny>> for FieldValue {
fn into_py(self, py: Python<'_>) -> Py<PyAny> {
match self {
Expand Down Expand Up @@ -46,37 +75,55 @@ impl<'py> pyo3::FromPyObject<'py> for FieldValue {
} else if let Ok(inner) = value.extract::<u64>() {
Ok(FieldValue::Uint64(inner))
} else if let Ok(inner) = value.extract::<f64>() {
// TODO: disallow and error on nan and infinite values
Ok(FieldValue::Float64(inner))
if inner.is_finite() {
Ok(FieldValue::Float64(inner))
} else {
Err(PyValueError::new_err(format!(
"{inner} is not a valid query argument value: \
float values may not be NaN or infinity"
)))
}
} else if let Ok(inner) = value.extract::<String>() {
Ok(FieldValue::String(inner.into()))
} else if let Ok(list) = value.downcast::<PyList>() {
let converted = list.iter().map(|element| element.extract::<FieldValue>()).try_fold(
vec![],
|mut acc, item| {
if let Ok(value) = item {
acc.push(value);
Some(acc)
} else {
None
}
},
);
let mut converted = Vec::with_capacity(list.len());
for element in list.iter() {
let value = element.extract::<FieldValue>()?;
converted.push(value);
}

// TODO: handle conversion errors properly
if let Some(inner_values) = converted {
Ok(FieldValue::List(inner_values))
} else {
Err(PyErr::new::<PyTypeError, &str>("first"))
// Ensure all non-null items in the list are of the same type.
let mut iter = converted.iter();
let first_non_null = loop {
let Some(next) = iter.next() else { break None };
if !next.is_null() {
break Some(next);
}
};
if let Some(first) = first_non_null {
let expected = std::mem::discriminant(first);
for other in iter {
if !other.is_null() {
let next_discriminant = std::mem::discriminant(other);
if expected != next_discriminant {
return Err(PyValueError::new_err(format!(
"Found elements of different (non-null) types in the same list, \
which is not allowed: {first} {other}"
)));
}
}
}
}

Ok(FieldValue::List(converted))
} else {
let repr = value.repr();
let display = repr
.as_ref()
.map_err(|_| ())
.and_then(|x| x.to_str().map_err(|_| ()))
.unwrap_or("<repr unavailable>");
Err(QueryArgumentsError::new_err(format!(
Err(PyValueError::new_err(format!(
"Value {display} of type {} is not supported by Trustfall",
value.get_type()
)))
Expand Down
4 changes: 3 additions & 1 deletion pytrustfall/trustfall/tests/test_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,9 @@ def test_unrepresentable_field_value(self) -> None:
"required": object(),
}

self.assertRaises(QueryArgumentsError, execute_query, NumbersAdapter(), SCHEMA, query, args)
self.assertRaises(
ValueError, execute_query, NumbersAdapter(), SCHEMA, query, args
)

def test_bad_query_input_type(self) -> None:
query = 123
Expand Down

0 comments on commit d10ef03

Please sign in to comment.