hmm_range_fault() currently triggers page faults from inside the page-table walk callbacks: hmm_vma_walk_pmd(), hmm_vma_walk_pud(), hmm_vma_walk_hugetlb_entry() and the pte-level helper all call hmm_vma_fault(), which in turn calls handle_mm_fault() while the walker still holds nested locks. The pte spinlock is dropped explicitly by each caller, and the hugetlb path manually drops and retakes hugetlb_vma_lock_read around the fault to dodge a deadlock against the walk framework's unconditional unlock. This layering does not extend cleanly to fault handlers that may release mmap_lock (VM_FAULT_RETRY, VM_FAULT_COMPLETED). If the lock is dropped while walk_page_range() is mid-traversal, the VMA can be freed before the walk framework's matching hugetlb_vma_unlock_read(), turning that unlock into a use-after-free. Split the responsibilities the way get_user_pages() does. Walk callbacks become inspect-only: when they detect a range that needs to be faulted in, they record it in struct hmm_vma_walk and return a private sentinel (HMM_FAULT_PENDING). The outer loop in hmm_range_fault() then drops out of walk_page_range(), invokes a new helper hmm_do_fault() that calls handle_mm_fault() with only mmap_lock held, and restarts the walk so the now-present entries are collected into hmm_pfns. No functional change for existing callers. As a side effect the hugetlb callback no longer needs the hugetlb_vma_{un}lock_read dance, and every fault-path exit from the callbacks now releases the pte spinlock on a single, common path. This refactor is also a precursor for adding an unlockable variant of hmm_range_fault() in a follow-up patch. Signed-off-by: Stanislav Kinsburskii --- mm/hmm.c | 118 +++++++++++++++++++++++++++++++++++++++----------------------- 1 file changed, 75 insertions(+), 43 deletions(-) diff --git a/mm/hmm.c b/mm/hmm.c index 5955f2f0c83db..2b157fcbc2928 100644 --- a/mm/hmm.c +++ b/mm/hmm.c @@ -33,8 +33,17 @@ struct hmm_vma_walk { struct hmm_range *range; unsigned long last; + unsigned long end; + unsigned int required_fault; }; +/* + * Internal sentinel returned by walk callbacks when they need a page fault. + * The callback stores end/required_fault in hmm_vma_walk; the outer loop + * consumes the sentinel and never propagates it to the caller. + */ +#define HMM_FAULT_PENDING -EAGAIN + enum { HMM_NEED_FAULT = 1 << 0, HMM_NEED_WRITE_FAULT = 1 << 1, @@ -60,37 +69,25 @@ static int hmm_pfns_fill(unsigned long addr, unsigned long end, } /* - * hmm_vma_fault() - fault in a range lacking valid pmd or pte(s) - * @addr: range virtual start address (inclusive) - * @end: range virtual end address (exclusive) - * @required_fault: HMM_NEED_* flags - * @walk: mm_walk structure - * Return: -EBUSY after page fault, or page fault error + * hmm_record_fault() - record a range that needs to be faulted in * - * This function will be called whenever pmd_none() or pte_none() returns true, - * or whenever there is no page directory covering the virtual address range. + * Called by the walk callbacks when they discover that part of the range + * needs a page fault. The callback records what to fault and returns + * HMM_FAULT_PENDING; the outer loop in hmm_range_fault() drops back out of + * walk_page_range() and invokes handle_mm_fault() from a context where no + * page-table or hugetlb_vma_lock is held. */ -static int hmm_vma_fault(unsigned long addr, unsigned long end, - unsigned int required_fault, struct mm_walk *walk) +static int hmm_record_fault(unsigned long addr, unsigned long end, + unsigned int required_fault, + struct mm_walk *walk) { struct hmm_vma_walk *hmm_vma_walk = walk->private; - struct vm_area_struct *vma = walk->vma; - unsigned int fault_flags = FAULT_FLAG_REMOTE; WARN_ON_ONCE(!required_fault); hmm_vma_walk->last = addr; - - if (required_fault & HMM_NEED_WRITE_FAULT) { - if (!(vma->vm_flags & VM_WRITE)) - return -EPERM; - fault_flags |= FAULT_FLAG_WRITE; - } - - for (; addr < end; addr += PAGE_SIZE) - if (handle_mm_fault(vma, addr, fault_flags, NULL) & - VM_FAULT_ERROR) - return -EFAULT; - return -EBUSY; + hmm_vma_walk->end = end; + hmm_vma_walk->required_fault = required_fault; + return HMM_FAULT_PENDING; } static unsigned int hmm_pte_need_fault(const struct hmm_vma_walk *hmm_vma_walk, @@ -174,7 +171,7 @@ static int hmm_vma_walk_hole(unsigned long addr, unsigned long end, return hmm_pfns_fill(addr, end, range, HMM_PFN_ERROR); } if (required_fault) - return hmm_vma_fault(addr, end, required_fault, walk); + return hmm_record_fault(addr, end, required_fault, walk); return hmm_pfns_fill(addr, end, range, 0); } @@ -209,7 +206,7 @@ static int hmm_vma_handle_pmd(struct mm_walk *walk, unsigned long addr, required_fault = hmm_range_need_fault(hmm_vma_walk, hmm_pfns, npages, cpu_flags); if (required_fault) - return hmm_vma_fault(addr, end, required_fault, walk); + return hmm_record_fault(addr, end, required_fault, walk); pfn = pmd_pfn(pmd) + ((addr & ~PMD_MASK) >> PAGE_SHIFT); for (i = 0; addr < end; addr += PAGE_SIZE, i++, pfn++) { @@ -328,7 +325,7 @@ static int hmm_vma_handle_pte(struct mm_walk *walk, unsigned long addr, fault: pte_unmap(ptep); /* Fault any virtual address we were asked to fault */ - return hmm_vma_fault(addr, end, required_fault, walk); + return hmm_record_fault(addr, end, required_fault, walk); } #ifdef CONFIG_ARCH_ENABLE_THP_MIGRATION @@ -371,7 +368,7 @@ static int hmm_vma_handle_absent_pmd(struct mm_walk *walk, unsigned long start, npages, 0); if (required_fault) { if (softleaf_is_device_private(entry)) - return hmm_vma_fault(addr, end, required_fault, walk); + return hmm_record_fault(addr, end, required_fault, walk); else return -EFAULT; } @@ -517,7 +514,7 @@ static int hmm_vma_walk_pud(pud_t *pudp, unsigned long start, unsigned long end, npages, cpu_flags); if (required_fault) { spin_unlock(ptl); - return hmm_vma_fault(addr, end, required_fault, walk); + return hmm_record_fault(addr, end, required_fault, walk); } pfn = pud_pfn(pud) + ((addr & ~PUD_MASK) >> PAGE_SHIFT); @@ -564,21 +561,8 @@ static int hmm_vma_walk_hugetlb_entry(pte_t *pte, unsigned long hmask, required_fault = hmm_pte_need_fault(hmm_vma_walk, pfn_req_flags, cpu_flags); if (required_fault) { - int ret; - spin_unlock(ptl); - hugetlb_vma_unlock_read(vma); - /* - * Avoid deadlock: drop the vma lock before calling - * hmm_vma_fault(), which will itself potentially take and - * drop the vma lock. This is also correct from a - * protection point of view, because there is no further - * use here of either pte or ptl after dropping the vma - * lock. - */ - ret = hmm_vma_fault(addr, end, required_fault, walk); - hugetlb_vma_lock_read(vma); - return ret; + return hmm_record_fault(addr, end, required_fault, walk); } pfn = pte_pfn(entry) + ((start & ~hmask) >> PAGE_SHIFT); @@ -637,6 +621,44 @@ static const struct mm_walk_ops hmm_walk_ops = { .walk_lock = PGWALK_RDLOCK, }; +/* + * hmm_do_fault - fault in a range recorded by a walk callback + * + * Called from the outer loop in hmm_range_fault() after a callback + * returned HMM_FAULT_PENDING. At this point we hold only mmap_lock; + * the page-table spinlock and any hugetlb_vma_lock acquired by the walk + * framework have already been released by the unwind. + * + * Returns -EBUSY on success (all pages faulted, caller should re-walk). + * Returns a negative errno on failure. + */ +static int hmm_do_fault(struct mm_struct *mm, + struct hmm_vma_walk *hmm_vma_walk) +{ + unsigned long addr = hmm_vma_walk->last; + unsigned long end = hmm_vma_walk->end; + unsigned int required_fault = hmm_vma_walk->required_fault; + unsigned int fault_flags = FAULT_FLAG_REMOTE; + struct vm_area_struct *vma; + + vma = vma_lookup(mm, addr); + if (!vma) + return -EFAULT; + + if (required_fault & HMM_NEED_WRITE_FAULT) { + if (!(vma->vm_flags & VM_WRITE)) + return -EPERM; + fault_flags |= FAULT_FLAG_WRITE; + } + + for (; addr < end; addr += PAGE_SIZE) + if (handle_mm_fault(vma, addr, fault_flags, NULL) & + VM_FAULT_ERROR) + return -EFAULT; + + return -EBUSY; +} + /** * hmm_range_fault - try to fault some address in a virtual address range * @range: argument structure @@ -674,6 +696,16 @@ int hmm_range_fault(struct hmm_range *range) return -EBUSY; ret = walk_page_range(mm, hmm_vma_walk.last, range->end, &hmm_walk_ops, &hmm_vma_walk); + /* + * When HMM_FAULT_PENDING is returned a walk callback + * recorded a range that needs handle_mm_fault(); + * hmm_do_fault() runs the fault outside walk_page_range() + * (so no page-table or hugetlb_vma_lock is held) and + * returns -EBUSY so the loop re-walks and picks up the + * now-present entries. + */ + if (ret == HMM_FAULT_PENDING) + ret = hmm_do_fault(mm, &hmm_vma_walk); /* * When -EBUSY is returned the loop restarts with * hmm_vma_walk.last set to an address that has not been stored