diff --git a/openssl/src/ssl/mod.rs b/openssl/src/ssl/mod.rs index cc643ae57f..fb38bb3e4a 100644 --- a/openssl/src/ssl/mod.rs +++ b/openssl/src/ssl/mod.rs @@ -1572,16 +1572,34 @@ impl SslContextBuilder { /// /// This can be used to provide data to callbacks registered with the context. Use the /// `SslContext::new_ex_index` method to create an `Index`. + // FIXME should return a result #[corresponds(SSL_CTX_set_ex_data)] pub fn set_ex_data(&mut self, index: Index, data: T) { self.set_ex_data_inner(index, data); } fn set_ex_data_inner(&mut self, index: Index, data: T) -> *mut c_void { + match self.ex_data_mut(index) { + Some(v) => { + *v = data; + (v as *mut T).cast() + } + _ => unsafe { + let data = Box::into_raw(Box::new(data)) as *mut c_void; + ffi::SSL_CTX_set_ex_data(self.as_ptr(), index.as_raw(), data); + data + }, + } + } + + fn ex_data_mut(&mut self, index: Index) -> Option<&mut T> { unsafe { - let data = Box::into_raw(Box::new(data)) as *mut c_void; - ffi::SSL_CTX_set_ex_data(self.as_ptr(), index.as_raw(), data); - data + let data = ffi::SSL_CTX_get_ex_data(self.as_ptr(), index.as_raw()); + if data.is_null() { + None + } else { + Some(&mut *data.cast()) + } } } @@ -2965,15 +2983,19 @@ impl SslRef { /// /// This can be used to provide data to callbacks registered with the context. Use the /// `Ssl::new_ex_index` method to create an `Index`. + // FIXME should return a result #[corresponds(SSL_set_ex_data)] pub fn set_ex_data(&mut self, index: Index, data: T) { - unsafe { - let data = Box::new(data); - ffi::SSL_set_ex_data( - self.as_ptr(), - index.as_raw(), - Box::into_raw(data) as *mut c_void, - ); + match self.ex_data_mut(index) { + Some(v) => *v = data, + None => unsafe { + let data = Box::new(data); + ffi::SSL_set_ex_data( + self.as_ptr(), + index.as_raw(), + Box::into_raw(data) as *mut c_void, + ); + }, } } diff --git a/openssl/src/ssl/test/mod.rs b/openssl/src/ssl/test/mod.rs index 1fc9ba6b48..412c4a5dc6 100644 --- a/openssl/src/ssl/test/mod.rs +++ b/openssl/src/ssl/test/mod.rs @@ -10,7 +10,7 @@ use std::net::UdpSocket; use std::net::{SocketAddr, TcpListener, TcpStream}; use std::path::Path; use std::process::{Child, ChildStdin, Command, Stdio}; -use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; use std::thread; use std::time::Duration; @@ -1638,3 +1638,50 @@ fn set_security_level() { let ssl = ssl; assert_eq!(4, ssl.security_level()); } + +#[test] +fn ssl_ctx_ex_data_leak() { + static DROPS: AtomicUsize = AtomicUsize::new(0); + + struct DropTest; + + impl Drop for DropTest { + fn drop(&mut self) { + DROPS.fetch_add(1, Ordering::Relaxed); + } + } + + let idx = SslContext::new_ex_index().unwrap(); + + let mut ctx = SslContext::builder(SslMethod::tls()).unwrap(); + ctx.set_ex_data(idx, DropTest); + ctx.set_ex_data(idx, DropTest); + assert_eq!(DROPS.load(Ordering::Relaxed), 1); + + drop(ctx); + assert_eq!(DROPS.load(Ordering::Relaxed), 2); +} + +#[test] +fn ssl_ex_data_leak() { + static DROPS: AtomicUsize = AtomicUsize::new(0); + + struct DropTest; + + impl Drop for DropTest { + fn drop(&mut self) { + DROPS.fetch_add(1, Ordering::Relaxed); + } + } + + let idx = Ssl::new_ex_index().unwrap(); + + let ctx = SslContext::builder(SslMethod::tls()).unwrap().build(); + let mut ssl = Ssl::new(&ctx).unwrap(); + ssl.set_ex_data(idx, DropTest); + ssl.set_ex_data(idx, DropTest); + assert_eq!(DROPS.load(Ordering::Relaxed), 1); + + drop(ssl); + assert_eq!(DROPS.load(Ordering::Relaxed), 2); +}