Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions libkernel/src/sync/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,6 @@ pub mod mpsc;
pub mod mutex;
pub mod once_lock;
pub mod per_cpu;
pub mod rwlock;
pub mod spinlock;
pub mod waker_set;
44 changes: 44 additions & 0 deletions libkernel/src/sync/mutex.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,3 +131,47 @@ impl<T: ?Sized, CPU: CpuOps> DerefMut for AsyncMutexGuard<'_, T, CPU> {

unsafe impl<T: ?Sized + Send, CPU: CpuOps> Send for Mutex<T, CPU> {}
unsafe impl<T: ?Sized + Send, CPU: CpuOps> Sync for Mutex<T, CPU> {}

impl<CPU: CpuOps> Mutex<(), CPU> {
/// Acquires the mutex lock without caring about the data.
pub(crate) fn acquire(&self) -> MutexAcquireFuture<'_, CPU> {
MutexAcquireFuture { mutex: self }
}

/// Releases the mutex lock without caring about the data.
///
/// # Safety
/// The caller must ensure that they have previously called [`Self::acquire()`].
pub(crate) unsafe fn release(&self) {
let mut state = self.state.lock_save_irq();

if let Some(next_waker) = state.waiters.pop_front() {
next_waker.wake();
}

state.is_locked = false;
}
}

/// A future that resolves to a locked mutex
pub struct MutexAcquireFuture<'a, CPU: CpuOps> {
mutex: &'a Mutex<(), CPU>,
}

impl<CPU: CpuOps> Future for MutexAcquireFuture<'_, CPU> {
type Output = ();

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut state = self.mutex.state.lock_save_irq();

if !state.is_locked {
state.is_locked = true;
Poll::Ready(())
} else {
if state.waiters.iter().all(|w| !w.will_wake(cx.waker())) {
state.waiters.push_back(cx.waker().clone());
}
Poll::Pending
}
}
}
128 changes: 128 additions & 0 deletions libkernel/src/sync/rwlock.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
use super::spinlock::SpinLockIrq;
use crate::CpuOps;
use crate::sync::mutex::Mutex;
use core::cell::UnsafeCell;
use core::ops::{Deref, DerefMut};

struct RwlockState<CPU: CpuOps> {
num_readers: SpinLockIrq<usize, CPU>,
writer_lock: Mutex<(), CPU>,
}

/// An asynchronous, rwlock primitive.
///
/// This rwlock can be used to protect shared data across asynchronous tasks.
/// `lock()` returns a future that resolves to a guard. When the guard is
/// dropped, the lock is released.
pub struct Rwlock<T: ?Sized, CPU: CpuOps> {
state: RwlockState<CPU>,
data: UnsafeCell<T>,
}

/// A guard that provides read-only access to the data in an `AsyncRwlock`.
///
/// When an `AsyncRwlockReadGuard` is dropped, it automatically decreases the
/// read count and wakes up the next task if necessary.
#[must_use = "if unused, the Rwlock will immediately unlock"]
pub struct AsyncRwlockReadGuard<'a, T: ?Sized, CPU: CpuOps> {
rwlock: &'a Rwlock<T, CPU>,
}

/// A guard that provides exclusive access to the data in an `AsyncRwlock`.
///
/// When an `AsyncRwlockWriteGuard` is dropped, it automatically releases the lock and
/// wakes up the next task.
#[must_use = "if unused, the Rwlock will immediately unlock"]
pub struct AsyncRwlockWriteGuard<'a, T: ?Sized, CPU: CpuOps> {
rwlock: &'a Rwlock<T, CPU>,
}

impl<T, CPU: CpuOps> Rwlock<T, CPU> {
/// Creates a new asynchronous rwlock in an unlocked state.
pub fn new(data: T) -> Self {
Self {
state: RwlockState {
num_readers: SpinLockIrq::new(0),
writer_lock: Mutex::new(()),
},
data: UnsafeCell::new(data),
}
}

/// Consumes the rwlock, returning the underlying data.
///
/// This is safe because consuming `self` guarantees no other code can
/// access the rwlock.
pub fn into_inner(self) -> T {
self.data.into_inner()
}
}

impl<T: ?Sized, CPU: CpuOps> Rwlock<T, CPU> {
/// Acquires rwlock read.
///
/// Returns a guard asynchronously. The guard is released when the
/// returned [`AsyncRwlockReadGuard`] is dropped.
pub async fn read(&self) -> AsyncRwlockReadGuard<'_, T, CPU> {
let mut num_readers = self.state.num_readers.lock_save_irq();
*num_readers += 1;
if *num_readers == 1 {
self.state.writer_lock.acquire().await;
}
AsyncRwlockReadGuard { rwlock: self }
}

/// Acquires rwlock write.
///
/// Returns a guard asynchronously. The guard is released when the
/// returned [`AsyncRwlockWriteGuard`] is dropped.
pub async fn write(&self) -> AsyncRwlockWriteGuard<'_, T, CPU> {
self.state.writer_lock.acquire().await;
AsyncRwlockWriteGuard { rwlock: self }
}
}

impl<T: ?Sized, CPU: CpuOps> Drop for AsyncRwlockReadGuard<'_, T, CPU> {
fn drop(&mut self) {
let mut num_readers = self.rwlock.state.num_readers.lock_save_irq();
*num_readers -= 1;
if *num_readers == 0 {
unsafe { self.rwlock.state.writer_lock.release() };
}
}
}

impl<T: ?Sized, CPU: CpuOps> Deref for AsyncRwlockReadGuard<'_, T, CPU> {
type Target = T;
fn deref(&self) -> &T {
// SAFETY: This is safe because the existence of this guard guarantees
// we have read access to the data without any writers.
unsafe { &*self.rwlock.data.get() }
}
}

impl<T: ?Sized, CPU: CpuOps> Drop for AsyncRwlockWriteGuard<'_, T, CPU> {
fn drop(&mut self) {
unsafe { self.rwlock.state.writer_lock.release() };
}
}

impl<T: ?Sized, CPU: CpuOps> Deref for AsyncRwlockWriteGuard<'_, T, CPU> {
type Target = T;
fn deref(&self) -> &T {
// SAFETY: This is safe because the existence of this guard guarantees
// we have exclusive access to the data.
unsafe { &*self.rwlock.data.get() }
}
}

impl<T: ?Sized, CPU: CpuOps> DerefMut for AsyncRwlockWriteGuard<'_, T, CPU> {
fn deref_mut(&mut self) -> &mut T {
// SAFETY: This is safe because the existence of this guard guarantees
// we have exclusive access to the data.
unsafe { &mut *self.rwlock.data.get() }
}
}

unsafe impl<T: ?Sized + Send, CPU: CpuOps> Send for Rwlock<T, CPU> {}
unsafe impl<T: ?Sized + Send, CPU: CpuOps> Sync for Rwlock<T, CPU> {}
8 changes: 8 additions & 0 deletions src/sync/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,14 @@ pub mod per_cpu;
pub type SpinLock<T> = libkernel::sync::spinlock::SpinLockIrq<T, ArchImpl>;
pub type Mutex<T> = libkernel::sync::mutex::Mutex<T, ArchImpl>;
pub type AsyncMutexGuard<'a, T> = libkernel::sync::mutex::AsyncMutexGuard<'a, T, ArchImpl>;
#[expect(dead_code)]
pub type Rwlock<T> = libkernel::sync::rwlock::Rwlock<T, ArchImpl>;
#[expect(dead_code)]
pub type AsyncRwlockReadGuard<'a, T> =
libkernel::sync::rwlock::AsyncRwlockReadGuard<'a, T, ArchImpl>;
#[expect(dead_code)]
pub type AsyncRwlockWriteGuard<'a, T> =
libkernel::sync::rwlock::AsyncRwlockWriteGuard<'a, T, ArchImpl>;
pub type OnceLock<T> = libkernel::sync::once_lock::OnceLock<T, ArchImpl>;
pub type CondVar<T> = libkernel::sync::condvar::CondVar<T, ArchImpl>;
// pub type Reciever<T> = libkernel::sync::mpsc::Reciever<T, ArchImpl>;
Expand Down