We will store a flag in the lowest bit of sk->sk_memcg. Then, we cannot pass the raw pointer to mem_cgroup_charge_skmem() and mem_cgroup_uncharge_skmem(). Let's pass struct sock to the functions. While at it, they are renamed to match other functions starting with mem_cgroup_sk_. Signed-off-by: Kuniyuki Iwashima --- include/linux/memcontrol.h | 29 ++++++++++++++++++++++++----- mm/memcontrol.c | 18 +++++++++++------- net/core/sock.c | 24 +++++++++++------------- net/ipv4/inet_connection_sock.c | 2 +- net/ipv4/tcp_output.c | 3 +-- 5 files changed, 48 insertions(+), 28 deletions(-) diff --git a/include/linux/memcontrol.h b/include/linux/memcontrol.h index d8319ad5e8ea7..9ccbcddbe3b8e 100644 --- a/include/linux/memcontrol.h +++ b/include/linux/memcontrol.h @@ -1594,15 +1594,16 @@ static inline void mem_cgroup_flush_foreign(struct bdi_writeback *wb) #endif /* CONFIG_CGROUP_WRITEBACK */ struct sock; -bool mem_cgroup_charge_skmem(struct mem_cgroup *memcg, unsigned int nr_pages, - gfp_t gfp_mask); -void mem_cgroup_uncharge_skmem(struct mem_cgroup *memcg, unsigned int nr_pages); #ifdef CONFIG_MEMCG extern struct static_key_false memcg_sockets_enabled_key; #define mem_cgroup_sockets_enabled static_branch_unlikely(&memcg_sockets_enabled_key) + void mem_cgroup_sk_alloc(struct sock *sk); void mem_cgroup_sk_free(struct sock *sk); void mem_cgroup_sk_inherit(const struct sock *sk, struct sock *newsk); +bool mem_cgroup_sk_charge(const struct sock *sk, unsigned int nr_pages, + gfp_t gfp_mask); +void mem_cgroup_sk_uncharge(const struct sock *sk, unsigned int nr_pages); static inline bool mem_cgroup_under_socket_pressure(struct mem_cgroup *memcg) { @@ -1623,13 +1624,31 @@ void set_shrinker_bit(struct mem_cgroup *memcg, int nid, int shrinker_id); void reparent_shrinker_deferred(struct mem_cgroup *memcg); #else #define mem_cgroup_sockets_enabled 0 -static inline void mem_cgroup_sk_alloc(struct sock *sk) { }; -static inline void mem_cgroup_sk_free(struct sock *sk) { }; + +static inline void mem_cgroup_sk_alloc(struct sock *sk) +{ +} + +static inline void mem_cgroup_sk_free(struct sock *sk) +{ +} static inline void mem_cgroup_sk_inherit(const struct sock *sk, struct sock *newsk) { } +static inline bool mem_cgroup_sk_charge(const struct sock *sk, + unsigned int nr_pages, + gfp_t gfp_mask) +{ + return false; +} + +static inline void mem_cgroup_sk_uncharge(const struct sock *sk, + unsigned int nr_pages) +{ +} + static inline bool mem_cgroup_under_socket_pressure(struct mem_cgroup *memcg) { return false; diff --git a/mm/memcontrol.c b/mm/memcontrol.c index 89b33e635cf89..d7f4e31f4e625 100644 --- a/mm/memcontrol.c +++ b/mm/memcontrol.c @@ -5105,17 +5105,19 @@ void mem_cgroup_sk_inherit(const struct sock *sk, struct sock *newsk) } /** - * mem_cgroup_charge_skmem - charge socket memory - * @memcg: memcg to charge + * mem_cgroup_sk_charge - charge socket memory + * @sk: socket in memcg to charge * @nr_pages: number of pages to charge * @gfp_mask: reclaim mode * * Charges @nr_pages to @memcg. Returns %true if the charge fit within * @memcg's configured limit, %false if it doesn't. */ -bool mem_cgroup_charge_skmem(struct mem_cgroup *memcg, unsigned int nr_pages, - gfp_t gfp_mask) +bool mem_cgroup_sk_charge(const struct sock *sk, unsigned int nr_pages, + gfp_t gfp_mask) { + struct mem_cgroup *memcg = mem_cgroup_from_sk(sk); + if (!cgroup_subsys_on_dfl(memory_cgrp_subsys)) return memcg1_charge_skmem(memcg, nr_pages, gfp_mask); @@ -5128,12 +5130,14 @@ bool mem_cgroup_charge_skmem(struct mem_cgroup *memcg, unsigned int nr_pages, } /** - * mem_cgroup_uncharge_skmem - uncharge socket memory - * @memcg: memcg to uncharge + * mem_cgroup_sk_uncharge - uncharge socket memory + * @sk: socket in memcg to uncharge * @nr_pages: number of pages to uncharge */ -void mem_cgroup_uncharge_skmem(struct mem_cgroup *memcg, unsigned int nr_pages) +void mem_cgroup_sk_uncharge(const struct sock *sk, unsigned int nr_pages) { + struct mem_cgroup *memcg = mem_cgroup_from_sk(sk); + if (!cgroup_subsys_on_dfl(memory_cgrp_subsys)) { memcg1_uncharge_skmem(memcg, nr_pages); return; diff --git a/net/core/sock.c b/net/core/sock.c index ab658fe23e1e6..5537ca2638588 100644 --- a/net/core/sock.c +++ b/net/core/sock.c @@ -1041,8 +1041,8 @@ static int sock_reserve_memory(struct sock *sk, int bytes) pages = sk_mem_pages(bytes); /* pre-charge to memcg */ - charged = mem_cgroup_charge_skmem(sk->sk_memcg, pages, - GFP_KERNEL | __GFP_RETRY_MAYFAIL); + charged = mem_cgroup_sk_charge(sk, pages, + GFP_KERNEL | __GFP_RETRY_MAYFAIL); if (!charged) return -ENOMEM; @@ -1054,7 +1054,7 @@ static int sock_reserve_memory(struct sock *sk, int bytes) */ if (allocated > sk_prot_mem_limits(sk, 1)) { sk_memory_allocated_sub(sk, pages); - mem_cgroup_uncharge_skmem(sk->sk_memcg, pages); + mem_cgroup_sk_uncharge(sk, pages); return -ENOMEM; } sk_forward_alloc_add(sk, pages << PAGE_SHIFT); @@ -3263,17 +3263,16 @@ EXPORT_SYMBOL(sk_wait_data); */ int __sk_mem_raise_allocated(struct sock *sk, int size, int amt, int kind) { + bool memcg_enabled = false, charged = false; struct proto *prot = sk->sk_prot; - struct mem_cgroup *memcg = NULL; - bool charged = false; long allocated; sk_memory_allocated_add(sk, amt); allocated = sk_memory_allocated(sk); if (mem_cgroup_sk_enabled(sk)) { - memcg = sk->sk_memcg; - charged = mem_cgroup_charge_skmem(memcg, amt, gfp_memcg_charge()); + memcg_enabled = true; + charged = mem_cgroup_sk_charge(sk, amt, gfp_memcg_charge()); if (!charged) goto suppress_allocation; } @@ -3347,10 +3346,9 @@ int __sk_mem_raise_allocated(struct sock *sk, int size, int amt, int kind) */ if (sk->sk_wmem_queued + size >= sk->sk_sndbuf) { /* Force charge with __GFP_NOFAIL */ - if (memcg && !charged) { - mem_cgroup_charge_skmem(memcg, amt, - gfp_memcg_charge() | __GFP_NOFAIL); - } + if (memcg_enabled && !charged) + mem_cgroup_sk_charge(sk, amt, + gfp_memcg_charge() | __GFP_NOFAIL); return 1; } } @@ -3360,7 +3358,7 @@ int __sk_mem_raise_allocated(struct sock *sk, int size, int amt, int kind) sk_memory_allocated_sub(sk, amt); if (charged) - mem_cgroup_uncharge_skmem(memcg, amt); + mem_cgroup_sk_uncharge(sk, amt); return 0; } @@ -3399,7 +3397,7 @@ void __sk_mem_reduce_allocated(struct sock *sk, int amount) sk_memory_allocated_sub(sk, amount); if (mem_cgroup_sk_enabled(sk)) - mem_cgroup_uncharge_skmem(sk->sk_memcg, amount); + mem_cgroup_sk_uncharge(sk, amount); if (sk_under_global_memory_pressure(sk) && (sk_memory_allocated(sk) < sk_prot_mem_limits(sk, 0))) diff --git a/net/ipv4/inet_connection_sock.c b/net/ipv4/inet_connection_sock.c index 93569bbe00f44..0ef1eacd539d1 100644 --- a/net/ipv4/inet_connection_sock.c +++ b/net/ipv4/inet_connection_sock.c @@ -727,7 +727,7 @@ struct sock *inet_csk_accept(struct sock *sk, struct proto_accept_arg *arg) } if (amt) - mem_cgroup_charge_skmem(newsk->sk_memcg, amt, gfp); + mem_cgroup_sk_charge(newsk, amt, gfp); kmem_cache_charge(newsk, gfp); release_sock(newsk); diff --git a/net/ipv4/tcp_output.c b/net/ipv4/tcp_output.c index 4e0af5c824c1a..09f0802f36afa 100644 --- a/net/ipv4/tcp_output.c +++ b/net/ipv4/tcp_output.c @@ -3567,8 +3567,7 @@ void sk_forced_mem_schedule(struct sock *sk, int size) sk_memory_allocated_add(sk, amt); if (mem_cgroup_sk_enabled(sk)) - mem_cgroup_charge_skmem(sk->sk_memcg, amt, - gfp_memcg_charge() | __GFP_NOFAIL); + mem_cgroup_sk_charge(sk, amt, gfp_memcg_charge() | __GFP_NOFAIL); } /* Send a FIN. The caller locks the socket for us. -- 2.50.0.727.gbf7dc18ff4-goog