From 2256b4bee8ceb68ef3da2f0ecb5453e1dbda730e Mon Sep 17 00:00:00 2001
From: Martin Mares <mj@ucw.cz>
Date: Fri, 1 Jan 2021 21:39:36 +0100
Subject: [PATCH] =?UTF-8?q?Sjednocen=C3=AD=20funkc=C3=AD=20na=20hled=C3=A1?=
 =?UTF-8?q?n=C3=AD=20m=C3=ADsta=20podle=20k=C3=B3du?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

---
 mo/db.py      | 28 +++++++++++++++++++---------
 mo/imports.py |  2 +-
 mo/web/org.py |  4 ++--
 3 files changed, 22 insertions(+), 12 deletions(-)

diff --git a/mo/db.py b/mo/db.py
index 5d524232..7b128893 100644
--- a/mo/db.py
+++ b/mo/db.py
@@ -3,12 +3,13 @@
 
 import datetime
 from enum import Enum as PythonEnum, auto
+import re
 from sqlalchemy import \
     Boolean, Column, DateTime, Enum, ForeignKey, Integer, String, Text, UniqueConstraint, \
     text, \
     create_engine, inspect
 from sqlalchemy.engine import Engine
-from sqlalchemy.orm import relationship, sessionmaker, Session, class_mapper
+from sqlalchemy.orm import relationship, sessionmaker, Session, class_mapper, joinedload
 from sqlalchemy.orm.attributes import get_history
 from sqlalchemy.dialects.postgresql import JSONB
 from sqlalchemy.ext.declarative import declarative_base
@@ -106,14 +107,23 @@ class Place(Base):
         return len(PlaceType.choices(level=self.level + 1)) > 0
 
 
-def place_by_code(code: str) -> Optional[Place]:
-    if code.startswith("#"):
-        try:
-            id = int(code[1:])
-            return get_session().query(Place).get(id)
-        except ValueError:
-            return None
-    return get_session().query(Place).filter_by(code=code).first()
+def get_root_place():
+    return get_session().query(Place).filter_by(parent=None).one()
+
+
+def get_place_by_code(code: str, fetch_school: bool = False) -> Optional[Place]:
+    if code == "":
+        return None
+
+    q = get_session().query(Place)
+    if fetch_school:
+        q = q.options(joinedload(Place.school))
+
+    m = re.fullmatch(r'#(\d+)', code)
+    if m:
+        return q.get(int(m[1]))
+    else:
+        return q.filter_by(code=code).one_or_none()
 
 
 class School(Base):
diff --git a/mo/imports.py b/mo/imports.py
index e7db1b72..b8b905ff 100644
--- a/mo/imports.py
+++ b/mo/imports.py
@@ -60,7 +60,7 @@ def parse_school(kod: str, errs: List[str]) -> Optional[db.Place]:
         errs.append('Škola je povinná')
         return None
 
-    place = mo.util.get_place_by_code(kod, fetch_school=True)
+    place = db.get_place_by_code(kod, fetch_school=True)
     if not place:
         errs.append('Škola nenalezena')
         return None
diff --git a/mo/web/org.py b/mo/web/org.py
index c48aca72..5f12a39f 100644
--- a/mo/web/org.py
+++ b/mo/web/org.py
@@ -201,7 +201,7 @@ def org_place_move(id: int):
         if form.reset.data:
             return redirect(url_for('org_place_move', id=id))
 
-        new_parent = db.place_by_code(form.code.data)
+        new_parent = db.get_place_by_code(form.code.data)
         if not new_parent:
             search_error = 'Místo s tímto kódem se nepovedlo nalézt'
         else:
@@ -350,7 +350,7 @@ def org_place_new_child(id: int):
 
 @app.route('/org/place/')
 def org_place_root():
-    root = mo.util.get_root_place()
+    root = db.get_root_place()
     return redirect(url_for('org_place', id=root.place_id))
 
 
-- 
GitLab