From: Kuniyuki Iwashima In __ipv6_sock_mc_join(), per-socket mld data is protected by lock_sock(), and only __dev_get_by_index() requires RTNL. Let's use dev_get_by_index() and drop RTNL for IPV6_ADD_MEMBERSHIP and MCAST_JOIN_GROUP. Note that we must call rt6_lookup() and dev_hold() under RCU. If rt6_lookup() returns an entry from the exception table, dst_dev_put() could change rt->dev.dst to loopback concurrently, and the original device could lose the refcount before dev_hold() and unblock device registration. dst_dev_put() is called from NETDEV_UNREGISTER and synchronize_net() follows it, so as long as rt6_lookup() and dev_hold() are called within the same RCU critical section, the dev is alive. Even if the race happens, they are synchronised by idev->dead and mcast addresses are cleaned up. For the racy access to rt->dst.dev, we use dst_dev(). Signed-off-by: Kuniyuki Iwashima Reviewed-by: Eric Dumazet --- v3: Add dst_dev() for rt->dst.dev v2: Hold rcu_read_lock() around rt6_lookup & dev_hold() --- net/ipv6/ipv6_sockglue.c | 2 -- net/ipv6/mcast.c | 24 +++++++++++++----------- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/net/ipv6/ipv6_sockglue.c b/net/ipv6/ipv6_sockglue.c index 1e225e6489ea..cb0dc885cbe4 100644 --- a/net/ipv6/ipv6_sockglue.c +++ b/net/ipv6/ipv6_sockglue.c @@ -121,11 +121,9 @@ static bool setsockopt_needs_rtnl(int optname) { switch (optname) { case IPV6_ADDRFORM: - case IPV6_ADD_MEMBERSHIP: case IPV6_DROP_MEMBERSHIP: case IPV6_JOIN_ANYCAST: case IPV6_LEAVE_ANYCAST: - case MCAST_JOIN_GROUP: case MCAST_LEAVE_GROUP: case MCAST_JOIN_SOURCE_GROUP: case MCAST_LEAVE_SOURCE_GROUP: diff --git a/net/ipv6/mcast.c b/net/ipv6/mcast.c index b3f063b5ffd7..d55c1cb4189a 100644 --- a/net/ipv6/mcast.c +++ b/net/ipv6/mcast.c @@ -175,14 +175,12 @@ static int unsolicited_report_interval(struct inet6_dev *idev) static int __ipv6_sock_mc_join(struct sock *sk, int ifindex, const struct in6_addr *addr, unsigned int mode) { - struct net_device *dev = NULL; - struct ipv6_mc_socklist *mc_lst; struct ipv6_pinfo *np = inet6_sk(sk); + struct ipv6_mc_socklist *mc_lst; struct net *net = sock_net(sk); + struct net_device *dev = NULL; int err; - ASSERT_RTNL(); - if (!ipv6_addr_is_multicast(addr)) return -EINVAL; @@ -202,13 +200,18 @@ static int __ipv6_sock_mc_join(struct sock *sk, int ifindex, if (ifindex == 0) { struct rt6_info *rt; + + rcu_read_lock(); rt = rt6_lookup(net, addr, NULL, 0, NULL, 0); if (rt) { - dev = rt->dst.dev; + dev = dst_dev(&rt->dst); + dev_hold(dev); ip6_rt_put(rt); } - } else - dev = __dev_get_by_index(net, ifindex); + rcu_read_unlock(); + } else { + dev = dev_get_by_index(net, ifindex); + } if (!dev) { sock_kfree_s(sk, mc_lst, sizeof(*mc_lst)); @@ -219,12 +222,11 @@ static int __ipv6_sock_mc_join(struct sock *sk, int ifindex, mc_lst->sfmode = mode; RCU_INIT_POINTER(mc_lst->sflist, NULL); - /* - * now add/increase the group membership on the device - */ - + /* now add/increase the group membership on the device */ err = __ipv6_dev_mc_inc(dev, addr, mode); + dev_put(dev); + if (err) { sock_kfree_s(sk, mc_lst, sizeof(*mc_lst)); return err; -- 2.49.0