The sock was not being released. Other than leaking, the stale socket will conflict with subsequent bind() calls in unrelated MCTP tests. Fixes: 11b67f6f22d6 ("net: mctp: test: Add extaddr routing output test") Signed-off-by: Matt Johnston --- Added in v3. The problem was introduced in current net-next so this patch isn't needed in the stable tree. --- net/mctp/test/route-test.c | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/net/mctp/test/route-test.c b/net/mctp/test/route-test.c index 7a398f41b6216afef72adecf118199753ed1bfea..12811032a2696167b4f319cbc9c81fef4cb2d951 100644 --- a/net/mctp/test/route-test.c +++ b/net/mctp/test/route-test.c @@ -1164,8 +1164,6 @@ static void mctp_test_route_extaddr_input(struct kunit *test) rc = mctp_dst_input(&dst, skb); KUNIT_ASSERT_EQ(test, rc, 0); - mctp_test_dst_release(&dst, &tpq); - skb2 = skb_recv_datagram(sock->sk, MSG_DONTWAIT, &rc); KUNIT_ASSERT_NOT_ERR_OR_NULL(test, skb2); KUNIT_ASSERT_EQ(test, skb2->len, len); @@ -1179,8 +1177,8 @@ static void mctp_test_route_extaddr_input(struct kunit *test) KUNIT_EXPECT_EQ(test, cb2->halen, sizeof(haddr)); KUNIT_EXPECT_MEMEQ(test, cb2->haddr, haddr, sizeof(haddr)); - skb_free_datagram(sock->sk, skb2); - mctp_test_destroy_dev(dev); + kfree_skb(skb2); + __mctp_route_test_fini(test, dev, &dst, &tpq, sock); } static void mctp_test_route_gw_lookup(struct kunit *test) -- 2.43.0 Disallow bind() calls that have the same arguments as existing bound sockets. Previously multiple sockets could bind() to the same type/local address, with an arbitrary socket receiving matched messages. This is only a partial fix, a future commit will define precedence order for MCTP_ADDR_ANY versus specific EID bind(), which are allowed to exist together. Signed-off-by: Matt Johnston --- net/mctp/af_mctp.c | 28 ++++++++++++++++++++++++---- 1 file changed, 24 insertions(+), 4 deletions(-) diff --git a/net/mctp/af_mctp.c b/net/mctp/af_mctp.c index aef74308c18e3273008cb84aabe23ff700d0f842..0d073bc32ec17905ac0118d1aa653a46d829b150 100644 --- a/net/mctp/af_mctp.c +++ b/net/mctp/af_mctp.c @@ -73,7 +73,6 @@ static int mctp_bind(struct socket *sock, struct sockaddr *addr, int addrlen) lock_sock(sk); - /* TODO: allow rebind */ if (sk_hashed(sk)) { rc = -EADDRINUSE; goto out_release; @@ -611,15 +610,36 @@ static void mctp_sk_close(struct sock *sk, long timeout) static int mctp_sk_hash(struct sock *sk) { struct net *net = sock_net(sk); + struct sock *existing; + struct mctp_sock *msk; + int rc; + + msk = container_of(sk, struct mctp_sock, sk); /* Bind lookup runs under RCU, remain live during that. */ sock_set_flag(sk, SOCK_RCU_FREE); mutex_lock(&net->mctp.bind_lock); - sk_add_node_rcu(sk, &net->mctp.binds); - mutex_unlock(&net->mctp.bind_lock); - return 0; + /* Prevent duplicate binds. */ + sk_for_each(existing, &net->mctp.binds) { + struct mctp_sock *mex = + container_of(existing, struct mctp_sock, sk); + + if (mex->bind_type == msk->bind_type && + mex->bind_addr == msk->bind_addr && + mex->bind_net == msk->bind_net) { + rc = -EADDRINUSE; + goto out; + } + } + + sk_add_node_rcu(sk, &net->mctp.binds); + rc = 0; + +out: + mutex_unlock(&net->mctp.bind_lock); + return rc; } static void mctp_sk_unhash(struct sock *sk) -- 2.43.0 When a specific EID is passed as a bind address, it only makes sense to interpret with an actual network ID, so resolve that to the default network at bind time. For bind address of MCTP_ADDR_ANY, we want to be able to capture traffic to any network and address, so keep the current behaviour of matching traffic from any network interface (don't interpret MCTP_NET_ANY as the default network ID). Signed-off-by: Matt Johnston --- net/mctp/af_mctp.c | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/net/mctp/af_mctp.c b/net/mctp/af_mctp.c index 0d073bc32ec17905ac0118d1aa653a46d829b150..20edaf840a607700c04b740708763fbd02a2df47 100644 --- a/net/mctp/af_mctp.c +++ b/net/mctp/af_mctp.c @@ -53,6 +53,7 @@ static int mctp_bind(struct socket *sock, struct sockaddr *addr, int addrlen) { struct sock *sk = sock->sk; struct mctp_sock *msk = container_of(sk, struct mctp_sock, sk); + struct net *net = sock_net(&msk->sk); struct sockaddr_mctp *smctp; int rc; @@ -77,8 +78,21 @@ static int mctp_bind(struct socket *sock, struct sockaddr *addr, int addrlen) rc = -EADDRINUSE; goto out_release; } - msk->bind_net = smctp->smctp_network; + msk->bind_addr = smctp->smctp_addr.s_addr; + + /* MCTP_NET_ANY with a specific EID is resolved to the default net + * at bind() time. + * For bind_addr=MCTP_ADDR_ANY it is handled specially at route + * lookup time. + */ + if (smctp->smctp_network == MCTP_NET_ANY && + msk->bind_addr != MCTP_ADDR_ANY) { + msk->bind_net = mctp_default_net(net); + } else { + msk->bind_net = smctp->smctp_network; + } + msk->bind_type = smctp->smctp_type & 0x7f; /* ignore the IC bit */ rc = sk->sk_prot->hash(sk); -- 2.43.0 Test pairwise combinations of bind addresses and types. Signed-off-by: Matt Johnston --- v3: - Moved test code to mctp/test/sock-test.c recently added in net-next, common bind test code in mctp/test/utils.c v2: - Remove unused bind test case - Fix line lengths --- net/mctp/test/sock-test.c | 130 ++++++++++++++++++++++++++++++++++++++++++++++ net/mctp/test/utils.c | 22 ++++++++ net/mctp/test/utils.h | 10 ++++ 3 files changed, 162 insertions(+) diff --git a/net/mctp/test/sock-test.c b/net/mctp/test/sock-test.c index 4eb3a724dca39eb22615cbfc1201b45ee4c78d16..0cfc337be687e7ad903023d2fae9f12f75628532 100644 --- a/net/mctp/test/sock-test.c +++ b/net/mctp/test/sock-test.c @@ -215,9 +215,139 @@ static void mctp_test_sock_recvmsg_extaddr(struct kunit *test) __mctp_sock_test_fini(test, dev, rt, sock); } +static const struct mctp_test_bind_setup bind_addrany_netdefault_type1 = { + .bind_addr = MCTP_ADDR_ANY, .bind_net = MCTP_NET_ANY, .bind_type = 1, +}; + +static const struct mctp_test_bind_setup bind_addrany_net2_type1 = { + .bind_addr = MCTP_ADDR_ANY, .bind_net = 2, .bind_type = 1, +}; + +/* 1 is default net */ +static const struct mctp_test_bind_setup bind_addr8_net1_type1 = { + .bind_addr = 8, .bind_net = 1, .bind_type = 1, +}; + +static const struct mctp_test_bind_setup bind_addrany_net1_type1 = { + .bind_addr = MCTP_ADDR_ANY, .bind_net = 1, .bind_type = 1, +}; + +/* 2 is an arbitrary net */ +static const struct mctp_test_bind_setup bind_addr8_net2_type1 = { + .bind_addr = 8, .bind_net = 2, .bind_type = 1, +}; + +static const struct mctp_test_bind_setup bind_addr8_netdefault_type1 = { + .bind_addr = 8, .bind_net = MCTP_NET_ANY, .bind_type = 1, +}; + +static const struct mctp_test_bind_setup bind_addrany_net2_type2 = { + .bind_addr = MCTP_ADDR_ANY, .bind_net = 2, .bind_type = 2, +}; + +struct mctp_bind_pair_test { + const struct mctp_test_bind_setup *bind1; + const struct mctp_test_bind_setup *bind2; + int error; +}; + +/* Pairs of binds and whether they will conflict */ +static const struct mctp_bind_pair_test mctp_bind_pair_tests[] = { + /* Both ADDR_ANY, conflict */ + { &bind_addrany_netdefault_type1, &bind_addrany_netdefault_type1, + EADDRINUSE }, + /* Same specific EID, conflict */ + { &bind_addr8_netdefault_type1, &bind_addr8_netdefault_type1, + EADDRINUSE }, + /* ADDR_ANY vs specific EID, OK */ + { &bind_addrany_netdefault_type1, &bind_addr8_netdefault_type1, 0 }, + /* ADDR_ANY different types, OK */ + { &bind_addrany_net2_type2, &bind_addrany_net2_type1, 0 }, + /* ADDR_ANY different nets, OK */ + { &bind_addrany_net2_type1, &bind_addrany_netdefault_type1, 0 }, + + /* specific EID, NET_ANY (resolves to default) + * vs specific EID, explicit default net 1, conflict + */ + { &bind_addr8_netdefault_type1, &bind_addr8_net1_type1, EADDRINUSE }, + + /* specific EID, net 1 vs specific EID, net 2, ok */ + { &bind_addr8_net1_type1, &bind_addr8_net2_type1, 0 }, + + /* ANY_ADDR, NET_ANY (doesn't resolve to default) + * vs ADDR_ANY, explicit default net 1, OK + */ + { &bind_addrany_netdefault_type1, &bind_addrany_net1_type1, 0 }, +}; + +static void mctp_bind_pair_desc(const struct mctp_bind_pair_test *t, char *desc) +{ + snprintf(desc, KUNIT_PARAM_DESC_SIZE, + "{bind(addr %d, type %d, net %d)} {bind(addr %d, type %d, net %d)} -> error %d", + t->bind1->bind_addr, t->bind1->bind_type, t->bind1->bind_net, + t->bind2->bind_addr, t->bind2->bind_type, t->bind2->bind_net, + t->error); +} + +KUNIT_ARRAY_PARAM(mctp_bind_pair, mctp_bind_pair_tests, mctp_bind_pair_desc); + +static int +mctp_test_bind_conflicts_inner(struct kunit *test, + const struct mctp_test_bind_setup *bind1, + const struct mctp_test_bind_setup *bind2) +{ + struct socket *sock1 = NULL, *sock2 = NULL, *sock3 = NULL; + int bind_errno; + + /* Bind to first address, always succeeds */ + mctp_test_bind_run(test, bind1, &bind_errno, &sock1); + KUNIT_EXPECT_EQ(test, bind_errno, 0); + + /* A second identical bind always fails */ + mctp_test_bind_run(test, bind1, &bind_errno, &sock2); + KUNIT_EXPECT_EQ(test, -bind_errno, EADDRINUSE); + + /* A different bind, result is returned */ + mctp_test_bind_run(test, bind2, &bind_errno, &sock3); + + if (sock1) + sock_release(sock1); + if (sock2) + sock_release(sock2); + if (sock3) + sock_release(sock3); + + return bind_errno; +} + +static void mctp_test_bind_conflicts(struct kunit *test) +{ + const struct mctp_bind_pair_test *pair; + int bind_errno; + + pair = test->param_value; + + bind_errno = + mctp_test_bind_conflicts_inner(test, pair->bind1, pair->bind2); + KUNIT_EXPECT_EQ(test, -bind_errno, pair->error); + + /* swapping the calls, the second bind should still fail */ + bind_errno = + mctp_test_bind_conflicts_inner(test, pair->bind2, pair->bind1); + KUNIT_EXPECT_EQ(test, -bind_errno, pair->error); +} + +static void mctp_test_assumptions(struct kunit *test) +{ + /* check assumption of default net from bind_addr8_net1_type1 */ + KUNIT_ASSERT_EQ(test, mctp_default_net(&init_net), 1); +} + static struct kunit_case mctp_test_cases[] = { + KUNIT_CASE(mctp_test_assumptions), KUNIT_CASE(mctp_test_sock_sendmsg_extaddr), KUNIT_CASE(mctp_test_sock_recvmsg_extaddr), + KUNIT_CASE_PARAM(mctp_test_bind_conflicts, mctp_bind_pair_gen_params), {} }; diff --git a/net/mctp/test/utils.c b/net/mctp/test/utils.c index 01f5af416b814baf812b4352c513ffcdd9939cb2..c971e2c326f3564f95b3f693c450b3e6f3d9c594 100644 --- a/net/mctp/test/utils.c +++ b/net/mctp/test/utils.c @@ -258,3 +258,25 @@ struct sk_buff *__mctp_test_create_skb_data(const struct mctp_hdr *hdr, return skb; } + +void mctp_test_bind_run(struct kunit *test, + const struct mctp_test_bind_setup *setup, + int *ret_bind_errno, struct socket **sock) +{ + struct sockaddr_mctp addr; + int rc; + + *ret_bind_errno = -EIO; + + rc = sock_create_kern(&init_net, AF_MCTP, SOCK_DGRAM, 0, sock); + KUNIT_ASSERT_EQ(test, rc, 0); + + memset(&addr, 0x0, sizeof(addr)); + addr.smctp_family = AF_MCTP; + addr.smctp_network = setup->bind_net; + addr.smctp_addr.s_addr = setup->bind_addr; + addr.smctp_type = setup->bind_type; + + *ret_bind_errno = + kernel_bind(*sock, (struct sockaddr *)&addr, sizeof(addr)); +} diff --git a/net/mctp/test/utils.h b/net/mctp/test/utils.h index f10d1d9066ccde53bbaf471ea79b87b1d94cd755..7dd1a92ab770995db506c24dc805bb9e0839eeef 100644 --- a/net/mctp/test/utils.h +++ b/net/mctp/test/utils.h @@ -31,6 +31,12 @@ struct mctp_test_pktqueue { struct sk_buff_head pkts; }; +struct mctp_test_bind_setup { + mctp_eid_t bind_addr; + int bind_net; + u8 bind_type; +}; + struct mctp_test_dev *mctp_test_create_dev(void); struct mctp_test_dev *mctp_test_create_dev_lladdr(unsigned short lladdr_len, const unsigned char *lladdr); @@ -61,4 +67,8 @@ struct sk_buff *__mctp_test_create_skb_data(const struct mctp_hdr *hdr, #define mctp_test_create_skb_data(h, d) \ __mctp_test_create_skb_data(h, d, sizeof(*d)) +void mctp_test_bind_run(struct kunit *test, + const struct mctp_test_bind_setup *setup, + int *ret_bind_errno, struct socket **sock); + #endif /* __NET_MCTP_TEST_UTILS_H */ -- 2.43.0 Ensure that a specific EID (remote or local) bind will match in preference to a MCTP_ADDR_ANY bind. This adds infrastructure for binding a socket to receive messages from a specific remote peer address, a future commit will expose an API for this. Signed-off-by: Matt Johnston --- v2: - Use DECLARE_HASHTABLE - Fix long lines --- include/net/netns/mctp.h | 20 +++++++++--- net/mctp/af_mctp.c | 11 ++++--- net/mctp/route.c | 81 ++++++++++++++++++++++++++++++++++++++---------- 3 files changed, 87 insertions(+), 25 deletions(-) diff --git a/include/net/netns/mctp.h b/include/net/netns/mctp.h index 1db8f9aaddb4b96f4803df9f30a762f5f88d7f7f..89555f90b97b297e50a571b26c5232b824909da7 100644 --- a/include/net/netns/mctp.h +++ b/include/net/netns/mctp.h @@ -6,19 +6,25 @@ #ifndef __NETNS_MCTP_H__ #define __NETNS_MCTP_H__ +#include +#include #include #include +#define MCTP_BINDS_BITS 7 + struct netns_mctp { /* Only updated under RTNL, entries freed via RCU */ struct list_head routes; - /* Bound sockets: list of sockets bound by type. - * This list is updated from non-atomic contexts (under bind_lock), - * and read (under rcu) in packet rx + /* Bound sockets: hash table of sockets, keyed by + * (type, src_eid, dest_eid). + * Specific src_eid/dest_eid entries also have an entry for + * MCTP_ADDR_ANY. This list is updated from non-atomic contexts + * (under bind_lock), and read (under rcu) in packet rx. */ struct mutex bind_lock; - struct hlist_head binds; + DECLARE_HASHTABLE(binds, MCTP_BINDS_BITS); /* tag allocations. This list is read and updated from atomic contexts, * but elements are free()ed after a RCU grace-period @@ -34,4 +40,10 @@ struct netns_mctp { struct list_head neighbours; }; +static inline u32 mctp_bind_hash(u8 type, u8 local_addr, u8 peer_addr) +{ + return hash_32(type | (u32)local_addr << 8 | (u32)peer_addr << 16, + MCTP_BINDS_BITS); +} + #endif /* __NETNS_MCTP_H__ */ diff --git a/net/mctp/af_mctp.c b/net/mctp/af_mctp.c index 20edaf840a607700c04b740708763fbd02a2df47..16341de5cf2893bbc04a8c05a038c30be6570296 100644 --- a/net/mctp/af_mctp.c +++ b/net/mctp/af_mctp.c @@ -626,17 +626,17 @@ static int mctp_sk_hash(struct sock *sk) struct net *net = sock_net(sk); struct sock *existing; struct mctp_sock *msk; + u32 hash; int rc; msk = container_of(sk, struct mctp_sock, sk); - /* Bind lookup runs under RCU, remain live during that. */ - sock_set_flag(sk, SOCK_RCU_FREE); + hash = mctp_bind_hash(msk->bind_type, msk->bind_addr, MCTP_ADDR_ANY); mutex_lock(&net->mctp.bind_lock); /* Prevent duplicate binds. */ - sk_for_each(existing, &net->mctp.binds) { + sk_for_each(existing, &net->mctp.binds[hash]) { struct mctp_sock *mex = container_of(existing, struct mctp_sock, sk); @@ -648,7 +648,10 @@ static int mctp_sk_hash(struct sock *sk) } } - sk_add_node_rcu(sk, &net->mctp.binds); + /* Bind lookup runs under RCU, remain live during that. */ + sock_set_flag(sk, SOCK_RCU_FREE); + + sk_add_node_rcu(sk, &net->mctp.binds[hash]); rc = 0; out: diff --git a/net/mctp/route.c b/net/mctp/route.c index a20d6b11d4186b55cab9d76e367169ea712553c7..69cfb0e6c545c2b44e5defdfac4e602c4f0265b1 100644 --- a/net/mctp/route.c +++ b/net/mctp/route.c @@ -40,14 +40,45 @@ static int mctp_dst_discard(struct mctp_dst *dst, struct sk_buff *skb) return 0; } -static struct mctp_sock *mctp_lookup_bind(struct net *net, struct sk_buff *skb) +static struct mctp_sock *mctp_lookup_bind_details(struct net *net, + struct sk_buff *skb, + u8 type, u8 dest, + u8 src, bool allow_net_any) { struct mctp_skb_cb *cb = mctp_cb(skb); - struct mctp_hdr *mh; struct sock *sk; - u8 type; + u8 hash; - WARN_ON(!rcu_read_lock_held()); + WARN_ON_ONCE(!rcu_read_lock_held()); + + hash = mctp_bind_hash(type, dest, src); + + sk_for_each_rcu(sk, &net->mctp.binds[hash]) { + struct mctp_sock *msk = container_of(sk, struct mctp_sock, sk); + + if (!allow_net_any && msk->bind_net == MCTP_NET_ANY) + continue; + + if (msk->bind_net != MCTP_NET_ANY && msk->bind_net != cb->net) + continue; + + if (msk->bind_type != type) + continue; + + if (!mctp_address_matches(msk->bind_addr, dest)) + continue; + + return msk; + } + + return NULL; +} + +static struct mctp_sock *mctp_lookup_bind(struct net *net, struct sk_buff *skb) +{ + struct mctp_sock *msk; + struct mctp_hdr *mh; + u8 type; /* TODO: look up in skb->cb? */ mh = mctp_hdr(skb); @@ -57,20 +88,36 @@ static struct mctp_sock *mctp_lookup_bind(struct net *net, struct sk_buff *skb) type = (*(u8 *)skb->data) & 0x7f; - sk_for_each_rcu(sk, &net->mctp.binds) { - struct mctp_sock *msk = container_of(sk, struct mctp_sock, sk); - - if (msk->bind_net != MCTP_NET_ANY && msk->bind_net != cb->net) - continue; - - if (msk->bind_type != type) - continue; - - if (!mctp_address_matches(msk->bind_addr, mh->dest)) - continue; + /* Look for binds in order of widening scope. A given destination or + * source address also implies matching on a particular network. + * + * - Matching destination and source + * - Matching destination + * - Matching source + * - Matching network, any address + * - Any network or address + */ + msk = mctp_lookup_bind_details(net, skb, type, mh->dest, mh->src, + false); + if (msk) + return msk; + msk = mctp_lookup_bind_details(net, skb, type, MCTP_ADDR_ANY, mh->src, + false); + if (msk) + return msk; + msk = mctp_lookup_bind_details(net, skb, type, mh->dest, MCTP_ADDR_ANY, + false); + if (msk) + return msk; + msk = mctp_lookup_bind_details(net, skb, type, MCTP_ADDR_ANY, + MCTP_ADDR_ANY, false); + if (msk) + return msk; + msk = mctp_lookup_bind_details(net, skb, type, MCTP_ADDR_ANY, + MCTP_ADDR_ANY, true); + if (msk) return msk; - } return NULL; } @@ -1671,7 +1718,7 @@ static int __net_init mctp_routes_net_init(struct net *net) struct netns_mctp *ns = &net->mctp; INIT_LIST_HEAD(&ns->routes); - INIT_HLIST_HEAD(&ns->binds); + hash_init(ns->binds); mutex_init(&ns->bind_lock); INIT_HLIST_HEAD(&ns->keys); spin_lock_init(&ns->keys_lock); -- 2.43.0 Prior to calling bind() a program may call connect() on a socket to restrict to a remote peer address. Using connect() is the normal mechanism to specify a remote network peer, so we use that here. In MCTP connect() is only used for bound sockets - send() is not available for MCTP since a tag must be provided for each message. The smctp_type must match between connect() and bind() calls. Signed-off-by: Matt Johnston --- include/net/mctp.h | 5 ++- net/mctp/af_mctp.c | 103 +++++++++++++++++++++++++++++++++++++++++++++++++---- net/mctp/route.c | 6 +++- 3 files changed, 106 insertions(+), 8 deletions(-) diff --git a/include/net/mctp.h b/include/net/mctp.h index ac4f4ecdfc24f1f481ff22a5673cb95e1bf21310..c3207ce98f07fcbb436e968d503bc45666794fdc 100644 --- a/include/net/mctp.h +++ b/include/net/mctp.h @@ -69,7 +69,10 @@ struct mctp_sock { /* bind() params */ unsigned int bind_net; - mctp_eid_t bind_addr; + mctp_eid_t bind_local_addr; + mctp_eid_t bind_peer_addr; + unsigned int bind_peer_net; + bool bind_peer_set; __u8 bind_type; /* sendmsg()/recvmsg() uses struct sockaddr_mctp_ext */ diff --git a/net/mctp/af_mctp.c b/net/mctp/af_mctp.c index 16341de5cf2893bbc04a8c05a038c30be6570296..79f3c53afebe4aafc14b710d8af2582e662b2df8 100644 --- a/net/mctp/af_mctp.c +++ b/net/mctp/af_mctp.c @@ -79,7 +79,7 @@ static int mctp_bind(struct socket *sock, struct sockaddr *addr, int addrlen) goto out_release; } - msk->bind_addr = smctp->smctp_addr.s_addr; + msk->bind_local_addr = smctp->smctp_addr.s_addr; /* MCTP_NET_ANY with a specific EID is resolved to the default net * at bind() time. @@ -87,13 +87,33 @@ static int mctp_bind(struct socket *sock, struct sockaddr *addr, int addrlen) * lookup time. */ if (smctp->smctp_network == MCTP_NET_ANY && - msk->bind_addr != MCTP_ADDR_ANY) { + msk->bind_local_addr != MCTP_ADDR_ANY) { msk->bind_net = mctp_default_net(net); } else { msk->bind_net = smctp->smctp_network; } - msk->bind_type = smctp->smctp_type & 0x7f; /* ignore the IC bit */ + /* ignore the IC bit */ + smctp->smctp_type &= 0x7f; + + if (msk->bind_peer_set) { + if (msk->bind_type != smctp->smctp_type) { + /* Prior connect() had a different type */ + return -EINVAL; + } + + if (msk->bind_net == MCTP_NET_ANY) { + /* Restrict to the network passed to connect() */ + msk->bind_net = msk->bind_peer_net; + } + + if (msk->bind_net != msk->bind_peer_net) { + /* connect() had a different net to bind() */ + return -EINVAL; + } + } else { + msk->bind_type = smctp->smctp_type; + } rc = sk->sk_prot->hash(sk); @@ -103,6 +123,67 @@ static int mctp_bind(struct socket *sock, struct sockaddr *addr, int addrlen) return rc; } +/* Used to set a specific peer prior to bind. Not used for outbound + * connections (Tag Owner set) since MCTP is a datagram protocol. + */ +static int mctp_connect(struct socket *sock, struct sockaddr *addr, + int addrlen, int flags) +{ + struct sock *sk = sock->sk; + struct mctp_sock *msk = container_of(sk, struct mctp_sock, sk); + struct net *net = sock_net(&msk->sk); + struct sockaddr_mctp *smctp; + int rc; + + if (addrlen != sizeof(*smctp)) + return -EINVAL; + + if (addr->sa_family != AF_MCTP) + return -EAFNOSUPPORT; + + /* It's a valid sockaddr for MCTP, cast and do protocol checks */ + smctp = (struct sockaddr_mctp *)addr; + + if (!mctp_sockaddr_is_ok(smctp)) + return -EINVAL; + + /* Can't bind by tag */ + if (smctp->smctp_tag) + return -EINVAL; + + /* IC bit must be unset */ + if (smctp->smctp_type & 0x80) + return -EINVAL; + + lock_sock(sk); + + if (sk_hashed(sk)) { + /* bind() already */ + rc = -EADDRINUSE; + goto out_release; + } + + if (msk->bind_peer_set) { + /* connect() already */ + rc = -EADDRINUSE; + goto out_release; + } + + msk->bind_peer_set = true; + msk->bind_peer_addr = smctp->smctp_addr.s_addr; + msk->bind_type = smctp->smctp_type; + if (smctp->smctp_network == MCTP_NET_ANY) + msk->bind_peer_net = mctp_default_net(net); + else + msk->bind_peer_net = smctp->smctp_network; + + rc = 0; + +out_release: + release_sock(sk); + return rc; +} + static int mctp_sendmsg(struct socket *sock, struct msghdr *msg, size_t len) { DECLARE_SOCKADDR(struct sockaddr_mctp *, addr, msg->msg_name); @@ -546,7 +627,7 @@ static const struct proto_ops mctp_dgram_ops = { .family = PF_MCTP, .release = mctp_release, .bind = mctp_bind, - .connect = sock_no_connect, + .connect = mctp_connect, .socketpair = sock_no_socketpair, .accept = sock_no_accept, .getname = sock_no_getname, @@ -613,6 +694,7 @@ static int mctp_sk_init(struct sock *sk) INIT_HLIST_HEAD(&msk->keys); timer_setup(&msk->key_expiry, mctp_sk_expire_keys, 0); + msk->bind_peer_set = false; return 0; } @@ -626,12 +708,17 @@ static int mctp_sk_hash(struct sock *sk) struct net *net = sock_net(sk); struct sock *existing; struct mctp_sock *msk; + mctp_eid_t remote; u32 hash; int rc; msk = container_of(sk, struct mctp_sock, sk); - hash = mctp_bind_hash(msk->bind_type, msk->bind_addr, MCTP_ADDR_ANY); + if (msk->bind_peer_set) + remote = msk->bind_peer_addr; + else + remote = MCTP_ADDR_ANY; + hash = mctp_bind_hash(msk->bind_type, msk->bind_local_addr, remote); mutex_lock(&net->mctp.bind_lock); @@ -640,8 +727,12 @@ static int mctp_sk_hash(struct sock *sk) struct mctp_sock *mex = container_of(existing, struct mctp_sock, sk); + bool same_peer = (mex->bind_peer_set && msk->bind_peer_set && + mex->bind_peer_addr == msk->bind_peer_addr) || + (!mex->bind_peer_set && !msk->bind_peer_set); + if (mex->bind_type == msk->bind_type && - mex->bind_addr == msk->bind_addr && + mex->bind_local_addr == msk->bind_local_addr && same_peer && mex->bind_net == msk->bind_net) { rc = -EADDRINUSE; goto out; diff --git a/net/mctp/route.c b/net/mctp/route.c index 69cfb0e6c545c2b44e5defdfac4e602c4f0265b1..2b2b958ef6a37525cc4d3f6a5758bd3880c98e6c 100644 --- a/net/mctp/route.c +++ b/net/mctp/route.c @@ -65,7 +65,11 @@ static struct mctp_sock *mctp_lookup_bind_details(struct net *net, if (msk->bind_type != type) continue; - if (!mctp_address_matches(msk->bind_addr, dest)) + if (msk->bind_peer_set && + !mctp_address_matches(msk->bind_peer_addr, src)) + continue; + + if (!mctp_address_matches(msk->bind_local_addr, dest)) continue; return msk; -- 2.43.0 The addition of connect() adds new conflict cases to test. Signed-off-by: Matt Johnston --- v3: - Moved test code to mctp/test/sock-test.c recently added in net-next --- net/mctp/test/sock-test.c | 45 +++++++++++++++++++++++++++++++++++++++++---- net/mctp/test/utils.c | 14 ++++++++++++++ net/mctp/test/utils.h | 4 ++++ 3 files changed, 59 insertions(+), 4 deletions(-) diff --git a/net/mctp/test/sock-test.c b/net/mctp/test/sock-test.c index 0cfc337be687e7ad903023d2fae9f12f75628532..b0942deb501980f196ce13da19ab171a3a9c9b8c 100644 --- a/net/mctp/test/sock-test.c +++ b/net/mctp/test/sock-test.c @@ -245,6 +245,11 @@ static const struct mctp_test_bind_setup bind_addrany_net2_type2 = { .bind_addr = MCTP_ADDR_ANY, .bind_net = 2, .bind_type = 2, }; +static const struct mctp_test_bind_setup bind_addrany_net2_type1_peer9 = { + .bind_addr = MCTP_ADDR_ANY, .bind_net = 2, .bind_type = 1, + .have_peer = true, .peer_addr = 9, .peer_net = 2, +}; + struct mctp_bind_pair_test { const struct mctp_test_bind_setup *bind1; const struct mctp_test_bind_setup *bind2; @@ -278,19 +283,50 @@ static const struct mctp_bind_pair_test mctp_bind_pair_tests[] = { * vs ADDR_ANY, explicit default net 1, OK */ { &bind_addrany_netdefault_type1, &bind_addrany_net1_type1, 0 }, + + /* specific remote peer doesn't conflict with any-peer bind */ + { &bind_addrany_net2_type1_peer9, &bind_addrany_net2_type1, 0 }, + + /* bind() NET_ANY is allowed with a connect() net */ + { &bind_addrany_net2_type1_peer9, &bind_addrany_netdefault_type1, 0 }, }; static void mctp_bind_pair_desc(const struct mctp_bind_pair_test *t, char *desc) { + char peer1[25] = {0}, peer2[25] = {0}; + + if (t->bind1->have_peer) + snprintf(peer1, sizeof(peer1), ", peer %d net %d", + t->bind1->peer_addr, t->bind1->peer_net); + if (t->bind2->have_peer) + snprintf(peer2, sizeof(peer2), ", peer %d net %d", + t->bind2->peer_addr, t->bind2->peer_net); + snprintf(desc, KUNIT_PARAM_DESC_SIZE, - "{bind(addr %d, type %d, net %d)} {bind(addr %d, type %d, net %d)} -> error %d", - t->bind1->bind_addr, t->bind1->bind_type, t->bind1->bind_net, - t->bind2->bind_addr, t->bind2->bind_type, t->bind2->bind_net, - t->error); + "{bind(addr %d, type %d, net %d%s)} {bind(addr %d, type %d, net %d%s)} -> error %d", + t->bind1->bind_addr, t->bind1->bind_type, + t->bind1->bind_net, peer1, + t->bind2->bind_addr, t->bind2->bind_type, + t->bind2->bind_net, peer2, t->error); } KUNIT_ARRAY_PARAM(mctp_bind_pair, mctp_bind_pair_tests, mctp_bind_pair_desc); +static void mctp_test_bind_invalid(struct kunit *test) +{ + struct socket *sock; + int rc; + + /* bind() fails if the bind() vs connect() networks mismatch. */ + const struct mctp_test_bind_setup bind_connect_net_mismatch = { + .bind_addr = MCTP_ADDR_ANY, .bind_net = 1, .bind_type = 1, + .have_peer = true, .peer_addr = 9, .peer_net = 2, + }; + mctp_test_bind_run(test, &bind_connect_net_mismatch, &rc, &sock); + KUNIT_EXPECT_EQ(test, -rc, EINVAL); + sock_release(sock); +} + static int mctp_test_bind_conflicts_inner(struct kunit *test, const struct mctp_test_bind_setup *bind1, @@ -348,6 +384,7 @@ static struct kunit_case mctp_test_cases[] = { KUNIT_CASE(mctp_test_sock_sendmsg_extaddr), KUNIT_CASE(mctp_test_sock_recvmsg_extaddr), KUNIT_CASE_PARAM(mctp_test_bind_conflicts, mctp_bind_pair_gen_params), + KUNIT_CASE(mctp_test_bind_invalid), {} }; diff --git a/net/mctp/test/utils.c b/net/mctp/test/utils.c index c971e2c326f3564f95b3f693c450b3e6f3d9c594..953d419027718959d913956c4c3893ef91624eb5 100644 --- a/net/mctp/test/utils.c +++ b/net/mctp/test/utils.c @@ -271,6 +271,20 @@ void mctp_test_bind_run(struct kunit *test, rc = sock_create_kern(&init_net, AF_MCTP, SOCK_DGRAM, 0, sock); KUNIT_ASSERT_EQ(test, rc, 0); + /* connect() if requested */ + if (setup->have_peer) { + memset(&addr, 0x0, sizeof(addr)); + addr.smctp_family = AF_MCTP; + addr.smctp_network = setup->peer_net; + addr.smctp_addr.s_addr = setup->peer_addr; + /* connect() type must match bind() type */ + addr.smctp_type = setup->bind_type; + rc = kernel_connect(*sock, (struct sockaddr *)&addr, + sizeof(addr), 0); + KUNIT_EXPECT_EQ(test, rc, 0); + } + + /* bind() */ memset(&addr, 0x0, sizeof(addr)); addr.smctp_family = AF_MCTP; addr.smctp_network = setup->bind_net; diff --git a/net/mctp/test/utils.h b/net/mctp/test/utils.h index 7dd1a92ab770995db506c24dc805bb9e0839eeef..c2aaba5188ab82237cb3bcc00d5abf1942753b9d 100644 --- a/net/mctp/test/utils.h +++ b/net/mctp/test/utils.h @@ -35,6 +35,10 @@ struct mctp_test_bind_setup { mctp_eid_t bind_addr; int bind_net; u8 bind_type; + + bool have_peer; + mctp_eid_t peer_addr; + int peer_net; }; struct mctp_test_dev *mctp_test_create_dev(void); -- 2.43.0 Test the preference order of bound socket matches with a series of test packets. Signed-off-by: Matt Johnston --- v3: - Updated test code for changes from net-next --- net/mctp/test/route-test.c | 188 +++++++++++++++++++++++++++++++++++++++++++++ net/mctp/test/utils.h | 3 + 2 files changed, 191 insertions(+) diff --git a/net/mctp/test/route-test.c b/net/mctp/test/route-test.c index 12811032a2696167b4f319cbc9c81fef4cb2d951..fb6b46a952cb432163f6adb40bb395d658745efd 100644 --- a/net/mctp/test/route-test.c +++ b/net/mctp/test/route-test.c @@ -1408,6 +1408,193 @@ static void mctp_test_route_gw_output(struct kunit *test) kfree_skb(skb); } +struct mctp_bind_lookup_test { + /* header of incoming message */ + struct mctp_hdr hdr; + u8 ty; + /* mctp network of incoming interface (smctp_network) */ + unsigned int net; + + /* expected socket, matches .name in lookup_binds, NULL for dropped */ + const char *expect; +}; + +/* Single-packet TO-set message */ +#define LK(src, dst) RX_HDR(1, (src), (dst), FL_S | FL_E | FL_TO) + +/* Input message test cases for bind lookup tests. + * + * 10 and 11 are local EIDs. + * 20 and 21 are remote EIDs. + */ +static const struct mctp_bind_lookup_test mctp_bind_lookup_tests[] = { + /* both local-eid and remote-eid binds, remote eid is preferenced */ + { .hdr = LK(20, 10), .ty = 1, .net = 1, .expect = "remote20" }, + + { .hdr = LK(20, 255), .ty = 1, .net = 1, .expect = "remote20" }, + { .hdr = LK(20, 0), .ty = 1, .net = 1, .expect = "remote20" }, + { .hdr = LK(0, 255), .ty = 1, .net = 1, .expect = "any" }, + { .hdr = LK(0, 11), .ty = 1, .net = 1, .expect = "any" }, + { .hdr = LK(0, 0), .ty = 1, .net = 1, .expect = "any" }, + { .hdr = LK(0, 10), .ty = 1, .net = 1, .expect = "local10" }, + { .hdr = LK(21, 10), .ty = 1, .net = 1, .expect = "local10" }, + { .hdr = LK(21, 11), .ty = 1, .net = 1, .expect = "remote21local11" }, + + /* both src and dest set to eid=99. unusual, but accepted + * by MCTP stack currently. + */ + { .hdr = LK(99, 99), .ty = 1, .net = 1, .expect = "any" }, + + /* unbound smctp_type */ + { .hdr = LK(20, 10), .ty = 3, .net = 1, .expect = NULL }, + + /* smctp_network tests */ + + { .hdr = LK(0, 0), .ty = 1, .net = 7, .expect = "any" }, + { .hdr = LK(21, 10), .ty = 1, .net = 2, .expect = "any" }, + + /* remote EID 20 matches, but MCTP_NET_ANY in "remote20" resolved + * to net=1, so lookup doesn't match "remote20" + */ + { .hdr = LK(20, 10), .ty = 1, .net = 3, .expect = "any" }, + + { .hdr = LK(21, 10), .ty = 1, .net = 3, .expect = "remote21net3" }, + { .hdr = LK(21, 10), .ty = 1, .net = 4, .expect = "remote21net4" }, + { .hdr = LK(21, 10), .ty = 1, .net = 5, .expect = "remote21net5" }, + + { .hdr = LK(21, 10), .ty = 1, .net = 5, .expect = "remote21net5" }, + + { .hdr = LK(99, 10), .ty = 1, .net = 8, .expect = "local10net8" }, + + { .hdr = LK(99, 10), .ty = 1, .net = 9, .expect = "anynet9" }, + { .hdr = LK(0, 0), .ty = 1, .net = 9, .expect = "anynet9" }, + { .hdr = LK(99, 99), .ty = 1, .net = 9, .expect = "anynet9" }, + { .hdr = LK(20, 10), .ty = 1, .net = 9, .expect = "anynet9" }, +}; + +/* Binds to create during the lookup tests */ +static const struct mctp_test_bind_setup lookup_binds[] = { + /* any address and net, type 1 */ + { .name = "any", .bind_addr = MCTP_ADDR_ANY, + .bind_net = MCTP_NET_ANY, .bind_type = 1, }, + /* local eid 10, net 1 (resolved from MCTP_NET_ANY) */ + { .name = "local10", .bind_addr = 10, + .bind_net = MCTP_NET_ANY, .bind_type = 1, }, + /* local eid 10, net 8 */ + { .name = "local10net8", .bind_addr = 10, + .bind_net = 8, .bind_type = 1, }, + /* any EID, net 9 */ + { .name = "anynet9", .bind_addr = MCTP_ADDR_ANY, + .bind_net = 9, .bind_type = 1, }, + + /* remote eid 20, net 1, any local eid */ + { .name = "remote20", .bind_addr = MCTP_ADDR_ANY, + .bind_net = MCTP_NET_ANY, .bind_type = 1, + .have_peer = true, .peer_addr = 20, .peer_net = MCTP_NET_ANY, }, + + /* remote eid 20, net 1, local eid 11 */ + { .name = "remote21local11", .bind_addr = 11, + .bind_net = MCTP_NET_ANY, .bind_type = 1, + .have_peer = true, .peer_addr = 21, .peer_net = MCTP_NET_ANY, }, + + /* remote eid 21, specific net=3 for connect() */ + { .name = "remote21net3", .bind_addr = MCTP_ADDR_ANY, + .bind_net = MCTP_NET_ANY, .bind_type = 1, + .have_peer = true, .peer_addr = 21, .peer_net = 3, }, + + /* remote eid 21, net 4 for bind, specific net=4 for connect() */ + { .name = "remote21net4", .bind_addr = MCTP_ADDR_ANY, + .bind_net = 4, .bind_type = 1, + .have_peer = true, .peer_addr = 21, .peer_net = 4, }, + + /* remote eid 21, net 5 for bind, specific net=5 for connect() */ + { .name = "remote21net5", .bind_addr = MCTP_ADDR_ANY, + .bind_net = 5, .bind_type = 1, + .have_peer = true, .peer_addr = 21, .peer_net = 5, }, +}; + +static void mctp_bind_lookup_desc(const struct mctp_bind_lookup_test *t, + char *desc) +{ + snprintf(desc, KUNIT_PARAM_DESC_SIZE, + "{src %d dst %d ty %d net %d expect %s}", + t->hdr.src, t->hdr.dest, t->ty, t->net, t->expect); +} + +KUNIT_ARRAY_PARAM(mctp_bind_lookup, mctp_bind_lookup_tests, + mctp_bind_lookup_desc); + +static void mctp_test_bind_lookup(struct kunit *test) +{ + const struct mctp_bind_lookup_test *rx; + struct socket *socks[ARRAY_SIZE(lookup_binds)]; + struct sk_buff *skb_pkt = NULL, *skb_sock = NULL; + struct socket *sock_ty0, *sock_expect = NULL; + struct mctp_test_pktqueue tpq; + struct mctp_test_dev *dev; + struct mctp_dst dst; + int rc; + + rx = test->param_value; + + __mctp_route_test_init(test, &dev, &dst, &tpq, &sock_ty0, rx->net); + /* Create all binds */ + for (size_t i = 0; i < ARRAY_SIZE(lookup_binds); i++) { + mctp_test_bind_run(test, &lookup_binds[i], + &rc, &socks[i]); + KUNIT_ASSERT_EQ(test, rc, 0); + + /* Record the expected receive socket */ + if (rx->expect && + strcmp(rx->expect, lookup_binds[i].name) == 0) { + KUNIT_ASSERT_NULL(test, sock_expect); + sock_expect = socks[i]; + } + } + KUNIT_ASSERT_EQ(test, !!sock_expect, !!rx->expect); + + /* Create test message */ + skb_pkt = mctp_test_create_skb_data(&rx->hdr, &rx->ty); + KUNIT_ASSERT_NOT_ERR_OR_NULL(test, skb_pkt); + mctp_test_skb_set_dev(skb_pkt, dev); + mctp_test_pktqueue_init(&tpq); + + rc = mctp_dst_input(&dst, skb_pkt); + if (rx->expect) { + /* Test the message is received on the expected socket */ + KUNIT_EXPECT_EQ(test, rc, 0); + skb_sock = skb_recv_datagram(sock_expect->sk, + MSG_DONTWAIT, &rc); + if (!skb_sock) { + /* Find which socket received it instead */ + for (size_t i = 0; i < ARRAY_SIZE(lookup_binds); i++) { + skb_sock = skb_recv_datagram(socks[i]->sk, + MSG_DONTWAIT, &rc); + if (skb_sock) { + KUNIT_FAIL(test, + "received on incorrect socket '%s', expect '%s'", + lookup_binds[i].name, + rx->expect); + goto cleanup; + } + } + KUNIT_FAIL(test, "no message received"); + } + } else { + KUNIT_EXPECT_NE(test, rc, 0); + } + +cleanup: + kfree_skb(skb_sock); + kfree_skb(skb_pkt); + + /* Drop all binds */ + for (size_t i = 0; i < ARRAY_SIZE(lookup_binds); i++) + sock_release(socks[i]); + + __mctp_route_test_fini(test, dev, &dst, &tpq, sock_ty0); +} + static struct kunit_case mctp_test_cases[] = { KUNIT_CASE_PARAM(mctp_test_fragment, mctp_frag_gen_params), KUNIT_CASE_PARAM(mctp_test_rx_input, mctp_rx_input_gen_params), @@ -1429,6 +1616,7 @@ static struct kunit_case mctp_test_cases[] = { KUNIT_CASE(mctp_test_route_gw_loop), KUNIT_CASE_PARAM(mctp_test_route_gw_mtu, mctp_route_gw_mtu_gen_params), KUNIT_CASE(mctp_test_route_gw_output), + KUNIT_CASE_PARAM(mctp_test_bind_lookup, mctp_bind_lookup_gen_params), {} }; diff --git a/net/mctp/test/utils.h b/net/mctp/test/utils.h index c2aaba5188ab82237cb3bcc00d5abf1942753b9d..06bdb6cb5eff6560c7378cf37a1bb17757938e82 100644 --- a/net/mctp/test/utils.h +++ b/net/mctp/test/utils.h @@ -39,6 +39,9 @@ struct mctp_test_bind_setup { bool have_peer; mctp_eid_t peer_addr; int peer_net; + + /* optional name. Used for comparison in "lookup" tests */ + const char *name; }; struct mctp_test_dev *mctp_test_create_dev(void); -- 2.43.0