Expand the suite of tests for bpf_sock_destroy() to include invocation from socket map and socket hash iterators. Signed-off-by: Jordan Rife --- .../selftests/bpf/prog_tests/sock_destroy.c | 119 +++++++++++++++--- .../selftests/bpf/progs/sock_destroy_prog.c | 63 ++++++++++ 2 files changed, 164 insertions(+), 18 deletions(-) diff --git a/tools/testing/selftests/bpf/prog_tests/sock_destroy.c b/tools/testing/selftests/bpf/prog_tests/sock_destroy.c index 9c11938fe597..829e6fb9c1d8 100644 --- a/tools/testing/selftests/bpf/prog_tests/sock_destroy.c +++ b/tools/testing/selftests/bpf/prog_tests/sock_destroy.c @@ -8,13 +8,20 @@ #define TEST_NS "sock_destroy_netns" -static void start_iter_sockets(struct bpf_program *prog) +static void start_iter_sockets(struct bpf_program *prog, struct bpf_map *map) { + DECLARE_LIBBPF_OPTS(bpf_iter_attach_opts, opts); + union bpf_iter_link_info linfo = {}; struct bpf_link *link; char buf[50] = {}; int iter_fd, len; - link = bpf_program__attach_iter(prog, NULL); + if (map) + linfo.map.map_fd = bpf_map__fd(map); + opts.link_info = &linfo; + opts.link_info_len = sizeof(linfo); + + link = bpf_program__attach_iter(prog, &opts); if (!ASSERT_OK_PTR(link, "attach_iter")) return; @@ -32,7 +39,22 @@ static void start_iter_sockets(struct bpf_program *prog) bpf_link__destroy(link); } -static void test_tcp_client(struct sock_destroy_prog *skel) +static int insert_socket(struct bpf_map *socks, int fd, __u32 key) +{ + int map_fd = bpf_map__fd(socks); + __s64 sfd = fd; + int ret; + + ret = bpf_map_update_elem(map_fd, &key, &sfd, BPF_NOEXIST); + if (!ASSERT_OK(ret, "map_update")) + return -1; + + return 0; +} + +static void test_tcp_client(struct sock_destroy_prog *skel, + struct bpf_program *prog, + struct bpf_map *socks) { int serv = -1, clien = -1, accept_serv = -1, n; @@ -52,8 +74,17 @@ static void test_tcp_client(struct sock_destroy_prog *skel) if (!ASSERT_EQ(n, 1, "client send")) goto cleanup; + if (socks) { + if (!ASSERT_OK(insert_socket(socks, clien, 0), + "insert_socket")) + goto cleanup; + if (!ASSERT_OK(insert_socket(socks, serv, 1), + "insert_socket")) + goto cleanup; + } + /* Run iterator program that destroys connected client sockets. */ - start_iter_sockets(skel->progs.iter_tcp6_client); + start_iter_sockets(prog, socks); n = send(clien, "t", 1, 0); if (!ASSERT_LT(n, 0, "client_send on destroyed socket")) @@ -69,7 +100,9 @@ static void test_tcp_client(struct sock_destroy_prog *skel) close(serv); } -static void test_tcp_server(struct sock_destroy_prog *skel) +static void test_tcp_server(struct sock_destroy_prog *skel, + struct bpf_program *prog, + struct bpf_map *socks) { int serv = -1, clien = -1, accept_serv = -1, n, serv_port; @@ -93,8 +126,17 @@ static void test_tcp_server(struct sock_destroy_prog *skel) if (!ASSERT_EQ(n, 1, "client send")) goto cleanup; + if (socks) { + if (!ASSERT_OK(insert_socket(socks, clien, 0), + "insert_socket")) + goto cleanup; + if (!ASSERT_OK(insert_socket(socks, accept_serv, 1), + "insert_socket")) + goto cleanup; + } + /* Run iterator program that destroys server sockets. */ - start_iter_sockets(skel->progs.iter_tcp6_server); + start_iter_sockets(prog, socks); n = send(clien, "t", 1, 0); if (!ASSERT_LT(n, 0, "client_send on destroyed socket")) @@ -110,7 +152,9 @@ static void test_tcp_server(struct sock_destroy_prog *skel) close(serv); } -static void test_udp_client(struct sock_destroy_prog *skel) +static void test_udp_client(struct sock_destroy_prog *skel, + struct bpf_program *prog, + struct bpf_map *socks) { int serv = -1, clien = -1, n = 0; @@ -126,8 +170,17 @@ static void test_udp_client(struct sock_destroy_prog *skel) if (!ASSERT_EQ(n, 1, "client send")) goto cleanup; + if (socks) { + if (!ASSERT_OK(insert_socket(socks, clien, 0), + "insert_socket")) + goto cleanup; + if (!ASSERT_OK(insert_socket(socks, serv, 1), + "insert_socket")) + goto cleanup; + } + /* Run iterator program that destroys sockets. */ - start_iter_sockets(skel->progs.iter_udp6_client); + start_iter_sockets(prog, socks); n = send(clien, "t", 1, 0); if (!ASSERT_LT(n, 0, "client_send on destroyed socket")) @@ -143,11 +196,14 @@ static void test_udp_client(struct sock_destroy_prog *skel) close(serv); } -static void test_udp_server(struct sock_destroy_prog *skel) +static void test_udp_server(struct sock_destroy_prog *skel, + struct bpf_program *prog, + struct bpf_map *socks) { int *listen_fds = NULL, n, i, serv_port; unsigned int num_listens = 5; char buf[1]; + __u32 key; /* Start reuseport servers. */ listen_fds = start_reuseport_server(AF_INET6, SOCK_DGRAM, @@ -159,8 +215,15 @@ static void test_udp_server(struct sock_destroy_prog *skel) goto cleanup; skel->bss->serv_port = (__be16) serv_port; + if (socks) + for (key = 0; key < num_listens; key++) + if (!ASSERT_OK(insert_socket(socks, listen_fds[key], + key), + "insert_socket")) + goto cleanup; + /* Run iterator program that destroys server sockets. */ - start_iter_sockets(skel->progs.iter_udp6_server); + start_iter_sockets(prog, socks); for (i = 0; i < num_listens; ++i) { n = read(listen_fds[i], buf, sizeof(buf)); @@ -200,14 +263,34 @@ void test_sock_destroy(void) if (!ASSERT_OK_PTR(nstoken, "open_netns")) goto cleanup; - if (test__start_subtest("tcp_client")) - test_tcp_client(skel); - if (test__start_subtest("tcp_server")) - test_tcp_server(skel); - if (test__start_subtest("udp_client")) - test_udp_client(skel); - if (test__start_subtest("udp_server")) - test_udp_server(skel); + if (test__start_subtest("tcp_client")) { + test_tcp_client(skel, skel->progs.iter_tcp6_client, NULL); + test_tcp_client(skel, skel->progs.iter_sockmap_client, + skel->maps.sock_map); + test_tcp_client(skel, skel->progs.iter_sockmap_client, + skel->maps.sock_hash); + } + if (test__start_subtest("tcp_server")) { + test_tcp_server(skel, skel->progs.iter_tcp6_server, NULL); + test_tcp_server(skel, skel->progs.iter_sockmap_server, + skel->maps.sock_map); + test_tcp_server(skel, skel->progs.iter_sockmap_server, + skel->maps.sock_hash); + } + if (test__start_subtest("udp_client")) { + test_udp_client(skel, skel->progs.iter_udp6_client, NULL); + test_udp_client(skel, skel->progs.iter_sockmap_client, + skel->maps.sock_map); + test_udp_client(skel, skel->progs.iter_sockmap_client, + skel->maps.sock_hash); + } + if (test__start_subtest("udp_server")) { + test_udp_server(skel, skel->progs.iter_udp6_server, NULL); + test_udp_server(skel, skel->progs.iter_sockmap_server, + skel->maps.sock_map); + test_udp_server(skel, skel->progs.iter_sockmap_server, + skel->maps.sock_hash); + } RUN_TESTS(sock_destroy_prog_fail); diff --git a/tools/testing/selftests/bpf/progs/sock_destroy_prog.c b/tools/testing/selftests/bpf/progs/sock_destroy_prog.c index 9e0bf7a54cec..d91f75190bbf 100644 --- a/tools/testing/selftests/bpf/progs/sock_destroy_prog.c +++ b/tools/testing/selftests/bpf/progs/sock_destroy_prog.c @@ -24,6 +24,20 @@ struct { __type(value, __u64); } udp_conn_sockets SEC(".maps"); +struct { + __uint(type, BPF_MAP_TYPE_SOCKMAP); + __uint(max_entries, 5); + __type(key, __u32); + __type(value, __u64); +} sock_map SEC(".maps"); + +struct { + __uint(type, BPF_MAP_TYPE_SOCKHASH); + __uint(max_entries, 5); + __type(key, __u32); + __type(value, __u64); +} sock_hash SEC(".maps"); + SEC("cgroup/connect6") int sock_connect(struct bpf_sock_addr *ctx) { @@ -142,4 +156,53 @@ int iter_udp6_server(struct bpf_iter__udp *ctx) return 0; } +SEC("iter/sockmap") +int iter_sockmap_client(struct bpf_iter__sockmap *ctx) +{ + __u64 sock_cookie = 0, *val; + struct sock *sk = ctx->sk; + __u32 *key = ctx->key; + __u32 zero = 0; + + if (!key || !sk) + return 0; + + sock_cookie = bpf_get_socket_cookie(sk); + val = bpf_map_lookup_elem(&udp_conn_sockets, &zero); + if (val && *val == sock_cookie) + goto destroy; + val = bpf_map_lookup_elem(&tcp_conn_sockets, &zero); + if (val && *val == sock_cookie) + goto destroy; + goto out; +destroy: + bpf_sock_destroy((struct sock_common *)sk); +out: + return 0; +} + +SEC("iter/sockmap") +int iter_sockmap_server(struct bpf_iter__sockmap *ctx) +{ + struct sock *sk = ctx->sk; + struct tcp6_sock *tcp_sk; + struct udp6_sock *udp_sk; + __u32 *key = ctx->key; + + if (!key || !sk) + return 0; + + tcp_sk = bpf_skc_to_tcp6_sock(sk); + if (tcp_sk && tcp_sk->tcp.inet_conn.icsk_inet.inet_sport == serv_port) + goto destroy; + udp_sk = bpf_skc_to_udp6_sock(sk); + if (udp_sk && udp_sk->udp.inet.inet_sport == serv_port) + goto destroy; + goto out; +destroy: + bpf_sock_destroy((struct sock_common *)sk); +out: + return 0; +} + char _license[] SEC("license") = "GPL"; -- 2.43.0