From 5b3b85630f2d80ceda12fcb8c273b50e2d0f0039 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Radek=20Hu=C5=A1ek?= <husek@iuuk.mff.cuni.cz>
Date: Tue, 6 Apr 2021 15:55:05 +0200
Subject: [PATCH] add option to disable caching

---
 graph_tools/base.py | 29 ++++++++++++++++-------------
 1 file changed, 16 insertions(+), 13 deletions(-)

diff --git a/graph_tools/base.py b/graph_tools/base.py
index 6669124..e5f3a15 100644
--- a/graph_tools/base.py
+++ b/graph_tools/base.py
@@ -129,9 +129,9 @@ def _init_():
 
     _gadget_cache = utils.DynamicLRU()
 
-    def eval(self, p, track_origins=False):
+    def eval(self, p, track_origins=False, **kwargs):
       assert self.is_graph()
-      ret = self.eval_gadget(p, track_origins=track_origins)
+      ret = self.eval_gadget(p, track_origins=track_origins, **kwargs)
       if not track_origins: return p.finalize(ret)
       return (p.finalize(ret), self.enumerate(ret))
 
@@ -150,12 +150,12 @@ def _init_():
       super().__init__(size)
       self.value = value
 
-    def eval_gadget(self, _, track_origins=None):
+    def eval_gadget(self, _, track_origins=None, no_cache=False):
       return self.value
 
 
   class BaseGadget(Gadget):
-    def eval_gadget(self, parameter, track_origins=None):
+    def eval_gadget(self, parameter, track_origins=None, no_cache=False):
       try:
         if self is CUBIC_VERTEX:
           return parameter.CUBIC_VERTEX
@@ -186,7 +186,7 @@ def _init_():
 
       self._info = (tuple(gadgets), joins, outs)
 
-    def eval_gadget(self, parameter, track_origins=False):
+    def eval_gadget(self, parameter, track_origins=False, no_cache=False):
       gadgets, joins, outs = self._info
 
       def do_eval():
@@ -196,9 +196,8 @@ def _init_():
         )
         return ret
 
-      return self._cache.get((parameter, self._info), do_eval)
-
-
+      if no_cache: return do_eval()
+      return self._cache.get((parameter, self), do_eval)
 
 
   def ParametrizedGraphSequence(cls):
@@ -246,7 +245,8 @@ def _init_():
       start_gadget = self.gadget(start_at)
 
       boundaries = set()
-      new_boundaries = set( b.boundary for b in start_gadget.eval_gadget(parameter) )
+      new_boundaries = set( b.boundary for b in
+        start_gadget.eval_gadget(parameter, no_cache=True) )
 
       i = 0
       while True:
@@ -258,7 +258,7 @@ def _init_():
           start_gadget.size(),
           [ Boundary(k, 1) for k in boundaries ]
         )
-        out = self._next_gadget(gadget).eval_gadget(parameter)
+        out = self._next_gadget(gadget).eval_gadget(parameter, no_cache=True)
         new_boundaries = set( b.boundary for b in out )
         print("Loop %i done - %i boundaries (%i new)" % \
           (i, len(new_boundaries), len(new_boundaries.difference(boundaries))))
@@ -271,15 +271,18 @@ def _init_():
       m = matrix(QQ, size, size, {
         (rename[b.boundary], rename[col]): b.value
         for col in sorted_boundaries
-        for b in self._next_gadget(FakeGadget(start_gadget.size(), [ Boundary(col, 1) ])).eval_gadget(parameter)
+        for b in self._next_gadget(FakeGadget(start_gadget.size(),
+          [ Boundary(col, 1) ])).eval_gadget(parameter, no_cache=True)
       })
 
       iv = matrix(QQ, size, 1, {
-        (rename[b.boundary], 0): b.value for b in start_gadget.eval_gadget(parameter)
+        (rename[b.boundary], 0): b.value for b in
+          start_gadget.eval_gadget(parameter, no_cache=True)
       })
 
       fin = matrix(QQ, 1, size, {
-        (0, rename[col]): self._make_graph(FakeGadget(start_gadget.size(), [ Boundary(col, 1) ])).eval(parameter)
+        (0, rename[col]): self._make_graph(FakeGadget(start_gadget.size(),
+          [ Boundary(col, 1) ])).eval(parameter, no_cache=True)
         for col in sorted_boundaries
       })
 
-- 
GitLab