Add x86_64 JIT support for BPF functions and kfuncs with more than 5 arguments. The extra arguments are passed through a stack area addressed by register r12 (BPF_REG_STACK_ARG_BASE) in BPF bytecode, which the JIT translates to RBP-relative accesses in native code. There are two possible approaches to allocate the stack arg area: Option 1: Allocate a single combined region (incoming + max_outgoing) below the program stack in the function prologue. All r12-relative accesses become [rbp - prog_stack_depth - offset] where the 'offset' is the offset value in (incoming + max_outgoing) region. This is simple because the area is always at a fixed offset from RBP. The tradeoff is slightly higher stack usage when multiple callees have different stack arg counts — the area is sized to the maximum. Option 2: Allocate each outgoing area individually at the call site, sized exactly to the callee's needs. This minimizes stack usage but significantly complicates the JIT: each call site must dynamically adjust RSP, and addresses of stack args would shift depending on context, making the offset calculations harder. This patch uses Option 1 for simplicity. The native x86_64 stack layout for a function with incoming and outgoing stack args: high address ┌─────────────────────────┐ │ incoming stack arg N │ [rbp + 16 + (N - 1) * 8] (pushed by caller) │ ... │ │ incoming stack arg 1 │ [rbp + 16] ├─────────────────────────┤ │ return address │ [rbp + 8] │ saved rbp │ [rbp] ├─────────────────────────┤ │ callee-saved regs │ │ BPF program stack │ (stack_depth bytes) ├─────────────────────────┤ │ incoming stack arg 1 │ [rbp - prog_stack_depth - 8] │ ... (copied from │ (copied in prologue) │ caller's push) │ │ incoming stack arg N │ [rbp - prog_stack_depth - N * 8] ├─────────────────────────┤ │ outgoing stack arg 1 │ (written via r12-relative STX/ST, │ ... │ JIT translates to RBP-relative) │ outgoing stack arg M │ └─────────────────────────┘ ... Other stack usage ┌─────────────────────────┐ │ incoming stack arg M │ (copy from outgoing stack arg to │ ... │ incoming stack arg) │ incoming stack arg 1 │ ├─────────────────────────┤ │ return address │ │ saved rbp │ ├─────────────────────────┤ │ ... │ └─────────────────────────┘ low address In prologue, the caller's incoming stack arguments are copied to callee's incoming stack arguments, which will be fetched by later load insns. The outgoing stack arguments are written by JIT RBP-relative STX or ST. For each bpf-to-bpf call, push outgoing stack args onto the native stack before CALL, pop them after return. So the same 'outgoing stack arg' area is used by all bpf-to-bpf functions. For kfunc calls, push stack args (arg 7+) onto the native stack and load arg 6 into R9 per the x86_64 calling convention, then clean up RSP after return. Signed-off-by: Yonghong Song --- arch/x86/net/bpf_jit_comp.c | 145 ++++++++++++++++++++++++++++++++++-- 1 file changed, 138 insertions(+), 7 deletions(-) diff --git a/arch/x86/net/bpf_jit_comp.c b/arch/x86/net/bpf_jit_comp.c index 32864dbc2c4e..807493f109e5 100644 --- a/arch/x86/net/bpf_jit_comp.c +++ b/arch/x86/net/bpf_jit_comp.c @@ -367,6 +367,27 @@ static void push_callee_regs(u8 **pprog, bool *callee_regs_used) *pprog = prog; } +static int push_stack_args(u8 **pprog, s32 base_off, int from, int to) +{ + u8 *prog = *pprog; + int j, off, cnt = 0; + + for (j = from; j >= to; j--) { + off = base_off - j * 8; + + /* push qword [rbp + off] */ + if (is_imm8(off)) { + EMIT3(0xFF, 0x75, off); + cnt += 3; + } else { + EMIT2_off32(0xFF, 0xB5, off); + cnt += 6; + } + } + *pprog = prog; + return cnt; +} + static void pop_r12(u8 **pprog) { u8 *prog = *pprog; @@ -1664,19 +1685,35 @@ static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image, u8 *rw_image int i, excnt = 0; int ilen, proglen = 0; u8 *prog = temp; - u32 stack_depth; + u16 stack_arg_depth, incoming_stack_arg_depth; + u32 prog_stack_depth, stack_depth; + bool has_stack_args; int err; stack_depth = bpf_prog->aux->stack_depth; + stack_arg_depth = bpf_prog->aux->stack_arg_depth; + incoming_stack_arg_depth = bpf_prog->aux->incoming_stack_arg_depth; priv_stack_ptr = bpf_prog->aux->priv_stack_ptr; if (priv_stack_ptr) { priv_frame_ptr = priv_stack_ptr + PRIV_STACK_GUARD_SZ + round_up(stack_depth, 8); stack_depth = 0; } + /* + * Save program stack depth before adding stack arg space. + * Each function allocates its own stack arg space + * (incoming + outgoing) below its BPF stack. + * Stack args are accessed via RBP-based addressing. + */ + prog_stack_depth = round_up(stack_depth, 8); + if (stack_arg_depth) + stack_depth += stack_arg_depth; + has_stack_args = stack_arg_depth > 0; + arena_vm_start = bpf_arena_get_kern_vm_start(bpf_prog->aux->arena); user_vm_start = bpf_arena_get_user_vm_start(bpf_prog->aux->arena); + detect_reg_usage(insn, insn_cnt, callee_regs_used); emit_prologue(&prog, image, stack_depth, @@ -1704,6 +1741,38 @@ static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image, u8 *rw_image emit_mov_imm64(&prog, X86_REG_R12, arena_vm_start >> 32, (u32) arena_vm_start); + if (incoming_stack_arg_depth && bpf_is_subprog(bpf_prog)) { + int n = incoming_stack_arg_depth / 8; + + /* + * Caller pushed stack args before CALL, so after prologue + * (CALL saves ret addr, then PUSH saves old RBP) they sit + * above RBP: + * + * [rbp + 16 + (n - 1) * 8] stack_arg n + * ... + * [rbp + 24] stack_arg 2 + * [rbp + 16] stack_arg 1 + * [rbp + 8] return address + * [rbp + 0] saved rbp + * + * Copy each into callee's own region below the program stack: + * [rbp - prog_stack_depth - i * 8] + */ + for (i = 0; i < n; i++) { + s32 src = 16 + i * 8; + s32 dst = -prog_stack_depth - (i + 1) * 8; + + /* mov rax, [rbp + src] */ + EMIT4(0x48, 0x8B, 0x45, src); + /* mov [rbp + dst], rax */ + if (is_imm8(dst)) + EMIT4(0x48, 0x89, 0x45, dst); + else + EMIT3_off32(0x48, 0x89, 0x85, dst); + } + } + if (priv_frame_ptr) emit_priv_frame_ptr(&prog, priv_frame_ptr); @@ -1715,13 +1784,14 @@ static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image, u8 *rw_image prog = temp; for (i = 1; i <= insn_cnt; i++, insn++) { + bool adjust_stack_arg_off = false; const s32 imm32 = insn->imm; u32 dst_reg = insn->dst_reg; u32 src_reg = insn->src_reg; u8 b2 = 0, b3 = 0; u8 *start_of_ldx; s64 jmp_offset; - s16 insn_off; + s32 insn_off; u8 jmp_cond; u8 *func; int nops; @@ -1734,6 +1804,21 @@ static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image, u8 *rw_image dst_reg = X86_REG_R9; } + if (has_stack_args) { + u8 class = BPF_CLASS(insn->code); + + if (class == BPF_LDX && + src_reg == BPF_REG_STACK_ARG_BASE) { + src_reg = BPF_REG_FP; + adjust_stack_arg_off = true; + } + if ((class == BPF_STX || class == BPF_ST) && + dst_reg == BPF_REG_STACK_ARG_BASE) { + dst_reg = BPF_REG_FP; + adjust_stack_arg_off = true; + } + } + switch (insn->code) { /* ALU */ case BPF_ALU | BPF_ADD | BPF_X: @@ -2131,10 +2216,16 @@ static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image, u8 *rw_image case BPF_ST | BPF_MEM | BPF_DW: EMIT2(add_1mod(0x48, dst_reg), 0xC7); -st: if (is_imm8(insn->off)) - EMIT2(add_1reg(0x40, dst_reg), insn->off); +st: { + int off = insn->off; + + if (adjust_stack_arg_off) + off -= prog_stack_depth; + if (is_imm8(off)) + EMIT2(add_1reg(0x40, dst_reg), off); else - EMIT1_off32(add_1reg(0x80, dst_reg), insn->off); + EMIT1_off32(add_1reg(0x80, dst_reg), off); + } EMIT(imm32, bpf_size_to_x86_bytes(BPF_SIZE(insn->code))); break; @@ -2143,9 +2234,14 @@ st: if (is_imm8(insn->off)) case BPF_STX | BPF_MEM | BPF_B: case BPF_STX | BPF_MEM | BPF_H: case BPF_STX | BPF_MEM | BPF_W: - case BPF_STX | BPF_MEM | BPF_DW: - emit_stx(&prog, BPF_SIZE(insn->code), dst_reg, src_reg, insn->off); + case BPF_STX | BPF_MEM | BPF_DW: { + int off = insn->off; + + if (adjust_stack_arg_off) + off -= prog_stack_depth; + emit_stx(&prog, BPF_SIZE(insn->code), dst_reg, src_reg, off); break; + } case BPF_ST | BPF_PROBE_MEM32 | BPF_B: case BPF_ST | BPF_PROBE_MEM32 | BPF_H: @@ -2243,6 +2339,8 @@ st: if (is_imm8(insn->off)) case BPF_LDX | BPF_PROBE_MEMSX | BPF_H: case BPF_LDX | BPF_PROBE_MEMSX | BPF_W: insn_off = insn->off; + if (adjust_stack_arg_off) + insn_off -= prog_stack_depth; if (BPF_MODE(insn->code) == BPF_PROBE_MEM || BPF_MODE(insn->code) == BPF_PROBE_MEMSX) { @@ -2440,6 +2538,8 @@ st: if (is_imm8(insn->off)) /* call */ case BPF_JMP | BPF_CALL: { + int off, base_off, n_stack_args, kfunc_stack_args = 0, stack_args = 0; + u16 outgoing_stack_args = stack_arg_depth - incoming_stack_arg_depth; u8 *ip = image + addrs[i - 1]; func = (u8 *) __bpf_call_base + imm32; @@ -2449,6 +2549,29 @@ st: if (is_imm8(insn->off)) } if (!imm32) return -EINVAL; + + if (src_reg == BPF_PSEUDO_CALL && outgoing_stack_args > 0) { + n_stack_args = outgoing_stack_args / 8; + base_off = -(prog_stack_depth + incoming_stack_arg_depth); + ip += push_stack_args(&prog, base_off, n_stack_args, 1); + } + + if (src_reg != BPF_PSEUDO_CALL && insn->off > 0) { + kfunc_stack_args = insn->off; + stack_args = kfunc_stack_args > 1 ? kfunc_stack_args - 1 : 0; + base_off = -(prog_stack_depth + incoming_stack_arg_depth); + ip += push_stack_args(&prog, base_off, kfunc_stack_args, 2); + + /* mov r9, [rbp + base_off - 8] */ + off = base_off - 8; + if (is_imm8(off)) { + EMIT4(0x4C, 0x8B, 0x4D, off); + ip += 4; + } else { + EMIT3_off32(0x4C, 0x8B, 0x8D, off); + ip += 7; + } + } if (priv_frame_ptr) { push_r9(&prog); ip += 2; @@ -2458,6 +2581,14 @@ st: if (is_imm8(insn->off)) return -EINVAL; if (priv_frame_ptr) pop_r9(&prog); + if (stack_args > 0) { + /* add rsp, stack_args * 8 */ + EMIT4(0x48, 0x83, 0xC4, stack_args * 8); + } + if (src_reg == BPF_PSEUDO_CALL && outgoing_stack_args > 0) { + /* add rsp, outgoing_stack_args */ + EMIT4(0x48, 0x83, 0xC4, outgoing_stack_args); + } break; } -- 2.52.0