diff --git a/pyo3-macros-backend/src/pyclass.rs b/pyo3-macros-backend/src/pyclass.rs index b83af9d0e7d..3a0a2bc74e3 100644 --- a/pyo3-macros-backend/src/pyclass.rs +++ b/pyo3-macros-backend/src/pyclass.rs @@ -410,7 +410,6 @@ struct PyClassEnum<'a> { ident: &'a syn::Ident, // The underlying #[repr] of the enum, used to implement __int__ and __richcmp__. // This matters when the underlying representation may not fit in `isize`. - #[allow(unused, dead_code)] repr_type: syn::Ident, variants: Vec>, } @@ -522,7 +521,68 @@ fn impl_enum_class( } }; - let default_impls = gen_default_slot_impls(cls, vec![default_repr_impl]); + let repr_type = &enum_.repr_type; + + let default_int = { + // This implementation allows us to convert &T to #repr_type without implementing `Copy` + let variants_to_int = variants.iter().map(|variant| { + let variant_name = variant.ident; + quote! { #cls::#variant_name => #cls::#variant_name as #repr_type, } + }); + quote! { + #[doc(hidden)] + #[allow(non_snake_case)] + #[pyo3(name = "__int__")] + fn __pyo3__int__(&self) -> #repr_type { + match self { + #(#variants_to_int)* + } + } + } + }; + + let default_richcmp = { + let variants_eq = variants.iter().map(|variant| { + let variant_name = variant.ident; + quote! { + (#cls::#variant_name, #cls::#variant_name) => + Ok(true.to_object(py)), + } + }); + quote! { + #[doc(hidden)] + #[allow(non_snake_case)] + #[pyo3(name = "__richcmp__")] + fn __pyo3__richcmp__( + &self, + py: _pyo3::Python, + other: &_pyo3::PyAny, + op: _pyo3::basic::CompareOp + ) -> _pyo3::PyResult<_pyo3::PyObject> { + use _pyo3::conversion::ToPyObject; + use ::core::result::Result::*; + match op { + _pyo3::basic::CompareOp::Eq => { + if let Ok(i) = other.extract::<#repr_type>() { + let self_val = self.__pyo3__int__(); + return Ok((self_val == i).to_object(py)); + } + let other = other.extract::<_pyo3::PyRef>()?; + let other = &*other; + match (self, other) { + #(#variants_eq)* + _ => Ok(false.to_object(py)), + } + } + _ => Ok(py.NotImplemented()), + } + } + } + }; + + let default_impls = + gen_default_slot_impls(cls, vec![default_repr_impl, default_richcmp, default_int]); + Ok(quote! { const _: () = { use #krate as _pyo3; diff --git a/tests/test_enum.rs b/tests/test_enum.rs index ef8c36cf844..9f15bf23e5a 100644 --- a/tests/test_enum.rs +++ b/tests/test_enum.rs @@ -14,12 +14,11 @@ pub enum MyEnum { #[test] fn test_enum_class_attr() { - let gil = Python::acquire_gil(); - let py = gil.python(); - let my_enum = py.get_type::(); - py_assert!(py, my_enum, "getattr(my_enum, 'Variant', None) is not None"); - py_assert!(py, my_enum, "getattr(my_enum, 'foobar', None) is None"); - py_run!(py, my_enum, "my_enum.Variant = None"); + Python::with_gil(|py| { + let my_enum = py.get_type::(); + let var = Py::new(py, MyEnum::Variant).unwrap(); + py_assert!(py, my_enum var, "my_enum.Variant == var"); + }) } #[pyfunction] @@ -28,7 +27,6 @@ fn return_enum() -> MyEnum { } #[test] -#[ignore] // need to implement __eq__ fn test_return_enum() { let gil = Python::acquire_gil(); let py = gil.python(); @@ -44,14 +42,24 @@ fn enum_arg(e: MyEnum) { } #[test] -#[ignore] // need to implement __eq__ fn test_enum_arg() { - let gil = Python::acquire_gil(); - let py = gil.python(); - let f = wrap_pyfunction!(enum_arg)(py).unwrap(); - let mynum = py.get_type::(); + Python::with_gil(|py| { + let f = wrap_pyfunction!(enum_arg)(py).unwrap(); + let mynum = py.get_type::(); + + py_run!(py, f mynum, "f(mynum.OtherVariant)") + }) +} - py_run!(py, f mynum, "f(mynum.Variant)") +#[test] +fn test_enum_eq() { + Python::with_gil(|py| { + let var1 = Py::new(py, MyEnum::Variant).unwrap(); + let var2 = Py::new(py, MyEnum::Variant).unwrap(); + let other_var = Py::new(py, MyEnum::OtherVariant).unwrap(); + py_assert!(py, var1 var2, "var1 == var2"); + py_assert!(py, var1 other_var, "var1 != other_var"); + }) } #[test] @@ -85,6 +93,63 @@ fn test_custom_discriminant() { }) } +#[test] +fn test_enum_to_int() { + Python::with_gil(|py| { + let one = Py::new(py, CustomDiscriminant::One).unwrap(); + py_assert!(py, one, "int(one) == 1"); + let v = Py::new(py, MyEnum::Variant).unwrap(); + let v_value = MyEnum::Variant as isize; + py_run!(py, v v_value, "int(v) == v_value"); + }) +} + +#[test] +fn test_enum_compare_int() { + Python::with_gil(|py| { + let one = Py::new(py, CustomDiscriminant::One).unwrap(); + py_run!( + py, + one, + r#" + assert one == 1 + assert 1 == one + assert one != 2 + "# + ) + }) +} + +#[pyclass] +#[repr(u8)] +enum SmallEnum { + V = 1, +} + +#[test] +fn test_enum_compare_int_no_throw_when_overflow() { + Python::with_gil(|py| { + let v = Py::new(py, SmallEnum::V).unwrap(); + py_assert!(py, v, "v != 1<<30") + }) +} + +#[pyclass] +#[repr(usize)] +enum BigEnum { + V = usize::MAX, +} + +#[test] +fn test_big_enum_no_overflow() { + Python::with_gil(|py| { + let usize_max = usize::MAX; + let v = Py::new(py, BigEnum::V).unwrap(); + py_assert!(py, usize_max v, "v == usize_max"); + py_assert!(py, usize_max v, "int(v) == usize_max"); + }) +} + #[pyclass] #[repr(u16, align(8))] enum TestReprParse {