Currently, the BPF verifier only allows shift operations when the shift amount is a known constant. This is overly restrictive for cases where the shift amount is bounded but not fully determined at verification time. For example, the following code is rejected by the verifier even though the shift amount is bounded to [1, 4]: u32 shift = bpf_get_prandom_u32(); shift &= 3; // shift is in range [0, 3] shift += 1; // shift is in range [1, 4] r1 <<= shift; // non-const but bounded shift amount Modify the shift helper functions (scalar_min_max_lsh, scalar32_min_max_lsh, scalar_min_max_rsh, scalar32_min_max_rsh, scalar_min_max_arsh, scalar32_min_max_arsh) to handle non-const but bounded shift amounts. Update is_safe_to_compute_dst_reg_range() to remove the src_is_const check for shift operations. This approach ensures the verifier remains sound while allowing more programs to pass verification. Also modify the comment on is_safe_to_compute_dst_reg_range. Shifts by more than insn bitness are legal in the BPF ISA; they are currently implementation-defined behaviour of the underlying architecture, rather than UB, and have been made legal for performance reasons. See: https://lore.kernel.org/bpf/20210706112502.2064236-47-sashal@kernel.org Co-developed-by: Yazhou Tang Signed-off-by: Yazhou Tang Co-developed-by: Shenghao Yuan Signed-off-by: Shenghao Yuan Signed-off-by: Tianci Cao --- kernel/bpf/verifier.c | 100 ++++++++++++++++++++++++++++++++---------- 1 file changed, 76 insertions(+), 24 deletions(-) diff --git a/kernel/bpf/verifier.c b/kernel/bpf/verifier.c index 8ed484cb1a8a..6eba2af1b5c4 100644 --- a/kernel/bpf/verifier.c +++ b/kernel/bpf/verifier.c @@ -14184,7 +14184,8 @@ static void scalar32_min_max_lsh(struct bpf_reg_state *dst_reg, struct tnum subreg = tnum_subreg(dst_reg->var_off); __scalar32_min_max_lsh(dst_reg, umin_val, umax_val); - dst_reg->var_off = tnum_subreg(tnum_lshift(subreg, umin_val)); + dst_reg->var_off = (umin_val == umax_val) ? + tnum_subreg(tnum_lshift(subreg, umin_val)) : tnum_unknown; /* Not required but being careful mark reg64 bounds as unknown so * that we are forced to pick them up from tnum and zext later and * if some path skips this step we are still safe. @@ -14229,7 +14230,8 @@ static void scalar_min_max_lsh(struct bpf_reg_state *dst_reg, __scalar64_min_max_lsh(dst_reg, umin_val, umax_val); __scalar32_min_max_lsh(dst_reg, umin_val, umax_val); - dst_reg->var_off = tnum_lshift(dst_reg->var_off, umin_val); + dst_reg->var_off = (umin_val == umax_val) ? + tnum_lshift(dst_reg->var_off, umin_val) : tnum_unknown; /* We may learn something more from the var_off */ __update_reg_bounds(dst_reg); } @@ -14256,7 +14258,8 @@ static void scalar32_min_max_rsh(struct bpf_reg_state *dst_reg, * var_off of the result. */ - dst_reg->var_off = tnum_rshift(subreg, umin_val); + dst_reg->var_off = (umin_val == umax_val) ? + tnum_rshift(subreg, umin_val) : tnum_unknown; reg_set_urange32(dst_reg, reg_u32_min(dst_reg) >> umax_val, reg_u32_max(dst_reg) >> umin_val); @@ -14284,7 +14287,8 @@ static void scalar_min_max_rsh(struct bpf_reg_state *dst_reg, * and rely on inferring new ones from the unsigned bounds and * var_off of the result. */ - dst_reg->var_off = tnum_rshift(dst_reg->var_off, umin_val); + dst_reg->var_off = (umin_val == umax_val) ? + tnum_rshift(dst_reg->var_off, umin_val) : tnum_unknown; reg_set_urange64(dst_reg, reg_umin(dst_reg) >> umax_val, reg_umax(dst_reg) >> umin_val); @@ -14299,18 +14303,44 @@ static void scalar_min_max_rsh(struct bpf_reg_state *dst_reg, static void scalar32_min_max_arsh(struct bpf_reg_state *dst_reg, struct bpf_reg_state *src_reg) { - u64 umin_val = reg_u32_min(src_reg); + u32 umin_val = reg_u32_min(src_reg); + u32 umax_val = reg_u32_max(src_reg); + s32 smin = reg_s32_min(dst_reg); + s32 smax = reg_s32_max(dst_reg); - /* Upon reaching here, src_known is true and - * umax_val is equal to umin_val. - * Blow away the dst_reg umin_value/umax_value and rely on - * dst_reg var_off to refine the result. + /* + * BPF_ARSH on 32-bit subregister. The shift amount [umin, umax] + * may be non-constant, so we conservatively derive signed bounds: + * + * smin >= 0: non-negative value; right-shift reduces magnitude. + * result in [s32_max >> umax_val, s32_min >> umin_val] + * e.g. [4,8] >> [1,2] → [1,4] + * smax < 0: negative value; right-shift increases (toward 0). + * result in [s32_min >> umin_val, s32_max >> umax_val] + * e.g. [-8,-4] >> [1,2] → [-4,-1] + * mixed: result in [s32_min >> umin_val, s32_max >> umin_val] + * e.g. [-8,8] >> [1,2] → [-4,4] + * + * var_off is set to tnum_unknown because without a constant shift + * amount we cannot precisely track which bits remain known. */ - reg_set_srange32(dst_reg, - (u32)(((s32)reg_s32_min(dst_reg)) >> umin_val), - (u32)(((s32)reg_s32_max(dst_reg)) >> umin_val)); - - dst_reg->var_off = tnum_arshift(tnum_subreg(dst_reg->var_off), umin_val, 32); + if (umin_val == umax_val) { + reg_set_srange32(dst_reg, (u32)(smin >> umin_val), + (u32)(smax >> umin_val)); + dst_reg->var_off = tnum_arshift(tnum_subreg(dst_reg->var_off), + umin_val, 32); + } else { + if (smin >= 0) + reg_set_srange32(dst_reg, (u32)(smin >> umax_val), + (u32)(smax >> umin_val)); + else if (smax < 0) + reg_set_srange32(dst_reg, (u32)(smin >> umin_val), + (u32)(smax >> umax_val)); + else + reg_set_srange32(dst_reg, (u32)(smin >> umin_val), + (u32)(smax >> umin_val)); + dst_reg->var_off = tnum_unknown; + } __mark_reg64_unbounded(dst_reg); __update_reg32_bounds(dst_reg); @@ -14320,14 +14350,36 @@ static void scalar_min_max_arsh(struct bpf_reg_state *dst_reg, struct bpf_reg_state *src_reg) { u64 umin_val = reg_umin(src_reg); + u64 umax_val = reg_umax(src_reg); + s64 smin = reg_smin(dst_reg); + s64 smax = reg_smax(dst_reg); - /* Upon reaching here, src_known is true and umax_val is equal - * to umin_val. + /* + * BPF_ARSH (arithmetic right shift) on 64-bit register. + * Same three-branch logic as the 32-bit variant (scalar32_min_max_arsh): + * + * smin >= 0: result in [smax >> umax_val, smin >> umin_val] + * e.g. [4,8] >> [1,2] → [1,4] + * smax < 0: result in [smin >> umin_val, smax >> umax_val] + * e.g. [-8,-4] >> [1,2] → [-4,-1] + * mixed: result in [smin >> umin_val, smax >> umin_val] + * e.g. [-8,8] >> [1,2] → [-4,4] + * + * var_off is set to tnum_unknown since a non-constant shift amount + * prevents precise bit tracking. */ - reg_set_srange64(dst_reg, reg_smin(dst_reg) >> umin_val, - reg_smax(dst_reg) >> umin_val); - - dst_reg->var_off = tnum_arshift(dst_reg->var_off, umin_val, 64); + if (umin_val == umax_val) { + reg_set_srange64(dst_reg, smin >> umin_val, smax >> umin_val); + dst_reg->var_off = tnum_arshift(dst_reg->var_off, umin_val, 64); + } else { + if (smin >= 0) + reg_set_srange64(dst_reg, smin >> umax_val, smax >> umin_val); + else if (smax < 0) + reg_set_srange64(dst_reg, smin >> umin_val, smax >> umax_val); + else + reg_set_srange64(dst_reg, smin >> umin_val, smax >> umin_val); + dst_reg->var_off = tnum_unknown; + } /* Its not easy to operate on alu32 bounds here because it depends * on bits being shifted in from upper 32-bits. Take easy way out @@ -14423,14 +14475,14 @@ static bool is_safe_to_compute_dst_reg_range(struct bpf_insn *insn, case BPF_MOD: return src_is_const; - /* Shift operators range is only computable if shift dimension operand - * is a constant. Shifts greater than 31 or 63 are undefined. This - * includes shifts by a negative number. + /* + * Shifts greater than 31 or 63 are implementation-defined behaviour. + * This includes shifts by a negative number. */ case BPF_LSH: case BPF_RSH: case BPF_ARSH: - return (src_is_const && reg_umax(src_reg) < insn_bitness); + return reg_umax(src_reg) < insn_bitness; default: return false; } -- 2.43.0 Add test cases for shift operations with non-const but bounded source operand: - shift_with_non_const_src_lsh: Tests left shift (BPF_LSH) where the shift amount is in range [1, 4] and the destination is a known constant (1). The verifier should compute correct bounds [2, 16] for the result. - shift_with_non_const_src_rsh: Tests logical right shift (BPF_RSH) where the shift amount is in range [1, 4] and the destination is 0xff. The verifier should compute correct bounds [15, 127] for the result. - shift_with_non_const_src_arsh: Tests arithmetic right shift (BPF_ARSH) where the shift amount is in range [1, 4] and the destination is a negative constant (-8). The verifier applies the three-branch signed bound logic to derive result bounds [-4, -1]. When the shift amount is non-constant, the var_off is conservatively set to tnum_unknown. Co-developed-by: Yazhou Tang Signed-off-by: Yazhou Tang Co-developed-by: Shenghao Yuan Signed-off-by: Shenghao Yuan Signed-off-by: Tianci Cao --- .../selftests/bpf/progs/verifier_bounds.c | 108 ++++++++++++++++++ 1 file changed, 108 insertions(+) diff --git a/tools/testing/selftests/bpf/progs/verifier_bounds.c b/tools/testing/selftests/bpf/progs/verifier_bounds.c index bc038ac2df98..63ee7603d482 100644 --- a/tools/testing/selftests/bpf/progs/verifier_bounds.c +++ b/tools/testing/selftests/bpf/progs/verifier_bounds.c @@ -482,6 +482,114 @@ l0_%=: /* exit */ \ : __clobber_all); } +SEC("socket") +__description("pure bounds check after non-const 32-bit left shift") +__success __log_level(2) +__msg("w1 <<= w2 {{.*}}; R1=scalar(smin=umin=smin32=umin32=2,smax=umax=smax32=umax32=16,var_off=(0x0; 0x1f))") +__naked void shift_with_non_const_src_lsh_32(void) +{ + asm volatile (" \ + call %[bpf_get_prandom_u32]; \ + w2 = w0; \ + w2 &= 3; \ + w2 += 1; \ + w1 = 1; \ + w1 <<= w2; \ + exit; \ +" :: __imm(bpf_get_prandom_u32) + : __clobber_all); +} + +SEC("socket") +__description("pure bounds check after non-const 32-bit right shift") +__success __log_level(2) +__msg("w1 >>= w2 {{.*}}; R1=scalar(smin=umin=smin32=umin32=15,smax=umax=smax32=umax32=127,var_off=(0x0; 0x7f))") +__naked void shift_with_non_const_src_rsh_32(void) +{ + asm volatile (" \ + call %[bpf_get_prandom_u32]; \ + w2 = w0; \ + w2 &= 3; \ + w2 += 1; \ + w1 = 0xff; \ + w1 >>= w2; \ + exit; \ +" :: __imm(bpf_get_prandom_u32) + : __clobber_all); +} + +SEC("socket") +__description("pure bounds check after non-const 32-bit arithmetic right shift") +__success __log_level(2) +__msg("w1 s>>= w2 {{.*}}; R1=scalar(smin=umin=umin32=0xfffffffc,smax=umax=0xffffffff,smin32=-4,smax32=-1,var_off=(0xfffffffc; 0x3)) R2=scalar(smin=umin=smin32=umin32=1,smax=umax=smax32=umax32=4,var_off=(0x0; 0x7))") +__naked void shift_with_non_const_src_arsh_32(void) +{ + asm volatile (" \ + call %[bpf_get_prandom_u32]; \ + w2 = w0; \ + w2 &= 3; \ + w2 += 1; \ + w1 = -8; \ + w1 s>>= w2; \ + exit; \ +" :: __imm(bpf_get_prandom_u32) + : __clobber_all); +} + +SEC("socket") +__description("pure bounds check after non-const 64-bit left shift") +__success __log_level(2) +__msg("r1 <<= r2 {{.*}}; R1=scalar(smin=umin=smin32=umin32=2,smax=umax=smax32=umax32=16,var_off=(0x0; 0x1f))") +__naked void shift_with_non_const_src_lsh_64(void) +{ + asm volatile (" \ + call %[bpf_get_prandom_u32]; \ + r2 = r0; \ + r2 &= 3; \ + r2 += 1; \ + r1 = 1; \ + r1 <<= r2; \ + exit; \ +" :: __imm(bpf_get_prandom_u32) + : __clobber_all); +} + +SEC("socket") +__description("pure bounds check after non-const 64-bit right shift") +__success __log_level(2) +__msg("r1 >>= r2 {{.*}}; R1=scalar(smin=umin=smin32=umin32=15,smax=umax=smax32=umax32=127,var_off=(0x0; 0x7f))") +__naked void shift_with_non_const_src_rsh_64(void) +{ + asm volatile (" \ + call %[bpf_get_prandom_u32]; \ + r2 = r0; \ + r2 &= 3; \ + r2 += 1; \ + r1 = 0xff; \ + r1 >>= r2; \ + exit; \ +" :: __imm(bpf_get_prandom_u32) + : __clobber_all); +} + +SEC("socket") +__description("pure bounds check after non-const 64-bit arithmetic right shift") +__success __log_level(2) +__msg("r1 s>>= r2 {{.*}}; R1=scalar(smin=smin32=-4,smax=smax32=-1,umin=0xfffffffffffffffc,umin32=0xfffffffc,var_off=(0xfffffffffffffffc; 0x3)) R2=scalar(smin=umin=smin32=umin32=1,smax=umax=smax32=umax32=4,var_off=(0x0; 0x7))") +__naked void shift_with_non_const_src_arsh_64(void) +{ + asm volatile (" \ + call %[bpf_get_prandom_u32]; \ + r2 = r0; \ + r2 &= 3; \ + r2 += 1; \ + r1 = -8; \ + r1 s>>= r2; \ + exit; \ +" :: __imm(bpf_get_prandom_u32) + : __clobber_all); +} + SEC("socket") __description("bounds check after 32-bit right shift with 64-bit input") __failure __msg("math between map_value pointer and 4294967294 is not allowed") -- 2.43.0