Add TLS 1.3 support to the kernel TLS hardware offload infrastructure, enabling hardware acceleration for TLS 1.3 connections on capable NICs. Tested on Mellanox ConnectX-6 Dx (Crypto Enabled) with TLS 1.3 AES-GCM-128 and AES-GCM-256 cipher suites. Signed-off-by: Rishikesh Jethwani --- net/tls/tls_device.c | 65 ++++++++++++++++----------- net/tls/tls_device_fallback.c | 58 +++++++++++++----------- net/tls/tls_main.c | 85 ++++++++++++++++++++--------------- 3 files changed, 121 insertions(+), 87 deletions(-) diff --git a/net/tls/tls_device.c b/net/tls/tls_device.c index 99c8eff9783e..1321bf9b59b0 100644 --- a/net/tls/tls_device.c +++ b/net/tls/tls_device.c @@ -317,25 +317,34 @@ static void tls_device_record_close(struct sock *sk, unsigned char record_type) { struct tls_prot_info *prot = &ctx->prot_info; - struct page_frag dummy_tag_frag; - - /* append tag - * device will fill in the tag, we just need to append a placeholder - * use socket memory to improve coalescing (re-using a single buffer - * increases frag count) - * if we can't allocate memory now use the dummy page + int tail = prot->tag_size + prot->tail_size; + + /* Append tail: tag for TLS 1.2, content_type + tag for TLS 1.3. + * Device fills in the tag, we just need to append a placeholder. + * Use socket memory to improve coalescing (re-using a single buffer + * increases frag count); if allocation fails use dummy_page + * (offset = record_type gives correct content_type byte via + * identity mapping) */ - if (unlikely(pfrag->size - pfrag->offset < prot->tag_size) && - !skb_page_frag_refill(prot->tag_size, pfrag, sk->sk_allocation)) { - dummy_tag_frag.page = dummy_page; - dummy_tag_frag.offset = 0; - pfrag = &dummy_tag_frag; + if (unlikely(pfrag->size - pfrag->offset < tail) && + !skb_page_frag_refill(tail, pfrag, sk->sk_allocation)) { + struct page_frag dummy_pfrag = { + .page = dummy_page, + .offset = record_type, + }; + tls_append_frag(record, &dummy_pfrag, tail); + } else { + if (prot->tail_size) { + char *content_type_addr = page_address(pfrag->page) + + pfrag->offset; + *content_type_addr = record_type; + } + tls_append_frag(record, pfrag, tail); } - tls_append_frag(record, pfrag, prot->tag_size); /* fill prepend */ tls_fill_prepend(ctx, skb_frag_address(&record->frags[0]), - record->len - prot->overhead_size, + record->len - prot->overhead_size + prot->tail_size, record_type); } @@ -883,6 +892,7 @@ static int tls_device_reencrypt(struct sock *sk, struct tls_context *tls_ctx) { struct tls_sw_context_rx *sw_ctx = tls_sw_ctx_rx(tls_ctx); + struct tls_prot_info *prot = &tls_ctx->prot_info; const struct tls_cipher_desc *cipher_desc; int err, offset, copy, data_len, pos; struct sk_buff *skb, *skb_iter; @@ -894,7 +904,7 @@ tls_device_reencrypt(struct sock *sk, struct tls_context *tls_ctx) DEBUG_NET_WARN_ON_ONCE(!cipher_desc || !cipher_desc->offloadable); rxm = strp_msg(tls_strp_msg(sw_ctx)); - orig_buf = kmalloc(rxm->full_len + TLS_HEADER_SIZE + cipher_desc->iv, + orig_buf = kmalloc(rxm->full_len + prot->prepend_size, sk->sk_allocation); if (!orig_buf) return -ENOMEM; @@ -909,9 +919,8 @@ tls_device_reencrypt(struct sock *sk, struct tls_context *tls_ctx) offset = rxm->offset; sg_init_table(sg, 1); - sg_set_buf(&sg[0], buf, - rxm->full_len + TLS_HEADER_SIZE + cipher_desc->iv); - err = skb_copy_bits(skb, offset, buf, TLS_HEADER_SIZE + cipher_desc->iv); + sg_set_buf(&sg[0], buf, rxm->full_len + prot->prepend_size); + err = skb_copy_bits(skb, offset, buf, prot->prepend_size); if (err) goto free_buf; @@ -1089,11 +1098,6 @@ int tls_set_device_offload(struct sock *sk) } crypto_info = &ctx->crypto_send.info; - if (crypto_info->version != TLS_1_2_VERSION) { - rc = -EOPNOTSUPP; - goto release_netdev; - } - cipher_desc = get_cipher_desc(crypto_info->cipher_type); if (!cipher_desc || !cipher_desc->offloadable) { rc = -EINVAL; @@ -1196,9 +1200,6 @@ int tls_set_device_offload_rx(struct sock *sk, struct tls_context *ctx) struct net_device *netdev; int rc = 0; - if (ctx->crypto_recv.info.version != TLS_1_2_VERSION) - return -EOPNOTSUPP; - netdev = get_netdev_for_sock(sk); if (!netdev) { pr_err_ratelimited("%s: netdev not found\n", __func__); @@ -1409,12 +1410,22 @@ static struct notifier_block tls_dev_notifier = { int __init tls_device_init(void) { - int err; + unsigned char *page_addr; + int err, i; dummy_page = alloc_page(GFP_KERNEL); if (!dummy_page) return -ENOMEM; + /* Pre-populate dummy_page with identity mapping for all byte values. + * This is used as fallback for TLS 1.3 content type when memory + * allocation fails. By populating all 256 values, we avoid needing + * to validate record_type at runtime. + */ + page_addr = page_address(dummy_page); + for (i = 0; i < 256; i++) + page_addr[i] = (unsigned char)i; + destruct_wq = alloc_workqueue("ktls_device_destruct", WQ_PERCPU, 0); if (!destruct_wq) { err = -ENOMEM; diff --git a/net/tls/tls_device_fallback.c b/net/tls/tls_device_fallback.c index 3b7d0ab2bcf1..1110f7ac6bcb 100644 --- a/net/tls/tls_device_fallback.c +++ b/net/tls/tls_device_fallback.c @@ -37,14 +37,15 @@ #include "tls.h" -static int tls_enc_record(struct aead_request *aead_req, +static int tls_enc_record(struct tls_context *tls_ctx, + struct aead_request *aead_req, struct crypto_aead *aead, char *aad, char *iv, __be64 rcd_sn, struct scatter_walk *in, - struct scatter_walk *out, int *in_len, - struct tls_prot_info *prot) + struct scatter_walk *out, int *in_len) { unsigned char buf[TLS_HEADER_SIZE + TLS_MAX_IV_SIZE]; + struct tls_prot_info *prot = &tls_ctx->prot_info; const struct tls_cipher_desc *cipher_desc; struct scatterlist sg_in[3]; struct scatterlist sg_out[3]; @@ -55,7 +56,7 @@ static int tls_enc_record(struct aead_request *aead_req, cipher_desc = get_cipher_desc(prot->cipher_type); DEBUG_NET_WARN_ON_ONCE(!cipher_desc || !cipher_desc->offloadable); - buf_size = TLS_HEADER_SIZE + cipher_desc->iv; + buf_size = prot->prepend_size; len = min_t(int, *in_len, buf_size); memcpy_from_scatterwalk(buf, in, len); @@ -66,16 +67,27 @@ static int tls_enc_record(struct aead_request *aead_req, return 0; len = buf[4] | (buf[3] << 8); - len -= cipher_desc->iv; + if (prot->version != TLS_1_3_VERSION) + len -= cipher_desc->iv; tls_make_aad(aad, len - cipher_desc->tag, (char *)&rcd_sn, buf[0], prot); - memcpy(iv + cipher_desc->salt, buf + TLS_HEADER_SIZE, cipher_desc->iv); + if (prot->version == TLS_1_3_VERSION) { + void *iv_src = crypto_info_iv(&tls_ctx->crypto_send.info, + cipher_desc); + + memcpy(iv + cipher_desc->salt, iv_src, cipher_desc->iv); + } else { + memcpy(iv + cipher_desc->salt, buf + TLS_HEADER_SIZE, + cipher_desc->iv); + } + + tls_xor_iv_with_seq(prot, iv, (char *)&rcd_sn); sg_init_table(sg_in, ARRAY_SIZE(sg_in)); sg_init_table(sg_out, ARRAY_SIZE(sg_out)); - sg_set_buf(sg_in, aad, TLS_AAD_SPACE_SIZE); - sg_set_buf(sg_out, aad, TLS_AAD_SPACE_SIZE); + sg_set_buf(sg_in, aad, prot->aad_size); + sg_set_buf(sg_out, aad, prot->aad_size); scatterwalk_get_sglist(in, sg_in + 1); scatterwalk_get_sglist(out, sg_out + 1); @@ -108,13 +120,6 @@ static int tls_enc_record(struct aead_request *aead_req, return rc; } -static void tls_init_aead_request(struct aead_request *aead_req, - struct crypto_aead *aead) -{ - aead_request_set_tfm(aead_req, aead); - aead_request_set_ad(aead_req, TLS_AAD_SPACE_SIZE); -} - static struct aead_request *tls_alloc_aead_request(struct crypto_aead *aead, gfp_t flags) { @@ -124,14 +129,15 @@ static struct aead_request *tls_alloc_aead_request(struct crypto_aead *aead, aead_req = kzalloc(req_size, flags); if (aead_req) - tls_init_aead_request(aead_req, aead); + aead_request_set_tfm(aead_req, aead); return aead_req; } -static int tls_enc_records(struct aead_request *aead_req, +static int tls_enc_records(struct tls_context *tls_ctx, + struct aead_request *aead_req, struct crypto_aead *aead, struct scatterlist *sg_in, struct scatterlist *sg_out, char *aad, char *iv, - u64 rcd_sn, int len, struct tls_prot_info *prot) + u64 rcd_sn, int len) { struct scatter_walk out, in; int rc; @@ -140,8 +146,8 @@ static int tls_enc_records(struct aead_request *aead_req, scatterwalk_start(&out, sg_out); do { - rc = tls_enc_record(aead_req, aead, aad, iv, - cpu_to_be64(rcd_sn), &in, &out, &len, prot); + rc = tls_enc_record(tls_ctx, aead_req, aead, aad, iv, + cpu_to_be64(rcd_sn), &in, &out, &len); rcd_sn++; } while (rc == 0 && len); @@ -314,7 +320,10 @@ static struct sk_buff *tls_enc_skb(struct tls_context *tls_ctx, cipher_desc = get_cipher_desc(tls_ctx->crypto_send.info.cipher_type); DEBUG_NET_WARN_ON_ONCE(!cipher_desc || !cipher_desc->offloadable); - buf_len = cipher_desc->salt + cipher_desc->iv + TLS_AAD_SPACE_SIZE + + aead_request_set_ad(aead_req, tls_ctx->prot_info.aad_size); + + buf_len = cipher_desc->salt + cipher_desc->iv + + tls_ctx->prot_info.aad_size + sync_size + cipher_desc->tag; buf = kmalloc(buf_len, GFP_ATOMIC); if (!buf) @@ -324,7 +333,7 @@ static struct sk_buff *tls_enc_skb(struct tls_context *tls_ctx, salt = crypto_info_salt(&tls_ctx->crypto_send.info, cipher_desc); memcpy(iv, salt, cipher_desc->salt); aad = buf + cipher_desc->salt + cipher_desc->iv; - dummy_buf = aad + TLS_AAD_SPACE_SIZE; + dummy_buf = aad + tls_ctx->prot_info.aad_size; nskb = alloc_skb(skb_headroom(skb) + skb->len, GFP_ATOMIC); if (!nskb) @@ -335,9 +344,8 @@ static struct sk_buff *tls_enc_skb(struct tls_context *tls_ctx, fill_sg_out(sg_out, buf, tls_ctx, nskb, tcp_payload_offset, payload_len, sync_size, dummy_buf); - if (tls_enc_records(aead_req, ctx->aead_send, sg_in, sg_out, aad, iv, - rcd_sn, sync_size + payload_len, - &tls_ctx->prot_info) < 0) + if (tls_enc_records(tls_ctx, aead_req, ctx->aead_send, sg_in, sg_out, + aad, iv, rcd_sn, sync_size + payload_len) < 0) goto free_nskb; complete_skb(nskb, skb, tcp_payload_offset); diff --git a/net/tls/tls_main.c b/net/tls/tls_main.c index fd39acf41a61..fd04857fa0ab 100644 --- a/net/tls/tls_main.c +++ b/net/tls/tls_main.c @@ -711,49 +711,64 @@ static int do_tls_setsockopt_conf(struct sock *sk, sockptr_t optval, } if (tx) { - rc = tls_set_device_offload(sk); - conf = TLS_HW; - if (!rc) { - TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSTXDEVICE); - TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRTXDEVICE); - } else { - rc = tls_set_sw_offload(sk, 1, - update ? crypto_info : NULL); - if (rc) - goto err_crypto_info; - - if (update) { - TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSTXREKEYOK); - } else { - TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSTXSW); - TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRTXSW); + if (update && ctx->tx_conf == TLS_HW) { + rc = -EOPNOTSUPP; + goto err_crypto_info; + } + + if (!update) { + rc = tls_set_device_offload(sk); + conf = TLS_HW; + if (!rc) { + TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSTXDEVICE); + TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRTXDEVICE); + goto out; } - conf = TLS_SW; } - } else { - rc = tls_set_device_offload_rx(sk, ctx); - conf = TLS_HW; - if (!rc) { - TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSRXDEVICE); - TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRRXDEVICE); + + rc = tls_set_sw_offload(sk, 1, update ? crypto_info : NULL); + if (rc) + goto err_crypto_info; + + if (update) { + TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSTXREKEYOK); } else { - rc = tls_set_sw_offload(sk, 0, - update ? crypto_info : NULL); - if (rc) - goto err_crypto_info; - - if (update) { - TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSRXREKEYOK); - } else { - TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSRXSW); - TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRRXSW); + TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSTXSW); + TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRTXSW); + } + conf = TLS_SW; + } else { + if (update && ctx->rx_conf == TLS_HW) { + rc = -EOPNOTSUPP; + goto err_crypto_info; + } + + if (!update) { + rc = tls_set_device_offload_rx(sk, ctx); + conf = TLS_HW; + if (!rc) { + TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSRXDEVICE); + TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRRXDEVICE); + tls_sw_strparser_arm(sk, ctx); + goto out; } - conf = TLS_SW; } - if (!update) + + rc = tls_set_sw_offload(sk, 0, update ? crypto_info : NULL); + if (rc) + goto err_crypto_info; + + if (update) { + TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSRXREKEYOK); + } else { + TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSRXSW); + TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRRXSW); tls_sw_strparser_arm(sk, ctx); + } + conf = TLS_SW; } +out: if (tx) ctx->tx_conf = conf; else -- 2.25.1