In specific use cases combining tailcalls and BPF-to-BPF calls, MAX_TAIL_CALL_CNT won't work because of missing tail_call_cnt back-propagation from callee to caller。This patch fixes this tailcall issue caused by abusing the tailcall in bpf2bpf feature on LoongArch like the way of "bpf, x64: Fix tailcall hierarchy". push tail_call_cnt_ptr and tail_call_cnt into the stack, tail_call_cnt_ptr is passed between tailcall and bpf2bpf, uses tail_call_cnt_ptr to increment tail_call_cnt. Fixes: bb035ef0cc91 ("LoongArch: BPF: Support mixing bpf2bpf and tailcalls") Fixes: 5dc615520c4d ("LoongArch: Add BPF JIT support") Reviewed-by: Hengqi Chen Signed-off-by: Haoran Jiang --- arch/loongarch/net/bpf_jit.c | 159 ++++++++++++++++++++++++----------- 1 file changed, 112 insertions(+), 47 deletions(-) diff --git a/arch/loongarch/net/bpf_jit.c b/arch/loongarch/net/bpf_jit.c index 5da41ccc6698..2109748c4733 100644 --- a/arch/loongarch/net/bpf_jit.c +++ b/arch/loongarch/net/bpf_jit.c @@ -17,10 +17,7 @@ #define LOONGARCH_BPF_FENTRY_NBYTES (LOONGARCH_LONG_JUMP_NINSNS * 4) #define REG_TCC LOONGARCH_GPR_A6 -#define TCC_SAVED LOONGARCH_GPR_S5 - -#define SAVE_RA BIT(0) -#define SAVE_TCC BIT(1) +#define BPF_TAIL_CALL_CNT_PTR_STACK_OFF(stack) (round_up(stack, 16) - 80) static const int regmap[] = { /* return value from in-kernel function, and exit value for eBPF program */ @@ -42,32 +39,59 @@ static const int regmap[] = { [BPF_REG_AX] = LOONGARCH_GPR_T0, }; -static void mark_call(struct jit_ctx *ctx) +static void prepare_bpf_tail_call_cnt(struct jit_ctx *ctx, int *store_offset) { - ctx->flags |= SAVE_RA; -} + const struct bpf_prog *prog = ctx->prog; + const bool is_main_prog = !bpf_is_subprog(prog); -static void mark_tail_call(struct jit_ctx *ctx) -{ - ctx->flags |= SAVE_TCC; -} + if (is_main_prog) { + /* + * LOONGARCH_GPR_T3 = MAX_TAIL_CALL_CNT + * if (REG_TCC > T3 ) + * std REG_TCC -> LOONGARCH_GPR_SP + store_offset + * else + * std REG_TCC -> LOONGARCH_GPR_SP + store_offset + * REG_TCC = LOONGARCH_GPR_SP + store_offset + * + * std REG_TCC -> LOONGARCH_GPR_SP + store_offset + * + * The purpose of this code is to first push the TCC into stack, + * and then push the address of TCC into stack. + * In cases where bpf2bpf and tailcall are used in combination, + * the value in REG_TCC may be a count or an address, + * these two cases need to be judged and handled separately。 + * + */ + emit_insn(ctx, addid, LOONGARCH_GPR_T3, LOONGARCH_GPR_ZERO, MAX_TAIL_CALL_CNT); + *store_offset -= sizeof(long); -static bool seen_call(struct jit_ctx *ctx) -{ - return (ctx->flags & SAVE_RA); -} + emit_cond_jmp(ctx, BPF_JGT, REG_TCC, LOONGARCH_GPR_T3, 4); -static bool seen_tail_call(struct jit_ctx *ctx) -{ - return (ctx->flags & SAVE_TCC); -} + /* If REG_TCC < MAX_TAIL_CALL_CNT, the value in REG_TCC is a count, + * push TCC into stack + */ + emit_insn(ctx, std, REG_TCC, LOONGARCH_GPR_SP, *store_offset); -static u8 tail_call_reg(struct jit_ctx *ctx) -{ - if (seen_call(ctx)) - return TCC_SAVED; + /* Push the address of TCC into the stack */ + emit_insn(ctx, addid, REG_TCC, LOONGARCH_GPR_SP, *store_offset); + + emit_uncond_jmp(ctx, 2); + + /* If REG_TCC > MAX_TAIL_CALL_CNT, the value in REG_TCC is an address, + * push TCC_ptr into stack + */ + emit_insn(ctx, std, REG_TCC, LOONGARCH_GPR_SP, *store_offset); + + *store_offset -= sizeof(long); + emit_insn(ctx, std, REG_TCC, LOONGARCH_GPR_SP, *store_offset); + + } else { + *store_offset -= sizeof(long); + emit_insn(ctx, std, REG_TCC, LOONGARCH_GPR_SP, *store_offset); - return REG_TCC; + *store_offset -= sizeof(long); + emit_insn(ctx, std, REG_TCC, LOONGARCH_GPR_SP, *store_offset); + } } /* @@ -90,6 +114,10 @@ static u8 tail_call_reg(struct jit_ctx *ctx) * | $s4 | * +-------------------------+ * | $s5 | + * +-------------------------+ + * | tcc | + * +-------------------------+ + * | tcc_ptr | * +-------------------------+ <--BPF_REG_FP * | prog->aux->stack_depth | * | (optional) | @@ -99,11 +127,14 @@ static u8 tail_call_reg(struct jit_ctx *ctx) static void build_prologue(struct jit_ctx *ctx) { int stack_adjust = 0, store_offset, bpf_stack_adjust; + const struct bpf_prog *prog = ctx->prog; + const bool is_main_prog = !bpf_is_subprog(prog); bpf_stack_adjust = round_up(ctx->prog->aux->stack_depth, 16); - /* To store ra, fp, s0, s1, s2, s3, s4 and s5. */ - stack_adjust += sizeof(long) * 8; + /* To store ra, fp, s0, s1, s2, s3, s4, s5, tcc and tcc_ptr */ + stack_adjust += sizeof(long) * 10; + stack_adjust = round_up(stack_adjust, 16); stack_adjust += bpf_stack_adjust; @@ -116,11 +147,12 @@ static void build_prologue(struct jit_ctx *ctx) emit_insn(ctx, nop); /* - * First instruction initializes the tail call count (TCC). - * On tail call we skip this instruction, and the TCC is + * First instruction initializes the tail call count (TCC) register + * to zero. On tail call we skip this instruction, and the TCC is * passed in REG_TCC from the caller. */ - emit_insn(ctx, addid, REG_TCC, LOONGARCH_GPR_ZERO, MAX_TAIL_CALL_CNT); + if (is_main_prog) + emit_insn(ctx, addid, REG_TCC, LOONGARCH_GPR_ZERO, 0); emit_insn(ctx, addid, LOONGARCH_GPR_SP, LOONGARCH_GPR_SP, -stack_adjust); @@ -148,20 +180,14 @@ static void build_prologue(struct jit_ctx *ctx) store_offset -= sizeof(long); emit_insn(ctx, std, LOONGARCH_GPR_S5, LOONGARCH_GPR_SP, store_offset); + prepare_bpf_tail_call_cnt(ctx, &store_offset); + + emit_insn(ctx, addid, LOONGARCH_GPR_FP, LOONGARCH_GPR_SP, stack_adjust); if (bpf_stack_adjust) emit_insn(ctx, addid, regmap[BPF_REG_FP], LOONGARCH_GPR_SP, bpf_stack_adjust); - /* - * Program contains calls and tail calls, so REG_TCC need - * to be saved across calls. - */ - if (seen_tail_call(ctx) && seen_call(ctx)) - move_reg(ctx, TCC_SAVED, REG_TCC); - else - emit_insn(ctx, nop); - ctx->stack_size = stack_adjust; } @@ -194,6 +220,16 @@ static void __build_epilogue(struct jit_ctx *ctx, bool is_tail_call) load_offset -= sizeof(long); emit_insn(ctx, ldd, LOONGARCH_GPR_S5, LOONGARCH_GPR_SP, load_offset); + /* + * When push into the stack, follow the order of tcc then tcc_ptr. + * When pop from the stack, first pop tcc_ptr followed by tcc + */ + load_offset -= 2*sizeof(long); + emit_insn(ctx, ldd, REG_TCC, LOONGARCH_GPR_SP, load_offset); + + load_offset += sizeof(long); + emit_insn(ctx, ldd, REG_TCC, LOONGARCH_GPR_SP, load_offset); + emit_insn(ctx, addid, LOONGARCH_GPR_SP, LOONGARCH_GPR_SP, stack_adjust); if (!is_tail_call) { @@ -206,7 +242,7 @@ static void __build_epilogue(struct jit_ctx *ctx, bool is_tail_call) * Call the next bpf prog and skip the first instruction * of TCC initialization. */ - emit_insn(ctx, jirl, LOONGARCH_GPR_ZERO, LOONGARCH_GPR_T3, 1); + emit_insn(ctx, jirl, LOONGARCH_GPR_ZERO, LOONGARCH_GPR_T3, 6); } } @@ -228,7 +264,7 @@ bool bpf_jit_supports_far_kfunc_call(void) static int emit_bpf_tail_call(struct jit_ctx *ctx, int insn) { int off; - u8 tcc = tail_call_reg(ctx); + int tcc_ptr_off = BPF_TAIL_CALL_CNT_PTR_STACK_OFF(ctx->stack_size); u8 a1 = LOONGARCH_GPR_A1; u8 a2 = LOONGARCH_GPR_A2; u8 t1 = LOONGARCH_GPR_T1; @@ -257,11 +293,15 @@ static int emit_bpf_tail_call(struct jit_ctx *ctx, int insn) goto toofar; /* - * if (--TCC < 0) - * goto out; + * if ((*tcc_ptr)++ >= MAX_TAIL_CALL_CNT) + * goto out; */ - emit_insn(ctx, addid, REG_TCC, tcc, -1); - if (emit_tailcall_jmp(ctx, BPF_JSLT, REG_TCC, LOONGARCH_GPR_ZERO, jmp_offset) < 0) + emit_insn(ctx, ldd, REG_TCC, LOONGARCH_GPR_SP, tcc_ptr_off); + emit_insn(ctx, ldd, t3, REG_TCC, 0); + emit_insn(ctx, addid, t3, t3, 1); + emit_insn(ctx, std, t3, REG_TCC, 0); + emit_insn(ctx, addid, t2, LOONGARCH_GPR_ZERO, MAX_TAIL_CALL_CNT); + if (emit_tailcall_jmp(ctx, BPF_JSGT, t3, t2, jmp_offset) < 0) goto toofar; /* @@ -482,6 +522,7 @@ static int build_insn(const struct bpf_insn *insn, struct jit_ctx *ctx, bool ext const s16 off = insn->off; const s32 imm = insn->imm; const bool is32 = BPF_CLASS(insn->code) == BPF_ALU || BPF_CLASS(insn->code) == BPF_JMP32; + int tcc_ptr_off; switch (code) { /* dst = src */ @@ -908,12 +949,17 @@ static int build_insn(const struct bpf_insn *insn, struct jit_ctx *ctx, bool ext /* function call */ case BPF_JMP | BPF_CALL: - mark_call(ctx); ret = bpf_jit_get_func_addr(ctx->prog, insn, extra_pass, &func_addr, &func_addr_fixed); if (ret < 0) return ret; + if (insn->src_reg == BPF_PSEUDO_CALL) { + tcc_ptr_off = BPF_TAIL_CALL_CNT_PTR_STACK_OFF(ctx->stack_size); + emit_insn(ctx, ldd, REG_TCC, LOONGARCH_GPR_SP, tcc_ptr_off); + } + + move_addr(ctx, t1, func_addr); emit_insn(ctx, jirl, LOONGARCH_GPR_RA, t1, 0); @@ -924,7 +970,6 @@ static int build_insn(const struct bpf_insn *insn, struct jit_ctx *ctx, bool ext /* tail call */ case BPF_JMP | BPF_TAIL_CALL: - mark_tail_call(ctx); if (emit_bpf_tail_call(ctx, i) < 0) return -EINVAL; break; @@ -1592,7 +1637,7 @@ static int __arch_prepare_bpf_trampoline(struct jit_ctx *ctx, struct bpf_tramp_i { int i; int stack_size = 0, nargs = 0; - int retval_off, args_off, nargs_off, ip_off, run_ctx_off, sreg_off; + int retval_off, args_off, nargs_off, ip_off, run_ctx_off, sreg_off, tcc_ptr_off; struct bpf_tramp_links *fentry = &tlinks[BPF_TRAMP_FENTRY]; struct bpf_tramp_links *fexit = &tlinks[BPF_TRAMP_FEXIT]; struct bpf_tramp_links *fmod_ret = &tlinks[BPF_TRAMP_MODIFY_RETURN]; @@ -1628,6 +1673,7 @@ static int __arch_prepare_bpf_trampoline(struct jit_ctx *ctx, struct bpf_tramp_i * * FP - sreg_off [ callee saved reg ] * + * FP - tcc_ptr_off [ tail_call_cnt_ptr ] */ if (m->nr_args > LOONGARCH_MAX_REG_ARGS) @@ -1670,6 +1716,13 @@ static int __arch_prepare_bpf_trampoline(struct jit_ctx *ctx, struct bpf_tramp_i stack_size += 8; sreg_off = stack_size; + /* room of trampoline frame to store tail_call_cnt_ptr */ + if (flags & BPF_TRAMP_F_TAIL_CALL_CTX) { + stack_size += 8; + tcc_ptr_off = stack_size; + } + + stack_size = round_up(stack_size, 16); if (!is_struct_ops) { @@ -1698,6 +1751,10 @@ static int __arch_prepare_bpf_trampoline(struct jit_ctx *ctx, struct bpf_tramp_i emit_insn(ctx, addid, LOONGARCH_GPR_FP, LOONGARCH_GPR_SP, stack_size); } + if (flags & BPF_TRAMP_F_TAIL_CALL_CTX) + emit_insn(ctx, std, REG_TCC, LOONGARCH_GPR_FP, -tcc_ptr_off); + + /* callee saved register S1 to pass start time */ emit_insn(ctx, std, LOONGARCH_GPR_S1, LOONGARCH_GPR_FP, -sreg_off); @@ -1745,6 +1802,10 @@ static int __arch_prepare_bpf_trampoline(struct jit_ctx *ctx, struct bpf_tramp_i if (flags & BPF_TRAMP_F_CALL_ORIG) { restore_args(ctx, m->nr_args, args_off); + + if (flags & BPF_TRAMP_F_TAIL_CALL_CTX) + emit_insn(ctx, ldd, REG_TCC, LOONGARCH_GPR_FP, -tcc_ptr_off); + ret = emit_call(ctx, (const u64)orig_call); if (ret) goto out; @@ -1789,6 +1850,10 @@ static int __arch_prepare_bpf_trampoline(struct jit_ctx *ctx, struct bpf_tramp_i emit_insn(ctx, ldd, LOONGARCH_GPR_S1, LOONGARCH_GPR_FP, -sreg_off); + if (flags & BPF_TRAMP_F_TAIL_CALL_CTX) + emit_insn(ctx, ldd, REG_TCC, LOONGARCH_GPR_FP, -tcc_ptr_off); + + if (!is_struct_ops) { /* trampoline called from function entry */ emit_insn(ctx, ldd, LOONGARCH_GPR_T0, LOONGARCH_GPR_SP, stack_size - 8); -- 2.43.0