diff --git a/rust/kernel/rbtree.rs b/rust/kernel/rbtree.rs index 754af0db86b59a..db185b6b3c65b1 100644 --- a/rust/kernel/rbtree.rs +++ b/rust/kernel/rbtree.rs @@ -297,12 +297,19 @@ where /// key/value pair). Returns [`None`] if a node with the same key didn't already exist. /// /// This function always succeeds. - pub fn insert(&mut self, RBTreeNode { node }: RBTreeNode) -> Option> { - let node = Box::into_raw(node); - // SAFETY: `node` is valid at least until we call `Box::from_raw`, which only happens when - // the node is removed or replaced. - let node_links = unsafe { addr_of_mut!((*node).links) }; + pub fn insert(&mut self, node: RBTreeNode) -> Option> { + match self.raw_entry(&node.node.key) { + RawEntry::Occupied(entry) => Some(entry.replace(node)), + RawEntry::Vacant(entry) => { + entry.insert(node); + None + } + } + } + fn raw_entry(&mut self, key: &K) -> RawEntry<'_, K, V> { + let raw_self: *mut RBTree = self; + // The returned `RawEntry` is used to call either `rb_link_node` or `rb_replace_node`. // The parameters of `bindings::rb_link_node` are as follows: // - `node`: A pointer to an uninitialized node being inserted. // - `parent`: A pointer to an existing node in the tree. One of its child pointers must be @@ -321,62 +328,56 @@ where // in the subtree of `parent` that `child_field_of_parent` points at. Once // we find an empty subtree, we can insert the new node using `rb_link_node`. let mut parent = core::ptr::null_mut(); - let mut child_field_of_parent: &mut *mut bindings::rb_node = &mut self.root.rb_node; - while !child_field_of_parent.is_null() { - parent = *child_field_of_parent; + let mut child_field_of_parent: &mut *mut bindings::rb_node = + // SAFETY: `raw_self` is a valid pointer to the `RBTree` (created from `self` above). + unsafe { &mut (*raw_self).root.rb_node }; + while !(*child_field_of_parent).is_null() { + let curr = *child_field_of_parent; + // SAFETY: All links fields we create are in a `Node`. + let node = unsafe { container_of!(curr, Node, links) }; - // We need to determine whether `node` should be the left or right child of `parent`, - // so we will compare with the `key` field of `parent` a.k.a. `this` below. - // - // SAFETY: By the type invariant of `Self`, all non-null `rb_node` pointers stored in `self` - // point to the links field of `Node` objects. - let this = unsafe { container_of!(parent, Node, links) }; - - // SAFETY: `this` is a non-null node so it is valid by the type invariants. `node` is - // valid until the node is removed. - match unsafe { (*node).key.cmp(&(*this).key) } { - // We would like `node` to be the left child of `parent`. Move to this child to check - // whether we can use it, or continue searching, at the next iteration. - // - // SAFETY: `parent` is a non-null node so it is valid by the type invariants. - Ordering::Less => child_field_of_parent = unsafe { &mut (*parent).rb_left }, - // We would like `node` to be the right child of `parent`. Move to this child to check - // whether we can use it, or continue searching, at the next iteration. - // - // SAFETY: `parent` is a non-null node so it is valid by the type invariants. - Ordering::Greater => child_field_of_parent = unsafe { &mut (*parent).rb_right }, + // SAFETY: `node` is a non-null node so it is valid by the type invariants. + match key.cmp(unsafe { &(*node).key }) { + // SAFETY: `curr` is a non-null node so it is valid by the type invariants. + Ordering::Less => child_field_of_parent = unsafe { &mut (*curr).rb_left }, + // SAFETY: `curr` is a non-null node so it is valid by the type invariants. + Ordering::Greater => child_field_of_parent = unsafe { &mut (*curr).rb_right }, Ordering::Equal => { - // There is an existing node in the tree with this key, and that node is - // `parent`. Thus, we are replacing parent with a new node. - // - // INVARIANT: We are replacing an existing node with a new one, which is valid. - // It remains valid because we "forgot" it with `Box::into_raw`. - // SAFETY: All pointers are non-null and valid. - unsafe { bindings::rb_replace_node(parent, node_links, &mut self.root) }; - - // INVARIANT: The node is being returned and the caller may free it, however, - // it was removed from the tree. So the invariants still hold. - return Some(RBTreeNode { - // SAFETY: `this` was a node in the tree, so it is valid. - node: unsafe { Box::from_raw(this.cast_mut()) }, - }); + return RawEntry::Occupied(OccupiedEntry { + rbtree: self, + node_links: curr, + }) } } + parent = curr; } - // INVARIANT: We are linking in a new node, which is valid. It remains valid because we - // "forgot" it with `Box::into_raw`. - // SAFETY: All pointers are non-null and valid (`*child_field_of_parent` is null, but `child_field_of_parent` is a - // mutable reference). - unsafe { bindings::rb_link_node(node_links, parent, child_field_of_parent) }; + RawEntry::Vacant(RawVacantEntry { + rbtree: raw_self, + parent, + child_field_of_parent, + _phantom: PhantomData, + }) + } - // SAFETY: All pointers are valid. `node` has just been inserted into the tree. - unsafe { bindings::rb_insert_color(node_links, &mut self.root) }; - None + /// Gets the given key's corresponding entry in the map for in-place manipulation. + pub fn entry(&mut self, key: K) -> Entry<'_, K, V> { + match self.raw_entry(&key) { + RawEntry::Occupied(entry) => Entry::Occupied(entry), + RawEntry::Vacant(entry) => Entry::Vacant(VacantEntry { raw: entry, key }), + } } - /// Returns a node with the given key, if one exists. - fn find(&self, key: &K) -> Option>> { + /// Used for accessing the given node, if it exists. + pub fn find_mut(&mut self, key: &K) -> Option> { + match self.raw_entry(key) { + RawEntry::Occupied(entry) => Some(entry), + RawEntry::Vacant(_entry) => None, + } + } + + /// Returns a reference to the value corresponding to the key. + pub fn get(&self, key: &K) -> Option<&V> { let mut node = self.root.rb_node; while !node.is_null() { // SAFETY: By the type invariant of `Self`, all non-null `rb_node` pointers stored in `self` @@ -388,47 +389,30 @@ where Ordering::Less => unsafe { (*node).rb_left }, // SAFETY: `node` is a non-null node so it is valid by the type invariants. Ordering::Greater => unsafe { (*node).rb_right }, - Ordering::Equal => return NonNull::new(this.cast_mut()), + // SAFETY: `node` is a non-null node so it is valid by the type invariants. + Ordering::Equal => return Some(unsafe { &(*this).value }), } } None } - /// Returns a reference to the value corresponding to the key. - pub fn get(&self, key: &K) -> Option<&V> { - // SAFETY: The `find` return value is a node in the tree, so it is valid. - self.find(key).map(|node| unsafe { &node.as_ref().value }) - } - /// Returns a mutable reference to the value corresponding to the key. pub fn get_mut(&mut self, key: &K) -> Option<&mut V> { - // SAFETY: The `find` return value is a node in the tree, so it is valid. - self.find(key) - .map(|mut node| unsafe { &mut node.as_mut().value }) + self.find_mut(key).map(|node| node.into_mut()) } /// Removes the node with the given key from the tree. /// /// It returns the node that was removed if one exists, or [`None`] otherwise. - fn remove_node(&mut self, key: &K) -> Option> { - let mut node = self.find(key)?; - - // SAFETY: The `find` return value is a node in the tree, so it is valid. - unsafe { bindings::rb_erase(&mut node.as_mut().links, &mut self.root) }; - - // INVARIANT: The node is being returned and the caller may free it, however, it was - // removed from the tree. So the invariants still hold. - Some(RBTreeNode { - // SAFETY: The `find` return value was a node in the tree, so it is valid. - node: unsafe { Box::from_raw(node.as_ptr()) }, - }) + pub fn remove_node(&mut self, key: &K) -> Option> { + self.find_mut(key).map(OccupiedEntry::remove_node) } /// Removes the node with the given key from the tree. /// /// It returns the value that was removed if one exists, or [`None`] otherwise. pub fn remove(&mut self, key: &K) -> Option { - self.remove_node(key).map(|node| node.node.value) + self.find_mut(key).map(OccupiedEntry::remove) } /// Returns a cursor over the tree nodes based on the given key. @@ -1136,6 +1120,182 @@ unsafe impl Send for RBTreeNode {} // [`RBTreeNode`] without synchronization. unsafe impl Sync for RBTreeNode {} +impl RBTreeNode { + /// Drop the key and value, but keep the allocation. + /// + /// It then becomes a reservation that can be re-initialised into a different node (i.e., with + /// a different key and/or value). + /// + /// The existing key and value are dropped in-place as part of this operation, that is, memory + /// may be freed (but only for the key/value; memory for the node itself is kept for reuse). + pub fn into_reservation(self) -> RBTreeNodeReservation { + RBTreeNodeReservation { + node: Box::drop_contents(self.node), + } + } +} + +/// A view into a single entry in a map, which may either be vacant or occupied. +/// +/// This enum is constructed from the [`RBTree::entry`]. +/// +/// [`entry`]: fn@RBTree::entry +pub enum Entry<'a, K, V> { + /// This [`RBTree`] does not have a node with this key. + Vacant(VacantEntry<'a, K, V>), + /// This [`RBTree`] already has a node with this key. + Occupied(OccupiedEntry<'a, K, V>), +} + +/// Like [`Entry`], except that it doesn't have ownership of the key. +enum RawEntry<'a, K, V> { + Vacant(RawVacantEntry<'a, K, V>), + Occupied(OccupiedEntry<'a, K, V>), +} + +/// A view into a vacant entry in a [`RBTree`]. It is part of the [`Entry`] enum. +pub struct VacantEntry<'a, K, V> { + key: K, + raw: RawVacantEntry<'a, K, V>, +} + +/// Like [`VacantEntry`], but doesn't hold on to the key. +/// +/// # Invariants +/// - `parent` may be null if the new node becomes the root. +/// - `child_field_of_parent` is a valid pointer to the left-child or right-child of `parent`. If `parent` is +/// null, it is a pointer to the root of the [`RBTree`]. +struct RawVacantEntry<'a, K, V> { + rbtree: *mut RBTree, + /// The node that will become the parent of the new node if we insert one. + parent: *mut bindings::rb_node, + /// This points to the left-child or right-child field of `parent`, or `root` if `parent` is + /// null. + child_field_of_parent: *mut *mut bindings::rb_node, + _phantom: PhantomData<&'a mut RBTree>, +} + +impl<'a, K, V> RawVacantEntry<'a, K, V> { + /// Inserts the given node into the [`RBTree`] at this entry. + /// + /// The `node` must have a key such that inserting it here does not break the ordering of this + /// [`RBTree`]. + fn insert(self, node: RBTreeNode) -> &'a mut V { + let node = Box::into_raw(node.node); + + // SAFETY: `node` is valid at least until we call `Box::from_raw`, which only happens when + // the node is removed or replaced. + let node_links = unsafe { addr_of_mut!((*node).links) }; + + // INVARIANT: We are linking in a new node, which is valid. It remains valid because we + // "forgot" it with `Box::into_raw`. + // SAFETY: The type invariants of `RawVacantEntry` are exactly the safety requirements of `rb_link_node`. + unsafe { bindings::rb_link_node(node_links, self.parent, self.child_field_of_parent) }; + + // SAFETY: All pointers are valid. `node` has just been inserted into the tree. + unsafe { bindings::rb_insert_color(node_links, addr_of_mut!((*self.rbtree).root)) }; + + // SAFETY: The node is valid until we remove it from the tree. + unsafe { &mut (*node).value } + } +} + +impl<'a, K, V> VacantEntry<'a, K, V> { + /// Inserts the given node into the [`RBTree`] at this entry. + pub fn insert(self, value: V, reservation: RBTreeNodeReservation) -> &'a mut V { + self.raw.insert(reservation.into_node(self.key, value)) + } +} + +/// A view into an occupied entry in a [`RBTree`]. It is part of the [`Entry`] enum. +/// +/// # Invariants +/// - `node_links` is a valid, non-null pointer to a tree node in `self.rbtree` +pub struct OccupiedEntry<'a, K, V> { + rbtree: &'a mut RBTree, + /// The node that this entry corresponds to. + node_links: *mut bindings::rb_node, +} + +impl<'a, K, V> OccupiedEntry<'a, K, V> { + /// Gets a reference to the value in the entry. + pub fn get(&self) -> &V { + // SAFETY: + // - `self.node_links` is a valid pointer to a node in the tree. + // - We have shared access to the underlying tree, and can thus give out a shared reference. + unsafe { &(*container_of!(self.node_links, Node, links)).value } + } + + /// Gets a mutable reference to the value in the entry. + pub fn get_mut(&mut self) -> &mut V { + // SAFETY: + // - `self.node_links` is a valid pointer to a node in the tree. + // - We have exclusive access to the underlying tree, and can thus give out a mutable reference. + unsafe { + &mut (*(container_of!(self.node_links, Node, links) as *mut Node)).value + } + } + + /// Converts the entry into a mutable reference to its value. + /// + /// If you need multiple references to the `OccupiedEntry`, see [`self#get_mut`]. + pub fn into_mut(self) -> &'a mut V { + // SAFETY: + // - `self.node_links` is a valid pointer to a node in the tree. + // - This consumes the `&'a mut RBTree`, therefore it can give out a mutable reference that lives for `'a`. + unsafe { + &mut (*(container_of!(self.node_links, Node, links) as *mut Node)).value + } + } + + /// Remove this entry from the [`RBTree`]. + pub fn remove_node(self) -> RBTreeNode { + // SAFETY: The node is a node in the tree, so it is valid. + unsafe { bindings::rb_erase(self.node_links, &mut self.rbtree.root) }; + + // INVARIANT: The node is being returned and the caller may free it, however, it was + // removed from the tree. So the invariants still hold. + RBTreeNode { + // SAFETY: The node was a node in the tree, but we removed it, so we can convert it + // back into a box. + node: unsafe { + Box::from_raw(container_of!(self.node_links, Node, links) as *mut Node) + }, + } + } + + /// Takes the value of the entry out of the map, and returns it. + pub fn remove(self) -> V { + self.remove_node().node.value + } + + /// Swap the current node for the provided node. + /// + /// The key of both nodes must be equal. + fn replace(self, node: RBTreeNode) -> RBTreeNode { + let node = Box::into_raw(node.node); + + // SAFETY: `node` is valid at least until we call `Box::from_raw`, which only happens when + // the node is removed or replaced. + let new_node_links = unsafe { addr_of_mut!((*node).links) }; + + // SAFETY: This updates the pointers so that `new_node_links` is in the tree where + // `self.node_links` used to be. + unsafe { + bindings::rb_replace_node(self.node_links, new_node_links, &mut self.rbtree.root) + }; + + // SAFETY: + // - `self.node_ptr` produces a valid pointer to a node in the tree. + // - Now that we removed this entry from the tree, we can convert the node to a box. + let old_node = unsafe { + Box::from_raw(container_of!(self.node_links, Node, links) as *mut Node) + }; + + RBTreeNode { node: old_node } + } +} + struct Node { links: bindings::rb_node, key: K,