diff --git a/groupConnectivity.pyx b/groupConnectivity.pyx
index 007c5deaee2ddd6f2ae1bdb3e39ce1f5b4057fad..bb11b0d8036b118f07035cd6fc6d03227faadbb6 100644
--- a/groupConnectivity.pyx
+++ b/groupConnectivity.pyx
@@ -139,7 +139,7 @@ def subdivisionIterator(G, edges = None):
def generateSubdivisions(file, G, edges = None,
function = lambda H: H.graph6_string(),
- parallel = 0):
+ parallel = 0, chunksize = 1):
f = open(file, 'w+')
f.write("# Subdivisions of edges '%s' of graph '%s' (%s)\n" %
(edges, G.graph6_string(), G))
@@ -147,7 +147,8 @@ def generateSubdivisions(file, G, edges = None,
if parallel == 0:
iterator = map(function, subdivisionIterator(G, edges))
else:
- iterator = parmap(function, subdivisionIterator(G, edges), parallel)
+ iterator = parmap(function, subdivisionIterator(G, edges),
+ nprocs = parallel, chunksize = chunksize)
for line in iterator:
f.write(line)
diff --git a/parmap.py b/parmap.py
index 040e8ddce735f1e3a5982bb953dbe3a146b83505..dcb8521873d82301a1a98f8c04b5e3c6e14ae27c 100644
--- a/parmap.py
+++ b/parmap.py
@@ -2,10 +2,21 @@
# http://stackoverflow.com/a/16071616
import multiprocessing
+from itertools import chain
-def producer_fun(X, q_in, q_control, nprocs):
+def chunkme(X, chunksize):
+ chunk = []
+ for x in X:
+ if len(chunk) >= chunksize:
+ yield chunk
+ chunk = []
+ chunk.append(x)
+ if len(chunk):
+ yield chunk
+
+def producer_fun(X, q_in, q_control, nprocs, chunksize):
sent = 0
- for i, x in enumerate(X):
+ for i, x in enumerate(chunkme(X, chunksize)):
q_in.put((i, x))
sent += 1
@@ -16,12 +27,12 @@ def producer_fun(X, q_in, q_control, nprocs):
def worker_fun(f, q_in, q_out):
while True:
- i, x = q_in.get()
+ i, chunk = q_in.get()
if i is None:
break
- q_out.put((i, f(x)))
+ q_out.put((i, [ f(x) for x in chunk ]))
-def parmap(f, X, nprocs = None):
+def parmap(f, X, nprocs = None, chunksize = 1):
if nprocs is None:
nprocs = multiprocessing.cpu_count()
@@ -34,7 +45,7 @@ def parmap(f, X, nprocs = None):
proc.append(multiprocessing.Process(
target = producer_fun,
- args = (X, q_in, q_out, nprocs)
+ args = (X, q_in, q_out, nprocs, chunksize)
))
for p in proc:
@@ -57,5 +68,5 @@ def parmap(f, X, nprocs = None):
for p in proc:
p.join()
- return [ ret[i] for i in range(len(ret)) ]
+ return chain.from_iterable([ ret[i] for i in range(len(ret)) ])