This post originated from an RSS feed registered with Python Buzz
by Andrew Dalke.
Original Post: Faster parity calulation
Feed Title: Andrew Dalke's writings
Feed URL: http://www.dalkescientific.com/writings/diary/diary-rss.xml
Feed Description: Writings from the software side of bioinformatics and chemical informatics, with a heaping of Python thrown in for good measure.
In
the previous essay I needed determine the parity
of a permutation. I used a Shell sort and counted the number of swaps
needed to order the list. The parity is even (or "0") if the number of
swaps is even, otherwise it's odd (or "1"). The final code was:
def parity_shell(values):
# Simple Shell sort; while O(N^2), we only deal with at most 4 values
values = list(values)
N = len(values)
num_swaps = 0
for i in range(N-1):
for j in range(i+1, N):
if values[i] > values[j]:
values[i], values[j] = values[j], values[i]
num_swaps += 1
return num_swaps % 2
I chose this implementation because it's easy to understand, and any
failure case is easily found. However, it's not fast.
It's tempting to use a better sort method. The Shell sort takes
quadratic time in the number of elements, while others take
O(N*ln(N)) time in the asymptotic case.
However, an asymptotic analysis is pointless for this case. The code
will only ever receive 3 terms (if there is a chiral hydrogen) or 4
terms because the code will only ever be called for tetrahedral
chirality.
Sorting networks
The first time I worked on this problem, I used a sorting
networks. A sorting network works on a fixed number of
elements. It uses a pre-determined set of pairwise comparisons, each
followed by a swap if needed. These are often used where code branches
are expensive, like in hardware or on a GPU. A sorting network takes
constant time, so can help minimize timing side-channel attacks, where
the time to sort may give some insight into what is being sorted.
A general algorithm to find a perfect sorting network for a given
value of 'N' element isn't known, though there are non-optimal
algorithms like Bose-Nelson and Batcher's odd–even mergesort, and
optimal solutions are known for up to N=10.
John M. Gamble has a CGI script which will generate a sorting
network for a given number of elements and choice of
algorithm. For N=4 it generates:
As a test, I'll sort every permutation of four values and make sure
the result is sorted. I could write the test cases out manually, but
it's easier to use the "permutations()"
function from Python's itertools module, as in this example with 3
values:
Here's the test, which confirms that the function sorts correctly:
> > > for permutation in itertools.permutations([0, 1, 2, 3]):
... permutation = list(permutation) # Convert the tuple to list sort4() can swap elements
... sort4(permutation)
... if permutation != [0, 1, 2, 3]:
... print("ERROR:", permutation)
...
>>>
I think it's obvious how to turn this into a parity function by adding
a swap counter. If the input array cannot be modified then the parity
function need to make a copy of the array first. That's what
parity_shell() does.
No need to sort
A sort network will always do D comparisions, but those sorts aren't
always needed. The reason is simple - if you think of the network as a
decision tree, where each comparison is a branch, then D comparison
will always have 2D leaves. This must be at least as large
as N!, where N is the number of elements in the list. But N! for N>2
is not a perfect power of 2, so there will be some unused leaves.
I would like to minimize the number of comparisions. I would also like
to not modify the array in-place by actually sorting it.
The key realization is that there's no need to sort in order to
determine the parity. For example, if there are only two elements in
the list, then the parity is as simple as testing
It's complicated enough that it took several attempts before it was
correct. I had to fix it using the following test code, which uses
parity_shell() as a reference because I'm confident that it gives the
correct values. (A useful development technique is to write something
that you know works, even if it's slow, so you can use it to test more
complicated code which better fits your needs)
The test code is:
def test_three_element_parity():
for x in itertools.permutations([1,2,3]):
p1 = parity_shell(x)
p2 = three_element_parity(x)
if p1 != p2:
print("MISMATCH", x, p1, p2)
else:
print("Match", x, p1, p2)
which gives the output:
>>> test_three_element_parity()
Match (1, 2, 3) 0 0
Match (1, 3, 2) 1 1
Match (2, 1, 3) 1 1
Match (2, 3, 1) 0 0
Match (3, 1, 2) 0 0
Match (3, 2, 1) 1 1
A debugging technique
As I said, it took a couple of iterations to get correct code. I
wasn't sure sometimes which branch was used to get a 0 or 1. During
development I added a second field to each return value, to serve as a
tag. The code looked like:
Of course the "MISMATCH" now is misleading and I need to compare
things by eye, but with this few number of elements that's fine. For
more complicated code I would modify the test code as well.
Brute force solution
The last time I worked on this problem I turned the sorting network
for N=4 into a decision tree. With 5 swaps there 25=32
terminal nodes, but only N! = 4! = 24 of them will be used. I pruned
them by hand, which is possible with 32 elements.
I thought this time I would come up with some clever way to handle
this, and pulled out Knuth's "The Art of Computer Programming" for a
pointer, which has a lot about optimal sorts and sorting network.
Oddly, "parity" wasn't in the index.
There's probably some interesting pattern I could use to figure out
which code paths to use, but N is small, so I decided to brute force
it.
I want to build a decision tree where each leaf contains only one
permutation. Each decision will be made by choosing two indices to use
for the comparison test. I'll go through the permtuations. If its
values at those indices are sorted then I'll put them into the
"lt_permutations" list ("lt" is short for "less than"), otherwise they
go into the "gt_permutations" list.
For now, I'll assume the first pair of indices to swap is (0, 1):
Each partitioning corresponds to additional if-statements until there
is only one element in the branch. I want to use the above information
to make a decision tree which looks like:
def parity3(data):
if data[0] < data[1]:
if data[1] < data[2];
return 0 # parity of (0, 1, 2)
else:
if data[0] < data[2]:
return 1 # parity of (0, 2, 1)
else:
return 0 # parity of (1, 2, 0)
...
Partition scoring
In the previous section I partioned using the successive pairs (0, 1),
(1, 2) and (0, 2). These are pretty obvious. What should I use for N=4
or higher? In truth, I could likely use same swap pairs as from the
sorting network, but I decided to continue with brute force.
For N item there are N*(N-1)/2 possible swap pairs.
>>> n = 4
>>> swap_pairs = [(i, j) for i in range(n-1) for j in range(i+1, n)]
>>> swap_pairs = [(i, j) for i in range(n-1) for j in range(i+1, n)]
>>> swap_pairs
[(0, 1), (0, 2), (0, 3), (1, 2), (1, 3), (2, 3)]
I decided to pick the one which is more likely to partition the set of
permutations in half. For each pair, I partition the given
permutations, and use the absolute value of the difference between the
"less than" and the "greater than" subsets.
def get_partition_score(swap_pair, permutations):
i, j = swap_pair
num_lt = num_gt = 0
for permutation in permutations:
if permutation[i] < permutation[j]:
num_lt += 1
else:
num_gt += 1
return abs(num_lt - num_gt)
I'll create all permutations for 4 terms and score each of the pairs
on the results:
Obviously it does no good to use (0, 1) again because those are all
sorted. Most of the other fields are also partially sorted so using
them leads to an imbalanced 4-8 partitioning, but (2, 3) gives a
perfect partitioning, so I'll use it for the next partitioning, and
again select the "less-than" subset and re-score:
Repeat this process until only one permutation is left, and use the
parity of that permutation as the return value.
Code generation
I'll combine the above code together and put it into program which
generates Python code that will compute the parity of a list with N
distinct items. It uses recursion. The main entry point is
"generate_parity_function()", which sets up the data for the recursive
function "_generate_comparison()". That identifies the best pair of
indices to use for the swap then calls itself to process each side.
On the other hand, if there's one permutation in the list, then
there's nothing more do to but compute the parity of that permutation
and use that as the return value for that case.
import itertools
def parity_shell(values):
# Simple Shell sort; while O(N^2), we only deal with at most 4 values
values = list(values)
N = len(values)
num_swaps = 0
for i in range(N-1):
for j in range(i+1, N):
if values[i] > values[j]:
values[i], values[j] = values[j], values[i]
num_swaps += 1
return num_swaps % 2
def get_partition_score(swap_pair, permutations):
i, j = swap_pair
num_lt = num_gt = 0
for permutation in permutations:
if permutation[i] < permutation[j]:
num_lt += 1
else:
num_gt += 1
return abs(num_lt - num_gt)
def partition_permutations(i, j, permutations):
lt_permutations = []
gt_permutations = []
for permutation in permutations:
if permutation[i] < permutation[j]:
lt_permutations.append(permutation)
else:
gt_permutations.append(permutation)
return lt_permutations, gt_permutations
def generate_parity_function(n):
print("def parity{}(data):".format(n))
permutations = list(itertools.permutations(range(n)))
swap_pairs = [(i, j) for i in range(n-1) for j in range(i+1, n)]
_generate_comparison(permutations, swap_pairs, " ")
def _generate_comparison(permutations, swap_pairs, indent):
if len(permutations) == 1:
parity = parity_shell(permutations[0])
print(indent + "return {} # {} ".format(parity, permutations[0]))
return
swap_pair = min(swap_pairs, key=lambda x: get_partition_score(x, permutations))
# Delete the swap pair because it can't be used again.
# (Not strictly needed as it will always have the worse score.)
del swap_pairs[swap_pairs.index(swap_pair)]
# I could have a case where the lt subset has 0 elements while the
# gt subset has 1 element. Rather than have the 'if' block do nothing,
# I'll swap the comparison indices and swap branches.
i, j = swap_pair
lt_permutations, gt_permutations = partition_permutations(i, j, permutations)
if not lt_permutations:
lt_permutations, gt_permutations = gt_permutations, lt_permutations
i, j = j, i
print(indent + "if data[{i}] < data[{j}]:".format(i=i, j=j))
# Need to copy the swap_pairs because the 'else' case may reuse a pair.
_generate_comparison(lt_permutations, swap_pairs[:], indent+" ")
if gt_permutations:
print(indent + "else:")
_generate_comparison(gt_permutations, swap_pairs, indent+" ")
if __name__ == "__main__":
import sys
n = 4
if sys.argv[1:]:
n = int(sys.argv[1])
generate_parity_function(n)
The output for n=2 elements is the expected trivial case:
The test code is essentially the same as
"test_three_element_parity()", so I won't include it here.
Evaluation
I don't think it makes much sense to use this function beyond n=5
because there's so much code. Here's a table of the number of lines of
code it generates for difference values of n:
This appears to be roughly factorial growth, which is what it should
be. For my case, n=4, so 71 lines is not a problem.
I wrote some timing code which does 100,000 random selections from the
possible permutations and compares the performance of the
parityN() function with parity_shell(). To put them on a more
even basis, I changed the parity_shell() implementation so mutates the
input values rather than making a temporary list. The timing code for
parity5() looks like:
import itertools
def parity_shell(values):
# Simple Shell sort; while O(N^2), we only deal with at most 4 values
#values = list(values)
N = len(values)
num_swaps = 0
for i in range(N-1):
for j in range(i+1, N):
if values[i] > values[j]:
values[i], values[j] = values[j], values[i]
num_swaps += 1
return num_swaps % 2
if __name__ == "__main__":
import random
import time
permutations = list(itertools.permutations(range(5)))
perms = [list(random.choice(permutations)) for i in range(100000)]
t1 = time.time()
p1 = [parity5(perm) for perm in perms]
t2 = time.time()
p2 = [parity_shell(perm) for perm in perms]
t3 = time.time()
if p1 != p2:
print("oops")
print("parity5:", t2-t1, "parity_shell:", t3-t2)
print("ratio:", (t2-t1)/(t3-t2))
The decision tree version is consisently 5-6x faster than the Shell sort version across all the sizes I tested.
A performance improvement
By the way, I was able to raise the performance to 9x faster by
switching to local variables rather than an array index each time. Here's the start of parity4() with that change:
def parity4(data):
data0,data1,data2,data3 = data
if data0 < data1:
if data2 < data3:
if data0 < data2:
if data1 < data2:
return 0 # (0, 1, 2, 3)
else:
if data1 < data3:
return 1 # (0, 2, 1, 3)
else:
return 0 # (0, 3, 1, 2)
… additional code omitted …
It's easy to change the code so it generates this version instead, or
you can use a bit of text replacement and hand-editing to do it more
manually from the code I gave earlier.