Simplify tnum_step() from a 10-variable algorithm into a straight line sequence of bitwise operations. Problem Reduction: tnum_step(): Given a tnum `(tval, tmask)` where `tval & tmask == 0`, and a value `z` with `tval ≤ z < (tval | tmask)`, find the smallest `r > z`, a tnum-satisfying value, i.e., `r & ~tmask == tval`. Every tnum-satisfying value has the form tval | s where s is a subset of tmask bits (s & ~tmask == 0). Since tval and tmask are disjoint: tval | s = tval + s Similarly z = tval + d where d = z - tval, so r > z becomes: tval + s > tval + d s > d The problem reduces to: find the smallest s, a subset of tmask, such that s > d. Notice that `s` must be a subset of tmask, the problem now is simplified. Algorithm: The mask bits of `d` form a "counter" that we want to increment by one, but the counter has gaps at the fixed-bit positions. A normal +1 would stop at the first 0-bit it meets; we need it to skip over fixed-bit gaps and land on the next mask bit. Step 1 -- plug the gaps: d | carry_mask | ~tmask - ~tmask fills all fixed-bit positions with 1. - carry_mask = (1 << fls64(d & ~tmask)) - 1 fills all positions (including mask positions) below the highest non-mask bit of d. After this, the only remaining 0s are mask bits above the highest non-mask bit of d where d is also 0 -- exactly the positions where the carry can validly land. Step 2 -- increment: (d | carry_mask | ~tmask) + 1 Adding 1 flips all trailing 1s to 0 and sets the first 0 to 1. Since every gap has been plugged, that first 0 is guaranteed to be a mask bit above all non-mask bits of d. Step 3 -- mask: ((d | carry_mask | ~tmask) + 1) & tmask Strip the scaffolding, keeping only mask bits. Call the result inc. Step 4 -- result: tval | inc Reattach the fixed bits. A simple 8-bit example: tmask: 1 1 0 1 0 1 1 0 d: 1 0 1 0 0 0 1 0 (d = 162) ^ non-mask 1 at bit 5 With carry_mask = 0b00111111 (smeared from bit 5): d|carry|~tm 1 0 1 1 1 1 1 1 + 1 1 1 0 0 0 0 0 0 & tmask 1 1 0 0 0 0 0 0 The patch passes my local test: test_verifier, test_progs for `-t verifier` and `-t reg_bounds`. CBMC shows the new code is equiv to original one[1], and a lean4 proof of correctness is available[2]: theorem tnumStep_correct (tval tmask z : BitVec 64) -- Precondition: valid tnum and input z (h_consistent : (tval &&& tmask) = 0) (h_lo : tval ≤ z) (h_hi : z < (tval ||| tmask)) : -- Postcondition: r must be: -- (1) tnum member -- (2) z < r -- (3) for any other member w > z, r <= w let r := tnumStep tval tmask z satisfiesTnum64 r tval tmask ∧ tval ≤ r ∧ r ≤ (tval ||| tmask) ∧ z < r ∧ ∀ w, satisfiesTnum64 w tval tmask → z < w → r ≤ w := by -- unfold definition unfold tnumStep satisfiesTnum64 simp only [] refine ⟨?_, ?_, ?_, ?_, ?_⟩ -- the solver proves each conjunct · bv_decide · bv_decide · bv_decide · bv_decide · intro w hw1 hw2; bv_decide [1] https://github.com/eddyz87/tnum-step-verif/blob/master/main.c [2] https://pastebin.com/raw/czHKiyY0 Signed-off-by: Hao Sun Acked-by: Eduard Zingerman Acked-by: Shung-Hsi Yu --- v1 -> v2: inline proof, add code comments, add a variable `filled`. kernel/bpf/tnum.c | 46 +++++++++++++++++++--------------------------- 1 file changed, 19 insertions(+), 27 deletions(-) diff --git a/kernel/bpf/tnum.c b/kernel/bpf/tnum.c index 4abc359b3db0..ec9c310cf5d7 100644 --- a/kernel/bpf/tnum.c +++ b/kernel/bpf/tnum.c @@ -286,8 +286,7 @@ struct tnum tnum_bswap64(struct tnum a) */ u64 tnum_step(struct tnum t, u64 z) { - u64 tmax, j, p, q, r, s, v, u, w, res; - u8 k; + u64 tmax, d, carry_mask, filled, inc; tmax = t.value | t.mask; @@ -299,29 +298,22 @@ u64 tnum_step(struct tnum t, u64 z) if (z < t.value) return t.value; - /* keep t's known bits, and match all unknown bits to z */ - j = t.value | (z & t.mask); - - if (j > z) { - p = ~z & t.value & ~t.mask; - k = fls64(p); /* k is the most-significant 0-to-1 flip */ - q = U64_MAX << k; - r = q & z; /* positions > k matched to z */ - s = ~q & t.value; /* positions <= k matched to t.value */ - v = r | s; - res = v; - } else { - p = z & ~t.value & ~t.mask; - k = fls64(p); /* k is the most-significant 1-to-0 flip */ - q = U64_MAX << k; - r = q & t.mask & z; /* unknown positions > k, matched to z */ - s = q & ~t.mask; /* known positions > k, set to 1 */ - v = r | s; - /* add 1 to unknown positions > k to make value greater than z */ - u = v + (1ULL << k); - /* extract bits in unknown positions > k from u, rest from t.value */ - w = (u & t.mask) | t.value; - res = w; - } - return res; + /* + * Let r be the result tnum member, z = t.value + d. + * Every tnum member is t.value | s for some submask s of t.mask, + * and since t.value & t.mask == 0, t.value | s == t.value + s. + * So r > z becomes s > d where d = z - t.value. + * + * Find the smallest submask s of t.mask greater than d by + * "incrementing d within the mask": fill every non-mask + * position with 1 (`filled`) so +1 ripples through the gaps, + * then keep only mask bits. `carry_mask` additionally fills + * positions below the highest non-mask 1 in d, preventing + * it from trapping the carry. + */ + d = z - t.value; + carry_mask = (1ULL << fls64(d & ~t.mask)) - 1; + filled = d | carry_mask | ~t.mask; + inc = (filled + 1) & t.mask; + return t.value | inc; } -- 2.34.1