memcg->socket_isolated can change at any time, so we must snapshot the value for each socket to ensure consistency. Given sk->sk_memcg can be accessed in the fast path, it would be preferable to place the flag field in the same cache line as sk->sk_memcg. However, struct sock does not have such a 1-byte hole. Let's store the flag in the lowest bit of sk->sk_memcg and add a helper to check the bit. Signed-off-by: Kuniyuki Iwashima --- include/net/sock.h | 20 +++++++++++++++++++- mm/memcontrol.c | 13 +++++++++++-- 2 files changed, 30 insertions(+), 3 deletions(-) diff --git a/include/net/sock.h b/include/net/sock.h index 5e8c73731531c..2e9d76fc2bf38 100644 --- a/include/net/sock.h +++ b/include/net/sock.h @@ -2599,10 +2599,16 @@ static inline gfp_t gfp_memcg_charge(void) #ifdef CONFIG_MEMCG #define MEMCG_SOCK_ISOLATED 1UL +#define MEMCG_SOCK_FLAG_MASK MEMCG_SOCK_ISOLATED +#define MEMCG_SOCK_PTR_MASK ~(MEMCG_SOCK_FLAG_MASK) static inline struct mem_cgroup *mem_cgroup_from_sk(const struct sock *sk) { - return sk->sk_memcg; + unsigned long val = (unsigned long)sk->sk_memcg; + + val &= MEMCG_SOCK_PTR_MASK; + + return (struct mem_cgroup *)val; } static inline bool mem_cgroup_sk_enabled(const struct sock *sk) @@ -2610,6 +2616,13 @@ static inline bool mem_cgroup_sk_enabled(const struct sock *sk) return mem_cgroup_sockets_enabled && mem_cgroup_from_sk(sk); } +static inline bool mem_cgroup_sk_isolated(const struct sock *sk) +{ + struct mem_cgroup *memcg = sk->sk_memcg; + + return (unsigned long)memcg & MEMCG_SOCK_ISOLATED; +} + static inline bool mem_cgroup_sk_under_memory_pressure(const struct sock *sk) { struct mem_cgroup *memcg = mem_cgroup_from_sk(sk); @@ -2636,6 +2649,11 @@ static inline bool mem_cgroup_sk_enabled(const struct sock *sk) return false; } +static inline bool mem_cgroup_sk_isolated(const struct sock *sk) +{ + return false; +} + static inline bool mem_cgroup_sk_under_memory_pressure(const struct sock *sk) { return false; diff --git a/mm/memcontrol.c b/mm/memcontrol.c index 0a55c12a6679b..85decc4319f96 100644 --- a/mm/memcontrol.c +++ b/mm/memcontrol.c @@ -5098,6 +5098,15 @@ void mem_cgroup_migrate(struct folio *old, struct folio *new) DEFINE_STATIC_KEY_FALSE(memcg_sockets_enabled_key); EXPORT_SYMBOL(memcg_sockets_enabled_key); +static void mem_cgroup_sk_set(struct sock *sk, const struct mem_cgroup *memcg) +{ + unsigned long val = (unsigned long)memcg; + + val |= READ_ONCE(memcg->socket_isolated); + + sk->sk_memcg = (struct mem_cgroup *)val; +} + void mem_cgroup_sk_alloc(struct sock *sk) { struct mem_cgroup *memcg; @@ -5116,7 +5125,7 @@ void mem_cgroup_sk_alloc(struct sock *sk) if (!cgroup_subsys_on_dfl(memory_cgrp_subsys) && !memcg1_tcpmem_active(memcg)) goto out; if (css_tryget(&memcg->css)) - sk->sk_memcg = memcg; + mem_cgroup_sk_set(sk, memcg); out: rcu_read_unlock(); } @@ -5138,7 +5147,7 @@ void mem_cgroup_sk_inherit(const struct sock *sk, struct sock *newsk) mem_cgroup_sk_free(newsk); css_get(&memcg->css); - newsk->sk_memcg = memcg; + mem_cgroup_sk_set(newsk, memcg); } /** -- 2.50.0.727.gbf7dc18ff4-goog