aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--arch/x86/net/bpf_jit_comp.c140
1 files changed, 112 insertions, 28 deletions
diff --git a/arch/x86/net/bpf_jit_comp.c b/arch/x86/net/bpf_jit_comp.c
index 1d4d50199293..b7a2911bda77 100644
--- a/arch/x86/net/bpf_jit_comp.c
+++ b/arch/x86/net/bpf_jit_comp.c
@@ -869,8 +869,31 @@ static void detect_reg_usage(struct bpf_insn *insn, int insn_cnt,
}
}
+static int emit_nops(u8 **pprog, int len)
+{
+ u8 *prog = *pprog;
+ int i, noplen, cnt = 0;
+
+ while (len > 0) {
+ noplen = len;
+
+ if (noplen > ASM_NOP_MAX)
+ noplen = ASM_NOP_MAX;
+
+ for (i = 0; i < noplen; i++)
+ EMIT1(ideal_nops[noplen][i]);
+ len -= noplen;
+ }
+
+ *pprog = prog;
+
+ return cnt;
+}
+
+#define INSN_SZ_DIFF (((addrs[i] - addrs[i - 1]) - (prog - temp)))
+
static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image,
- int oldproglen, struct jit_context *ctx)
+ int oldproglen, struct jit_context *ctx, bool jmp_padding)
{
bool tail_call_reachable = bpf_prog->aux->tail_call_reachable;
struct bpf_insn *insn = bpf_prog->insnsi;
@@ -880,7 +903,7 @@ static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image,
bool seen_exit = false;
u8 temp[BPF_MAX_INSN_SIZE + BPF_INSN_SAFETY];
int i, cnt = 0, excnt = 0;
- int proglen = 0;
+ int ilen, proglen = 0;
u8 *prog = temp;
int err;
@@ -894,7 +917,13 @@ static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image,
bpf_prog_was_classic(bpf_prog), tail_call_reachable,
bpf_prog->aux->func_idx != 0);
push_callee_regs(&prog, callee_regs_used);
- addrs[0] = prog - temp;
+
+ ilen = prog - temp;
+ if (image)
+ memcpy(image + proglen, temp, ilen);
+ proglen += ilen;
+ addrs[0] = proglen;
+ prog = temp;
for (i = 1; i <= insn_cnt; i++, insn++) {
const s32 imm32 = insn->imm;
@@ -903,8 +932,8 @@ static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image,
u8 b2 = 0, b3 = 0;
s64 jmp_offset;
u8 jmp_cond;
- int ilen;
u8 *func;
+ int nops;
switch (insn->code) {
/* ALU */
@@ -1502,6 +1531,30 @@ emit_cond_jmp: /* Convert BPF opcode to x86 */
}
jmp_offset = addrs[i + insn->off] - addrs[i];
if (is_imm8(jmp_offset)) {
+ if (jmp_padding) {
+ /* To keep the jmp_offset valid, the extra bytes are
+ * padded before the jump insn, so we substract the
+ * 2 bytes of jmp_cond insn from INSN_SZ_DIFF.
+ *
+ * If the previous pass already emits an imm8
+ * jmp_cond, then this BPF insn won't shrink, so
+ * "nops" is 0.
+ *
+ * On the other hand, if the previous pass emits an
+ * imm32 jmp_cond, the extra 4 bytes(*) is padded to
+ * keep the image from shrinking further.
+ *
+ * (*) imm32 jmp_cond is 6 bytes, and imm8 jmp_cond
+ * is 2 bytes, so the size difference is 4 bytes.
+ */
+ nops = INSN_SZ_DIFF - 2;
+ if (nops != 0 && nops != 4) {
+ pr_err("unexpected jmp_cond padding: %d bytes\n",
+ nops);
+ return -EFAULT;
+ }
+ cnt += emit_nops(&prog, nops);
+ }
EMIT2(jmp_cond, jmp_offset);
} else if (is_simm32(jmp_offset)) {
EMIT2_off32(0x0F, jmp_cond + 0x10, jmp_offset);
@@ -1524,11 +1577,55 @@ emit_cond_jmp: /* Convert BPF opcode to x86 */
else
jmp_offset = addrs[i + insn->off] - addrs[i];
- if (!jmp_offset)
- /* Optimize out nop jumps */
+ if (!jmp_offset) {
+ /*
+ * If jmp_padding is enabled, the extra nops will
+ * be inserted. Otherwise, optimize out nop jumps.
+ */
+ if (jmp_padding) {
+ /* There are 3 possible conditions.
+ * (1) This BPF_JA is already optimized out in
+ * the previous run, so there is no need
+ * to pad any extra byte (0 byte).
+ * (2) The previous pass emits an imm8 jmp,
+ * so we pad 2 bytes to match the previous
+ * insn size.
+ * (3) Similarly, the previous pass emits an
+ * imm32 jmp, and 5 bytes is padded.
+ */
+ nops = INSN_SZ_DIFF;
+ if (nops != 0 && nops != 2 && nops != 5) {
+ pr_err("unexpected nop jump padding: %d bytes\n",
+ nops);
+ return -EFAULT;
+ }
+ cnt += emit_nops(&prog, nops);
+ }
break;
+ }
emit_jmp:
if (is_imm8(jmp_offset)) {
+ if (jmp_padding) {
+ /* To avoid breaking jmp_offset, the extra bytes
+ * are padded before the actual jmp insn, so
+ * 2 bytes is substracted from INSN_SZ_DIFF.
+ *
+ * If the previous pass already emits an imm8
+ * jmp, there is nothing to pad (0 byte).
+ *
+ * If it emits an imm32 jmp (5 bytes) previously
+ * and now an imm8 jmp (2 bytes), then we pad
+ * (5 - 2 = 3) bytes to stop the image from
+ * shrinking further.
+ */
+ nops = INSN_SZ_DIFF - 2;
+ if (nops != 0 && nops != 3) {
+ pr_err("unexpected jump padding: %d bytes\n",
+ nops);
+ return -EFAULT;
+ }
+ cnt += emit_nops(&prog, INSN_SZ_DIFF - 2);
+ }
EMIT2(0xEB, jmp_offset);
} else if (is_simm32(jmp_offset)) {
EMIT1_off32(0xE9, jmp_offset);
@@ -1671,26 +1768,6 @@ static int invoke_bpf_prog(const struct btf_func_model *m, u8 **pprog,
return 0;
}
-static void emit_nops(u8 **pprog, unsigned int len)
-{
- unsigned int i, noplen;
- u8 *prog = *pprog;
- int cnt = 0;
-
- while (len > 0) {
- noplen = len;
-
- if (noplen > ASM_NOP_MAX)
- noplen = ASM_NOP_MAX;
-
- for (i = 0; i < noplen; i++)
- EMIT1(ideal_nops[noplen][i]);
- len -= noplen;
- }
-
- *pprog = prog;
-}
-
static void emit_align(u8 **pprog, u32 align)
{
u8 *target, *prog = *pprog;
@@ -2065,6 +2142,9 @@ struct x64_jit_data {
struct jit_context ctx;
};
+#define MAX_PASSES 20
+#define PADDING_PASSES (MAX_PASSES - 5)
+
struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog)
{
struct bpf_binary_header *header = NULL;
@@ -2074,6 +2154,7 @@ struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog)
struct jit_context ctx = {};
bool tmp_blinded = false;
bool extra_pass = false;
+ bool padding = false;
u8 *image = NULL;
int *addrs;
int pass;
@@ -2110,6 +2191,7 @@ struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog)
image = jit_data->image;
header = jit_data->header;
extra_pass = true;
+ padding = true;
goto skip_init_addrs;
}
addrs = kmalloc_array(prog->len + 1, sizeof(*addrs), GFP_KERNEL);
@@ -2135,8 +2217,10 @@ skip_init_addrs:
* may converge on the last pass. In such case do one more
* pass to emit the final image.
*/
- for (pass = 0; pass < 20 || image; pass++) {
- proglen = do_jit(prog, addrs, image, oldproglen, &ctx);
+ for (pass = 0; pass < MAX_PASSES || image; pass++) {
+ if (!padding && pass >= PADDING_PASSES)
+ padding = true;
+ proglen = do_jit(prog, addrs, image, oldproglen, &ctx, padding);
if (proglen <= 0) {
out_image:
image = NULL;