mstill.dev / blog / Wasm interpreter optimisation: part 3[]

This post is part of a series; see part 1 and part 2.

In the last post I attempted to measure to understand the branch prediction behaviour of the interpreter. Whilst I expected lots of branch misprediction because of the single jmp instruction inherent in the giant-switch-statement design of the interpreter and indeed callgrind was hinting at very poor branch predicition this didn’t actually seem to be the case. Rather, callgrind was simulating a primitive branch predictor not as powerful as the branch predictor in my machines.

However, the interpreter was still quite slow. My rough and not very comprehensive testsuite-of-one: fib(39) was taking 26 seconds compared to lua taking 6 seconds. This difference gives hope that optimisation is possible.

Serendipously, around the time I wrote that last post, Andrew proposed adding language support for a labelled continue syntax that would produce more optimal machine code (including jmps for each instruction). Now that language support is still pending (but accepted) and so I couldn’t try that out (though I think you can sometimes convince the compiler to generate machine code of that shape). More seredipously, haberman joined the conversation and suggested the use of using tail calls for interpreter dispatch. Ensured tail calls are already supported in zig with the @call builtin.

What might this look like (this is not exactly the code as it currently is but I want to show the shape of it more than the details at this point…we’ll come back to details later)?

fn @"i32.add"(self: *Interpreter) void {
    const c2 = self.popOperand(u32);
    const c1 = self.popOperand(u32);

    self.pushOperandNoCheck(u32, c1 +% c2);

    return @call(.{ .modifier = .always_tail }, dispatch, .{ self });
}

fn nop(self: *Interpreter) void {
    return @call(.{ .modifier = .always_tail }, dispatch, .{ self });
}

inline fn dispatch(self: *Interpreter) void {
    const next_instr = self.code[self.ip++];

    return @call(.{ .modifier = .always_tail }, lookup[@enumToInt(next_instr)], .{ self });
}

const lookup = [256]InstructionFunction{nop, @"i32.add"...

So what do we have here? We’re implementing two (virtual, i.e. WebAssembly) instructions: nop and @"i32.add". With the tail call interpreter we get a native function for every virtual instruction; this is different to the giant switch statement approach where all the virtual instruction implementatiosn live inside one big function.

The final statement in each virtual instruction implementation is a call to return @call(.{ .modifier = .always_tail }, dispatch, .{ self });. The .modifier = .always_tail is what gives us a tail call. Rather than inline complete code that does the dispatch, I have a separate dispatch function which is declared as inline, which just allows us to write that code once but will be included…inline in the generated code for each virtual function.

What’s this lookup thing? We need to give @call the address of the function we want to call and we’ll do this here by just having a giant table of virtual instruction function pointers that can be indexed by the Opcode enum.

Hold on. Why do we need tail calls? What happens if we just call the function directly? If we don’t insist on a tail call a WebAssembly function which only uses a single virtual frame, let’s say a for loop counting up a number, that single virtual frame would allocate a number of native frames linear in the number of virtual instructions. We obviously don’t want this…we don’t want a WebAssembly loop to blow up the native stack on a large number. The tail call ensure that as we transition between virtual instructions we reuse the same native stack frame over and over again.

Another thing to note about tail calls is that we need to supply the same arguments of caller to callee.

Now, there’s a problem: in various places a virtual instruction implementation can error. So let’s have our tail calls !. E.g.

fn nop(self: *Interpreter) !void {
    return @call(.{ .modifier = .always_tail }, dispatch, .{ self });
}

Unfortunately this does not compile due to https://github.com/ziglang/zig/issues/5692. Okay, for the time being, we’re going to have to sacrifice some aesthetics. Instead of returning !void, we’ll continue to return void, but we’ll pass in an additional parameter being an optional pointer to an error:

fn nop(self: *Interpreter, err: ?*WasmError) void {
    return @call(.{ .modifier = .always_tail }, dispatch, .{ self, err });
}

This err value will be set to null when entering the intepreter. If we want to signal an error we will set err to a WasmError and return fully from the tail call. For example the @"unreachable" function would be:

fn @"unreachable"(self: *Interpreter, err: *?WasmError) void {
    err.* = error.TrapUnreachable;
}

In the initial code I wrote implementing the bare minimum number of virtual instructions for fib the runtime for fib(39) dropped from 26 seconds to 18 seconds. The assmebly code for that looks like:

0000000000207c50 <interpreter.Interpreter.impl_nop>:
  207c50:	48 8b 47 60          	mov    rax,QWORD PTR [rdi+0x60]
  207c54:	c5 f8 10 00          	vmovups xmm0,XMMWORD PTR [rax]
  207c58:	c5 f8 10 48 10       	vmovups xmm1,XMMWORD PTR [rax+0x10]
  207c5d:	c5 f8 10 50 20       	vmovups xmm2,XMMWORD PTR [rax+0x20]
  207c62:	c5 f8 10 58 30       	vmovups xmm3,XMMWORD PTR [rax+0x30]
  207c67:	c5 f8 29 5c 24 e8    	vmovaps XMMWORD PTR [rsp-0x18],xmm3
  207c6d:	c5 f8 29 54 24 d8    	vmovaps XMMWORD PTR [rsp-0x28],xmm2
  207c73:	c5 f8 29 4c 24 c8    	vmovaps XMMWORD PTR [rsp-0x38],xmm1
  207c79:	c5 f8 29 44 24 b8    	vmovaps XMMWORD PTR [rsp-0x48],xmm0
  207c7f:	48 83 c0 40          	add    rax,0x40
  207c83:	48 89 47 60          	mov    QWORD PTR [rdi+0x60],rax
  207c87:	48 ff 4f 68          	dec    QWORD PTR [rdi+0x68]
  207c8b:	0f b6 44 24 f0       	movzx  eax,BYTE PTR [rsp-0x10]
  207c90:	48 8d 74 24 b8       	lea    rsi,[rsp-0x48]
  207c95:	ff 24 c5 48 34 20 00 	jmp    QWORD PTR [rax*8+0x203448]
  207c9c:	0f 1f 40 00          	nop    DWORD PTR [rax+0x0]

Geez, look at all the vmovups / vmovaps stuff. Let’s ignore that for the moment. One of the issues with that initial code is I was using slices to demarcate the beginning and end of a continuation. As we iterate over the virtual instructions we’re incrementing a pointer (the start of the continuation) but also need to decrement the slice length. But in reality we just need a single virtual instruction pointer to track where we are in virtual code meaning that this instruction dec QWORD PTR [rdi+0x68] is unneeded. So there’s definitely one instruction we can get rid of. But there still seems to be a lot going on in the above.

Back in https://github.com/ziglang/zig/issues/8220, haberman talks about (also see his fantastic article about parsing protobufs) how the smaller the function the more opportunity the compiler has to optimise, so small changes like removing dec could potentially allow a greater improvement than you might expect. The compiler can potentially keep commonly used arguments in registers which will be fast and prevent pushing and popping to the (native) stack. We can choose which parameters we pass in as registers (and hopefully we can keep them there).

In my PR that merged the tail call dispatch, I got rid of all the slice continuations and resort to a usize representing the operand stack pointer, label stack pointer and frame stack pointer. I also promote the (virtual) instruction pointer ip to a function parameter along with a slice of []Instruction for the current module. The nop function implementation then looks like:

fn nop(self: *Interpreter, ip: usize, code: []Instruction, err: *?WasmError) void {
    return @call(.{ .modifier = .always_tail }, dispatch, .{ self, ip + 1, code, err });
}

This compiles to:

0000000000222fb0 <interpreter.Interpreter.nop>:
  222fb0:	48 ff c6             	inc    rsi
  222fb3:	4c 8b 02             	mov    r8,QWORD PTR [rdx]
  222fb6:	48 8d 04 76          	lea    rax,[rsi+rsi*2]
  222fba:	48 c1 e0 04          	shl    rax,0x4
  222fbe:	41 0f b6 44 00 28    	movzx  eax,BYTE PTR [r8+rax*1+0x28]
  222fc4:	ff 24 c5 40 61 21 00 	jmp    QWORD PTR [rax*8+0x216140]
  222fcb:	0f 1f 44 00 00       	nop    DWORD PTR [rax+rax*1+0x0]

…and you can see (partially) why we’ve got down from 18 seconds to 13 seconds. We no longer have the slice decrement and all the vmovups / vmovaps crap has disappeared. It is also worth looking at perf (perf -d -d -d stat ./fib) output from previous master and comparing it to the tail call dispatch.

Previous master:

fib(39) = 63245986

 Performance counter stats for './fib':

         25,520.93 msec task-clock                #    1.000 CPUs utilized
               113      context-switches          #    4.428 /sec
                10      cpu-migrations            #    0.392 /sec
                26      page-faults               #    1.019 /sec
    81,344,938,138      cycles                    #    3.187 GHz                      (28.57%)
    31,119,585,302      stalled-cycles-frontend   #   38.26% frontend cycles idle     (28.57%)
   128,751,238,578      instructions              #    1.58  insn per cycle
                                                  #    0.24  stalled cycles per insn  (35.71%)
    16,658,200,203      branches                  #  652.727 M/sec                    (35.71%)
       809,936,529      branch-misses             #    4.86% of all branches          (35.71%)
    60,359,256,646      L1-dcache-loads           #    2.365 G/sec                    (28.56%)
         4,012,985      L1-dcache-load-misses     #    0.01% of all L1-dcache accesses  (14.29%)
           200,827      LLC-loads                 #    7.869 K/sec                    (14.29%)
   <not supported>      LLC-load-misses
   <not supported>      L1-icache-loads
         4,631,769      L1-icache-load-misses                                         (21.43%)
    60,383,842,949      dTLB-loads                #    2.366 G/sec                    (21.43%)
           356,086      dTLB-load-misses          #    0.00% of all dTLB cache accesses  (14.29%)
           449,370      iTLB-loads                #   17.608 K/sec                    (14.29%)
       255,509,966      iTLB-load-misses          # 56859.60% of all iTLB cache accesses  (21.43%)
   <not supported>      L1-dcache-prefetches
           422,578      L1-dcache-prefetch-misses #   16.558 K/sec                    (28.57%)

      25.522272275 seconds time elapsed

      25.356742000 seconds user
       0.000960000 seconds sys

With tail call dispatch:

fib(39) = 63245986

 Performance counter stats for './fib':

         12,397.63 msec task-clock                #    1.000 CPUs utilized
               131      context-switches          #   10.567 /sec
                 2      cpu-migrations            #    0.161 /sec
                24      page-faults               #    1.936 /sec
    39,424,853,935      cycles                    #    3.180 GHz                      (28.56%)
    10,592,820,823      stalled-cycles-frontend   #   26.87% frontend cycles idle     (28.56%)
    78,865,250,519      instructions              #    2.00  insn per cycle
                                                  #    0.13  stalled cycles per insn  (35.71%)
     7,680,828,125      branches                  #  619.540 M/sec                    (35.72%)
       358,732,112      branch-misses             #    4.67% of all branches          (35.72%)
    35,630,774,927      L1-dcache-loads           #    2.874 G/sec                    (28.57%)
         2,375,000      L1-dcache-load-misses     #    0.01% of all L1-dcache accesses  (14.29%)
           393,521      LLC-loads                 #   31.742 K/sec                    (14.29%)
   <not supported>      LLC-load-misses
   <not supported>      L1-icache-loads
         2,973,606      L1-icache-load-misses                                         (21.44%)
    35,648,628,162      dTLB-loads                #    2.875 G/sec                    (21.42%)
           242,155      dTLB-load-misses          #    0.00% of all dTLB cache accesses  (14.28%)
           144,986      iTLB-loads                #   11.695 K/sec                    (14.28%)
           145,246      iTLB-load-misses          #  100.18% of all iTLB cache accesses  (21.41%)
   <not supported>      L1-dcache-prefetches
           385,581      L1-dcache-prefetch-misses #   31.101 K/sec                    (28.55%)

      12.399759040 seconds time elapsed

      12.304415000 seconds user
       0.002956000 seconds sys

We’re not seeing much difference in branch misprediction or L1 cache misses. Rather it seems like we’re just executing far fewer instructions (128,751,238,578 vs 78,865,250,519) due to better code generation.

Now, nop is the simplest instruction (almost…unreachable actually wins here because it doesn’t tail call but just returns), but we don’t necessarily get as clean code as this. For example loop:

0000000000223030 <interpreter.Interpreter.loop>:
  223030:	41 56                	push   r14
  223032:	53                   	push   rbx
  223033:	4c 8b 4f 40          	mov    r9,QWORD PTR [rdi+0x40]
  223037:	4c 3b 4f 38          	cmp    r9,QWORD PTR [rdi+0x38]
  22303b:	75 09                	jne    223046 <interpreter.Interpreter.loop+0x16>
  22303d:	66 c7 01 42 00       	mov    WORD PTR [rcx],0x42
  223042:	5b                   	pop    rbx
  223043:	41 5e                	pop    r14
  223045:	c3                   	ret
  223046:	4c 8b 02             	mov    r8,QWORD PTR [rdx]
  223049:	48 8d 04 76          	lea    rax,[rsi+rsi*2]
  22304d:	48 c1 e0 04          	shl    rax,0x4
  223051:	4d 8b 14 00          	mov    r10,QWORD PTR [r8+rax*1]
  223055:	4c 8b 5f 10          	mov    r11,QWORD PTR [rdi+0x10]
  223059:	48 8b 47 30          	mov    rax,QWORD PTR [rdi+0x30]
  22305d:	4d 29 d3             	sub    r11,r10
  223060:	48 89 f3             	mov    rbx,rsi
  223063:	48 c1 e3 04          	shl    rbx,0x4
  223067:	4c 8d 34 5b          	lea    r14,[rbx+rbx*2]
  22306b:	4f 8b 44 30 10       	mov    r8,QWORD PTR [r8+r14*1+0x10]
  223070:	49 8d 59 01          	lea    rbx,[r9+0x1]
  223074:	48 89 5f 40          	mov    QWORD PTR [rdi+0x40],rbx
  223078:	4b 8d 1c 49          	lea    rbx,[r9+r9*2]
  22307c:	4c 89 14 d8          	mov    QWORD PTR [rax+rbx*8],r10
  223080:	4c 89 44 d8 08       	mov    QWORD PTR [rax+rbx*8+0x8],r8
  223085:	4c 89 5c d8 10       	mov    QWORD PTR [rax+rbx*8+0x10],r11
  22308a:	48 ff c6             	inc    rsi
  22308d:	48 8b 02             	mov    rax,QWORD PTR [rdx]
  223090:	42 0f b6 44 30 58    	movzx  eax,BYTE PTR [rax+r14*1+0x58]
  223096:	5b                   	pop    rbx
  223097:	41 5e                	pop    r14
  223099:	ff 24 c5 40 61 21 00 	jmp    QWORD PTR [rax*8+0x216140]

where we see some register spilling at the top of the function with those values then restored at the end of the function.

Now, I’m not sure if my choice of function parameters is optimal. I will need to do some experimentation to see if there is more speedup to be gained by careful that careful choice (this is quite hard because there’s a lot of code to change for a simple change in parameters).

The other place where we can potentially get speedup is just limiting the amount of work done by particular instructions. For example, I think call is probably doing more work than it needs to.

That will hopefully be for the next article!