Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Creating Pydantic objects in Rust and passing to the interpreter. #1364

Open
PSU3D0 opened this issue Jul 9, 2024 · 4 comments
Open

Creating Pydantic objects in Rust and passing to the interpreter. #1364

PSU3D0 opened this issue Jul 9, 2024 · 4 comments
Assignees
Labels
enhancement New feature or request

Comments

@PSU3D0
Copy link

PSU3D0 commented Jul 9, 2024

What's the best way to do this?

I'd like to avoid passing JSON via Pyo3 to python and THEN creating the model.

Use case:

I am moving bounding box processing logic in my library Docprompt into Rust. Documents can have tens of thousands of bounding boxes, so small overhead becomes an issue.

Thank you for the help!

@samuelcolvin
Copy link
Member

I'm the general case you can do this efficiently with pyo3.

But I'm not sure how pydantic comes into it?

If you're looking to implement your own types in rust, then validate them in pydantic in rust, I believe that's not currently possible. It's something we'd like to support one day, but I'm not sure how we'd go about it.

@PSU3D0
Copy link
Author

PSU3D0 commented Jul 9, 2024

Thank's for the reply! Apologies as this may be more pyo3 in general, but wondered since the core pydantic lib is rust-side.

So we currently represent a bounding box in Python Space

class NormBBox(BaseModel):
    """
    Represents a normalized bounding box with each value in the range [0, 1]

    Where x1 > x0 and bottom > top
    """
    x0: BoundedFloat
    top: BoundedFloat
    x1: BoundedFloat
    bottom: BoundedFloat

    def as_tuple(self):
        return (self.x0, self.top, self.x1, self.bottom)

    def __getitem__(self, index):
        # Lots of if statements to prevent new allocations
        if index > 3:
            raise IndexError("Index out of range")

        if index == 0:
            return self.x0
        elif index == 1:
            return self.top
        elif index == 2:
            return self.x1
        elif index == 3:
            return self.bottom

    def __eq__(self, other):
        if not isinstance(other, NormBBox):
            return False

        return self.as_tuple() == other.as_tuple()

In Pyo3 to implement this we could:

#[pyclass]
#[derive(Clone, Debug)]
struct NormBBox {
    x0: f32,
    top: f32,
    x1: f32,
    bottom: f32,
}

#[pymethods]
impl NormBBox {
...

How would we retain NormBBox being a subclass of BaseModel? Are validation/serialization utilities available in Rust space?

@davidhewitt
Copy link
Contributor

To subclass BaseModel from Rust would require something like PyO3/pyo3#991. It's very much in the "future feature" category at the moment, I'm afraid.

@Jocelyn-Gas
Copy link

I made "Pydantic-like" objects from rust by "emulating" some of the features.

My requirements were:

  • Rust-exposed classes had to be usable as field types within Python-defined models. For that, we need to implement __get_pydantic_core_schema__.

  • As some of the models defined in Rust will also be used as-is and not be nested in other models, the public methods that I used the most add to match the one from Pydantic.

For that I created a macro attribute:

pub fn pydantic(
    attr: proc_macro::TokenStream,
    item: proc_macro::TokenStream,
) -> proc_macro::TokenStream {
    let imports = get_import_statements(attr.into());

    let input = parse_macro_input!(item as Item);

    let struct_name = match &input {
        Item::Struct(ItemStruct { ident, .. }) => ident,
        _ => panic!("The pydantic macro can only be used on structs"),
    };

    let methods = get_methods(struct_name);

    // Generate the code to be added
    let expanded = quote! {
        #imports
        
        // Original struct
        #input

        // Pydantic methods
        #methods
    };

    // Convert the expanded code back into a TokenStream
    proc_macro::TokenStream::from(expanded)
}

Where the get_methods function looks like:

use proc_macro2::TokenStream;
use quote::quote;
use syn::Ident;

pub fn get_methods(struct_name: &Ident) -> TokenStream {
    quote! {

        impl #struct_name {
            pub fn from_json_string(json_str: &str) -> Result<Self, serde_json::Error> {
                serde_json::from_str(json_str)
            }
        }

        #[pymethods]
        impl #struct_name {
            #[classmethod]
            fn __get_pydantic_core_schema__<'a>(
                cls: &'a Bound<'a, PyType>,
                source_type: &'a Bound<'a, PyAny>,
                handler: &'a Bound<'a, PyAny>,
            ) -> PyResult<Bound<'a, PyAny>> {
                let py = cls.py();

                let schema_generator = PyModule::from_code_bound(
                    py,
                    r#"
from pydantic import GetCoreSchemaHandler
from pydantic_core import core_schema

def generate_schema(
    cls,
    source_type,
    handler,
) -> core_schema.CoreSchema:
    def validate_from_dict_or_json_str(value: dict | str):
        return cls.model_validate(value)

    from_dict_schema = core_schema.chain_schema(
        [
            core_schema.no_info_plain_validator_function(
                validate_from_dict_or_json_str
            ),
        ]
    )
    return core_schema.json_or_python_schema(
        json_schema=from_dict_schema,
        python_schema=core_schema.union_schema(
            [
                core_schema.is_instance_schema(cls),
                from_dict_schema,
            ]
        ),
        serialization=core_schema.plain_serializer_function_ser_schema(
            lambda instance: instance.model_dump(),
        ),
    )

                    "#, "schema_generator.py", "schema_generator"
                )?;
                let function = schema_generator.getattr("generate_schema")?;

                let result = function.call1((cls, source_type, handler))?;
                Ok(result)
            }


            #[pyo3(signature = (*, mode = "python", indent = None))]
            pub fn model_dump(&self, mode: &str, indent: Option<usize>) -> PyResult<PyObject> {
                return Python::with_gil(|py| match mode {
                    "json" => match self.model_dump_json(indent) {
                        Ok(json) => Ok(PyString::new_bound(py, &json).into()),
                        Err(e) => Err(PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(
                            format!("Error converting to JSON: {}", e),
                        )),
                    },
                    "python" => self.model_dump_python(),
                    _ => Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
                        "Unsupported mode: {}",
                        mode
                    ))),
                });
            }

            #[pyo3(signature = (*, indent = None))]
            fn model_dump_json(&self, indent: Option<usize>) -> PyResult<String> {
                match indent {
                    Some(indent) => {
                        let mut buf = Vec::new();
                        let byte_string = self.create_byte_string(indent);
                        let formatter =
                            serde_json::ser::PrettyFormatter::with_indent(&byte_string);
                        let mut ser =
                            serde_json::Serializer::with_formatter(&mut buf, formatter);
                        match self.serialize(&mut ser) {
                            Ok(_) => Ok(String::from_utf8(buf).unwrap()),
                            Err(e) => Err(PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(
                                format!("Error converting to JSON: {}", e),
                            )),
                        }
                    }
                    None => match serde_json::to_string(self) {
                        Ok(json) => Ok(json),
                        Err(e) => Err(PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(
                            format!("Error converting to JSON: {}", e),
                        )),
                    },
                }
            }

            fn model_dump_python(&self) -> PyResult<PyObject> {
                return Python::with_gil(|py| {
                    let binding = py.import_bound("json").unwrap().getattr("loads").unwrap();
                    let json_str = self.model_dump_json(None).unwrap();
                    let json_str_ref: &str = &json_str;
                    Ok(binding.call1((json_str_ref,)).unwrap().into())
                });
            }
            fn create_byte_string(&self, indent: usize) -> Vec<u8> {
                let result = format!("{}", " ".repeat(indent));
                result.into_bytes()
            }

            #[classmethod]
            fn model_validate(cls: &Bound<PyType>, data: PyObject) -> PyResult<Self> {
                let py = cls.py();
                match data.extract::<String>(py) {
                    Ok(data) => {
                        return serde_json::from_str(&data).map_err(|e| {
                            PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
                                "Error validating data: {}",
                                e
                            ))
                        });
                    }
                    Err(_) => {
                        let dumper = py.import_bound("json").unwrap().getattr("dumps").unwrap();
                        return serde_json::from_str(dumper.call1((data,)).unwrap().extract()?)
                            .map_err(|e| {
                                PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
                                    "Error validating data: {}",
                                    e
                                ))
                            });
                    }
                }
            }
        }
    }
}

It definitely feels very hacky, but it worked for my use-case.

The conversion from python dicts and json strings is done by using python's json module and serde's de/serialization

@sydney-runkle sydney-runkle added the enhancement New feature or request label Aug 16, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

5 participants