Once an mshare shared page table has been linked with one or more process page tables it becomes necessary to ensure that the shared page table is not completely freed when objects in it are unmapped in order to avoid a potential UAF bug. To do this, introduce and use a reference count for PUD pages. Signed-off-by: Anthony Yznaga --- include/linux/mm.h | 1 + include/linux/mm_types.h | 36 ++++++++++++++++++++++++++++++++++-- mm/memory.c | 21 +++++++++++++++++++-- 3 files changed, 54 insertions(+), 4 deletions(-) diff --git a/include/linux/mm.h b/include/linux/mm.h index 96440082a633..c8dfa5c6e7d4 100644 --- a/include/linux/mm.h +++ b/include/linux/mm.h @@ -3217,6 +3217,7 @@ static inline spinlock_t *pud_lock(struct mm_struct *mm, pud_t *pud) static inline void pagetable_pud_ctor(struct ptdesc *ptdesc) { + ptdesc_pud_pts_init(ptdesc); __pagetable_ctor(ptdesc); } diff --git a/include/linux/mm_types.h b/include/linux/mm_types.h index c8f4d2a2c60b..da5a7a31a81d 100644 --- a/include/linux/mm_types.h +++ b/include/linux/mm_types.h @@ -537,7 +537,7 @@ FOLIO_MATCH(compound_head, _head_3); * @pt_index: Used for s390 gmap. * @pt_mm: Used for x86 pgds. * @pt_frag_refcount: For fragmented page table tracking. Powerpc only. - * @pt_share_count: Used for HugeTLB PMD page table share count. + * @pt_share_count: Used for HugeTLB PMD or Mshare PUD page table share count. * @_pt_pad_2: Padding to ensure proper alignment. * @ptl: Lock for the page table. * @__page_type: Same as page->page_type. Unused for page tables. @@ -564,7 +564,7 @@ struct ptdesc { pgoff_t pt_index; struct mm_struct *pt_mm; atomic_t pt_frag_refcount; -#ifdef CONFIG_HUGETLB_PMD_PAGE_TABLE_SHARING +#if defined(CONFIG_HUGETLB_PMD_PAGE_TABLE_SHARING) || defined(CONFIG_MSHARE) atomic_t pt_share_count; #endif }; @@ -638,6 +638,38 @@ static inline void ptdesc_pmd_pts_init(struct ptdesc *ptdesc) } #endif +#ifdef CONFIG_MSHARE +static inline void ptdesc_pud_pts_init(struct ptdesc *ptdesc) +{ + atomic_set(&ptdesc->pt_share_count, 0); +} + +static inline void ptdesc_pud_pts_inc(struct ptdesc *ptdesc) +{ + atomic_inc(&ptdesc->pt_share_count); +} + +static inline void ptdesc_pud_pts_dec(struct ptdesc *ptdesc) +{ + atomic_dec(&ptdesc->pt_share_count); +} + +static inline int ptdesc_pud_pts_count(struct ptdesc *ptdesc) +{ + return atomic_read(&ptdesc->pt_share_count); +} +#else +static inline void ptdesc_pud_pts_init(struct ptdesc *ptdesc) +{ +} + +static inline int ptdesc_pud_pts_count(struct ptdesc *ptdesc) +{ + return 0; +} +#endif + + /* * Used for sizing the vmemmap region on some architectures */ diff --git a/mm/memory.c b/mm/memory.c index dbc299aa82c2..4e3bb49b95e2 100644 --- a/mm/memory.c +++ b/mm/memory.c @@ -228,9 +228,18 @@ static inline void free_pmd_range(struct mmu_gather *tlb, pud_t *pud, mm_dec_nr_pmds(tlb->mm); } +static inline bool pud_range_is_shared(pud_t *pud) +{ + if (ptdesc_pud_pts_count(virt_to_ptdesc(pud))) + return true; + + return false; +} + static inline void free_pud_range(struct mmu_gather *tlb, p4d_t *p4d, unsigned long addr, unsigned long end, - unsigned long floor, unsigned long ceiling) + unsigned long floor, unsigned long ceiling, + bool *pud_is_shared) { pud_t *pud; unsigned long next; @@ -257,6 +266,10 @@ static inline void free_pud_range(struct mmu_gather *tlb, p4d_t *p4d, return; pud = pud_offset(p4d, start); + if (unlikely(pud_range_is_shared(pud))) { + *pud_is_shared = true; + return; + } p4d_clear(p4d); pud_free_tlb(tlb, pud, start); mm_dec_nr_puds(tlb->mm); @@ -269,6 +282,7 @@ static inline void free_p4d_range(struct mmu_gather *tlb, pgd_t *pgd, p4d_t *p4d; unsigned long next; unsigned long start; + bool pud_is_shared = false; start = addr; p4d = p4d_offset(pgd, addr); @@ -276,7 +290,8 @@ static inline void free_p4d_range(struct mmu_gather *tlb, pgd_t *pgd, next = p4d_addr_end(addr, end); if (p4d_none_or_clear_bad(p4d)) continue; - free_pud_range(tlb, p4d, addr, next, floor, ceiling); + free_pud_range(tlb, p4d, addr, next, floor, ceiling, + &pud_is_shared); } while (p4d++, addr = next, addr != end); start &= PGDIR_MASK; @@ -290,6 +305,8 @@ static inline void free_p4d_range(struct mmu_gather *tlb, pgd_t *pgd, if (end - 1 > ceiling - 1) return; + if (unlikely(pud_is_shared)) + return; p4d = p4d_offset(pgd, start); pgd_clear(pgd); p4d_free_tlb(tlb, p4d, start); -- 2.47.1