Use reference counting to decide when to free socket hash elements instead of freeing them immediately after they are unlinked from a bucket's list. In the next patch this is essential, allowing socket hash iterators to hold a reference to a `struct bpf_shtab_elem` outside of an RCU read-side critical section. sock_hash_put_elem() follows the list, scheduling elements to be freed until it hits an element where the reference count is two or greater. This does nothing yet; in this patch the loop will never iterate more than once, since we always take a reference to the next element in sock_hash_unlink_elem() before calling sock_hash_put_elem(), and in general, the reference count to any element is always one except during these transitions. However, in the next patch it's possible for an iterator to hold a reference to an element that has been unlinked from a bucket's list. In this context, sock_hash_put_elem() may free several unlinked elements up until the point where it finds an element that is still in the bucket's list. Signed-off-by: Jordan Rife --- net/core/sock_map.c | 67 +++++++++++++++++++++++++++++++++------------ 1 file changed, 50 insertions(+), 17 deletions(-) diff --git a/net/core/sock_map.c b/net/core/sock_map.c index 5947b38e4f8b..005112ba19fd 100644 --- a/net/core/sock_map.c +++ b/net/core/sock_map.c @@ -847,6 +847,7 @@ struct bpf_shtab_elem { u32 hash; struct sock *sk; struct hlist_node node; + refcount_t ref; u8 key[]; }; @@ -906,11 +907,46 @@ static struct sock *__sock_hash_lookup_elem(struct bpf_map *map, void *key) return elem ? elem->sk : NULL; } -static void sock_hash_free_elem(struct bpf_shtab *htab, - struct bpf_shtab_elem *elem) +static void sock_hash_free_elem(struct rcu_head *rcu_head) { + struct bpf_shtab_elem *elem = container_of(rcu_head, + struct bpf_shtab_elem, rcu); + + /* Matches sock_hold() in sock_hash_alloc_elem(). */ + sock_put(elem->sk); + kfree(elem); +} + +static void sock_hash_put_elem(struct bpf_shtab_elem *elem) +{ + while (elem && refcount_dec_and_test(&elem->ref)) { + call_rcu(&elem->rcu, sock_hash_free_elem); + elem = hlist_entry_safe(rcu_dereference(hlist_next_rcu(&elem->node)), + struct bpf_shtab_elem, node); + } +} + +static bool sock_hash_hold_elem(struct bpf_shtab_elem *elem) +{ + return refcount_inc_not_zero(&elem->ref); +} + +static void sock_hash_unlink_elem(struct bpf_shtab *htab, + struct bpf_shtab_elem *elem) +{ + struct bpf_shtab_elem *elem_next; + + elem_next = hlist_entry_safe(rcu_dereference(hlist_next_rcu(&elem->node)), + struct bpf_shtab_elem, node); + hlist_del_rcu(&elem->node); + sock_map_unref(elem->sk, elem); + /* Take a reference to the next element first to make sure it's not + * freed by the call to sock_hash_put_elem(). + */ + if (elem_next) + sock_hash_hold_elem(elem_next); + sock_hash_put_elem(elem); atomic_dec(&htab->count); - kfree_rcu(elem, rcu); } static void sock_hash_delete_from_link(struct bpf_map *map, struct sock *sk, @@ -930,11 +966,8 @@ static void sock_hash_delete_from_link(struct bpf_map *map, struct sock *sk, spin_lock_bh(&bucket->lock); elem_probe = sock_hash_lookup_elem_raw(&bucket->head, elem->hash, elem->key, map->key_size); - if (elem_probe && elem_probe == elem) { - hlist_del_rcu(&elem->node); - sock_map_unref(elem->sk, elem); - sock_hash_free_elem(htab, elem); - } + if (elem_probe && elem_probe == elem) + sock_hash_unlink_elem(htab, elem); spin_unlock_bh(&bucket->lock); } @@ -952,9 +985,7 @@ static long sock_hash_delete_elem(struct bpf_map *map, void *key) spin_lock_bh(&bucket->lock); elem = sock_hash_lookup_elem_raw(&bucket->head, hash, key, key_size); if (elem) { - hlist_del_rcu(&elem->node); - sock_map_unref(elem->sk, elem); - sock_hash_free_elem(htab, elem); + sock_hash_unlink_elem(htab, elem); ret = 0; } spin_unlock_bh(&bucket->lock); @@ -985,6 +1016,11 @@ static struct bpf_shtab_elem *sock_hash_alloc_elem(struct bpf_shtab *htab, memcpy(new->key, key, key_size); new->sk = sk; new->hash = hash; + refcount_set(&new->ref, 1); + /* Matches sock_put() in sock_hash_free_elem(). Ensure that sk is not + * freed until elem is. + */ + sock_hold(sk); return new; } @@ -1038,11 +1074,8 @@ static int sock_hash_update_common(struct bpf_map *map, void *key, * concurrent search will find it before old elem. */ hlist_add_head_rcu(&elem_new->node, &bucket->head); - if (elem) { - hlist_del_rcu(&elem->node); - sock_map_unref(elem->sk, elem); - sock_hash_free_elem(htab, elem); - } + if (elem) + sock_hash_unlink_elem(htab, elem); spin_unlock_bh(&bucket->lock); return 0; out_unlock: @@ -1182,7 +1215,7 @@ static void sock_hash_free(struct bpf_map *map) rcu_read_unlock(); release_sock(elem->sk); sock_put(elem->sk); - sock_hash_free_elem(htab, elem); + call_rcu(&elem->rcu, sock_hash_free_elem); } cond_resched(); } -- 2.43.0