These drivers only support TLS 1.2. Return early when TLS 1.3 is requested to prevent unsupported hardware offload attempts. Signed-off-by: Rishikesh Jethwani --- drivers/net/ethernet/chelsio/inline_crypto/ch_ktls/chcr_ktls.c | 3 +++ drivers/net/ethernet/netronome/nfp/crypto/tls.c | 3 +++ 2 files changed, 6 insertions(+) diff --git a/drivers/net/ethernet/chelsio/inline_crypto/ch_ktls/chcr_ktls.c b/drivers/net/ethernet/chelsio/inline_crypto/ch_ktls/chcr_ktls.c index f5acd4be1e69..29e108ce6764 100644 --- a/drivers/net/ethernet/chelsio/inline_crypto/ch_ktls/chcr_ktls.c +++ b/drivers/net/ethernet/chelsio/inline_crypto/ch_ktls/chcr_ktls.c @@ -431,6 +431,9 @@ static int chcr_ktls_dev_add(struct net_device *netdev, struct sock *sk, atomic64_inc(&port_stats->ktls_tx_connection_open); u_ctx = adap->uld[CXGB4_ULD_KTLS].handle; + if (crypto_info->version != TLS_1_2_VERSION) + goto out; + if (direction == TLS_OFFLOAD_CTX_DIR_RX) { pr_err("not expecting for RX direction\n"); goto out; diff --git a/drivers/net/ethernet/netronome/nfp/crypto/tls.c b/drivers/net/ethernet/netronome/nfp/crypto/tls.c index 9983d7aa2b9c..13864c6a55dc 100644 --- a/drivers/net/ethernet/netronome/nfp/crypto/tls.c +++ b/drivers/net/ethernet/netronome/nfp/crypto/tls.c @@ -287,6 +287,9 @@ nfp_net_tls_add(struct net_device *netdev, struct sock *sk, BUILD_BUG_ON(offsetof(struct nfp_net_tls_offload_ctx, rx_end) > TLS_DRIVER_STATE_SIZE_RX); + if (crypto_info->version != TLS_1_2_VERSION) + return -EOPNOTSUPP; + if (!nfp_net_cipher_supported(nn, crypto_info->cipher_type, direction)) return -EOPNOTSUPP; -- 2.25.1 Enable TLS 1.3 TX/RX hardware offload on ConnectX-6 Dx and newer crypto-enabled adapters. Key changes: - Add TLS 1.3 capability checking and version validation - Use MLX5E_STATIC_PARAMS_CONTEXT_TLS_1_3 (0x3) for crypto context - Handle TLS 1.3 IV format: full 12-byte IV copied to gcm_iv + implicit_iv (vs TLS 1.2's 4-byte salt only) Tested with TLS 1.3 AES-GCM-128 and AES-GCM-256 cipher suites. Signed-off-by: Rishikesh Jethwani --- .../ethernet/mellanox/mlx5/core/en_accel/ktls.h | 8 +++++++- .../mellanox/mlx5/core/en_accel/ktls_txrx.c | 14 +++++++++++--- 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/drivers/net/ethernet/mellanox/mlx5/core/en_accel/ktls.h b/drivers/net/ethernet/mellanox/mlx5/core/en_accel/ktls.h index 07a04a142a2e..0469ca6a0762 100644 --- a/drivers/net/ethernet/mellanox/mlx5/core/en_accel/ktls.h +++ b/drivers/net/ethernet/mellanox/mlx5/core/en_accel/ktls.h @@ -30,7 +30,9 @@ static inline bool mlx5e_is_ktls_device(struct mlx5_core_dev *mdev) return false; return (MLX5_CAP_TLS(mdev, tls_1_2_aes_gcm_128) || - MLX5_CAP_TLS(mdev, tls_1_2_aes_gcm_256)); + MLX5_CAP_TLS(mdev, tls_1_2_aes_gcm_256) || + MLX5_CAP_TLS(mdev, tls_1_3_aes_gcm_128) || + MLX5_CAP_TLS(mdev, tls_1_3_aes_gcm_256)); } static inline bool mlx5e_ktls_type_check(struct mlx5_core_dev *mdev, @@ -40,10 +42,14 @@ static inline bool mlx5e_ktls_type_check(struct mlx5_core_dev *mdev, case TLS_CIPHER_AES_GCM_128: if (crypto_info->version == TLS_1_2_VERSION) return MLX5_CAP_TLS(mdev, tls_1_2_aes_gcm_128); + else if (crypto_info->version == TLS_1_3_VERSION) + return MLX5_CAP_TLS(mdev, tls_1_3_aes_gcm_128); break; case TLS_CIPHER_AES_GCM_256: if (crypto_info->version == TLS_1_2_VERSION) return MLX5_CAP_TLS(mdev, tls_1_2_aes_gcm_256); + else if (crypto_info->version == TLS_1_3_VERSION) + return MLX5_CAP_TLS(mdev, tls_1_3_aes_gcm_256); break; } diff --git a/drivers/net/ethernet/mellanox/mlx5/core/en_accel/ktls_txrx.c b/drivers/net/ethernet/mellanox/mlx5/core/en_accel/ktls_txrx.c index 570a912dd6fa..f3f90ad6c6cf 100644 --- a/drivers/net/ethernet/mellanox/mlx5/core/en_accel/ktls_txrx.c +++ b/drivers/net/ethernet/mellanox/mlx5/core/en_accel/ktls_txrx.c @@ -6,6 +6,7 @@ enum { MLX5E_STATIC_PARAMS_CONTEXT_TLS_1_2 = 0x2, + MLX5E_STATIC_PARAMS_CONTEXT_TLS_1_3 = 0x3, }; enum { @@ -15,8 +16,10 @@ enum { #define EXTRACT_INFO_FIELDS do { \ salt = info->salt; \ rec_seq = info->rec_seq; \ + iv = info->iv; \ salt_sz = sizeof(info->salt); \ rec_seq_sz = sizeof(info->rec_seq); \ + iv_sz = sizeof(info->iv); \ } while (0) static void @@ -25,8 +28,8 @@ fill_static_params(struct mlx5_wqe_tls_static_params_seg *params, u32 key_id, u32 resync_tcp_sn) { char *initial_rn, *gcm_iv; - u16 salt_sz, rec_seq_sz; - char *salt, *rec_seq; + u16 salt_sz, rec_seq_sz, iv_sz; + char *salt, *rec_seq, *iv; u8 tls_version; u8 *ctx; @@ -59,7 +62,12 @@ fill_static_params(struct mlx5_wqe_tls_static_params_seg *params, memcpy(gcm_iv, salt, salt_sz); memcpy(initial_rn, rec_seq, rec_seq_sz); - tls_version = MLX5E_STATIC_PARAMS_CONTEXT_TLS_1_2; + if (crypto_info->crypto_info.version == TLS_1_3_VERSION) { + memcpy(gcm_iv + salt_sz, iv, iv_sz); + tls_version = MLX5E_STATIC_PARAMS_CONTEXT_TLS_1_3; + } else { + tls_version = MLX5E_STATIC_PARAMS_CONTEXT_TLS_1_2; + } MLX5_SET(tls_static_params, ctx, tls_version, tls_version); MLX5_SET(tls_static_params, ctx, const_1, 1); -- 2.25.1 Add TLS 1.3 support to the kernel TLS hardware offload infrastructure, enabling hardware acceleration for TLS 1.3 connections on capable NICs. Tested on Mellanox ConnectX-6 Dx (Crypto Enabled) with TLS 1.3 AES-GCM-128 and AES-GCM-256 cipher suites. Signed-off-by: Rishikesh Jethwani --- net/tls/tls_device.c | 65 ++++++++++++++++----------- net/tls/tls_device_fallback.c | 58 +++++++++++++----------- net/tls/tls_main.c | 85 ++++++++++++++++++++--------------- 3 files changed, 121 insertions(+), 87 deletions(-) diff --git a/net/tls/tls_device.c b/net/tls/tls_device.c index 99c8eff9783e..1321bf9b59b0 100644 --- a/net/tls/tls_device.c +++ b/net/tls/tls_device.c @@ -317,25 +317,34 @@ static void tls_device_record_close(struct sock *sk, unsigned char record_type) { struct tls_prot_info *prot = &ctx->prot_info; - struct page_frag dummy_tag_frag; - - /* append tag - * device will fill in the tag, we just need to append a placeholder - * use socket memory to improve coalescing (re-using a single buffer - * increases frag count) - * if we can't allocate memory now use the dummy page + int tail = prot->tag_size + prot->tail_size; + + /* Append tail: tag for TLS 1.2, content_type + tag for TLS 1.3. + * Device fills in the tag, we just need to append a placeholder. + * Use socket memory to improve coalescing (re-using a single buffer + * increases frag count); if allocation fails use dummy_page + * (offset = record_type gives correct content_type byte via + * identity mapping) */ - if (unlikely(pfrag->size - pfrag->offset < prot->tag_size) && - !skb_page_frag_refill(prot->tag_size, pfrag, sk->sk_allocation)) { - dummy_tag_frag.page = dummy_page; - dummy_tag_frag.offset = 0; - pfrag = &dummy_tag_frag; + if (unlikely(pfrag->size - pfrag->offset < tail) && + !skb_page_frag_refill(tail, pfrag, sk->sk_allocation)) { + struct page_frag dummy_pfrag = { + .page = dummy_page, + .offset = record_type, + }; + tls_append_frag(record, &dummy_pfrag, tail); + } else { + if (prot->tail_size) { + char *content_type_addr = page_address(pfrag->page) + + pfrag->offset; + *content_type_addr = record_type; + } + tls_append_frag(record, pfrag, tail); } - tls_append_frag(record, pfrag, prot->tag_size); /* fill prepend */ tls_fill_prepend(ctx, skb_frag_address(&record->frags[0]), - record->len - prot->overhead_size, + record->len - prot->overhead_size + prot->tail_size, record_type); } @@ -883,6 +892,7 @@ static int tls_device_reencrypt(struct sock *sk, struct tls_context *tls_ctx) { struct tls_sw_context_rx *sw_ctx = tls_sw_ctx_rx(tls_ctx); + struct tls_prot_info *prot = &tls_ctx->prot_info; const struct tls_cipher_desc *cipher_desc; int err, offset, copy, data_len, pos; struct sk_buff *skb, *skb_iter; @@ -894,7 +904,7 @@ tls_device_reencrypt(struct sock *sk, struct tls_context *tls_ctx) DEBUG_NET_WARN_ON_ONCE(!cipher_desc || !cipher_desc->offloadable); rxm = strp_msg(tls_strp_msg(sw_ctx)); - orig_buf = kmalloc(rxm->full_len + TLS_HEADER_SIZE + cipher_desc->iv, + orig_buf = kmalloc(rxm->full_len + prot->prepend_size, sk->sk_allocation); if (!orig_buf) return -ENOMEM; @@ -909,9 +919,8 @@ tls_device_reencrypt(struct sock *sk, struct tls_context *tls_ctx) offset = rxm->offset; sg_init_table(sg, 1); - sg_set_buf(&sg[0], buf, - rxm->full_len + TLS_HEADER_SIZE + cipher_desc->iv); - err = skb_copy_bits(skb, offset, buf, TLS_HEADER_SIZE + cipher_desc->iv); + sg_set_buf(&sg[0], buf, rxm->full_len + prot->prepend_size); + err = skb_copy_bits(skb, offset, buf, prot->prepend_size); if (err) goto free_buf; @@ -1089,11 +1098,6 @@ int tls_set_device_offload(struct sock *sk) } crypto_info = &ctx->crypto_send.info; - if (crypto_info->version != TLS_1_2_VERSION) { - rc = -EOPNOTSUPP; - goto release_netdev; - } - cipher_desc = get_cipher_desc(crypto_info->cipher_type); if (!cipher_desc || !cipher_desc->offloadable) { rc = -EINVAL; @@ -1196,9 +1200,6 @@ int tls_set_device_offload_rx(struct sock *sk, struct tls_context *ctx) struct net_device *netdev; int rc = 0; - if (ctx->crypto_recv.info.version != TLS_1_2_VERSION) - return -EOPNOTSUPP; - netdev = get_netdev_for_sock(sk); if (!netdev) { pr_err_ratelimited("%s: netdev not found\n", __func__); @@ -1409,12 +1410,22 @@ static struct notifier_block tls_dev_notifier = { int __init tls_device_init(void) { - int err; + unsigned char *page_addr; + int err, i; dummy_page = alloc_page(GFP_KERNEL); if (!dummy_page) return -ENOMEM; + /* Pre-populate dummy_page with identity mapping for all byte values. + * This is used as fallback for TLS 1.3 content type when memory + * allocation fails. By populating all 256 values, we avoid needing + * to validate record_type at runtime. + */ + page_addr = page_address(dummy_page); + for (i = 0; i < 256; i++) + page_addr[i] = (unsigned char)i; + destruct_wq = alloc_workqueue("ktls_device_destruct", WQ_PERCPU, 0); if (!destruct_wq) { err = -ENOMEM; diff --git a/net/tls/tls_device_fallback.c b/net/tls/tls_device_fallback.c index d3c72f509baa..99d5590d20b0 100644 --- a/net/tls/tls_device_fallback.c +++ b/net/tls/tls_device_fallback.c @@ -37,14 +37,15 @@ #include "tls.h" -static int tls_enc_record(struct aead_request *aead_req, +static int tls_enc_record(struct tls_context *tls_ctx, + struct aead_request *aead_req, struct crypto_aead *aead, char *aad, char *iv, __be64 rcd_sn, struct scatter_walk *in, - struct scatter_walk *out, int *in_len, - struct tls_prot_info *prot) + struct scatter_walk *out, int *in_len) { unsigned char buf[TLS_HEADER_SIZE + TLS_MAX_IV_SIZE]; + struct tls_prot_info *prot = &tls_ctx->prot_info; const struct tls_cipher_desc *cipher_desc; struct scatterlist sg_in[3]; struct scatterlist sg_out[3]; @@ -55,7 +56,7 @@ static int tls_enc_record(struct aead_request *aead_req, cipher_desc = get_cipher_desc(prot->cipher_type); DEBUG_NET_WARN_ON_ONCE(!cipher_desc || !cipher_desc->offloadable); - buf_size = TLS_HEADER_SIZE + cipher_desc->iv; + buf_size = prot->prepend_size; len = min_t(int, *in_len, buf_size); memcpy_from_scatterwalk(buf, in, len); @@ -66,16 +67,27 @@ static int tls_enc_record(struct aead_request *aead_req, return 0; len = buf[4] | (buf[3] << 8); - len -= cipher_desc->iv; + if (prot->version != TLS_1_3_VERSION) + len -= cipher_desc->iv; tls_make_aad(aad, len - cipher_desc->tag, (char *)&rcd_sn, buf[0], prot); - memcpy(iv + cipher_desc->salt, buf + TLS_HEADER_SIZE, cipher_desc->iv); + if (prot->version == TLS_1_3_VERSION) { + void *iv_src = crypto_info_iv(&tls_ctx->crypto_send.info, + cipher_desc); + + memcpy(iv + cipher_desc->salt, iv_src, cipher_desc->iv); + } else { + memcpy(iv + cipher_desc->salt, buf + TLS_HEADER_SIZE, + cipher_desc->iv); + } + + tls_xor_iv_with_seq(prot, iv, (char *)&rcd_sn); sg_init_table(sg_in, ARRAY_SIZE(sg_in)); sg_init_table(sg_out, ARRAY_SIZE(sg_out)); - sg_set_buf(sg_in, aad, TLS_AAD_SPACE_SIZE); - sg_set_buf(sg_out, aad, TLS_AAD_SPACE_SIZE); + sg_set_buf(sg_in, aad, prot->aad_size); + sg_set_buf(sg_out, aad, prot->aad_size); scatterwalk_get_sglist(in, sg_in + 1); scatterwalk_get_sglist(out, sg_out + 1); @@ -108,13 +120,6 @@ static int tls_enc_record(struct aead_request *aead_req, return rc; } -static void tls_init_aead_request(struct aead_request *aead_req, - struct crypto_aead *aead) -{ - aead_request_set_tfm(aead_req, aead); - aead_request_set_ad(aead_req, TLS_AAD_SPACE_SIZE); -} - static struct aead_request *tls_alloc_aead_request(struct crypto_aead *aead, gfp_t flags) { @@ -124,14 +129,15 @@ static struct aead_request *tls_alloc_aead_request(struct crypto_aead *aead, aead_req = kzalloc(req_size, flags); if (aead_req) - tls_init_aead_request(aead_req, aead); + aead_request_set_tfm(aead_req, aead); return aead_req; } -static int tls_enc_records(struct aead_request *aead_req, +static int tls_enc_records(struct tls_context *tls_ctx, + struct aead_request *aead_req, struct crypto_aead *aead, struct scatterlist *sg_in, struct scatterlist *sg_out, char *aad, char *iv, - u64 rcd_sn, int len, struct tls_prot_info *prot) + u64 rcd_sn, int len) { struct scatter_walk out, in; int rc; @@ -140,8 +146,8 @@ static int tls_enc_records(struct aead_request *aead_req, scatterwalk_start(&out, sg_out); do { - rc = tls_enc_record(aead_req, aead, aad, iv, - cpu_to_be64(rcd_sn), &in, &out, &len, prot); + rc = tls_enc_record(tls_ctx, aead_req, aead, aad, iv, + cpu_to_be64(rcd_sn), &in, &out, &len); rcd_sn++; } while (rc == 0 && len); @@ -317,7 +323,10 @@ static struct sk_buff *tls_enc_skb(struct tls_context *tls_ctx, cipher_desc = get_cipher_desc(tls_ctx->crypto_send.info.cipher_type); DEBUG_NET_WARN_ON_ONCE(!cipher_desc || !cipher_desc->offloadable); - buf_len = cipher_desc->salt + cipher_desc->iv + TLS_AAD_SPACE_SIZE + + aead_request_set_ad(aead_req, tls_ctx->prot_info.aad_size); + + buf_len = cipher_desc->salt + cipher_desc->iv + + tls_ctx->prot_info.aad_size + sync_size + cipher_desc->tag; buf = kmalloc(buf_len, GFP_ATOMIC); if (!buf) @@ -327,7 +336,7 @@ static struct sk_buff *tls_enc_skb(struct tls_context *tls_ctx, salt = crypto_info_salt(&tls_ctx->crypto_send.info, cipher_desc); memcpy(iv, salt, cipher_desc->salt); aad = buf + cipher_desc->salt + cipher_desc->iv; - dummy_buf = aad + TLS_AAD_SPACE_SIZE; + dummy_buf = aad + tls_ctx->prot_info.aad_size; nskb = alloc_skb(skb_headroom(skb) + skb->len, GFP_ATOMIC); if (!nskb) @@ -338,9 +347,8 @@ static struct sk_buff *tls_enc_skb(struct tls_context *tls_ctx, fill_sg_out(sg_out, buf, tls_ctx, nskb, tcp_payload_offset, payload_len, sync_size, dummy_buf); - if (tls_enc_records(aead_req, ctx->aead_send, sg_in, sg_out, aad, iv, - rcd_sn, sync_size + payload_len, - &tls_ctx->prot_info) < 0) + if (tls_enc_records(tls_ctx, aead_req, ctx->aead_send, sg_in, sg_out, + aad, iv, rcd_sn, sync_size + payload_len) < 0) goto free_nskb; complete_skb(nskb, skb, tcp_payload_offset); diff --git a/net/tls/tls_main.c b/net/tls/tls_main.c index fd39acf41a61..fd04857fa0ab 100644 --- a/net/tls/tls_main.c +++ b/net/tls/tls_main.c @@ -711,49 +711,64 @@ static int do_tls_setsockopt_conf(struct sock *sk, sockptr_t optval, } if (tx) { - rc = tls_set_device_offload(sk); - conf = TLS_HW; - if (!rc) { - TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSTXDEVICE); - TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRTXDEVICE); - } else { - rc = tls_set_sw_offload(sk, 1, - update ? crypto_info : NULL); - if (rc) - goto err_crypto_info; - - if (update) { - TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSTXREKEYOK); - } else { - TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSTXSW); - TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRTXSW); + if (update && ctx->tx_conf == TLS_HW) { + rc = -EOPNOTSUPP; + goto err_crypto_info; + } + + if (!update) { + rc = tls_set_device_offload(sk); + conf = TLS_HW; + if (!rc) { + TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSTXDEVICE); + TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRTXDEVICE); + goto out; } - conf = TLS_SW; } - } else { - rc = tls_set_device_offload_rx(sk, ctx); - conf = TLS_HW; - if (!rc) { - TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSRXDEVICE); - TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRRXDEVICE); + + rc = tls_set_sw_offload(sk, 1, update ? crypto_info : NULL); + if (rc) + goto err_crypto_info; + + if (update) { + TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSTXREKEYOK); } else { - rc = tls_set_sw_offload(sk, 0, - update ? crypto_info : NULL); - if (rc) - goto err_crypto_info; - - if (update) { - TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSRXREKEYOK); - } else { - TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSRXSW); - TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRRXSW); + TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSTXSW); + TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRTXSW); + } + conf = TLS_SW; + } else { + if (update && ctx->rx_conf == TLS_HW) { + rc = -EOPNOTSUPP; + goto err_crypto_info; + } + + if (!update) { + rc = tls_set_device_offload_rx(sk, ctx); + conf = TLS_HW; + if (!rc) { + TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSRXDEVICE); + TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRRXDEVICE); + tls_sw_strparser_arm(sk, ctx); + goto out; } - conf = TLS_SW; } - if (!update) + + rc = tls_set_sw_offload(sk, 0, update ? crypto_info : NULL); + if (rc) + goto err_crypto_info; + + if (update) { + TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSRXREKEYOK); + } else { + TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSRXSW); + TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRRXSW); tls_sw_strparser_arm(sk, ctx); + } + conf = TLS_SW; } +out: if (tx) ctx->tx_conf = conf; else -- 2.25.1 Separate cipher context initialization from key material finalization to support staged setup for hardware offload fallback paths. Signed-off-by: Rishikesh Jethwani --- net/tls/tls.h | 4 +++ net/tls/tls_device.c | 3 +- net/tls/tls_sw.c | 77 +++++++++++++++++++++++++++++++------------- 3 files changed, 61 insertions(+), 23 deletions(-) diff --git a/net/tls/tls.h b/net/tls/tls.h index 2f86baeb71fc..56eba13261d4 100644 --- a/net/tls/tls.h +++ b/net/tls/tls.h @@ -147,6 +147,10 @@ void tls_strp_abort_strp(struct tls_strparser *strp, int err); int init_prot_info(struct tls_prot_info *prot, const struct tls_crypto_info *crypto_info, const struct tls_cipher_desc *cipher_desc); +int tls_sw_ctx_init(struct sock *sk, int tx, + struct tls_crypto_info *new_crypto_info); +void tls_sw_ctx_finalize(struct sock *sk, int tx, + struct tls_crypto_info *new_crypto_info); int tls_set_sw_offload(struct sock *sk, int tx, struct tls_crypto_info *new_crypto_info); void tls_update_rx_zc_capable(struct tls_context *tls_ctx); diff --git a/net/tls/tls_device.c b/net/tls/tls_device.c index 1321bf9b59b0..cd26873e9063 100644 --- a/net/tls/tls_device.c +++ b/net/tls/tls_device.c @@ -1233,7 +1233,7 @@ int tls_set_device_offload_rx(struct sock *sk, struct tls_context *ctx) context->resync_nh_reset = 1; ctx->priv_ctx_rx = context; - rc = tls_set_sw_offload(sk, 0, NULL); + rc = tls_sw_ctx_init(sk, 0, NULL); if (rc) goto release_ctx; @@ -1247,6 +1247,7 @@ int tls_set_device_offload_rx(struct sock *sk, struct tls_context *ctx) goto free_sw_resources; tls_device_attach(ctx, sk, netdev); + tls_sw_ctx_finalize(sk, 0, NULL); up_read(&device_offload_lock); dev_put(netdev); diff --git a/net/tls/tls_sw.c b/net/tls/tls_sw.c index 5fe07f110fe8..424e0a11bcf4 100644 --- a/net/tls/tls_sw.c +++ b/net/tls/tls_sw.c @@ -2775,20 +2775,19 @@ static void tls_finish_key_update(struct sock *sk, struct tls_context *tls_ctx) ctx->saved_data_ready(sk); } -int tls_set_sw_offload(struct sock *sk, int tx, - struct tls_crypto_info *new_crypto_info) +int tls_sw_ctx_init(struct sock *sk, int tx, + struct tls_crypto_info *new_crypto_info) { struct tls_crypto_info *crypto_info, *src_crypto_info; struct tls_sw_context_tx *sw_ctx_tx = NULL; struct tls_sw_context_rx *sw_ctx_rx = NULL; const struct tls_cipher_desc *cipher_desc; - char *iv, *rec_seq, *key, *salt; - struct cipher_context *cctx; struct tls_prot_info *prot; struct crypto_aead **aead; struct tls_context *ctx; struct crypto_tfm *tfm; int rc = 0; + char *key; ctx = tls_get_ctx(sk); prot = &ctx->prot_info; @@ -2809,12 +2808,10 @@ int tls_set_sw_offload(struct sock *sk, int tx, if (tx) { sw_ctx_tx = ctx->priv_ctx_tx; crypto_info = &ctx->crypto_send.info; - cctx = &ctx->tx; aead = &sw_ctx_tx->aead_send; } else { sw_ctx_rx = ctx->priv_ctx_rx; crypto_info = &ctx->crypto_recv.info; - cctx = &ctx->rx; aead = &sw_ctx_rx->aead_recv; } @@ -2830,10 +2827,7 @@ int tls_set_sw_offload(struct sock *sk, int tx, if (rc) goto free_priv; - iv = crypto_info_iv(src_crypto_info, cipher_desc); key = crypto_info_key(src_crypto_info, cipher_desc); - salt = crypto_info_salt(src_crypto_info, cipher_desc); - rec_seq = crypto_info_rec_seq(src_crypto_info, cipher_desc); if (!*aead) { *aead = crypto_alloc_aead(cipher_desc->cipher_name, 0, 0); @@ -2877,19 +2871,6 @@ int tls_set_sw_offload(struct sock *sk, int tx, goto free_aead; } - memcpy(cctx->iv, salt, cipher_desc->salt); - memcpy(cctx->iv + cipher_desc->salt, iv, cipher_desc->iv); - memcpy(cctx->rec_seq, rec_seq, cipher_desc->rec_seq); - - if (new_crypto_info) { - unsafe_memcpy(crypto_info, new_crypto_info, - cipher_desc->crypto_info, - /* size was checked in do_tls_setsockopt_conf */); - memzero_explicit(new_crypto_info, cipher_desc->crypto_info); - if (!tx) - tls_finish_key_update(sk, ctx); - } - goto out; free_aead: @@ -2908,3 +2889,55 @@ int tls_set_sw_offload(struct sock *sk, int tx, out: return rc; } + +void tls_sw_ctx_finalize(struct sock *sk, int tx, + struct tls_crypto_info *new_crypto_info) +{ + struct tls_crypto_info *crypto_info, *src_crypto_info; + const struct tls_cipher_desc *cipher_desc; + struct tls_context *ctx = tls_get_ctx(sk); + struct cipher_context *cctx; + char *iv, *salt, *rec_seq; + + if (tx) { + crypto_info = &ctx->crypto_send.info; + cctx = &ctx->tx; + } else { + crypto_info = &ctx->crypto_recv.info; + cctx = &ctx->rx; + } + + src_crypto_info = new_crypto_info ?: crypto_info; + cipher_desc = get_cipher_desc(src_crypto_info->cipher_type); + + iv = crypto_info_iv(src_crypto_info, cipher_desc); + salt = crypto_info_salt(src_crypto_info, cipher_desc); + rec_seq = crypto_info_rec_seq(src_crypto_info, cipher_desc); + + memcpy(cctx->iv, salt, cipher_desc->salt); + memcpy(cctx->iv + cipher_desc->salt, iv, cipher_desc->iv); + memcpy(cctx->rec_seq, rec_seq, cipher_desc->rec_seq); + + if (new_crypto_info) { + unsafe_memcpy(crypto_info, new_crypto_info, + cipher_desc->crypto_info, + /* size was checked in do_tls_setsockopt_conf */); + memzero_explicit(new_crypto_info, cipher_desc->crypto_info); + + if (!tx) + tls_finish_key_update(sk, ctx); + } +} + +int tls_set_sw_offload(struct sock *sk, int tx, + struct tls_crypto_info *new_crypto_info) +{ + int rc; + + rc = tls_sw_ctx_init(sk, tx, new_crypto_info); + if (rc) + return rc; + + tls_sw_ctx_finalize(sk, tx, new_crypto_info); + return 0; +} -- 2.25.1 Add TLS KeyUpdate (rekey) support for hardware offload connections. When tls_dev_add() fails during hardware rekey, the connection gracefully degrades to software encryption/decryption while maintaining TLS_HW configuration. Tested on Mellanox ConnectX-6 Dx (Crypto Enabled) with multiple TLS 1.3 key update cycles. Signed-off-by: Rishikesh Jethwani --- include/net/tls.h | 79 ++++-- include/uapi/linux/snmp.h | 2 + net/tls/tls.h | 14 +- net/tls/tls_device.c | 492 +++++++++++++++++++++++++++++----- net/tls/tls_device_fallback.c | 24 ++ net/tls/tls_main.c | 92 ++++--- net/tls/tls_proc.c | 2 + net/tls/tls_sw.c | 28 +- 8 files changed, 594 insertions(+), 139 deletions(-) diff --git a/include/net/tls.h b/include/net/tls.h index ebd2550280ae..4855cd7f1747 100644 --- a/include/net/tls.h +++ b/include/net/tls.h @@ -151,6 +151,22 @@ struct tls_record_info { skb_frag_t frags[MAX_SKB_FRAGS]; }; +struct cipher_context { + char iv[TLS_MAX_IV_SIZE + TLS_MAX_SALT_SIZE]; + char rec_seq[TLS_MAX_REC_SEQ_SIZE]; +}; + +union tls_crypto_context { + struct tls_crypto_info info; + union { + struct tls12_crypto_info_aes_gcm_128 aes_gcm_128; + struct tls12_crypto_info_aes_gcm_256 aes_gcm_256; + struct tls12_crypto_info_chacha20_poly1305 chacha20_poly1305; + struct tls12_crypto_info_sm4_gcm sm4_gcm; + struct tls12_crypto_info_sm4_ccm sm4_ccm; + }; +}; + #define TLS_DRIVER_STATE_SIZE_TX 16 struct tls_offload_context_tx { struct crypto_aead *aead_send; @@ -165,6 +181,11 @@ struct tls_offload_context_tx { void (*sk_destruct)(struct sock *sk); struct work_struct destruct_work; struct tls_context *ctx; + + struct tls_sw_context_tx rekey_sw; /* SW context for new key */ + struct cipher_context rekey_tx; /* IV, rec_seq for new key */ + union tls_crypto_context rekey_crypto_send; /* Crypto for new key */ + /* The TLS layer reserves room for driver specific state * Currently the belief is that there is not enough * driver specific state to justify another layer of indirection @@ -189,22 +210,21 @@ enum tls_context_flags { * tls_dev_del call in tls_device_down if it happens simultaneously. */ TLS_RX_DEV_CLOSED = 2, -}; - -struct cipher_context { - char iv[TLS_MAX_IV_SIZE + TLS_MAX_SALT_SIZE]; - char rec_seq[TLS_MAX_REC_SEQ_SIZE]; -}; - -union tls_crypto_context { - struct tls_crypto_info info; - union { - struct tls12_crypto_info_aes_gcm_128 aes_gcm_128; - struct tls12_crypto_info_aes_gcm_256 aes_gcm_256; - struct tls12_crypto_info_chacha20_poly1305 chacha20_poly1305; - struct tls12_crypto_info_sm4_gcm sm4_gcm; - struct tls12_crypto_info_sm4_ccm sm4_ccm; - }; + /* Flag for TX HW context deleted during failed rekey. + * Prevents double tls_dev_del in cleanup paths. + */ + TLS_TX_DEV_CLOSED = 3, + /* TX rekey is pending, waiting for old-key data to be ACKed. + * While set, new data uses SW path with new key, HW keeps old key + * for retransmissions. + */ + TLS_TX_REKEY_PENDING = 4, + /* All old-key data has been ACKed, ready to install new key in HW. */ + TLS_TX_REKEY_READY = 5, + /* HW rekey failed, permanently stay in SW encrypt mode. + * Prevents tls_tcp_clean_acked from re-setting TLS_TX_REKEY_READY. + */ + TLS_TX_REKEY_FAILED = 6, }; struct tls_prot_info { @@ -253,6 +273,18 @@ struct tls_context { */ unsigned long flags; + /* TCP sequence number boundary for pending rekey. + * Packets with seq < this use old key, >= use new key. + */ + u32 rekey_boundary_seq; + + /* TCP sequence number where SW-encrypted region ends */ + u32 rekey_complete_seq; + + /* Pointers to rekey contexts for SW encryption with new key */ + struct tls_sw_context_tx *rekey_sw_ctx; + struct cipher_context *rekey_cipher_ctx; + /* cache cold stuff */ struct proto *sk_proto; struct sock *sk; @@ -385,9 +417,21 @@ static inline struct tls_sw_context_rx *tls_sw_ctx_rx( static inline struct tls_sw_context_tx *tls_sw_ctx_tx( const struct tls_context *tls_ctx) { + if (unlikely(tls_ctx->rekey_sw_ctx)) + return tls_ctx->rekey_sw_ctx; + return (struct tls_sw_context_tx *)tls_ctx->priv_ctx_tx; } +static inline struct cipher_context *tls_tx_cipher_ctx( + const struct tls_context *tls_ctx) +{ + if (unlikely(tls_ctx->rekey_cipher_ctx)) + return tls_ctx->rekey_cipher_ctx; + + return (struct cipher_context *)&tls_ctx->tx; +} + static inline struct tls_offload_context_tx * tls_offload_ctx_tx(const struct tls_context *tls_ctx) { @@ -500,6 +544,9 @@ struct sk_buff *tls_encrypt_skb(struct sk_buff *skb); #ifdef CONFIG_TLS_DEVICE void tls_device_sk_destruct(struct sock *sk); void tls_offload_tx_resync_request(struct sock *sk, u32 got_seq, u32 exp_seq); +struct sk_buff * +tls_validate_xmit_skb_rekey(struct sock *sk, struct net_device *dev, + struct sk_buff *skb); static inline bool tls_is_sk_rx_device_offloaded(struct sock *sk) { diff --git a/include/uapi/linux/snmp.h b/include/uapi/linux/snmp.h index 49f5640092a0..39fa48821faa 100644 --- a/include/uapi/linux/snmp.h +++ b/include/uapi/linux/snmp.h @@ -369,6 +369,8 @@ enum LINUX_MIB_TLSTXREKEYOK, /* TlsTxRekeyOk */ LINUX_MIB_TLSTXREKEYERROR, /* TlsTxRekeyError */ LINUX_MIB_TLSRXREKEYRECEIVED, /* TlsRxRekeyReceived */ + LINUX_MIB_TLSTXREKEYHWFAIL, /* TlsTxRekeyHwFail */ + LINUX_MIB_TLSRXREKEYHWFAIL, /* TlsRxRekeyHwFail */ __LINUX_MIB_TLSMAX }; diff --git a/net/tls/tls.h b/net/tls/tls.h index 56eba13261d4..98f94a610f23 100644 --- a/net/tls/tls.h +++ b/net/tls/tls.h @@ -157,6 +157,9 @@ void tls_update_rx_zc_capable(struct tls_context *tls_ctx); void tls_sw_strparser_arm(struct sock *sk, struct tls_context *ctx); void tls_sw_strparser_done(struct tls_context *tls_ctx); int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size); +int tls_sw_sendmsg_locked(struct sock *sk, struct msghdr *msg, size_t size); +void tls_tx_work_handler(struct work_struct *work); +void tls_sw_ctx_tx_init(struct sock *sk, struct tls_sw_context_tx *sw_ctx); void tls_sw_splice_eof(struct socket *sock); void tls_sw_cancel_work_tx(struct tls_context *tls_ctx); void tls_sw_release_resources_tx(struct sock *sk); @@ -233,9 +236,11 @@ static inline bool tls_strp_msg_mixed_decrypted(struct tls_sw_context_rx *ctx) #ifdef CONFIG_TLS_DEVICE int tls_device_init(void); void tls_device_cleanup(void); -int tls_set_device_offload(struct sock *sk); +int tls_set_device_offload(struct sock *sk, + struct tls_crypto_info *crypto_info); void tls_device_free_resources_tx(struct sock *sk); -int tls_set_device_offload_rx(struct sock *sk, struct tls_context *ctx); +int tls_set_device_offload_rx(struct sock *sk, struct tls_context *ctx, + struct tls_crypto_info *crypto_info); void tls_device_offload_cleanup_rx(struct sock *sk); void tls_device_rx_resync_new_rec(struct sock *sk, u32 rcd_len, u32 seq); int tls_device_decrypted(struct sock *sk, struct tls_context *tls_ctx); @@ -244,7 +249,7 @@ static inline int tls_device_init(void) { return 0; } static inline void tls_device_cleanup(void) {} static inline int -tls_set_device_offload(struct sock *sk) +tls_set_device_offload(struct sock *sk, struct tls_crypto_info *crypto_info) { return -EOPNOTSUPP; } @@ -252,7 +257,8 @@ tls_set_device_offload(struct sock *sk) static inline void tls_device_free_resources_tx(struct sock *sk) {} static inline int -tls_set_device_offload_rx(struct sock *sk, struct tls_context *ctx) +tls_set_device_offload_rx(struct sock *sk, struct tls_context *ctx, + struct tls_crypto_info *crypto_info) { return -EOPNOTSUPP; } diff --git a/net/tls/tls_device.c b/net/tls/tls_device.c index cd26873e9063..ab6ef7742121 100644 --- a/net/tls/tls_device.c +++ b/net/tls/tls_device.c @@ -79,7 +79,9 @@ static void tls_device_tx_del_task(struct work_struct *work) netdev = rcu_dereference_protected(ctx->netdev, !refcount_read(&ctx->refcount)); - netdev->tlsdev_ops->tls_dev_del(netdev, ctx, TLS_OFFLOAD_CTX_DIR_TX); + if (!test_bit(TLS_TX_DEV_CLOSED, &ctx->flags)) + netdev->tlsdev_ops->tls_dev_del(netdev, ctx, + TLS_OFFLOAD_CTX_DIR_TX); dev_put(netdev); ctx->netdev = NULL; tls_device_free_ctx(ctx); @@ -159,6 +161,221 @@ static void delete_all_records(struct tls_offload_context_tx *offload_ctx) offload_ctx->retransmit_hint = NULL; } +static bool tls_has_unacked_records(struct tls_offload_context_tx *offload_ctx) +{ + struct tls_record_info *info; + bool has_unacked = false; + unsigned long flags; + + spin_lock_irqsave(&offload_ctx->lock, flags); + list_for_each_entry(info, &offload_ctx->records_list, list) { + if (!tls_record_is_start_marker(info)) { + has_unacked = true; + break; + } + } + spin_unlock_irqrestore(&offload_ctx->lock, flags); + + return has_unacked; +} + +static int tls_device_init_rekey_sw(struct sock *sk, + struct tls_context *ctx, + struct tls_offload_context_tx *offload_ctx, + struct tls_crypto_info *new_crypto_info) +{ + struct tls_sw_context_tx *sw_ctx = &offload_ctx->rekey_sw; + const struct tls_cipher_desc *cipher_desc; + char *key; + int rc; + + cipher_desc = get_cipher_desc(new_crypto_info->cipher_type); + DEBUG_NET_WARN_ON_ONCE(!cipher_desc || !cipher_desc->offloadable); + + memset(sw_ctx, 0, sizeof(*sw_ctx)); + tls_sw_ctx_tx_init(sk, sw_ctx); + + sw_ctx->aead_send = crypto_alloc_aead(cipher_desc->cipher_name, 0, 0); + if (IS_ERR(sw_ctx->aead_send)) { + rc = PTR_ERR(sw_ctx->aead_send); + sw_ctx->aead_send = NULL; + return rc; + } + + key = crypto_info_key(new_crypto_info, cipher_desc); + rc = crypto_aead_setkey(sw_ctx->aead_send, key, cipher_desc->key); + if (rc) + goto free_aead; + + rc = crypto_aead_setauthsize(sw_ctx->aead_send, cipher_desc->tag); + if (rc) + goto free_aead; + + return 0; + +free_aead: + crypto_free_aead(sw_ctx->aead_send); + sw_ctx->aead_send = NULL; + return rc; +} + +static int tls_device_start_rekey(struct sock *sk, + struct tls_context *ctx, + struct tls_offload_context_tx *offload_ctx, + struct tls_crypto_info *new_crypto_info) +{ + bool rekey_pending = test_bit(TLS_TX_REKEY_PENDING, &ctx->flags); + bool rekey_failed = test_bit(TLS_TX_REKEY_FAILED, &ctx->flags); + const struct tls_cipher_desc *cipher_desc; + char *key, *iv, *rec_seq, *salt; + int rc; + + cipher_desc = get_cipher_desc(new_crypto_info->cipher_type); + DEBUG_NET_WARN_ON_ONCE(!cipher_desc || !cipher_desc->offloadable); + + key = crypto_info_key(new_crypto_info, cipher_desc); + iv = crypto_info_iv(new_crypto_info, cipher_desc); + rec_seq = crypto_info_rec_seq(new_crypto_info, cipher_desc); + salt = crypto_info_salt(new_crypto_info, cipher_desc); + + if (rekey_pending || rekey_failed) { + rc = crypto_aead_setkey(offload_ctx->rekey_sw.aead_send, + key, cipher_desc->key); + if (rc) + return rc; + + memcpy(offload_ctx->rekey_tx.iv, salt, cipher_desc->salt); + memcpy(offload_ctx->rekey_tx.iv + cipher_desc->salt, iv, + cipher_desc->iv); + memcpy(offload_ctx->rekey_tx.rec_seq, rec_seq, + cipher_desc->rec_seq); + + if (rekey_failed) { + set_bit(TLS_TX_REKEY_PENDING, &ctx->flags); + clear_bit(TLS_TX_REKEY_FAILED, &ctx->flags); + } + } else { + rc = tls_device_init_rekey_sw(sk, ctx, offload_ctx, + new_crypto_info); + if (rc) + return rc; + + memcpy(offload_ctx->rekey_tx.iv, salt, cipher_desc->salt); + memcpy(offload_ctx->rekey_tx.iv + cipher_desc->salt, iv, + cipher_desc->iv); + memcpy(offload_ctx->rekey_tx.rec_seq, rec_seq, + cipher_desc->rec_seq); + + WRITE_ONCE(ctx->rekey_complete_seq, 0); + WRITE_ONCE(ctx->rekey_boundary_seq, tcp_sk(sk)->write_seq); + + ctx->rekey_sw_ctx = &offload_ctx->rekey_sw; + ctx->rekey_cipher_ctx = &offload_ctx->rekey_tx; + + set_bit(TLS_TX_REKEY_PENDING, &ctx->flags); + + /* Ensure rekey context is visible before TX path sees + * new callback + */ + smp_store_release(&sk->sk_validate_xmit_skb, + tls_validate_xmit_skb_rekey); + } + + unsafe_memcpy(&offload_ctx->rekey_crypto_send.info, new_crypto_info, + cipher_desc->crypto_info, + /* checked in do_tls_setsockopt_conf */); + memzero_explicit(new_crypto_info, cipher_desc->crypto_info); + + return 0; +} + +static int tls_device_complete_rekey(struct sock *sk, struct tls_context *ctx) +{ + struct tls_offload_context_tx *offload_ctx = tls_offload_ctx_tx(ctx); + const struct tls_cipher_desc *cipher_desc; + struct net_device *netdev; + unsigned long flags; + __be64 rcd_sn; + char *key; + int rc; + + cipher_desc = get_cipher_desc(offload_ctx->rekey_crypto_send.info.cipher_type); + DEBUG_NET_WARN_ON_ONCE(!cipher_desc || !cipher_desc->offloadable); + + down_read(&device_offload_lock); + + netdev = rcu_dereference_protected(ctx->netdev, + lockdep_is_held(&device_offload_lock)); + if (!netdev) { + rc = -ENODEV; + goto release_lock; + } + + if (!test_bit(TLS_TX_DEV_CLOSED, &ctx->flags)) { + netdev->tlsdev_ops->tls_dev_del(netdev, ctx, + TLS_OFFLOAD_CTX_DIR_TX); + set_bit(TLS_TX_DEV_CLOSED, &ctx->flags); + } + + memcpy(crypto_info_rec_seq(&offload_ctx->rekey_crypto_send.info, cipher_desc), + offload_ctx->rekey_tx.rec_seq, cipher_desc->rec_seq); + + rc = netdev->tlsdev_ops->tls_dev_add(netdev, sk, TLS_OFFLOAD_CTX_DIR_TX, + &offload_ctx->rekey_crypto_send.info, + tcp_sk(sk)->write_seq); + +release_lock: + up_read(&device_offload_lock); + + spin_lock_irqsave(&offload_ctx->lock, flags); + memcpy(&rcd_sn, offload_ctx->rekey_tx.rec_seq, sizeof(rcd_sn)); + offload_ctx->unacked_record_sn = be64_to_cpu(rcd_sn); + spin_unlock_irqrestore(&offload_ctx->lock, flags); + + memcpy(ctx->tx.iv, offload_ctx->rekey_tx.iv, + cipher_desc->salt + cipher_desc->iv); + memcpy(ctx->tx.rec_seq, offload_ctx->rekey_tx.rec_seq, + cipher_desc->rec_seq); + unsafe_memcpy(&ctx->crypto_send.info, + &offload_ctx->rekey_crypto_send.info, + cipher_desc->crypto_info, + /* checked during rekey setup */); + + if (rc) + goto rekey_fail; + + clear_bit(TLS_TX_DEV_CLOSED, &ctx->flags); + + key = crypto_info_key(&offload_ctx->rekey_crypto_send.info, cipher_desc); + rc = crypto_aead_setkey(offload_ctx->aead_send, key, cipher_desc->key); + if (rc) + goto rekey_fail; + + clear_bit(TLS_TX_REKEY_READY, &ctx->flags); + clear_bit(TLS_TX_REKEY_PENDING, &ctx->flags); + clear_bit(TLS_TX_REKEY_FAILED, &ctx->flags); + + /* following this assignment tls_is_skb_tx_device_offloaded + * will return true and the context might be accessed + * by the netdev's xmit function. + */ + smp_store_release(&sk->sk_validate_xmit_skb, tls_validate_xmit_skb); + + tls_sw_release_resources_tx(sk); + ctx->rekey_sw_ctx = NULL; + ctx->rekey_cipher_ctx = NULL; + + return 0; + +rekey_fail: + set_bit(TLS_TX_REKEY_FAILED, &ctx->flags); + clear_bit(TLS_TX_REKEY_READY, &ctx->flags); + clear_bit(TLS_TX_REKEY_PENDING, &ctx->flags); + TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSTXREKEYHWFAIL); + + return 0; +} + static void tls_tcp_clean_acked(struct sock *sk, u32 acked_seq) { struct tls_context *tls_ctx = tls_get_ctx(sk); @@ -187,6 +404,32 @@ static void tls_tcp_clean_acked(struct sock *sk, u32 acked_seq) } ctx->unacked_record_sn += deleted_records; + + /* Track ACKs to determine when HW rekey can complete: + * complete_seq captures write_seq when old-key data is ACKed, + * REKEY_READY is set once all pending data (including any new) is + * ACKed. + */ + if (test_bit(TLS_TX_REKEY_PENDING, &tls_ctx->flags) && + !test_bit(TLS_TX_REKEY_FAILED, &tls_ctx->flags)) { + u32 boundary_seq = READ_ONCE(tls_ctx->rekey_boundary_seq); + u32 complete_seq = READ_ONCE(tls_ctx->rekey_complete_seq); + + if (!before(acked_seq, boundary_seq) && complete_seq == 0) { + complete_seq = tcp_sk(sk)->write_seq; + WRITE_ONCE(tls_ctx->rekey_complete_seq, complete_seq); + } + + if (complete_seq != 0) { + u32 current_write_seq = tcp_sk(sk)->write_seq; + + if (before(complete_seq, current_write_seq)) + WRITE_ONCE(tls_ctx->rekey_complete_seq, current_write_seq); + else if (!before(acked_seq, complete_seq)) + set_bit(TLS_TX_REKEY_READY, &tls_ctx->flags); + } + } + spin_unlock_irqrestore(&ctx->lock, flags); } @@ -218,6 +461,9 @@ void tls_device_free_resources_tx(struct sock *sk) struct tls_context *tls_ctx = tls_get_ctx(sk); tls_free_partial_record(sk, tls_ctx); + + if (unlikely(tls_ctx->rekey_sw_ctx)) + tls_sw_release_resources_tx(sk); } void tls_offload_tx_resync_request(struct sock *sk, u32 got_seq, u32 exp_seq) @@ -589,6 +835,22 @@ int tls_device_sendmsg(struct sock *sk, struct msghdr *msg, size_t size) goto out; } + /* Complete the rekey if READY is set. This means all old-key + * records have been ACKed and no new data is in flight, so we + * can safely switch from SW to HW offload with the new key. + */ + if (test_bit(TLS_TX_REKEY_READY, &tls_ctx->flags)) + tls_device_complete_rekey(sk, tls_ctx); + + /* Use SW path if rekey is in progress (PENDING) or if HW rekey + * failed (FAILED). + */ + if (test_bit(TLS_TX_REKEY_PENDING, &tls_ctx->flags) || + test_bit(TLS_TX_REKEY_FAILED, &tls_ctx->flags)) { + rc = tls_sw_sendmsg_locked(sk, msg, size); + goto out; + } + rc = tls_push_data(sk, &msg->msg_iter, size, msg->msg_flags, record_type); @@ -1068,57 +1330,31 @@ static struct tls_offload_context_tx *alloc_offload_ctx_tx(struct tls_context *c return offload_ctx; } -int tls_set_device_offload(struct sock *sk) +static int tls_set_device_offload_initial(struct sock *sk, + struct tls_context *ctx, + struct net_device *netdev, + struct tls_crypto_info *crypto_info, + const struct tls_cipher_desc *cipher_desc) { + struct tls_prot_info *prot = &ctx->prot_info; struct tls_record_info *start_marker_record; struct tls_offload_context_tx *offload_ctx; - const struct tls_cipher_desc *cipher_desc; - struct tls_crypto_info *crypto_info; - struct tls_prot_info *prot; - struct net_device *netdev; - struct tls_context *ctx; char *iv, *rec_seq; int rc; - ctx = tls_get_ctx(sk); - prot = &ctx->prot_info; - - if (ctx->priv_ctx_tx) - return -EEXIST; - - netdev = get_netdev_for_sock(sk); - if (!netdev) { - pr_err_ratelimited("%s: netdev not found\n", __func__); - return -EINVAL; - } - - if (!(netdev->features & NETIF_F_HW_TLS_TX)) { - rc = -EOPNOTSUPP; - goto release_netdev; - } - - crypto_info = &ctx->crypto_send.info; - cipher_desc = get_cipher_desc(crypto_info->cipher_type); - if (!cipher_desc || !cipher_desc->offloadable) { - rc = -EINVAL; - goto release_netdev; - } + iv = crypto_info_iv(crypto_info, cipher_desc); + rec_seq = crypto_info_rec_seq(crypto_info, cipher_desc); rc = init_prot_info(prot, crypto_info, cipher_desc); if (rc) - goto release_netdev; - - iv = crypto_info_iv(crypto_info, cipher_desc); - rec_seq = crypto_info_rec_seq(crypto_info, cipher_desc); + return rc; memcpy(ctx->tx.iv + cipher_desc->salt, iv, cipher_desc->iv); memcpy(ctx->tx.rec_seq, rec_seq, cipher_desc->rec_seq); start_marker_record = kmalloc_obj(*start_marker_record); - if (!start_marker_record) { - rc = -ENOMEM; - goto release_netdev; - } + if (!start_marker_record) + return -ENOMEM; offload_ctx = alloc_offload_ctx_tx(ctx); if (!offload_ctx) { @@ -1159,8 +1395,10 @@ int tls_set_device_offload(struct sock *sk) } ctx->priv_ctx_tx = offload_ctx; - rc = netdev->tlsdev_ops->tls_dev_add(netdev, sk, TLS_OFFLOAD_CTX_DIR_TX, - &ctx->crypto_send.info, + + rc = netdev->tlsdev_ops->tls_dev_add(netdev, sk, + TLS_OFFLOAD_CTX_DIR_TX, + crypto_info, tcp_sk(sk)->write_seq); trace_tls_device_offload_set(sk, TLS_OFFLOAD_CTX_DIR_TX, tcp_sk(sk)->write_seq, rec_seq, rc); @@ -1175,7 +1413,6 @@ int tls_set_device_offload(struct sock *sk) * by the netdev's xmit function. */ smp_store_release(&sk->sk_validate_xmit_skb, tls_validate_xmit_skb); - dev_put(netdev); return 0; @@ -1188,18 +1425,112 @@ int tls_set_device_offload(struct sock *sk) ctx->priv_ctx_tx = NULL; free_marker_record: kfree(start_marker_record); + return rc; +} + +static int tls_set_device_offload_rekey(struct sock *sk, + struct tls_context *ctx, + struct net_device *netdev, + struct tls_crypto_info *new_crypto_info) +{ + struct tls_offload_context_tx *offload_ctx = tls_offload_ctx_tx(ctx); + bool rekey_pending = test_bit(TLS_TX_REKEY_PENDING, &ctx->flags); + bool has_unacked = false; + int rc; + + if (!rekey_pending) + has_unacked = tls_has_unacked_records(offload_ctx); + + down_read(&device_offload_lock); + + rc = tls_device_start_rekey(sk, ctx, offload_ctx, new_crypto_info); + if (rc) { + up_read(&device_offload_lock); + return rc; + } + + up_read(&device_offload_lock); + + if (!rekey_pending && !has_unacked && + ctx->rekey_boundary_seq == tcp_sk(sk)->write_seq) + rc = tls_device_complete_rekey(sk, ctx); + + return rc; +} + +int tls_set_device_offload(struct sock *sk, + struct tls_crypto_info *new_crypto_info) +{ + struct tls_crypto_info *crypto_info, *src_crypto_info; + const struct tls_cipher_desc *cipher_desc; + struct net_device *netdev; + struct tls_context *ctx; + int rc; + + ctx = tls_get_ctx(sk); + + /* Rekey is only supported for connections that are already + * using HW offload. For SW offload connections, the caller + * should fall back to tls_set_sw_offload() for rekey. + */ + if (new_crypto_info && ctx->tx_conf != TLS_HW) + return -EINVAL; + + netdev = get_netdev_for_sock(sk); + if (!netdev) { + pr_err_ratelimited("%s: netdev not found\n", __func__); + return -EINVAL; + } + + if (!(netdev->features & NETIF_F_HW_TLS_TX)) { + rc = -EOPNOTSUPP; + goto release_netdev; + } + + crypto_info = &ctx->crypto_send.info; + src_crypto_info = new_crypto_info ?: crypto_info; + cipher_desc = get_cipher_desc(src_crypto_info->cipher_type); + if (!cipher_desc || !cipher_desc->offloadable) { + rc = -EINVAL; + goto release_netdev; + } + + if (new_crypto_info) + rc = tls_set_device_offload_rekey(sk, ctx, netdev, + src_crypto_info); + else + rc = tls_set_device_offload_initial(sk, ctx, netdev, + src_crypto_info, + cipher_desc); + release_netdev: dev_put(netdev); return rc; } -int tls_set_device_offload_rx(struct sock *sk, struct tls_context *ctx) +int tls_set_device_offload_rx(struct sock *sk, struct tls_context *ctx, + struct tls_crypto_info *new_crypto_info) { - struct tls12_crypto_info_aes_gcm_128 *info; + struct tls_crypto_info *crypto_info, *src_crypto_info; + const struct tls_cipher_desc *cipher_desc; struct tls_offload_context_rx *context; struct net_device *netdev; + char *rec_seq; int rc = 0; + /* Rekey is only supported for connections that are already + * using HW offload. For SW offload connections, the caller + * should fall back to tls_set_sw_offload() for rekey. + */ + if (new_crypto_info && ctx->rx_conf != TLS_HW) + return -EINVAL; + + crypto_info = &ctx->crypto_recv.info; + src_crypto_info = new_crypto_info ?: crypto_info; + cipher_desc = get_cipher_desc(src_crypto_info->cipher_type); + if (!cipher_desc || !cipher_desc->offloadable) + return -EINVAL; + netdev = get_netdev_for_sock(sk); if (!netdev) { pr_err_ratelimited("%s: netdev not found\n", __func__); @@ -1225,29 +1556,50 @@ int tls_set_device_offload_rx(struct sock *sk, struct tls_context *ctx) goto release_lock; } - context = kzalloc_obj(*context); - if (!context) { - rc = -ENOMEM; - goto release_lock; + if (!new_crypto_info) { + context = kzalloc_obj(*context); + if (!context) { + rc = -ENOMEM; + goto release_lock; + } + context->resync_nh_reset = 1; + ctx->priv_ctx_rx = context; } - context->resync_nh_reset = 1; - ctx->priv_ctx_rx = context; - rc = tls_sw_ctx_init(sk, 0, NULL); + rc = tls_sw_ctx_init(sk, 0, new_crypto_info); if (rc) goto release_ctx; + if (new_crypto_info && !test_bit(TLS_RX_DEV_CLOSED, &ctx->flags)) + netdev->tlsdev_ops->tls_dev_del(netdev, ctx, + TLS_OFFLOAD_CTX_DIR_RX); + rc = netdev->tlsdev_ops->tls_dev_add(netdev, sk, TLS_OFFLOAD_CTX_DIR_RX, - &ctx->crypto_recv.info, + src_crypto_info, tcp_sk(sk)->copied_seq); - info = (void *)&ctx->crypto_recv.info; + + rec_seq = crypto_info_rec_seq(src_crypto_info, cipher_desc); trace_tls_device_offload_set(sk, TLS_OFFLOAD_CTX_DIR_RX, - tcp_sk(sk)->copied_seq, info->rec_seq, rc); - if (rc) - goto free_sw_resources; + tcp_sk(sk)->copied_seq, rec_seq, rc); + if (rc) { + if (new_crypto_info) { + set_bit(TLS_RX_DEV_DEGRADED, &ctx->flags); + set_bit(TLS_RX_DEV_CLOSED, &ctx->flags); + TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSRXREKEYHWFAIL); + } else { + goto free_sw_resources; + } + } else { + if (new_crypto_info) { + clear_bit(TLS_RX_DEV_DEGRADED, &ctx->flags); + clear_bit(TLS_RX_DEV_CLOSED, &ctx->flags); + } + + tls_device_attach(ctx, sk, netdev); + } + + tls_sw_ctx_finalize(sk, 0, new_crypto_info); - tls_device_attach(ctx, sk, netdev); - tls_sw_ctx_finalize(sk, 0, NULL); up_read(&device_offload_lock); dev_put(netdev); @@ -1256,10 +1608,13 @@ int tls_set_device_offload_rx(struct sock *sk, struct tls_context *ctx) free_sw_resources: up_read(&device_offload_lock); - tls_sw_free_resources_rx(sk); + tls_sw_release_resources_rx(sk); down_read(&device_offload_lock); release_ctx: - ctx->priv_ctx_rx = NULL; + if (!new_crypto_info) { + kfree(ctx->priv_ctx_rx); + ctx->priv_ctx_rx = NULL; + } release_lock: up_read(&device_offload_lock); release_netdev: @@ -1278,8 +1633,9 @@ void tls_device_offload_cleanup_rx(struct sock *sk) if (!netdev) goto out; - netdev->tlsdev_ops->tls_dev_del(netdev, tls_ctx, - TLS_OFFLOAD_CTX_DIR_RX); + if (!test_bit(TLS_RX_DEV_CLOSED, &tls_ctx->flags)) + netdev->tlsdev_ops->tls_dev_del(netdev, tls_ctx, + TLS_OFFLOAD_CTX_DIR_RX); if (tls_ctx->tx_conf != TLS_HW) { dev_put(netdev); @@ -1319,7 +1675,10 @@ static int tls_device_down(struct net_device *netdev) /* Stop offloaded TX and switch to the fallback. * tls_is_skb_tx_device_offloaded will return false. */ - WRITE_ONCE(ctx->sk->sk_validate_xmit_skb, tls_validate_xmit_skb_sw); + if (!test_bit(TLS_TX_REKEY_PENDING, &ctx->flags) && + !test_bit(TLS_TX_REKEY_FAILED, &ctx->flags)) + WRITE_ONCE(ctx->sk->sk_validate_xmit_skb, + tls_validate_xmit_skb_sw); /* Stop the RX and TX resync. * tls_dev_resync must not be called after tls_dev_del. @@ -1336,13 +1695,18 @@ static int tls_device_down(struct net_device *netdev) synchronize_net(); /* Release the offload context on the driver side. */ - if (ctx->tx_conf == TLS_HW) + if (ctx->tx_conf == TLS_HW && + !test_bit(TLS_TX_DEV_CLOSED, &ctx->flags)) { netdev->tlsdev_ops->tls_dev_del(netdev, ctx, TLS_OFFLOAD_CTX_DIR_TX); + set_bit(TLS_TX_DEV_CLOSED, &ctx->flags); + } if (ctx->rx_conf == TLS_HW && - !test_bit(TLS_RX_DEV_CLOSED, &ctx->flags)) + !test_bit(TLS_RX_DEV_CLOSED, &ctx->flags)) { netdev->tlsdev_ops->tls_dev_del(netdev, ctx, TLS_OFFLOAD_CTX_DIR_RX); + set_bit(TLS_RX_DEV_CLOSED, &ctx->flags); + } dev_put(netdev); diff --git a/net/tls/tls_device_fallback.c b/net/tls/tls_device_fallback.c index 99d5590d20b0..40a0ddde2fce 100644 --- a/net/tls/tls_device_fallback.c +++ b/net/tls/tls_device_fallback.c @@ -438,6 +438,30 @@ struct sk_buff *tls_validate_xmit_skb_sw(struct sock *sk, return tls_sw_fallback(sk, skb); } +struct sk_buff *tls_validate_xmit_skb_rekey(struct sock *sk, + struct net_device *dev, + struct sk_buff *skb) +{ + struct tls_context *tls_ctx = tls_get_ctx(sk); + u32 tcp_seq = ntohl(tcp_hdr(skb)->seq); + u32 boundary_seq; + + if (test_bit(TLS_TX_REKEY_FAILED, &tls_ctx->flags)) + return skb; + + /* If this packet is at or after the rekey boundary, it's already + * SW-encrypted with the new key, pass through unchanged + */ + boundary_seq = READ_ONCE(tls_ctx->rekey_boundary_seq); + if (!before(tcp_seq, boundary_seq)) + return skb; + + /* Packet before boundary means retransmit of old data, + * use SW fallback with the old key + */ + return tls_sw_fallback(sk, skb); +} + struct sk_buff *tls_encrypt_skb(struct sk_buff *skb) { return tls_sw_fallback(skb->sk, skb); diff --git a/net/tls/tls_main.c b/net/tls/tls_main.c index fd04857fa0ab..ab701f166b57 100644 --- a/net/tls/tls_main.c +++ b/net/tls/tls_main.c @@ -371,6 +371,8 @@ static void tls_sk_proto_close(struct sock *sk, long timeout) if (ctx->tx_conf == TLS_SW) tls_sw_cancel_work_tx(ctx); + else if (ctx->tx_conf == TLS_HW && ctx->rekey_sw_ctx) + tls_sw_cancel_work_tx(ctx); lock_sock(sk); free_ctx = ctx->tx_conf != TLS_HW && ctx->rx_conf != TLS_HW; @@ -711,64 +713,68 @@ static int do_tls_setsockopt_conf(struct sock *sk, sockptr_t optval, } if (tx) { - if (update && ctx->tx_conf == TLS_HW) { - rc = -EOPNOTSUPP; - goto err_crypto_info; - } - - if (!update) { - rc = tls_set_device_offload(sk); - conf = TLS_HW; - if (!rc) { + rc = tls_set_device_offload(sk, update ? crypto_info : NULL); + conf = TLS_HW; + if (!rc) { + if (update) { + TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSTXREKEYOK); + } else { TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSTXDEVICE); TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRTXDEVICE); - goto out; } - } - - rc = tls_set_sw_offload(sk, 1, update ? crypto_info : NULL); - if (rc) + } else if (update && ctx->tx_conf == TLS_HW) { + /* HW rekey failed - return the actual error. + * Cannot fall back to SW for an existing HW connection. + */ goto err_crypto_info; - - if (update) { - TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSTXREKEYOK); } else { - TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSTXSW); - TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRTXSW); + rc = tls_set_sw_offload(sk, 1, + update ? crypto_info : NULL); + if (rc) + goto err_crypto_info; + + if (update) { + TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSTXREKEYOK); + } else { + TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSTXSW); + TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRTXSW); + } + conf = TLS_SW; } - conf = TLS_SW; } else { - if (update && ctx->rx_conf == TLS_HW) { - rc = -EOPNOTSUPP; - goto err_crypto_info; - } - - if (!update) { - rc = tls_set_device_offload_rx(sk, ctx); - conf = TLS_HW; - if (!rc) { + rc = tls_set_device_offload_rx(sk, ctx, + update ? crypto_info : NULL); + conf = TLS_HW; + if (!rc) { + if (update) { + TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSRXREKEYOK); + } else { TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSRXDEVICE); TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRRXDEVICE); - tls_sw_strparser_arm(sk, ctx); - goto out; } - } - - rc = tls_set_sw_offload(sk, 0, update ? crypto_info : NULL); - if (rc) + } else if (update && ctx->rx_conf == TLS_HW) { + /* HW rekey failed - return the actual error. + * Cannot fall back to SW for an existing HW connection. + */ goto err_crypto_info; - - if (update) { - TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSRXREKEYOK); } else { - TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSRXSW); - TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRRXSW); - tls_sw_strparser_arm(sk, ctx); + rc = tls_set_sw_offload(sk, 0, + update ? crypto_info : NULL); + if (rc) + goto err_crypto_info; + + if (update) { + TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSRXREKEYOK); + } else { + TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSRXSW); + TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRRXSW); + } + conf = TLS_SW; } - conf = TLS_SW; + if (!update) + tls_sw_strparser_arm(sk, ctx); } -out: if (tx) ctx->tx_conf = conf; else diff --git a/net/tls/tls_proc.c b/net/tls/tls_proc.c index 4012c4372d4c..5599af306aab 100644 --- a/net/tls/tls_proc.c +++ b/net/tls/tls_proc.c @@ -27,6 +27,8 @@ static const struct snmp_mib tls_mib_list[] = { SNMP_MIB_ITEM("TlsTxRekeyOk", LINUX_MIB_TLSTXREKEYOK), SNMP_MIB_ITEM("TlsTxRekeyError", LINUX_MIB_TLSTXREKEYERROR), SNMP_MIB_ITEM("TlsRxRekeyReceived", LINUX_MIB_TLSRXREKEYRECEIVED), + SNMP_MIB_ITEM("TlsTxRekeyHwFail", LINUX_MIB_TLSTXREKEYHWFAIL), + SNMP_MIB_ITEM("TlsRxRekeyHwFail", LINUX_MIB_TLSRXREKEYHWFAIL), }; static int tls_statistics_seq_show(struct seq_file *seq, void *v) diff --git a/net/tls/tls_sw.c b/net/tls/tls_sw.c index 424e0a11bcf4..4a4cf838bd0d 100644 --- a/net/tls/tls_sw.c +++ b/net/tls/tls_sw.c @@ -554,11 +554,11 @@ static int tls_do_encryption(struct sock *sk, break; } - memcpy(&rec->iv_data[iv_offset], tls_ctx->tx.iv, + memcpy(&rec->iv_data[iv_offset], tls_tx_cipher_ctx(tls_ctx)->iv, prot->iv_size + prot->salt_size); tls_xor_iv_with_seq(prot, rec->iv_data + iv_offset, - tls_ctx->tx.rec_seq); + tls_tx_cipher_ctx(tls_ctx)->rec_seq); sge->offset += prot->prepend_size; sge->length -= prot->prepend_size; @@ -599,7 +599,7 @@ static int tls_do_encryption(struct sock *sk, /* Unhook the record from context if encryption is not failure */ ctx->open_rec = NULL; - tls_advance_record_sn(sk, prot, &tls_ctx->tx); + tls_advance_record_sn(sk, prot, tls_tx_cipher_ctx(tls_ctx)); return rc; } @@ -806,7 +806,7 @@ static int tls_push_record(struct sock *sk, int flags, sg_chain(rec->sg_aead_out, 2, &msg_en->sg.data[i]); tls_make_aad(rec->aad_space, msg_pl->sg.size + prot->tail_size, - tls_ctx->tx.rec_seq, record_type, prot); + tls_tx_cipher_ctx(tls_ctx)->rec_seq, record_type, prot); tls_fill_prepend(tls_ctx, page_address(sg_page(&msg_en->sg.data[i])) + @@ -1022,8 +1022,7 @@ static int tls_sw_sendmsg_splice(struct sock *sk, struct msghdr *msg, return 0; } -static int tls_sw_sendmsg_locked(struct sock *sk, struct msghdr *msg, - size_t size) +int tls_sw_sendmsg_locked(struct sock *sk, struct msghdr *msg, size_t size) { long timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT); struct tls_context *tls_ctx = tls_get_ctx(sk); @@ -2621,7 +2620,7 @@ void tls_sw_free_resources_rx(struct sock *sk) } /* The work handler to transmitt the encrypted records in tx_list */ -static void tx_work_handler(struct work_struct *work) +void tls_tx_work_handler(struct work_struct *work) { struct delayed_work *delayed_work = to_delayed_work(work); struct tx_work *tx_work = container_of(delayed_work, @@ -2654,6 +2653,15 @@ static void tx_work_handler(struct work_struct *work) } } +void tls_sw_ctx_tx_init(struct sock *sk, struct tls_sw_context_tx *sw_ctx) +{ + crypto_init_wait(&sw_ctx->async_wait); + atomic_set(&sw_ctx->encrypt_pending, 1); + INIT_LIST_HEAD(&sw_ctx->tx_list); + INIT_DELAYED_WORK(&sw_ctx->tx_work.work, tls_tx_work_handler); + sw_ctx->tx_work.sk = sk; +} + static bool tls_is_tx_ready(struct tls_sw_context_tx *ctx) { struct tls_rec *rec; @@ -2705,11 +2713,7 @@ static struct tls_sw_context_tx *init_ctx_tx(struct tls_context *ctx, struct soc sw_ctx_tx = ctx->priv_ctx_tx; } - crypto_init_wait(&sw_ctx_tx->async_wait); - atomic_set(&sw_ctx_tx->encrypt_pending, 1); - INIT_LIST_HEAD(&sw_ctx_tx->tx_list); - INIT_DELAYED_WORK(&sw_ctx_tx->tx_work.work, tx_work_handler); - sw_ctx_tx->tx_work.sk = sk; + tls_sw_ctx_tx_init(sk, sw_ctx_tx); return sw_ctx_tx; } -- 2.25.1 Two-node kTLS hardware offload test using NetDrvEpEnv. Tests TLS 1.2/1.3 with AES-GCM-128/256, rekey operations, and various buffer sizes. Signed-off-by: Rishikesh Jethwani --- .../selftests/drivers/net/hw/.gitignore | 1 + .../testing/selftests/drivers/net/hw/Makefile | 2 + .../selftests/drivers/net/hw/tls_hw_offload.c | 902 ++++++++++++++++++ .../drivers/net/hw/tls_hw_offload.py | 281 ++++++ 4 files changed, 1186 insertions(+) create mode 100644 tools/testing/selftests/drivers/net/hw/tls_hw_offload.c create mode 100755 tools/testing/selftests/drivers/net/hw/tls_hw_offload.py diff --git a/tools/testing/selftests/drivers/net/hw/.gitignore b/tools/testing/selftests/drivers/net/hw/.gitignore index 46540468a775..f0a5d15b469b 100644 --- a/tools/testing/selftests/drivers/net/hw/.gitignore +++ b/tools/testing/selftests/drivers/net/hw/.gitignore @@ -2,3 +2,4 @@ iou-zcrx ncdevmem toeplitz +tls_hw_offload diff --git a/tools/testing/selftests/drivers/net/hw/Makefile b/tools/testing/selftests/drivers/net/hw/Makefile index a64140333a46..6b12b0920cae 100644 --- a/tools/testing/selftests/drivers/net/hw/Makefile +++ b/tools/testing/selftests/drivers/net/hw/Makefile @@ -15,6 +15,7 @@ endif TEST_GEN_FILES := \ $(COND_GEN_FILES) \ + tls_hw_offload \ # end of TEST_GEN_FILES TEST_PROGS = \ @@ -38,6 +39,7 @@ TEST_PROGS = \ rss_drv.py \ rss_flow_label.py \ rss_input_xfrm.py \ + tls_hw_offload.py \ toeplitz.py \ tso.py \ xsk_reconfig.py \ diff --git a/tools/testing/selftests/drivers/net/hw/tls_hw_offload.c b/tools/testing/selftests/drivers/net/hw/tls_hw_offload.c new file mode 100644 index 000000000000..cf059368a801 --- /dev/null +++ b/tools/testing/selftests/drivers/net/hw/tls_hw_offload.c @@ -0,0 +1,902 @@ +// SPDX-License-Identifier: GPL-2.0 +/* + * TLS Hardware Offload Two-Node Test + * + * Tests kTLS hardware offload between two physical nodes using + * hardcoded keys. Supports TLS 1.2/1.3, AES-GCM-128/256, and rekey. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#define TLS_RECORD_TYPE_HANDSHAKE 22 +#define TLS_RECORD_TYPE_APPLICATION_DATA 23 +#define TLS_HANDSHAKE_KEY_UPDATE 0x18 +#define KEY_UPDATE_NOT_REQUESTED 0 +#define KEY_UPDATE_REQUESTED 1 + +#define TEST_ITERATIONS 100 +#define MAX_REKEYS 99 + +/* Initial key material */ +static struct tls12_crypto_info_aes_gcm_128 tls_info_key0_128 = { + .info = { + .version = TLS_1_3_VERSION, + .cipher_type = TLS_CIPHER_AES_GCM_128, + }, + .iv = { 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08 }, + .key = { 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, + 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10 }, + .salt = { 0x01, 0x02, 0x03, 0x04 }, + .rec_seq = { 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00 }, +}; + +static struct tls12_crypto_info_aes_gcm_256 tls_info_key0_256 = { + .info = { + .version = TLS_1_3_VERSION, + .cipher_type = TLS_CIPHER_AES_GCM_256, + }, + .iv = { 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08 }, + .key = { 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, + 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10, + 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, + 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, 0x20 }, + .salt = { 0x01, 0x02, 0x03, 0x04 }, + .rec_seq = { 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00 }, +}; + +static int do_rekey; +static int num_rekeys = 1; +static int rekeys_done; +static int cipher_type = 128; +static int tls_version = 13; +static int server_port = 4433; +static char *server_ip; +static int addr_family = AF_INET; + +static int send_size = 16384; +static int random_size_max; + +static int detect_addr_family(const char *ip) +{ + char addr_buf[INET6_ADDRSTRLEN]; + struct in_addr addr4; + struct in6_addr addr6; + char *scope_sep; + + if (inet_pton(AF_INET, ip, &addr4) == 1) + return AF_INET; + + strncpy(addr_buf, ip, sizeof(addr_buf) - 1); + addr_buf[sizeof(addr_buf) - 1] = '\0'; + scope_sep = strchr(addr_buf, '%'); + if (scope_sep) + *scope_sep = '\0'; + + if (inet_pton(AF_INET6, addr_buf, &addr6) == 1) + return AF_INET6; + return -1; +} + +/* Derive key for given generation (0 = initial, N = Nth rekey) */ +static void derive_key_128(struct tls12_crypto_info_aes_gcm_128 *key, + int generation) +{ + unsigned char pattern; + int i; + + memcpy(key, &tls_info_key0_128, sizeof(*key)); + key->info.version = (tls_version == 12) ? + TLS_1_2_VERSION : TLS_1_3_VERSION; + + if (generation == 0) + return; + + pattern = (unsigned char)((generation * 0x1B) ^ 0x63); + for (i = 0; i < TLS_CIPHER_AES_GCM_128_KEY_SIZE; i++) { + key->key[i] ^= pattern; + pattern = (pattern << 1) | (pattern >> 7); + } + + pattern = (unsigned char)((generation * 0x2D) ^ 0x7C); + for (i = 0; i < TLS_CIPHER_AES_GCM_128_IV_SIZE; i++) { + key->iv[i] ^= pattern; + pattern = (pattern << 1) | (pattern >> 7); + } + + for (i = 0; i < TLS_CIPHER_AES_GCM_128_SALT_SIZE; i++) + key->salt[i] ^= (unsigned char)(generation & 0xFF); + + memset(key->rec_seq, 0, TLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE); +} + +static void derive_key_256(struct tls12_crypto_info_aes_gcm_256 *key, + int generation) +{ + unsigned char pattern; + int i; + + memcpy(key, &tls_info_key0_256, sizeof(*key)); + key->info.version = (tls_version == 12) ? + TLS_1_2_VERSION : TLS_1_3_VERSION; + + if (generation == 0) + return; + + pattern = (unsigned char)((generation * 0x1B) ^ 0x63); + for (i = 0; i < TLS_CIPHER_AES_GCM_256_KEY_SIZE; i++) { + key->key[i] ^= pattern; + pattern = (pattern << 1) | (pattern >> 7); + } + + pattern = (unsigned char)((generation * 0x2D) ^ 0x7C); + for (i = 0; i < TLS_CIPHER_AES_GCM_256_IV_SIZE; i++) { + key->iv[i] ^= pattern; + pattern = (pattern << 1) | (pattern >> 7); + } + + for (i = 0; i < TLS_CIPHER_AES_GCM_256_SALT_SIZE; i++) + key->salt[i] ^= (unsigned char)(generation & 0xFF); + + memset(key->rec_seq, 0, TLS_CIPHER_AES_GCM_256_REC_SEQ_SIZE); +} + +static const char *cipher_name(int cipher) +{ + switch (cipher) { + case 128: return "AES-GCM-128"; + case 256: return "AES-GCM-256"; + default: return "unknown"; + } +} + +static const char *version_name(int version) +{ + switch (version) { + case 12: return "TLS 1.2"; + case 13: return "TLS 1.3"; + default: return "unknown"; + } +} + +static int setup_tls_ulp(int fd) +{ + int ret; + + ret = setsockopt(fd, IPPROTO_TCP, TCP_ULP, "tls", sizeof("tls")); + if (ret < 0) { + printf("TCP_ULP failed: %s\n", strerror(errno)); + return -1; + } + return 0; +} + +static int setup_tls_key(int fd, int is_tx, int generation, int cipher) +{ + int ret; + + if (cipher == 256) { + struct tls12_crypto_info_aes_gcm_256 key; + + derive_key_256(&key, generation); + ret = setsockopt(fd, SOL_TLS, is_tx ? TLS_TX : TLS_RX, + &key, sizeof(key)); + } else { + struct tls12_crypto_info_aes_gcm_128 key; + + derive_key_128(&key, generation); + ret = setsockopt(fd, SOL_TLS, is_tx ? TLS_TX : TLS_RX, + &key, sizeof(key)); + } + + if (ret < 0) { + printf("TLS_%s %s (gen %d) failed: %s\n", + is_tx ? "TX" : "RX", cipher_name(cipher), + generation, strerror(errno)); + return -1; + } + + printf("TLS_%s %s gen %d installed\n", + is_tx ? "TX" : "RX", cipher_name(cipher), generation); + return 0; +} + +/* Send TLS 1.3 KeyUpdate handshake message */ +static int send_tls_key_update(int fd, int request_update) +{ + char cmsg_buf[CMSG_SPACE(sizeof(unsigned char))]; + unsigned char key_update_msg[5]; + struct msghdr msg = {0}; + struct cmsghdr *cmsg; + struct iovec iov; + + key_update_msg[0] = TLS_HANDSHAKE_KEY_UPDATE; + key_update_msg[1] = 0; + key_update_msg[2] = 0; + key_update_msg[3] = 1; + key_update_msg[4] = request_update ? KEY_UPDATE_REQUESTED + : KEY_UPDATE_NOT_REQUESTED; + + iov.iov_base = key_update_msg; + iov.iov_len = sizeof(key_update_msg); + + msg.msg_iov = &iov; + msg.msg_iovlen = 1; + msg.msg_control = cmsg_buf; + msg.msg_controllen = sizeof(cmsg_buf); + + cmsg = CMSG_FIRSTHDR(&msg); + cmsg->cmsg_level = SOL_TLS; + cmsg->cmsg_type = TLS_SET_RECORD_TYPE; + cmsg->cmsg_len = CMSG_LEN(sizeof(unsigned char)); + *CMSG_DATA(cmsg) = TLS_RECORD_TYPE_HANDSHAKE; + msg.msg_controllen = cmsg->cmsg_len; + + if (sendmsg(fd, &msg, 0) < 0) { + printf("sendmsg KeyUpdate failed: %s\n", strerror(errno)); + return -1; + } + + printf("Sent TLS KeyUpdate handshake message\n"); + return 0; +} + +static int recv_tls_message(int fd, char *buf, size_t buflen, int *record_type) +{ + char cmsg_buf[CMSG_SPACE(sizeof(unsigned char))]; + struct msghdr msg = {0}; + struct cmsghdr *cmsg; + struct iovec iov; + int ret; + + iov.iov_base = buf; + iov.iov_len = buflen; + + msg.msg_iov = &iov; + msg.msg_iovlen = 1; + msg.msg_control = cmsg_buf; + msg.msg_controllen = sizeof(cmsg_buf); + + ret = recvmsg(fd, &msg, 0); + if (ret <= 0) + return ret; + + *record_type = TLS_RECORD_TYPE_APPLICATION_DATA; /* default */ + + cmsg = CMSG_FIRSTHDR(&msg); + if (cmsg && cmsg->cmsg_level == SOL_TLS && + cmsg->cmsg_type == TLS_GET_RECORD_TYPE) + *record_type = *((unsigned char *)CMSG_DATA(cmsg)); + + return ret; +} + +static int recv_tls_keyupdate(int fd) +{ + int record_type; + char buf[16]; + int ret; + + ret = recv_tls_message(fd, buf, sizeof(buf), &record_type); + if (ret < 0) { + printf("recv_tls_message failed: %s\n", strerror(errno)); + return -1; + } + + if (record_type != TLS_RECORD_TYPE_HANDSHAKE) { + printf("Expected handshake record (0x%02x), got 0x%02x\n", + TLS_RECORD_TYPE_HANDSHAKE, record_type); + return -1; + } + + if (ret >= 1 && buf[0] == TLS_HANDSHAKE_KEY_UPDATE) { + printf("Received TLS KeyUpdate handshake (%d bytes)\n", ret); + return 0; + } + + printf("Expected KeyUpdate (0x%02x), got 0x%02x\n", + TLS_HANDSHAKE_KEY_UPDATE, (unsigned char)buf[0]); + return -1; +} + +static void check_ekeyexpired(int fd) +{ + char buf[16]; + int ret; + + ret = recv(fd, buf, sizeof(buf), MSG_DONTWAIT); + if (ret == -1 && errno == EKEYEXPIRED) + printf("recv() returned EKEYEXPIRED as expected\n"); + else if (ret == -1 && errno == EAGAIN) + printf("recv() returned EAGAIN (no pending data)\n"); + else if (ret == -1) + printf("recv() returned error: %s\n", strerror(errno)); +} + +static int do_tls_rekey(int fd, int is_tx, int generation, int cipher) +{ + int ret; + + printf("Performing TLS_%s %s rekey to generation %d...\n", + is_tx ? "TX" : "RX", cipher_name(cipher), generation); + + if (cipher == 256) { + struct tls12_crypto_info_aes_gcm_256 key; + + derive_key_256(&key, generation); + ret = setsockopt(fd, SOL_TLS, is_tx ? TLS_TX : TLS_RX, + &key, sizeof(key)); + } else { + struct tls12_crypto_info_aes_gcm_128 key; + + derive_key_128(&key, generation); + ret = setsockopt(fd, SOL_TLS, is_tx ? TLS_TX : TLS_RX, + &key, sizeof(key)); + } + + if (ret < 0) { + printf("TLS_%s %s rekey failed: %s\n", is_tx ? "TX" : "RX", + cipher_name(cipher), strerror(errno)); + return -1; + } + printf("TLS_%s %s rekey to gen %d successful!\n", + is_tx ? "TX" : "RX", cipher_name(cipher), generation); + return 0; +} + +static int do_client(void) +{ + struct sockaddr_storage sa; + char *buf = NULL, *echo_buf = NULL; + int max_size, rekey_interval; + ssize_t echo_total, echo_n; + int csk = -1, ret, i, j; + int test_result = 0; + int current_gen = 0; + int next_rekey_at; + socklen_t sa_len; + ssize_t n; + + if (!server_ip) { + printf("ERROR: Client requires -s option\n"); + return -1; + } + + max_size = random_size_max > 0 ? random_size_max : send_size; + buf = malloc(max_size); + echo_buf = malloc(max_size); + if (!buf || !echo_buf) { + printf("failed to allocate buffers\n"); + test_result = -1; + goto out; + } + + csk = socket(addr_family, SOCK_STREAM, IPPROTO_TCP); + if (csk < 0) { + printf("failed to create socket: %s\n", strerror(errno)); + test_result = -1; + goto out; + } + + memset(&sa, 0, sizeof(sa)); + if (addr_family == AF_INET6) { + struct sockaddr_in6 *sa6 = (struct sockaddr_in6 *)&sa; + char addr_buf[INET6_ADDRSTRLEN]; + unsigned int scope_id = 0; + char *scope_sep; + + strncpy(addr_buf, server_ip, sizeof(addr_buf) - 1); + addr_buf[sizeof(addr_buf) - 1] = '\0'; + scope_sep = strchr(addr_buf, '%'); + if (scope_sep) { + *scope_sep = '\0'; + scope_id = if_nametoindex(scope_sep + 1); + if (scope_id == 0) { + printf("Invalid interface: %s\n", scope_sep + 1); + test_result = -1; + goto out; + } + } + + sa6->sin6_family = AF_INET6; + if (inet_pton(AF_INET6, addr_buf, &sa6->sin6_addr) != 1) { + printf("Invalid IPv6 address: %s\n", addr_buf); + test_result = -1; + goto out; + } + sa6->sin6_port = htons(server_port); + sa6->sin6_scope_id = scope_id; + sa_len = sizeof(*sa6); + printf("Connecting to [%s]:%d (scope_id=%u)...\n", + server_ip, server_port, scope_id); + } else { + struct sockaddr_in *sa4 = (struct sockaddr_in *)&sa; + + sa4->sin_family = AF_INET; + sa4->sin_addr.s_addr = inet_addr(server_ip); + sa4->sin_port = htons(server_port); + sa_len = sizeof(*sa4); + printf("Connecting to %s:%d...\n", server_ip, server_port); + } + + ret = connect(csk, (struct sockaddr *)&sa, sa_len); + if (ret < 0) { + printf("connect failed: %s\n", strerror(errno)); + test_result = -1; + goto out; + } + printf("Connected!\n"); + + if (setup_tls_ulp(csk) < 0) { + test_result = -1; + goto out; + } + + if (setup_tls_key(csk, 1, 0, cipher_type) < 0 || + setup_tls_key(csk, 0, 0, cipher_type) < 0) { + test_result = -1; + goto out; + } + + if (do_rekey) + printf("TLS %s setup complete. Will perform %d rekey(s).\n", + cipher_name(cipher_type), num_rekeys); + else + printf("TLS setup complete.\n"); + + if (random_size_max > 0) + printf("Sending %d messages of random size (1..%d bytes)...\n", + TEST_ITERATIONS, random_size_max); + else + printf("Sending %d messages of %d bytes...\n", + TEST_ITERATIONS, send_size); + + rekey_interval = TEST_ITERATIONS / (num_rekeys + 1); + if (rekey_interval < 1) + rekey_interval = 1; + next_rekey_at = rekey_interval; + + for (i = 0; i < TEST_ITERATIONS; i++) { + int this_size; + + if (random_size_max > 0) + this_size = (rand() % random_size_max) + 1; + else + this_size = send_size; + + for (j = 0; j < this_size; j++) + buf[j] = rand() & 0xFF; + + n = send(csk, buf, this_size, 0); + if (n != this_size) { + printf("FAIL: send failed: %s\n", strerror(errno)); + test_result = -1; + break; + } + printf("Sent %zd bytes (iteration %d)\n", n, i + 1); + + echo_total = 0; + while (echo_total < n) { + echo_n = recv(csk, echo_buf + echo_total, + n - echo_total, 0); + if (echo_n < 0) { + printf("FAIL: Echo recv failed: %s\n", + strerror(errno)); + test_result = -1; + break; + } + if (echo_n == 0) { + printf("FAIL: Connection closed during echo\n"); + test_result = -1; + break; + } + echo_total += echo_n; + } + if (test_result != 0) + break; + + if (memcmp(buf, echo_buf, n) != 0) { + printf("FAIL: Echo data mismatch!\n"); + test_result = -1; + break; + } + printf("Received echo %zd bytes (ok)\n", echo_total); + + /* Rekey at intervals: send KeyUpdate, update TX, recv KeyUpdate, update RX */ + if (do_rekey && rekeys_done < num_rekeys && + (i + 1) == next_rekey_at) { + current_gen++; + printf("\n=== Client Rekey #%d (gen %d) ===\n", + rekeys_done + 1, current_gen); + + ret = send_tls_key_update(csk, KEY_UPDATE_REQUESTED); + if (ret < 0) { + printf("FAIL: send KeyUpdate\n"); + test_result = -1; + break; + } + + ret = do_tls_rekey(csk, 1, current_gen, cipher_type); + if (ret < 0) { + test_result = -1; + break; + } + + if (recv_tls_keyupdate(csk) < 0) { + printf("FAIL: recv KeyUpdate from server\n"); + test_result = -1; + break; + } + + check_ekeyexpired(csk); + + ret = do_tls_rekey(csk, 0, current_gen, cipher_type); + if (ret < 0) { + test_result = -1; + break; + } + + rekeys_done++; + next_rekey_at += rekey_interval; + printf("=== Client Rekey #%d Complete ===\n\n", + rekeys_done); + } + } + + if (i < TEST_ITERATIONS && test_result == 0) { + printf("FAIL: Only %d of %d iterations\n", i, TEST_ITERATIONS); + test_result = -1; + } + + close(csk); + csk = -1; + if (do_rekey) + printf("Rekeys completed: %d/%d\n", rekeys_done, num_rekeys); + +out: + if (csk >= 0) + close(csk); + free(buf); + free(echo_buf); + return test_result; +} + +static int do_server(void) +{ + struct sockaddr_storage sa; + int lsk = -1, csk = -1, ret; + ssize_t n, total = 0, sent; + int current_gen = 0; + int test_result = 0; + int recv_count = 0; + char *buf = NULL; + int record_type; + socklen_t sa_len; + int max_size; + int one = 1; + + max_size = random_size_max > 0 ? random_size_max : send_size; + buf = malloc(max_size); + if (!buf) { + printf("failed to allocate buffer\n"); + test_result = -1; + goto out; + } + + lsk = socket(addr_family, SOCK_STREAM, IPPROTO_TCP); + if (lsk < 0) { + printf("failed to create socket: %s\n", strerror(errno)); + test_result = -1; + goto out; + } + + setsockopt(lsk, SOL_SOCKET, SO_REUSEADDR, &one, sizeof(one)); + + memset(&sa, 0, sizeof(sa)); + if (addr_family == AF_INET6) { + struct sockaddr_in6 *sa6 = (struct sockaddr_in6 *)&sa; + + sa6->sin6_family = AF_INET6; + sa6->sin6_addr = in6addr_any; + sa6->sin6_port = htons(server_port); + sa_len = sizeof(*sa6); + } else { + struct sockaddr_in *sa4 = (struct sockaddr_in *)&sa; + + sa4->sin_family = AF_INET; + sa4->sin_addr.s_addr = INADDR_ANY; + sa4->sin_port = htons(server_port); + sa_len = sizeof(*sa4); + } + + ret = bind(lsk, (struct sockaddr *)&sa, sa_len); + if (ret < 0) { + printf("bind failed: %s\n", strerror(errno)); + test_result = -1; + goto out; + } + + ret = listen(lsk, 5); + if (ret < 0) { + printf("listen failed: %s\n", strerror(errno)); + test_result = -1; + goto out; + } + + if (addr_family == AF_INET6) + printf("Server listening on [::]:%d (IPv6)\n", server_port); + else + printf("Server listening on 0.0.0.0:%d (IPv4)\n", server_port); + printf("Waiting for client connection...\n"); + + csk = accept(lsk, (struct sockaddr *)NULL, (socklen_t *)NULL); + if (csk < 0) { + printf("accept failed: %s\n", strerror(errno)); + test_result = -1; + goto out; + } + printf("Client connected!\n"); + + if (setup_tls_ulp(csk) < 0) { + test_result = -1; + goto out; + } + + if (setup_tls_key(csk, 1, 0, cipher_type) < 0 || + setup_tls_key(csk, 0, 0, cipher_type) < 0) { + test_result = -1; + goto out; + } + + printf("TLS %s setup complete. Receiving...\n", + cipher_name(cipher_type)); + + /* Main receive loop - detect KeyUpdate via MSG_PEEK + recvmsg */ + while (1) { + n = recv(csk, buf, max_size, MSG_PEEK | MSG_DONTWAIT); + if (n < 0 && + (errno == EIO || errno == ENOMSG || errno == EAGAIN)) { + n = recv_tls_message(csk, buf, max_size, &record_type); + } else if (n > 0) { + n = recv_tls_message(csk, buf, max_size, &record_type); + } else if (n == 0) { + printf("Connection closed by client\n"); + break; + } + + if (n <= 0) { + if (n < 0) + printf("recv failed: %s\n", strerror(errno)); + break; + } + + /* Handle KeyUpdate: update RX, send response, update TX */ + if (record_type == TLS_RECORD_TYPE_HANDSHAKE && + n >= 1 && buf[0] == TLS_HANDSHAKE_KEY_UPDATE) { + current_gen++; + printf("\n=== Server Rekey #%d (gen %d) ===\n", + rekeys_done + 1, current_gen); + printf("Received KeyUpdate from client (%zd bytes)\n", + n); + + check_ekeyexpired(csk); + + ret = do_tls_rekey(csk, 0, current_gen, cipher_type); + if (ret < 0) { + test_result = -1; + break; + } + + ret = send_tls_key_update(csk, + KEY_UPDATE_NOT_REQUESTED); + if (ret < 0) { + printf("Failed to send KeyUpdate\n"); + test_result = -1; + break; + } + + ret = do_tls_rekey(csk, 1, current_gen, cipher_type); + if (ret < 0) { + test_result = -1; + break; + } + + rekeys_done++; + printf("=== Server Rekey #%d Complete ===\n\n", + rekeys_done); + continue; + } + + total += n; + recv_count++; + printf("Received %zd bytes (total: %zd, count: %d)\n", + n, total, recv_count); + + sent = send(csk, buf, n, 0); + if (sent < 0) { + printf("Echo send failed: %s\n", strerror(errno)); + break; + } + if (sent != n) + printf("Echo partial: %zd of %zd bytes\n", sent, n); + printf("Echoed %zd bytes back to client\n", sent); + } + + printf("Connection closed. Total received: %zd bytes\n", total); + if (do_rekey) + printf("Rekeys completed: %d\n", rekeys_done); + +out: + if (csk >= 0) + close(csk); + if (lsk >= 0) + close(lsk); + free(buf); + return test_result; +} + +static void parse_rekey_option(const char *arg) +{ + int requested; + + if (strncmp(arg, "--rekey=", 8) == 0) { + requested = atoi(arg + 8); + if (requested < 1) { + printf("WARNING: Invalid rekey count, using 1\n"); + num_rekeys = 1; + } else if (requested > MAX_REKEYS) { + printf("WARNING: Rekey count %d > max %d, using %d\n", + requested, MAX_REKEYS, MAX_REKEYS); + num_rekeys = MAX_REKEYS; + } else { + num_rekeys = requested; + } + do_rekey = 1; + } else if (strcmp(arg, "--rekey") == 0) { + do_rekey = 1; + num_rekeys = 1; + } +} + +static int parse_cipher_option(const char *arg) +{ + if (strcmp(arg, "128") == 0) { + cipher_type = 128; + return 0; + } else if (strcmp(arg, "256") == 0) { + cipher_type = 256; + return 0; + } + printf("ERROR: Invalid cipher '%s'. Must be 128 or 256.\n", arg); + return -1; +} + +static int parse_version_option(const char *arg) +{ + if (strcmp(arg, "1.2") == 0) { + tls_version = 12; + return 0; + } else if (strcmp(arg, "1.3") == 0) { + tls_version = 13; + return 0; + } + printf("ERROR: Invalid TLS version '%s'. Must be 1.2 or 1.3.\n", arg); + return -1; +} + +static void print_usage(const char *prog) +{ + printf("TLS Hardware Offload Two-Node Test\n\n"); + printf("Usage:\n"); + printf(" %s server [OPTIONS]\n", prog); + printf(" %s client -s [OPTIONS]\n", prog); + printf("\nOptions:\n"); + printf(" -s Server IP to connect (client, required)\n"); + printf(" Supports both IPv4 and IPv6 addresses\n"); + printf(" -6 Use IPv6 (server only, default: IPv4)\n"); + printf(" -p Server port (default: 4433)\n"); + printf(" -b Send buffer (record) size (default: 16384)\n"); + printf(" -r Use random send buffer sizes (1..)\n"); + printf(" -v TLS version: 1.2 or 1.3 (default: 1.3)\n"); + printf(" -c Cipher: 128 or 256 (default: 128)\n"); + printf(" --rekey[=N] Enable rekey (default: 1, TLS 1.3 only)\n"); + printf(" --help Show this help message\n"); + printf("\nExample (IPv4):\n"); + printf(" Node A: %s server\n", prog); + printf(" Node B: %s client -s 192.168.20.2\n", prog); + printf("\nExample (IPv6):\n"); + printf(" Node A: %s server -6\n", prog); + printf(" Node B: %s client -s 2001:db8::1\n", prog); + printf("\nRekey Example (3 rekeys, TLS 1.3 only):\n"); + printf(" Node A: %s server --rekey=3\n", prog); + printf(" Node B: %s client -s 192.168.20.2 --rekey=3\n", prog); +} + +int main(int argc, char *argv[]) +{ + int i; + + + for (i = 1; i < argc; i++) { + if (strcmp(argv[i], "--help") == 0 || + strcmp(argv[i], "-h") == 0) { + print_usage(argv[0]); + return 0; + } + } + + for (i = 1; i < argc; i++) { + parse_rekey_option(argv[i]); + if (strcmp(argv[i], "-s") == 0 && i + 1 < argc) { + server_ip = argv[i + 1]; + addr_family = detect_addr_family(server_ip); + if (addr_family < 0) { + printf("ERROR: Invalid IP address '%s'\n", + server_ip); + return -1; + } + } + if (strcmp(argv[i], "-p") == 0 && i + 1 < argc) + server_port = atoi(argv[i + 1]); + if (strcmp(argv[i], "-6") == 0) + addr_family = AF_INET6; + if (strcmp(argv[i], "-b") == 0 && i + 1 < argc) { + send_size = atoi(argv[i + 1]); + if (send_size < 1) + send_size = 1; + } + if (strcmp(argv[i], "-r") == 0 && i + 1 < argc) { + random_size_max = atoi(argv[i + 1]); + if (random_size_max < 1) + random_size_max = 1; + } + if (strcmp(argv[i], "-c") == 0 && i + 1 < argc) { + if (parse_cipher_option(argv[i + 1]) < 0) + return -1; + } + if (strcmp(argv[i], "-v") == 0 && i + 1 < argc) { + if (parse_version_option(argv[i + 1]) < 0) + return -1; + } + } + + if (tls_version == 12 && do_rekey) { + printf("WARNING: TLS 1.2 does not support rekey\n"); + do_rekey = 0; + } + + printf("Address Family: %s\n", addr_family == AF_INET6 ? "IPv6" : "IPv4"); + printf("TLS Version: %s\n", version_name(tls_version)); + printf("Cipher: %s\n", cipher_name(cipher_type)); + if (random_size_max > 0) + printf("Buffer size: random (1..%d)\n", random_size_max); + else + printf("Buffer size: %d\n", send_size); + + if (do_rekey) + printf("Rekey testing ENABLED: %d rekey(s)\n", num_rekeys); + + srand(time(NULL)); + + if (argc < 2 || + (strcmp(argv[1], "server") && strcmp(argv[1], "client"))) { + print_usage(argv[0]); + return -1; + } + + if (!strcmp(argv[1], "client")) + return do_client(); + + return do_server(); +} diff --git a/tools/testing/selftests/drivers/net/hw/tls_hw_offload.py b/tools/testing/selftests/drivers/net/hw/tls_hw_offload.py new file mode 100755 index 000000000000..5d14cb7d2e3c --- /dev/null +++ b/tools/testing/selftests/drivers/net/hw/tls_hw_offload.py @@ -0,0 +1,281 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: GPL-2.0 + +"""Test kTLS hardware offload using a C helper binary.""" + +from lib.py import ksft_run, ksft_exit, ksft_pr, KsftSkipEx, ksft_true +from lib.py import NetDrvEpEnv +from lib.py import cmd, bkg, wait_port_listen, rand_port +import time + + +def check_tls_support(cfg): + try: + cmd("test -f /proc/net/tls_stat") + cmd("test -f /proc/net/tls_stat", host=cfg.remote) + except Exception as e: + raise KsftSkipEx(f"kTLS not supported: {e}") + + +def read_tls_stats(): + stats = {} + output = cmd("cat /proc/net/tls_stat") + for line in output.stdout.strip().split('\n'): + parts = line.split() + if len(parts) == 2: + stats[parts[0]] = int(parts[1]) + return stats + + +def verify_tls_counters(stats_before, stats_after, expected_rekeys, is_server): + tx_device_diff = (stats_after.get('TlsTxDevice', 0) - + stats_before.get('TlsTxDevice', 0)) + rx_device_diff = (stats_after.get('TlsRxDevice', 0) - + stats_before.get('TlsRxDevice', 0)) + tx_sw_diff = (stats_after.get('TlsTxSw', 0) - + stats_before.get('TlsTxSw', 0)) + rx_sw_diff = (stats_after.get('TlsRxSw', 0) - + stats_before.get('TlsRxSw', 0)) + decrypt_err_diff = (stats_after.get('TlsDecryptError', 0) - + stats_before.get('TlsDecryptError', 0)) + + used_tx_hw = tx_device_diff >= 1 + used_rx_hw = rx_device_diff >= 1 + used_tx_sw = tx_sw_diff >= 1 + used_rx_sw = rx_sw_diff >= 1 + + errors = 0 + + role = 'Server' if is_server else 'Client' + ksft_pr(f"=== Counter Verification ({role}) ===") + + tx_dev_before = stats_before.get('TlsTxDevice', 0) + tx_dev_after = stats_after.get('TlsTxDevice', 0) + ksft_pr(f"TlsTxDevice: {tx_dev_before} -> {tx_dev_after} " + f"(diff: {tx_device_diff})") + + tx_sw_before = stats_before.get('TlsTxSw', 0) + tx_sw_after = stats_after.get('TlsTxSw', 0) + ksft_pr(f"TlsTxSw: {tx_sw_before} -> {tx_sw_after} " + f"(diff: {tx_sw_diff})") + + if used_tx_hw: + ksft_pr("TX Path: HARDWARE OFFLOAD") + elif used_tx_sw: + ksft_pr("TX Path: SOFTWARE") + else: + ksft_pr("TX Path: FAIL (no TLS TX activity detected)") + errors += 1 + + rx_dev_before = stats_before.get('TlsRxDevice', 0) + rx_dev_after = stats_after.get('TlsRxDevice', 0) + ksft_pr(f"TlsRxDevice: {rx_dev_before} -> {rx_dev_after} " + f"(diff: {rx_device_diff})") + + rx_sw_before = stats_before.get('TlsRxSw', 0) + rx_sw_after = stats_after.get('TlsRxSw', 0) + ksft_pr(f"TlsRxSw: {rx_sw_before} -> {rx_sw_after} " + f"(diff: {rx_sw_diff})") + + if used_rx_hw: + ksft_pr("RX Path: HARDWARE OFFLOAD") + elif used_rx_sw: + ksft_pr("RX Path: SOFTWARE") + else: + ksft_pr("RX Path: FAIL (no TLS RX activity detected)") + errors += 1 + + if expected_rekeys > 0: + tx_rekey_diff = (stats_after.get('TlsTxRekeyOk', 0) - + stats_before.get('TlsTxRekeyOk', 0)) + rx_rekey_diff = (stats_after.get('TlsRxRekeyOk', 0) - + stats_before.get('TlsRxRekeyOk', 0)) + rx_rekey_recv_diff = (stats_after.get('TlsRxRekeyReceived', 0) - + stats_before.get('TlsRxRekeyReceived', 0)) + tx_rekey_err_diff = (stats_after.get('TlsTxRekeyError', 0) - + stats_before.get('TlsTxRekeyError', 0)) + rx_rekey_err_diff = (stats_after.get('TlsRxRekeyError', 0) - + stats_before.get('TlsRxRekeyError', 0)) + + tx_rekey_before = stats_before.get('TlsTxRekeyOk', 0) + tx_rekey_after = stats_after.get('TlsTxRekeyOk', 0) + ksft_pr(f"TlsTxRekeyOk: {tx_rekey_before} -> {tx_rekey_after} " + f"(diff: {tx_rekey_diff})") + if tx_rekey_diff < expected_rekeys: + ksft_pr(f"FAIL: Expected >= {expected_rekeys} TX rekeys") + errors += 1 + + rx_rekey_before = stats_before.get('TlsRxRekeyOk', 0) + rx_rekey_after = stats_after.get('TlsRxRekeyOk', 0) + ksft_pr(f"TlsRxRekeyOk: {rx_rekey_before} -> {rx_rekey_after} " + f"(diff: {rx_rekey_diff})") + if rx_rekey_diff < expected_rekeys: + ksft_pr(f"FAIL: Expected >= {expected_rekeys} RX rekeys") + errors += 1 + + if is_server: + rx_recv_before = stats_before.get('TlsRxRekeyReceived', 0) + rx_recv_after = stats_after.get('TlsRxRekeyReceived', 0) + ksft_pr(f"TlsRxRekeyReceived: {rx_recv_before} -> " + f"{rx_recv_after} (diff: {rx_rekey_recv_diff})") + if rx_rekey_recv_diff < expected_rekeys: + ksft_pr(f"FAIL: Expected >= {expected_rekeys} " + f"KeyUpdate messages") + errors += 1 + + if tx_rekey_err_diff > 0: + ksft_pr(f"ERROR: TlsTxRekeyError increased by " + f"{tx_rekey_err_diff}") + errors += 1 + if rx_rekey_err_diff > 0: + ksft_pr(f"ERROR: TlsRxRekeyError increased by " + f"{rx_rekey_err_diff}") + errors += 1 + + if decrypt_err_diff > 0: + ksft_pr(f"ERROR: TlsDecryptError increased by {decrypt_err_diff}") + errors += 1 + + ksft_pr(f"=== Verification {'PASSED' if errors == 0 else 'FAILED'} ===\n") + return errors == 0 + + +def run_tls_test(cfg, cipher="128", tls_version="1.3", rekey=0, buffer_size=None, random_max=None): + port = rand_port() + + server_cmd = f"{cfg.bin_remote} server -p {port} -c {cipher} -v {tls_version}" + if rekey > 0: + server_cmd += f" --rekey={rekey}" + if random_max: + server_cmd += f" -r {random_max}" + elif buffer_size: + server_cmd += f" -b {buffer_size}" + + client_cmd = (f"{cfg.bin_local} client -s {cfg.remote_addr_v['4']} " + f"-p {port} -c {cipher} -v {tls_version}") + if rekey > 0: + client_cmd += f" --rekey={rekey}" + if random_max: + client_cmd += f" -r {random_max}" + elif buffer_size: + client_cmd += f" -b {buffer_size}" + + test_desc = f"cipher={cipher}, version={tls_version}, rekey={rekey}" + if random_max: + test_desc += f", random_size=1-{random_max}" + elif buffer_size: + test_desc += f", buffer={buffer_size}" + ksft_pr(f"Starting TLS test: {test_desc}") + + stats_before_local = read_tls_stats() + stats_before_remote = read_tls_stats_remote(cfg) + + with bkg(server_cmd, host=cfg.remote, exit_wait=True): + wait_port_listen(port, host=cfg.remote) + time.sleep(0.5) + + ksft_pr("Running client...") + result = cmd(client_cmd, fail=False) + time.sleep(1) + + stats_after_local = read_tls_stats() + stats_after_remote = read_tls_stats_remote(cfg) + + ksft_pr("\n=== Client Side Verification ===") + client_ok = verify_tls_counters(stats_before_local, stats_after_local, rekey, False) + + ksft_pr("\n=== Server Side Verification ===") + server_ok = verify_tls_counters(stats_before_remote, stats_after_remote, rekey, True) + + ksft_true(result.ret == 0, "Client completed successfully") + ksft_true(client_ok, "Client TLS counters verified") + ksft_true(server_ok, "Server TLS counters verified") + + +def read_tls_stats_remote(cfg): + stats = {} + output = cmd("cat /proc/net/tls_stat", host=cfg.remote) + for line in output.stdout.strip().split('\n'): + parts = line.split() + if len(parts) == 2: + stats[parts[0]] = int(parts[1]) + return stats + + +def test_tls_offload_basic(cfg): + cfg.require_ipver("4") + check_tls_support(cfg) + run_tls_test(cfg, cipher="128", tls_version="1.3", rekey=0) + + +def test_tls_offload_aes256(cfg): + cfg.require_ipver("4") + check_tls_support(cfg) + run_tls_test(cfg, cipher="256", tls_version="1.3", rekey=0) + + +def test_tls_offload_tls12(cfg): + cfg.require_ipver("4") + check_tls_support(cfg) + run_tls_test(cfg, cipher="128", tls_version="1.2", rekey=0) + + +def test_tls_offload_tls12_aes256(cfg): + cfg.require_ipver("4") + check_tls_support(cfg) + run_tls_test(cfg, cipher="256", tls_version="1.2", rekey=0) + + +def test_tls_offload_rekey(cfg): + cfg.require_ipver("4") + check_tls_support(cfg) + run_tls_test(cfg, cipher="128", tls_version="1.3", rekey=1) + + +def test_tls_offload_rekey_multiple(cfg): + cfg.require_ipver("4") + check_tls_support(cfg) + run_tls_test(cfg, cipher="128", tls_version="1.3", rekey=99) + + +def test_tls_offload_small_records(cfg): + cfg.require_ipver("4") + check_tls_support(cfg) + run_tls_test(cfg, cipher="128", tls_version="1.3", rekey=30, buffer_size=512) + + +def test_tls_offload_large_records(cfg): + cfg.require_ipver("4") + check_tls_support(cfg) + run_tls_test(cfg, cipher="128", tls_version="1.3", rekey=10, buffer_size=2097152) + + +def test_tls_offload_random_sizes(cfg): + cfg.require_ipver("4") + check_tls_support(cfg) + run_tls_test(cfg, cipher="128", tls_version="1.3", rekey=20, random_max=8192) + + +def main() -> None: + with NetDrvEpEnv(__file__, nsim_test=False) as cfg: + cfg.bin_local = cfg.test_dir / "tls_hw_offload" + if not cfg.bin_local.exists(): + raise KsftSkipEx(f"tls_hw_offload binary not found at {cfg.bin_local}") + cfg.bin_remote = cfg.remote.deploy(cfg.bin_local) + + ksft_run([ + test_tls_offload_basic, + test_tls_offload_aes256, + test_tls_offload_tls12, + test_tls_offload_tls12_aes256, + test_tls_offload_rekey, + test_tls_offload_rekey_multiple, + test_tls_offload_small_records, + test_tls_offload_large_records, + test_tls_offload_random_sizes, + ], args=(cfg, )) + ksft_exit() + + +if __name__ == "__main__": + main() -- 2.25.1