atalk_sendmsg() looks up an AppleTalk route, stores the returned atalk_route and net_device pointers, and then drops the socket lock around sock_alloc_send_skb(). The route pointer returned by atrtr_find() is only protected while atalk_routes_lock is held; after that lock is dropped, a concurrent SIOCDELRT or device-down path can unlink the route, drop the device reference, and free the route. When sendmsg resumes, it can still dereference the stale route and device pointers while building or transmitting the packet. A KASAN reproducer using AF_APPLETALK sockets and SIOCADDRT/SIOCDELRT reports slab-use-after-free reads in atalk_sendmsg(), with the object allocated by atrtr_create() and freed by atrtr_delete(). Fix this by splitting the route lookup into a helper that is called with atalk_routes_lock already held. atalk_sendmsg() now performs route lookup, copies the route fields it needs, and takes references to the selected devices while still holding atalk_routes_lock. After the lock is dropped and skb allocation sleeps, the send path uses only the copied route data and the held net_device references, which are released before returning. This preserves the existing route selection behaviour, including the separate loopback route used for broadcast loopback, while removing the dangling route/device window. Fixes: 60d9f461a20b ("appletalk: remove the BKL") Cc: stable@vger.kernel.org Reported-by: Yizhou Zhao Reported-by: Yuxiang Yang Reported-by: Ao Wang Reported-by: Xuewei Feng Reported-by: Qi Li Reported-by: Ke Xu Assisted-by: GLM:GLM-5.1 Signed-off-by: Yizhou Zhao --- net/appletalk/ddp.c | 68 ++++++++++++++++++++++++++++++++++++----------------- 1 file changed, 47 insertions(+), 21 deletions(-) diff --git a/net/appletalk/ddp.c b/net/appletalk/ddp.c index 30a6dc06291c..e7fb4613c518 100644 --- a/net/appletalk/ddp.c +++ b/net/appletalk/ddp.c @@ -434,7 +434,7 @@ static struct atalk_iface *atalk_find_interface(__be16 net, int node) * the socket (later on...). We know about host routes and the fact * that a route must be direct to broadcast. */ -static struct atalk_route *atrtr_find(struct atalk_addr *target) +static struct atalk_route *atrtr_find_locked(struct atalk_addr *target) { /* * we must search through all routes unless we find a @@ -444,7 +444,6 @@ static struct atalk_route *atrtr_find(struct atalk_addr *target) struct atalk_route *net_route = NULL; struct atalk_route *r; - read_lock_bh(&atalk_routes_lock); for (r = atalk_routes; r; r = r->next) { if (!(r->flags & RTF_UP)) continue; @@ -477,6 +476,15 @@ static struct atalk_route *atrtr_find(struct atalk_addr *target) else /* No route can be found */ r = NULL; out: + return r; +} + +static struct atalk_route *atrtr_find(struct atalk_addr *target) +{ + struct atalk_route *r; + + read_lock_bh(&atalk_routes_lock); + r = atrtr_find_locked(target); read_unlock_bh(&atalk_routes_lock); return r; } @@ -1553,10 +1561,12 @@ static int atalk_sendmsg(struct socket *sock, struct msghdr *msg, size_t len) int loopback = 0; struct sockaddr_at local_satalk, gsat; struct sk_buff *skb; - struct net_device *dev; + struct net_device *dev = NULL, *dev_lo = NULL; struct ddpehdr *ddp; int size, hard_header_len; struct atalk_route *rt, *rt_lo = NULL; + int rt_flags; + struct atalk_addr rt_gateway; int err; if (flags & ~(MSG_DONTWAIT|MSG_CMSG_COMPAT)) @@ -1600,39 +1610,50 @@ static int atalk_sendmsg(struct socket *sock, struct msghdr *msg, size_t len) /* For headers */ size = sizeof(struct ddpehdr) + len + ddp_dl->header_length; + read_lock_bh(&atalk_routes_lock); if (usat->sat_addr.s_net || usat->sat_addr.s_node == ATADDR_ANYNODE) { - rt = atrtr_find(&usat->sat_addr); + rt = atrtr_find_locked(&usat->sat_addr); } else { struct atalk_addr at_hint; at_hint.s_node = 0; at_hint.s_net = at->src_net; - rt = atrtr_find(&at_hint); + rt = atrtr_find_locked(&at_hint); } err = -ENETUNREACH; - if (!rt) + if (!rt) { + read_unlock_bh(&atalk_routes_lock); goto out; + } dev = rt->dev; - - net_dbg_ratelimited("SK %p: Size needed %d, device %s\n", - sk, size, dev->name); + dev_hold(dev); + rt_flags = rt->flags; + rt_gateway = rt->gateway; hard_header_len = dev->hard_header_len; /* Leave room for loopback hardware header if necessary */ if (usat->sat_addr.s_node == ATADDR_BCAST && - (dev->flags & IFF_LOOPBACK || !(rt->flags & RTF_GATEWAY))) { + (dev->flags & IFF_LOOPBACK || !(rt_flags & RTF_GATEWAY))) { struct atalk_addr at_lo; at_lo.s_node = 0; at_lo.s_net = 0; - rt_lo = atrtr_find(&at_lo); + rt_lo = atrtr_find_locked(&at_lo); - if (rt_lo && rt_lo->dev->hard_header_len > hard_header_len) - hard_header_len = rt_lo->dev->hard_header_len; + if (rt_lo) { + dev_lo = rt_lo->dev; + dev_hold(dev_lo); + if (dev_lo->hard_header_len > hard_header_len) + hard_header_len = dev_lo->hard_header_len; + } } + read_unlock_bh(&atalk_routes_lock); + + net_dbg_ratelimited("SK %p: Size needed %d, device %s\n", + sk, size, dev->name); size += hard_header_len; release_sock(sk); @@ -1675,7 +1696,7 @@ static int atalk_sendmsg(struct socket *sock, struct msghdr *msg, size_t len) * to group we are in) */ if (ddp->deh_dnode == ATADDR_BCAST && - !(rt->flags & RTF_GATEWAY) && !(dev->flags & IFF_LOOPBACK)) { + !(rt_flags & RTF_GATEWAY) && !(dev->flags & IFF_LOOPBACK)) { struct sk_buff *skb2 = skb_copy(skb, GFP_KERNEL); if (skb2) { @@ -1693,20 +1714,21 @@ static int atalk_sendmsg(struct socket *sock, struct msghdr *msg, size_t len) /* loop back */ skb_orphan(skb); if (ddp->deh_dnode == ATADDR_BCAST) { - if (!rt_lo) { + if (!dev_lo) { kfree_skb(skb); err = -ENETUNREACH; goto out; } - dev = rt_lo->dev; - skb->dev = dev; + skb->dev = dev_lo; + ddp_dl->request(ddp_dl, skb, dev_lo->dev_addr); + } else { + ddp_dl->request(ddp_dl, skb, dev->dev_addr); } - ddp_dl->request(ddp_dl, skb, dev->dev_addr); } else { net_dbg_ratelimited("SK %p: send out.\n", sk); - if (rt->flags & RTF_GATEWAY) { - gsat.sat_addr = rt->gateway; - usat = &gsat; + if (rt_flags & RTF_GATEWAY) { + gsat.sat_addr = rt_gateway; + usat = &gsat; } /* @@ -1717,6 +1739,10 @@ static int atalk_sendmsg(struct socket *sock, struct msghdr *msg, size_t len) net_dbg_ratelimited("SK %p: Done write (%zd).\n", sk, len); out: + if (dev) + dev_put(dev); + if (dev_lo) + dev_put(dev_lo); release_sock(sk); return err ? : len; }