Add methods to find the next element in an XArray starting from a given index. The methods return a tuple containing the index where the element was found and a reference to the element. The implementation uses the XArray state API via `xas_find` to avoid taking the rcu lock as an exclusive lock is already held by `Guard`. Signed-off-by: Andreas Hindborg --- rust/kernel/xarray.rs | 68 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 68 insertions(+) diff --git a/rust/kernel/xarray.rs b/rust/kernel/xarray.rs index e654bf56dc97c..656ec897a0c41 100644 --- a/rust/kernel/xarray.rs +++ b/rust/kernel/xarray.rs @@ -251,6 +251,67 @@ pub fn get_mut(&mut self, index: usize) -> Option> { Some(unsafe { T::borrow_mut(ptr.as_ptr()) }) } + fn load_next(&self, index: usize) -> Option<(usize, NonNull)> { + XArrayState::new(self, index).load_next() + } + + /// Finds the next element starting from the given index. + /// + /// # Examples + /// + /// ``` + /// # use kernel::{prelude::*, xarray::{AllocKind, XArray}}; + /// let mut xa = KBox::pin_init(XArray::>::new(AllocKind::Alloc), GFP_KERNEL)?; + /// let mut guard = xa.lock(); + /// + /// guard.store(10, KBox::new(10u32, GFP_KERNEL)?, GFP_KERNEL)?; + /// guard.store(20, KBox::new(20u32, GFP_KERNEL)?, GFP_KERNEL)?; + /// + /// if let Some((found_index, value)) = guard.find_next(11) { + /// assert_eq!(found_index, 20); + /// assert_eq!(*value, 20); + /// } + /// + /// if let Some((found_index, value)) = guard.find_next(5) { + /// assert_eq!(found_index, 10); + /// assert_eq!(*value, 10); + /// } + /// + /// # Ok::<(), kernel::error::Error>(()) + /// ``` + pub fn find_next(&self, index: usize) -> Option<(usize, T::Borrowed<'_>)> { + self.load_next(index) + // SAFETY: `ptr` came from `T::into_foreign`. + .map(|(index, ptr)| (index, unsafe { T::borrow(ptr.as_ptr()) })) + } + + /// Finds the next element starting from the given index, returning a mutable reference. + /// + /// # Examples + /// + /// ``` + /// # use kernel::{prelude::*, xarray::{AllocKind, XArray}}; + /// let mut xa = KBox::pin_init(XArray::>::new(AllocKind::Alloc), GFP_KERNEL)?; + /// let mut guard = xa.lock(); + /// + /// guard.store(10, KBox::new(10u32, GFP_KERNEL)?, GFP_KERNEL)?; + /// guard.store(20, KBox::new(20u32, GFP_KERNEL)?, GFP_KERNEL)?; + /// + /// if let Some((found_index, mut_value)) = guard.find_next_mut(5) { + /// assert_eq!(found_index, 10); + /// *mut_value = 0x99; + /// } + /// + /// assert_eq!(guard.get(10).copied(), Some(0x99)); + /// + /// # Ok::<(), kernel::error::Error>(()) + /// ``` + pub fn find_next_mut(&mut self, index: usize) -> Option<(usize, T::BorrowedMut<'_>)> { + self.load_next(index) + // SAFETY: `ptr` came from `T::into_foreign`. + .map(move |(index, ptr)| (index, unsafe { T::borrow_mut(ptr.as_ptr()) })) + } + /// Removes and returns the element at the given index. pub fn remove(&mut self, index: usize) -> Option { // SAFETY: @@ -354,6 +415,13 @@ fn load(&mut self) -> Option> { let ptr = unsafe { bindings::xas_load(&raw mut self.state) }; NonNull::new(ptr.cast()) } + + fn load_next(&mut self) -> Option<(usize, NonNull)> { + // SAFETY: `self.state` is always valid by the type invariant of + // `XArrayState` and the we hold the xarray lock. + let ptr = unsafe { bindings::xas_find(&raw mut self.state, usize::MAX) }; + NonNull::new(ptr).map(|ptr| (self.state.xa_index, ptr)) + } } // SAFETY: `XArray` has no shared mutable state so it is `Send` iff `T` is `Send`. -- 2.51.2