diff --git a/graph_tools/misc.py b/graph_tools/misc.py index 11c5fe4cbafb447deb75ef1d67ea7b81952f8f4d..f8657b786bc34d57fc54353f99c9d9d23fe941be 100644 --- a/graph_tools/misc.py +++ b/graph_tools/misc.py @@ -164,36 +164,33 @@ def _init_(): return Gadget.join([ g1, g2 ], joins, [], no_cache=True) - def parameter_matrix(P, gadgets): - from sage.all import matrix, QQ - - m = matrix(QQ, len(gadgets), len(gadgets), sparse=True) - - for r, gr in enumerate(gadgets): - for c, gc in enumerate(gadgets): - m[r, c] = edge_model_join(gr, gc).eval(P, no_cache=True) - - return m - - - def edge_model_join_matrix(P, k, threads=1): + def parameter_matrix(P, gadgets, threads=1): from sage.all import matrix, QQ from parmap import parmap - BG = [ FakeGadget(k, [ BoundaryValue(b, 1) ]) for b in P.enumerate_boundaries(k) ] - inp = ( (r, c) for r in range(len(BG)) for c in range(len(BG)) ) + N = len(gadgets) + inp = ( (r, c) for r in range(N) for c in range(N) ) def worker(x): r, c = x - return (r, c, edge_model_join(BG[r], BG[c]).eval(P, no_cache=True)) + return (r, c, edge_model_join(gadgets[r], gadgets[c]).eval(P, no_cache=True)) - m = matrix(QQ, len(BG), len(BG), sparse=True) - for r, c, v in parmap(worker, inp, nprocs=threads, in_order=False): + m = matrix(QQ, N, N, sparse=True) + for r, c, v in parmap(worker, inp, nprocs=threads, in_order=False, + chunksize=100, out_chunksize=100): m[r, c] = v return m + def edge_model_join_matrix(P, k, threads=1, boundaries=False): + B = list(P.enumerate_boundaries(k)) + BG = [ FakeGadget(k, [ BoundaryValue(b, 1) ]) for b in B ] + M = parameter_matrix(P, BG, threads=threads) + if boundaries: return (B, M) + return M + + def enumerate_diamond_matchings(k): assert k % 2 == 0 and k >= 2