Add error checking for failure of crypto_skcipher_en/decrypt() to various rxkad function as the crypto functions can fail with ENOMEM at least. Fixes: 17926a79320a ("[AF_RXRPC]: Provide secure RxRPC sockets for use by userspace and kernel both") Closes: https://sashiko.dev/#/patchset/20260401105614.1696001-10-dhowells@redhat.com Signed-off-by: David Howells cc: Marc Dionne cc: Jeffrey Altman cc: Eric Dumazet cc: "David S. Miller" cc: Jakub Kicinski cc: Paolo Abeni cc: Simon Horman cc: linux-afs@lists.infradead.org cc: netdev@vger.kernel.org cc: stable@kernel.org --- net/rxrpc/rxkad.c | 57 +++++++++++++++++++++++++++++++---------------- 1 file changed, 38 insertions(+), 19 deletions(-) diff --git a/net/rxrpc/rxkad.c b/net/rxrpc/rxkad.c index 0f79d694cb08..eb7f2769d2b1 100644 --- a/net/rxrpc/rxkad.c +++ b/net/rxrpc/rxkad.c @@ -197,6 +197,7 @@ static int rxkad_prime_packet_security(struct rxrpc_connection *conn, struct rxrpc_crypt iv; __be32 *tmpbuf; size_t tmpsize = 4 * sizeof(__be32); + int ret; _enter(""); @@ -225,13 +226,13 @@ static int rxkad_prime_packet_security(struct rxrpc_connection *conn, skcipher_request_set_sync_tfm(req, ci); skcipher_request_set_callback(req, 0, NULL, NULL); skcipher_request_set_crypt(req, &sg, &sg, tmpsize, iv.x); - crypto_skcipher_encrypt(req); + ret = crypto_skcipher_encrypt(req); skcipher_request_free(req); memcpy(&conn->rxkad.csum_iv, tmpbuf + 2, sizeof(conn->rxkad.csum_iv)); kfree(tmpbuf); - _leave(" = 0"); - return 0; + _leave(" = %d", ret); + return ret; } /* @@ -264,6 +265,7 @@ static int rxkad_secure_packet_auth(const struct rxrpc_call *call, struct scatterlist sg; size_t pad; u16 check; + int ret; _enter(""); @@ -286,11 +288,11 @@ static int rxkad_secure_packet_auth(const struct rxrpc_call *call, skcipher_request_set_sync_tfm(req, call->conn->rxkad.cipher); skcipher_request_set_callback(req, 0, NULL, NULL); skcipher_request_set_crypt(req, &sg, &sg, 8, iv.x); - crypto_skcipher_encrypt(req); + ret = crypto_skcipher_encrypt(req); skcipher_request_zero(req); - _leave(" = 0"); - return 0; + _leave(" = %d", ret); + return ret; } /* @@ -345,7 +347,7 @@ static int rxkad_secure_packet(struct rxrpc_call *call, struct rxrpc_txbuf *txb) union { __be32 buf[2]; } crypto __aligned(8); - u32 x, y; + u32 x, y = 0; int ret; _enter("{%d{%x}},{#%u},%u,", @@ -376,8 +378,10 @@ static int rxkad_secure_packet(struct rxrpc_call *call, struct rxrpc_txbuf *txb) skcipher_request_set_sync_tfm(req, call->conn->rxkad.cipher); skcipher_request_set_callback(req, 0, NULL, NULL); skcipher_request_set_crypt(req, &sg, &sg, 8, iv.x); - crypto_skcipher_encrypt(req); + ret = crypto_skcipher_encrypt(req); skcipher_request_zero(req); + if (ret < 0) + goto out; y = ntohl(crypto.buf[1]); y = (y >> 16) & 0xffff; @@ -413,6 +417,7 @@ static int rxkad_secure_packet(struct rxrpc_call *call, struct rxrpc_txbuf *txb) memset(p + txb->pkt_len, 0, gap); } +out: skcipher_request_free(req); _leave(" = %d [set %x]", ret, y); return ret; @@ -453,8 +458,10 @@ static int rxkad_verify_packet_1(struct rxrpc_call *call, struct sk_buff *skb, skcipher_request_set_sync_tfm(req, call->conn->rxkad.cipher); skcipher_request_set_callback(req, 0, NULL, NULL); skcipher_request_set_crypt(req, sg, sg, 8, iv.x); - crypto_skcipher_decrypt(req); + ret = crypto_skcipher_decrypt(req); skcipher_request_zero(req); + if (ret < 0) + return ret; /* Extract the decrypted packet length */ if (skb_copy_bits(skb, sp->offset, &sechdr, sizeof(sechdr)) < 0) @@ -531,10 +538,14 @@ static int rxkad_verify_packet_2(struct rxrpc_call *call, struct sk_buff *skb, skcipher_request_set_sync_tfm(req, call->conn->rxkad.cipher); skcipher_request_set_callback(req, 0, NULL, NULL); skcipher_request_set_crypt(req, sg, sg, sp->len, iv.x); - crypto_skcipher_decrypt(req); + ret = crypto_skcipher_decrypt(req); skcipher_request_zero(req); if (sg != _sg) kfree(sg); + if (ret < 0) { + WARN_ON_ONCE(ret != -ENOMEM); + return ret; + } /* Extract the decrypted packet length */ if (skb_copy_bits(skb, sp->offset, &sechdr, sizeof(sechdr)) < 0) @@ -602,8 +613,10 @@ static int rxkad_verify_packet(struct rxrpc_call *call, struct sk_buff *skb) skcipher_request_set_sync_tfm(req, call->conn->rxkad.cipher); skcipher_request_set_callback(req, 0, NULL, NULL); skcipher_request_set_crypt(req, &sg, &sg, 8, iv.x); - crypto_skcipher_encrypt(req); + ret = crypto_skcipher_encrypt(req); skcipher_request_zero(req); + if (ret < 0) + goto out; y = ntohl(crypto.buf[1]); cksum = (y >> 16) & 0xffff; @@ -1077,21 +1090,23 @@ static int rxkad_decrypt_ticket(struct rxrpc_connection *conn, /* * decrypt the response packet */ -static void rxkad_decrypt_response(struct rxrpc_connection *conn, - struct rxkad_response *resp, - const struct rxrpc_crypt *session_key) +static int rxkad_decrypt_response(struct rxrpc_connection *conn, + struct rxkad_response *resp, + const struct rxrpc_crypt *session_key) { struct skcipher_request *req = rxkad_ci_req; struct scatterlist sg[1]; struct rxrpc_crypt iv; + int ret; _enter(",,%08x%08x", ntohl(session_key->n[0]), ntohl(session_key->n[1])); mutex_lock(&rxkad_ci_mutex); - if (crypto_sync_skcipher_setkey(rxkad_ci, session_key->x, - sizeof(*session_key)) < 0) - BUG(); + ret = crypto_sync_skcipher_setkey(rxkad_ci, session_key->x, + sizeof(*session_key)); + if (ret < 0) + goto unlock; memcpy(&iv, session_key, sizeof(iv)); @@ -1100,12 +1115,14 @@ static void rxkad_decrypt_response(struct rxrpc_connection *conn, skcipher_request_set_sync_tfm(req, rxkad_ci); skcipher_request_set_callback(req, 0, NULL, NULL); skcipher_request_set_crypt(req, sg, sg, sizeof(resp->encrypted), iv.x); - crypto_skcipher_decrypt(req); + ret = crypto_skcipher_decrypt(req); skcipher_request_zero(req); +unlock: mutex_unlock(&rxkad_ci_mutex); _leave(""); + return ret; } /* @@ -1198,7 +1215,9 @@ static int rxkad_verify_response(struct rxrpc_connection *conn, /* use the session key from inside the ticket to decrypt the * response */ - rxkad_decrypt_response(conn, response, &session_key); + ret = rxkad_decrypt_response(conn, response, &session_key); + if (ret < 0) + goto temporary_error_free_ticket; if (ntohl(response->encrypted.epoch) != conn->proto.epoch || ntohl(response->encrypted.cid) != conn->proto.cid ||