SCTP_DIAG endpoint dumping currently walks the endpoint hash table without taking the socket lock before calling inet_sctp_diag_fill(). This is problematic because inet_sctp_diag_fill() eventually calls inet_diag_msg_sctpladdrs_fill(), which traverses the endpoint's local address list twice: once to count entries for nla_reserve(), and once again to copy the addresses into the netlink buffer. Since these two traversals are protected only by separate RCU read-side critical sections, concurrent socket operations such as SCTP_SOCKOPT_BINDX_REM may remove entries from the address list between them. In that case, the number of copied addresses becomes smaller than the originally reserved buffer size, leaving part of the netlink payload uninitialized and potentially leaking kernel memory to user space. Fix this by changing sctp_for_each_endpoint() to iterate with net and position awareness while taking a reference on each socket, then release the endpoint hash bucket read_lock_bh() before invoking the callback. A socket reference is required because the callback acquires lock_sock(), which must be called outside of read_lock_bh() since lock_sock() may sleep. Holding a socket reference ensures the socket remains valid after dropping the bucket lock and before acquiring the socket lock. With the socket lock held, concurrent bind-address modifications are serialized against the diagnostic dump, ensuring the local address list remains stable during buffer sizing and initialization. This also simplifies endpoint traversal by removing the temporary callback local position tracking args[4] and moving dump progress tracking into sctp_for_each_endpoint() itself. While at it, fix the idiag_states check in sctp_ep_dump() and skip ep dumping when non LISTEN|CLOSE states are also requested and the ep has assocs, since such cases will be handled later by sctp_sock_dump(). Reported-by: Zero Day Initiative Fixes: 8f840e47f190 ("sctp: add the sctp_diag.c file") Signed-off-by: Xin Long --- include/net/sctp/sctp.h | 3 +- net/sctp/diag.c | 62 +++++++++++++++++++---------------------- net/sctp/socket.c | 34 +++++++++++++++++----- 3 files changed, 57 insertions(+), 42 deletions(-) diff --git a/include/net/sctp/sctp.h b/include/net/sctp/sctp.h index 58242b37b47a..cd82b05354a3 100644 --- a/include/net/sctp/sctp.h +++ b/include/net/sctp/sctp.h @@ -111,7 +111,8 @@ int sctp_transport_lookup_process(sctp_callback_t cb, struct net *net, const union sctp_addr *paddr, void *p, int dif); int sctp_transport_traverse_process(sctp_callback_t cb, sctp_callback_t cb_done, struct net *net, int *pos, void *p); -int sctp_for_each_endpoint(int (*cb)(struct sctp_endpoint *, void *), void *p); +int sctp_for_each_endpoint(int (*cb)(struct sctp_endpoint *, void *), + struct net *net, int *pos, void *p); int sctp_get_sctp_info(struct sock *sk, struct sctp_association *asoc, struct sctp_info *info); diff --git a/net/sctp/diag.c b/net/sctp/diag.c index d758f5c3e06e..9108272ca527 100644 --- a/net/sctp/diag.c +++ b/net/sctp/diag.c @@ -92,6 +92,7 @@ static int inet_diag_msg_sctpladdrs_fill(struct sk_buff *skb, if (!--addrcnt) break; } + WARN_ON_ONCE(addrcnt); rcu_read_unlock(); return 0; @@ -373,42 +374,36 @@ static int sctp_ep_dump(struct sctp_endpoint *ep, void *p) struct sk_buff *skb = commp->skb; struct netlink_callback *cb = commp->cb; const struct inet_diag_req_v2 *r = commp->r; - struct net *net = sock_net(skb->sk); struct inet_sock *inet = inet_sk(sk); int err = 0; - if (!net_eq(sock_net(sk), net)) + lock_sock(sk); + if (sctp_sstate(sk, CLOSED)) goto out; - if (cb->args[4] < cb->args[1]) - goto next; - - if (!(r->idiag_states & TCPF_LISTEN) && !list_empty(&ep->asocs)) - goto next; + if ((r->idiag_states & ~(TCPF_LISTEN | TCPF_CLOSE)) && + !list_empty(&ep->asocs)) + goto out; if (r->sdiag_family != AF_UNSPEC && sk->sk_family != r->sdiag_family) - goto next; + goto out; if (r->id.idiag_sport != inet->inet_sport && r->id.idiag_sport) - goto next; + goto out; if (r->id.idiag_dport != inet->inet_dport && r->id.idiag_dport) - goto next; - - if (inet_sctp_diag_fill(sk, NULL, skb, r, - sk_user_ns(NETLINK_CB(cb->skb).sk), - NETLINK_CB(cb->skb).portid, - cb->nlh->nlmsg_seq, NLM_F_MULTI, - cb->nlh, commp->net_admin) < 0) { - err = 2; goto out; - } -next: - cb->args[4]++; + + err = inet_sctp_diag_fill(sk, NULL, skb, r, + sk_user_ns(NETLINK_CB(cb->skb).sk), + NETLINK_CB(cb->skb).portid, + cb->nlh->nlmsg_seq, NLM_F_MULTI, + cb->nlh, commp->net_admin); out: + release_sock(sk); return err; } @@ -479,41 +474,40 @@ static void sctp_diag_dump(struct sk_buff *skb, struct netlink_callback *cb, .r = r, .net_admin = netlink_net_capable(cb->skb, CAP_NET_ADMIN), }; - int pos = cb->args[2]; + int pos; /* eps hashtable dumps * args: * 0 : if it will traversal listen sock * 1 : to record the sock pos of this time's traversal - * 4 : to work as a temporary variable to traversal list */ if (cb->args[0] == 0) { - if (!(idiag_states & TCPF_LISTEN)) - goto skip; - if (sctp_for_each_endpoint(sctp_ep_dump, &commp)) - goto done; -skip: + if (idiag_states & TCPF_LISTEN) { + pos = cb->args[1]; + if (sctp_for_each_endpoint(sctp_ep_dump, net, &pos, + &commp)) { + cb->args[1] = pos; + return; + } + } cb->args[0] = 1; cb->args[1] = 0; - cb->args[4] = 0; } + if (!(idiag_states & ~(TCPF_LISTEN | TCPF_CLOSE))) + return; + /* asocs by transport hashtable dump * args: * 1 : to record the assoc pos of this time's traversal * 2 : to record the transport pos of this time's traversal * 3 : to mark if we have dumped the ep info of the current asoc * 4 : to work as a temporary variable to traversal list - * 5 : to save the sk we get from travelsing the tsp list. */ - if (!(idiag_states & ~(TCPF_LISTEN | TCPF_CLOSE))) - goto done; - + pos = cb->args[2]; sctp_transport_traverse_process(sctp_sock_filter, sctp_sock_dump, net, &pos, &commp); cb->args[2] = pos; - -done: cb->args[1] = cb->args[4]; cb->args[4] = 0; } diff --git a/net/sctp/socket.c b/net/sctp/socket.c index 66e12fb0c646..1ed405dedc01 100644 --- a/net/sctp/socket.c +++ b/net/sctp/socket.c @@ -5369,24 +5369,44 @@ struct sctp_transport *sctp_transport_get_idx(struct net *net, } int sctp_for_each_endpoint(int (*cb)(struct sctp_endpoint *, void *), - void *p) { - int err = 0; - int hash = 0; - struct sctp_endpoint *ep; + struct net *net, int *pos, void *p) { + int err, hash = 0, idx = 0, start; struct sctp_hashbucket *head; + struct sctp_endpoint *ep; + struct sock *sk; for (head = sctp_ep_hashtable; hash < sctp_ep_hashsize; hash++, head++) { + start = idx; +again: + sk = NULL; read_lock_bh(&head->lock); sctp_for_each_hentry(ep, &head->chain) { - err = cb(ep, p); - if (err) + if (sock_net(ep->base.sk) != net) + continue; + if (idx++ >= *pos) { + sk = ep->base.sk; + sock_hold(sk); break; + } } read_unlock_bh(&head->lock); + + if (sk) { + err = cb(ep, p); + if (err) { + sock_put(sk); + return err; + } + sock_put(sk); + (*pos)++; + + idx = start; + goto again; + } } - return err; + return 0; } EXPORT_SYMBOL_GPL(sctp_for_each_endpoint); -- 2.47.1