diff --git a/packages/kernel/src/sync/mutex.rs b/packages/kernel/src/sync/mutex.rs index bc609a3..8f92aac 100644 --- a/packages/kernel/src/sync/mutex.rs +++ b/packages/kernel/src/sync/mutex.rs @@ -1,23 +1,62 @@ -use core::sync::atomic::{AtomicU8, Ordering}; +use std::{cell::SyncUnsafeCell, hint::spin_loop, sync::atomic::{AtomicBool, AtomicU32, Ordering}}; +use critical_section::RestoreState; pub use lock_api::MutexGuard; pub type Mutex = lock_api::Mutex; -struct MutexState(AtomicU8); +fn in_interrupt() -> bool { + unsafe { + let mut cpsr: u32; + asm!("mrs {0}, cpsr", out(reg) cpsr); + let is_system_mode = (cpsr & 0b11111) == 0b11111; + !is_system_mode + } +} + +static MUTEX_COUNTER: AtomicU32 = AtomicU32::new(0); +struct GlobalMutexState(SyncUnsafeCell); +/// safety: this variable should only be accessed in a critical section +static GLOBAL_MUTEX_STATE: SyncUnsafeCell = SyncUnsafeCell::new(RestoreState::invalid()); + +struct MutexState(AtomicBool); impl MutexState { const fn new() -> Self { - Self(AtomicU8::new(0)) + Self(AtomicBool::new(false)) } /// Returns true if the lock was acquired. fn try_lock(&self) -> bool { - self.0 - .compare_exchange(0, 1, Ordering::Acquire, Ordering::Acquire) - .is_ok() + unsafe { + let state = critical_section::acquire(); + if self.0.compare_exchange(false, true, Ordering::Acquire, Ordering::Acquire).is_ok() { + if MUTEX_COUNTER.fetch_add(1, Ordering::AcqRel) == 0 { + // we're the first mutex to be locked, we need to save the RestoreState + // to be released later + *GLOBAL_MUTEX_STATE.get() = state; + } else { + // another mutex has already entered a critical section before us so + // we need to release the nested critical section we entered + critical_section::release(state); + } + true + } else { + // the mutex is already locked so release the critical section we acquired + critical_section::release(state); + false + } + } } - fn unlock(&self) { - self.0.store(0, Ordering::Release); + /// safety: this function must only be called after a successful call to try_lock + unsafe fn unlock(&self) { + unsafe { + if MUTEX_COUNTER.fetch_sub(1, Ordering::AcqRel) == 1 { + // we are the last mutex to be unlocked so we need to unlock the critical section + critical_section::release(GLOBAL_MUTEX_STATE.into()); + } + self.0.store(false, Ordering::Release); + } + () } } @@ -42,20 +81,23 @@ unsafe impl lock_api::RawMutex for RawMutex { type GuardMarker = lock_api::GuardSend; fn lock(&self) { - critical_section::with(|_| { + if self.state.try_lock() { + () + } else if in_interrupt() { + unreachable!("Deadlock in kernel detected!"); + } else { while !self.state.try_lock() { - core::hint::spin_loop(); + spin_loop() } - }) + () + } } fn try_lock(&self) -> bool { - critical_section::with(|_| self.state.try_lock()) + self.state.try_lock() } unsafe fn unlock(&self) { - critical_section::with(|_| { - self.state.unlock(); - }) + unsafe { self.state.unlock() } } }