diff --git a/mo/imports.py b/mo/imports.py index 2953f9433114297b52989d2ce68116d437f16e5b..634908c5712110bd6d87db5f497d99fd5560433d 100644 --- a/mo/imports.py +++ b/mo/imports.py @@ -1,11 +1,11 @@ -from dataclasses import dataclass -import decimal +from dataclasses import dataclass, make_dataclass, field +from decimal import Decimal from enum import auto import io import re from sqlalchemy import and_ from sqlalchemy.orm import joinedload, Query -from typing import List, Optional, Any, Dict, Type, Union +from typing import List, Optional, Any, Dict, Type, Union, Tuple import mo.csv from mo.csv import FileFormat, MissingHeaderError @@ -202,7 +202,7 @@ class Import: self.new_user_ids.append(user.user_id) return user - def parse_points(self, points_str: str) -> Union[decimal.Decimal, str, None]: + def parse_points(self, points_str: str, task: db.Task) -> Union[Decimal, str, None]: if points_str == "": return self.error('Body musí být vyplněny') @@ -210,7 +210,7 @@ class Import: if points_str in ['X', '?']: return points_str - pts, error = mo.util.parse_points(points_str, self.task, self.round) + pts, error = mo.util.parse_points(points_str, task, self.round) if error: return self.error(error) @@ -635,24 +635,28 @@ class GlobalOrgsImport(OrgsImport): @dataclass -class PointsImportRow(mo.csv.Row): +class PointsImportRowBase(mo.csv.Row): user_id: str = "" krestni: str = "" prijmeni: str = "" - body: str = "" + # sloupce pro body za úlohy generujeme za běhu + + +_SolutionDict = Dict[Tuple[int, int], Optional[db.Solution]] +_TaskPoints = List[Tuple[db.Task, Union[Decimal, str, None]]] class PointsImport(Import): - row_class = PointsImportRow log_msg_prefix = 'Body' + task_columns: List[Tuple[db.Task, str]] allow_add_del: bool # je povoleno zakládat/mazat řešení def __init__( self, user: db.User, round: db.Round, - task: db.Task, + task: Optional[db.Task] = None, contest: Optional[db.Contest] = None, only_region: Optional[db.Place] = None, allow_add_del: bool = False, @@ -660,20 +664,32 @@ class PointsImport(Import): super().__init__(user) self.round = round self.contest = contest - self.task = task self.only_region = only_region self.allow_add_del = allow_add_del assert self.round is not None - assert self.task is not None - self.log_details = {'action': 'import-points', 'task': self.task.code} - self.template_basename = 'body-' + self.task.code + self.log_details = {'action': 'import-points'} + if task is None: + self.task_columns = [] + tasks = db.get_session().query(db.Task).filter_by(round=round).all() + task_cnt = 0 + for t in sorted(tasks, key=lambda task: task.code): + task_cnt += 1 + self.task_columns.append((t, f'body{task_cnt}')) + self.template_basename = 'body-' + self.round.round_code() + else: + self.task = task + self.task_columns = [(task, 'body')] + self.log_details['task'] = task.code + self.template_basename = 'body-' + task.code - def _pion_sol_query(self) -> Query: + fields = [(name, str, field(default="")) for task, name in self.task_columns] + self.row_class = make_dataclass('PointsImportRow', fields, bases=(PointsImportRowBase,)) + + def _pion_query(self) -> Query: sess = db.get_session() - query = (sess.query(db.Participation, db.Solution) + query = (sess.query(db.Participation) .select_from(db.Participation) - .outerjoin(db.Solution, and_(db.Solution.user_id == db.Participation.user_id, db.Solution.task == self.task)) - .options(joinedload(db.Participation.user))) + .options(joinedload(db.Participation.user), joinedload(db.Participation.contest))) if self.contest is not None: query = query.filter(db.Participation.contest_id == self.contest.master_contest_id) @@ -686,29 +702,39 @@ class PointsImport(Import): return query + def _get_solutions(self, pions: List[db.Participation]) -> _SolutionDict: + sess = db.get_session() + user_ids = [pion.user.user_id for pion in pions] + task_ids = [task.task_id for task, col in self.task_columns] + out: _SolutionDict = {(uid, tid): None for uid in user_ids for tid in task_ids} + for sol in (sess.query(db.Solution) + .filter(db.Solution.user_id.in_(user_ids)) + .filter(db.Solution.task_id.in_(task_ids)) + .all()): + out[sol.user_id, sol.task_id] = sol + return out + def import_row(self, r: mo.csv.Row) -> None: - assert isinstance(r, PointsImportRow) + assert isinstance(r, PointsImportRowBase) num_prev_errs = len(self.errors) user_id = self.parse_user_id(r.user_id) krestni = self.parse_name(r.krestni) prijmeni = self.parse_name(r.prijmeni) - body = self.parse_points(r.body) + task_points = [(task, self.parse_points(getattr(r, col), task)) for task, col in self.task_columns] if (len(self.errors) > num_prev_errs or user_id is None or krestni is None or prijmeni is None - or body is None): + or any(pts is None for task, pts in task_points)): return assert self.round is not None - assert self.task is not None - task_id = self.task.task_id sess = db.get_session() - query = self._pion_sol_query().filter(db.Participation.user_id == user_id) - pion_sols = query.all() - if not pion_sols: + query = self._pion_query().filter(db.Participation.user_id == user_id) + pions = query.all() + if not pions: if self.contest is not None: msg = self.round.get_level().name_locative('tomto', 'této', 'tomto') elif self.only_region is not None: @@ -716,14 +742,15 @@ class PointsImport(Import): else: msg = 'tomto kole' return self.error(f'Soutěžící nenalezen v {msg}') - elif len(pion_sols) > 1: + elif len(pions) > 1: return self.error('Soutěžící v tomto kole soutěží vícekrát, neumím zpracovat') - pion, sol = pion_sols[0] + pion = pions[0] + sols = self._get_solutions(pions) if not self.round.is_subround(): contest = pion.contest else: - contest = sess.query(db.Contest).filter_by(round=self.round, master_contest_id=pion.contest_id).one() + contest = sess.query(db.Contest).filter_by(round=self.round, master_contest_id=pion.contest.contest_id).one() rights = self.gatekeeper.rights_for_contest(contest) if not rights.can_edit_points(): @@ -733,67 +760,102 @@ class PointsImport(Import): if user.first_name != krestni or user.last_name != prijmeni: return self.error('Neodpovídá ID a jméno soutěžícího') - if sol is None: - if body == 'X': - return - if not self.allow_add_del: - return self.error('Tento soutěžící úlohu neodevzdal') - if not rights.can_upload_solutions(): - return self.error('Nemáte právo na zakládání nových řešení') - sol = db.Solution(user_id=user_id, task_id=task_id) - sess.add(sol) - logger.info(f'Import: Založeno řešení user=#{user_id} task=#{task_id}') - mo.util.log( - type=db.LogType.participant, - what=user_id, - details={'action': 'solution-created', 'task': task_id}, - ) - self.cnt_add_sols += 1 - elif body == 'X': - if not self.allow_add_del: - return self.error('Tento soutěžící úlohu odevzdal') - if sol.final_submit is not None or sol.final_feedback is not None: - return self.error('Nelze smazat řešení, ke kterému existují odevzdané soubory') - if not rights.can_upload_solutions(): - return self.error('Nemáte právo na mazání řešení') - logger.info(f'Import: Smazáno řešení user=#{user_id} task=#{task_id}') - mo.util.log( - type=db.LogType.participant, - what=user_id, - details={'action': 'solution-removed', 'task': task_id}, - ) - self.cnt_del_sols += 1 - sess.delete(sol) - return + self._add_del_solutions(user, sols, task_points, rights) + self._set_points(user, sols, task_points) - points = body if isinstance(body, decimal.Decimal) else None - if sol.points != points: - sol.points = points - sess.add(db.PointsHistory( - task=self.task, - participant_id=user_id, - user=self.user, - points_at=mo.now, - points=points, - )) - self.cnt_set_points += 1 + def _add_del_solutions(self, + user: db.User, + sols: _SolutionDict, + task_points: _TaskPoints, + rights: mo.rights.ContestRights) -> None: + user_id = user.user_id + sess = db.get_session() - def get_template(self) -> str: - rows = [] - for pion, sol in sorted(self._pion_sol_query().all(), key=lambda pair: pair[0].user.sort_key()): + for task, pts in task_points: + task_id = task.task_id + sol = sols[user_id, task_id] + if (sol is None) != (pts == 'X'): + if sol is None: + if not self.allow_add_del: + return self.error('Tento soutěžící úlohu neodevzdal') + if not rights.can_upload_solutions(): + return self.error('Nemáte právo na zakládání nových řešení') + sol = db.Solution(user_id=user_id, task_id=task_id) + sols[user_id, task_id] = sol + sess.add(sol) + logger.info(f'Import: Založeno řešení user=#{user_id} task=#{task_id}') + mo.util.log( + type=db.LogType.participant, + what=user_id, + details={'action': 'solution-created', 'task': task_id}, + ) + self.cnt_add_sols += 1 + elif pts == 'X': + if not self.allow_add_del: + return self.error('Tento soutěžící úlohu odevzdal') + if sol.final_submit is not None or sol.final_feedback is not None: + return self.error('Nelze smazat řešení, ke kterému existují odevzdané soubory') + if not rights.can_upload_solutions(): + return self.error('Nemáte právo na mazání řešení') + logger.info(f'Import: Smazáno řešení user=#{user_id} task=#{task_id}') + mo.util.log( + type=db.LogType.participant, + what=user_id, + details={'action': 'solution-removed', 'task': task_id}, + ) + sols[user_id, task_id] = None + sess.delete(sol) + self.cnt_del_sols += 1 + + def _set_points(self, + user: db.User, + sols: _SolutionDict, + task_points: _TaskPoints) -> None: + user_id = user.user_id + sess = db.get_session() + + for task, pts in task_points: + task_id = task.task_id + sol = sols[user_id, task_id] if sol is None: - pts = 'X' - elif sol.points is None: - pts = '?' - else: - pts = format_decimal(sol.points) + continue + + points = pts if isinstance(pts, Decimal) else None + if sol.points != points: + sol.points = points + sess.add(db.PointsHistory( + task=task, + participant_id=user_id, + user=self.user, + points_at=mo.now, + points=points, + )) + self.cnt_set_points += 1 + + def get_template(self) -> str: + # Není to přímo PointsImportRowBase, ale její dynamicky generovaný potomek + rows: List[PointsImportRowBase] = [] + assert issubclass(self.row_class, PointsImportRowBase) + pions = self._pion_query().all() + sols = self._get_solutions(pions) + + for pion in sorted(pions, key=lambda pion: pion.user.sort_key()): user = pion.user - rows.append(PointsImportRow( + row = self.row_class( user_id=user.user_id, krestni=user.first_name, prijmeni=user.last_name, - body=pts, - )) + ) + for task, col in self.task_columns: + sol = sols[user.user_id, task.task_id] + if sol is None: + pts = 'X' + elif sol.points is None: + pts = '?' + else: + pts = format_decimal(sol.points) + setattr(row, col, pts) + rows.append(row) out = io.StringIO() mo.csv.write(file=out, fmt=self.fmt, row_class=self.row_class, rows=rows)