diff --git a/groupConnectivity.pyx b/groupConnectivity.pyx index 068d1da96ba8d198829d33b776155ff2d8344fd3..f405d70b38fecf256c36ff3a6bd3b533ad77fcf0 100644 --- a/groupConnectivity.pyx +++ b/groupConnectivity.pyx @@ -139,7 +139,7 @@ def subdivisionIterator(G, edges = None): def generateSubdivisions(file, G, edges = None, function = lambda H: H.graph6_string(), - parallel = 0, chunksize = 1, q_size = 3): + parallel = 0, chunksize = 1, in_flight = None): f = open(file, 'w+') f.write("# Subdivisions of edges '%s' of graph '%s' (%s)\n" % (edges, G.graph6_string(), G)) @@ -148,7 +148,8 @@ def generateSubdivisions(file, G, edges = None, iterator = map(function, subdivisionIterator(G, edges)) else: iterator = parmap(function, subdivisionIterator(G, edges), - nprocs = parallel, chunksize = chunksize, queue_size = q_size) + nprocs = parallel, chunksize = chunksize, + chunks_in_flight = in_flight) for line in iterator: f.write(line) diff --git a/parmap.py b/parmap.py index cdb0f89f551045ea9031d5fd12bce54301cc0a8d..f3b36de7eaa4124e1f88140e74b1eb5534ba0953 100644 --- a/parmap.py +++ b/parmap.py @@ -14,16 +14,17 @@ def chunkme(X, chunksize): if len(chunk): yield chunk -def producer_fun(X, q_in, q_control, nprocs, chunksize): +def producer_fun(X, q_in, q_out, cont, nprocs, chunksize): sent = 0 for i, x in enumerate(chunkme(X, chunksize)): + cont.acquire() q_in.put((i, x)) sent += 1 for _ in range(nprocs): q_in.put((None, None)) - q_control.put((None, sent)) + q_out.put((None, sent)) def worker_fun(f, q_in, q_out): while True: @@ -32,11 +33,17 @@ def worker_fun(f, q_in, q_out): break q_out.put((i, [ f(x) for x in chunk ])) -def parmap(f, X, nprocs = None, chunksize = 1, queue_size = 3): +def parmap(f, X, nprocs = None, chunksize = 1, chunks_in_flight = None): if nprocs is None: nprocs = multiprocessing.cpu_count() - q_in = multiprocessing.Queue(queue_size) + if chunks_in_flight is None: + chunks_in_flight = 10 + 3 * nprocs + + chunks_in_flight = max(chunks_in_flight, nprocs + 1) + + cont = multiprocessing.Semaphore(chunks_in_flight) + q_in = multiprocessing.Queue() q_out = multiprocessing.Queue() proc = [ multiprocessing.Process( @@ -45,28 +52,33 @@ def parmap(f, X, nprocs = None, chunksize = 1, queue_size = 3): proc.append(multiprocessing.Process( target = producer_fun, - args = (X, q_in, q_out, nprocs, chunksize) + args = (X, q_in, q_out, cont, nprocs, chunksize) )) for p in proc: p.daemon = True p.start() - ret = {} - jobs = None - while True: - i, val = q_out.get() - if i is None: - jobs = val - break - ret[i] = val + def get_chunk(): + ret = {} + chunk_index = 0 + jobs = None - for _ in range(jobs - len(ret)): - i, val = q_out.get() - ret[i] = val + while jobs is None or chunk_index < jobs: + i, val = q_out.get() + if i is None: + jobs = val + continue - for p in proc: - p.join() + ret[i] = val + while chunk_index in ret: + val = ret[chunk_index] + chunk_index += 1 + cont.release() + yield val + + for p in proc: + p.join() - return chain.from_iterable([ ret[i] for i in range(len(ret)) ]) + return chain.from_iterable(get_chunk())