diff --git a/ristretto_classloader/src/class.rs b/ristretto_classloader/src/class.rs index e2d18f14..8c65c85e 100644 --- a/ristretto_classloader/src/class.rs +++ b/ristretto_classloader/src/class.rs @@ -471,6 +471,11 @@ impl Class { impl PartialEq for Class { fn eq(&self, other: &Self) -> bool { + // Optimization for the case where the two classes are the same instance. + if std::ptr::eq(self, other) { + return true; + } + self.name == other.name && self.class_file == other.class_file && *self.parent.read().expect("parent") == *other.parent.read().expect("parent") diff --git a/ristretto_classloader/src/object.rs b/ristretto_classloader/src/object.rs index eaf3ec15..fb6ef046 100644 --- a/ristretto_classloader/src/object.rs +++ b/ristretto_classloader/src/object.rs @@ -9,7 +9,7 @@ use std::sync::Arc; const JAVA_8: Version = Version::Java8 { minor: 0 }; /// Represents an object in the Ristretto VM. -#[derive(Clone, PartialEq)] +#[derive(Clone)] pub struct Object { class: Arc, fields: HashMap, @@ -186,6 +186,17 @@ impl Display for Object { } } +impl PartialEq for Object { + fn eq(&self, other: &Self) -> bool { + // Compare the references by pointer to determine if they are the same object in + // order to avoid infinite recursion + if std::ptr::eq(self, other) { + return true; + } + self.class == other.class && self.fields == other.fields + } +} + impl TryInto for Object { type Error = crate::Error; diff --git a/ristretto_classloader/src/reference.rs b/ristretto_classloader/src/reference.rs index 22e35fb8..376df13f 100644 --- a/ristretto_classloader/src/reference.rs +++ b/ristretto_classloader/src/reference.rs @@ -7,7 +7,7 @@ use std::fmt::Display; use std::sync::Arc; /// Represents a reference to an object in the Ristretto VM. -#[derive(Clone, Debug)] +#[derive(Clone, Debug, PartialEq)] pub enum Reference { ByteArray(ConcurrentVec), CharArray(ConcurrentVec), @@ -180,33 +180,6 @@ impl Display for Reference { } } -impl PartialEq for Reference { - fn eq(&self, other: &Self) -> bool { - match (self, other) { - (Reference::ByteArray(a), Reference::ByteArray(b)) => a == b, - (Reference::CharArray(a), Reference::CharArray(b)) => a == b, - (Reference::ShortArray(a), Reference::ShortArray(b)) => a == b, - (Reference::IntArray(a), Reference::IntArray(b)) => a == b, - (Reference::LongArray(a), Reference::LongArray(b)) => a == b, - (Reference::FloatArray(a), Reference::FloatArray(b)) => a == b, - (Reference::DoubleArray(a), Reference::DoubleArray(b)) => a == b, - (Reference::Array(a_class, a_array), Reference::Array(b_class, b_array)) => { - a_class.name() == b_class.name() && a_array == b_array - } - (Reference::Object(a), Reference::Object(b)) => { - // Compare the references by pointer to determine if they are the same object in - // order to avoid infinite recursion - if std::ptr::eq(a, b) { - true - } else { - a == b - } - } - _ => false, - } - } -} - impl From> for Reference { fn from(value: Vec) -> Self { let value: Vec = value.into_iter().map(i8::from).collect();