From: Wei Wang Introduce 2 versions of psp_device_get_locked: 1. psp_device_get_locked_admin(): This version is used for operations that would change the status of the psd, and are currently used for dev-set nad key-rotation. 2. psp_device_get_locked(): This is the non-admin version, which are used for broader user issued operations including: dev-get, rx-assoc, tx-assoc, get-stats. Following commit will be implementing both of the checks. Signed-off-by: Wei Wang --- Documentation/netlink/specs/psp.yaml | 4 ++-- net/psp/psp-nl-gen.c | 4 ++-- net/psp/psp-nl-gen.h | 2 ++ net/psp/psp.h | 2 +- net/psp/psp_main.c | 7 ++++++- net/psp/psp_nl.c | 28 +++++++++++++++++++++------- 6 files changed, 34 insertions(+), 13 deletions(-) diff --git a/Documentation/netlink/specs/psp.yaml b/Documentation/netlink/specs/psp.yaml index f3a57782d2cf..fe2cdc966604 100644 --- a/Documentation/netlink/specs/psp.yaml +++ b/Documentation/netlink/specs/psp.yaml @@ -195,7 +195,7 @@ operations: - psp-versions-ena reply: attributes: [] - pre: psp-device-get-locked + pre: psp-device-get-locked-admin post: psp-device-unlock - name: dev-change-ntf @@ -214,7 +214,7 @@ operations: reply: attributes: - id - pre: psp-device-get-locked + pre: psp-device-get-locked-admin post: psp-device-unlock - name: key-rotate-ntf diff --git a/net/psp/psp-nl-gen.c b/net/psp/psp-nl-gen.c index 22a48d0fa378..1f5e73e7ccc1 100644 --- a/net/psp/psp-nl-gen.c +++ b/net/psp/psp-nl-gen.c @@ -71,7 +71,7 @@ static const struct genl_split_ops psp_nl_ops[] = { }, { .cmd = PSP_CMD_DEV_SET, - .pre_doit = psp_device_get_locked, + .pre_doit = psp_device_get_locked_admin, .doit = psp_nl_dev_set_doit, .post_doit = psp_device_unlock, .policy = psp_dev_set_nl_policy, @@ -80,7 +80,7 @@ static const struct genl_split_ops psp_nl_ops[] = { }, { .cmd = PSP_CMD_KEY_ROTATE, - .pre_doit = psp_device_get_locked, + .pre_doit = psp_device_get_locked_admin, .doit = psp_nl_key_rotate_doit, .post_doit = psp_device_unlock, .policy = psp_key_rotate_nl_policy, diff --git a/net/psp/psp-nl-gen.h b/net/psp/psp-nl-gen.h index 599c5f1c82f2..977355455395 100644 --- a/net/psp/psp-nl-gen.h +++ b/net/psp/psp-nl-gen.h @@ -17,6 +17,8 @@ extern const struct nla_policy psp_keys_nl_policy[PSP_A_KEYS_SPI + 1]; int psp_device_get_locked(const struct genl_split_ops *ops, struct sk_buff *skb, struct genl_info *info); +int psp_device_get_locked_admin(const struct genl_split_ops *ops, + struct sk_buff *skb, struct genl_info *info); int psp_assoc_device_get_locked(const struct genl_split_ops *ops, struct sk_buff *skb, struct genl_info *info); void diff --git a/net/psp/psp.h b/net/psp/psp.h index 9f19137593a0..0f9c4e4e52cb 100644 --- a/net/psp/psp.h +++ b/net/psp/psp.h @@ -14,7 +14,7 @@ extern struct xarray psp_devs; extern struct mutex psp_devs_lock; void psp_dev_free(struct psp_dev *psd); -int psp_dev_check_access(struct psp_dev *psd, struct net *net); +int psp_dev_check_access(struct psp_dev *psd, struct net *net, bool admin); void psp_nl_notify_dev(struct psp_dev *psd, u32 cmd); diff --git a/net/psp/psp_main.c b/net/psp/psp_main.c index 9508b6c38003..82de78a1d6bd 100644 --- a/net/psp/psp_main.c +++ b/net/psp/psp_main.c @@ -27,10 +27,15 @@ struct mutex psp_devs_lock; * psp_dev_check_access() - check if user in a given net ns can access PSP dev * @psd: PSP device structure user is trying to access * @net: net namespace user is in + * @admin: If true, only allow access from @psd's main device's netns, + * for admin operations like config changes and key rotation. + * If false, also allow access from network namespaces that have + * an associated device with @psd, for read-only and association + * management operations. * * Return: 0 if PSP device should be visible in @net, errno otherwise. */ -int psp_dev_check_access(struct psp_dev *psd, struct net *net) +int psp_dev_check_access(struct psp_dev *psd, struct net *net, bool admin) { if (dev_net(psd->main_netdev) == net) return 0; diff --git a/net/psp/psp_nl.c b/net/psp/psp_nl.c index 6afd7707ec12..b988f35412df 100644 --- a/net/psp/psp_nl.c +++ b/net/psp/psp_nl.c @@ -41,7 +41,8 @@ static int psp_nl_reply_send(struct sk_buff *rsp, struct genl_info *info) /* Device stuff */ static struct psp_dev * -psp_device_get_and_lock(struct net *net, struct nlattr *dev_id) +psp_device_get_and_lock(struct net *net, struct nlattr *dev_id, + bool admin) { struct psp_dev *psd; int err; @@ -56,7 +57,7 @@ psp_device_get_and_lock(struct net *net, struct nlattr *dev_id) mutex_lock(&psd->lock); mutex_unlock(&psp_devs_lock); - err = psp_dev_check_access(psd, net); + err = psp_dev_check_access(psd, net, admin); if (err) { mutex_unlock(&psd->lock); return ERR_PTR(err); @@ -65,6 +66,18 @@ psp_device_get_and_lock(struct net *net, struct nlattr *dev_id) return psd; } +int psp_device_get_locked_admin(const struct genl_split_ops *ops, + struct sk_buff *skb, struct genl_info *info) +{ + if (GENL_REQ_ATTR_CHECK(info, PSP_A_DEV_ID)) + return -EINVAL; + + info->user_ptr[0] = psp_device_get_and_lock(genl_info_net(info), + info->attrs[PSP_A_DEV_ID], + true); + return PTR_ERR_OR_ZERO(info->user_ptr[0]); +} + int psp_device_get_locked(const struct genl_split_ops *ops, struct sk_buff *skb, struct genl_info *info) { @@ -72,7 +85,8 @@ int psp_device_get_locked(const struct genl_split_ops *ops, return -EINVAL; info->user_ptr[0] = psp_device_get_and_lock(genl_info_net(info), - info->attrs[PSP_A_DEV_ID]); + info->attrs[PSP_A_DEV_ID], + false); return PTR_ERR_OR_ZERO(info->user_ptr[0]); } @@ -160,7 +174,7 @@ static int psp_nl_dev_get_dumpit_one(struct sk_buff *rsp, struct netlink_callback *cb, struct psp_dev *psd) { - if (psp_dev_check_access(psd, sock_net(rsp->sk))) + if (psp_dev_check_access(psd, sock_net(rsp->sk), false)) return 0; return psp_nl_dev_fill(psd, rsp, genl_info_dump(cb)); @@ -305,7 +319,7 @@ int psp_assoc_device_get_locked(const struct genl_split_ops *ops, psd = psp_dev_get_for_sock(socket->sk); if (psd) { - err = psp_dev_check_access(psd, genl_info_net(info)); + err = psp_dev_check_access(psd, genl_info_net(info), false); if (err) { psp_dev_put(psd); psd = NULL; @@ -330,7 +344,7 @@ int psp_assoc_device_get_locked(const struct genl_split_ops *ops, psp_dev_put(psd); } else { - psd = psp_device_get_and_lock(genl_info_net(info), id); + psd = psp_device_get_and_lock(genl_info_net(info), id, false); if (IS_ERR(psd)) { err = PTR_ERR(psd); goto err_sock_put; @@ -573,7 +587,7 @@ static int psp_nl_stats_get_dumpit_one(struct sk_buff *rsp, struct netlink_callback *cb, struct psp_dev *psd) { - if (psp_dev_check_access(psd, sock_net(rsp->sk))) + if (psp_dev_check_access(psd, sock_net(rsp->sk), false)) return 0; return psp_nl_stats_fill(psd, rsp, genl_info_dump(cb)); -- 2.52.0 From: Wei Wang The main purpose of this cmd is to be able to associate a non-psp-capable device (e.g. veth or netkit) with a psp device. One use case is if we create a pair of veth/netkit, and assign 1 end inside a netns, while leaving the other end within the default netns, with a real PSP device, e.g. netdevsim or a physical PSP-capable NIC. With this command, we could associate the veth/netkit inside the netns with PSP device, so the virtual device could act as PSP-capable device to initiate PSP connections, and performs PSP encryption/decryption on the real PSP device. Signed-off-by: Wei Wang --- Documentation/netlink/specs/psp.yaml | 67 +++++- include/net/psp/types.h | 15 ++ include/uapi/linux/psp.h | 13 ++ net/psp/psp-nl-gen.c | 32 +++ net/psp/psp-nl-gen.h | 2 + net/psp/psp_main.c | 21 +- net/psp/psp_nl.c | 309 ++++++++++++++++++++++++++- 7 files changed, 447 insertions(+), 12 deletions(-) diff --git a/Documentation/netlink/specs/psp.yaml b/Documentation/netlink/specs/psp.yaml index fe2cdc966604..336ef19155ff 100644 --- a/Documentation/netlink/specs/psp.yaml +++ b/Documentation/netlink/specs/psp.yaml @@ -13,6 +13,17 @@ definitions: hdr0-aes-gmac-128, hdr0-aes-gmac-256] attribute-sets: + - + name: assoc-dev-info + attributes: + - + name: ifindex + doc: ifindex of an associated network device. + type: u32 + - + name: nsid + doc: Network namespace ID of the associated device. + type: s32 - name: dev attributes: @@ -24,7 +35,9 @@ attribute-sets: min: 1 - name: ifindex - doc: ifindex of the main netdevice linked to the PSP device. + doc: | + ifindex of the main netdevice linked to the PSP device, + or the ifindex to associate with the PSP device. type: u32 - name: psp-versions-cap @@ -38,6 +51,28 @@ attribute-sets: type: u32 enum: version enum-as-flags: true + - + name: assoc-list + doc: List of associated virtual devices. + type: nest + nested-attributes: assoc-dev-info + multi-attr: true + - + name: nsid + doc: | + Network namespace ID for the device to associate/disassociate. + Optional for dev-assoc and dev-disassoc; if not present, the + device is looked up in the caller's network namespace. + type: s32 + - + name: by-association + doc: | + Flag indicating the PSP device is an associated device from a + different network namespace. + Present when in associated namespace, absent when in primary/host + namespace. + type: flag + - name: assoc attributes: @@ -170,6 +205,8 @@ operations: - ifindex - psp-versions-cap - psp-versions-ena + - assoc-list + - by-association pre: psp-device-get-locked post: psp-device-unlock dump: @@ -271,6 +308,34 @@ operations: post: psp-device-unlock dump: reply: *stats-all + - + name: dev-assoc + doc: Associate a network device with a PSP device. + attribute-set: dev + do: + request: + attributes: + - id + - ifindex + - nsid + reply: + attributes: [] + pre: psp-device-get-locked + post: psp-device-unlock + - + name: dev-disassoc + doc: Disassociate a network device from a PSP device. + attribute-set: dev + do: + request: + attributes: + - id + - ifindex + - nsid + reply: + attributes: [] + pre: psp-device-get-locked + post: psp-device-unlock mcast-groups: list: diff --git a/include/net/psp/types.h b/include/net/psp/types.h index 25a9096d4e7d..4bd432ed107a 100644 --- a/include/net/psp/types.h +++ b/include/net/psp/types.h @@ -5,6 +5,7 @@ #include #include +#include struct netlink_ext_ack; @@ -43,9 +44,22 @@ struct psp_dev_config { u32 versions; }; +/** + * struct psp_assoc_dev - wrapper for associated net_device + * @dev_list: list node for psp_dev::assoc_dev_list + * @assoc_dev: the associated net_device + * @dev_tracker: tracker for the net_device reference + */ +struct psp_assoc_dev { + struct list_head dev_list; + struct net_device *assoc_dev; + netdevice_tracker dev_tracker; +}; + /** * struct psp_dev - PSP device struct * @main_netdev: original netdevice of this PSP device + * @assoc_dev_list: list of psp_assoc_dev entries associated with this PSP device * @ops: driver callbacks * @caps: device capabilities * @drv_priv: driver priv pointer @@ -67,6 +81,7 @@ struct psp_dev_config { */ struct psp_dev { struct net_device *main_netdev; + struct list_head assoc_dev_list; struct psp_dev_ops *ops; struct psp_dev_caps *caps; diff --git a/include/uapi/linux/psp.h b/include/uapi/linux/psp.h index a3a336488dc3..1c8899cd4da5 100644 --- a/include/uapi/linux/psp.h +++ b/include/uapi/linux/psp.h @@ -17,11 +17,22 @@ enum psp_version { PSP_VERSION_HDR0_AES_GMAC_256, }; +enum { + PSP_A_ASSOC_DEV_INFO_IFINDEX = 1, + PSP_A_ASSOC_DEV_INFO_NSID, + + __PSP_A_ASSOC_DEV_INFO_MAX, + PSP_A_ASSOC_DEV_INFO_MAX = (__PSP_A_ASSOC_DEV_INFO_MAX - 1) +}; + enum { PSP_A_DEV_ID = 1, PSP_A_DEV_IFINDEX, PSP_A_DEV_PSP_VERSIONS_CAP, PSP_A_DEV_PSP_VERSIONS_ENA, + PSP_A_DEV_ASSOC_LIST, + PSP_A_DEV_NSID, + PSP_A_DEV_BY_ASSOCIATION, __PSP_A_DEV_MAX, PSP_A_DEV_MAX = (__PSP_A_DEV_MAX - 1) @@ -74,6 +85,8 @@ enum { PSP_CMD_RX_ASSOC, PSP_CMD_TX_ASSOC, PSP_CMD_GET_STATS, + PSP_CMD_DEV_ASSOC, + PSP_CMD_DEV_DISASSOC, __PSP_CMD_MAX, PSP_CMD_MAX = (__PSP_CMD_MAX - 1) diff --git a/net/psp/psp-nl-gen.c b/net/psp/psp-nl-gen.c index 1f5e73e7ccc1..114299c64423 100644 --- a/net/psp/psp-nl-gen.c +++ b/net/psp/psp-nl-gen.c @@ -53,6 +53,20 @@ static const struct nla_policy psp_get_stats_nl_policy[PSP_A_STATS_DEV_ID + 1] = [PSP_A_STATS_DEV_ID] = NLA_POLICY_MIN(NLA_U32, 1), }; +/* PSP_CMD_DEV_ASSOC - do */ +static const struct nla_policy psp_dev_assoc_nl_policy[PSP_A_DEV_NSID + 1] = { + [PSP_A_DEV_ID] = NLA_POLICY_MIN(NLA_U32, 1), + [PSP_A_DEV_IFINDEX] = { .type = NLA_U32, }, + [PSP_A_DEV_NSID] = { .type = NLA_S32, }, +}; + +/* PSP_CMD_DEV_DISASSOC - do */ +static const struct nla_policy psp_dev_disassoc_nl_policy[PSP_A_DEV_NSID + 1] = { + [PSP_A_DEV_ID] = NLA_POLICY_MIN(NLA_U32, 1), + [PSP_A_DEV_IFINDEX] = { .type = NLA_U32, }, + [PSP_A_DEV_NSID] = { .type = NLA_S32, }, +}; + /* Ops table for psp */ static const struct genl_split_ops psp_nl_ops[] = { { @@ -119,6 +133,24 @@ static const struct genl_split_ops psp_nl_ops[] = { .dumpit = psp_nl_get_stats_dumpit, .flags = GENL_CMD_CAP_DUMP, }, + { + .cmd = PSP_CMD_DEV_ASSOC, + .pre_doit = psp_device_get_locked, + .doit = psp_nl_dev_assoc_doit, + .post_doit = psp_device_unlock, + .policy = psp_dev_assoc_nl_policy, + .maxattr = PSP_A_DEV_NSID, + .flags = GENL_CMD_CAP_DO, + }, + { + .cmd = PSP_CMD_DEV_DISASSOC, + .pre_doit = psp_device_get_locked, + .doit = psp_nl_dev_disassoc_doit, + .post_doit = psp_device_unlock, + .policy = psp_dev_disassoc_nl_policy, + .maxattr = PSP_A_DEV_NSID, + .flags = GENL_CMD_CAP_DO, + }, }; static const struct genl_multicast_group psp_nl_mcgrps[] = { diff --git a/net/psp/psp-nl-gen.h b/net/psp/psp-nl-gen.h index 977355455395..4dd0f0f23053 100644 --- a/net/psp/psp-nl-gen.h +++ b/net/psp/psp-nl-gen.h @@ -33,6 +33,8 @@ int psp_nl_rx_assoc_doit(struct sk_buff *skb, struct genl_info *info); int psp_nl_tx_assoc_doit(struct sk_buff *skb, struct genl_info *info); int psp_nl_get_stats_doit(struct sk_buff *skb, struct genl_info *info); int psp_nl_get_stats_dumpit(struct sk_buff *skb, struct netlink_callback *cb); +int psp_nl_dev_assoc_doit(struct sk_buff *skb, struct genl_info *info); +int psp_nl_dev_disassoc_doit(struct sk_buff *skb, struct genl_info *info); enum { PSP_NLGRP_MGMT, diff --git a/net/psp/psp_main.c b/net/psp/psp_main.c index 82de78a1d6bd..178b848989f1 100644 --- a/net/psp/psp_main.c +++ b/net/psp/psp_main.c @@ -37,8 +37,18 @@ struct mutex psp_devs_lock; */ int psp_dev_check_access(struct psp_dev *psd, struct net *net, bool admin) { + struct psp_assoc_dev *entry; + if (dev_net(psd->main_netdev) == net) return 0; + + if (!admin) { + list_for_each_entry(entry, &psd->assoc_dev_list, dev_list) { + if (dev_net(entry->assoc_dev) == net) + return 0; + } + } + return -ENOENT; } @@ -74,6 +84,7 @@ psp_dev_create(struct net_device *netdev, return ERR_PTR(-ENOMEM); psd->main_netdev = netdev; + INIT_LIST_HEAD(&psd->assoc_dev_list); psd->ops = psd_ops; psd->caps = psd_caps; psd->drv_priv = priv_ptr; @@ -121,6 +132,7 @@ void psp_dev_free(struct psp_dev *psd) */ void psp_dev_unregister(struct psp_dev *psd) { + struct psp_assoc_dev *entry, *entry_tmp; struct psp_assoc *pas, *next; mutex_lock(&psp_devs_lock); @@ -140,6 +152,14 @@ void psp_dev_unregister(struct psp_dev *psd) list_for_each_entry_safe(pas, next, &psd->stale_assocs, assocs_list) psp_dev_tx_key_del(psd, pas); + list_for_each_entry_safe(entry, entry_tmp, &psd->assoc_dev_list, + dev_list) { + list_del(&entry->dev_list); + rcu_assign_pointer(entry->assoc_dev->psp_dev, NULL); + netdev_put(entry->assoc_dev, &entry->dev_tracker); + kfree(entry); + } + rcu_assign_pointer(psd->main_netdev->psp_dev, NULL); psd->ops = NULL; @@ -361,5 +381,4 @@ static int __init psp_init(void) return genl_register_family(&psp_nl_family); } - subsys_initcall(psp_init); diff --git a/net/psp/psp_nl.c b/net/psp/psp_nl.c index b988f35412df..aa60a8277829 100644 --- a/net/psp/psp_nl.c +++ b/net/psp/psp_nl.c @@ -2,6 +2,7 @@ #include #include +#include #include #include #include @@ -38,6 +39,73 @@ static int psp_nl_reply_send(struct sk_buff *rsp, struct genl_info *info) return genlmsg_reply(rsp, info); } +/** + * psp_nl_multicast_per_ns() - multicast a notification to each unique netns + * @psd: PSP device (must be locked) + * @group: multicast group + * @build_ntf: callback to build an skb for a given netns, or NULL on failure + * @ctx: opaque context passed to @build_ntf + * + * Iterates all unique network namespaces from the associated device list + * plus the main device's netns. For each unique netns, calls @build_ntf + * to construct a notification skb and multicasts it. + */ +static void psp_nl_multicast_per_ns(struct psp_dev *psd, unsigned int group, + struct sk_buff *(*build_ntf)(struct psp_dev *, + struct net *, + void *), + void *ctx) +{ + struct psp_assoc_dev *entry; + struct xarray sent_nets; + struct net *main_net; + struct sk_buff *ntf; + + main_net = dev_net(psd->main_netdev); + xa_init(&sent_nets); + + list_for_each_entry(entry, &psd->assoc_dev_list, dev_list) { + struct net *assoc_net = dev_net(entry->assoc_dev); + int ret; + + if (net_eq(assoc_net, main_net)) + continue; + + ret = xa_insert(&sent_nets, (unsigned long)assoc_net, assoc_net, + GFP_KERNEL); + if (ret == -EBUSY) + continue; + + ntf = build_ntf(psd, assoc_net, ctx); + if (!ntf) + continue; + + genlmsg_multicast_netns(&psp_nl_family, assoc_net, ntf, 0, + group, GFP_KERNEL); + } + xa_destroy(&sent_nets); + + /* Send to main device netns */ + ntf = build_ntf(psd, main_net, ctx); + if (!ntf) + return; + genlmsg_multicast_netns(&psp_nl_family, main_net, ntf, 0, group, + GFP_KERNEL); +} + +static struct sk_buff *psp_nl_clone_ntf(struct psp_dev *psd, struct net *net, + void *ctx) +{ + return skb_clone(ctx, GFP_KERNEL); +} + +static void psp_nl_multicast_all_ns(struct psp_dev *psd, struct sk_buff *ntf, + unsigned int group) +{ + psp_nl_multicast_per_ns(psd, group, psp_nl_clone_ntf, ntf); + nlmsg_free(ntf); +} + /* Device stuff */ static struct psp_dev * @@ -102,11 +170,74 @@ psp_device_unlock(const struct genl_split_ops *ops, struct sk_buff *skb, sockfd_put(socket); } +static bool psp_has_assoc_dev_in_ns(struct psp_dev *psd, struct net *net) +{ + struct psp_assoc_dev *entry; + + list_for_each_entry(entry, &psd->assoc_dev_list, dev_list) { + if (dev_net(entry->assoc_dev) == net) + return true; + } + + return false; +} + +static int psp_nl_fill_assoc_dev_list(struct psp_dev *psd, struct sk_buff *rsp, + struct net *cur_net, + struct net *filter_net) +{ + struct psp_assoc_dev *entry; + struct net *dev_net_ns; + struct nlattr *nest; + int nsid; + + list_for_each_entry(entry, &psd->assoc_dev_list, dev_list) { + dev_net_ns = dev_net(entry->assoc_dev); + + if (filter_net && dev_net_ns != filter_net) + continue; + + /* When filtering by namespace, all devices are in the caller's + * namespace so nsid is always NETNSA_NSID_NOT_ASSIGNED (-1). + * Otherwise, calculate the nsid relative to cur_net. + */ + nsid = filter_net ? NETNSA_NSID_NOT_ASSIGNED : + peernet2id_alloc(cur_net, dev_net_ns, + GFP_KERNEL); + + nest = nla_nest_start(rsp, PSP_A_DEV_ASSOC_LIST); + if (!nest) + return -1; + + if (nla_put_u32(rsp, PSP_A_ASSOC_DEV_INFO_IFINDEX, + entry->assoc_dev->ifindex) || + nla_put_s32(rsp, PSP_A_ASSOC_DEV_INFO_NSID, nsid)) { + nla_nest_cancel(rsp, nest); + return -1; + } + + nla_nest_end(rsp, nest); + } + + return 0; +} + static int psp_nl_dev_fill(struct psp_dev *psd, struct sk_buff *rsp, const struct genl_info *info) { + struct net *cur_net; void *hdr; + int err; + + cur_net = genl_info_net(info); + + /* Skip this device if we're in an associated netns but have no + * associated devices in cur_net + */ + if (cur_net != dev_net(psd->main_netdev) && + !psp_has_assoc_dev_in_ns(psd, cur_net)) + return 0; hdr = genlmsg_iput(rsp, info); if (!hdr) @@ -118,6 +249,22 @@ psp_nl_dev_fill(struct psp_dev *psd, struct sk_buff *rsp, nla_put_u32(rsp, PSP_A_DEV_PSP_VERSIONS_ENA, psd->config.versions)) goto err_cancel_msg; + if (cur_net == dev_net(psd->main_netdev)) { + /* Primary device - dump assoc list */ + err = psp_nl_fill_assoc_dev_list(psd, rsp, cur_net, NULL); + if (err) + goto err_cancel_msg; + } else { + /* In netns: set by-association flag and dump filtered + * assoc list containing only devices in cur_net + */ + if (nla_put_flag(rsp, PSP_A_DEV_BY_ASSOCIATION)) + goto err_cancel_msg; + err = psp_nl_fill_assoc_dev_list(psd, rsp, cur_net, cur_net); + if (err) + goto err_cancel_msg; + } + genlmsg_end(rsp, hdr); return 0; @@ -126,27 +273,34 @@ psp_nl_dev_fill(struct psp_dev *psd, struct sk_buff *rsp, return -EMSGSIZE; } -void psp_nl_notify_dev(struct psp_dev *psd, u32 cmd) +static struct sk_buff *psp_nl_build_dev_ntf(struct psp_dev *psd, + struct net *net, void *ctx) { + u32 cmd = *(u32 *)ctx; struct genl_info info; struct sk_buff *ntf; - if (!genl_has_listeners(&psp_nl_family, dev_net(psd->main_netdev), - PSP_NLGRP_MGMT)) - return; + if (!genl_has_listeners(&psp_nl_family, net, PSP_NLGRP_MGMT)) + return NULL; ntf = genlmsg_new(GENLMSG_DEFAULT_SIZE, GFP_KERNEL); if (!ntf) - return; + return NULL; genl_info_init_ntf(&info, &psp_nl_family, cmd); + genl_info_net_set(&info, net); if (psp_nl_dev_fill(psd, ntf, &info)) { nlmsg_free(ntf); - return; + return NULL; } - genlmsg_multicast_netns(&psp_nl_family, dev_net(psd->main_netdev), ntf, - 0, PSP_NLGRP_MGMT, GFP_KERNEL); + return ntf; +} + +void psp_nl_notify_dev(struct psp_dev *psd, u32 cmd) +{ + psp_nl_multicast_per_ns(psd, PSP_NLGRP_MGMT, + psp_nl_build_dev_ntf, &cmd); } int psp_nl_dev_get_doit(struct sk_buff *req, struct genl_info *info) @@ -280,8 +434,9 @@ int psp_nl_key_rotate_doit(struct sk_buff *skb, struct genl_info *info) psd->stats.rotations++; nlmsg_end(ntf, (struct nlmsghdr *)ntf->data); - genlmsg_multicast_netns(&psp_nl_family, dev_net(psd->main_netdev), ntf, - 0, PSP_NLGRP_USE, GFP_KERNEL); + + psp_nl_multicast_all_ns(psd, ntf, PSP_NLGRP_USE); + return psp_nl_reply_send(rsp, info); err_free_ntf: @@ -291,6 +446,140 @@ int psp_nl_key_rotate_doit(struct sk_buff *skb, struct genl_info *info) return err; } +int psp_nl_dev_assoc_doit(struct sk_buff *skb, struct genl_info *info) +{ + struct psp_dev *psd = info->user_ptr[0]; + struct psp_assoc_dev *psp_assoc_dev; + struct net_device *assoc_dev; + u32 assoc_ifindex; + struct sk_buff *rsp; + struct net *net; + int nsid; + + if (GENL_REQ_ATTR_CHECK(info, PSP_A_DEV_IFINDEX)) + return -EINVAL; + + if (info->attrs[PSP_A_DEV_NSID]) { + nsid = nla_get_s32(info->attrs[PSP_A_DEV_NSID]); + + net = get_net_ns_by_id(genl_info_net(info), nsid); + if (!net) { + NL_SET_BAD_ATTR(info->extack, + info->attrs[PSP_A_DEV_NSID]); + return -EINVAL; + } + } else { + net = get_net(genl_info_net(info)); + } + + psp_assoc_dev = kzalloc(sizeof(*psp_assoc_dev), GFP_KERNEL); + if (!psp_assoc_dev) { + put_net(net); + return -ENOMEM; + } + + assoc_ifindex = nla_get_u32(info->attrs[PSP_A_DEV_IFINDEX]); + assoc_dev = netdev_get_by_index(net, assoc_ifindex, + &psp_assoc_dev->dev_tracker, + GFP_KERNEL); + if (!assoc_dev) { + put_net(net); + kfree(psp_assoc_dev); + NL_SET_BAD_ATTR(info->extack, info->attrs[PSP_A_DEV_IFINDEX]); + return -ENODEV; + } + + /* Check if device is already associated with a PSP device */ + if (rcu_access_pointer(assoc_dev->psp_dev)) { + NL_SET_ERR_MSG(info->extack, + "Device already associated with a PSP device"); + netdev_put(assoc_dev, &psp_assoc_dev->dev_tracker); + put_net(net); + kfree(psp_assoc_dev); + return -EBUSY; + } + + psp_assoc_dev->assoc_dev = assoc_dev; + rsp = psp_nl_reply_new(info); + if (!rsp) { + netdev_put(assoc_dev, &psp_assoc_dev->dev_tracker); + put_net(net); + kfree(psp_assoc_dev); + return -ENOMEM; + } + + rcu_assign_pointer(assoc_dev->psp_dev, psd); + list_add_tail(&psp_assoc_dev->dev_list, &psd->assoc_dev_list); + + put_net(net); + + psp_nl_notify_dev(psd, PSP_CMD_DEV_CHANGE_NTF); + + return psp_nl_reply_send(rsp, info); +} + +int psp_nl_dev_disassoc_doit(struct sk_buff *skb, struct genl_info *info) +{ + struct psp_assoc_dev *entry, *found = NULL; + struct psp_dev *psd = info->user_ptr[0]; + u32 assoc_ifindex; + struct sk_buff *rsp; + struct net *net; + int nsid; + + if (GENL_REQ_ATTR_CHECK(info, PSP_A_DEV_IFINDEX)) + return -EINVAL; + + if (info->attrs[PSP_A_DEV_NSID]) { + nsid = nla_get_s32(info->attrs[PSP_A_DEV_NSID]); + + net = get_net_ns_by_id(genl_info_net(info), nsid); + if (!net) { + NL_SET_BAD_ATTR(info->extack, + info->attrs[PSP_A_DEV_NSID]); + return -EINVAL; + } + } else { + net = get_net(genl_info_net(info)); + } + + assoc_ifindex = nla_get_u32(info->attrs[PSP_A_DEV_IFINDEX]); + + /* Search the association list by ifindex and netns */ + list_for_each_entry(entry, &psd->assoc_dev_list, dev_list) { + if (entry->assoc_dev->ifindex == assoc_ifindex && + dev_net(entry->assoc_dev) == net) { + found = entry; + break; + } + } + + if (!found) { + put_net(net); + NL_SET_BAD_ATTR(info->extack, info->attrs[PSP_A_DEV_IFINDEX]); + return -ENODEV; + } + + rsp = psp_nl_reply_new(info); + if (!rsp) { + put_net(net); + return -ENOMEM; + } + + /* Notify before removal */ + psp_nl_notify_dev(psd, PSP_CMD_DEV_CHANGE_NTF); + + /* Remove from the association list */ + list_del(&found->dev_list); + rcu_assign_pointer(found->assoc_dev->psp_dev, NULL); + netdev_put(found->assoc_dev, &found->dev_tracker); + kfree(found); + + put_net(net); + + return psp_nl_reply_send(rsp, info); +} + /* Key etc. */ int psp_assoc_device_get_locked(const struct genl_split_ops *ops, -- 2.52.0 From: Wei Wang Add a new netdev event for dev unregister and handle the removal of this dev from psp->assoc_dev_list, upon the first successful dev-assoc operation. Signed-off-by: Wei Wang --- net/psp/psp.h | 1 + net/psp/psp_main.c | 70 ++++++++++++++++++++++++++++++++++++++++++++++ net/psp/psp_nl.c | 7 +++++ 3 files changed, 78 insertions(+) diff --git a/net/psp/psp.h b/net/psp/psp.h index 0f9c4e4e52cb..fd7457dedd30 100644 --- a/net/psp/psp.h +++ b/net/psp/psp.h @@ -15,6 +15,7 @@ extern struct mutex psp_devs_lock; void psp_dev_free(struct psp_dev *psd); int psp_dev_check_access(struct psp_dev *psd, struct net *net, bool admin); +void psp_attach_netdev_notifier(void); void psp_nl_notify_dev(struct psp_dev *psd, u32 cmd); diff --git a/net/psp/psp_main.c b/net/psp/psp_main.c index 178b848989f1..db4593e76fa7 100644 --- a/net/psp/psp_main.c +++ b/net/psp/psp_main.c @@ -375,10 +375,80 @@ int psp_dev_rcv(struct sk_buff *skb, u16 dev_id, u8 generation, bool strip_icv) } EXPORT_SYMBOL(psp_dev_rcv); +static void psp_dev_disassoc_one(struct psp_dev *psd, struct net_device *dev) +{ + struct psp_assoc_dev *entry, *tmp; + + list_for_each_entry_safe(entry, tmp, &psd->assoc_dev_list, dev_list) { + if (entry->assoc_dev == dev) { + list_del(&entry->dev_list); + rcu_assign_pointer(entry->assoc_dev->psp_dev, NULL); + netdev_put(entry->assoc_dev, &entry->dev_tracker); + kfree(entry); + return; + } + } +} + +static int psp_netdev_event(struct notifier_block *nb, unsigned long event, + void *ptr) +{ + struct net_device *dev = netdev_notifier_info_to_dev(ptr); + struct psp_dev *psd; + + if (event != NETDEV_UNREGISTER) + return NOTIFY_DONE; + + rcu_read_lock(); + psd = rcu_dereference(dev->psp_dev); + if (psd && psp_dev_tryget(psd)) { + rcu_read_unlock(); + mutex_lock(&psd->lock); + psp_dev_disassoc_one(psd, dev); + mutex_unlock(&psd->lock); + psp_dev_put(psd); + } else { + rcu_read_unlock(); + } + + return NOTIFY_DONE; +} + +static struct notifier_block psp_netdev_notifier = { + .notifier_call = psp_netdev_event, +}; + +static bool psp_notifier_registered; + +/** + * psp_attach_netdev_notifier() - register netdev notifier on first use + * + * Register the netdevice notifier when the first device association + * is created. In many installations no associations will be created and + * the notifier won't be needed. + * + * Must be called without psd->lock held, due to lock ordering: + * rtnl_lock -> psd->lock (the notifier callback runs under rtnl_lock + * and takes psd->lock). + */ +void psp_attach_netdev_notifier(void) +{ + if (READ_ONCE(psp_notifier_registered)) + return; + + mutex_lock(&psp_devs_lock); + if (!psp_notifier_registered) { + register_netdevice_notifier(&psp_netdev_notifier); + WRITE_ONCE(psp_notifier_registered, true); + } + mutex_unlock(&psp_devs_lock); +} + static int __init psp_init(void) { mutex_init(&psp_devs_lock); return genl_register_family(&psp_nl_family); } + subsys_initcall(psp_init); diff --git a/net/psp/psp_nl.c b/net/psp/psp_nl.c index aa60a8277829..44e00add4211 100644 --- a/net/psp/psp_nl.c +++ b/net/psp/psp_nl.c @@ -515,6 +515,13 @@ int psp_nl_dev_assoc_doit(struct sk_buff *skb, struct genl_info *info) psp_nl_notify_dev(psd, PSP_CMD_DEV_CHANGE_NTF); + /* Register netdev notifier for assoc cleanup on success. + * Must drop psd->lock to ensure lock ordering: rtnl_lock -> psd->lock + */ + mutex_unlock(&psd->lock); + psp_attach_netdev_notifier(); + mutex_lock(&psd->lock); + return psp_nl_reply_send(rsp, info); } -- 2.52.0 From: Wei Wang Add nk_redirect.bpf.c, a BPF program that forwards skbs matching some IPv6 prefix received on eth0 ifindex to a specified dev ifindex. bpf_redirect_neigh() is used to make sure neighbor lookup is performed and proper MAC addr is being used. Signed-off-by: Wei Wang Reviewed-by: Bobby Eshleman Tested-by: Bobby Eshleman --- .../drivers/net/hw/nk_redirect.bpf.c | 60 +++++++++++++++++++ 1 file changed, 60 insertions(+) create mode 100644 tools/testing/selftests/drivers/net/hw/nk_redirect.bpf.c diff --git a/tools/testing/selftests/drivers/net/hw/nk_redirect.bpf.c b/tools/testing/selftests/drivers/net/hw/nk_redirect.bpf.c new file mode 100644 index 000000000000..7ac9ffd50f15 --- /dev/null +++ b/tools/testing/selftests/drivers/net/hw/nk_redirect.bpf.c @@ -0,0 +1,60 @@ +// SPDX-License-Identifier: GPL-2.0 +/* + * BPF program for redirecting traffic using bpf_redirect_neigh(). + * Unlike bpf_redirect() which preserves L2 headers, bpf_redirect_neigh() + * performs neighbor lookup and fills in the correct L2 addresses for the + * target interface. This is necessary when redirecting across different + * device types (e.g., from netdevsim to netkit). + */ +#include +#include +#include +#include +#include +#include +#include + +#define TC_ACT_OK 0 +#define ETH_P_IPV6 0x86DD + +#define ctx_ptr(field) ((void *)(long)(field)) + +#define v6_p64_equal(a, b) (a.s6_addr32[0] == b.s6_addr32[0] && \ + a.s6_addr32[1] == b.s6_addr32[1]) + +volatile __u32 redirect_ifindex; +volatile __u8 ipv6_prefix[16]; + +SEC("tc/ingress") +int tc_redirect(struct __sk_buff *skb) +{ + void *data_end = ctx_ptr(skb->data_end); + void *data = ctx_ptr(skb->data); + struct in6_addr *match_prefix; + struct ipv6hdr *ip6h; + struct ethhdr *eth; + + match_prefix = (struct in6_addr *)ipv6_prefix; + + if (skb->protocol != bpf_htons(ETH_P_IPV6)) + return TC_ACT_OK; + + eth = data; + if ((void *)(eth + 1) > data_end) + return TC_ACT_OK; + + ip6h = data + sizeof(struct ethhdr); + if ((void *)(ip6h + 1) > data_end) + return TC_ACT_OK; + + if (!v6_p64_equal(ip6h->daddr, (*match_prefix))) + return TC_ACT_OK; + + /* + * Use bpf_redirect_neigh() to perform neighbor lookup and fill in + * correct L2 addresses for the target interface. + */ + return bpf_redirect_neigh(redirect_ifindex, NULL, 0, 0); +} + +char __license[] SEC("license") = "GPL"; -- 2.52.0 From: Wei Wang Add a new param to NetDrvContEnv to add an additional bpf redirect program on nk_host to redirect traffic to the psp_dev_local. The topology looks like this: Host NS: psp_dev_local <---> nk_host | | | | (netkit pair) | | Remote NS: psp_dev_peer Guest NS: nk_guest (responder) (PSP tests) Add following tests for dev-assoc/dev-disassoc functionality: 1. Test the output of `./tools/net/ynl/pyynl/cli.py --spec Documentation/netlink/specs/psp.yaml --dump dev-get` in both default and the guest netns. 2. Test the case where we associate netkit with psp_dev_local, and send PSP traffic from nk_guest to psp_dev_peer in 2 different netns. 3. Test to make sure the key rotation notification is sent to the netns for associated dev as well 4. Test to make sure the dev change notification is sent to the netns for associated dev as well 5. Test for dev-assoc/dev-disassoc without nsid parameter. 6. Test the deletion of nk_guest in client netns, and proper cleanup in the assoc-list for psp dev. Signed-off-by: Wei Wang --- tools/testing/selftests/drivers/net/config | 1 + .../selftests/drivers/net/lib/py/env.py | 54 +- tools/testing/selftests/drivers/net/psp.py | 498 ++++++++++++++++-- 3 files changed, 518 insertions(+), 35 deletions(-) diff --git a/tools/testing/selftests/drivers/net/config b/tools/testing/selftests/drivers/net/config index 77ccf83d87e0..cdde8234dc07 100644 --- a/tools/testing/selftests/drivers/net/config +++ b/tools/testing/selftests/drivers/net/config @@ -7,4 +7,5 @@ CONFIG_NETCONSOLE=m CONFIG_NETCONSOLE_DYNAMIC=y CONFIG_NETCONSOLE_EXTENDED_LOG=y CONFIG_NETDEVSIM=m +CONFIG_NETKIT=y CONFIG_XDP_SOCKETS=y diff --git a/tools/testing/selftests/drivers/net/lib/py/env.py b/tools/testing/selftests/drivers/net/lib/py/env.py index ccff345fe1c1..2b88e4738fae 100644 --- a/tools/testing/selftests/drivers/net/lib/py/env.py +++ b/tools/testing/selftests/drivers/net/lib/py/env.py @@ -2,6 +2,7 @@ import ipaddress import os +import re import time import json from pathlib import Path @@ -327,7 +328,7 @@ class NetDrvContEnv(NetDrvEpEnv): +---------------+ """ - def __init__(self, src_path, rxqueues=1, **kwargs): + def __init__(self, src_path, rxqueues=1, install_tx_redirect_bpf=False, **kwargs): self.netns = None self._nk_host_ifname = None self._nk_guest_ifname = None @@ -338,6 +339,8 @@ class NetDrvContEnv(NetDrvEpEnv): self._init_ns_attached = False self._old_fwd = None self._old_accept_ra = None + self._nk_host_tc_attached = False + self._nk_host_bpf_prog_pref = None super().__init__(src_path, **kwargs) @@ -388,7 +391,13 @@ class NetDrvContEnv(NetDrvEpEnv): self._setup_ns() self._attach_bpf() + if install_tx_redirect_bpf: + self._attach_tx_redirect_bpf() + def __del__(self): + if self._nk_host_tc_attached: + cmd(f"tc filter del dev {self._nk_host_ifname} ingress pref {self._nk_host_bpf_prog_pref}", fail=False) + self._nk_host_tc_attached = False if self._tc_attached: cmd(f"tc filter del dev {self.ifname} ingress pref {self._bpf_prog_pref}") self._tc_attached = False @@ -496,3 +505,46 @@ class NetDrvContEnv(NetDrvEpEnv): value = ipv6_bytes + ifindex_bytes value_hex = ' '.join(f'{b:02x}' for b in value) bpftool(f"map update id {bss_map_id} key hex 00 00 00 00 value hex {value_hex}") + + def _attach_tx_redirect_bpf(self): + """ + Attach BPF program on nk_host ingress to redirect TX traffic. + + Packets from nk_guest destined for the nsim network arrive at nk_host + via the netkit pair. This BPF program redirects them to the physical + interface so they can reach the remote peer. + """ + bpf_obj = self.test_dir / "nk_redirect.bpf.o" + if not bpf_obj.exists(): + raise KsftSkipEx("BPF prog nk_redirect.bpf.o not found") + + cmd(f"tc qdisc add dev {self._nk_host_ifname} clsact") + + cmd(f"tc filter add dev {self._nk_host_ifname} ingress bpf obj {bpf_obj} sec tc/ingress direct-action") + self._nk_host_tc_attached = True + + tc_info = cmd(f"tc filter show dev {self._nk_host_ifname} ingress").stdout + match = re.search(r'pref (\d+).*nk_redirect\.bpf.*id (\d+)', tc_info) + if not match: + raise Exception("Failed to get TX redirect BPF prog ID") + self._nk_host_bpf_prog_pref = int(match.group(1)) + nk_host_bpf_prog_id = int(match.group(2)) + + prog_info = bpftool(f"prog show id {nk_host_bpf_prog_id}", json=True) + map_ids = prog_info.get("map_ids", []) + + bss_map_id = None + for map_id in map_ids: + map_info = bpftool(f"map show id {map_id}", json=True) + if map_info.get("name").endswith("bss"): + bss_map_id = map_id + + if bss_map_id is None: + raise Exception("Failed to find TX redirect BPF .bss map") + + ipv6_addr = ipaddress.IPv6Address(self.nsim_v6_pfx) + ipv6_bytes = ipv6_addr.packed + ifindex_bytes = self.ifindex.to_bytes(4, byteorder='little') + value = ipv6_bytes + ifindex_bytes + value_hex = ' '.join(f'{b:02x}' for b in value) + bpftool(f"map update id {bss_map_id} key hex 00 00 00 00 value hex {value_hex}") diff --git a/tools/testing/selftests/drivers/net/psp.py b/tools/testing/selftests/drivers/net/psp.py index 864d9fce1094..2d2c3cdabf24 100755 --- a/tools/testing/selftests/drivers/net/psp.py +++ b/tools/testing/selftests/drivers/net/psp.py @@ -5,6 +5,7 @@ import errno import fcntl +import os import socket import struct import termios @@ -14,9 +15,12 @@ from lib.py import defer from lib.py import ksft_run, ksft_exit, ksft_pr from lib.py import ksft_true, ksft_eq, ksft_ne, ksft_gt, ksft_raises from lib.py import ksft_not_none -from lib.py import KsftSkipEx -from lib.py import NetDrvEpEnv, PSPFamily, NlError -from lib.py import bkg, rand_port, wait_port_listen +from lib.py import ksft_variants, KsftNamedVariant +from lib.py import KsftSkipEx, KsftFailEx +from lib.py import NetDrvEpEnv, NetDrvContEnv, PSPFamily, NlError +from lib.py import NetNSEnter +from lib.py import bkg, cmd, rand_port, wait_port_listen +from lib.py import ip def _get_outq(s): @@ -117,11 +121,13 @@ def _get_stat(cfg, key): # Test case boiler plate # -def _init_psp_dev(cfg): +def _init_psp_dev(cfg, use_psp_ifindex=False): if not hasattr(cfg, 'psp_dev_id'): # Figure out which local device we are testing against + # For NetDrvContEnv: use psp_ifindex instead of ifindex + target_ifindex = cfg.psp_ifindex if use_psp_ifindex else cfg.ifindex for dev in cfg.pspnl.dev_get({}, dump=True): - if dev['ifindex'] == cfg.ifindex: + if dev['ifindex'] == target_ifindex: cfg.psp_info = dev cfg.psp_dev_id = cfg.psp_info['id'] break @@ -394,6 +400,301 @@ def _data_basic_send(cfg, version, ipver): _close_psp_conn(cfg, s) +def _data_basic_send_netkit_psp_assoc(cfg, version, ipver): + """ + Test basic data send with netkit interface associated with PSP dev. + """ + + _init_psp_dev(cfg, True) + psp_dev_id_for_assoc = cfg.psp_dev_id + + # Associate PSP device with nk_guest interface (in guest namespace) + nk_guest_dev = ip(f"link show dev {cfg._nk_guest_ifname}", json=True, ns=cfg.netns)[0] + nk_guest_ifindex = nk_guest_dev['ifindex'] + + cfg.pspnl.dev_assoc({'id': psp_dev_id_for_assoc, 'ifindex': nk_guest_ifindex, 'nsid': cfg.psp_dev_peer_nsid}) + + # Test connectivity in both directions before PSP operations + remote_addr = cfg.remote_addr_v["6"] # remote peer address + nk_guest_addr = cfg.nk_guest_ipv6 # nk_guest address + + # Check if assoc-list contains nk_guest + dev_info = cfg.pspnl.dev_get({'id': psp_dev_id_for_assoc}) + + if 'assoc-list' in dev_info: + found = False + for assoc in dev_info['assoc-list']: + if assoc['ifindex'] == nk_guest_ifindex and assoc['nsid'] == cfg.psp_dev_peer_nsid: + found = True + break + ksft_true(found, "Associated device not found in dev_get() response") + else: + raise RuntimeError("No assoc-list in dev_get() response after association") + + # Enter guest namespace (netns) to run PSP test + with NetNSEnter(cfg.netns.name): + cfg.pspnl = PSPFamily() + + s = _make_psp_conn(cfg, version, ipver) + + rx_assoc = cfg.pspnl.rx_assoc({"version": version, + "dev-id": cfg.psp_dev_id, + "sock-fd": s.fileno()}) + rx = rx_assoc['rx-key'] + tx = _spi_xchg(s, rx) + + cfg.pspnl.tx_assoc({"dev-id": cfg.psp_dev_id, + "version": version, + "tx-key": tx, + "sock-fd": s.fileno()}) + + data_len = _send_careful(cfg, s, 100) + _check_data_rx(cfg, data_len) + _close_psp_conn(cfg, s) + + # Clean up - back in host namespace + cfg.pspnl = PSPFamily() + cfg.pspnl.dev_disassoc({'id': psp_dev_id_for_assoc, 'ifindex': nk_guest_ifindex, 'nsid': cfg.psp_dev_peer_nsid}) + + del cfg.psp_dev_id + del cfg.psp_info + + +def _key_rotation_notify_multi_ns_netkit(cfg, version, ipver): + """ Test key rotation notifications across multiple namespaces using netkit """ + _init_psp_dev(cfg, True) + psp_dev_id_for_assoc = cfg.psp_dev_id + + # Associate PSP device with nk_guest interface (in guest namespace) + nk_guest_dev = ip(f"link show dev {cfg._nk_guest_ifname}", json=True, ns=cfg.netns)[0] + nk_guest_ifindex = nk_guest_dev['ifindex'] + + cfg.pspnl.dev_assoc({'id': psp_dev_id_for_assoc, 'ifindex': nk_guest_ifindex, 'nsid': cfg.psp_dev_peer_nsid}) + + # Create listener in guest namespace; socket stays bound to that ns + with NetNSEnter(cfg.netns.name): + peer_pspnl = PSPFamily() + peer_pspnl.ntf_subscribe('use') + + # Create listener in main namespace + main_pspnl = PSPFamily() + main_pspnl.ntf_subscribe('use') + + # Trigger key rotation on the PSP device + cfg.pspnl.key_rotate({"id": psp_dev_id_for_assoc}) + + # Poll both sockets from main thread + for pspnl, label in [(main_pspnl, "main"), (peer_pspnl, "guest")]: + for i in range(100): + pspnl.check_ntf() + + try: + msg = pspnl.async_msg_queue.get_nowait() + break + except Exception: + pass + + time.sleep(0.1) + else: + raise KsftFailEx(f"No key rotation notification received in {label} namespace") + + ksft_true(msg['msg'].get('id') == psp_dev_id_for_assoc, + f"Key rotation notification for correct device not found in {label} namespace") + + # Clean up + cfg.pspnl.dev_disassoc({'id': psp_dev_id_for_assoc, 'ifindex': nk_guest_ifindex, 'nsid': cfg.psp_dev_peer_nsid}) + del cfg.psp_dev_id + del cfg.psp_info + + +def _dev_change_notify_multi_ns_netkit(cfg, version, ipver): + """ Test dev_change notifications across multiple namespaces using netkit """ + _init_psp_dev(cfg, True) + psp_dev_id_for_assoc = cfg.psp_dev_id + + # Associate PSP device with nk_guest interface (in guest namespace) + nk_guest_dev = ip(f"link show dev {cfg._nk_guest_ifname}", json=True, ns=cfg.netns)[0] + nk_guest_ifindex = nk_guest_dev['ifindex'] + + cfg.pspnl.dev_assoc({'id': psp_dev_id_for_assoc, 'ifindex': nk_guest_ifindex, 'nsid': cfg.psp_dev_peer_nsid}) + + # Create listener in guest namespace; socket stays bound to that ns + with NetNSEnter(cfg.netns.name): + peer_pspnl = PSPFamily() + peer_pspnl.ntf_subscribe('mgmt') + + # Create listener in main namespace + main_pspnl = PSPFamily() + main_pspnl.ntf_subscribe('mgmt') + + # Trigger dev_change by calling dev_set (notification is always sent) + cfg.pspnl.dev_set({'id': psp_dev_id_for_assoc, 'psp-versions-ena': cfg.psp_info['psp-versions-cap']}) + + # Poll both sockets from main thread + for pspnl, label in [(main_pspnl, "main"), (peer_pspnl, "guest")]: + for i in range(100): + pspnl.check_ntf() + + try: + msg = pspnl.async_msg_queue.get_nowait() + break + except Exception: + pass + + time.sleep(0.1) + else: + raise KsftFailEx(f"No dev_change notification received in {label} namespace") + + ksft_true(msg['msg'].get('id') == psp_dev_id_for_assoc, + f"Dev_change notification for correct device not found in {label} namespace") + + # Clean up + cfg.pspnl.dev_disassoc({'id': psp_dev_id_for_assoc, 'ifindex': nk_guest_ifindex, 'nsid': cfg.psp_dev_peer_nsid}) + del cfg.psp_dev_id + del cfg.psp_info + + +def _psp_dev_get_check_netkit_psp_assoc(cfg, version, ipver): + """ Check psp dev-get output with netkit interface associated with PSP dev """ + + _init_psp_dev(cfg, True) + psp_dev_id_for_assoc = cfg.psp_dev_id + + # Associate PSP device with nk_guest interface (in guest namespace) + nk_guest_dev = ip(f"link show dev {cfg._nk_guest_ifname}", json=True, ns=cfg.netns)[0] + nk_guest_ifindex = nk_guest_dev['ifindex'] + + cfg.pspnl.dev_assoc({'id': psp_dev_id_for_assoc, 'ifindex': nk_guest_ifindex, 'nsid': cfg.psp_dev_peer_nsid}) + + # Check 1: In default netns, verify dev-get has correct ifindex and assoc-list + dev_info = cfg.pspnl.dev_get({'id': psp_dev_id_for_assoc}) + + # Verify the PSP device has the correct ifindex + ksft_eq(dev_info['ifindex'], cfg.psp_ifindex) + + # Verify assoc-list exists and contains the associated nk_guest with correct ifindex and nsid + ksft_true('assoc-list' in dev_info, "No assoc-list in dev_get() response after association") + found = False + for assoc in dev_info['assoc-list']: + if assoc['ifindex'] == nk_guest_ifindex and assoc['nsid'] == cfg.psp_dev_peer_nsid: + found = True + break + ksft_true(found, "Associated device not found in assoc-list with correct ifindex and nsid") + + # Check 2: In guest netns, verify dev-get has assoc-list with nk_guest device + with NetNSEnter(cfg.netns.name): + peer_pspnl = PSPFamily() + + # Dump all devices in the guest namespace + peer_devices = peer_pspnl.dev_get({}, dump=True) + + # Find the device with by-association flag + peer_dev = None + for dev in peer_devices: + if dev.get('by-association'): + peer_dev = dev + break + + ksft_not_none(peer_dev, "No PSP device found with by-association flag in guest netns") + + # Verify assoc-list contains the nk_guest device + ksft_true('assoc-list' in peer_dev and len(peer_dev['assoc-list']) > 0, + "Guest device should have assoc-list with local devices") + + # Verify the assoc-list contains nk_guest ifindex with nsid=-1 (same namespace) + found = False + for assoc in peer_dev['assoc-list']: + if assoc['ifindex'] == nk_guest_ifindex: + ksft_eq(assoc['nsid'], -1, + "nsid should be -1 (NETNSA_NSID_NOT_ASSIGNED) for same-namespace device") + found = True + break + ksft_true(found, "nk_guest ifindex not found in assoc-list") + + # Clean up + cfg.pspnl.dev_disassoc({'id': psp_dev_id_for_assoc, 'ifindex': nk_guest_ifindex, 'nsid': cfg.psp_dev_peer_nsid}) + + del cfg.psp_dev_id + del cfg.psp_info + + +def _dev_assoc_no_nsid(cfg): + """ Test dev-assoc and dev-disassoc without nsid attribute """ + _init_psp_dev(cfg, True) + psp_dev_id = cfg.psp_dev_id + + # Get nk_host's ifindex (in host namespace, same as caller) + nk_host_dev = ip(f"link show dev {cfg._nk_host_ifname}", json=True)[0] + nk_host_ifindex = nk_host_dev['ifindex'] + + # Associate without nsid - should look up ifindex in caller's netns + cfg.pspnl.dev_assoc({'id': psp_dev_id, 'ifindex': nk_host_ifindex}) + + # Verify assoc-list contains the device + dev_info = cfg.pspnl.dev_get({'id': psp_dev_id}) + ksft_true('assoc-list' in dev_info, "No assoc-list after association") + found = False + for assoc in dev_info['assoc-list']: + if assoc['ifindex'] == nk_host_ifindex: + found = True + break + ksft_true(found, "Associated device not found in assoc-list") + + # Disassociate without nsid - should also use caller's netns + cfg.pspnl.dev_disassoc({'id': psp_dev_id, 'ifindex': nk_host_ifindex}) + + # Verify assoc-list no longer contains the device + dev_info = cfg.pspnl.dev_get({'id': psp_dev_id}) + found = False + if 'assoc-list' in dev_info: + for assoc in dev_info['assoc-list']: + if assoc['ifindex'] == nk_host_ifindex: + found = True + break + ksft_true(not found, "Device should not be in assoc-list after disassociation") + + del cfg.psp_dev_id + del cfg.psp_info + + +def _psp_dev_assoc_cleanup_on_netkit_del(cfg): + """ Test that assoc-list is cleared when associated netkit interface is deleted """ + _init_psp_dev(cfg, True) + psp_dev_id_for_assoc = cfg.psp_dev_id + + # Associate PSP device with nk_guest interface (in guest namespace) + nk_guest_dev = ip(f"link show dev {cfg._nk_guest_ifname}", json=True, ns=cfg.netns)[0] + nk_guest_ifindex = nk_guest_dev['ifindex'] + + cfg.pspnl.dev_assoc({'id': psp_dev_id_for_assoc, 'ifindex': nk_guest_ifindex, 'nsid': cfg.psp_dev_peer_nsid}) + + # Verify assoc-list exists in default netns + dev_info = cfg.pspnl.dev_get({'id': psp_dev_id_for_assoc}) + ksft_true('assoc-list' in dev_info, "No assoc-list after association") + found = False + for assoc in dev_info['assoc-list']: + if assoc['ifindex'] == nk_guest_ifindex and assoc['nsid'] == cfg.psp_dev_peer_nsid: + found = True + break + ksft_true(found, "Associated device not found in assoc-list") + + # Delete the netkit interface in the guest namespace + ip(f"link del {cfg._nk_guest_ifname}", ns=cfg.netns) + + # Mark netkit as already deleted so cleanup won't try to delete it again + # (deleting nk_guest also removes nk_host since they're a pair) + cfg._nk_host_ifname = None + cfg._nk_guest_ifname = None + + # Verify assoc-list is gone in default netns after netkit deletion + dev_info = cfg.pspnl.dev_get({'id': psp_dev_id_for_assoc}) + ksft_true('assoc-list' not in dev_info or len(dev_info['assoc-list']) == 0, + "assoc-list should be empty after netkit deletion") + + del cfg.psp_dev_id + del cfg.psp_info + + def __bad_xfer_do(cfg, s, tx, version='hdr0-aes-gcm-128'): # Make sure we accept the ACK for the SPI before we seal with the bad assoc _check_data_outq(s, 0) @@ -571,33 +872,162 @@ def removal_device_bi(cfg): _close_conn(cfg, s) -def psp_ip_ver_test_builder(name, test_func, psp_ver, ipver): - """Build test cases for each combo of PSP version and IP version""" - def test_case(cfg): - cfg.require_ipver(ipver) - test_func(cfg, psp_ver, ipver) - - test_case.__name__ = f"{name}_v{psp_ver}_ip{ipver}" - return test_case +@ksft_variants([ + KsftNamedVariant(f"v{v}_ip{ip}", v, ip) + for v in range(4) for ip in ("4", "6") +]) +def data_basic_send(cfg, version, ipver): + cfg.require_ipver(ipver) + _data_basic_send(cfg, version, ipver) + + +@ksft_variants([ + KsftNamedVariant(f"ip{ip}", ip) + for ip in ("4", "6") +]) +def data_mss_adjust(cfg, ipver): + cfg.require_ipver(ipver) + _data_mss_adjust(cfg, ipver) + + +@ksft_variants([ + KsftNamedVariant(f"v{v}_ip6", v, "6") + for v in range(4) +]) +def data_basic_send_netkit_psp_assoc(cfg, version, ipver): + cfg.require_ipver(ipver) + _data_basic_send_netkit_psp_assoc(cfg, version, ipver) + + +@ksft_variants([ + KsftNamedVariant(f"v{v}_ip6", v, "6") + for v in range(4) +]) +def key_rotation_notify_multi_ns_netkit(cfg, version, ipver): + cfg.require_ipver(ipver) + _key_rotation_notify_multi_ns_netkit(cfg, version, ipver) + + +@ksft_variants([ + KsftNamedVariant(f"v{v}_ip6", v, "6") + for v in range(4) +]) +def dev_change_notify_multi_ns_netkit(cfg, version, ipver): + cfg.require_ipver(ipver) + _dev_change_notify_multi_ns_netkit(cfg, version, ipver) + + +@ksft_variants([ + KsftNamedVariant(f"v{v}_ip6", v, "6") + for v in range(4) +]) +def psp_dev_get_check_netkit_psp_assoc(cfg, version, ipver): + cfg.require_ipver(ipver) + _psp_dev_get_check_netkit_psp_assoc(cfg, version, ipver) + + +@ksft_variants([ + KsftNamedVariant(f"v{v}_ip6", v, "6") + for v in range(4) +]) +def dev_assoc_no_nsid(cfg, version, ipver): + cfg.require_ipver(ipver) + _dev_assoc_no_nsid(cfg) + + +def _get_nsid(ns_name): + """Get the nsid for a namespace.""" + for entry in ip("netns list-id", json=True): + if entry.get("name") == str(ns_name): + return entry["nsid"] + raise KsftSkipEx(f"nsid not found for namespace {ns_name}") + + +def _setup_psp_attributes(cfg): + """ + Set up PSP-specific attributes on the environment. + + This sets attributes needed for PSP tests based on whether we're using + netdevsim or a real NIC. + """ + if cfg._ns is not None: + # netdevsim case: PSP device is the local dev (in host namespace) + cfg.psp_dev = cfg._ns.nsims[0].dev + cfg.psp_ifname = cfg.psp_dev['ifname'] + cfg.psp_ifindex = cfg.psp_dev['ifindex'] + + # PSP peer device is the remote dev (in _netns, where psp_responder runs) + cfg.psp_dev_peer = cfg._ns_peer.nsims[0].dev + cfg.psp_dev_peer_ifname = cfg.psp_dev_peer['ifname'] + cfg.psp_dev_peer_ifindex = cfg.psp_dev_peer['ifindex'] + else: + # Real NIC case: PSP device is the local interface + cfg.psp_dev = cfg.dev + cfg.psp_ifname = cfg.ifname + cfg.psp_ifindex = cfg.ifindex + + # PSP peer device is the remote interface + cfg.psp_dev_peer = cfg.remote_dev + cfg.psp_dev_peer_ifname = cfg.remote_ifname + cfg.psp_dev_peer_ifindex = cfg.remote_ifindex + + # Get nsid for the guest namespace (netns) where nk_guest is + cfg.psp_dev_peer_nsid = _get_nsid(cfg.netns.name) + + +def _setup_psp_routes(cfg): + """ + Set up routes for cross-namespace connectivity. + + Traffic flows: + 1. remote (_netns) -> nk_guest (netns): + psp_dev_peer -> psp_dev_local -> BPF redirect -> nk_host -> nk_guest + Needs: route in _netns to nk_v6_pfx/64 via psp_dev_local + + 2. nk_guest (netns) -> remote (_netns): + nk_guest -> nk_host -> psp_dev_local -> psp_dev_peer + Needs: route in netns to dev_v6_pfx/64 via nk_host + """ + # In _netns (remote namespace): add route to nk_guest prefix via psp_dev_local + # psp_dev_peer can reach psp_dev_local via the link, then traffic goes through BPF + ip(f"-6 route add {cfg.nk_v6_pfx}/64 via {cfg.nsim_v6_pfx}1 dev {cfg.psp_dev_peer_ifname}", + ns=cfg._netns) + + # In netns (guest namespace): add route to remote peer prefix + # nk_guest default route goes to nk_host, but we need explicit route to dev_v6_pfx/64 + ip(f"-6 route add {cfg.nsim_v6_pfx}/64 via fe80::1 dev {cfg._nk_guest_ifname}", + ns=cfg.netns) -def ipver_test_builder(name, test_func, ipver): - """Build test cases for each IP version""" - def test_case(cfg): - cfg.require_ipver(ipver) - test_func(cfg, ipver) +def main() -> None: + """ Ksft boiler plate main """ - test_case.__name__ = f"{name}_ip{ipver}" - return test_case + # Use a different prefix for netkit guest to avoid conflict with dev prefix + nk_v6_pfx = "2001:db9::" + # Set LOCAL_PREFIX_V6 to a DIFFERENT prefix than the dev prefix to avoid BPF + # redirecting psp_responder traffic. The BPF only redirects traffic + # matching LOCAL_PREFIX_V6, so dev traffic (2001:db8::) won't be affected. + if "LOCAL_PREFIX_V6" not in os.environ: + os.environ["LOCAL_PREFIX_V6"] = nk_v6_pfx -def main() -> None: - """ Ksft boiler plate main """ + try: + env = NetDrvContEnv(__file__, install_tx_redirect_bpf=True) + has_cont = True + except KsftSkipEx: + env = NetDrvEpEnv(__file__) + has_cont = False - with NetDrvEpEnv(__file__) as cfg: + with env as cfg: cfg.pspnl = PSPFamily() + if has_cont: + cfg.nk_v6_pfx = nk_v6_pfx + _setup_psp_attributes(cfg) + _setup_psp_routes(cfg) + # Set up responder and communication sock + # psp_responder runs in _netns (remote namespace with psp_dev_peer) responder = cfg.remote.deploy("psp_responder") cfg.comm_port = rand_port() @@ -611,17 +1041,17 @@ def main() -> None: cfg.comm_port), timeout=1) - cases = [ - psp_ip_ver_test_builder( - "data_basic_send", _data_basic_send, version, ipver - ) - for version in range(0, 4) - for ipver in ("4", "6") - ] - cases += [ - ipver_test_builder("data_mss_adjust", _data_mss_adjust, ipver) - for ipver in ("4", "6") - ] + cases = [data_basic_send, data_mss_adjust] + + if has_cont: + cases += [ + data_basic_send_netkit_psp_assoc, + key_rotation_notify_multi_ns_netkit, + dev_change_notify_multi_ns_netkit, + psp_dev_get_check_netkit_psp_assoc, + dev_assoc_no_nsid, + _psp_dev_assoc_cleanup_on_netkit_del, + ] ksft_run(cases=cases, globs=globals(), case_pfx={"dev_", "data_", "assoc_", "removal_"}, -- 2.52.0