diff --git a/.gitignore b/.gitignore index eab1f010c2563666f34ea210afc4f794765cc558..35e4cbac38d394eb8ab7cbe77fd41654467982da 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,4 @@ .*.swp groupConnectivity.so +parmap.pyc + diff --git a/Makefile b/Makefile index c870548c1fbab8690af880a12cd17fb510d708af..572cac9d2733af7d59da56f26fb11da42ee70f88 100644 --- a/Makefile +++ b/Makefile @@ -1,6 +1,6 @@ default: groupConnectivity.so clean_obj -groupConnectivity.so: groupConnectivity.pyx group-connectivity.h setup.py compileTimeOptions.h +groupConnectivity.so: groupConnectivity.pyx group-connectivity.h setup.py compileTimeOptions.h parmap.py python setup.py build_ext cp build/lib*/groupConnectivity.so . diff --git a/groupConnectivity.pyx b/groupConnectivity.pyx index 2cb25061d003ddf3dc8a1f1a715b64719bdb3c2c..7e30885d219e57fd1b47b76ebe2d1095dc275bfd 100644 --- a/groupConnectivity.pyx +++ b/groupConnectivity.pyx @@ -2,6 +2,7 @@ from libcpp.vector cimport vector from libcpp cimport bool from libcpp.utility cimport pair from sage.graphs.graph import Graph +from parmap import parmap cdef extern from "group-connectivity.h" namespace "Ring": cdef cppclass Z4[T]: @@ -87,12 +88,21 @@ def subdivisionIterator(G, edges = None): for L in subdivisionIterator(H, edges[1:]): yield L -def generateSubdivisions(file, G, edges = None, function = lambda H: H.graph6_string()): + +def generateSubdivisions(file, G, edges = None, + function = lambda H: H.graph6_string(), + parallel = 1): f = open(file, 'w+') f.write("# Subdivisions of edges '%s' of graph '%s' (%s)\n" % (edges, G.graph6_string(), G)) - for H in subdivisionIterator(G, edges): - f.write(function(H)) + + if parallel == 1: + iterator = map(function, subdivisionIterator(G, edges)) + else: + iterator = parmap(function, subdivisionIterator(G, edges), parallel) + + for line in iterator: + f.write(line) f.write("\n") f.close() diff --git a/parmap.py b/parmap.py new file mode 100644 index 0000000000000000000000000000000000000000..411820dff38984cf63c79711177acf760e225b68 --- /dev/null +++ b/parmap.py @@ -0,0 +1,36 @@ +# modified version of code written by klaus se from +# http://stackoverflow.com/a/16071616 + +import multiprocessing + +def fun(f, q_in, q_out): + while True: + i,x = q_in.get() + if i is None: + break + q_out.put((i, f(x))) + +def parmap(f, X, nprocs = None): + if nprocs is None: + nprocs = multiprocessing.cpu_count() + + q_in = multiprocessing.Queue(100) + q_out = multiprocessing.Queue() + + proc = [multiprocessing.Process(target=fun,args=(f,q_in,q_out)) for _ in range(nprocs)] + for p in proc: + p.daemon = True + p.start() + + sent = [q_in.put((i,x)) for i,x in enumerate(X)] + [q_in.put((None,None)) for _ in range(nprocs)] + + ret = {} + for _ in range(len(sent)): + i, val = q_out.get() + ret[i] = val + + [ p.join() for p in proc ] + + return [ ret[i] for i in range(len(sent)) ] +