This commit addresses a challenge explained in an open question ("How can we incorporate correlation in unknown bits across partial products?") left by Harishankar et al. in their paper: https://arxiv.org/abs/2105.05398 When LSB(a) is uncertain, we know for sure that it is either 0 or 1, from which we could find two possible partial products and take a union. Experiment shows that applying this technique in long multiplication improves the precision in a significant number of cases (at the cost of losing precision in a relatively lower number of cases). This commit also removes the value-mask decomposition technique employed by Harishankar et al., as its direct incorporation did not result in any improvements for the new algorithm. Signed-off-by: Nandakumar Edamana --- include/linux/tnum.h | 3 +++ kernel/bpf/tnum.c | 52 ++++++++++++++++++++++++++++++++------------ 2 files changed, 41 insertions(+), 14 deletions(-) diff --git a/include/linux/tnum.h b/include/linux/tnum.h index 57ed3035cc30..68e9cdd0a2ab 100644 --- a/include/linux/tnum.h +++ b/include/linux/tnum.h @@ -54,6 +54,9 @@ struct tnum tnum_mul(struct tnum a, struct tnum b); /* Return a tnum representing numbers satisfying both @a and @b */ struct tnum tnum_intersect(struct tnum a, struct tnum b); +/* Returns a tnum representing numbers satisfying either @a or @b */ +struct tnum tnum_union(struct tnum t1, struct tnum t2); + /* Return @a with all but the lowest @size bytes cleared */ struct tnum tnum_cast(struct tnum a, u8 size); diff --git a/kernel/bpf/tnum.c b/kernel/bpf/tnum.c index fa353c5d550f..1ae00dbc8b0e 100644 --- a/kernel/bpf/tnum.c +++ b/kernel/bpf/tnum.c @@ -116,31 +116,47 @@ struct tnum tnum_xor(struct tnum a, struct tnum b) return TNUM(v & ~mu, mu); } -/* Generate partial products by multiplying each bit in the multiplier (tnum a) - * with the multiplicand (tnum b), and add the partial products after - * appropriately bit-shifting them. Instead of directly performing tnum addition - * on the generated partial products, equivalenty, decompose each partial - * product into two tnums, consisting of the value-sum (acc_v) and the - * mask-sum (acc_m) and then perform tnum addition on them. The following paper - * explains the algorithm in more detail: https://arxiv.org/abs/2105.05398. +/* Perform long multiplication, iterating through the trits in a. A small trick + * inside the loop finds two possible partial products and takes their union, + * improving the precision significantly. + * A comment inside refers to a paper by Harishankar et al.: + * https://arxiv.org/abs/2105.05398 */ struct tnum tnum_mul(struct tnum a, struct tnum b) { - u64 acc_v = a.value * b.value; - struct tnum acc_m = TNUM(0, 0); + struct tnum acc = TNUM(0, 0); while (a.value || a.mask) { /* LSB of tnum a is a certain 1 */ - if (a.value & 1) - acc_m = tnum_add(acc_m, TNUM(0, b.mask)); + if (a.value & 1) { + acc = tnum_add(acc, b); + } /* LSB of tnum a is uncertain */ - else if (a.mask & 1) - acc_m = tnum_add(acc_m, TNUM(0, b.value | b.mask)); + else if (a.mask & 1) { + /* Simply multiplying b with LSB(a)'s uncertainty results in decreased + * precision, as explained in an open question ("How can we incorporate + * correlation in unknown bits across partial products?") left by + * Harishankar et al. However, we know for sure that LSB(a) is either + * 0 or 1, from which we could find two possible partial products and + * take a union. This improves the precision in a significant number of + * cases. + * + * The first partial product (acc_0) is for the case LSB(a) = 0; + * but acc_0 = acc + 0 * b = acc. + */ + + /* In case LSB(a) is 1 */ + u64 itermask = b.value | b.mask; + struct tnum iterprod = TNUM(b.value & ~itermask, itermask); + struct tnum acc_1 = tnum_add(acc, iterprod); + + acc = tnum_union(acc, acc_1); + } /* Note: no case for LSB is certain 0 */ a = tnum_rshift(a, 1); b = tnum_lshift(b, 1); } - return tnum_add(TNUM(acc_v, 0), acc_m); + return acc; } /* Note that if a and b disagree - i.e. one has a 'known 1' where the other has @@ -155,6 +171,14 @@ struct tnum tnum_intersect(struct tnum a, struct tnum b) return TNUM(v & ~mu, mu); } +struct tnum tnum_union(struct tnum a, struct tnum b) +{ + u64 v = a.value & b.value; + u64 mu = (a.value ^ b.value) | a.mask | b.mask; + + return TNUM(v & ~mu, mu); +} + struct tnum tnum_cast(struct tnum a, u8 size) { a.value &= (1ULL << (size * 8)) - 1; -- 2.39.5