diff --git a/drivers/vfio/pci/vfio_pci.c b/drivers/vfio/pci/vfio_pci.c index a603f363835c..b06d76216599 100644 --- a/drivers/vfio/pci/vfio_pci.c +++ b/drivers/vfio/pci/vfio_pci.c @@ -1430,6 +1430,7 @@ static vm_fault_t vfio_pci_mmap_fault(struct vm_fault *vmf) { struct vm_area_struct *vma = vmf->vma; struct vfio_pci_device *vdev = vma->vm_private_data; + struct vfio_pci_mmap_vma *mmap_vma; vm_fault_t ret = VM_FAULT_NOPAGE; mutex_lock(&vdev->vma_lock); @@ -1437,24 +1438,36 @@ static vm_fault_t vfio_pci_mmap_fault(struct vm_fault *vmf) if (!__vfio_pci_memory_enabled(vdev)) { ret = VM_FAULT_SIGBUS; - mutex_unlock(&vdev->vma_lock); + goto up_out; + } + + /* + * We populate the whole vma on fault, so we need to test whether + * the vma has already been mapped, such as for concurrent faults + * to the same vma. io_remap_pfn_range() will trigger a BUG_ON if + * we ask it to fill the same range again. + */ + list_for_each_entry(mmap_vma, &vdev->vma_list, vma_next) { + if (mmap_vma->vma == vma) + goto up_out; + } + + if (io_remap_pfn_range(vma, vma->vm_start, vma->vm_pgoff, + vma->vm_end - vma->vm_start, + vma->vm_page_prot)) { + ret = VM_FAULT_SIGBUS; + zap_vma_ptes(vma, vma->vm_start, vma->vm_end - vma->vm_start); goto up_out; } if (__vfio_pci_add_vma(vdev, vma)) { ret = VM_FAULT_OOM; - mutex_unlock(&vdev->vma_lock); - goto up_out; + zap_vma_ptes(vma, vma->vm_start, vma->vm_end - vma->vm_start); } - mutex_unlock(&vdev->vma_lock); - - if (io_remap_pfn_range(vma, vma->vm_start, vma->vm_pgoff, - vma->vm_end - vma->vm_start, vma->vm_page_prot)) - ret = VM_FAULT_SIGBUS; - up_out: up_read(&vdev->memory_lock); + mutex_unlock(&vdev->vma_lock); return ret; }