Implement pci_reachable_set() to efficiently compute a set of devices on the same bus that are "reachable" from a starting device. The meaning of reachability is defined by the caller through a callback function. This is a faster implementation of the same logic in pci_device_group(). Being inside the PCI core allows use of pci_bus_sem so it can use list_for_each_entry() on a small list of devices instead of the expensive for_each_pci_dev(). Server systems can now have hundreds of PCI devices, but typically only a very small number of devices per bus. An example of a reachability function would be pci_devs_are_dma_aliases() which would compute a set of devices on the same bus that are aliases. This would also be useful in future support for the ACS P2P Egress Vector which has a similar reachability problem. This is effectively a graph algorithm where the set of devices on the bus are vertexes and the reachable() function defines the edges. It returns a set of vertexes that form a connected graph. Signed-off-by: Jason Gunthorpe --- drivers/pci/search.c | 90 ++++++++++++++++++++++++++++++++++++++++++++ include/linux/pci.h | 12 ++++++ 2 files changed, 102 insertions(+) diff --git a/drivers/pci/search.c b/drivers/pci/search.c index fe6c07e67cb8ce..dac6b042fd5f5d 100644 --- a/drivers/pci/search.c +++ b/drivers/pci/search.c @@ -595,3 +595,93 @@ int pci_dev_present(const struct pci_device_id *ids) return 0; } EXPORT_SYMBOL(pci_dev_present); + +/** + * pci_reachable_set - Generate a bitmap of devices within a reachability set + * @start: First device in the set + * @devfns: The set of devices on the bus + * @reachable: Callback to tell if two devices can reach each other + * + * Compute a bitmap where every set bit is a device on the bus that is reachable + * from the start device, including the start device. Reachability between two + * devices is determined by a callback function. + * + * This is a non-recursive implementation that invokes the callback once per + * pair. The callback must be commutative: + * reachable(a, b) == reachable(b, a) + * reachable() can form a cyclic graph: + * reachable(a,b) == reachable(b,c) == reachable(c,a) == true + * + * Since this function is limited to a single bus the largest set can be 256 + * devices large. + */ +void pci_reachable_set(struct pci_dev *start, struct pci_reachable_set *devfns, + bool (*reachable)(struct pci_dev *deva, + struct pci_dev *devb)) +{ + struct pci_reachable_set todo_devfns = {}; + struct pci_reachable_set next_devfns = {}; + struct pci_bus *bus = start->bus; + bool again; + + /* Assume devfn of all PCI devices is bounded by MAX_NR_DEVFNS */ + static_assert(sizeof(next_devfns.devfns) * BITS_PER_BYTE >= + MAX_NR_DEVFNS); + + memset(devfns, 0, sizeof(devfns->devfns)); + __set_bit(start->devfn, devfns->devfns); + __set_bit(start->devfn, next_devfns.devfns); + + down_read(&pci_bus_sem); + while (true) { + unsigned int devfna; + unsigned int i; + + /* + * For each device that hasn't been checked compare every + * device on the bus against it. + */ + again = false; + for_each_set_bit(devfna, next_devfns.devfns, MAX_NR_DEVFNS) { + struct pci_dev *deva = NULL; + struct pci_dev *devb; + + list_for_each_entry(devb, &bus->devices, bus_list) { + if (devb->devfn == devfna) + deva = devb; + + if (test_bit(devb->devfn, devfns->devfns)) + continue; + + if (!deva) { + deva = devb; + list_for_each_entry_continue( + deva, &bus->devices, bus_list) + if (deva->devfn == devfna) + break; + } + + if (!reachable(deva, devb)) + continue; + + __set_bit(devb->devfn, todo_devfns.devfns); + again = true; + } + } + + if (!again) + break; + + /* + * Every new bit adds a new deva to check, reloop the whole + * thing. Expect this to be rare. + */ + for (i = 0; i != ARRAY_SIZE(devfns->devfns); i++) { + devfns->devfns[i] |= todo_devfns.devfns[i]; + next_devfns.devfns[i] = todo_devfns.devfns[i]; + todo_devfns.devfns[i] = 0; + } + } + up_read(&pci_bus_sem); +} +EXPORT_SYMBOL_GPL(pci_reachable_set); diff --git a/include/linux/pci.h b/include/linux/pci.h index fb9adf0562f8ef..21f6b20b487f8d 100644 --- a/include/linux/pci.h +++ b/include/linux/pci.h @@ -855,6 +855,10 @@ struct pci_dynids { struct list_head list; /* For IDs added at runtime */ }; +struct pci_reachable_set { + DECLARE_BITMAP(devfns, 256); +}; + enum pci_bus_isolation { /* * The bus is off a root port and the root port has isolated ACS flags @@ -1269,6 +1273,9 @@ struct pci_dev *pci_get_domain_bus_and_slot(int domain, unsigned int bus, struct pci_dev *pci_get_class(unsigned int class, struct pci_dev *from); struct pci_dev *pci_get_base_class(unsigned int class, struct pci_dev *from); +void pci_reachable_set(struct pci_dev *start, struct pci_reachable_set *devfns, + bool (*reachable)(struct pci_dev *deva, + struct pci_dev *devb)); enum pci_bus_isolation pci_bus_isolated(struct pci_bus *bus); int pci_dev_present(const struct pci_device_id *ids); @@ -2084,6 +2091,11 @@ static inline struct pci_dev *pci_get_base_class(unsigned int class, struct pci_dev *from) { return NULL; } +static inline void +pci_reachable_set(struct pci_dev *start, struct pci_reachable_set *devfns, + bool (*reachable)(struct pci_dev *deva, struct pci_dev *devb)) +{ } + static inline enum pci_bus_isolation pci_bus_isolated(struct pci_bus *bus) { return PCIE_NON_ISOLATED; } -- 2.43.0