Require that iter->batch always contains a full bucket snapshot. This invariant is important to avoid skipping or repeating sockets during iteration when combined with the next few patches. Before, there were two cases where a call to bpf_iter_tcp_batch may only capture part of a bucket: 1. When bpf_iter_tcp_realloc_batch() returns -ENOMEM. 2. When more sockets are added to the bucket while calling bpf_iter_tcp_realloc_batch(), making the updated batch size insufficient. In cases where the batch size only covers part of a bucket, it is possible to forget which sockets were already visited, especially if we have to process a bucket in more than two batches. This forces us to choose between repeating or skipping sockets, so don't allow this: 1. Stop iteration and propagate -ENOMEM up to userspace if reallocation fails instead of continuing with a partial batch. 2. Try bpf_iter_tcp_realloc_batch() with GFP_USER just as before, but if we still aren't able to capture the full bucket, call bpf_iter_tcp_realloc_batch() again while holding the bucket lock to guarantee the bucket does not change. On the second attempt use GFP_NOWAIT since we hold onto the spin lock. I did some manual testing to exercise the code paths where GFP_NOWAIT is used and where ERR_PTR(err) is returned. I used the realloc test cases included later in this series to trigger a scenario where a realloc happens inside bpf_iter_tcp_batch and made a small code tweak to force the first realloc attempt to allocate a too-small batch, thus requiring another attempt with GFP_NOWAIT. Some printks showed both reallocs with the tests passing: Jun 27 00:00:53 crow kernel: again GFP_USER Jun 27 00:00:53 crow kernel: again GFP_NOWAIT Jun 27 00:00:53 crow kernel: again GFP_USER Jun 27 00:00:53 crow kernel: again GFP_NOWAIT With this setup, I also forced each of the bpf_iter_tcp_realloc_batch calls to return -ENOMEM to ensure that iteration ends and that the read() in userspace fails. Signed-off-by: Jordan Rife Reviewed-by: Kuniyuki Iwashima Acked-by: Stanislav Fomichev --- net/ipv4/tcp_ipv4.c | 109 +++++++++++++++++++++++++++++++------------- 1 file changed, 77 insertions(+), 32 deletions(-) diff --git a/net/ipv4/tcp_ipv4.c b/net/ipv4/tcp_ipv4.c index 2e40af6aff37..8dfb87be422e 100644 --- a/net/ipv4/tcp_ipv4.c +++ b/net/ipv4/tcp_ipv4.c @@ -3057,7 +3057,7 @@ static int bpf_iter_tcp_realloc_batch(struct bpf_tcp_iter_state *iter, if (!new_batch) return -ENOMEM; - bpf_iter_tcp_put_batch(iter); + memcpy(new_batch, iter->batch, sizeof(*iter->batch) * iter->end_sk); kvfree(iter->batch); iter->batch = new_batch; iter->max_sk = new_batch_sz; @@ -3066,69 +3066,95 @@ static int bpf_iter_tcp_realloc_batch(struct bpf_tcp_iter_state *iter, } static unsigned int bpf_iter_tcp_listening_batch(struct seq_file *seq, - struct sock *start_sk) + struct sock **start_sk) { - struct inet_hashinfo *hinfo = seq_file_net(seq)->ipv4.tcp_death_row.hashinfo; struct bpf_tcp_iter_state *iter = seq->private; - struct tcp_iter_state *st = &iter->state; struct hlist_nulls_node *node; unsigned int expected = 1; struct sock *sk; - sock_hold(start_sk); - iter->batch[iter->end_sk++] = start_sk; + sock_hold(*start_sk); + iter->batch[iter->end_sk++] = *start_sk; - sk = sk_nulls_next(start_sk); + sk = sk_nulls_next(*start_sk); + *start_sk = NULL; sk_nulls_for_each_from(sk, node) { if (seq_sk_match(seq, sk)) { if (iter->end_sk < iter->max_sk) { sock_hold(sk); iter->batch[iter->end_sk++] = sk; + } else if (!*start_sk) { + /* Remember where we left off. */ + *start_sk = sk; } expected++; } } - spin_unlock(&hinfo->lhash2[st->bucket].lock); return expected; } static unsigned int bpf_iter_tcp_established_batch(struct seq_file *seq, - struct sock *start_sk) + struct sock **start_sk) { - struct inet_hashinfo *hinfo = seq_file_net(seq)->ipv4.tcp_death_row.hashinfo; struct bpf_tcp_iter_state *iter = seq->private; - struct tcp_iter_state *st = &iter->state; struct hlist_nulls_node *node; unsigned int expected = 1; struct sock *sk; - sock_hold(start_sk); - iter->batch[iter->end_sk++] = start_sk; + sock_hold(*start_sk); + iter->batch[iter->end_sk++] = *start_sk; - sk = sk_nulls_next(start_sk); + sk = sk_nulls_next(*start_sk); + *start_sk = NULL; sk_nulls_for_each_from(sk, node) { if (seq_sk_match(seq, sk)) { if (iter->end_sk < iter->max_sk) { sock_hold(sk); iter->batch[iter->end_sk++] = sk; + } else if (!*start_sk) { + /* Remember where we left off. */ + *start_sk = sk; } expected++; } } - spin_unlock_bh(inet_ehash_lockp(hinfo, st->bucket)); return expected; } +static unsigned int bpf_iter_fill_batch(struct seq_file *seq, + struct sock **start_sk) +{ + struct bpf_tcp_iter_state *iter = seq->private; + struct tcp_iter_state *st = &iter->state; + + if (st->state == TCP_SEQ_STATE_LISTENING) + return bpf_iter_tcp_listening_batch(seq, start_sk); + else + return bpf_iter_tcp_established_batch(seq, start_sk); +} + +static void bpf_iter_tcp_unlock_bucket(struct seq_file *seq) +{ + struct inet_hashinfo *hinfo = seq_file_net(seq)->ipv4.tcp_death_row.hashinfo; + struct bpf_tcp_iter_state *iter = seq->private; + struct tcp_iter_state *st = &iter->state; + + if (st->state == TCP_SEQ_STATE_LISTENING) + spin_unlock(&hinfo->lhash2[st->bucket].lock); + else + spin_unlock_bh(inet_ehash_lockp(hinfo, st->bucket)); +} + static struct sock *bpf_iter_tcp_batch(struct seq_file *seq) { struct inet_hashinfo *hinfo = seq_file_net(seq)->ipv4.tcp_death_row.hashinfo; struct bpf_tcp_iter_state *iter = seq->private; struct tcp_iter_state *st = &iter->state; unsigned int expected; - bool resized = false; struct sock *sk; + int err; /* The st->bucket is done. Directly advance to the next * bucket instead of having the tcp_seek_last_pos() to skip @@ -3145,33 +3171,52 @@ static struct sock *bpf_iter_tcp_batch(struct seq_file *seq) } } -again: - /* Get a new batch */ iter->cur_sk = 0; iter->end_sk = 0; - iter->st_bucket_done = false; + iter->st_bucket_done = true; sk = tcp_seek_last_pos(seq); if (!sk) return NULL; /* Done */ - if (st->state == TCP_SEQ_STATE_LISTENING) - expected = bpf_iter_tcp_listening_batch(seq, sk); - else - expected = bpf_iter_tcp_established_batch(seq, sk); + expected = bpf_iter_fill_batch(seq, &sk); + if (likely(iter->end_sk == expected)) + goto done; - if (iter->end_sk == expected) { - iter->st_bucket_done = true; - return sk; - } + /* Batch size was too small. */ + bpf_iter_tcp_unlock_bucket(seq); + bpf_iter_tcp_put_batch(iter); + err = bpf_iter_tcp_realloc_batch(iter, expected * 3 / 2, + GFP_USER); + if (err) + return ERR_PTR(err); + + iter->cur_sk = 0; + iter->end_sk = 0; + + sk = tcp_seek_last_pos(seq); + if (!sk) + return NULL; /* Done */ + + expected = bpf_iter_fill_batch(seq, &sk); + if (likely(iter->end_sk == expected)) + goto done; - if (!resized && !bpf_iter_tcp_realloc_batch(iter, expected * 3 / 2, - GFP_USER)) { - resized = true; - goto again; + /* Batch size was still too small. Hold onto the lock while we try + * again with a larger batch to make sure the current bucket's size + * does not change in the meantime. + */ + err = bpf_iter_tcp_realloc_batch(iter, expected, GFP_NOWAIT); + if (err) { + bpf_iter_tcp_unlock_bucket(seq); + return ERR_PTR(err); } - return sk; + expected = bpf_iter_fill_batch(seq, &sk); + WARN_ON_ONCE(iter->end_sk != expected); +done: + bpf_iter_tcp_unlock_bucket(seq); + return iter->batch[0]; } static void *bpf_iter_tcp_seq_start(struct seq_file *seq, loff_t *pos) -- 2.43.0