Add the emit_kasan_check() function that emits KASAN shadow memory checks before memory accesses in JIT-compiled BPF programs. The implementation relies on the existing __asan_{load,store}X functions from KASAN subsystem. The helper: - ensures that the kasan instrumention is actually needed: if the instruction being processed accesses the program stack, we skip the instrumentation, as those accesses are already protected with page guards - saves registers. This includes caller-saved registers, but also temporary registers, as those were possibly used by the affected program - computes the accessed address and stores it in %rdi - calls the relevant function, depending on the instruction being a load or a store, and the size of the access. - restores registers The special care needed when inserting this instrumentation comes at the cost of a non negligeable increase in JITed code size. For example, a bare mov 0x0(%si),rbx # Load in rbx content at address stored in rsi becomes push %rax push %rcx push %rdx push %rsi push %rdi push %r8 push %r9 mov %rsi,%rdi call 0xffffffff81da0a60 <__asan_load8> pop %r9 pop %r8 pop %rdi pop %rsi pop %rdx pop %rcx pop %rax mov 0x0(%rsi),rbx Signed-off-by: Alexis Lothoré (eBPF Foundation) --- Changes in v2: - move asan functions declaration directly into jit compiler, and guard them with IS_ENABLED - remove faulty stack alignment, no arg is passed to kasan funcs on the stack anyway - make sure to emit call depth accounting code - do not save unneeded registers - update helper signature to let caller configure some values (eg: is_write) Signed-off-by: Alexis Lothoré (eBPF Foundation) --- arch/x86/net/bpf_jit_comp.c | 93 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 93 insertions(+) diff --git a/arch/x86/net/bpf_jit_comp.c b/arch/x86/net/bpf_jit_comp.c index a0c541a441cf..0981791014eb 100644 --- a/arch/x86/net/bpf_jit_comp.c +++ b/arch/x86/net/bpf_jit_comp.c @@ -21,6 +21,19 @@ #include #include +#if IS_ENABLED(CONFIG_BPF_JIT_KASAN) +void __asan_load1(void *p); +void __asan_store1(void *p); +void __asan_load2(void *p); +void __asan_store2(void *p); +void __asan_load4(void *p); +void __asan_store4(void *p); +void __asan_load8(void *p); +void __asan_store8(void *p); +void __asan_load16(void *p); +void __asan_store16(void *p); +#endif + static bool all_callee_regs_used[4] = {true, true, true, true}; static u8 *emit_code(u8 *ptr, u32 bytes, unsigned int len) @@ -1330,6 +1343,86 @@ static void emit_store_stack_imm64(u8 **pprog, int reg, int stack_off, u64 imm64 emit_stx(pprog, BPF_DW, BPF_REG_FP, reg, stack_off); } +static int emit_kasan_check(u8 **pprog, u32 addr_reg, struct bpf_insn *insn, + u8 *ip, bool is_write, bool accesses_stack_only) +{ +#ifdef CONFIG_BPF_JIT_KASAN + u32 bpf_size = BPF_SIZE(insn->code); + s32 off = insn->off; + u8 *prog = *pprog; + void *kasan_func; + + if (accesses_stack_only) + return 0; + + /* Derive KASAN check function from access type and size */ + switch (bpf_size) { + case BPF_B: + kasan_func = is_write ? __asan_store1 : __asan_load1; + break; + case BPF_H: + kasan_func = is_write ? __asan_store2 : __asan_load2; + break; + case BPF_W: + kasan_func = is_write ? __asan_store4 : __asan_load4; + break; + case BPF_DW: + kasan_func = is_write ? __asan_store8 : __asan_load8; + break; + default: + return -EINVAL; + } + + /* Save rax */ + EMIT1(0x50); + /* Save rcx */ + EMIT1(0x51); + /* Save rdx */ + EMIT1(0x52); + /* Save rsi */ + EMIT1(0x56); + /* Save rdi */ + EMIT1(0x57); + /* Save r8 */ + EMIT2(0x41, 0x50); + /* Save r9 */ + EMIT2(0x41, 0x51); + + /* mov rdi, addr_reg */ + EMIT_mov(BPF_REG_1, addr_reg); + + /* add rdi, off (if offset is non-zero) */ + if (off) { + if (is_imm8(off)) { + /* add rdi, imm8 */ + EMIT4(0x48, 0x83, 0xC7, (u8)off); + } else { + /* add rdi, imm32 */ + EMIT3_off32(0x48, 0x81, 0xC7, off); + } + } + + /* Adjust ip to account for the instrumentation generated so far */ + ip += (prog - *pprog); + /* We emit a call, so update call depth counting */ + ip += x86_call_depth_emit_accounting(&prog, kasan_func, ip); + /* call kasan_func */ + if (emit_call(&prog, kasan_func, ip)) + return -ERANGE; + + EMIT2(0x41, 0x59); + EMIT2(0x41, 0x58); + EMIT1(0x5F); + EMIT1(0x5E); + EMIT1(0x5A); + EMIT1(0x59); + EMIT1(0x58); + + *pprog = prog; +#endif /* CONFIG_BPF_JIT_KASAN */ + return 0; +} + static int emit_atomic_rmw(u8 **pprog, u32 atomic_op, u32 dst_reg, u32 src_reg, s16 off, u8 bpf_size) { -- 2.54.0