diff --git a/zenoh/src/api/session.rs b/zenoh/src/api/session.rs index fafa3cb460..2bf9c85851 100644 --- a/zenoh/src/api/session.rs +++ b/zenoh/src/api/session.rs @@ -11,6 +11,8 @@ // Contributors: // ZettaScale Zenoh Team, // +#[cfg(feature = "test")] +use std::ops::DerefMut; use std::{ collections::{hash_map::Entry, HashMap}, convert::TryInto, @@ -525,10 +527,49 @@ pub trait Undeclarable: UndeclarableSealed {} impl Undeclarable for T where T: UndeclarableSealed {} +#[cfg(feature = "test")] +pub(crate) struct InvalidatableSessionState { + inner: Option, +} + +#[cfg(feature = "test")] +impl InvalidatableSessionState { + fn new(state: SessionState) -> Self { + Self { inner: Some(state) } + } + + fn invalidate(&mut self) { + let _ = self.inner.take(); + } +} + +#[cfg(feature = "test")] +impl DerefMut for InvalidatableSessionState { + fn deref_mut(&mut self) -> &mut Self::Target { + self.inner + .as_mut() + .expect("referring to invalidated (closed) session state") + } +} + +#[cfg(feature = "test")] +impl Deref for InvalidatableSessionState { + type Target = SessionState; + + fn deref(&self) -> &Self::Target { + self.inner + .as_ref() + .expect("referring to invalidated (closed) session state") + } +} + pub(crate) struct SessionInner { /// See [`WeakSession`] doc weak_counter: Mutex, pub(crate) runtime: Runtime, + #[cfg(feature = "test")] + pub(crate) state: RwLock, + #[cfg(not(feature = "test"))] pub(crate) state: RwLock, pub(crate) id: u16, owns_runtime: bool, @@ -686,11 +727,15 @@ impl Session { let publisher_qos = config.0.qos().publication().clone(); let namespace = config.0.namespace().clone(); drop(config); - let state = RwLock::new(SessionState::new( + let state = SessionState::new( aggregated_subscribers, aggregated_publishers, publisher_qos.into(), - )); + ); + #[cfg(feature = "test")] + let state = InvalidatableSessionState::new(state); + let state = RwLock::new(state); + let session = Session(Arc::new(SessionInner { weak_counter: Mutex::new(0), runtime: runtime.clone(), @@ -3185,27 +3230,33 @@ impl Closee for Arc { primitives.send_close(); } - // defer the cleanup of internal data structures by taking them out of the locked state - // this is needed because callbacks may contain entities which need to acquire the - // lock to be dropped, so callback must be dropped without the lock held - let mut state = zwrite!(self.state); - let _queryables = std::mem::take(&mut state.queryables); - let _subscribers = std::mem::take(&mut state.subscribers); - let _liveliness_subscribers = std::mem::take(&mut state.liveliness_subscribers); - let _local_resources = std::mem::take(&mut state.local_resources); - let _remote_resources = std::mem::take(&mut state.remote_resources); - drop(state); - #[cfg(feature = "unstable")] { - // the lock from the outer scope cannot be reused because the declared variables - // would be undeclared at the end of the block, with the lock held, and we want - // to avoid that; so we reacquire the lock in the block - // anyway, it doesn't really matter, and this code will be cleaned up when the APIs - // will be stabilized. + // defer the cleanup of internal data structures by taking them out of the locked state + // this is needed because callbacks may contain entities which need to acquire the + // lock to be dropped, so callback must be dropped without the lock held let mut state = zwrite!(self.state); - let _matching_listeners = std::mem::take(&mut state.matching_listeners); + let _queryables = std::mem::take(&mut state.queryables); + let _subscribers = std::mem::take(&mut state.subscribers); + let _liveliness_subscribers = std::mem::take(&mut state.liveliness_subscribers); + let _local_resources = std::mem::take(&mut state.local_resources); + let _remote_resources = std::mem::take(&mut state.remote_resources); + let _queries = std::mem::take(&mut state.queries); drop(state); + #[cfg(feature = "unstable")] + { + // the lock from the outer scope cannot be reused because the declared variables + // would be undeclared at the end of the block, with the lock held, and we want + // to avoid that; so we reacquire the lock in the block + // anyway, it doesn't really matter, and this code will be cleaned up when the APIs + // will be stabilized. + let mut state = zwrite!(self.state); + let _queriers = std::mem::take(&mut state.queriers); + let _matching_listeners = std::mem::take(&mut state.matching_listeners); + drop(state); + } } + #[cfg(feature = "test")] + zwrite!(self.state).invalidate(); } }