mht.wtf A blog about computer science, programming, and whatnot.

Code Generation and Merge Sort

April 24, 2019

I was reading a few pages of Knuths The Art of Computer Programming, Volume 4A about “branchless computation” (p. 180) in which he demonstrates how to get rid of branches by using conditional instructions. As an instructive example he consideres the inner part of merge sort, in which we are to merge two sorted lists of numbers into one bigger list of the numbers. The description as given by Knuth is as follows:

If $x_i < y_j$ set $z_k \gets x_i$, $i \gets i+1$, and go to x_done if $i = i_{max}$.

Otherwise set $z_k \gets y_i$, $j \gets j+1$, and go to *y_done* if $j = j_{max}$.

Then set $k \gets k+1$ and go to *z_done* if $k = k_{max}$.

$x$ and $y$ are the input lists, $z$ is the output merged list. $i$, $j$, and $k$ are loop indices for the three respective lists and the $_{max}$ variants are the lists length.

I got curious and decided to see how a standard optimizing compilier would handle this case, and whether writing the assmebly yourself would provide any gain in performance. After all, this is just slightly more complicated than the trivial examples used to show off good codegen, so it would not be unreasonable for the compiler to manage to fix a bad implementation of this. In addition, it would serve as a great excuse to finally learn how to write x86 .

Basics

Here’s the inner loop in C code:

void branching (uint64_t * xs, size_t xmax, uint64_t * ys, size_t ymax, uint64_t * zs, size_t zmax) { size_t i = 0 , j = 0 , k = 0 ; while (k < zmax) { if (xs[i] < ys[j]) { zs[k ++ ] = xs[i ++ ]; if (i == xmax) { // x_done memcpy(zs + k, ys + j, 8 * (zmax - k)); return ; } } else { zs[k ++ ] = ys[j ++ ]; if (j == ymax) { // y_done memcpy(zs + k, xs + i, 8 * (zmax - k)); return ; } } } // z_done }

This seems to be a more or less straight forward textbook implementation of the procedure, so it will do fine as a benchmark. As a quick check before going any deeper into this we can use godbolt.org to see whether this experiment is even worth doing. Godbolts x86-64 gcc 8.3 with -O3 spits out this (annotations are by me):

branching(unsigned long*, unsigned long, unsigned long*, unsigned long, unsigned long*, unsigned long): test r9, r9 ; if (r9 == 0) je .L15 ; goto .L15 push r13 ; xor eax, eax ; xor r11d, r11d ; j = 0 xor r10d, r10d ; i = 0 push r12 ; push rbp ; push rbx ; jmp .L2 ; .L17: add r10, 1 ; i++ mov QWORD PTR [r8-8+rax*8], rbp ; zs[k-1] = xi cmp r10, rsi ; if (i == xmax) je .L16 ; goto .L16 .L6: cmp r9, rax ; if (k == zmax) je .L1 ; goto .L1 .L2: lea r12, [rdi+r10*8] ; calculate xs + i lea r13, [rdx+r11*8] ; calculate ys + j add rax, 1 ; k++ mov rbp, QWORD PTR [r12] ; xi = xs[i] mov rbx, QWORD PTR [r13+0] ; yj = ys[j] cmp rbp, rbx ; if (xi < yj) jb .L17 ; goto .L17 add r11, 1 ; j++ mov QWORD PTR [r8-8+rax*8], rbx ; zs[k-1] = yj cmp r11, rcx ; if (j != ymax) jne .L6 ; goto .L6 sub r9, rax ; y_done pop rbx ; mov rsi, r12 ; pop rbp ; lea rdi, [r8+rax*8] ; pop r12 ; lea rdx, [0+r9*8] ; pop r13 ; jmp memcpy ; .L1: pop rbx ; z_done pop rbp ; pop r12 ; pop r13 ; ret ; .L16: sub r9, rax ; x_done pop rbx ; mov rsi, r13 ; pop rbp ; lea rdi, [r8+rax*8] ; pop r12 ; lea rdx, [0+r9*8] ; pop r13 ; jmp memcpy ; .L15: ret

Plenty of branches

Now, maybe it turns out that it doesn’t matter if we’re branching or not and that the compiler knows best. We could guess that the reason we’re still getting branches is because that’s really the best way to go here. After all “you can’t beat the compiler” seems to be the consensus in many programming circles. Let’s try to write a version in C without exessive use of branching. Then perhaps the compiler will generate different code, and we can see what that difference amounts to in terms of running time. We can adopt Knuth’s branchless version:

void nonbranching_but_branching (uint64_t * xs, size_t xmax, uint64_t * ys, size_t ymax, uint64_t * zs, size_t zmax) { size_t i = 0 , j = 0 , k = 0 ; uint64_t xi = xs[i], yj = ys[j]; while ((i < xmax) && (j < ymax) && (k < zmax)) { int64_t t = one_if_lt(xi - yj); yj = min(xi, yj); zs[k] = yj; i += t; xi = xs[i]; t ^= 1 ; j += t; yj = ys[j]; k += 1 ; } if (i == xmax) memcpy(zs + k, ys + j, 8 * (zmax - k)); if (j == ymax) memcpy(zs + k, xs + i, 8 * (zmax - k)); }

What is going on, you might ask? The general idea is to first get min(xi, yj) , and then have a number t that’s 1 if xi < yj and 0 otherwise: we can add t to i , since t=1 if we just wrote xi to zs[k] . Then we can xor it with 1 , effectively flipping 1 to 0 and 0 to 1 , and then add t^1 to j ; this causes either i or j to be incremented but not both. We used two convenience functions here, one_if_lt and min , both implemented straight forward with branching, hoping that the compiler will figure this out for us, now that the branches are much smaller.

Next, if we cheat a litte and assume that the highest bit in the numbers are never set we can get rid of those branches:

void nonbranching (uint64_t * xs, size_t xmax, uint64_t * ys, size_t ymax, uint64_t * zs, size_t zmax) { size_t i = 0 , j = 0 , k = 0 ; uint64_t xi = xs[i], yj = ys[j]; while ((i < xmax) && (j < ymax) && (k < zmax)) { uint64_t neg = (xi - yj) >> 63 ; yj = neg * xi + ( 1 - neg) * yj; zs[k] = yj; i += neg; xi = xs[i]; neg ^= 1 ; j += neg; yj = ys[j]; k += 1 ; } if (i == xmax) memcpy(zs + k, ys + j, 8 * (zmax - k)); if (j == ymax) memcpy(zs + k, xs + i, 8 * (zmax - k)); }

What is up with (xi - yj) >> 63 you may ask? This result is negative if xi < yj , and so it will overflow and its most significant bit will be set. Then we shift down logically (since we’re using unsigned integers) so the bits that are filled in are all zeroes. Since the width is 64, we effectively move the upper bit to the lowest position while setting all other bits to zero.

Knuth has another quirk, namely that his arrays usually points to the end of the array, and his indices are negative, going from -xmax up to 0 instead of the more standard going from 0 up to xmax . One consequence of this is that the termination check can be done with one comparison instead of three, by and ing together the three indices: since they are negative they have their most significant bit set, unless zero. Here’s both of the previous versions with this reversal trick:

void nonbranching_but_branching_reverse (uint64_t * xs, size_t xmax, uint64_t * ys, size_t ymax, uint64_t * zs, size_t zmax) { uint64_t * xse = xs + xmax; uint64_t * yse = ys + ymax; uint64_t * zse = zs + zmax; ssize_t i = - ((ssize_t) xmax); ssize_t j = - ((ssize_t) ymax); ssize_t k = - ((ssize_t) zmax); uint64_t xi = xse[i], yj = yse[j]; while (i & j & k) { uint64_t t = one_if_lt(xi - yj); yj = min(xi, yj); zse[k] = yj; i += t; xi = xse[i]; t ^= 1 ; j += t; yj = yse[j]; k += 1 ; } if (i == 0 ) memcpy(zse + k, yse + j, - 8 * k); if (j == 0 ) memcpy(zse + k, xse + i, - 8 * k); } void nonbranching_reverse (uint64_t * xs, size_t xmax, uint64_t * ys, size_t ymax, uint64_t * zs, size_t zmax) { uint64_t * xse = xs + xmax; uint64_t * yse = ys + ymax; uint64_t * zse = zs + zmax; ssize_t i = - ((ssize_t) xmax); ssize_t j = - ((ssize_t) ymax); ssize_t k = - ((ssize_t) zmax); uint64_t xi = xse[i], yj = yse[j]; while (i & j & k) { uint64_t neg = (xi - yj) >> 63 ; yj = neg * xi + ( 1 - neg) * yj; zse[k] = yj; i += neg; xi = xse[i]; neg ^= 1 ; j += neg; yj = yse[j]; k += 1 ; } if (i == 0 ) memcpy(zse + k, yse + j, - 8 * k); if (j == 0 ) memcpy(zse + k, xse + i, - 8 * k); }

Technically, I suppose we do assume that the length of the arrays are not >2**63 , so that they fit in an ssize_t , but considering that the address space of x86-64 is not 64 bits, but merely 48 bits, this is not a problem, even in theory.

Writing the ASM ourselves

Lastly, we can try to write the assembly ourselves. When translating the branch-free routine by Knuth into x86 there are a number of things to do. First we need to figure out how to get -1/0/+1 by comparing two variables, as MMIX s CMP instruction does. However, instead of trying to translate this line by line, which would end up with us having more instructions than needed, we should rather look more closely at what we’re doing, so that we really understand the minimal amount of work that we have to do.

We only need to do two things: compare $x_i$ and $y_i$ and load the smaller into a register, and increment either i or j . The former can be done using cmovl , and the latter can be done in a similar fasion as Knuth does it, which is basically what we’ve been doing up to this point in C. This is the version I ended up with (here in inline-GCC asm format):

1: mov %[minxy], %[yj] ; cmp %[xi], %[yj] ; minxy = min(xi, yj) cmovl %[minxy], %[xi] ; mov QWORD PTR [%[zse]+8*%[k]], %[minxy] ; zs[k] = minxy mov %[t], 0 ; t = 0 cmovl %[t], %[one] ; if xi < yj: t = 1 add %[i], %[t] ; i += t mov %[xi], QWORD PTR [%[xse]+8*%[i]] ; xi = xs[i] xor %[t], 1 ; t ^= 1 add %[j], %[t] ; j += t mov %[yj], QWORD PTR [%[yse]+8*%[j]] ; yj = ys[j] add %[k], 1 ; k += 1 mov %[u], %[i] ; and %[u], %[j] ; test %[u], %[k] ; if ((i & j & k) != 0) jnz 1b ; goto 1

There’s a few quirks here, like having a couple of mov instructions in between the second conditional load and the instruction it conditions on, and the fact that cmovl couldn’t take an immediate value, so I had to setup a register with only the value 1 in it. A sneaky detail to keep in mind is that when we set t = 0 we cannot use the trick of xor ing t with itself, since this will change the flags, causing the subsequent cmovl to be wrong.

Now we can take a look at the assembly generated from some of the other fuctions by using objdump -d . Our own programs are compiled with -O3 -march=native . Here is the inner loop in nonbranching_reverse :

<nonbranching_reverse>: 1ef0: mov rax,rdi 1ef3: sub rax,rsi 1ef6: shr rax,0x3f 1efa: mov rdx,r8 1efd: sub rdx,rax 1f00: imul rdx,rsi 1f04: imul rdi,rax 1f08: add rbp,rax 1f0b: xor rax,0x1 1f0f: add rdi,rdx 1f12: mov QWORD PTR [r13+r12*8+0x0],rdi 1f17: add rcx,rax 1f1a: inc r12 1f1d: mov rax,rbp 1f20: and rax,r12 1f23: mov rdi,QWORD PTR [rbx+rbp*8] 1f27: mov rsi,QWORD PTR [r10+rcx*8] 1f2b: test rax,rcx 1f2e: jne 1ef0 <nonbranching_reverse+0x40>

Sure looks a lot better than branching ! This seems more or less reasonable, but we can see that the multiplication trickery that we used to avoid the min branch takes up some space here; presumably it also takes some time. Maybe one little branch isn’t too bad though, and perhaps the compiler is more willingly to use conditional instructions if we use the ternary operator, like this:

void nonbranching_reverse_ternary (uint64_t * xs, size_t xmax, uint64_t * ys, size_t ymax, uint64_t * zs, size_t zmax) { uint64_t * xse = xs + xmax; uint64_t * yse = ys + ymax; uint64_t * zse = zs + zmax; ssize_t i = - ((ssize_t) xmax); ssize_t j = - ((ssize_t) ymax); ssize_t k = - ((ssize_t) zmax); uint64_t xi = xse[i], yj = yse[j]; while (i & j & k) { uint64_t ybig = (xi - yj) >> 63 ; yj = ybig ? xi : yj; zse[k] = yj; i += ybig; xi = xse[i]; ybig ^= 1 ; j += ybig; yj = yse[j]; k += 1 ; } if (i == 0 ) memcpy(zse + k, yse + j, - 8 * k); if (j == 0 ) memcpy(zse + k, xse + i, - 8 * k); }

This time, if we look at the assembly, we can see that the compiler is finally getting it: cmove !

2080: mov rax,yj ; 2083: sub rax,xi ; 2086: shr rax,0x3f ; t = (yj - xi) >> 63 208a: cmove yj,xi ; yj = t == 0 ? xi : yj 208e: add j,rax ; j += t 2091: mov QWORD PTR [zs+k*8],yj ; z[k] = yj 2096: xor rax,0x1 ; t ^= 1 209a: inc k ; k++ 209d: add i,rax ; i += t 20a0: mov rax,k ; 20a3: and rax,j ; t = k & j 20a6: mov yj,QWORD PTR [ys+j*8] ; yj = ys[j] 20aa: mov xi,QWORD PTR [xs+i*8] ; xi = xs[i] 20ae: test rax,i ; if ((i & j & k) != 0) 20b1: jne 2080 ; goto .2080

So we see it’s really the same! Curiously, the compiler turned our code around to have t be 1 if xi was the bigger, whereas our ybig was 1 if yj was the bigger.

Results

And now for the results! We fill two arrays with random elements and run branching on it, such that we get the merged array back. This is used as the ground truth which all other variations are checked agaist, in case we have messed up. Then we use clock_gettime to measure the wall clock time that we spend, per method. The following is running time in milliseconds where both lists are 2**25 elements long, averaged over 100 runs; 10 iterations per seed and 10 different seeds ( srand(i) for each iteration).

These are the numbers I got on a Intel i7-7500U@2.7GHz ( avg +/- var ):

branching: 30.998 +/- 0.001 nonbranching_but_branching: 27.330 +/- 0.002 nonbranching: 24.770 +/- 0.000 nonbranching_but_branching_reverse: 19.387 +/- 0.000 nonbranching_reverse: 20.015 +/- 0.000 nonbranching_reverse_ternary: 19.038 +/- 0.000 asm_nb_rev: 18.987 +/- 0.001

I also ran the suite on another machine with a Intel i5-8250U@1.60GHz, in order to see if there would be any significant difference:

branching: 31.405 +/- 0.034 nonbranching_but_branching: 27.646 +/- 0.097 nonbranching: 27.894 +/- 0.021 nonbranching_but_branching_reverse: 22.760 +/- 0.040 nonbranching_reverse: 21.284 +/- 0.050 nonbranching_reverse_ternary: 19.299 +/- 0.002 asm_nb_rev: 19.793 +/- 0.009

Interestingly, on this CPU our assembly is slightly slower than the ternary version; I guess this is due to us using a cmovl where the compiler generated version used the shifting trick.

Bonus: Sorting

We can’t possibly have done all this merging without making a proper mergesort in the end! Luckily for us, the merge part is really the only difficult part of the routine:

void merge_sort (uint64_t * xs, size_t n, uint64_t * buf) { if (n < 2 ) return ; size_t h = n / 2 ; merge_sort(xs, h, buf); merge_sort(xs + h, n - h, buf + h); merge(xs, h, xs + h, n - h, buf, n); memcpy(xs, buf, 8 * n); }

Unfortunately we have to merge to a buffer and then memcpy it back. Perhaps this is fixable: we can make the sorting routine either put the result in xs or in buf , and by having the recursive calls say which we can merge into the other, assuming both recursive calls agree(!). That is, if the recursive calls say that the sorted subarrays are in xs , we merge into buf and tell our caller that our result is in buf . At the end, we just need to make sure that the final sorted numbers are in xs .

void _sort_asm (uint64_t * xs, size_t n, uint64_t * buf, int * into_buf) { if (n < 2 ) { * into_buf = 0 ; return ; } size_t h = n / 2 ; int res_in_buf; _sort_asm(xs, h, buf, & res_in_buf); // WARNING: `res_in_buf` for the two calls is needs _sort_asm(xs + h, n - h, buf + h, & res_in_buf); // not be the same in the real world! * into_buf = res_in_buf ^ 1 ; if (res_in_buf) asm_nb_rev(buf, h, buf + h, n - h, xs, n); else asm_nb_rev(xs, h, xs + h, n - h, buf, n); } void sort_asm (uint64_t * xs, size_t n, uint64_t * buf) { int res_in_buf; _sort_asm(xs, n, buf, & res_in_buf); if (res_in_buf) { memcpy(xs, buf, 8 * n); } }

and similar, for the other variants. You might see the branch and wonder if we can remove it — I tried, by making an array {xs, buf} and index it with res_in_buf , but it caused a minor slowdown: maybe some branching is fine after all.

Here are the running times:

i7-7500U i5-8250U sort_branching: 369.479 +/- 0.047 393.762 +/- 0.082 sort_nonbranching_but_branching: 324.337 +/- 0.014 337.120 +/- 0.099 sort_nonbranching: 325.658 +/- 0.028 352.802 +/- 0.120 sort_nonbranching_but_branching_reverse: 279.237 +/- 0.164 287.799 +/- 0.154 sort_nonbranching_reverse: 283.927 +/- 0.033 299.277 +/- 0.929 sort_nonbranching_reverse_ternary: 270.668 +/- 0.009 278.644 +/- 1.677 sort_asm_nb_rev: 270.228 +/- 0.009 281.657 +/- 0.360

If you would like to run the suite yourself, the git repo is avaiable here.

Thanks for reading.



This work is licensed under a Creative Commons Attribution-ShareAlike 4.0 International License

mht | contact