raw_v4_match() is a lockless match helper under sk_for_each_rcu(). It still reads inet->inet_daddr, inet->inet_rcv_saddr and sk->sk_bound_dev_if with plain loads while bind, connect and disconnect paths can update the same match fields concurrently. Annotate only those mutable match fields in raw_v4_match(), and do so at the point of use instead of hoisting the bound-device read before the earlier short-circuit tests. Also annotate the corresponding IPv4 writers so the read side matches explicit WRITE_ONCE() updates when those same fields are modified. Fixes: 0daf07e52709 ("raw: convert raw sockets to RCU") Signed-off-by: Runyu Xiao --- v3: - drop the inet_num annotation for raw sockets - cover inet_daddr as well - avoid hoisting sk_bound_dev_if into a temporary variable - annotate the matching IPv4 writer paths v2: - note that inet_num and sk_bound_dev_if already have WRITE_ONCE() writers - add WRITE_ONCE() in raw_bind() for inet_rcv_saddr - previous version: https://lore.kernel.org/r/20260601073937.1137673-1-runyu.xiao@seu.edu.cn include/net/ip.h | 3 ++- net/ipv4/datagram.c | 4 ++-- net/ipv4/raw.c | 23 ++++++++++++++++------- net/ipv4/udp.c | 4 ++-- 4 files changed, 22 insertions(+), 12 deletions(-) diff --git a/include/net/ip.h b/include/net/ip.h index 7f2fe1a8401b..3902e8b58699 100644 --- a/include/net/ip.h +++ b/include/net/ip.h @@ -679,7 +679,8 @@ static inline void ip_ipgre_mc_map(__be32 naddr, const unsigned char *broadcast, static __inline__ void inet_reset_saddr(struct sock *sk) { - inet_sk(sk)->inet_rcv_saddr = inet_sk(sk)->inet_saddr = 0; + inet_sk(sk)->inet_saddr = 0; + WRITE_ONCE(inet_sk(sk)->inet_rcv_saddr, 0); #if IS_ENABLED(CONFIG_IPV6) if (sk->sk_family == PF_INET6) { struct ipv6_pinfo *np = inet6_sk(sk); diff --git a/net/ipv4/datagram.c b/net/ipv4/datagram.c index 1614593b6d72..7d25519a6cdd 100644 --- a/net/ipv4/datagram.c +++ b/net/ipv4/datagram.c @@ -63,12 +63,12 @@ int __ip4_datagram_connect(struct sock *sk, struct sockaddr_unsized *uaddr, int } /* Update addresses before rehashing */ - inet->inet_daddr = fl4->daddr; + WRITE_ONCE(inet->inet_daddr, fl4->daddr); inet->inet_dport = usin->sin_port; if (!inet->inet_saddr) inet->inet_saddr = fl4->saddr; if (!inet->inet_rcv_saddr) { - inet->inet_rcv_saddr = fl4->saddr; + WRITE_ONCE(inet->inet_rcv_saddr, fl4->saddr); if (sk->sk_prot->rehash) sk->sk_prot->rehash(sk); } diff --git a/net/ipv4/raw.c b/net/ipv4/raw.c index 5aaf9c62c8e1..dfb294b5c794 100644 --- a/net/ipv4/raw.c +++ b/net/ipv4/raw.c @@ -120,13 +120,21 @@ bool raw_v4_match(struct net *net, const struct sock *sk, unsigned short num, __be32 raddr, __be32 laddr, int dif, int sdif) { const struct inet_sock *inet = inet_sk(sk); + __be32 daddr, rcv_saddr; - if (net_eq(sock_net(sk), net) && inet->inet_num == num && - !(inet->inet_daddr && inet->inet_daddr != raddr) && - !(inet->inet_rcv_saddr && inet->inet_rcv_saddr != laddr) && - raw_sk_bound_dev_eq(net, sk->sk_bound_dev_if, dif, sdif)) - return true; - return false; + if (!net_eq(sock_net(sk), net) || inet->inet_num != num) + return false; + + daddr = READ_ONCE(inet->inet_daddr); + if (daddr && daddr != raddr) + return false; + + rcv_saddr = READ_ONCE(inet->inet_rcv_saddr); + if (rcv_saddr && rcv_saddr != laddr) + return false; + + return raw_sk_bound_dev_eq(net, READ_ONCE(sk->sk_bound_dev_if), + dif, sdif); } EXPORT_SYMBOL_GPL(raw_v4_match); @@ -724,7 +732,8 @@ static int raw_bind(struct sock *sk, struct sockaddr_unsized *uaddr, chk_addr_ret)) goto out; - inet->inet_rcv_saddr = inet->inet_saddr = addr->sin_addr.s_addr; + inet->inet_saddr = addr->sin_addr.s_addr; + WRITE_ONCE(inet->inet_rcv_saddr, addr->sin_addr.s_addr); if (chk_addr_ret == RTN_MULTICAST || chk_addr_ret == RTN_BROADCAST) inet->inet_saddr = 0; /* Use device */ sk_dst_reset(sk); diff --git a/net/ipv4/udp.c b/net/ipv4/udp.c index 0ac2bf4f8759..5d3e07f3ac27 100644 --- a/net/ipv4/udp.c +++ b/net/ipv4/udp.c @@ -2156,10 +2156,10 @@ int __udp_disconnect(struct sock *sk, int flags) */ sk->sk_state = TCP_CLOSE; - inet->inet_daddr = 0; + WRITE_ONCE(inet->inet_daddr, 0); inet->inet_dport = 0; sock_rps_reset_rxhash(sk); - sk->sk_bound_dev_if = 0; + WRITE_ONCE(sk->sk_bound_dev_if, 0); if (!(sk->sk_userlocks & SOCK_BINDADDR_LOCK)) { inet_reset_saddr(sk); if (sk->sk_prot->rehash && -- 2.34.1