io_register_resize_rings() briefly sets ctx->rings to NULL under completion_lock before assigning the new rings and publishing them via rcu_assign_pointer(ctx->rings_rcu, ...). Several code paths read ctx->rings without holding any of those locks, leading to a NULL pointer dereference if they race with a resize: - io_uring_poll() (VFS poll callback) - io_should_wake() (waitqueue wake callback) - io_cqring_min_timer_wakeup() (hrtimer callback) - io_cqring_wait() (called from io_uring_enter) Commit 96189080265e only addressed io_ctx_mark_taskrun() in tw.c. Protect the remaining sites by reading ctx->rings_rcu under rcu_read_lock() (via guard(rcu)/scoped_guard(rcu)) and treating a NULL rings as "no data available / force re-evaluation". Fixes: 79cfe9e59c2a ("io_uring/register: add IORING_REGISTER_RESIZE_RINGS") Cc: stable@vger.kernel.org Signed-off-by: Junxi Qian --- I'm not entirely sure this is the best approach for all the affected call sites -- I'd appreciate any feedback or suggestions on whether this looks reasonable. --- io_uring/io_uring.c | 17 +++++++++--- io_uring/io_uring.h | 9 ++++++- io_uring/wait.c | 63 +++++++++++++++++++++++++++++++++------------ 3 files changed, 69 insertions(+), 20 deletions(-) diff --git a/io_uring/io_uring.c b/io_uring/io_uring.c index 9a37035e7..98029b039 100644 --- a/io_uring/io_uring.c +++ b/io_uring/io_uring.c @@ -2240,6 +2240,7 @@ __cold void io_activate_pollwq(struct io_ring_ctx *ctx) static __poll_t io_uring_poll(struct file *file, poll_table *wait) { struct io_ring_ctx *ctx = file->private_data; + struct io_rings *rings; __poll_t mask = 0; if (unlikely(!ctx->poll_activated)) @@ -2250,7 +2251,17 @@ static __poll_t io_uring_poll(struct file *file, poll_table *wait) */ poll_wait(file, &ctx->poll_wq, wait); - if (!io_sqring_full(ctx)) + /* + * Use the RCU-protected rings pointer to be safe against + * concurrent ring resizing, which briefly NULLs ctx->rings. + */ + guard(rcu)(); + rings = rcu_dereference(ctx->rings_rcu); + if (unlikely(!rings)) + return 0; + + if (READ_ONCE(rings->sq.tail) - READ_ONCE(rings->sq.head) != + ctx->sq_entries) mask |= EPOLLOUT | EPOLLWRNORM; /* @@ -2266,8 +2277,8 @@ static __poll_t io_uring_poll(struct file *file, poll_table *wait) * Users may get EPOLLIN meanwhile seeing nothing in cqring, this * pushes them to do the flush. */ - - if (__io_cqring_events_user(ctx) || io_has_work(ctx)) + if (READ_ONCE(rings->cq.tail) != READ_ONCE(rings->cq.head) || + io_has_work(ctx)) mask |= EPOLLIN | EPOLLRDNORM; return mask; diff --git a/io_uring/io_uring.h b/io_uring/io_uring.h index 0fa844faf..ea953f2c7 100644 --- a/io_uring/io_uring.h +++ b/io_uring/io_uring.h @@ -145,7 +145,14 @@ struct io_wait_queue { static inline bool io_should_wake(struct io_wait_queue *iowq) { struct io_ring_ctx *ctx = iowq->ctx; - int dist = READ_ONCE(ctx->rings->cq.tail) - (int) iowq->cq_tail; + struct io_rings *rings; + int dist; + + guard(rcu)(); + rings = rcu_dereference(ctx->rings_rcu); + if (unlikely(!rings)) + return true; + dist = READ_ONCE(rings->cq.tail) - (int) iowq->cq_tail; /* * Wake up if we have enough events, or if a timeout occurred since we diff --git a/io_uring/wait.c b/io_uring/wait.c index 0581cadf2..af25f8f16 100644 --- a/io_uring/wait.c +++ b/io_uring/wait.c @@ -78,12 +78,20 @@ static enum hrtimer_restart io_cqring_min_timer_wakeup(struct hrtimer *timer) /* work we may need to run, wake function will see if we need to wake */ if (io_has_work(ctx)) goto out_wake; - /* got events since we started waiting, min timeout is done */ - if (iowq->cq_min_tail != READ_ONCE(ctx->rings->cq.tail)) - goto out_wake; - /* if we have any events and min timeout expired, we're done */ - if (io_cqring_events(ctx)) - goto out_wake; + + scoped_guard(rcu) { + struct io_rings *rings = rcu_dereference(ctx->rings_rcu); + + if (!rings) + goto out_wake; + /* got events since we started waiting, min timeout is done */ + if (iowq->cq_min_tail != READ_ONCE(rings->cq.tail)) + goto out_wake; + /* if we have any events and min timeout expired, we're done */ + smp_rmb(); + if (ctx->cached_cq_tail != READ_ONCE(rings->cq.head)) + goto out_wake; + } /* * If using deferred task_work running and application is waiting on @@ -186,7 +194,7 @@ int io_cqring_wait(struct io_ring_ctx *ctx, int min_events, u32 flags, struct ext_arg *ext_arg) { struct io_wait_queue iowq; - struct io_rings *rings = ctx->rings; + struct io_rings *rings; ktime_t start_time; int ret; @@ -201,15 +209,27 @@ int io_cqring_wait(struct io_ring_ctx *ctx, int min_events, u32 flags, if (unlikely(test_bit(IO_CHECK_CQ_OVERFLOW_BIT, &ctx->check_cq))) io_cqring_do_overflow_flush(ctx); - if (__io_cqring_events_user(ctx) >= min_events) - return 0; init_waitqueue_func_entry(&iowq.wq, io_wake_function); iowq.wq.private = current; INIT_LIST_HEAD(&iowq.wq.entry); iowq.ctx = ctx; - iowq.cq_tail = READ_ONCE(ctx->rings->cq.head) + min_events; - iowq.cq_min_tail = READ_ONCE(ctx->rings->cq.tail); + + scoped_guard(rcu) { + rings = rcu_dereference(ctx->rings_rcu); + if (rings) { + if (READ_ONCE(rings->cq.tail) - + READ_ONCE(rings->cq.head) >= + (unsigned int)min_events) + return 0; + iowq.cq_tail = READ_ONCE(rings->cq.head) + + min_events; + iowq.cq_min_tail = READ_ONCE(rings->cq.tail); + } else { + iowq.cq_tail = min_events; + iowq.cq_min_tail = 0; + } + } iowq.nr_timeouts = atomic_read(&ctx->cq_timeouts); iowq.hit_timeout = 0; iowq.min_timeout = ext_arg->min_time; @@ -243,11 +263,16 @@ int io_cqring_wait(struct io_ring_ctx *ctx, int min_events, u32 flags, int nr_wait; /* if min timeout has been hit, don't reset wait count */ - if (!iowq.hit_timeout) - nr_wait = (int) iowq.cq_tail - - READ_ONCE(ctx->rings->cq.tail); - else + if (!iowq.hit_timeout) { + scoped_guard(rcu) { + rings = rcu_dereference(ctx->rings_rcu); + nr_wait = rings ? + (int) iowq.cq_tail - + READ_ONCE(rings->cq.tail) : 1; + } + } else { nr_wait = 1; + } if (ctx->flags & IORING_SETUP_DEFER_TASKRUN) { atomic_set(&ctx->cq_wait_nr, nr_wait); @@ -304,5 +329,11 @@ int io_cqring_wait(struct io_ring_ctx *ctx, int min_events, u32 flags, finish_wait(&ctx->cq_wait, &iowq.wq); restore_saved_sigmask_unless(ret == -EINTR); - return READ_ONCE(rings->cq.head) == READ_ONCE(rings->cq.tail) ? ret : 0; + scoped_guard(rcu) { + rings = rcu_dereference(ctx->rings_rcu); + if (rings && + READ_ONCE(rings->cq.head) != READ_ONCE(rings->cq.tail)) + ret = 0; + } + return ret; } -- 2.34.1