Ensure that a specific EID (remote or local) bind will match in preference to a MCTP_ADDR_ANY bind. This adds infrastructure for binding a socket to receive messages from a specific remote peer address, a future commit will expose an API for this. Signed-off-by: Matt Johnston --- v2: - Use DECLARE_HASHTABLE - Fix long lines --- include/net/netns/mctp.h | 20 +++++++++--- net/mctp/af_mctp.c | 11 ++++--- net/mctp/route.c | 81 ++++++++++++++++++++++++++++++++++++++---------- 3 files changed, 87 insertions(+), 25 deletions(-) diff --git a/include/net/netns/mctp.h b/include/net/netns/mctp.h index 1db8f9aaddb4b96f4803df9f30a762f5f88d7f7f..89555f90b97b297e50a571b26c5232b824909da7 100644 --- a/include/net/netns/mctp.h +++ b/include/net/netns/mctp.h @@ -6,19 +6,25 @@ #ifndef __NETNS_MCTP_H__ #define __NETNS_MCTP_H__ +#include +#include #include #include +#define MCTP_BINDS_BITS 7 + struct netns_mctp { /* Only updated under RTNL, entries freed via RCU */ struct list_head routes; - /* Bound sockets: list of sockets bound by type. - * This list is updated from non-atomic contexts (under bind_lock), - * and read (under rcu) in packet rx + /* Bound sockets: hash table of sockets, keyed by + * (type, src_eid, dest_eid). + * Specific src_eid/dest_eid entries also have an entry for + * MCTP_ADDR_ANY. This list is updated from non-atomic contexts + * (under bind_lock), and read (under rcu) in packet rx. */ struct mutex bind_lock; - struct hlist_head binds; + DECLARE_HASHTABLE(binds, MCTP_BINDS_BITS); /* tag allocations. This list is read and updated from atomic contexts, * but elements are free()ed after a RCU grace-period @@ -34,4 +40,10 @@ struct netns_mctp { struct list_head neighbours; }; +static inline u32 mctp_bind_hash(u8 type, u8 local_addr, u8 peer_addr) +{ + return hash_32(type | (u32)local_addr << 8 | (u32)peer_addr << 16, + MCTP_BINDS_BITS); +} + #endif /* __NETNS_MCTP_H__ */ diff --git a/net/mctp/af_mctp.c b/net/mctp/af_mctp.c index 4f456b1c82d182ac2c64acebb0e603726826a7e7..a07da537bab41005ce643862b23d3050e958a66a 100644 --- a/net/mctp/af_mctp.c +++ b/net/mctp/af_mctp.c @@ -644,17 +644,17 @@ static int mctp_sk_hash(struct sock *sk) struct net *net = sock_net(sk); struct sock *existing; struct mctp_sock *msk; + u32 hash; int rc; msk = container_of(sk, struct mctp_sock, sk); - /* Bind lookup runs under RCU, remain live during that. */ - sock_set_flag(sk, SOCK_RCU_FREE); + hash = mctp_bind_hash(msk->bind_type, msk->bind_addr, MCTP_ADDR_ANY); mutex_lock(&net->mctp.bind_lock); /* Prevent duplicate binds. */ - sk_for_each(existing, &net->mctp.binds) { + sk_for_each(existing, &net->mctp.binds[hash]) { struct mctp_sock *mex = container_of(existing, struct mctp_sock, sk); @@ -666,7 +666,10 @@ static int mctp_sk_hash(struct sock *sk) } } - sk_add_node_rcu(sk, &net->mctp.binds); + /* Bind lookup runs under RCU, remain live during that. */ + sock_set_flag(sk, SOCK_RCU_FREE); + + sk_add_node_rcu(sk, &net->mctp.binds[hash]); rc = 0; out: diff --git a/net/mctp/route.c b/net/mctp/route.c index d9c8e5a5f9ce9aefbf16730c65a1f54caa5592b9..815fcb8db3beff338eedbabe6b3f4d44dd238f11 100644 --- a/net/mctp/route.c +++ b/net/mctp/route.c @@ -38,14 +38,45 @@ static int mctp_route_discard(struct mctp_route *route, struct sk_buff *skb) return 0; } -static struct mctp_sock *mctp_lookup_bind(struct net *net, struct sk_buff *skb) +static struct mctp_sock *mctp_lookup_bind_details(struct net *net, + struct sk_buff *skb, + u8 type, u8 dest, + u8 src, bool allow_net_any) { struct mctp_skb_cb *cb = mctp_cb(skb); - struct mctp_hdr *mh; struct sock *sk; - u8 type; + u8 hash; - WARN_ON(!rcu_read_lock_held()); + WARN_ON_ONCE(!rcu_read_lock_held()); + + hash = mctp_bind_hash(type, dest, src); + + sk_for_each_rcu(sk, &net->mctp.binds[hash]) { + struct mctp_sock *msk = container_of(sk, struct mctp_sock, sk); + + if (!allow_net_any && msk->bind_net == MCTP_NET_ANY) + continue; + + if (msk->bind_net != MCTP_NET_ANY && msk->bind_net != cb->net) + continue; + + if (msk->bind_type != type) + continue; + + if (!mctp_address_matches(msk->bind_addr, dest)) + continue; + + return msk; + } + + return NULL; +} + +static struct mctp_sock *mctp_lookup_bind(struct net *net, struct sk_buff *skb) +{ + struct mctp_sock *msk; + struct mctp_hdr *mh; + u8 type; /* TODO: look up in skb->cb? */ mh = mctp_hdr(skb); @@ -55,20 +86,36 @@ static struct mctp_sock *mctp_lookup_bind(struct net *net, struct sk_buff *skb) type = (*(u8 *)skb->data) & 0x7f; - sk_for_each_rcu(sk, &net->mctp.binds) { - struct mctp_sock *msk = container_of(sk, struct mctp_sock, sk); - - if (msk->bind_net != MCTP_NET_ANY && msk->bind_net != cb->net) - continue; - - if (msk->bind_type != type) - continue; - - if (!mctp_address_matches(msk->bind_addr, mh->dest)) - continue; + /* Look for binds in order of widening scope. A given destination or + * source address also implies matching on a particular network. + * + * - Matching destination and source + * - Matching destination + * - Matching source + * - Matching network, any address + * - Any network or address + */ + msk = mctp_lookup_bind_details(net, skb, type, mh->dest, mh->src, + false); + if (msk) + return msk; + msk = mctp_lookup_bind_details(net, skb, type, MCTP_ADDR_ANY, mh->src, + false); + if (msk) + return msk; + msk = mctp_lookup_bind_details(net, skb, type, mh->dest, MCTP_ADDR_ANY, + false); + if (msk) + return msk; + msk = mctp_lookup_bind_details(net, skb, type, MCTP_ADDR_ANY, + MCTP_ADDR_ANY, false); + if (msk) + return msk; + msk = mctp_lookup_bind_details(net, skb, type, MCTP_ADDR_ANY, + MCTP_ADDR_ANY, true); + if (msk) return msk; - } return NULL; } @@ -1475,7 +1522,7 @@ static int __net_init mctp_routes_net_init(struct net *net) struct netns_mctp *ns = &net->mctp; INIT_LIST_HEAD(&ns->routes); - INIT_HLIST_HEAD(&ns->binds); + hash_init(ns->binds); mutex_init(&ns->bind_lock); INIT_HLIST_HEAD(&ns->keys); spin_lock_init(&ns->keys_lock); -- 2.43.0