parmap.py 1.88 KB
Newer Older
1
2
3
4
# modified version of code written by klaus se from
# http://stackoverflow.com/a/16071616

import multiprocessing
Radek Hušek's avatar
Radek Hušek committed
5
from itertools import chain
6

Radek Hušek's avatar
Radek Hušek committed
7
8
9
10
11
12
13
14
15
16
def chunkme(X, chunksize):
  chunk = []
  for x in X:
    if len(chunk) >= chunksize:
      yield chunk
      chunk = []
    chunk.append(x)
  if len(chunk):
    yield chunk

17
def producer_fun(X, q_in, q_out, cont, nprocs, chunksize):
18
  sent = 0
Radek Hušek's avatar
Radek Hušek committed
19
  for i, x in enumerate(chunkme(X, chunksize)):
20
    cont.acquire()
21
22
23
24
25
26
    q_in.put((i, x))
    sent += 1

  for _ in range(nprocs):
    q_in.put((None, None))

27
  q_out.put((None, sent))
28
29

def worker_fun(f, q_in, q_out):
30
    while True:
Radek Hušek's avatar
Radek Hušek committed
31
        i, chunk = q_in.get()
32
33
        if i is None:
            break
Radek Hušek's avatar
Radek Hušek committed
34
        q_out.put((i, [ f(x) for x in chunk ]))
35

36
def parmap(f, X, nprocs = None, chunksize = 1, chunks_in_flight = None):
37
38
39
    if nprocs is None:
      nprocs = multiprocessing.cpu_count()

40
41
42
43
44
45
46
    if chunks_in_flight is None:
      chunks_in_flight = 10 + 3 * nprocs

    chunks_in_flight = max(chunks_in_flight, nprocs + 1)

    cont   = multiprocessing.Semaphore(chunks_in_flight)
    q_in   = multiprocessing.Queue()
47
48
    q_out  = multiprocessing.Queue()

49
50
51
52
53
54
    proc = [ multiprocessing.Process(
        target = worker_fun, args = (f, q_in, q_out)
      ) for _ in range(nprocs)]

    proc.append(multiprocessing.Process(
      target = producer_fun,
55
      args = (X, q_in, q_out, cont, nprocs, chunksize)
56
57
    ))

58
59
60
61
    for p in proc:
        p.daemon = True
        p.start()

62
63
64
65
    def get_chunk():
      ret = {}
      chunk_index = 0
      jobs = None
66

67
68
69
70
71
      while jobs is None or chunk_index < jobs:
        i, val = q_out.get()
        if i is None:
          jobs = val
          continue
72

73
74
75
76
77
78
79
80
81
        ret[i] = val
        while chunk_index in ret:
          val = ret[chunk_index]
          chunk_index += 1
          cont.release()
          yield val

      for p in proc:
        p.join()
82

83
    return chain.from_iterable(get_chunk())
84