Skip to content
Snippets Groups Projects
Commit 2db208cd authored by Radek Hušek's avatar Radek Hušek
Browse files

improve memory consumption

parent f9d909e9
No related branches found
No related tags found
No related merge requests found
...@@ -136,13 +136,14 @@ def _init_(): ...@@ -136,13 +136,14 @@ def _init_():
return (p.finalize(ret), self.enumerate(ret)) return (p.finalize(ret), self.enumerate(ret))
@staticmethod @staticmethod
def join(gadgets, joins, outs): def join(gadgets, joins, outs, *, no_cache=False):
assert all( isinstance(g, Gadget) for g in gadgets ), \ assert all( isinstance(g, Gadget) for g in gadgets ), \
"Not all gadgets are instance of Gadget: %s" % gadgets "Not all gadgets are instance of Gadget: %s" % gadgets
assert sorted(list(chain.from_iterable(joins)) + outs) == \ assert sorted(list(chain.from_iterable(joins)) + outs) == \
[ (i+1, j+1) for i in range(len(gadgets)) for j in range(gadgets[i].size()) ] [ (i+1, j+1) for i in range(len(gadgets)) for j in range(gadgets[i].size()) ]
info = (tuple(gadgets), tuple(joins), tuple(outs)) info = (tuple(gadgets), tuple(joins), tuple(outs))
if no_cache: return JoinGadget(*info)
return Gadget._gadget_cache.get(info, lambda: JoinGadget(*info)) return Gadget._gadget_cache.get(info, lambda: JoinGadget(*info))
class FakeGadget(Gadget): class FakeGadget(Gadget):
...@@ -191,7 +192,7 @@ def _init_(): ...@@ -191,7 +192,7 @@ def _init_():
def do_eval(): def do_eval():
ret = parameter.eval_join( ret = parameter.eval_join(
[ g.eval_gadget(parameter, track_origins=track_origins) for g in gadgets ], [ g.eval_gadget(parameter, track_origins=track_origins, no_cache=no_cache) for g in gadgets ],
joins, outs, self.offsets, track_origins joins, outs, self.offsets, track_origins
) )
return ret return ret
......
...@@ -161,21 +161,22 @@ def _init_(): ...@@ -161,21 +161,22 @@ def _init_():
def edge_model_join(g1 : Gadget, g2 : Gadget): def edge_model_join(g1 : Gadget, g2 : Gadget):
assert g1.size() == g2.size() assert g1.size() == g2.size()
joins = [ ((1, i+1), (2, i+1)) for i in range(g1.size()) ] joins = [ ((1, i+1), (2, i+1)) for i in range(g1.size()) ]
return Gadget.join([ g1, g2 ], joins, []) return Gadget.join([ g1, g2 ], joins, [], no_cache=True)
def parameter_matrix(P, gadgets): def parameter_matrix(P, gadgets):
from sage.all import matrix, QQ from sage.all import matrix, QQ
m = matrix(QQ, len(gadgets), len(gadgets)) m = matrix(QQ, len(gadgets), len(gadgets), sparse=True)
for r, gr in enumerate(gadgets): for r, gr in enumerate(gadgets):
for c, gc in enumerate(gadgets): for c, gc in enumerate(gadgets):
m[r, c] = edge_model_join(gr, gc).eval(P) m[r, c] = edge_model_join(gr, gc).eval(P, no_cache=True)
return m return m
# FIXME k cemu je to dobry?
def edge_model_join_matrix(P, k): def edge_model_join_matrix(P, k):
BG = [ FakeGadget(k, [ BoundaryValue(b, 1) ]) for b in P.enumerate_boundaries(k) ] BG = [ FakeGadget(k, [ BoundaryValue(b, 1) ]) for b in P.enumerate_boundaries(k) ]
return parameter_matrix(P, BG) return parameter_matrix(P, BG)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment