Ensure that a specific EID (remote or local) bind will match in preference to a MCTP_ADDR_ANY bind. This adds infrastructure for binding a socket to receive messages from a specific remote peer address, a future commit will expose an API for this. Signed-off-by: Matt Johnston --- include/net/netns/mctp.h | 15 ++++++++-- net/mctp/af_mctp.c | 11 ++++--- net/mctp/route.c | 76 +++++++++++++++++++++++++++++++++++++----------- 3 files changed, 79 insertions(+), 23 deletions(-) diff --git a/include/net/netns/mctp.h b/include/net/netns/mctp.h index 1db8f9aaddb4b96f4803df9f30a762f5f88d7f7f..9f4f1c1065a8f00bbd5b1df5fa8f1cedf8d60686 100644 --- a/include/net/netns/mctp.h +++ b/include/net/netns/mctp.h @@ -6,19 +6,25 @@ #ifndef __NETNS_MCTP_H__ #define __NETNS_MCTP_H__ +#include #include #include +#define MCTP_BINDS_BITS 7 +#define MCTP_BINDS_SIZE (1 << MCTP_BINDS_BITS) +#define MCTP_BINDS_MASK (MCTP_BINDS_SIZE - 1) + struct netns_mctp { /* Only updated under RTNL, entries freed via RCU */ struct list_head routes; - /* Bound sockets: list of sockets bound by type. + /* Bound sockets: hash table of sockets, keyed by (type, src_eid, dest_eid). + * Specific src_eid/dest_eid entries also have an entry for MCTP_ADDR_ANY. * This list is updated from non-atomic contexts (under bind_lock), * and read (under rcu) in packet rx */ struct mutex bind_lock; - struct hlist_head binds; + struct hlist_head binds[MCTP_BINDS_SIZE]; /* tag allocations. This list is read and updated from atomic contexts, * but elements are free()ed after a RCU grace-period @@ -34,4 +40,9 @@ struct netns_mctp { struct list_head neighbours; }; +static inline u32 mctp_bind_hash(u8 type, u8 local_addr, u8 peer_addr) +{ + return hash_32(type | (u32)local_addr << 8 | (u32)peer_addr << 16, MCTP_BINDS_BITS); +} + #endif /* __NETNS_MCTP_H__ */ diff --git a/net/mctp/af_mctp.c b/net/mctp/af_mctp.c index 4751f5fc082dcab27df77a9c5acbc6abb4e861d5..7638e22bf03848868768700fdac07f74891dad0d 100644 --- a/net/mctp/af_mctp.c +++ b/net/mctp/af_mctp.c @@ -643,17 +643,17 @@ static int mctp_sk_hash(struct sock *sk) struct net *net = sock_net(sk); struct sock *existing; struct mctp_sock *msk; + u32 hash; int rc; msk = container_of(sk, struct mctp_sock, sk); - /* Bind lookup runs under RCU, remain live during that. */ - sock_set_flag(sk, SOCK_RCU_FREE); + hash = mctp_bind_hash(msk->bind_type, msk->bind_addr, MCTP_ADDR_ANY); mutex_lock(&net->mctp.bind_lock); /* Prevent duplicate binds. */ - sk_for_each(existing, &net->mctp.binds) { + sk_for_each(existing, &net->mctp.binds[hash]) { struct mctp_sock *mex = container_of(existing, struct mctp_sock, sk); if (mex->bind_type == msk->bind_type && @@ -664,7 +664,10 @@ static int mctp_sk_hash(struct sock *sk) } } - sk_add_node_rcu(sk, &net->mctp.binds); + /* Bind lookup runs under RCU, remain live during that. */ + sock_set_flag(sk, SOCK_RCU_FREE); + + sk_add_node_rcu(sk, &net->mctp.binds[hash]); rc = 0; out: diff --git a/net/mctp/route.c b/net/mctp/route.c index d9c8e5a5f9ce9aefbf16730c65a1f54caa5592b9..8a8c7841d2382717b3c9a6698036d56f64da77f0 100644 --- a/net/mctp/route.c +++ b/net/mctp/route.c @@ -38,14 +38,45 @@ static int mctp_route_discard(struct mctp_route *route, struct sk_buff *skb) return 0; } -static struct mctp_sock *mctp_lookup_bind(struct net *net, struct sk_buff *skb) +static struct mctp_sock *mctp_lookup_bind_details(struct net *net, + struct sk_buff *skb, + u8 type, u8 dest, + u8 src, bool allow_net_any) { struct mctp_skb_cb *cb = mctp_cb(skb); - struct mctp_hdr *mh; struct sock *sk; - u8 type; + u8 hash; - WARN_ON(!rcu_read_lock_held()); + WARN_ON_ONCE(!rcu_read_lock_held()); + + hash = mctp_bind_hash(type, dest, src); + + sk_for_each_rcu(sk, &net->mctp.binds[hash]) { + struct mctp_sock *msk = container_of(sk, struct mctp_sock, sk); + + if (!allow_net_any && msk->bind_net == MCTP_NET_ANY) + continue; + + if (msk->bind_net != MCTP_NET_ANY && msk->bind_net != cb->net) + continue; + + if (msk->bind_type != type) + continue; + + if (!mctp_address_matches(msk->bind_addr, dest)) + continue; + + return msk; + } + + return NULL; +} + +static struct mctp_sock *mctp_lookup_bind(struct net *net, struct sk_buff *skb) +{ + struct mctp_sock *msk; + struct mctp_hdr *mh; + u8 type; /* TODO: look up in skb->cb? */ mh = mctp_hdr(skb); @@ -55,20 +86,31 @@ static struct mctp_sock *mctp_lookup_bind(struct net *net, struct sk_buff *skb) type = (*(u8 *)skb->data) & 0x7f; - sk_for_each_rcu(sk, &net->mctp.binds) { - struct mctp_sock *msk = container_of(sk, struct mctp_sock, sk); - - if (msk->bind_net != MCTP_NET_ANY && msk->bind_net != cb->net) - continue; - - if (msk->bind_type != type) - continue; - - if (!mctp_address_matches(msk->bind_addr, mh->dest)) - continue; + /* Look for binds in order of widening scope. A given destination or + * source address also implies matching on a particular network. + * + * - Matching destination and source + * - Matching destination + * - Matching source + * - Matching network, any address + * - Any network or address + */ + msk = mctp_lookup_bind_details(net, skb, type, mh->dest, mh->src, false); + if (msk) + return msk; + msk = mctp_lookup_bind_details(net, skb, type, MCTP_ADDR_ANY, mh->src, false); + if (msk) + return msk; + msk = mctp_lookup_bind_details(net, skb, type, mh->dest, MCTP_ADDR_ANY, false); + if (msk) + return msk; + msk = mctp_lookup_bind_details(net, skb, type, MCTP_ADDR_ANY, MCTP_ADDR_ANY, false); + if (msk) + return msk; + msk = mctp_lookup_bind_details(net, skb, type, MCTP_ADDR_ANY, MCTP_ADDR_ANY, true); + if (msk) return msk; - } return NULL; } @@ -1475,7 +1517,7 @@ static int __net_init mctp_routes_net_init(struct net *net) struct netns_mctp *ns = &net->mctp; INIT_LIST_HEAD(&ns->routes); - INIT_HLIST_HEAD(&ns->binds); + hash_init(ns->binds); mutex_init(&ns->bind_lock); INIT_HLIST_HEAD(&ns->keys); spin_lock_init(&ns->keys_lock); -- 2.43.0