Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Assume commutation when deriving PostgresEq #1261

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions pgrx-examples/custom_types/src/hexint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -208,4 +208,41 @@ mod tests {
assert_eq!(value, Some("0x2a".into()));
Ok(())
}

#[pg_test]
fn test_hash() {
Spi::run(
"CREATE TABLE hexintext (
id hexint,
data TEXT
);
CREATE TABLE hexint_duo (
id hexint,
foo_id hexint
);
INSERT INTO hexintext DEFAULT VALUES;
INSERT INTO hexint_duo DEFAULT VALUES;
SELECT *
FROM hexint_duo
JOIN hexintext ON hexint_duo.id = hexintext.id;",
)
.unwrap();
}

#[pg_test]
fn test_commutator() {
Spi::run(
"CREATE TABLE hexintext (
id hexint,
data TEXT
);
CREATE TABLE just_hexint (
id hexint
);
SELECT *
FROM just_hexint
JOIN hexintext ON just_hexint.id = hexintext.id;",
)
.unwrap();
}
}
24 changes: 18 additions & 6 deletions pgrx-macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use quote::{quote, ToTokens};
use syn::spanned::Spanned;
use syn::{parse_macro_input, Attribute, Data, DeriveInput, Item, ItemImpl};

use operators::{impl_postgres_eq, impl_postgres_hash, impl_postgres_ord};
use operators::{deriving_postgres_eq, deriving_postgres_hash, deriving_postgres_ord};
use pgrx_sql_entity_graph::{
parse_extern_attributes, CodeEnrichment, ExtensionSql, ExtensionSqlFile, ExternArgs,
PgAggregate, PgExtern, PostgresEnum, PostgresType, Schema,
Expand Down Expand Up @@ -976,11 +976,15 @@ enum DogNames {
Optionally accepts the following attributes:

* `sql`: Same arguments as [`#[pgrx(sql = ..)]`](macro@pgrx).

# No bounds?
Unlike some derives, this does not implement a "real" Rust trait, thus
PostgresEq cannot be used in trait bounds, nor can it be manually implemented.
*/
#[proc_macro_derive(PostgresEq, attributes(pgrx))]
pub fn postgres_eq(input: TokenStream) -> TokenStream {
pub fn derive_postgres_eq(input: TokenStream) -> TokenStream {
let ast = parse_macro_input!(input as syn::DeriveInput);
impl_postgres_eq(ast).unwrap_or_else(syn::Error::into_compile_error).into()
deriving_postgres_eq(ast).unwrap_or_else(syn::Error::into_compile_error).into()
}

/**
Expand All @@ -1002,11 +1006,15 @@ enum DogNames {
Optionally accepts the following attributes:

* `sql`: Same arguments as [`#[pgrx(sql = ..)]`](macro@pgrx).

# No bounds?
Unlike some derives, this does not implement a "real" Rust trait, thus
PostgresOrd cannot be used in trait bounds, nor can it be manually implemented.
*/
#[proc_macro_derive(PostgresOrd, attributes(pgrx))]
pub fn postgres_ord(input: TokenStream) -> TokenStream {
pub fn derive_postgres_ord(input: TokenStream) -> TokenStream {
let ast = parse_macro_input!(input as syn::DeriveInput);
impl_postgres_ord(ast).unwrap_or_else(syn::Error::into_compile_error).into()
deriving_postgres_ord(ast).unwrap_or_else(syn::Error::into_compile_error).into()
}

/**
Expand All @@ -1025,11 +1033,15 @@ enum DogNames {
Optionally accepts the following attributes:

* `sql`: Same arguments as [`#[pgrx(sql = ..)]`](macro@pgrx).

# No bounds?
Unlike some derives, this does not implement a "real" Rust trait, thus
PostgresHash cannot be used in trait bounds, nor can it be manually implemented.
*/
#[proc_macro_derive(PostgresHash, attributes(pgrx))]
pub fn postgres_hash(input: TokenStream) -> TokenStream {
let ast = parse_macro_input!(input as syn::DeriveInput);
impl_postgres_hash(ast).unwrap_or_else(syn::Error::into_compile_error).into()
deriving_postgres_hash(ast).unwrap_or_else(syn::Error::into_compile_error).into()
}

/**
Expand Down
125 changes: 84 additions & 41 deletions pgrx-macros/src/operators.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,88 +14,131 @@ use proc_macro2::Ident;
use quote::{quote, ToTokens};
use syn::DeriveInput;

fn ident_and_type_path(ast: &DeriveInput) -> (&Ident, proc_macro2::TokenStream) {
fn ident_and_path(ast: &DeriveInput) -> (&Ident, proc_macro2::TokenStream) {
let ident = &ast.ident;
let args = parse_postgres_type_args(&ast.attrs);
let type_path = if args.contains(&PostgresTypeAttribute::PgVarlenaInOutFuncs) {
let path = if args.contains(&PostgresTypeAttribute::PgVarlenaInOutFuncs) {
quote! { ::pgrx::datum::PgVarlena<#ident> }
} else {
quote! { #ident }
};
(ident, type_path)
(ident, path)
}

pub(crate) fn impl_postgres_eq(ast: DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
pub(crate) fn deriving_postgres_eq(ast: DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
let mut stream = proc_macro2::TokenStream::new();
let (ident, type_path) = ident_and_type_path(&ast);
stream.extend(eq(ident, &type_path));
stream.extend(ne(ident, &type_path));
let (ident, path) = ident_and_path(&ast);
stream.extend(derive_pg_eq(ident, &path));
stream.extend(derive_pg_ne(ident, &path));

Ok(stream)
}

pub(crate) fn impl_postgres_ord(ast: DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
pub(crate) fn deriving_postgres_ord(ast: DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
let mut stream = proc_macro2::TokenStream::new();
let (ident, type_path) = ident_and_type_path(&ast);
let (ident, path) = ident_and_path(&ast);

stream.extend(lt(ident, &type_path));
stream.extend(gt(ident, &type_path));
stream.extend(le(ident, &type_path));
stream.extend(ge(ident, &type_path));
stream.extend(cmp(ident, &type_path));
stream.extend(derive_pg_lt(ident, &path));
stream.extend(derive_pg_gt(ident, &path));
stream.extend(derive_pg_le(ident, &path));
stream.extend(derive_pg_ge(ident, &path));
stream.extend(derive_pg_cmp(ident, &path));

let sql_graph_entity_item = PostgresOrd::from_derive_input(ast)?;
sql_graph_entity_item.to_tokens(&mut stream);

Ok(stream)
}

pub(crate) fn impl_postgres_hash(ast: DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
pub(crate) fn deriving_postgres_hash(ast: DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
let mut stream = proc_macro2::TokenStream::new();
let (ident, type_path) = ident_and_type_path(&ast);
let (ident, path) = ident_and_path(&ast);

stream.extend(hash(ident, &type_path));
stream.extend(derive_pg_hash(ident, &path));

let sql_graph_entity_item = PostgresHash::from_derive_input(ast)?;
sql_graph_entity_item.to_tokens(&mut stream);

Ok(stream)
}

pub fn eq(type_name: &Ident, type_path: &proc_macro2::TokenStream) -> proc_macro2::TokenStream {
let pg_name = Ident::new(&format!("{}_eq", type_name).to_lowercase(), type_name.span());
/// Derive a Postgres `=` operator from Rust `==`
///
/// Note this expansion applies a number of assumptions that may not be true:
/// - PartialEq::eq is referentially transparent (immutable and parallel-safe)
/// - PartialEq::ne must reverse PartialEq::eq (negator)
/// - PartialEq::eq is commutative
///
/// Postgres swears that these are just ["optimization hints"], and they can be
/// defined to use regular SQL or PL/pgSQL functions with spurious results.
///
/// However, it is entirely plausible these assumptions actually are venomous.
/// It is deeply unlikely that we can audit the millions of lines of C code in
/// Postgres to confirm that it avoids using these assumptions in a way that
/// would lead to UB or unacceptable behavior from PGRX if Eq is incorrectly
/// implemented, and we have no realistic means of guaranteeing this.
///
/// Further, Postgres adds a disclaimer to these "optimization hints":
///
/// ```text
/// But if you provide them, you must be sure that they are right!
/// Incorrect use of an optimization clause can result in
/// slow queries, subtly wrong output, or other Bad Things.
/// ```
///
/// In practice, most Eq impls are in fact correct, referentially transparent,
/// and commutative. So this note could be for nothing. This signpost is left
/// in order to guide anyone unfortunate enough to be debugging an issue that
/// finally leads them here.
///
/// ["optimization hints"]: https://www.postgresql.org/docs/current/xoper-optimization.html
pub fn derive_pg_eq(name: &Ident, path: &proc_macro2::TokenStream) -> proc_macro2::TokenStream {
let pg_name = Ident::new(&format!("{}_eq", name).to_lowercase(), name.span());
quote! {
#[doc(hidden)]
impl ::pgrx::deriving::PostgresEqRequiresTotalEq for #name {}

#[allow(non_snake_case)]
#[::pgrx::pgrx_macros::pg_operator(immutable, parallel_safe)]
#[::pgrx::pgrx_macros::opname(=)]
#[::pgrx::pgrx_macros::commutator(=)]
#[::pgrx::pgrx_macros::negator(<>)]
#[::pgrx::pgrx_macros::restrict(eqsel)]
#[::pgrx::pgrx_macros::join(eqjoinsel)]
#[::pgrx::pgrx_macros::merges]
#[::pgrx::pgrx_macros::hashes]
fn #pg_name(left: #type_path, right: #type_path) -> bool {
fn #pg_name(left: #path, right: #path) -> bool {
left == right
}
}
}

pub fn ne(type_name: &Ident, type_path: &proc_macro2::TokenStream) -> proc_macro2::TokenStream {
let pg_name = Ident::new(&format!("{}_ne", type_name).to_lowercase(), type_name.span());
/// Derive a Postgres `<>` operator from Rust `!=`
///
/// Note that this expansion applies a number of assumptions that aren't necessarily true:
/// - PartialEq::ne is referentially transparent (immutable and parallel-safe)
/// - PartialEq::eq must reverse PartialEq::ne (negator)
/// - PartialEq::ne is commutative
///
/// See `derive_pg_eq` for the implications of this assumption.
pub fn derive_pg_ne(name: &Ident, path: &proc_macro2::TokenStream) -> proc_macro2::TokenStream {
let pg_name = Ident::new(&format!("{}_ne", name).to_lowercase(), name.span());
quote! {
#[allow(non_snake_case)]
#[::pgrx::pgrx_macros::pg_operator(immutable, parallel_safe)]
#[::pgrx::pgrx_macros::opname(<>)]
#[::pgrx::pgrx_macros::commutator(<>)]
#[::pgrx::pgrx_macros::negator(=)]
#[::pgrx::pgrx_macros::restrict(neqsel)]
#[::pgrx::pgrx_macros::join(neqjoinsel)]
fn #pg_name(left: #type_path, right: #type_path) -> bool {
fn #pg_name(left: #path, right: #path) -> bool {
left != right
}
}
}

pub fn lt(type_name: &Ident, type_path: &proc_macro2::TokenStream) -> proc_macro2::TokenStream {
let pg_name = Ident::new(&format!("{}_lt", type_name).to_lowercase(), type_name.span());
pub fn derive_pg_lt(name: &Ident, path: &proc_macro2::TokenStream) -> proc_macro2::TokenStream {
let pg_name = Ident::new(&format!("{}_lt", name).to_lowercase(), name.span());
quote! {
#[allow(non_snake_case)]
#[::pgrx::pgrx_macros::pg_operator(immutable, parallel_safe)]
Expand All @@ -104,15 +147,15 @@ pub fn lt(type_name: &Ident, type_path: &proc_macro2::TokenStream) -> proc_macro
#[::pgrx::pgrx_macros::commutator(>)]
#[::pgrx::pgrx_macros::restrict(scalarltsel)]
#[::pgrx::pgrx_macros::join(scalarltjoinsel)]
fn #pg_name(left: #type_path, right: #type_path) -> bool {
fn #pg_name(left: #path, right: #path) -> bool {
left < right
}

}
}

pub fn gt(type_name: &Ident, type_path: &proc_macro2::TokenStream) -> proc_macro2::TokenStream {
let pg_name = Ident::new(&format!("{}_gt", type_name).to_lowercase(), type_name.span());
pub fn derive_pg_gt(name: &Ident, path: &proc_macro2::TokenStream) -> proc_macro2::TokenStream {
let pg_name = Ident::new(&format!("{}_gt", name).to_lowercase(), name.span());
quote! {
#[allow(non_snake_case)]
#[::pgrx::pgrx_macros::pg_operator(immutable, parallel_safe)]
Expand All @@ -121,14 +164,14 @@ pub fn gt(type_name: &Ident, type_path: &proc_macro2::TokenStream) -> proc_macro
#[::pgrx::pgrx_macros::commutator(<)]
#[::pgrx::pgrx_macros::restrict(scalargtsel)]
#[::pgrx::pgrx_macros::join(scalargtjoinsel)]
fn #pg_name(left: #type_path, right: #type_path) -> bool {
fn #pg_name(left: #path, right: #path) -> bool {
left > right
}
}
}

pub fn le(type_name: &Ident, type_path: &proc_macro2::TokenStream) -> proc_macro2::TokenStream {
let pg_name = Ident::new(&format!("{}_le", type_name).to_lowercase(), type_name.span());
pub fn derive_pg_le(name: &Ident, path: &proc_macro2::TokenStream) -> proc_macro2::TokenStream {
let pg_name = Ident::new(&format!("{}_le", name).to_lowercase(), name.span());
quote! {
#[allow(non_snake_case)]
#[::pgrx::pgrx_macros::pg_operator(immutable, parallel_safe)]
Expand All @@ -137,14 +180,14 @@ pub fn le(type_name: &Ident, type_path: &proc_macro2::TokenStream) -> proc_macro
#[::pgrx::pgrx_macros::commutator(>=)]
#[::pgrx::pgrx_macros::restrict(scalarlesel)]
#[::pgrx::pgrx_macros::join(scalarlejoinsel)]
fn #pg_name(left: #type_path, right: #type_path) -> bool {
fn #pg_name(left: #path, right: #path) -> bool {
left <= right
}
}
}

pub fn ge(type_name: &Ident, type_path: &proc_macro2::TokenStream) -> proc_macro2::TokenStream {
let pg_name = Ident::new(&format!("{}_ge", type_name).to_lowercase(), type_name.span());
pub fn derive_pg_ge(name: &Ident, path: &proc_macro2::TokenStream) -> proc_macro2::TokenStream {
let pg_name = Ident::new(&format!("{}_ge", name).to_lowercase(), name.span());
quote! {
#[allow(non_snake_case)]
#[::pgrx::pgrx_macros::pg_operator(immutable, parallel_safe)]
Expand All @@ -153,29 +196,29 @@ pub fn ge(type_name: &Ident, type_path: &proc_macro2::TokenStream) -> proc_macro
#[::pgrx::pgrx_macros::commutator(<=)]
#[::pgrx::pgrx_macros::restrict(scalargesel)]
#[::pgrx::pgrx_macros::join(scalargejoinsel)]
fn #pg_name(left: #type_path, right: #type_path) -> bool {
fn #pg_name(left: #path, right: #path) -> bool {
left >= right
}
}
}

pub fn cmp(type_name: &Ident, type_path: &proc_macro2::TokenStream) -> proc_macro2::TokenStream {
let pg_name = Ident::new(&format!("{}_cmp", type_name).to_lowercase(), type_name.span());
pub fn derive_pg_cmp(name: &Ident, path: &proc_macro2::TokenStream) -> proc_macro2::TokenStream {
let pg_name = Ident::new(&format!("{}_cmp", name).to_lowercase(), name.span());
quote! {
#[allow(non_snake_case)]
#[::pgrx::pgrx_macros::pg_extern(immutable, parallel_safe)]
fn #pg_name(left: #type_path, right: #type_path) -> i32 {
fn #pg_name(left: #path, right: #path) -> i32 {
left.cmp(&right) as i32
}
}
}

pub fn hash(type_name: &Ident, type_path: &proc_macro2::TokenStream) -> proc_macro2::TokenStream {
let pg_name = Ident::new(&format!("{}_hash", type_name).to_lowercase(), type_name.span());
pub fn derive_pg_hash(name: &Ident, path: &proc_macro2::TokenStream) -> proc_macro2::TokenStream {
let pg_name = Ident::new(&format!("{}_hash", name).to_lowercase(), name.span());
quote! {
#[allow(non_snake_case)]
#[::pgrx::pgrx_macros::pg_extern(immutable, parallel_safe)]
fn #pg_name(value: #type_path) -> i32 {
fn #pg_name(value: #path) -> i32 {
::pgrx::misc::pgrx_seahash(&value) as i32
}
}
Expand Down
8 changes: 8 additions & 0 deletions pgrx-tests/tests/ui/total_eq_for_postgres_eq.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
use pgrx::prelude::*;

#[derive(PartialEq, PostgresEq)]
struct BrokenType {
int: i32,
}

fn main() {}
16 changes: 16 additions & 0 deletions pgrx-tests/tests/ui/total_eq_for_postgres_eq.stderr
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
error[E0277]: the trait bound `BrokenType: std::cmp::Eq` is not satisfied
--> tests/ui/total_eq_for_postgres_eq.rs:4:8
|
4 | struct BrokenType {
| ^^^^^^^^^^ the trait `std::cmp::Eq` is not implemented for `BrokenType`
|
note: required by a bound in `PostgresEqRequiresTotalEq`
--> $WORKSPACE/pgrx/src/deriving.rs
|
| pub trait PostgresEqRequiresTotalEq: Eq {}
| ^^ required by this bound in `PostgresEqRequiresTotalEq`
help: consider annotating `BrokenType` with `#[derive(Eq)]`
|
4 + #[derive(Eq)]
5 | struct BrokenType {
|
3 changes: 3 additions & 0 deletions pgrx/src/deriving.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
#![doc(hidden)]

pub trait PostgresEqRequiresTotalEq: Eq {}
1 change: 1 addition & 0 deletions pgrx/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ pub mod atomics;
pub mod bgworkers;
pub mod callbacks;
pub mod datum;
pub mod deriving;
pub mod enum_helper;
pub mod fcinfo;
pub mod ffi;
Expand Down