Use reference counting to decide when to free socket hash elements instead of freeing them immediately after they are unlinked from a bucket's list. In the next patch this is essential, allowing socket hash iterators to hold a reference to a `struct bpf_shtab_elem` outside of an RCU read-side critical section. sock_hash_put_elem() follows the list, scheduling elements to be freed until it hits an element where the reference count is two or greater. This does nothing yet; in this patch the loop will never iterate more than once, since we always take a reference to the next element in sock_hash_unlink_elem() before calling sock_hash_put_elem(), and in general, the reference count to any element is always one except during these transitions. However, in the next patch it's possible for an iterator to hold a reference to an element that has been unlinked from a bucket's list. In this context, sock_hash_put_elem() may free several unlinked elements up until the point where it finds an element that is still in the bucket's list. Signed-off-by: Jordan Rife --- net/core/sock_map.c | 67 +++++++++++++++++++++++++++++++++------------ 1 file changed, 50 insertions(+), 17 deletions(-) diff --git a/net/core/sock_map.c b/net/core/sock_map.c index 5947b38e4f8b..005112ba19fd 100644 --- a/net/core/sock_map.c +++ b/net/core/sock_map.c @@ -847,6 +847,7 @@ struct bpf_shtab_elem { u32 hash; struct sock *sk; struct hlist_node node; + refcount_t ref; u8 key[]; }; @@ -906,11 +907,46 @@ static struct sock *__sock_hash_lookup_elem(struct bpf_map *map, void *key) return elem ? elem->sk : NULL; } -static void sock_hash_free_elem(struct bpf_shtab *htab, - struct bpf_shtab_elem *elem) +static void sock_hash_free_elem(struct rcu_head *rcu_head) { + struct bpf_shtab_elem *elem = container_of(rcu_head, + struct bpf_shtab_elem, rcu); + + /* Matches sock_hold() in sock_hash_alloc_elem(). */ + sock_put(elem->sk); + kfree(elem); +} + +static void sock_hash_put_elem(struct bpf_shtab_elem *elem) +{ + while (elem && refcount_dec_and_test(&elem->ref)) { + call_rcu(&elem->rcu, sock_hash_free_elem); + elem = hlist_entry_safe(rcu_dereference(hlist_next_rcu(&elem->node)), + struct bpf_shtab_elem, node); + } +} + +static bool sock_hash_hold_elem(struct bpf_shtab_elem *elem) +{ + return refcount_inc_not_zero(&elem->ref); +} + +static void sock_hash_unlink_elem(struct bpf_shtab *htab, + struct bpf_shtab_elem *elem) +{ + struct bpf_shtab_elem *elem_next; + + elem_next = hlist_entry_safe(rcu_dereference(hlist_next_rcu(&elem->node)), + struct bpf_shtab_elem, node); + hlist_del_rcu(&elem->node); + sock_map_unref(elem->sk, elem); + /* Take a reference to the next element first to make sure it's not + * freed by the call to sock_hash_put_elem(). + */ + if (elem_next) + sock_hash_hold_elem(elem_next); + sock_hash_put_elem(elem); atomic_dec(&htab->count); - kfree_rcu(elem, rcu); } static void sock_hash_delete_from_link(struct bpf_map *map, struct sock *sk, @@ -930,11 +966,8 @@ static void sock_hash_delete_from_link(struct bpf_map *map, struct sock *sk, spin_lock_bh(&bucket->lock); elem_probe = sock_hash_lookup_elem_raw(&bucket->head, elem->hash, elem->key, map->key_size); - if (elem_probe && elem_probe == elem) { - hlist_del_rcu(&elem->node); - sock_map_unref(elem->sk, elem); - sock_hash_free_elem(htab, elem); - } + if (elem_probe && elem_probe == elem) + sock_hash_unlink_elem(htab, elem); spin_unlock_bh(&bucket->lock); } @@ -952,9 +985,7 @@ static long sock_hash_delete_elem(struct bpf_map *map, void *key) spin_lock_bh(&bucket->lock); elem = sock_hash_lookup_elem_raw(&bucket->head, hash, key, key_size); if (elem) { - hlist_del_rcu(&elem->node); - sock_map_unref(elem->sk, elem); - sock_hash_free_elem(htab, elem); + sock_hash_unlink_elem(htab, elem); ret = 0; } spin_unlock_bh(&bucket->lock); @@ -985,6 +1016,11 @@ static struct bpf_shtab_elem *sock_hash_alloc_elem(struct bpf_shtab *htab, memcpy(new->key, key, key_size); new->sk = sk; new->hash = hash; + refcount_set(&new->ref, 1); + /* Matches sock_put() in sock_hash_free_elem(). Ensure that sk is not + * freed until elem is. + */ + sock_hold(sk); return new; } @@ -1038,11 +1074,8 @@ static int sock_hash_update_common(struct bpf_map *map, void *key, * concurrent search will find it before old elem. */ hlist_add_head_rcu(&elem_new->node, &bucket->head); - if (elem) { - hlist_del_rcu(&elem->node); - sock_map_unref(elem->sk, elem); - sock_hash_free_elem(htab, elem); - } + if (elem) + sock_hash_unlink_elem(htab, elem); spin_unlock_bh(&bucket->lock); return 0; out_unlock: @@ -1182,7 +1215,7 @@ static void sock_hash_free(struct bpf_map *map) rcu_read_unlock(); release_sock(elem->sk); sock_put(elem->sk); - sock_hash_free_elem(htab, elem); + call_rcu(&elem->rcu, sock_hash_free_elem); } cond_resched(); } -- 2.43.0 bpf_sock_destroy() must be invoked from a context where the socket lock is held, but we cannot simply call lock_sock() inside sock_hash_seq_show(), since it's inside an RCU read side critical section and lock_sock() may sleep. We also don't want to hold the bucket lock while running sock_hash_seq_show(), since the BPF program may itself call map_update_elem() on the same socket hash, acquiring the same bucket lock and creating a deadlock. TCP and UDP socket iterators use a batching algorithm to decouple reading the current bucket's contents and running the BPF iterator program for each element in the bucket. This enables sock_hash_seq_show() to acquire the socket lock and lets helpers like bpf_sock_destroy() run safely. One concern with adopting a similar algorithm here is that with later patches in the series, bucket sizes can grow arbitrarily large, or at least as large as max_entries for the map. Naively adopting the same approach risks needing to allocate batches at least this large to cover the largest bucket size in the map. This could in theory be mitigated by placing an upper bound on our batch size and processing a bucket in multiple chunks, but processing in chunks without a reliable way to track progress through the bucket may lead to skipped or repeated elements as described in [1]. This could be solved with an indexing scheme like that described in [2] that associates a monotonically increasing index to new elements added to the head of the bucket, but doing so requires an extra 8 bytes to be added to each element. Not to mention that processing in multiple chunks requires that we seek to our last position multiple times, making iteration over a large bucket less efficient. This patch attempts to improve upon this by using reference counting to make sure that the current element and its descendants are not freed even outside an RCU read-side critical section and even if they're unlinked from the bucket in the meantime. This requires no batching and eliminates the need to seek to our last position on every read(). Note: This also fixes a latent bug in the original logic. Before, sock_hash_seq_start() always called sock_hash_seq_find_next() with prev_elem set to NULL, forcing iteration to start at the first element of the current bucket. This logic works under the assumption that sock_hash_seq_start() is only ever called once for iteration over the socket hash or that no bucket has more than one element; however, when using bpf_seq_write sock_hash_seq_start() and sock_hash_seq_stop() may be called several times as a series of read() calls are made by userspace, and it may be necessary to resume iteration in the middle of a bucket. As is, if iteration tries to resume in a bucket with more than one element it gets stuck, since there is no way to make progress. [1]: https://lore.kernel.org/bpf/Z_xQhm4aLW9UBykJ@t14/ [2]: https://lore.kernel.org/bpf/20250313233615.2329869-1-jrife@google.com/ Signed-off-by: Jordan Rife --- net/core/sock_map.c | 103 +++++++++++++++++++++++++++++++++++++------- 1 file changed, 87 insertions(+), 16 deletions(-) diff --git a/net/core/sock_map.c b/net/core/sock_map.c index 005112ba19fd..9d972069665b 100644 --- a/net/core/sock_map.c +++ b/net/core/sock_map.c @@ -1343,23 +1343,69 @@ const struct bpf_func_proto bpf_msg_redirect_hash_proto = { struct sock_hash_seq_info { struct bpf_map *map; struct bpf_shtab *htab; + struct bpf_shtab_elem *next_elem; u32 bucket_id; }; +static inline bool bpf_shtab_elem_unhashed(struct bpf_shtab_elem *elem) +{ + return READ_ONCE(elem->node.pprev) == LIST_POISON2; +} + +static struct bpf_shtab_elem *sock_hash_seq_hold_next(struct bpf_shtab_elem *elem) +{ + hlist_for_each_entry_from_rcu(elem, node) + /* It's possible that the first element or its descendants were + * unlinked from the bucket's list. Skip any unlinked elements + * until we get back to the main list. + */ + if (!bpf_shtab_elem_unhashed(elem) && + sock_hash_hold_elem(elem)) + return elem; + + return NULL; +} + static void *sock_hash_seq_find_next(struct sock_hash_seq_info *info, struct bpf_shtab_elem *prev_elem) { const struct bpf_shtab *htab = info->htab; + struct bpf_shtab_elem *elem = NULL; struct bpf_shtab_bucket *bucket; - struct bpf_shtab_elem *elem; struct hlist_node *node; + /* RCU is important here. It's possible that a parallel update operation + * unlinks an element while we're handling it. Without rcu_read_lock(), + * this sequence could occur: + * + * 1. sock_hash_seq_find_next() gets to elem but hasn't yet taken a + * reference to it. + * 2. elem is unlinked and sock_hash_put_elem() schedules + * sock_hash_free_elem(): + * call_rcu(&elem->rcu, sock_hash_free_elem); + * 3. sock_hash_free_elem() runs, freeing elem. + * 4. sock_hash_seq_find_next() continues and tries to read elem + * creating a use-after-free. + * + * rcu_read_lock() guarantees that elem won't be freed out from under + * us, and if a parallel update unlinks it then either: + * + * (i) We will take a reference to it before sock_hash_put_elem() + * decrements the reference count thus preventing it from calling + * call_rcu. + * (ii) We will fail to take a reference to it and simply proceed to the + * next element in the list until we find an element that isn't + * currently being removed from the list or reach the end of the + * list. + */ + rcu_read_lock(); /* try to find next elem in the same bucket */ if (prev_elem) { node = rcu_dereference(hlist_next_rcu(&prev_elem->node)); elem = hlist_entry_safe(node, struct bpf_shtab_elem, node); + elem = sock_hash_seq_hold_next(elem); if (elem) - return elem; + goto unlock; /* no more elements, continue in the next bucket */ info->bucket_id++; @@ -1369,28 +1415,47 @@ static void *sock_hash_seq_find_next(struct sock_hash_seq_info *info, bucket = &htab->buckets[info->bucket_id]; node = rcu_dereference(hlist_first_rcu(&bucket->head)); elem = hlist_entry_safe(node, struct bpf_shtab_elem, node); + elem = sock_hash_seq_hold_next(elem); if (elem) - return elem; + goto unlock; } - - return NULL; +unlock: + /* sock_hash_put_elem() will free all elements up until the + * point that either: + * + * (i) It hits elem + * (ii) It hits an unlinked element between prev_elem and elem + * to which another iterator holds a reference. + * + * In case (i), this iterator is responsible for freeing all the + * unlinked but as yet unfreed elements in this chain. In case (ii), it + * is the other iterator's responsibility to free remaining elements + * after that point. The last one out "shuts the door". + */ + if (prev_elem) + sock_hash_put_elem(prev_elem); + rcu_read_unlock(); + return elem; } static void *sock_hash_seq_start(struct seq_file *seq, loff_t *pos) - __acquires(rcu) { struct sock_hash_seq_info *info = seq->private; if (*pos == 0) ++*pos; - /* pairs with sock_hash_seq_stop */ - rcu_read_lock(); - return sock_hash_seq_find_next(info, NULL); + /* info->next_elem may have become unhashed between read()s. If so, skip + * it to avoid inconsistencies where, e.g., an element is deleted from + * the map then appears in the next call to read(). + */ + if (!info->next_elem || bpf_shtab_elem_unhashed(info->next_elem)) + return sock_hash_seq_find_next(info, info->next_elem); + + return info->next_elem; } static void *sock_hash_seq_next(struct seq_file *seq, void *v, loff_t *pos) - __must_hold(rcu) { struct sock_hash_seq_info *info = seq->private; @@ -1399,13 +1464,13 @@ static void *sock_hash_seq_next(struct seq_file *seq, void *v, loff_t *pos) } static int sock_hash_seq_show(struct seq_file *seq, void *v) - __must_hold(rcu) { struct sock_hash_seq_info *info = seq->private; struct bpf_iter__sockmap ctx = {}; struct bpf_shtab_elem *elem = v; struct bpf_iter_meta meta; struct bpf_prog *prog; + int ret; meta.seq = seq; prog = bpf_iter_get_info(&meta, !elem); @@ -1419,17 +1484,21 @@ static int sock_hash_seq_show(struct seq_file *seq, void *v) ctx.sk = elem->sk; } - return bpf_iter_run_prog(prog, &ctx); + if (elem) + lock_sock(elem->sk); + ret = bpf_iter_run_prog(prog, &ctx); + if (elem) + release_sock(elem->sk); + return ret; } static void sock_hash_seq_stop(struct seq_file *seq, void *v) - __releases(rcu) { + struct sock_hash_seq_info *info = seq->private; + if (!v) (void)sock_hash_seq_show(seq, NULL); - - /* pairs with sock_hash_seq_start */ - rcu_read_unlock(); + info->next_elem = v; } static const struct seq_operations sock_hash_seq_ops = { @@ -1454,6 +1523,8 @@ static void sock_hash_fini_seq_private(void *priv_data) { struct sock_hash_seq_info *info = priv_data; + if (info->next_elem) + sock_hash_put_elem(info->next_elem); bpf_map_put_with_uref(info->map); } -- 2.43.0 Similar to socket hash iterators, decouple reading from processing to enable bpf_iter_run_prog to run while holding the socket lock and take a reference to the current socket to ensure that it isn't freed outside of the RCU read-side critical section. Signed-off-by: Jordan Rife --- net/core/sock_map.c | 33 ++++++++++++++++++++++++--------- 1 file changed, 24 insertions(+), 9 deletions(-) diff --git a/net/core/sock_map.c b/net/core/sock_map.c index 9d972069665b..f33bfce96b9e 100644 --- a/net/core/sock_map.c +++ b/net/core/sock_map.c @@ -723,30 +723,39 @@ static void *sock_map_seq_lookup_elem(struct sock_map_seq_info *info) if (unlikely(info->index >= info->map->max_entries)) return NULL; + rcu_read_lock(); info->sk = __sock_map_lookup_elem(info->map, info->index); + if (info->sk) + sock_hold(info->sk); + rcu_read_unlock(); /* can't return sk directly, since that might be NULL */ return info; } +static void sock_map_seq_put_elem(struct sock_map_seq_info *info) +{ + if (info->sk) { + sock_put(info->sk); + info->sk = NULL; + } +} + static void *sock_map_seq_start(struct seq_file *seq, loff_t *pos) - __acquires(rcu) { struct sock_map_seq_info *info = seq->private; if (*pos == 0) ++*pos; - /* pairs with sock_map_seq_stop */ - rcu_read_lock(); return sock_map_seq_lookup_elem(info); } static void *sock_map_seq_next(struct seq_file *seq, void *v, loff_t *pos) - __must_hold(rcu) { struct sock_map_seq_info *info = seq->private; + sock_map_seq_put_elem(info); ++*pos; ++info->index; @@ -754,12 +763,12 @@ static void *sock_map_seq_next(struct seq_file *seq, void *v, loff_t *pos) } static int sock_map_seq_show(struct seq_file *seq, void *v) - __must_hold(rcu) { struct sock_map_seq_info *info = seq->private; struct bpf_iter__sockmap ctx = {}; struct bpf_iter_meta meta; struct bpf_prog *prog; + int ret; meta.seq = seq; prog = bpf_iter_get_info(&meta, !v); @@ -773,17 +782,23 @@ static int sock_map_seq_show(struct seq_file *seq, void *v) ctx.sk = info->sk; } - return bpf_iter_run_prog(prog, &ctx); + if (ctx.sk) + lock_sock(ctx.sk); + ret = bpf_iter_run_prog(prog, &ctx); + if (ctx.sk) + release_sock(ctx.sk); + + return ret; } static void sock_map_seq_stop(struct seq_file *seq, void *v) - __releases(rcu) { + struct sock_map_seq_info *info = seq->private; + if (!v) (void)sock_map_seq_show(seq, NULL); - /* pairs with sock_map_seq_start */ - rcu_read_unlock(); + sock_map_seq_put_elem(info); } static const struct seq_operations sock_map_seq_ops = { -- 2.43.0 Allow sk to be passed to bpf_sock_destroy() by marking it as PTR_TRUSTED. Signed-off-by: Jordan Rife --- net/core/sock_map.c | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/net/core/sock_map.c b/net/core/sock_map.c index f33bfce96b9e..20b0627b1eb1 100644 --- a/net/core/sock_map.c +++ b/net/core/sock_map.c @@ -2065,7 +2065,7 @@ static struct bpf_iter_reg sock_map_iter_reg = { { offsetof(struct bpf_iter__sockmap, key), PTR_TO_BUF | PTR_MAYBE_NULL | MEM_RDONLY }, { offsetof(struct bpf_iter__sockmap, sk), - PTR_TO_BTF_ID_OR_NULL }, + PTR_TO_BTF_ID_OR_NULL | PTR_TRUSTED }, }, }; -- 2.43.0 Expand the suite of tests for bpf_sock_destroy() to include invocation from socket map and socket hash iterators. Signed-off-by: Jordan Rife --- .../selftests/bpf/prog_tests/sock_destroy.c | 119 +++++++++++++++--- .../selftests/bpf/progs/sock_destroy_prog.c | 63 ++++++++++ 2 files changed, 164 insertions(+), 18 deletions(-) diff --git a/tools/testing/selftests/bpf/prog_tests/sock_destroy.c b/tools/testing/selftests/bpf/prog_tests/sock_destroy.c index 9c11938fe597..829e6fb9c1d8 100644 --- a/tools/testing/selftests/bpf/prog_tests/sock_destroy.c +++ b/tools/testing/selftests/bpf/prog_tests/sock_destroy.c @@ -8,13 +8,20 @@ #define TEST_NS "sock_destroy_netns" -static void start_iter_sockets(struct bpf_program *prog) +static void start_iter_sockets(struct bpf_program *prog, struct bpf_map *map) { + DECLARE_LIBBPF_OPTS(bpf_iter_attach_opts, opts); + union bpf_iter_link_info linfo = {}; struct bpf_link *link; char buf[50] = {}; int iter_fd, len; - link = bpf_program__attach_iter(prog, NULL); + if (map) + linfo.map.map_fd = bpf_map__fd(map); + opts.link_info = &linfo; + opts.link_info_len = sizeof(linfo); + + link = bpf_program__attach_iter(prog, &opts); if (!ASSERT_OK_PTR(link, "attach_iter")) return; @@ -32,7 +39,22 @@ static void start_iter_sockets(struct bpf_program *prog) bpf_link__destroy(link); } -static void test_tcp_client(struct sock_destroy_prog *skel) +static int insert_socket(struct bpf_map *socks, int fd, __u32 key) +{ + int map_fd = bpf_map__fd(socks); + __s64 sfd = fd; + int ret; + + ret = bpf_map_update_elem(map_fd, &key, &sfd, BPF_NOEXIST); + if (!ASSERT_OK(ret, "map_update")) + return -1; + + return 0; +} + +static void test_tcp_client(struct sock_destroy_prog *skel, + struct bpf_program *prog, + struct bpf_map *socks) { int serv = -1, clien = -1, accept_serv = -1, n; @@ -52,8 +74,17 @@ static void test_tcp_client(struct sock_destroy_prog *skel) if (!ASSERT_EQ(n, 1, "client send")) goto cleanup; + if (socks) { + if (!ASSERT_OK(insert_socket(socks, clien, 0), + "insert_socket")) + goto cleanup; + if (!ASSERT_OK(insert_socket(socks, serv, 1), + "insert_socket")) + goto cleanup; + } + /* Run iterator program that destroys connected client sockets. */ - start_iter_sockets(skel->progs.iter_tcp6_client); + start_iter_sockets(prog, socks); n = send(clien, "t", 1, 0); if (!ASSERT_LT(n, 0, "client_send on destroyed socket")) @@ -69,7 +100,9 @@ static void test_tcp_client(struct sock_destroy_prog *skel) close(serv); } -static void test_tcp_server(struct sock_destroy_prog *skel) +static void test_tcp_server(struct sock_destroy_prog *skel, + struct bpf_program *prog, + struct bpf_map *socks) { int serv = -1, clien = -1, accept_serv = -1, n, serv_port; @@ -93,8 +126,17 @@ static void test_tcp_server(struct sock_destroy_prog *skel) if (!ASSERT_EQ(n, 1, "client send")) goto cleanup; + if (socks) { + if (!ASSERT_OK(insert_socket(socks, clien, 0), + "insert_socket")) + goto cleanup; + if (!ASSERT_OK(insert_socket(socks, accept_serv, 1), + "insert_socket")) + goto cleanup; + } + /* Run iterator program that destroys server sockets. */ - start_iter_sockets(skel->progs.iter_tcp6_server); + start_iter_sockets(prog, socks); n = send(clien, "t", 1, 0); if (!ASSERT_LT(n, 0, "client_send on destroyed socket")) @@ -110,7 +152,9 @@ static void test_tcp_server(struct sock_destroy_prog *skel) close(serv); } -static void test_udp_client(struct sock_destroy_prog *skel) +static void test_udp_client(struct sock_destroy_prog *skel, + struct bpf_program *prog, + struct bpf_map *socks) { int serv = -1, clien = -1, n = 0; @@ -126,8 +170,17 @@ static void test_udp_client(struct sock_destroy_prog *skel) if (!ASSERT_EQ(n, 1, "client send")) goto cleanup; + if (socks) { + if (!ASSERT_OK(insert_socket(socks, clien, 0), + "insert_socket")) + goto cleanup; + if (!ASSERT_OK(insert_socket(socks, serv, 1), + "insert_socket")) + goto cleanup; + } + /* Run iterator program that destroys sockets. */ - start_iter_sockets(skel->progs.iter_udp6_client); + start_iter_sockets(prog, socks); n = send(clien, "t", 1, 0); if (!ASSERT_LT(n, 0, "client_send on destroyed socket")) @@ -143,11 +196,14 @@ static void test_udp_client(struct sock_destroy_prog *skel) close(serv); } -static void test_udp_server(struct sock_destroy_prog *skel) +static void test_udp_server(struct sock_destroy_prog *skel, + struct bpf_program *prog, + struct bpf_map *socks) { int *listen_fds = NULL, n, i, serv_port; unsigned int num_listens = 5; char buf[1]; + __u32 key; /* Start reuseport servers. */ listen_fds = start_reuseport_server(AF_INET6, SOCK_DGRAM, @@ -159,8 +215,15 @@ static void test_udp_server(struct sock_destroy_prog *skel) goto cleanup; skel->bss->serv_port = (__be16) serv_port; + if (socks) + for (key = 0; key < num_listens; key++) + if (!ASSERT_OK(insert_socket(socks, listen_fds[key], + key), + "insert_socket")) + goto cleanup; + /* Run iterator program that destroys server sockets. */ - start_iter_sockets(skel->progs.iter_udp6_server); + start_iter_sockets(prog, socks); for (i = 0; i < num_listens; ++i) { n = read(listen_fds[i], buf, sizeof(buf)); @@ -200,14 +263,34 @@ void test_sock_destroy(void) if (!ASSERT_OK_PTR(nstoken, "open_netns")) goto cleanup; - if (test__start_subtest("tcp_client")) - test_tcp_client(skel); - if (test__start_subtest("tcp_server")) - test_tcp_server(skel); - if (test__start_subtest("udp_client")) - test_udp_client(skel); - if (test__start_subtest("udp_server")) - test_udp_server(skel); + if (test__start_subtest("tcp_client")) { + test_tcp_client(skel, skel->progs.iter_tcp6_client, NULL); + test_tcp_client(skel, skel->progs.iter_sockmap_client, + skel->maps.sock_map); + test_tcp_client(skel, skel->progs.iter_sockmap_client, + skel->maps.sock_hash); + } + if (test__start_subtest("tcp_server")) { + test_tcp_server(skel, skel->progs.iter_tcp6_server, NULL); + test_tcp_server(skel, skel->progs.iter_sockmap_server, + skel->maps.sock_map); + test_tcp_server(skel, skel->progs.iter_sockmap_server, + skel->maps.sock_hash); + } + if (test__start_subtest("udp_client")) { + test_udp_client(skel, skel->progs.iter_udp6_client, NULL); + test_udp_client(skel, skel->progs.iter_sockmap_client, + skel->maps.sock_map); + test_udp_client(skel, skel->progs.iter_sockmap_client, + skel->maps.sock_hash); + } + if (test__start_subtest("udp_server")) { + test_udp_server(skel, skel->progs.iter_udp6_server, NULL); + test_udp_server(skel, skel->progs.iter_sockmap_server, + skel->maps.sock_map); + test_udp_server(skel, skel->progs.iter_sockmap_server, + skel->maps.sock_hash); + } RUN_TESTS(sock_destroy_prog_fail); diff --git a/tools/testing/selftests/bpf/progs/sock_destroy_prog.c b/tools/testing/selftests/bpf/progs/sock_destroy_prog.c index 9e0bf7a54cec..d91f75190bbf 100644 --- a/tools/testing/selftests/bpf/progs/sock_destroy_prog.c +++ b/tools/testing/selftests/bpf/progs/sock_destroy_prog.c @@ -24,6 +24,20 @@ struct { __type(value, __u64); } udp_conn_sockets SEC(".maps"); +struct { + __uint(type, BPF_MAP_TYPE_SOCKMAP); + __uint(max_entries, 5); + __type(key, __u32); + __type(value, __u64); +} sock_map SEC(".maps"); + +struct { + __uint(type, BPF_MAP_TYPE_SOCKHASH); + __uint(max_entries, 5); + __type(key, __u32); + __type(value, __u64); +} sock_hash SEC(".maps"); + SEC("cgroup/connect6") int sock_connect(struct bpf_sock_addr *ctx) { @@ -142,4 +156,53 @@ int iter_udp6_server(struct bpf_iter__udp *ctx) return 0; } +SEC("iter/sockmap") +int iter_sockmap_client(struct bpf_iter__sockmap *ctx) +{ + __u64 sock_cookie = 0, *val; + struct sock *sk = ctx->sk; + __u32 *key = ctx->key; + __u32 zero = 0; + + if (!key || !sk) + return 0; + + sock_cookie = bpf_get_socket_cookie(sk); + val = bpf_map_lookup_elem(&udp_conn_sockets, &zero); + if (val && *val == sock_cookie) + goto destroy; + val = bpf_map_lookup_elem(&tcp_conn_sockets, &zero); + if (val && *val == sock_cookie) + goto destroy; + goto out; +destroy: + bpf_sock_destroy((struct sock_common *)sk); +out: + return 0; +} + +SEC("iter/sockmap") +int iter_sockmap_server(struct bpf_iter__sockmap *ctx) +{ + struct sock *sk = ctx->sk; + struct tcp6_sock *tcp_sk; + struct udp6_sock *udp_sk; + __u32 *key = ctx->key; + + if (!key || !sk) + return 0; + + tcp_sk = bpf_skc_to_tcp6_sock(sk); + if (tcp_sk && tcp_sk->tcp.inet_conn.icsk_inet.inet_sport == serv_port) + goto destroy; + udp_sk = bpf_skc_to_udp6_sock(sk); + if (udp_sk && udp_sk->udp.inet.inet_sport == serv_port) + goto destroy; + goto out; +destroy: + bpf_sock_destroy((struct sock_common *)sk); +out: + return 0; +} + char _license[] SEC("license") = "GPL"; -- 2.43.0 Enable control over which keys are bucketed together in the hash by allowing users to specify the number of bytes from the key that should be used to determine the bucket hash. Example: ``` struct ipv4_sockets_tuple { union v4addr address; __be32 port; __sock_cookie cookie; } __packed; struct { __uint(type, BPF_MAP_TYPE_SOCKHASH); __uint(max_entries, 1 << 20); /* ~1 million */ __uint(map_extra, offsetof(struct ipv4_sockets_tuple, cookie)); __type(key, struct ipv4_sockets_tuple); __type(value, __u64); } sockets SEC(".maps"); ``` This allows you to bucket all keys sharing a common prefix together to, for example, place all sockets connected to a single backend in the same bucket. This is complimented by a change later in this series that allows users to specify a key prefix filter when creating a socket hash iterator. Note: struct bpf_shtab_elem currently contains a four byte hole between hash and sk, so place bucket_hash there. Signed-off-by: Jordan Rife --- kernel/bpf/syscall.c | 1 + net/core/sock_map.c | 57 ++++++++++++++++++++++++++++---------------- 2 files changed, 38 insertions(+), 20 deletions(-) diff --git a/kernel/bpf/syscall.c b/kernel/bpf/syscall.c index 3f178a0f8eb1..f5992e588fc7 100644 --- a/kernel/bpf/syscall.c +++ b/kernel/bpf/syscall.c @@ -1371,6 +1371,7 @@ static int map_create(union bpf_attr *attr, bool kernel) if (attr->map_type != BPF_MAP_TYPE_BLOOM_FILTER && attr->map_type != BPF_MAP_TYPE_ARENA && + attr->map_type != BPF_MAP_TYPE_SOCKHASH && attr->map_extra != 0) return -EINVAL; diff --git a/net/core/sock_map.c b/net/core/sock_map.c index 20b0627b1eb1..51930f24d2f9 100644 --- a/net/core/sock_map.c +++ b/net/core/sock_map.c @@ -860,6 +860,7 @@ const struct bpf_map_ops sock_map_ops = { struct bpf_shtab_elem { struct rcu_head rcu; u32 hash; + u32 bucket_hash; struct sock *sk; struct hlist_node node; refcount_t ref; @@ -878,11 +879,14 @@ struct bpf_shtab { u32 elem_size; struct sk_psock_progs progs; atomic_t count; + u32 hash_len; }; -static inline u32 sock_hash_bucket_hash(const void *key, u32 len) +static inline void sock_hash_elem_hash(const void *key, u32 *bucket_hash, + u32 *hash, u32 hash_len, u32 key_size) { - return jhash(key, len, 0); + *bucket_hash = jhash(key, hash_len, 0); + *hash = hash_len == key_size ? *bucket_hash : jhash(key, key_size, 0); } static struct bpf_shtab_bucket *sock_hash_select_bucket(struct bpf_shtab *htab, @@ -909,14 +913,15 @@ sock_hash_lookup_elem_raw(struct hlist_head *head, u32 hash, void *key, static struct sock *__sock_hash_lookup_elem(struct bpf_map *map, void *key) { struct bpf_shtab *htab = container_of(map, struct bpf_shtab, map); - u32 key_size = map->key_size, hash; + u32 key_size = map->key_size, bucket_hash, hash; struct bpf_shtab_bucket *bucket; struct bpf_shtab_elem *elem; WARN_ON_ONCE(!rcu_read_lock_held()); - hash = sock_hash_bucket_hash(key, key_size); - bucket = sock_hash_select_bucket(htab, hash); + sock_hash_elem_hash(key, &bucket_hash, &hash, htab->hash_len, + map->key_size); + bucket = sock_hash_select_bucket(htab, bucket_hash); elem = sock_hash_lookup_elem_raw(&bucket->head, hash, key, key_size); return elem ? elem->sk : NULL; @@ -972,7 +977,7 @@ static void sock_hash_delete_from_link(struct bpf_map *map, struct sock *sk, struct bpf_shtab_bucket *bucket; WARN_ON_ONCE(!rcu_read_lock_held()); - bucket = sock_hash_select_bucket(htab, elem->hash); + bucket = sock_hash_select_bucket(htab, elem->bucket_hash); /* elem may be deleted in parallel from the map, but access here * is okay since it's going away only after RCU grace period. @@ -989,13 +994,14 @@ static void sock_hash_delete_from_link(struct bpf_map *map, struct sock *sk, static long sock_hash_delete_elem(struct bpf_map *map, void *key) { struct bpf_shtab *htab = container_of(map, struct bpf_shtab, map); - u32 hash, key_size = map->key_size; + u32 bucket_hash, hash, key_size = map->key_size; struct bpf_shtab_bucket *bucket; struct bpf_shtab_elem *elem; int ret = -ENOENT; - hash = sock_hash_bucket_hash(key, key_size); - bucket = sock_hash_select_bucket(htab, hash); + sock_hash_elem_hash(key, &bucket_hash, &hash, htab->hash_len, + map->key_size); + bucket = sock_hash_select_bucket(htab, bucket_hash); spin_lock_bh(&bucket->lock); elem = sock_hash_lookup_elem_raw(&bucket->head, hash, key, key_size); @@ -1009,7 +1015,8 @@ static long sock_hash_delete_elem(struct bpf_map *map, void *key) static struct bpf_shtab_elem *sock_hash_alloc_elem(struct bpf_shtab *htab, void *key, u32 key_size, - u32 hash, struct sock *sk, + u32 bucket_hash, u32 hash, + struct sock *sk, struct bpf_shtab_elem *old) { struct bpf_shtab_elem *new; @@ -1031,6 +1038,7 @@ static struct bpf_shtab_elem *sock_hash_alloc_elem(struct bpf_shtab *htab, memcpy(new->key, key, key_size); new->sk = sk; new->hash = hash; + new->bucket_hash = bucket_hash; refcount_set(&new->ref, 1); /* Matches sock_put() in sock_hash_free_elem(). Ensure that sk is not * freed until elem is. @@ -1043,7 +1051,7 @@ static int sock_hash_update_common(struct bpf_map *map, void *key, struct sock *sk, u64 flags) { struct bpf_shtab *htab = container_of(map, struct bpf_shtab, map); - u32 key_size = map->key_size, hash; + u32 key_size = map->key_size, bucket_hash, hash; struct bpf_shtab_elem *elem, *elem_new; struct bpf_shtab_bucket *bucket; struct sk_psock_link *link; @@ -1065,8 +1073,9 @@ static int sock_hash_update_common(struct bpf_map *map, void *key, psock = sk_psock(sk); WARN_ON_ONCE(!psock); - hash = sock_hash_bucket_hash(key, key_size); - bucket = sock_hash_select_bucket(htab, hash); + sock_hash_elem_hash(key, &bucket_hash, &hash, htab->hash_len, + map->key_size); + bucket = sock_hash_select_bucket(htab, bucket_hash); spin_lock_bh(&bucket->lock); elem = sock_hash_lookup_elem_raw(&bucket->head, hash, key, key_size); @@ -1078,7 +1087,8 @@ static int sock_hash_update_common(struct bpf_map *map, void *key, goto out_unlock; } - elem_new = sock_hash_alloc_elem(htab, key, key_size, hash, sk, elem); + elem_new = sock_hash_alloc_elem(htab, key, key_size, bucket_hash, hash, + sk, elem); if (IS_ERR(elem_new)) { ret = PTR_ERR(elem_new); goto out_unlock; @@ -1105,15 +1115,16 @@ static int sock_hash_get_next_key(struct bpf_map *map, void *key, void *key_next) { struct bpf_shtab *htab = container_of(map, struct bpf_shtab, map); + u32 bucket_hash, hash, key_size = map->key_size; struct bpf_shtab_elem *elem, *elem_next; - u32 hash, key_size = map->key_size; struct hlist_head *head; int i = 0; if (!key) goto find_first_elem; - hash = sock_hash_bucket_hash(key, key_size); - head = &sock_hash_select_bucket(htab, hash)->head; + sock_hash_elem_hash(key, &bucket_hash, &hash, htab->hash_len, + map->key_size); + head = &sock_hash_select_bucket(htab, bucket_hash)->head; elem = sock_hash_lookup_elem_raw(head, hash, key, key_size); if (!elem) goto find_first_elem; @@ -1125,7 +1136,7 @@ static int sock_hash_get_next_key(struct bpf_map *map, void *key, return 0; } - i = hash & (htab->buckets_num - 1); + i = bucket_hash & (htab->buckets_num - 1); i++; find_first_elem: for (; i < htab->buckets_num; i++) { @@ -1150,7 +1161,11 @@ static struct bpf_map *sock_hash_alloc(union bpf_attr *attr) attr->key_size == 0 || (attr->value_size != sizeof(u32) && attr->value_size != sizeof(u64)) || - attr->map_flags & ~SOCK_CREATE_FLAG_MASK) + attr->map_flags & ~SOCK_CREATE_FLAG_MASK || + /* The lower 32 bits of map_extra specify the number of bytes in + * the key to hash. + */ + attr->map_extra & ~U32_MAX) return ERR_PTR(-EINVAL); if (attr->key_size > MAX_BPF_STACK) return ERR_PTR(-E2BIG); @@ -1164,8 +1179,10 @@ static struct bpf_map *sock_hash_alloc(union bpf_attr *attr) htab->buckets_num = roundup_pow_of_two(htab->map.max_entries); htab->elem_size = sizeof(struct bpf_shtab_elem) + round_up(htab->map.key_size, 8); + htab->hash_len = attr->map_extra ?: attr->key_size; if (htab->buckets_num == 0 || - htab->buckets_num > U32_MAX / sizeof(struct bpf_shtab_bucket)) { + htab->buckets_num > U32_MAX / sizeof(struct bpf_shtab_bucket) || + htab->hash_len > attr->key_size) { err = -EINVAL; goto free_htab; } -- 2.43.0 Complimenting the change to socket hashes that allows users to bucket keys with the same prefix together, support a key prefix filter for socket hash iterators that traverses all the sockets in the bucket matching the provided prefix. Together, the bucketing control and key prefix filter allow for efficient iteration over a set of sockets whose keys share a common prefix without needing to iterate through every key in every bucket to find those that we're interested in. Signed-off-by: Jordan Rife --- include/linux/bpf.h | 4 ++ include/uapi/linux/bpf.h | 7 ++++ net/core/sock_map.c | 67 ++++++++++++++++++++++++++++++---- tools/include/uapi/linux/bpf.h | 7 ++++ 4 files changed, 78 insertions(+), 7 deletions(-) diff --git a/include/linux/bpf.h b/include/linux/bpf.h index 8f6e87f0f3a8..1c7bb1fb3a80 100644 --- a/include/linux/bpf.h +++ b/include/linux/bpf.h @@ -2632,6 +2632,10 @@ struct bpf_iter_aux_info { enum bpf_iter_task_type type; u32 pid; } task; + struct { + void *key_prefix; + u32 key_prefix_len; + } sockhash; }; typedef int (*bpf_iter_attach_target_t)(struct bpf_prog *prog, diff --git a/include/uapi/linux/bpf.h b/include/uapi/linux/bpf.h index 233de8677382..22761dea4635 100644 --- a/include/uapi/linux/bpf.h +++ b/include/uapi/linux/bpf.h @@ -124,6 +124,13 @@ enum bpf_cgroup_iter_order { union bpf_iter_link_info { struct { __u32 map_fd; + union { + /* Parameters for socket hash iterators. */ + struct { + __aligned_u64 key_prefix; /* key prefix filter */ + __u32 key_prefix_len; /* key_prefix length */ + } sock_hash; + }; } map; struct { enum bpf_cgroup_iter_order order; diff --git a/net/core/sock_map.c b/net/core/sock_map.c index 51930f24d2f9..b0b428190561 100644 --- a/net/core/sock_map.c +++ b/net/core/sock_map.c @@ -889,10 +889,16 @@ static inline void sock_hash_elem_hash(const void *key, u32 *bucket_hash, *hash = hash_len == key_size ? *bucket_hash : jhash(key, key_size, 0); } +static inline u32 sock_hash_select_bucket_num(struct bpf_shtab *htab, + u32 hash) +{ + return hash & (htab->buckets_num - 1); +} + static struct bpf_shtab_bucket *sock_hash_select_bucket(struct bpf_shtab *htab, u32 hash) { - return &htab->buckets[hash & (htab->buckets_num - 1)]; + return &htab->buckets[sock_hash_select_bucket_num(htab, hash)]; } static struct bpf_shtab_elem * @@ -1376,6 +1382,8 @@ struct sock_hash_seq_info { struct bpf_map *map; struct bpf_shtab *htab; struct bpf_shtab_elem *next_elem; + void *key_prefix; + u32 key_prefix_len; u32 bucket_id; }; @@ -1384,7 +1392,8 @@ static inline bool bpf_shtab_elem_unhashed(struct bpf_shtab_elem *elem) return READ_ONCE(elem->node.pprev) == LIST_POISON2; } -static struct bpf_shtab_elem *sock_hash_seq_hold_next(struct bpf_shtab_elem *elem) +static struct bpf_shtab_elem *sock_hash_seq_hold_next(struct sock_hash_seq_info *info, + struct bpf_shtab_elem *elem) { hlist_for_each_entry_from_rcu(elem, node) /* It's possible that the first element or its descendants were @@ -1392,6 +1401,9 @@ static struct bpf_shtab_elem *sock_hash_seq_hold_next(struct bpf_shtab_elem *ele * until we get back to the main list. */ if (!bpf_shtab_elem_unhashed(elem) && + (!info->key_prefix || + !memcmp(&elem->key, info->key_prefix, + info->key_prefix_len)) && sock_hash_hold_elem(elem)) return elem; @@ -1435,21 +1447,27 @@ static void *sock_hash_seq_find_next(struct sock_hash_seq_info *info, if (prev_elem) { node = rcu_dereference(hlist_next_rcu(&prev_elem->node)); elem = hlist_entry_safe(node, struct bpf_shtab_elem, node); - elem = sock_hash_seq_hold_next(elem); + elem = sock_hash_seq_hold_next(info, elem); if (elem) goto unlock; - - /* no more elements, continue in the next bucket */ - info->bucket_id++; + if (info->key_prefix) + /* no more elements, skip to the end */ + info->bucket_id = htab->buckets_num; + else + /* no more elements, continue in the next bucket */ + info->bucket_id++; } for (; info->bucket_id < htab->buckets_num; info->bucket_id++) { bucket = &htab->buckets[info->bucket_id]; node = rcu_dereference(hlist_first_rcu(&bucket->head)); elem = hlist_entry_safe(node, struct bpf_shtab_elem, node); - elem = sock_hash_seq_hold_next(elem); + elem = sock_hash_seq_hold_next(info, elem); if (elem) goto unlock; + if (info->key_prefix) + /* no more elements, skip to the end */ + info->bucket_id = htab->buckets_num; } unlock: /* sock_hash_put_elem() will free all elements up until the @@ -1544,10 +1562,18 @@ static int sock_hash_init_seq_private(void *priv_data, struct bpf_iter_aux_info *aux) { struct sock_hash_seq_info *info = priv_data; + u32 hash; bpf_map_inc_with_uref(aux->map); info->map = aux->map; info->htab = container_of(aux->map, struct bpf_shtab, map); + info->key_prefix = aux->sockhash.key_prefix; + info->key_prefix_len = aux->sockhash.key_prefix_len; + if (info->key_prefix) { + sock_hash_elem_hash(info->key_prefix, &hash, &hash, + info->key_prefix_len, info->key_prefix_len); + info->bucket_id = sock_hash_select_bucket_num(info->htab, hash); + } return 0; } @@ -2039,8 +2065,12 @@ static int sock_map_iter_attach_target(struct bpf_prog *prog, union bpf_iter_link_info *linfo, struct bpf_iter_aux_info *aux) { + void __user *ukey_prefix; + struct bpf_shtab *htab; struct bpf_map *map; + u32 key_prefix_len; int err = -EINVAL; + void *key_prefix; if (!linfo->map.map_fd) return -EBADF; @@ -2053,6 +2083,27 @@ static int sock_map_iter_attach_target(struct bpf_prog *prog, map->map_type != BPF_MAP_TYPE_SOCKHASH) goto put_map; + if (map->map_type == BPF_MAP_TYPE_SOCKHASH) { + ukey_prefix = u64_to_user_ptr(linfo->map.sock_hash.key_prefix); + key_prefix_len = linfo->map.sock_hash.key_prefix_len; + htab = container_of(map, struct bpf_shtab, map); + + if (ukey_prefix) { + if (key_prefix_len != htab->hash_len) + goto put_map; + key_prefix = vmemdup_user(ukey_prefix, key_prefix_len); + if (IS_ERR(key_prefix)) { + err = PTR_ERR(key_prefix); + goto put_map; + } + } else if (linfo->map.sock_hash.key_prefix_len) { + goto put_map; + } + + aux->sockhash.key_prefix_len = key_prefix_len; + aux->sockhash.key_prefix = key_prefix; + } + if (prog->aux->max_rdonly_access > map->key_size) { err = -EACCES; goto put_map; @@ -2069,6 +2120,8 @@ static int sock_map_iter_attach_target(struct bpf_prog *prog, static void sock_map_iter_detach_target(struct bpf_iter_aux_info *aux) { bpf_map_put_with_uref(aux->map); + if (aux->sockhash.key_prefix) + kvfree(aux->sockhash.key_prefix); } static struct bpf_iter_reg sock_map_iter_reg = { diff --git a/tools/include/uapi/linux/bpf.h b/tools/include/uapi/linux/bpf.h index 233de8677382..22761dea4635 100644 --- a/tools/include/uapi/linux/bpf.h +++ b/tools/include/uapi/linux/bpf.h @@ -124,6 +124,13 @@ enum bpf_cgroup_iter_order { union bpf_iter_link_info { struct { __u32 map_fd; + union { + /* Parameters for socket hash iterators. */ + struct { + __aligned_u64 key_prefix; /* key prefix filter */ + __u32 key_prefix_len; /* key_prefix length */ + } sock_hash; + }; } map; struct { enum bpf_cgroup_iter_order order; -- 2.43.0 This test is meant to destroy all sockets left in the current bucket, but currently destroys all sockets except the last one. This worked for the normal TCP socket iterator tests, since the last socket was removed from the bucket anyway when its counterpart was destroyed. However, with socket hash iterators this doesn't work, since the last socket stays in the bucket until it's closed or destroyed explicitly. Fix this before the next patch which adds test coverage for socket hash iterators. Signed-off-by: Jordan Rife --- tools/testing/selftests/bpf/prog_tests/sock_iter_batch.c | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tools/testing/selftests/bpf/prog_tests/sock_iter_batch.c b/tools/testing/selftests/bpf/prog_tests/sock_iter_batch.c index 27781df8f2fb..e6fc4fd994f9 100644 --- a/tools/testing/selftests/bpf/prog_tests/sock_iter_batch.c +++ b/tools/testing/selftests/bpf/prog_tests/sock_iter_batch.c @@ -463,7 +463,7 @@ static void remove_all_established(int family, int sock_type, const char *addr, for (i = 0; i < established_socks_len - 1; i++) { close_idx[i] = get_nth_socket(established_socks, established_socks_len, link, - listen_socks_len + i); + listen_socks_len + i + 1); if (!ASSERT_GE(close_idx[i], 0, "close_idx")) return; } -- 2.43.0 Extend the suite of tests that exercise edge cases around iteration over multiple sockets in the same bucket to cover socket hashes using key prefix filtering. Signed-off-by: Jordan Rife --- .../bpf/prog_tests/sock_iter_batch.c | 119 +++++++++++++++++- .../selftests/bpf/progs/sock_iter_batch.c | 31 +++++ 2 files changed, 147 insertions(+), 3 deletions(-) diff --git a/tools/testing/selftests/bpf/prog_tests/sock_iter_batch.c b/tools/testing/selftests/bpf/prog_tests/sock_iter_batch.c index e6fc4fd994f9..2034ddfdf134 100644 --- a/tools/testing/selftests/bpf/prog_tests/sock_iter_batch.c +++ b/tools/testing/selftests/bpf/prog_tests/sock_iter_batch.c @@ -10,6 +10,7 @@ #define TEST_CHILD_NS "sock_iter_batch_child_netns" static const int init_batch_size = 16; +static const __u32 key_prefix = 1; static const int nr_soreuse = 4; struct iter_out { @@ -255,6 +256,31 @@ static int *connect_to_server(int family, int sock_type, const char *addr, return NULL; } +static int insert_sockets_hash(struct bpf_map *sock_map, __u32 first_id, + int *sock_fds, int sock_fds_len) +{ + int map_fd = bpf_map__fd(sock_map); + struct { + __u32 bucket_key; + __u32 id; + } key = { + .bucket_key = key_prefix, + }; + __s64 sfd; + int ret; + __u32 i; + + for (i = 0; i < sock_fds_len; i++) { + sfd = sock_fds[i]; + key.id = first_id + i; + ret = bpf_map_update_elem(map_fd, &key, &sfd, BPF_NOEXIST); + if (!ASSERT_OK(ret, "map_update")) + return -1; + } + + return 0; +} + static void remove_seen(int family, int sock_type, const char *addr, __u16 port, int *socks, int socks_len, int *established_socks, int established_socks_len, struct sock_count *counts, @@ -609,6 +635,7 @@ struct test_case { int init_socks; int max_socks; int sock_type; + bool fill_map; int family; }; @@ -660,6 +687,33 @@ static struct test_case resume_tests[] = { .family = AF_INET6, .test = force_realloc, }, + { + .description = "sockhash: udp: resume after removing a seen socket", + .init_socks = nr_soreuse, + .max_socks = nr_soreuse, + .sock_type = SOCK_DGRAM, + .family = AF_INET6, + .test = remove_seen, + .fill_map = true, + }, + { + .description = "sockhash: udp: resume after removing one unseen socket", + .init_socks = nr_soreuse, + .max_socks = nr_soreuse, + .sock_type = SOCK_DGRAM, + .family = AF_INET6, + .test = remove_unseen, + .fill_map = true, + }, + { + .description = "sockhash: udp: resume after removing all unseen sockets", + .init_socks = nr_soreuse, + .max_socks = nr_soreuse, + .sock_type = SOCK_DGRAM, + .family = AF_INET6, + .test = remove_all, + .fill_map = true, + }, { .description = "tcp: resume after removing a seen socket (listening)", .init_socks = nr_soreuse, @@ -770,13 +824,49 @@ static struct test_case resume_tests[] = { .family = AF_INET6, .test = force_realloc_established, }, + { + .description = "sockhash: tcp: resume after removing a seen socket", + .connections = nr_soreuse, + .init_socks = nr_soreuse, + /* Room for connect()ed and accept()ed sockets */ + .max_socks = nr_soreuse * 3, + .sock_type = SOCK_STREAM, + .family = AF_INET6, + .test = remove_seen_established, + .fill_map = true, + }, + { + .description = "sockhash: tcp: resume after removing one unseen socket", + .connections = nr_soreuse, + .init_socks = nr_soreuse, + /* Room for connect()ed and accept()ed sockets */ + .max_socks = nr_soreuse * 3, + .sock_type = SOCK_STREAM, + .family = AF_INET6, + .test = remove_unseen_established, + .fill_map = true, + }, + { + .description = "sockhash: tcp: resume after removing all unseen sockets", + .connections = nr_soreuse, + .init_socks = nr_soreuse, + /* Room for connect()ed and accept()ed sockets */ + .max_socks = nr_soreuse * 3, + .sock_type = SOCK_STREAM, + .family = AF_INET6, + .test = remove_all_established, + .fill_map = true, + }, }; static void do_resume_test(struct test_case *tc) { + DECLARE_LIBBPF_OPTS(bpf_iter_attach_opts, opts); + union bpf_iter_link_info linfo = {}; struct sock_iter_batch *skel = NULL; struct sock_count *counts = NULL; static const __u16 port = 10001; + struct bpf_program *prog = NULL; struct nstoken *nstoken = NULL; struct bpf_link *link = NULL; int *established_fds = NULL; @@ -825,10 +915,33 @@ static void do_resume_test(struct test_case *tc) if (!ASSERT_OK(err, "sock_iter_batch__load")) goto done; - link = bpf_program__attach_iter(tc->sock_type == SOCK_STREAM ? + if (tc->fill_map) { + /* Established sockets must be inserted first so that all + * listening sockets will be seen first during iteration. + */ + if (!ASSERT_OK(insert_sockets_hash(skel->maps.sockets, 0, + established_fds, + tc->connections*2), + "insert_sockets_hash")) + goto done; + if (!ASSERT_OK(insert_sockets_hash(skel->maps.sockets, + tc->connections*2, fds, + tc->init_socks), + "insert_sockets_hash")) + goto done; + linfo.map.map_fd = bpf_map__fd(skel->maps.sockets); + linfo.map.sock_hash.key_prefix = (__u64)(void *)&key_prefix; + linfo.map.sock_hash.key_prefix_len = sizeof(key_prefix); + opts.link_info = &linfo; + opts.link_info_len = sizeof(linfo); + prog = skel->progs.iter_sockmap; + } else { + prog = tc->sock_type == SOCK_STREAM ? skel->progs.iter_tcp_soreuse : - skel->progs.iter_udp_soreuse, - NULL); + skel->progs.iter_udp_soreuse; + } + + link = bpf_program__attach_iter(prog, &opts); if (!ASSERT_OK_PTR(link, "bpf_program__attach_iter")) goto done; diff --git a/tools/testing/selftests/bpf/progs/sock_iter_batch.c b/tools/testing/selftests/bpf/progs/sock_iter_batch.c index 77966ded5467..a19581f19eda 100644 --- a/tools/testing/selftests/bpf/progs/sock_iter_batch.c +++ b/tools/testing/selftests/bpf/progs/sock_iter_batch.c @@ -130,4 +130,35 @@ int iter_udp_soreuse(struct bpf_iter__udp *ctx) return 0; } +struct sock_hash_key { + __u32 bucket_key; + __u32 id; +}; + +struct { + __uint(type, BPF_MAP_TYPE_SOCKHASH); + __uint(max_entries, 16); + __ulong(map_extra, offsetof(struct sock_hash_key, id)); + __type(key, sizeof(struct sock_hash_key)); + __type(value, __u64); +} sockets SEC(".maps"); + +SEC("iter/sockmap") +int iter_sockmap(struct bpf_iter__sockmap *ctx) +{ + struct sock *sk = ctx->sk; + __u32 *key = ctx->key; + __u64 sock_cookie; + int idx = 0; + + if (!key || !sk) + return 0; + + sock_cookie = bpf_get_socket_cookie(sk); + bpf_seq_write(ctx->meta->seq, &idx, sizeof(idx)); + bpf_seq_write(ctx->meta->seq, &sock_cookie, sizeof(sock_cookie)); + + return 0; +} + char _license[] SEC("license") = "GPL"; -- 2.43.0 Use a sockops program to automatically insert sockets into a socket map and socket hash and use BPF iterators with key prefix bucketing and filtering to destroy the set of sockets connected to the same remote port regardless of protocol. This test wraps things up by demonstrating the desired end to end flow and showing how all the pieces are meant to fit together. Signed-off-by: Jordan Rife --- .../selftests/bpf/prog_tests/sockmap_basic.c | 277 ++++++++++++++++++ .../selftests/bpf/progs/bpf_iter_sockmap.c | 14 + .../selftests/bpf/progs/test_sockmap_update.c | 43 +++ 3 files changed, 334 insertions(+) diff --git a/tools/testing/selftests/bpf/prog_tests/sockmap_basic.c b/tools/testing/selftests/bpf/prog_tests/sockmap_basic.c index 1e3e4392dcca..00afa377cf7d 100644 --- a/tools/testing/selftests/bpf/prog_tests/sockmap_basic.c +++ b/tools/testing/selftests/bpf/prog_tests/sockmap_basic.c @@ -16,6 +16,7 @@ #include "bpf_iter_sockmap.skel.h" #include "sockmap_helpers.h" +#include "network_helpers.h" #define TCP_REPAIR 19 /* TCP sock is under repair right now */ @@ -364,6 +365,280 @@ static void test_sockmap_copy(enum bpf_map_type map_type) bpf_iter_sockmap__destroy(skel); } +#define TEST_NS "sockmap_basic" + +struct sock_hash_key { + __u32 bucket_key; + __u64 cookie; +} __packed; + +static void close_fds(int fds[], int fds_len) +{ + int i; + + for (i = 0; i < fds_len; i++) + if (fds[i] >= 0) + close(fds[i]); +} + +static __u64 socket_cookie(int fd) +{ + __u64 cookie; + socklen_t cookie_len = sizeof(cookie); + + if (!ASSERT_OK(getsockopt(fd, SOL_SOCKET, SO_COOKIE, &cookie, + &cookie_len), "getsockopt(SO_COOKIE)")) + return 0; + return cookie; +} + +static bool has_socket(struct bpf_map *map, __u64 sk_cookie, int key_size) +{ + void *prev_key = NULL, *key = NULL; + int map_fd = bpf_map__fd(map); + bool found = false; + __u64 cookie; + int err; + + key = malloc(key_size); + if (!ASSERT_OK_PTR(key, "malloc(key_size)")) + goto cleanup; + + prev_key = malloc(key_size); + if (!ASSERT_OK_PTR(key, "malloc(key_size)")) + goto cleanup; + + err = bpf_map__get_next_key(map, NULL, key, key_size); + if (!ASSERT_OK(err, "get_next_key")) + goto cleanup; + + do { + err = bpf_map_lookup_elem(map_fd, key, &cookie); + if (!err) + found = sk_cookie == cookie; + else if (!ASSERT_EQ(err, -ENOENT, "bpf_map_lookup_elem")) + goto cleanup; + + memcpy(prev_key, key, key_size); + } while (!found && + bpf_map__get_next_key(map, prev_key, key, key_size) == 0); +cleanup: + if (prev_key) + free(prev_key); + if (key) + free(key); + return found; +} + +static void test_sockmap_insert_sockops_and_destroy(void) +{ + DECLARE_LIBBPF_OPTS(bpf_iter_attach_opts, opts); + struct test_sockmap_update *update_skel = NULL; + static const int port0 = 10000, port1 = 10001; + int prog_fd = -1, cg_fd = -1, iter_fd = -1; + struct bpf_iter_sockmap *iter_skel = NULL; + __u32 key_prefix = htonl((__u32)port0); + int accept_serv[4] = {-1, -1, -1, -1}; + int tcp_clien[4] = {-1, -1, -1, -1}; + union bpf_iter_link_info linfo = {}; + int tcp_serv[4] = {-1, -1, -1, -1}; + struct nstoken *nstoken = NULL; + int tcp_clien_cookies[4] = {}; + struct bpf_link *link = NULL; + char buf[64]; + int len; + int i; + + SYS_NOFAIL("ip netns del " TEST_NS); + SYS(cleanup, "ip netns add %s", TEST_NS); + SYS(cleanup, "ip -net %s link set dev lo up", TEST_NS); + + nstoken = open_netns(TEST_NS); + if (!ASSERT_OK_PTR(nstoken, "open_netns")) + goto cleanup; + + cg_fd = test__join_cgroup("/sockmap_basic"); + if (!ASSERT_OK_FD(cg_fd, "join_cgroup")) + goto cleanup; + + update_skel = test_sockmap_update__open_and_load(); + if (!ASSERT_OK_PTR(update_skel, "test_sockmap_update__open_and_load")) + goto cleanup; + + iter_skel = bpf_iter_sockmap__open_and_load(); + if (!ASSERT_OK_PTR(iter_skel, "bpf_iter_sockmap__open_and_load")) + goto cleanup; + + if (!ASSERT_OK(bpf_prog_attach(bpf_program__fd(update_skel->progs.insert_sock), + cg_fd, BPF_CGROUP_SOCK_OPS, + BPF_F_ALLOW_OVERRIDE), + "bpf_prog_attach")) + goto cleanup; + + /* Create two servers on each port, port0 and port1, and connect a + * client to each. + */ + tcp_serv[0] = start_server(AF_INET, SOCK_STREAM, "127.0.0.1", port0, 0); + if (!ASSERT_OK_FD(tcp_serv[0], "start_server")) + goto cleanup; + + tcp_serv[1] = start_server(AF_INET6, SOCK_STREAM, "::1", port0, 0); + if (!ASSERT_OK_FD(tcp_serv[1], "start_server")) + goto cleanup; + + tcp_serv[2] = start_server(AF_INET, SOCK_STREAM, "127.0.0.1", port1, 0); + if (!ASSERT_OK_FD(tcp_serv[2], "start_server")) + goto cleanup; + + tcp_serv[3] = start_server(AF_INET6, SOCK_STREAM, "::1", port1, 0); + if (!ASSERT_OK_FD(tcp_serv[3], "start_server")) + goto cleanup; + + for (i = 0; i < ARRAY_SIZE(tcp_serv); i++) { + tcp_clien[i] = connect_to_fd(tcp_serv[i], 0); + if (!ASSERT_OK_FD(tcp_clien[i], "connect_to_fd")) + goto cleanup; + + accept_serv[i] = accept(tcp_serv[i], NULL, NULL); + if (!ASSERT_OK_FD(accept_serv[i], "accept")) + goto cleanup; + } + + /* Ensure that sockets are connected. */ + for (i = 0; i < ARRAY_SIZE(tcp_clien); i++) + if (!ASSERT_EQ(send(tcp_clien[i], "a", 1, 0), 1, "send")) + goto cleanup; + + /* Ensure that client sockets exist in the map and the hash. */ + if (!ASSERT_EQ(update_skel->bss->count, + ARRAY_SIZE(tcp_clien) + ARRAY_SIZE(udp_clien), + "count")) + goto cleanup; + + for (i = 0; i < ARRAY_SIZE(tcp_clien); i++) + tcp_clien_cookies[i] = socket_cookie(tcp_clien[i]); + + for (i = 0; i < ARRAY_SIZE(tcp_clien); i++) { + if (!ASSERT_TRUE(has_socket(update_skel->maps.sock_map, + tcp_clien_cookies[i], + sizeof(__u32)), + "has_socket")) + goto cleanup; + + if (!ASSERT_TRUE(has_socket(update_skel->maps.sock_hash, + tcp_clien_cookies[i], + sizeof(struct sock_hash_key)), + "has_socket")) + goto cleanup; + } + + /* Destroy sockets connected to port0. */ + linfo.map.map_fd = bpf_map__fd(update_skel->maps.sock_hash); + linfo.map.sock_hash.key_prefix = (__u64)(void *)&key_prefix; + linfo.map.sock_hash.key_prefix_len = sizeof(key_prefix); + opts.link_info = &linfo; + opts.link_info_len = sizeof(linfo); + link = bpf_program__attach_iter(iter_skel->progs.destroy, &opts); + if (!ASSERT_OK_PTR(link, "bpf_program__attach_iter")) + goto cleanup; + + iter_fd = bpf_iter_create(bpf_link__fd(link)); + if (!ASSERT_OK_FD(iter_fd, "bpf_iter_create")) + goto cleanup; + + while ((len = read(iter_fd, buf, sizeof(buf))) > 0) + ; + if (!ASSERT_GE(len, 0, "read")) + goto cleanup; + + /* Ensure that sockets connected to port0 were destroyed. */ + if (!ASSERT_LT(send(tcp_clien[0], "a", 1, 0), 0, "send")) + goto cleanup; + if (!ASSERT_EQ(errno, ECONNABORTED, "ECONNABORTED")) + goto cleanup; + + if (!ASSERT_LT(send(tcp_clien[1], "a", 1, 0), 0, "send")) + goto cleanup; + if (!ASSERT_EQ(errno, ECONNABORTED, "ECONNABORTED")) + goto cleanup; + + if (!ASSERT_EQ(send(tcp_clien[2], "a", 1, 0), 1, "send")) + goto cleanup; + + if (!ASSERT_EQ(send(tcp_clien[3], "a", 1, 0), 1, "send")) + goto cleanup; + + /* Close and ensure that sockets are removed from maps. */ + close(tcp_clien[0]); + close(tcp_clien[1]); + + /* Ensure that the sockets connected to port0 were removed from the + * maps. + */ + if (!ASSERT_FALSE(has_socket(update_skel->maps.sock_map, + tcp_clien_cookies[0], + sizeof(__u32)), + "has_socket")) + goto cleanup; + + if (!ASSERT_FALSE(has_socket(update_skel->maps.sock_map, + tcp_clien_cookies[1], + sizeof(__u32)), + "has_socket")) + goto cleanup; + + if (!ASSERT_TRUE(has_socket(update_skel->maps.sock_map, + tcp_clien_cookies[2], + sizeof(__u32)), + "has_socket")) + goto cleanup; + + if (!ASSERT_TRUE(has_socket(update_skel->maps.sock_map, + tcp_clien_cookies[3], + sizeof(__u32)), + "has_socket")) + goto cleanup; + + if (!ASSERT_FALSE(has_socket(update_skel->maps.sock_hash, + tcp_clien_cookies[0], + sizeof(struct sock_hash_key)), + "has_socket")) + goto cleanup; + + if (!ASSERT_FALSE(has_socket(update_skel->maps.sock_hash, + tcp_clien_cookies[1], + sizeof(struct sock_hash_key)), + "has_socket")) + goto cleanup; + + if (!ASSERT_TRUE(has_socket(update_skel->maps.sock_hash, + tcp_clien_cookies[2], + sizeof(struct sock_hash_key)), + "has_socket")) + goto cleanup; + + if (!ASSERT_TRUE(has_socket(update_skel->maps.sock_hash, + tcp_clien_cookies[3], + sizeof(struct sock_hash_key)), + "has_socket")) + goto cleanup; +cleanup: + close_fds(accept_serv, ARRAY_SIZE(accept_serv)); + close_fds(tcp_clien, ARRAY_SIZE(tcp_clien)); + close_fds(tcp_serv, ARRAY_SIZE(tcp_serv)); + if (prog_fd >= 0) + bpf_prog_detach(cg_fd, BPF_CGROUP_SOCK_OPS); + if (cg_fd >= 0) + close(cg_fd); + if (iter_fd >= 0) + close(iter_fd); + bpf_link__destroy(link); + test_sockmap_update__destroy(update_skel); + bpf_iter_sockmap__destroy(iter_skel); + close_netns(nstoken); + SYS_NOFAIL("ip netns del " TEST_NS); +} + static void test_sockmap_skb_verdict_attach(enum bpf_attach_type first, enum bpf_attach_type second) { @@ -1064,6 +1339,8 @@ void test_sockmap_basic(void) test_sockmap_copy(BPF_MAP_TYPE_SOCKMAP); if (test__start_subtest("sockhash copy")) test_sockmap_copy(BPF_MAP_TYPE_SOCKHASH); + if (test__start_subtest("sock(map|hash) sockops insert and destroy")) + test_sockmap_insert_sockops_and_destroy(); if (test__start_subtest("sockmap skb_verdict attach")) { test_sockmap_skb_verdict_attach(BPF_SK_SKB_VERDICT, BPF_SK_SKB_STREAM_VERDICT); diff --git a/tools/testing/selftests/bpf/progs/bpf_iter_sockmap.c b/tools/testing/selftests/bpf/progs/bpf_iter_sockmap.c index 317fe49760cc..9eb2bee443c1 100644 --- a/tools/testing/selftests/bpf/progs/bpf_iter_sockmap.c +++ b/tools/testing/selftests/bpf/progs/bpf_iter_sockmap.c @@ -57,3 +57,17 @@ int copy(struct bpf_iter__sockmap *ctx) ret = bpf_map_delete_elem(&dst, &tmp); return ret && ret != -ENOENT; } + +SEC("iter/sockmap") +int destroy(struct bpf_iter__sockmap *ctx) +{ + struct sock *sk = ctx->sk; + void *key = ctx->key; + + if (!key || !sk) + return 0; + + bpf_sock_destroy((struct sock_common *)sk); + + return 0; +} diff --git a/tools/testing/selftests/bpf/progs/test_sockmap_update.c b/tools/testing/selftests/bpf/progs/test_sockmap_update.c index 6d64ea536e3d..eb84753c6a1a 100644 --- a/tools/testing/selftests/bpf/progs/test_sockmap_update.c +++ b/tools/testing/selftests/bpf/progs/test_sockmap_update.c @@ -45,4 +45,47 @@ int copy_sock_map(void *ctx) return failed ? SK_DROP : SK_PASS; } +__u32 count = 0; + +struct sock_hash_key { + __u32 bucket_key; + __u64 cookie; +} __attribute__((__packed__)); + +struct { + __uint(type, BPF_MAP_TYPE_SOCKHASH); + __uint(max_entries, 16); + __ulong(map_extra, offsetof(struct sock_hash_key, cookie)); + __type(key, struct sock_hash_key); + __type(value, __u64); +} sock_hash SEC(".maps"); + +struct { + __uint(type, BPF_MAP_TYPE_SOCKMAP); + __uint(max_entries, 16); + __type(key, __u32); + __type(value, __u64); +} sock_map SEC(".maps"); + +SEC("sockops") +int insert_sock(struct bpf_sock_ops *skops) +{ + struct sock_hash_key key = { + .bucket_key = skops->remote_port, + .cookie = bpf_get_socket_cookie(skops), + }; + + switch (skops->op) { + case BPF_SOCK_OPS_ACTIVE_ESTABLISHED_CB: + bpf_sock_hash_update(skops, &sock_hash, &key, BPF_NOEXIST); + bpf_sock_map_update(skops, &sock_map, &count, BPF_NOEXIST); + count++; + break; + default: + break; + } + + return 0; +} + char _license[] SEC("license") = "GPL"; -- 2.43.0 Add the BPF_SOCK_OPS_UDP_CONNECTED_CB callback as a sockops hook where connected UDP sockets can be inserted into a socket map. This is invoked on calls to connect() for UDP sockets right after the socket is hashed. Together with the next patch, this provides the missing piece allowing us to fully manage the contents of a socket hash in an environment where we want to keep track of all UDP and TCP sockets connected to some backend. is_locked_tcp_sock was recently introduced in [1] to prevent access to TCP-specific socket fields in contexts where the socket lock isn't held. This patch extends the use of this field to prevent access to these fields in UDP socket contexts. Note: Technically, there should be nothing preventing the use of bpf_sock_ops_setsockopt() and bpf_sock_ops_getsockopt() in this context, but I've avoided removing the is_locked_tcp_sock_ops() guard from these helpers for now to keep the changes in this patch series more focused. [1]: https://lore.kernel.org/all/20250220072940.99994-4-kerneljasonxing@gmail.com/ Signed-off-by: Jordan Rife --- include/net/udp.h | 43 ++++++++++++++++++++++++++++++++++ include/uapi/linux/bpf.h | 3 +++ net/ipv4/udp.c | 1 + net/ipv6/udp.c | 1 + tools/include/uapi/linux/bpf.h | 3 +++ 5 files changed, 51 insertions(+) diff --git a/include/net/udp.h b/include/net/udp.h index e2af3bda90c9..0f55c489e90f 100644 --- a/include/net/udp.h +++ b/include/net/udp.h @@ -18,6 +18,7 @@ #ifndef _UDP_H #define _UDP_H +#include #include #include #include @@ -25,6 +26,7 @@ #include #include #include +#include #include #include #include @@ -661,4 +663,45 @@ struct sk_psock; int udp_bpf_update_proto(struct sock *sk, struct sk_psock *psock, bool restore); #endif +#ifdef CONFIG_BPF + +/* Call BPF_SOCK_OPS program that returns an int. If the return value + * is < 0, then the BPF op failed (for example if the loaded BPF + * program does not support the chosen operation or there is no BPF + * program loaded). + */ +static inline int udp_call_bpf(struct sock *sk, int op) +{ + struct bpf_sock_ops_kern sock_ops; + int ret; + + memset(&sock_ops, 0, offsetof(struct bpf_sock_ops_kern, temp)); + if (sk_fullsock(sk)) { + sock_ops.is_fullsock = 1; + /* sock_ops.is_locked_tcp_sock not set. This prevents + * access to TCP-specific fields. + */ + sock_owned_by_me(sk); + } + + sock_ops.sk = sk; + sock_ops.op = op; + + ret = BPF_CGROUP_RUN_PROG_SOCK_OPS(&sock_ops); + if (ret == 0) + ret = sock_ops.reply; + else + ret = -1; + return ret; +} + +#else + +static inline int udp_call_bpf(struct sock *sk, int op, u32 nargs, u32 *args) +{ + return -EPERM; +} + +#endif /* CONFIG_BPF */ + #endif /* _UDP_H */ diff --git a/include/uapi/linux/bpf.h b/include/uapi/linux/bpf.h index 22761dea4635..e30515af1f27 100644 --- a/include/uapi/linux/bpf.h +++ b/include/uapi/linux/bpf.h @@ -7122,6 +7122,9 @@ enum { * sendmsg timestamp with corresponding * tskey. */ + BPF_SOCK_OPS_UDP_CONNECTED_CB, /* Called on connect() for UDP sockets + * right after the socket is hashed. + */ }; /* List of TCP states. There is a build check in net/ipv4/tcp.c to detect diff --git a/net/ipv4/udp.c b/net/ipv4/udp.c index cc3ce0f762ec..2d51d0ead70d 100644 --- a/net/ipv4/udp.c +++ b/net/ipv4/udp.c @@ -2153,6 +2153,7 @@ static int udp_connect(struct sock *sk, struct sockaddr *uaddr, int addr_len) res = __ip4_datagram_connect(sk, uaddr, addr_len); if (!res) udp4_hash4(sk); + udp_call_bpf(sk, BPF_SOCK_OPS_UDP_CONNECTED_CB); release_sock(sk); return res; } diff --git a/net/ipv6/udp.c b/net/ipv6/udp.c index 6a68f77da44b..304b43851e16 100644 --- a/net/ipv6/udp.c +++ b/net/ipv6/udp.c @@ -1310,6 +1310,7 @@ static int udpv6_connect(struct sock *sk, struct sockaddr *uaddr, int addr_len) res = __ip6_datagram_connect(sk, uaddr, addr_len); if (!res) udp6_hash4(sk); + udp_call_bpf(sk, BPF_SOCK_OPS_UDP_CONNECTED_CB); release_sock(sk); return res; } diff --git a/tools/include/uapi/linux/bpf.h b/tools/include/uapi/linux/bpf.h index 22761dea4635..e30515af1f27 100644 --- a/tools/include/uapi/linux/bpf.h +++ b/tools/include/uapi/linux/bpf.h @@ -7122,6 +7122,9 @@ enum { * sendmsg timestamp with corresponding * tskey. */ + BPF_SOCK_OPS_UDP_CONNECTED_CB, /* Called on connect() for UDP sockets + * right after the socket is hashed. + */ }; /* List of TCP states. There is a build check in net/ipv4/tcp.c to detect -- 2.43.0 Finally, enable the use of bpf_sock_map_update and bpf_sock_hash_update from the BPF_SOCK_OPS_UDP_CONNECTED_CB sockops hook to allow automatic management of the contents of a socket hash. Signed-off-by: Jordan Rife --- net/core/sock_map.c | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/net/core/sock_map.c b/net/core/sock_map.c index b0b428190561..08b6d647100c 100644 --- a/net/core/sock_map.c +++ b/net/core/sock_map.c @@ -522,7 +522,8 @@ static bool sock_map_op_okay(const struct bpf_sock_ops_kern *ops) { return ops->op == BPF_SOCK_OPS_PASSIVE_ESTABLISHED_CB || ops->op == BPF_SOCK_OPS_ACTIVE_ESTABLISHED_CB || - ops->op == BPF_SOCK_OPS_TCP_LISTEN_CB; + ops->op == BPF_SOCK_OPS_TCP_LISTEN_CB || + ops->op == BPF_SOCK_OPS_UDP_CONNECTED_CB; } static bool sock_map_redirect_allowed(const struct sock *sk) -- 2.43.0 Exercise BPF_SOCK_OPS_UDP_CONNECTED_CB by extending the socket map insert and destroy tests. Signed-off-by: Jordan Rife --- .../selftests/bpf/prog_tests/sockmap_basic.c | 110 ++++++++++++++++++ .../selftests/bpf/progs/test_sockmap_update.c | 1 + 2 files changed, 111 insertions(+) diff --git a/tools/testing/selftests/bpf/prog_tests/sockmap_basic.c b/tools/testing/selftests/bpf/prog_tests/sockmap_basic.c index 00afa377cf7d..7506de15611e 100644 --- a/tools/testing/selftests/bpf/prog_tests/sockmap_basic.c +++ b/tools/testing/selftests/bpf/prog_tests/sockmap_basic.c @@ -440,10 +440,13 @@ static void test_sockmap_insert_sockops_and_destroy(void) __u32 key_prefix = htonl((__u32)port0); int accept_serv[4] = {-1, -1, -1, -1}; int tcp_clien[4] = {-1, -1, -1, -1}; + int udp_clien[4] = {-1, -1, -1, -1}; union bpf_iter_link_info linfo = {}; int tcp_serv[4] = {-1, -1, -1, -1}; + int udp_serv[4] = {-1, -1, -1, -1}; struct nstoken *nstoken = NULL; int tcp_clien_cookies[4] = {}; + int udp_clien_cookies[4] = {}; struct bpf_link *link = NULL; char buf[64]; int len; @@ -494,6 +497,22 @@ static void test_sockmap_insert_sockops_and_destroy(void) if (!ASSERT_OK_FD(tcp_serv[3], "start_server")) goto cleanup; + udp_serv[0] = start_server(AF_INET, SOCK_DGRAM, "127.0.0.1", port0, 0); + if (!ASSERT_OK_FD(udp_serv[0], "start_server")) + goto cleanup; + + udp_serv[1] = start_server(AF_INET6, SOCK_DGRAM, "::1", port0, 0); + if (!ASSERT_OK_FD(udp_serv[1], "start_server")) + goto cleanup; + + udp_serv[2] = start_server(AF_INET, SOCK_DGRAM, "127.0.0.1", port1, 0); + if (!ASSERT_OK_FD(udp_serv[2], "start_server")) + goto cleanup; + + udp_serv[3] = start_server(AF_INET6, SOCK_DGRAM, "::1", port1, 0); + if (!ASSERT_OK_FD(udp_serv[3], "start_server")) + goto cleanup; + for (i = 0; i < ARRAY_SIZE(tcp_serv); i++) { tcp_clien[i] = connect_to_fd(tcp_serv[i], 0); if (!ASSERT_OK_FD(tcp_clien[i], "connect_to_fd")) @@ -504,11 +523,21 @@ static void test_sockmap_insert_sockops_and_destroy(void) goto cleanup; } + for (i = 0; i < ARRAY_SIZE(udp_serv); i++) { + udp_clien[i] = connect_to_fd(udp_serv[i], 0); + if (!ASSERT_OK_FD(udp_clien[i], "connect_to_fd")) + goto cleanup; + } + /* Ensure that sockets are connected. */ for (i = 0; i < ARRAY_SIZE(tcp_clien); i++) if (!ASSERT_EQ(send(tcp_clien[i], "a", 1, 0), 1, "send")) goto cleanup; + for (i = 0; i < ARRAY_SIZE(udp_clien); i++) + if (!ASSERT_EQ(send(udp_clien[i], "a", 1, 0), 1, "send")) + goto cleanup; + /* Ensure that client sockets exist in the map and the hash. */ if (!ASSERT_EQ(update_skel->bss->count, ARRAY_SIZE(tcp_clien) + ARRAY_SIZE(udp_clien), @@ -518,6 +547,9 @@ static void test_sockmap_insert_sockops_and_destroy(void) for (i = 0; i < ARRAY_SIZE(tcp_clien); i++) tcp_clien_cookies[i] = socket_cookie(tcp_clien[i]); + for (i = 0; i < ARRAY_SIZE(udp_clien); i++) + udp_clien_cookies[i] = socket_cookie(udp_clien[i]); + for (i = 0; i < ARRAY_SIZE(tcp_clien); i++) { if (!ASSERT_TRUE(has_socket(update_skel->maps.sock_map, tcp_clien_cookies[i], @@ -532,6 +564,20 @@ static void test_sockmap_insert_sockops_and_destroy(void) goto cleanup; } + for (i = 0; i < ARRAY_SIZE(udp_clien); i++) { + if (!ASSERT_TRUE(has_socket(update_skel->maps.sock_map, + udp_clien_cookies[i], + sizeof(__u32)), + "has_socket")) + goto cleanup; + + if (!ASSERT_TRUE(has_socket(update_skel->maps.sock_hash, + udp_clien_cookies[i], + sizeof(struct sock_hash_key)), + "has_socket")) + goto cleanup; + } + /* Destroy sockets connected to port0. */ linfo.map.map_fd = bpf_map__fd(update_skel->maps.sock_hash); linfo.map.sock_hash.key_prefix = (__u64)(void *)&key_prefix; @@ -568,9 +614,23 @@ static void test_sockmap_insert_sockops_and_destroy(void) if (!ASSERT_EQ(send(tcp_clien[3], "a", 1, 0), 1, "send")) goto cleanup; + if (!ASSERT_LT(send(udp_clien[0], "a", 1, 0), 0, "send")) + goto cleanup; + + if (!ASSERT_LT(send(udp_clien[1], "a", 1, 0), 0, "send")) + goto cleanup; + + if (!ASSERT_EQ(send(udp_clien[2], "a", 1, 0), 1, "send")) + goto cleanup; + + if (!ASSERT_EQ(send(udp_clien[3], "a", 1, 0), 1, "send")) + goto cleanup; + /* Close and ensure that sockets are removed from maps. */ close(tcp_clien[0]); close(tcp_clien[1]); + close(udp_clien[0]); + close(udp_clien[1]); /* Ensure that the sockets connected to port0 were removed from the * maps. @@ -622,10 +682,60 @@ static void test_sockmap_insert_sockops_and_destroy(void) sizeof(struct sock_hash_key)), "has_socket")) goto cleanup; + + if (!ASSERT_FALSE(has_socket(update_skel->maps.sock_map, + udp_clien_cookies[0], + sizeof(__u32)), + "has_socket")) + goto cleanup; + + if (!ASSERT_FALSE(has_socket(update_skel->maps.sock_map, + udp_clien_cookies[1], + sizeof(__u32)), + "has_socket")) + goto cleanup; + + if (!ASSERT_TRUE(has_socket(update_skel->maps.sock_map, + udp_clien_cookies[2], + sizeof(__u32)), + "has_socket")) + goto cleanup; + + if (!ASSERT_TRUE(has_socket(update_skel->maps.sock_map, + udp_clien_cookies[3], + sizeof(__u32)), + "has_socket")) + goto cleanup; + + if (!ASSERT_FALSE(has_socket(update_skel->maps.sock_hash, + udp_clien_cookies[0], + sizeof(struct sock_hash_key)), + "has_socket")) + goto cleanup; + + if (!ASSERT_FALSE(has_socket(update_skel->maps.sock_hash, + udp_clien_cookies[1], + sizeof(struct sock_hash_key)), + "has_socket")) + goto cleanup; + + if (!ASSERT_TRUE(has_socket(update_skel->maps.sock_hash, + udp_clien_cookies[2], + sizeof(struct sock_hash_key)), + "has_socket")) + goto cleanup; + + if (!ASSERT_TRUE(has_socket(update_skel->maps.sock_hash, + udp_clien_cookies[3], + sizeof(struct sock_hash_key)), + "has_socket")) + goto cleanup; cleanup: close_fds(accept_serv, ARRAY_SIZE(accept_serv)); close_fds(tcp_clien, ARRAY_SIZE(tcp_clien)); + close_fds(udp_clien, ARRAY_SIZE(udp_clien)); close_fds(tcp_serv, ARRAY_SIZE(tcp_serv)); + close_fds(udp_serv, ARRAY_SIZE(udp_serv)); if (prog_fd >= 0) bpf_prog_detach(cg_fd, BPF_CGROUP_SOCK_OPS); if (cg_fd >= 0) diff --git a/tools/testing/selftests/bpf/progs/test_sockmap_update.c b/tools/testing/selftests/bpf/progs/test_sockmap_update.c index eb84753c6a1a..0d826004d56d 100644 --- a/tools/testing/selftests/bpf/progs/test_sockmap_update.c +++ b/tools/testing/selftests/bpf/progs/test_sockmap_update.c @@ -77,6 +77,7 @@ int insert_sock(struct bpf_sock_ops *skops) switch (skops->op) { case BPF_SOCK_OPS_ACTIVE_ESTABLISHED_CB: + case BPF_SOCK_OPS_UDP_CONNECTED_CB: bpf_sock_hash_update(skops, &sock_hash, &key, BPF_NOEXIST); bpf_sock_map_update(skops, &sock_map, &count, BPF_NOEXIST); count++; -- 2.43.0 Add documentation explaining how to use map_extra with a BPF_MAP_TYPE_SOCKHASH to control bucketing behavior and how to iterate over a specific bucket using a key prefix filter. Signed-off-by: Jordan Rife --- Documentation/bpf/bpf_iterators.rst | 11 +++++++++++ Documentation/bpf/map_sockmap.rst | 6 ++++++ 2 files changed, 17 insertions(+) diff --git a/Documentation/bpf/bpf_iterators.rst b/Documentation/bpf/bpf_iterators.rst index 189e3ec1c6c8..135bf6a6195c 100644 --- a/Documentation/bpf/bpf_iterators.rst +++ b/Documentation/bpf/bpf_iterators.rst @@ -587,3 +587,14 @@ A BPF task iterator with *pid* includes all tasks (threads) of a process. The BPF program receives these tasks one after another. You can specify a BPF task iterator with *tid* parameter to include only the tasks that match the given *tid*. + +--------------------------------------------- +Parametrizing BPF_MAP_TYPE_SOCKHASH Iterators +--------------------------------------------- + +An iterator for a ``BPF_MAP_TYPE_SOCKHASH`` can limit results to only sockets +whose keys share a common prefix by using a key prefix filter. The key prefix +length must match the value of ``map_extra`` if ``map_extra`` is used in the +``BPF_MAP_TYPE_SOCKHASH`` definition; otherwise, it must match the map key +length. This guarantees that the iterator only visits a single hash bucket, +ensuring efficient iteration over a subset of map elements. diff --git a/Documentation/bpf/map_sockmap.rst b/Documentation/bpf/map_sockmap.rst index 2d630686a00b..505e02c79feb 100644 --- a/Documentation/bpf/map_sockmap.rst +++ b/Documentation/bpf/map_sockmap.rst @@ -76,6 +76,12 @@ sk_msg_buff *msg``. All these helpers will be described in more detail below. +Hashing behavior is configurable for ``BPF_MAP_TYPE_SOCKHASH`` using the lower +32 bits of ``map_extra``. When provided, ``map_extra`` specifies the number of +bytes from a key to use when calculating its bucket hash. This may be used +to force keys sharing a common prefix, e.g. an (address, port) tuple, into the +same bucket for efficient iteration. + Usage ===== Kernel BPF -- 2.43.0