When receiving an ADD_ADDR right after the 3WHS, the connection will switch to 'fully established'. It means the MPTCP worker will be called to treat two events, in this order: ADD_ADDR_RECEIVED, PM_ESTABLISHED. The MPTCP endpoints cannot have the ID 0, because it is reserved to the address and port used by the initial subflow. To be able to deal with this case in different places, msk->mpc_endpoint_id contains the endpoint ID linked to the initial subflow. This variable was only set when treating the first PM_ESTABLISHED event, after ADD_ADDR_RECEIVED. That's why in fill_local_addresses_vec(), the endpoint addresses were compared with the one of the initial subflow, instead of only comparing the IDs. Instead, msk->mpc_endpoint_id is now set when treating ADD_ADDR_RECEIVED as well, if needed, then the IDs can be compared. To be able to do so, the code doing that is now in a dedicated helper, and called from the functions linked to the two actions. While at it, mptcp_endp_get_local_id() has also been moved up, next to this new helper, because they are linked, and to be able to use it in fill_local_addresses_vec() in the next commit. Reviewed-by: Mat Martineau Signed-off-by: Matthieu Baerts (NGI0) --- net/mptcp/pm_kernel.c | 82 +++++++++++++++++++++++++++------------------------ 1 file changed, 44 insertions(+), 38 deletions(-) diff --git a/net/mptcp/pm_kernel.c b/net/mptcp/pm_kernel.c index 117f842fe18e44f6e887d9044c7b0bb55cbb9084..55dbf89d19b8afeb879f5307c035c855601c6b04 100644 --- a/net/mptcp/pm_kernel.c +++ b/net/mptcp/pm_kernel.c @@ -268,6 +268,46 @@ __lookup_addr(struct pm_nl_pernet *pernet, const struct mptcp_addr_info *info) return NULL; } +static u8 mptcp_endp_get_local_id(struct mptcp_sock *msk, + const struct mptcp_addr_info *addr) +{ + return msk->mpc_endpoint_id == addr->id ? 0 : addr->id; +} + +/* Set mpc_endpoint_id, and send MP_PRIO for ID0 if needed */ +static void mptcp_mpc_endpoint_setup(struct mptcp_sock *msk) +{ + struct mptcp_subflow_context *subflow; + struct mptcp_pm_addr_entry *entry; + struct mptcp_addr_info mpc_addr; + struct pm_nl_pernet *pernet; + bool backup = false; + + /* do lazy endpoint usage accounting for the MPC subflows */ + if (likely(msk->pm.status & BIT(MPTCP_PM_MPC_ENDPOINT_ACCOUNTED)) || + !msk->first) + return; + + subflow = mptcp_subflow_ctx(msk->first); + pernet = pm_nl_get_pernet_from_msk(msk); + + mptcp_local_address((struct sock_common *)msk->first, &mpc_addr); + rcu_read_lock(); + entry = __lookup_addr(pernet, &mpc_addr); + if (entry) { + __clear_bit(entry->addr.id, msk->pm.id_avail_bitmap); + msk->mpc_endpoint_id = entry->addr.id; + backup = !!(entry->flags & MPTCP_PM_ADDR_FLAG_BACKUP); + } + rcu_read_unlock(); + + /* Send MP_PRIO */ + if (backup) + mptcp_pm_send_ack(msk, subflow, true, backup); + + msk->pm.status |= BIT(MPTCP_PM_MPC_ENDPOINT_ACCOUNTED); +} + static void mptcp_pm_create_subflow_or_signal_addr(struct mptcp_sock *msk) { u8 limit_extra_subflows = mptcp_pm_get_limit_extra_subflows(msk); @@ -278,28 +318,7 @@ static void mptcp_pm_create_subflow_or_signal_addr(struct mptcp_sock *msk) bool signal_and_subflow = false; struct mptcp_pm_local local; - /* do lazy endpoint usage accounting for the MPC subflows */ - if (unlikely(!(msk->pm.status & BIT(MPTCP_PM_MPC_ENDPOINT_ACCOUNTED))) && msk->first) { - struct mptcp_subflow_context *subflow = mptcp_subflow_ctx(msk->first); - struct mptcp_pm_addr_entry *entry; - struct mptcp_addr_info mpc_addr; - bool backup = false; - - mptcp_local_address((struct sock_common *)msk->first, &mpc_addr); - rcu_read_lock(); - entry = __lookup_addr(pernet, &mpc_addr); - if (entry) { - __clear_bit(entry->addr.id, msk->pm.id_avail_bitmap); - msk->mpc_endpoint_id = entry->addr.id; - backup = !!(entry->flags & MPTCP_PM_ADDR_FLAG_BACKUP); - } - rcu_read_unlock(); - - if (backup) - mptcp_pm_send_ack(msk, subflow, true, backup); - - msk->pm.status |= BIT(MPTCP_PM_MPC_ENDPOINT_ACCOUNTED); - } + mptcp_mpc_endpoint_setup(msk); pr_debug("local %d:%d signal %d:%d subflows %d:%d\n", msk->pm.local_addr_used, endp_subflow_max, @@ -396,12 +415,9 @@ fill_local_addresses_vec_fullmesh(struct mptcp_sock *msk, struct pm_nl_pernet *pernet = pm_nl_get_pernet_from_msk(msk); struct sock *sk = (struct sock *)msk; struct mptcp_pm_addr_entry *entry; - struct mptcp_addr_info mpc_addr; struct mptcp_pm_local *local; int i = 0; - mptcp_local_address((struct sock_common *)msk, &mpc_addr); - rcu_read_lock(); list_for_each_entry_rcu(entry, &pernet->endp_list, list) { bool is_id0; @@ -417,8 +433,7 @@ fill_local_addresses_vec_fullmesh(struct mptcp_sock *msk, local->flags = entry->flags; local->ifindex = entry->ifindex; - is_id0 = mptcp_addresses_equal(&local->addr, &mpc_addr, - local->addr.port); + is_id0 = local->addr.id == msk->mpc_endpoint_id; if (c_flag_case && (entry->flags & MPTCP_PM_ADDR_FLAG_SUBFLOW)) { @@ -452,12 +467,9 @@ fill_local_addresses_vec_c_flag(struct mptcp_sock *msk, struct pm_nl_pernet *pernet = pm_nl_get_pernet_from_msk(msk); u8 endp_subflow_max = mptcp_pm_get_endp_subflow_max(msk); struct sock *sk = (struct sock *)msk; - struct mptcp_addr_info mpc_addr; struct mptcp_pm_local *local; int i = 0; - mptcp_local_address((struct sock_common *)msk, &mpc_addr); - while (msk->pm.local_addr_used < endp_subflow_max) { local = &locals[i]; @@ -469,8 +481,7 @@ fill_local_addresses_vec_c_flag(struct mptcp_sock *msk, if (!mptcp_pm_addr_families_match(sk, &local->addr, remote)) continue; - if (mptcp_addresses_equal(&local->addr, &mpc_addr, - local->addr.port)) + if (local->addr.id == msk->mpc_endpoint_id) continue; msk->pm.local_addr_used++; @@ -548,6 +559,7 @@ static void mptcp_pm_nl_add_addr_received(struct mptcp_sock *msk) remote = msk->pm.remote; mptcp_pm_announce_addr(msk, &remote, true); mptcp_pm_addr_send_ack(msk); + mptcp_mpc_endpoint_setup(msk); if (lookup_subflow_by_daddr(&msk->conn_list, &remote)) return; @@ -935,12 +947,6 @@ int mptcp_pm_nl_add_addr_doit(struct sk_buff *skb, struct genl_info *info) return ret; } -static u8 mptcp_endp_get_local_id(struct mptcp_sock *msk, - const struct mptcp_addr_info *addr) -{ - return msk->mpc_endpoint_id == addr->id ? 0 : addr->id; -} - static bool mptcp_pm_remove_anno_addr(struct mptcp_sock *msk, const struct mptcp_addr_info *addr, bool force) -- 2.51.0