bpf_sock_destroy() must be invoked from a context where the socket lock is held, but we cannot simply call lock_sock() inside sock_hash_seq_show(), since it's inside an RCU read side critical section and lock_sock() may sleep. We also don't want to hold the bucket lock while running sock_hash_seq_show(), since the BPF program may itself call map_update_elem() on the same socket hash, acquiring the same bucket lock and creating a deadlock. TCP and UDP socket iterators use a batching algorithm to decouple reading the current bucket's contents and running the BPF iterator program for each element in the bucket. This enables sock_hash_seq_show() to acquire the socket lock and lets helpers like bpf_sock_destroy() run safely. One concern with adopting a similar algorithm here is that with later patches in the series, bucket sizes can grow arbitrarily large, or at least as large as max_entries for the map. Naively adopting the same approach risks needing to allocate batches at least this large to cover the largest bucket size in the map. This could in theory be mitigated by placing an upper bound on our batch size and processing a bucket in multiple chunks, but processing in chunks without a reliable way to track progress through the bucket may lead to skipped or repeated elements as described in [1]. This could be solved with an indexing scheme like that described in [2] that associates a monotonically increasing index to new elements added to the head of the bucket, but doing so requires an extra 8 bytes to be added to each element. Not to mention that processing in multiple chunks requires that we seek to our last position multiple times, making iteration over a large bucket less efficient. This patch attempts to improve upon this by using reference counting to make sure that the current element and its descendants are not freed even outside an RCU read-side critical section and even if they're unlinked from the bucket in the meantime. This requires no batching and eliminates the need to seek to our last position on every read(). Note: This also fixes a latent bug in the original logic. Before, sock_hash_seq_start() always called sock_hash_seq_find_next() with prev_elem set to NULL, forcing iteration to start at the first element of the current bucket. This logic works under the assumption that sock_hash_seq_start() is only ever called once for iteration over the socket hash or that no bucket has more than one element; however, when using bpf_seq_write sock_hash_seq_start() and sock_hash_seq_stop() may be called several times as a series of read() calls are made by userspace, and it may be necessary to resume iteration in the middle of a bucket. As is, if iteration tries to resume in a bucket with more than one element it gets stuck, since there is no way to make progress. [1]: https://lore.kernel.org/bpf/Z_xQhm4aLW9UBykJ@t14/ [2]: https://lore.kernel.org/bpf/20250313233615.2329869-1-jrife@google.com/ Signed-off-by: Jordan Rife --- net/core/sock_map.c | 103 +++++++++++++++++++++++++++++++++++++------- 1 file changed, 87 insertions(+), 16 deletions(-) diff --git a/net/core/sock_map.c b/net/core/sock_map.c index 005112ba19fd..9d972069665b 100644 --- a/net/core/sock_map.c +++ b/net/core/sock_map.c @@ -1343,23 +1343,69 @@ const struct bpf_func_proto bpf_msg_redirect_hash_proto = { struct sock_hash_seq_info { struct bpf_map *map; struct bpf_shtab *htab; + struct bpf_shtab_elem *next_elem; u32 bucket_id; }; +static inline bool bpf_shtab_elem_unhashed(struct bpf_shtab_elem *elem) +{ + return READ_ONCE(elem->node.pprev) == LIST_POISON2; +} + +static struct bpf_shtab_elem *sock_hash_seq_hold_next(struct bpf_shtab_elem *elem) +{ + hlist_for_each_entry_from_rcu(elem, node) + /* It's possible that the first element or its descendants were + * unlinked from the bucket's list. Skip any unlinked elements + * until we get back to the main list. + */ + if (!bpf_shtab_elem_unhashed(elem) && + sock_hash_hold_elem(elem)) + return elem; + + return NULL; +} + static void *sock_hash_seq_find_next(struct sock_hash_seq_info *info, struct bpf_shtab_elem *prev_elem) { const struct bpf_shtab *htab = info->htab; + struct bpf_shtab_elem *elem = NULL; struct bpf_shtab_bucket *bucket; - struct bpf_shtab_elem *elem; struct hlist_node *node; + /* RCU is important here. It's possible that a parallel update operation + * unlinks an element while we're handling it. Without rcu_read_lock(), + * this sequence could occur: + * + * 1. sock_hash_seq_find_next() gets to elem but hasn't yet taken a + * reference to it. + * 2. elem is unlinked and sock_hash_put_elem() schedules + * sock_hash_free_elem(): + * call_rcu(&elem->rcu, sock_hash_free_elem); + * 3. sock_hash_free_elem() runs, freeing elem. + * 4. sock_hash_seq_find_next() continues and tries to read elem + * creating a use-after-free. + * + * rcu_read_lock() guarantees that elem won't be freed out from under + * us, and if a parallel update unlinks it then either: + * + * (i) We will take a reference to it before sock_hash_put_elem() + * decrements the reference count thus preventing it from calling + * call_rcu. + * (ii) We will fail to take a reference to it and simply proceed to the + * next element in the list until we find an element that isn't + * currently being removed from the list or reach the end of the + * list. + */ + rcu_read_lock(); /* try to find next elem in the same bucket */ if (prev_elem) { node = rcu_dereference(hlist_next_rcu(&prev_elem->node)); elem = hlist_entry_safe(node, struct bpf_shtab_elem, node); + elem = sock_hash_seq_hold_next(elem); if (elem) - return elem; + goto unlock; /* no more elements, continue in the next bucket */ info->bucket_id++; @@ -1369,28 +1415,47 @@ static void *sock_hash_seq_find_next(struct sock_hash_seq_info *info, bucket = &htab->buckets[info->bucket_id]; node = rcu_dereference(hlist_first_rcu(&bucket->head)); elem = hlist_entry_safe(node, struct bpf_shtab_elem, node); + elem = sock_hash_seq_hold_next(elem); if (elem) - return elem; + goto unlock; } - - return NULL; +unlock: + /* sock_hash_put_elem() will free all elements up until the + * point that either: + * + * (i) It hits elem + * (ii) It hits an unlinked element between prev_elem and elem + * to which another iterator holds a reference. + * + * In case (i), this iterator is responsible for freeing all the + * unlinked but as yet unfreed elements in this chain. In case (ii), it + * is the other iterator's responsibility to free remaining elements + * after that point. The last one out "shuts the door". + */ + if (prev_elem) + sock_hash_put_elem(prev_elem); + rcu_read_unlock(); + return elem; } static void *sock_hash_seq_start(struct seq_file *seq, loff_t *pos) - __acquires(rcu) { struct sock_hash_seq_info *info = seq->private; if (*pos == 0) ++*pos; - /* pairs with sock_hash_seq_stop */ - rcu_read_lock(); - return sock_hash_seq_find_next(info, NULL); + /* info->next_elem may have become unhashed between read()s. If so, skip + * it to avoid inconsistencies where, e.g., an element is deleted from + * the map then appears in the next call to read(). + */ + if (!info->next_elem || bpf_shtab_elem_unhashed(info->next_elem)) + return sock_hash_seq_find_next(info, info->next_elem); + + return info->next_elem; } static void *sock_hash_seq_next(struct seq_file *seq, void *v, loff_t *pos) - __must_hold(rcu) { struct sock_hash_seq_info *info = seq->private; @@ -1399,13 +1464,13 @@ static void *sock_hash_seq_next(struct seq_file *seq, void *v, loff_t *pos) } static int sock_hash_seq_show(struct seq_file *seq, void *v) - __must_hold(rcu) { struct sock_hash_seq_info *info = seq->private; struct bpf_iter__sockmap ctx = {}; struct bpf_shtab_elem *elem = v; struct bpf_iter_meta meta; struct bpf_prog *prog; + int ret; meta.seq = seq; prog = bpf_iter_get_info(&meta, !elem); @@ -1419,17 +1484,21 @@ static int sock_hash_seq_show(struct seq_file *seq, void *v) ctx.sk = elem->sk; } - return bpf_iter_run_prog(prog, &ctx); + if (elem) + lock_sock(elem->sk); + ret = bpf_iter_run_prog(prog, &ctx); + if (elem) + release_sock(elem->sk); + return ret; } static void sock_hash_seq_stop(struct seq_file *seq, void *v) - __releases(rcu) { + struct sock_hash_seq_info *info = seq->private; + if (!v) (void)sock_hash_seq_show(seq, NULL); - - /* pairs with sock_hash_seq_start */ - rcu_read_unlock(); + info->next_elem = v; } static const struct seq_operations sock_hash_seq_ops = { @@ -1454,6 +1523,8 @@ static void sock_hash_fini_seq_private(void *priv_data) { struct sock_hash_seq_info *info = priv_data; + if (info->next_elem) + sock_hash_put_elem(info->next_elem); bpf_map_put_with_uref(info->map); } -- 2.43.0