diff --git a/mo/db.py b/mo/db.py index 5d52423256f57938497f734d3aea13e1ba8e2b1d..7b1288938cbdf31d94f05208c8cce53477cf82d7 100644 --- a/mo/db.py +++ b/mo/db.py @@ -3,12 +3,13 @@ import datetime from enum import Enum as PythonEnum, auto +import re from sqlalchemy import \ Boolean, Column, DateTime, Enum, ForeignKey, Integer, String, Text, UniqueConstraint, \ text, \ create_engine, inspect from sqlalchemy.engine import Engine -from sqlalchemy.orm import relationship, sessionmaker, Session, class_mapper +from sqlalchemy.orm import relationship, sessionmaker, Session, class_mapper, joinedload from sqlalchemy.orm.attributes import get_history from sqlalchemy.dialects.postgresql import JSONB from sqlalchemy.ext.declarative import declarative_base @@ -106,14 +107,23 @@ class Place(Base): return len(PlaceType.choices(level=self.level + 1)) > 0 -def place_by_code(code: str) -> Optional[Place]: - if code.startswith("#"): - try: - id = int(code[1:]) - return get_session().query(Place).get(id) - except ValueError: - return None - return get_session().query(Place).filter_by(code=code).first() +def get_root_place(): + return get_session().query(Place).filter_by(parent=None).one() + + +def get_place_by_code(code: str, fetch_school: bool = False) -> Optional[Place]: + if code == "": + return None + + q = get_session().query(Place) + if fetch_school: + q = q.options(joinedload(Place.school)) + + m = re.fullmatch(r'#(\d+)', code) + if m: + return q.get(int(m[1])) + else: + return q.filter_by(code=code).one_or_none() class School(Base): diff --git a/mo/imports.py b/mo/imports.py index e7db1b721dcaa066463d7ae90d1438e8ffbd36a9..b8b905ff026bded24b9f92aae5a9b402a8d5c889 100644 --- a/mo/imports.py +++ b/mo/imports.py @@ -60,7 +60,7 @@ def parse_school(kod: str, errs: List[str]) -> Optional[db.Place]: errs.append('Škola je povinná') return None - place = mo.util.get_place_by_code(kod, fetch_school=True) + place = db.get_place_by_code(kod, fetch_school=True) if not place: errs.append('Škola nenalezena') return None diff --git a/mo/web/org.py b/mo/web/org.py index c48aca7272210bf2f65c5d0e1e32d9025b00b1a8..5f12a39fbd1d66d1810b2c0ab488eb402bf2bc89 100644 --- a/mo/web/org.py +++ b/mo/web/org.py @@ -201,7 +201,7 @@ def org_place_move(id: int): if form.reset.data: return redirect(url_for('org_place_move', id=id)) - new_parent = db.place_by_code(form.code.data) + new_parent = db.get_place_by_code(form.code.data) if not new_parent: search_error = 'Místo s tímto kódem se nepovedlo nalézt' else: @@ -350,7 +350,7 @@ def org_place_new_child(id: int): @app.route('/org/place/') def org_place_root(): - root = mo.util.get_root_place() + root = db.get_root_place() return redirect(url_for('org_place', id=root.place_id))