Require that Live Update preserved devices are in singleton iommu_groups during preservation (outgoing kernel) and retrieval (incoming kernel). PCI devices preserved across Live Update will be allowed to perform memory transactions throughout the Live Update. Thus IOMMU groups for preserved devices must remain fixed. Since all current use cases for Live Update are for PCI devices in singleton iommu_groups, require that as a starting point. This avoids the complexity of needing to enforce arbitrary iommu_group topologies while still allowing all current use cases. Suggested-by: Jason Gunthorpe Signed-off-by: David Matlack --- drivers/pci/liveupdate.c | 34 +++++++++++++++++++++++++++++++++- 1 file changed, 33 insertions(+), 1 deletion(-) diff --git a/drivers/pci/liveupdate.c b/drivers/pci/liveupdate.c index bec7b3500057..a3dbe06650ff 100644 --- a/drivers/pci/liveupdate.c +++ b/drivers/pci/liveupdate.c @@ -75,6 +75,8 @@ * * * The device must not be a Physical Function (PF). * + * * The device must be the only device in its IOMMU group. + * * Preservation Behavior * ===================== * @@ -105,6 +107,7 @@ #include #include +#include #include #include #include @@ -222,6 +225,31 @@ static void pci_ser_delete(struct pci_ser *ser, struct pci_dev *dev) ser->nr_devices--; } +static int count_devices(struct device *dev, void *__nr_devices) +{ + (*(int *)__nr_devices)++; + return 0; +} + +static int pci_liveupdate_validate_iommu_group(struct pci_dev *dev) +{ + struct iommu_group *group; + int nr_devices = 0; + + group = iommu_group_get(&dev->dev); + if (group) { + iommu_group_for_each_dev(group, &nr_devices, count_devices); + iommu_group_put(group); + } + + if (nr_devices != 1) { + pci_warn(dev, "Live Update preserved devices must be in singleton iommu groups!"); + return -EINVAL; + } + + return 0; +} + int pci_liveupdate_preserve(struct pci_dev *dev) { struct pci_dev_ser new = INIT_PCI_DEV_SER(dev); @@ -232,6 +260,10 @@ int pci_liveupdate_preserve(struct pci_dev *dev) if (dev->is_virtfn || dev->is_physfn) return -EINVAL; + ret = pci_liveupdate_validate_iommu_group(dev); + if (ret) + return ret; + guard(mutex)(&pci_flb_outgoing_lock); if (dev->liveupdate_outgoing) @@ -357,7 +389,7 @@ int pci_liveupdate_retrieve(struct pci_dev *dev) if (!dev->liveupdate_incoming) return -EINVAL; - return 0; + return pci_liveupdate_validate_iommu_group(dev); } EXPORT_SYMBOL_GPL(pci_liveupdate_retrieve); -- 2.53.0.983.g0bb29b3bc5-goog