Add x86_64 JIT support for BPF functions and kfuncs with more than 5 arguments. The extra arguments are passed through a stack area addressed by register r12 (BPF_REG_STACK_ARG_BASE) in BPF bytecode, which the JIT translates to RBP-relative accesses in native code. The JIT follows the native x86_64 calling convention for stack argument placement. Incoming stack args from the caller sit above the callee's frame pointer at [rbp + 16], [rbp + 24], etc., exactly where x86_64 expects them after CALL + PUSH RBP. Only the outgoing stack arg area is allocated below the program stack in the prologue. The native x86_64 stack layout for a function with incoming and outgoing stack args: high address ┌─────────────────────────┐ │ incoming stack arg N │ [rbp + 16 + (N-1)*8] (from caller) │ ... │ │ incoming stack arg 1 │ [rbp + 16] ├─────────────────────────┤ │ return address │ [rbp + 8] │ saved rbp │ [rbp] ├─────────────────────────┤ │ BPF program stack │ (stack_depth bytes) ├─────────────────────────┤ │ outgoing stack arg 1 │ [rbp - prog_stack_depth - outgoing_depth] │ ... │ (written via r12-relative STX/ST) │ outgoing stack arg M │ [rbp - prog_stack_depth - 8] ├─────────────────────────┤ │ callee-saved regs ... │ (pushed after sub rsp) └─────────────────────────┘ rsp low address BPF r12-relative offsets are translated to native RBP-relative offsets with two formulas: - Incoming args (load: -off <= incoming_depth): native_off = 8 - bpf_off → [rbp + 16 + ...] - Outgoing args (store: -off > incoming_depth): native_off = -(bpf_prog_stack + stack_arg_depth + 8) - bpf_off Since callee-saved registers are pushed below the outgoing area, outgoing args are not at [rsp] at call time. Therefore, for both BPF-to-BPF calls and kfunc calls, outgoing args are explicitly pushed from the outgoing area onto the stack before CALL and rsp is restored after return. For kfunc calls specifically, arg 6 is loaded into R9 and args 7+ are pushed onto the native stack, per the x86_64 calling convention. Signed-off-by: Yonghong Song --- arch/x86/net/bpf_jit_comp.c | 135 ++++++++++++++++++++++++++++++++++-- 1 file changed, 129 insertions(+), 6 deletions(-) diff --git a/arch/x86/net/bpf_jit_comp.c b/arch/x86/net/bpf_jit_comp.c index 32864dbc2c4e..206f342a0ca0 100644 --- a/arch/x86/net/bpf_jit_comp.c +++ b/arch/x86/net/bpf_jit_comp.c @@ -390,6 +390,28 @@ static void pop_callee_regs(u8 **pprog, bool *callee_regs_used) *pprog = prog; } +/* Push stack args from [rbp + outgoing_base + (k - 1) * 8] in reverse order. */ +static int push_stack_args(u8 **pprog, s32 outgoing_base, int from, int to) +{ + u8 *prog = *pprog; + int k, bytes = 0; + s32 off; + + for (k = from; k >= to; k--) { + off = outgoing_base + (k - 1) * 8; + /* push qword [rbp + off] */ + if (is_imm8(off)) { + EMIT3(0xFF, 0x75, off); + bytes += 3; + } else { + EMIT2_off32(0xFF, 0xB5, off); + bytes += 6; + } + } + *pprog = prog; + return bytes; +} + static void emit_nops(u8 **pprog, int len) { u8 *prog = *pprog; @@ -1664,16 +1686,33 @@ static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image, u8 *rw_image int i, excnt = 0; int ilen, proglen = 0; u8 *prog = temp; - u32 stack_depth; + u16 stack_arg_depth, incoming_stack_arg_depth, outgoing_stack_arg_depth; + u32 prog_stack_depth, stack_depth; + bool has_stack_args; int err; stack_depth = bpf_prog->aux->stack_depth; + stack_arg_depth = bpf_prog->aux->stack_arg_depth; + incoming_stack_arg_depth = bpf_prog->aux->incoming_stack_arg_depth; + outgoing_stack_arg_depth = stack_arg_depth - incoming_stack_arg_depth; priv_stack_ptr = bpf_prog->aux->priv_stack_ptr; if (priv_stack_ptr) { priv_frame_ptr = priv_stack_ptr + PRIV_STACK_GUARD_SZ + round_up(stack_depth, 8); stack_depth = 0; } + /* + * Save program stack depth before adding outgoing stack arg space. + * Incoming stack args are read directly from [rbp + 16 + ...]. + * Only the outgoing stack arg area is allocated below the + * program stack. Outgoing args written here become the callee's + * incoming args. + */ + prog_stack_depth = round_up(stack_depth, 8); + if (outgoing_stack_arg_depth) + stack_depth += outgoing_stack_arg_depth; + has_stack_args = stack_arg_depth > 0; + arena_vm_start = bpf_arena_get_kern_vm_start(bpf_prog->aux->arena); user_vm_start = bpf_arena_get_user_vm_start(bpf_prog->aux->arena); @@ -1715,13 +1754,14 @@ static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image, u8 *rw_image prog = temp; for (i = 1; i <= insn_cnt; i++, insn++) { + bool adjust_stack_arg_off = false; const s32 imm32 = insn->imm; u32 dst_reg = insn->dst_reg; u32 src_reg = insn->src_reg; u8 b2 = 0, b3 = 0; u8 *start_of_ldx; s64 jmp_offset; - s16 insn_off; + s32 insn_off; u8 jmp_cond; u8 *func; int nops; @@ -1734,6 +1774,37 @@ static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image, u8 *rw_image dst_reg = X86_REG_R9; } + if (has_stack_args) { + u8 class = BPF_CLASS(insn->code); + + if (class == BPF_LDX && + src_reg == BPF_REG_STACK_ARG_BASE) { + src_reg = BPF_REG_FP; + adjust_stack_arg_off = true; + } + if ((class == BPF_STX || class == BPF_ST) && + dst_reg == BPF_REG_STACK_ARG_BASE) { + dst_reg = BPF_REG_FP; + adjust_stack_arg_off = true; + } + } + + /* + * Translate BPF r12-relative offset to native RBP-relative: + * + * Incoming args (load: offset >= -incoming_depth): + * BPF: r12 + bpf_off = r12 - k * 8 (k = 1,2,...) for incoming arg k + * Native: [rbp + 8 + k * 8] + * Formula: native_off = 8 + k * 8 = 8 - bpf_off + * + * Outgoing args (store: offset < -incoming_depth): + * BPF: r12 + bpf_off = r12 - (incoming + k * 8) for outgoing arg k + * Native: [rbp - prog_stack_depth - outgoing + (k - 1) * 8] + * Formula: native_off = -(prog_stack_depth + outgoing) + (k - 1) * 8 + * = -(prog_stack_depth + outgoing + incoming + 8) - bpf_off + * = -(prog_stack_depth + stack_arg_depth + 8) - bpf_off + */ + switch (insn->code) { /* ALU */ case BPF_ALU | BPF_ADD | BPF_X: @@ -2131,10 +2202,13 @@ static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image, u8 *rw_image case BPF_ST | BPF_MEM | BPF_DW: EMIT2(add_1mod(0x48, dst_reg), 0xC7); -st: if (is_imm8(insn->off)) - EMIT2(add_1reg(0x40, dst_reg), insn->off); +st: insn_off = insn->off; + if (adjust_stack_arg_off) + insn_off = -(prog_stack_depth + stack_arg_depth + 8) - insn_off; + if (is_imm8(insn_off)) + EMIT2(add_1reg(0x40, dst_reg), insn_off); else - EMIT1_off32(add_1reg(0x80, dst_reg), insn->off); + EMIT1_off32(add_1reg(0x80, dst_reg), insn_off); EMIT(imm32, bpf_size_to_x86_bytes(BPF_SIZE(insn->code))); break; @@ -2144,7 +2218,10 @@ st: if (is_imm8(insn->off)) case BPF_STX | BPF_MEM | BPF_H: case BPF_STX | BPF_MEM | BPF_W: case BPF_STX | BPF_MEM | BPF_DW: - emit_stx(&prog, BPF_SIZE(insn->code), dst_reg, src_reg, insn->off); + insn_off = insn->off; + if (adjust_stack_arg_off) + insn_off = -(prog_stack_depth + stack_arg_depth + 8) - insn_off; + emit_stx(&prog, BPF_SIZE(insn->code), dst_reg, src_reg, insn_off); break; case BPF_ST | BPF_PROBE_MEM32 | BPF_B: @@ -2243,6 +2320,8 @@ st: if (is_imm8(insn->off)) case BPF_LDX | BPF_PROBE_MEMSX | BPF_H: case BPF_LDX | BPF_PROBE_MEMSX | BPF_W: insn_off = insn->off; + if (adjust_stack_arg_off) + insn_off = 8 - insn_off; if (BPF_MODE(insn->code) == BPF_PROBE_MEM || BPF_MODE(insn->code) == BPF_PROBE_MEMSX) { @@ -2441,6 +2520,7 @@ st: if (is_imm8(insn->off)) /* call */ case BPF_JMP | BPF_CALL: { u8 *ip = image + addrs[i - 1]; + int stack_args = 0; func = (u8 *) __bpf_call_base + imm32; if (src_reg == BPF_PSEUDO_CALL && tail_call_reachable) { @@ -2449,6 +2529,41 @@ st: if (is_imm8(insn->off)) } if (!imm32) return -EINVAL; + + if (src_reg == BPF_PSEUDO_CALL && outgoing_stack_arg_depth > 0) { + /* + * BPF-to-BPF calls: push outgoing stack args from + * the outgoing area onto the stack before CALL. + * The outgoing area is at [rbp - prog_stack - outgoing], + * but rsp is below that due to callee-saved reg pushes, + * so we must explicitly push args for the callee. + */ + s32 outgoing_base = -(prog_stack_depth + outgoing_stack_arg_depth); + int n_args = outgoing_stack_arg_depth / 8; + + ip += push_stack_args(&prog, outgoing_base, n_args, 1); + } + + if (src_reg != BPF_PSEUDO_CALL && insn->off > 0) { + /* Kfunc calls: arg 6 → R9, args 7+ → push. */ + s32 outgoing_base = -(prog_stack_depth + outgoing_stack_arg_depth); + int kfunc_stack_args = insn->off; + + stack_args = kfunc_stack_args > 1 ? kfunc_stack_args - 1 : 0; + + /* Push args 7+ in reverse order */ + if (stack_args > 0) + ip += push_stack_args(&prog, outgoing_base, kfunc_stack_args, 2); + + /* mov r9, [rbp + outgoing_base] (arg 6) */ + if (is_imm8(outgoing_base)) { + EMIT4(0x4C, 0x8B, 0x4D, outgoing_base); + ip += 4; + } else { + EMIT3_off32(0x4C, 0x8B, 0x8D, outgoing_base); + ip += 7; + } + } if (priv_frame_ptr) { push_r9(&prog); ip += 2; @@ -2458,6 +2573,14 @@ st: if (is_imm8(insn->off)) return -EINVAL; if (priv_frame_ptr) pop_r9(&prog); + if (stack_args > 0) { + /* add rsp, stack_args * 8 */ + EMIT4(0x48, 0x83, 0xC4, stack_args * 8); + } + if (src_reg == BPF_PSEUDO_CALL && outgoing_stack_arg_depth > 0) { + /* add rsp, outgoing_stack_arg_depth */ + EMIT4(0x48, 0x83, 0xC4, outgoing_stack_arg_depth); + } break; } -- 2.52.0