diff --git a/graph_tools/base.py b/graph_tools/base.py index 4954de52132f78a191305dd53024eb18623f2817..9693fbe24f356807674463f4f3e6d7cef087cb5e 100644 --- a/graph_tools/base.py +++ b/graph_tools/base.py @@ -136,13 +136,14 @@ def _init_(): return (p.finalize(ret), self.enumerate(ret)) @staticmethod - def join(gadgets, joins, outs): + def join(gadgets, joins, outs, *, no_cache=False): assert all( isinstance(g, Gadget) for g in gadgets ), \ "Not all gadgets are instance of Gadget: %s" % gadgets assert sorted(list(chain.from_iterable(joins)) + outs) == \ [ (i+1, j+1) for i in range(len(gadgets)) for j in range(gadgets[i].size()) ] info = (tuple(gadgets), tuple(joins), tuple(outs)) + if no_cache: return JoinGadget(*info) return Gadget._gadget_cache.get(info, lambda: JoinGadget(*info)) class FakeGadget(Gadget): @@ -191,7 +192,7 @@ def _init_(): def do_eval(): ret = parameter.eval_join( - [ g.eval_gadget(parameter, track_origins=track_origins) for g in gadgets ], + [ g.eval_gadget(parameter, track_origins=track_origins, no_cache=no_cache) for g in gadgets ], joins, outs, self.offsets, track_origins ) return ret diff --git a/graph_tools/misc.py b/graph_tools/misc.py index 42d6c3c74a56f3bf43c17ef1a48ff8122068051c..bc737e9ed45f91de72465cbb177a84a68aa67ea5 100644 --- a/graph_tools/misc.py +++ b/graph_tools/misc.py @@ -161,21 +161,22 @@ def _init_(): def edge_model_join(g1 : Gadget, g2 : Gadget): assert g1.size() == g2.size() joins = [ ((1, i+1), (2, i+1)) for i in range(g1.size()) ] - return Gadget.join([ g1, g2 ], joins, []) + return Gadget.join([ g1, g2 ], joins, [], no_cache=True) def parameter_matrix(P, gadgets): from sage.all import matrix, QQ - m = matrix(QQ, len(gadgets), len(gadgets)) + m = matrix(QQ, len(gadgets), len(gadgets), sparse=True) for r, gr in enumerate(gadgets): for c, gc in enumerate(gadgets): - m[r, c] = edge_model_join(gr, gc).eval(P) + m[r, c] = edge_model_join(gr, gc).eval(P, no_cache=True) return m + # FIXME k cemu je to dobry? def edge_model_join_matrix(P, k): BG = [ FakeGadget(k, [ BoundaryValue(b, 1) ]) for b in P.enumerate_boundaries(k) ] return parameter_matrix(P, BG)