Skip to content
Snippets Groups Projects
Select Git revision
  • b3eba8b82b292d625146f816690721558f50b535
  • devel default
  • master
  • fo
  • jirka/typing
  • fo-base
  • mj/submit-images
  • jk/issue-96
  • jk/issue-196
  • honza/add-contestant
  • honza/mr7
  • honza/mrf
  • honza/mrd
  • honza/mra
  • honza/mr6
  • honza/submit-images
  • honza/kolo-vs-soutez
  • jh-stress-test-wip
  • shorten-schools
19 results

util.py

Blame
  • db.py 12.28 KiB
    # SQLAlchemy definitions of all tables in the database
    # Generated by sqlacodegen and then heavily edited.
    
    from enum import Enum as PythonEnum, auto
    from sqlalchemy import \
        Boolean, Column, DateTime, Enum, ForeignKey, Integer, String, Text, UniqueConstraint, \
        text, \
        create_engine, inspect
    from sqlalchemy.engine import Engine
    from sqlalchemy.orm import relationship, sessionmaker, Session, class_mapper
    from sqlalchemy.orm.attributes import get_history
    from sqlalchemy.dialects.postgresql import JSONB
    from sqlalchemy.ext.declarative import declarative_base
    from typing import Optional, List
    
    
    # HACK: Work-around for https://github.com/dropbox/sqlalchemy-stubs/issues/114
    from typing import TYPE_CHECKING, TypeVar, Type, Any
    if TYPE_CHECKING:
        from sqlalchemy.sql.type_api import TypeEngine
        T = TypeVar('T')
        class Enum(TypeEngine[T]):
            def __init__(self, enum: Type[T], **kwargs: Any) -> None: ...
    else:
        from sqlalchemy import Enum
    
    
    Base = declarative_base()
    metadata = Base.metadata
    
    
    class PlaceType(str, PythonEnum):
        region = auto()
        school = auto()
        site = auto()
    
        @classmethod
        def choices(enum):
            return [('region', 'Region'), ('school', 'Škola'), ('site', 'Soutěžní místo')]
    
        @classmethod
        def coerce(enum, name):
            if isinstance(name, enum):
                return name
            try:
                return enum[name]
            except KeyError:
                raise ValueError(name)
    
    
    place_level_names = ['stát', 'kraj', 'okres', 'obec', 'škola']
    
    
    class Place(Base):
        __tablename__ = 'places'
    
        place_id = Column(Integer, primary_key=True, server_default=text("nextval('places_place_id_seq'::regclass)"))
        level = Column(Integer, nullable=False)
        parent = Column(Integer, ForeignKey('places.place_id'))
        name = Column(String(255))
        code = Column(String(255))
        type = Column(Enum(PlaceType, name='place_type'), nullable=False)
        nuts = Column(String(255), unique=True, server_default=text("NULL::character varying"))
        note = Column(Text, nullable=False, server_default=text("''::text"))
    
        children = relationship('Place')
    
        def type_name(self):
            if self.type == PlaceType.site:
                return "soutěžní místo"
            elif self.type == PlaceType.school:
                return "škola"
            elif self.level < len(place_level_names):
                return place_level_names[self.level]
            else:
                return "region"
    
    
    class School(Base):
        __tablename__ = 'schools'
    
        place_id = Column(Integer, ForeignKey('places.place_id', ondelete='CASCADE'), primary_key=True)
        red_izo = Column(String(255), server_default=text("NULL::character varying"))
        official_name = Column(String(255), server_default=text("NULL::character varying"))
        address = Column(String(255), server_default=text("NULL::character varying"))
        is_zs = Column(Boolean, nullable=False, server_default=text("false"))
        is_ss = Column(Boolean, nullable=False, server_default=text("false"))
    
        place = relationship('Place')
    
    
    class Round(Base):
        __tablename__ = 'rounds'
        __table_args__ = (
            UniqueConstraint('year', 'category', 'seq'),
        )
    
        round_id = Column(Integer, primary_key=True, server_default=text("nextval('rounds_round_id_seq'::regclass)"))
        year = Column(Integer, nullable=False)
        category = Column(String(2), nullable=False)
        seq = Column(Integer, nullable=False)
        level = Column(Integer, nullable=False)
        name = Column(String(255), nullable=False)
    
        def round_code(self):
            return f"{self.year}-{self.category}-{self.seq}"
    
    
    class User(Base):
        __tablename__ = 'users'
    
        user_id = Column(Integer, primary_key=True, server_default=text("nextval('users_user_id_seq'::regclass)"))
        email = Column(String(255), nullable=False, unique=True)
        first_name = Column(String(255), nullable=False)
        last_name = Column(String(255), nullable=False)
        is_org = Column(Boolean, nullable=False, server_default=text("false"))
        is_admin = Column(Boolean, nullable=False, server_default=text("false"))
        created_at = Column(DateTime(True), nullable=False, server_default=text("CURRENT_TIMESTAMP"))
        last_login_at = Column(DateTime(True))
        reset_at = Column(DateTime(True))
        password_hash = Column(String(255), server_default=text("NULL::character varying"))
    
        roles = relationship('UserRole', primaryjoin='UserRole.user_id == User.user_id')
    
    class Contest(Base):
        __tablename__ = 'contests'
        __table_args__ = (
            UniqueConstraint('round', 'region'),
        )
    
        contest_id = Column(Integer, primary_key=True, server_default=text("nextval('contests_contest_id_seq'::regclass)"))
        round = Column(Integer, ForeignKey('rounds.round_id'), nullable=False)
        region = Column(Integer, ForeignKey('places.place_id'), nullable=False)
    
        region_object = relationship('Place')
        round_object = relationship('Round')
    
    class LogType(str, PythonEnum):
        general = auto()
        user = auto()
        place = auto()
        round = auto()
        contest = auto()
        participant = auto()
        task = auto()
        user_role = auto()
    
    
    class Log(Base):
        __tablename__ = 'log'
    
        log_entry_id = Column(Integer, primary_key=True, server_default=text("nextval('log_log_entry_id_seq'::regclass)"))
        changed_by = Column(Integer, ForeignKey('users.user_id'))
        changed_at = Column(DateTime(True), nullable=False, server_default=text("CURRENT_TIMESTAMP"))
        type = Column(Enum(LogType, name='log_type'), nullable=False)
        id = Column(Integer, nullable=False)
        details = Column(JSONB, nullable=False)
    
        user = relationship('User')
    
    
    class Participant(Base):
        __tablename__ = 'participants'
    
        user_id = Column(Integer, ForeignKey('users.user_id'), primary_key=True, nullable=False)
        year = Column(Integer, primary_key=True, nullable=False)
        school = Column(Integer, ForeignKey('places.place_id'), nullable=False)
        birth_year = Column(Integer, nullable=False)
        grade = Column(String(20), nullable=False)
    
        place = relationship('Place')
        user = relationship('User')
    
    
    class PartState(str, PythonEnum):
        invited = auto()
        refused = auto()
        present = auto()
        absent = auto()
    
    
    class Participation(Base):
        __tablename__ = 'participations'
    
        user_id = Column(Integer, ForeignKey('users.user_id'), primary_key=True, nullable=False)
        contest_id = Column(Integer, ForeignKey('contests.contest_id'), primary_key=True, nullable=False)
        place_id = Column(Integer, ForeignKey('places.place_id'), nullable=False)
        state = Column(Enum(PartState, name='part_state'), nullable=False)
    
        contest = relationship('Contest', primaryjoin='Participation.contest_id == Contest.contest_id')
        place = relationship('Place', primaryjoin='Participation.place_id == Place.place_id')
        user = relationship('User')
    
    
    class Task(Base):
        __tablename__ = 'tasks'
        __table_args__ = (
            UniqueConstraint('round_id', 'code'),
        )
    
        task_id = Column(Integer, primary_key=True, server_default=text("nextval('tasks_task_id_seq'::regclass)"))
        round_id = Column(Integer, ForeignKey('rounds.round_id'), nullable=False)
        code = Column(String(255), nullable=False)
        name = Column(String(255), nullable=False)
    
        round = relationship('Round')
    
    
    class RoleType(str, PythonEnum):
        garant = auto()
        garant_kraj = auto()
        garant_okres = auto()
        dozor = auto()
        opravovatel = auto()
    
    
    class UserRole(Base):
        __tablename__ = 'user_roles'
    
        user_role_id = Column(Integer, primary_key=True, server_default=text("nextval('user_roles_user_role_id_seq'::regclass)"))
        user_id = Column(Integer, ForeignKey('users.user_id'), nullable=False)
        place_id = Column(Integer, ForeignKey('places.place_id'), nullable=False)
        role = Column(Enum(RoleType, name='role_type'), nullable=False)
        category = Column(String(2), server_default=text("NULL::character varying"))
        round_id = Column(Integer, ForeignKey('rounds.round_id'))
        assigned_by = Column(Integer, ForeignKey('users.user_id'))
        assigned_at = Column(DateTime(True))
    
        user = relationship('User', primaryjoin='UserRole.user_id == User.user_id')
        assigned_by_user = relationship('User', primaryjoin='UserRole.assigned_by == User.user_id')
        place_object = relationship('Place')
        round = relationship('Round')
    
    
    class PaperType(str, PythonEnum):
        solution = auto()
        feedback = auto()
    
    
    class Paper(Base):
        __tablename__ = 'papers'
    
        paper_id = Column(Integer, primary_key=True, server_default=text("nextval('papers_paper_id_seq'::regclass)"))
        for_task = Column(Integer, ForeignKey('tasks.task_id'), nullable=False)
        for_user = Column(Integer, ForeignKey('users.user_id'), nullable=False)
        type = Column(Enum(PaperType, name='paper_type'), nullable=False)
        uploaded_by = Column(Integer, ForeignKey('users.user_id'), nullable=False)
        uploaded_at = Column(DateTime(True), nullable=False, server_default=text("CURRENT_TIMESTAMP"))
        pages = Column(Integer)
        bytes = Column(Integer)
        file_name = Column(String(255), nullable=False)
    
        task = relationship('Task')
        for_user_obj = relationship('User', primaryjoin='Paper.for_user == User.user_id')
        uploaded_by_obj = relationship('User', primaryjoin='Paper.uploaded_by == User.user_id')
    
    
    class PointsHistory(Base):
        __tablename__ = 'points_history'
    
        points_history_id = Column(Integer, primary_key=True, server_default=text("nextval('points_history_points_history_id_seq'::regclass)"))
        task_id = Column(Integer, ForeignKey('tasks.task_id'), nullable=False)
        participant_id = Column(Integer, ForeignKey('users.user_id'), nullable=False)
        points = Column(Integer, nullable=False)
        points_by = Column(Integer, ForeignKey('users.user_id'), nullable=False)
        points_at = Column(DateTime(True), nullable=False)
    
        participant = relationship('User', primaryjoin='PointsHistory.participant_id == User.user_id')
        user = relationship('User', primaryjoin='PointsHistory.points_by == User.user_id')
        task = relationship('Task')
    
    
    class Solution(Base):
        __tablename__ = 'solutions'
    
        task_id = Column(Integer, ForeignKey('tasks.task_id'), primary_key=True, nullable=False)
        user_id = Column(Integer, ForeignKey('users.user_id'), primary_key=True, nullable=False)
        last_submit = Column(Integer, ForeignKey('papers.paper_id'))
        last_feedback = Column(Integer, ForeignKey('papers.paper_id'))
        points = Column(Integer)
    
        last_submit_obj = relationship('Paper', primaryjoin='Solution.last_submit == Paper.paper_id')
        last_feedback_obj = relationship('Paper', primaryjoin='Solution.last_feedback == Paper.paper_id')
        task = relationship('Task')
        user = relationship('User')
    
    
    _engine: Optional[Engine] = None
    _session: Optional[Session] = None
    flask_db: Any = None
    
    
    def get_session() -> Session:
        global _session, _engine
        if flask_db:
            return flask_db.session
        if _session is None:
            if _engine is None:
                import config
                _engine = create_engine(config.SQLALCHEMY_DATABASE_URI, echo=config.SQLALCHEMY_ECHO)
            MOSession = sessionmaker(bind=_engine)
            _session = MOSession()
        return _session
    
    
    def get_place_parents(place: Place) -> List[Place]:
        sess = get_session()
    
        topq = (sess.query(Place)
                .filter(Place.place_id == place.place_id)
                .cte('parents', recursive=True))
    
        botq = (sess.query(Place)
                .join(topq, Place.place_id == topq.c.parent))
    
        recq = topq.union(botq)
    
        return sess.query(recq).all()
    
    
    def get_object_changes(obj):
        """ Given a model instance, returns dict of pending
        changes waiting for database flush/commit.
    
        e.g. {
            'some_field': {
                'before': *SOME-VALUE*,
                'after': *SOME-VALUE*
            },
            ...
        }
    
        Source: https://stackoverflow.com/questions/15952733/sqlalchemy-logging-of-changes-with-date-and-user
        """
        inspection = inspect(obj)
        changes = {}
        for attr in class_mapper(obj.__class__).column_attrs:
            if getattr(inspection.attrs, attr.key).history.has_changes():
                if get_history(obj, attr.key)[2]:
                    before = get_history(obj, attr.key)[2].pop()
                    after = getattr(obj, attr.key)
                    if before != after:
                        if before or after:
                            changes[attr.key] = {'before': before, 'after': after}
        return changes
    
    
    def row2dict(row):
        d = {}
        for column in row.__table__.columns:
            d[column.name] = getattr(row, column.name)
    
        return d