Instead of trying to read the data from the msg iterator, callers to tls_alert_recv() need to pass in a kvec directly. Signed-off-by: Olga Kornievskaia --- drivers/nvme/target/tcp.c | 13 +++++-------- include/net/handshake.h | 2 +- net/handshake/alert.c | 6 +++--- net/sunrpc/svcsock.c | 17 ++++++++--------- net/sunrpc/xprtsock.c | 19 +++++++++---------- 5 files changed, 26 insertions(+), 31 deletions(-) diff --git a/drivers/nvme/target/tcp.c b/drivers/nvme/target/tcp.c index 98cee10de713..7ea8644de622 100644 --- a/drivers/nvme/target/tcp.c +++ b/drivers/nvme/target/tcp.c @@ -1106,7 +1106,7 @@ static inline bool nvmet_tcp_pdu_valid(u8 type) } static int nvmet_tcp_tls_record_ok(struct nvmet_tcp_queue *queue, - struct msghdr *msg, char *cbuf) + struct kvec *iov, char *cbuf) { struct cmsghdr *cmsg = (struct cmsghdr *)cbuf; u8 ctype, level, description; @@ -1119,7 +1119,7 @@ static int nvmet_tcp_tls_record_ok(struct nvmet_tcp_queue *queue, case TLS_RECORD_TYPE_DATA: break; case TLS_RECORD_TYPE_ALERT: - tls_alert_recv(queue->sock->sk, msg, &level, &description); + tls_alert_recv(queue->sock->sk, iov, &level, &description); if (level == TLS_ALERT_LEVEL_FATAL) { pr_err("queue %d: TLS Alert desc %u\n", queue->idx, description); @@ -1160,8 +1160,7 @@ static int nvmet_tcp_try_recv_pdu(struct nvmet_tcp_queue *queue) if (unlikely(len < 0)) return len; if (queue->tls_pskid) { - iov_iter_revert(&msg.msg_iter, len); - ret = nvmet_tcp_tls_record_ok(queue, &msg, cbuf); + ret = nvmet_tcp_tls_record_ok(queue, &iov, cbuf); if (ret < 0) return ret; } @@ -1276,8 +1275,7 @@ static int nvmet_tcp_try_recv_ddgst(struct nvmet_tcp_queue *queue) if (unlikely(len < 0)) return len; if (queue->tls_pskid) { - iov_iter_revert(&msg.msg_iter, len); - ret = nvmet_tcp_tls_record_ok(queue, &msg, cbuf); + ret = nvmet_tcp_tls_record_ok(queue, &iov, cbuf); if (ret < 0) return ret; } @@ -1742,8 +1740,7 @@ static int nvmet_tcp_try_peek_pdu(struct nvmet_tcp_queue *queue) return len; } - iov_iter_revert(&msg.msg_iter, len); - ret = nvmet_tcp_tls_record_ok(queue, &msg, cbuf); + ret = nvmet_tcp_tls_record_ok(queue, &iov, cbuf); if (ret < 0) return ret; diff --git a/include/net/handshake.h b/include/net/handshake.h index 8ebd4f9ed26e..33ffc8e88923 100644 --- a/include/net/handshake.h +++ b/include/net/handshake.h @@ -43,7 +43,7 @@ bool tls_handshake_cancel(struct sock *sk); void tls_handshake_close(struct socket *sock); u8 tls_get_record_type(const struct sock *sk, const struct cmsghdr *msg); -void tls_alert_recv(const struct sock *sk, const struct msghdr *msg, +void tls_alert_recv(const struct sock *sk, const struct kvec *iov, u8 *level, u8 *description); #endif /* _NET_HANDSHAKE_H */ diff --git a/net/handshake/alert.c b/net/handshake/alert.c index 329d91984683..4662a406b64a 100644 --- a/net/handshake/alert.c +++ b/net/handshake/alert.c @@ -94,13 +94,13 @@ EXPORT_SYMBOL(tls_get_record_type); * @description: OUT - TLS AlertDescription value * */ -void tls_alert_recv(const struct sock *sk, const struct msghdr *msg, +void tls_alert_recv(const struct sock *sk, const struct kvec *iov, u8 *level, u8 *description) { - const struct kvec *iov; u8 *data; - iov = msg->msg_iter.kvec; + if (!iov) + return; data = iov->iov_base; *level = data[0]; *description = data[1]; diff --git a/net/sunrpc/svcsock.c b/net/sunrpc/svcsock.c index e2c5e0e626f9..8701abd7fff2 100644 --- a/net/sunrpc/svcsock.c +++ b/net/sunrpc/svcsock.c @@ -228,7 +228,7 @@ static int svc_one_sock_name(struct svc_sock *svsk, char *buf, int remaining) } static int -svc_tcp_sock_process_cmsg(struct socket *sock, struct msghdr *msg, +svc_tcp_sock_process_cmsg(struct socket *sock, struct kvec *iov, struct cmsghdr *cmsg, int ret) { u8 content_type = tls_get_record_type(sock->sk, cmsg); @@ -238,14 +238,10 @@ svc_tcp_sock_process_cmsg(struct socket *sock, struct msghdr *msg, case 0: break; case TLS_RECORD_TYPE_DATA: - /* TLS sets EOR at the end of each application data - * record, even though there might be more frames - * waiting to be decrypted. - */ - msg->msg_flags &= ~MSG_EOR; + pr_warn("received TLS DATA; expected TLS control message\n"); break; case TLS_RECORD_TYPE_ALERT: - tls_alert_recv(sock->sk, msg, &level, &description); + tls_alert_recv(sock->sk, iov, &level, &description); ret = (level == TLS_ALERT_LEVEL_FATAL) ? -ENOTCONN : -EAGAIN; break; @@ -280,8 +276,7 @@ svc_tcp_sock_recv_cmsg(struct socket *sock, unsigned int *msg_flags) ret = sock_recvmsg(sock, &msg, MSG_DONTWAIT); if (ret > 0 && tls_get_record_type(sock->sk, &u.cmsg) == TLS_RECORD_TYPE_ALERT) { - iov_iter_revert(&msg.msg_iter, ret); - ret = svc_tcp_sock_process_cmsg(sock, &msg, &u.cmsg, -EAGAIN); + ret = svc_tcp_sock_process_cmsg(sock, &alert_kvec, &u.cmsg, -EAGAIN); } return ret; } @@ -294,6 +289,10 @@ svc_tcp_sock_recvmsg(struct svc_sock *svsk, struct msghdr *msg) ret = sock_recvmsg(sock, msg, MSG_DONTWAIT); if (msg->msg_flags & MSG_CTRUNC) { + /* TLS sets EOR at the end of each application data + * record, even though there might be more frames + * waiting to be decrypted. + */ msg->msg_flags &= ~(MSG_CTRUNC | MSG_EOR); if (ret == 0 || ret == -EIO) ret = svc_tcp_sock_recv_cmsg(sock, &msg->msg_flags); diff --git a/net/sunrpc/xprtsock.c b/net/sunrpc/xprtsock.c index c5f7bbf5775f..005021773da1 100644 --- a/net/sunrpc/xprtsock.c +++ b/net/sunrpc/xprtsock.c @@ -357,7 +357,7 @@ xs_alloc_sparse_pages(struct xdr_buf *buf, size_t want, gfp_t gfp) } static int -xs_sock_process_cmsg(struct socket *sock, struct msghdr *msg, +xs_sock_process_cmsg(struct socket *sock, struct kvec *iov, unsigned int *msg_flags, struct cmsghdr *cmsg, int ret) { u8 content_type = tls_get_record_type(sock->sk, cmsg); @@ -367,14 +367,10 @@ xs_sock_process_cmsg(struct socket *sock, struct msghdr *msg, case 0: break; case TLS_RECORD_TYPE_DATA: - /* TLS sets EOR at the end of each application data - * record, even though there might be more frames - * waiting to be decrypted. - */ - *msg_flags &= ~MSG_EOR; + pr_warn("received TLS DATA; expected TLS control message\n"); break; case TLS_RECORD_TYPE_ALERT: - tls_alert_recv(sock->sk, msg, &level, &description); + tls_alert_recv(sock->sk, iov, &level, &description); ret = (level == TLS_ALERT_LEVEL_FATAL) ? -EACCES : -EAGAIN; break; @@ -409,9 +405,8 @@ xs_sock_recv_cmsg(struct socket *sock, unsigned int *msg_flags, int flags) ret = sock_recvmsg(sock, &msg, flags); if (ret > 0 && tls_get_record_type(sock->sk, &u.cmsg) == TLS_RECORD_TYPE_ALERT) { - iov_iter_revert(&msg.msg_iter, ret); - ret = xs_sock_process_cmsg(sock, &msg, msg_flags, &u.cmsg, - -EAGAIN); + ret = xs_sock_process_cmsg(sock, &alert_kvec, msg_flags, + &u.cmsg, -EAGAIN); } return ret; } @@ -425,6 +420,10 @@ xs_sock_recvmsg(struct socket *sock, struct msghdr *msg, int flags, size_t seek) ret = sock_recvmsg(sock, msg, flags); /* Handle TLS inband control message lazily */ if (msg->msg_flags & MSG_CTRUNC) { + /* TLS sets EOR at the end of each application data + * ecord, even though there might be more frames + * waiting to be decrypted. + */ msg->msg_flags &= ~(MSG_CTRUNC | MSG_EOR); if (ret == 0 || ret == -EIO) ret = xs_sock_recv_cmsg(sock, &msg->msg_flags, flags); -- 2.47.1