Skip to content
Snippets Groups Projects
Unverified Commit ed6903a7 authored by Ranald Lam's avatar Ranald Lam Committed by GitHub
Browse files

Add no-users option to DumpImporter and DumpExporter (#1165)


* feat: Allow DumpExporter to only export tasks

* fix: Dump exporter tests

* fix: Add DumpExporterTest for skip_users, fix bug

* feat: Add no-users option to DumpImporter

* fixup

Co-authored-by: default avatarAndrey Vihrov <andrey.vihrov@gmail.com>
parent d77b3bd6
No related branches found
No related tags found
No related merge requests found
...@@ -274,8 +274,8 @@ def get_datasets_to_judge(task): ...@@ -274,8 +274,8 @@ def get_datasets_to_judge(task):
def enumerate_files( def enumerate_files(
session, contest=None, session, contest=None,
skip_submissions=False, skip_user_tests=False, skip_print_jobs=False, skip_submissions=False, skip_user_tests=False, skip_users=False,
skip_generated=False): skip_print_jobs=False, skip_generated=False):
"""Enumerate all the files (by digest) referenced by the """Enumerate all the files (by digest) referenced by the
contest. contest.
...@@ -302,7 +302,7 @@ def enumerate_files( ...@@ -302,7 +302,7 @@ def enumerate_files(
queries.append(dataset_q.join(Dataset.testcases) queries.append(dataset_q.join(Dataset.testcases)
.with_entities(Testcase.output)) .with_entities(Testcase.output))
if not skip_submissions: if not skip_submissions and not skip_users:
submission_q = task_q.join(Task.submissions) submission_q = task_q.join(Task.submissions)
queries.append(submission_q.join(Submission.files) queries.append(submission_q.join(Submission.files)
.with_entities(File.digest)) .with_entities(File.digest))
...@@ -312,7 +312,7 @@ def enumerate_files( ...@@ -312,7 +312,7 @@ def enumerate_files(
.join(SubmissionResult.executables) .join(SubmissionResult.executables)
.with_entities(Executable.digest)) .with_entities(Executable.digest))
if not skip_user_tests: if not skip_user_tests and not skip_users:
user_test_q = task_q.join(Task.user_tests) user_test_q = task_q.join(Task.user_tests)
queries.append(user_test_q.with_entities(UserTest.input)) queries.append(user_test_q.with_entities(UserTest.input))
queries.append(user_test_q.join(UserTest.files) queries.append(user_test_q.join(UserTest.files)
...@@ -328,7 +328,7 @@ def enumerate_files( ...@@ -328,7 +328,7 @@ def enumerate_files(
.filter(UserTestResult.output != None) .filter(UserTestResult.output != None)
.with_entities(UserTestResult.output)) .with_entities(UserTestResult.output))
if not skip_print_jobs: if not skip_print_jobs and not skip_users:
queries.append(contest_q.join(Contest.participations) queries.append(contest_q.join(Contest.participations)
.join(Participation.printjobs) .join(Participation.printjobs)
.with_entities(PrintJob.digest)) .with_entities(PrintJob.digest))
......
...@@ -47,7 +47,7 @@ from cms import rmtree, utf8_decoder ...@@ -47,7 +47,7 @@ from cms import rmtree, utf8_decoder
from cms.db import version as model_version, Codename, Filename, \ from cms.db import version as model_version, Codename, Filename, \
FilenameSchema, FilenameSchemaArray, Digest, SessionGen, Contest, User, \ FilenameSchema, FilenameSchemaArray, Digest, SessionGen, Contest, User, \
Task, Submission, UserTest, SubmissionResult, UserTestResult, PrintJob, \ Task, Submission, UserTest, SubmissionResult, UserTestResult, PrintJob, \
enumerate_files Announcement, Participation, enumerate_files
from cms.db.filecacher import FileCacher from cms.db.filecacher import FileCacher
from cmscommon.datetime import make_timestamp from cmscommon.datetime import make_timestamp
from cmscommon.digest import path_digest from cmscommon.digest import path_digest
...@@ -136,13 +136,16 @@ class DumpExporter: ...@@ -136,13 +136,16 @@ class DumpExporter:
def __init__(self, contest_ids, export_target, def __init__(self, contest_ids, export_target,
dump_files, dump_model, skip_generated, dump_files, dump_model, skip_generated,
skip_submissions, skip_user_tests, skip_print_jobs): skip_submissions, skip_user_tests, skip_users, skip_print_jobs):
if contest_ids is None: if contest_ids is None:
with SessionGen() as session: with SessionGen() as session:
contests = session.query(Contest).all() contests = session.query(Contest).all()
self.contests_ids = [contest.id for contest in contests] self.contests_ids = [contest.id for contest in contests]
if not skip_users:
users = session.query(User).all() users = session.query(User).all()
self.users_ids = [user.id for user in users] self.users_ids = [user.id for user in users]
else:
self.users_ids = []
tasks = session.query(Task)\ tasks = session.query(Task)\
.filter(Task.contest_id.is_(None)).all() .filter(Task.contest_id.is_(None)).all()
self.tasks_ids = [task.id for task in tasks] self.tasks_ids = [task.id for task in tasks]
...@@ -158,6 +161,7 @@ class DumpExporter: ...@@ -158,6 +161,7 @@ class DumpExporter:
self.skip_generated = skip_generated self.skip_generated = skip_generated
self.skip_submissions = skip_submissions self.skip_submissions = skip_submissions
self.skip_user_tests = skip_user_tests self.skip_user_tests = skip_user_tests
self.skip_users = skip_users
self.skip_print_jobs = skip_print_jobs self.skip_print_jobs = skip_print_jobs
self.export_target = export_target self.export_target = export_target
...@@ -208,6 +212,7 @@ class DumpExporter: ...@@ -208,6 +212,7 @@ class DumpExporter:
session, contest, session, contest,
skip_submissions=self.skip_submissions, skip_submissions=self.skip_submissions,
skip_user_tests=self.skip_user_tests, skip_user_tests=self.skip_user_tests,
skip_users=self.skip_users,
skip_print_jobs=self.skip_print_jobs, skip_print_jobs=self.skip_print_jobs,
skip_generated=self.skip_generated) skip_generated=self.skip_generated)
for file_ in files: for file_ in files:
...@@ -317,6 +322,17 @@ class DumpExporter: ...@@ -317,6 +322,17 @@ class DumpExporter:
if self.skip_user_tests and other_cls is UserTest: if self.skip_user_tests and other_cls is UserTest:
continue continue
if self.skip_users:
skip = False
# User-related classes reachable from root
for rel_class in [Participation, Submission, UserTest,
Announcement]:
if other_cls is rel_class:
skip = True
break
if skip:
continue
# Skip print jobs if requested # Skip print jobs if requested
if self.skip_print_jobs and other_cls is PrintJob: if self.skip_print_jobs and other_cls is PrintJob:
continue continue
...@@ -397,6 +413,8 @@ def main(): ...@@ -397,6 +413,8 @@ def main():
help="don't export submissions") help="don't export submissions")
parser.add_argument("-U", "--no-user-tests", action="store_true", parser.add_argument("-U", "--no-user-tests", action="store_true",
help="don't export user tests") help="don't export user tests")
parser.add_argument("-X", "--no-users", action="store_true",
help="don't export users")
parser.add_argument("-P", "--no-print-jobs", action="store_true", parser.add_argument("-P", "--no-print-jobs", action="store_true",
help="don't export print jobs") help="don't export print jobs")
parser.add_argument("export_target", action="store", parser.add_argument("export_target", action="store",
...@@ -412,6 +430,7 @@ def main(): ...@@ -412,6 +430,7 @@ def main():
skip_generated=args.no_generated, skip_generated=args.no_generated,
skip_submissions=args.no_submissions, skip_submissions=args.no_submissions,
skip_user_tests=args.no_user_tests, skip_user_tests=args.no_user_tests,
skip_users=args.no_users,
skip_print_jobs=args.no_print_jobs) skip_print_jobs=args.no_print_jobs)
success = exporter.do_export() success = exporter.do_export()
return 0 if success is True else 1 return 0 if success is True else 1
......
...@@ -49,8 +49,8 @@ import cms.db as class_hook ...@@ -49,8 +49,8 @@ import cms.db as class_hook
from cms import utf8_decoder from cms import utf8_decoder
from cms.db import version as model_version, Codename, Filename, \ from cms.db import version as model_version, Codename, Filename, \
FilenameSchema, FilenameSchemaArray, Digest, SessionGen, Contest, \ FilenameSchema, FilenameSchemaArray, Digest, SessionGen, Contest, \
Submission, SubmissionResult, UserTest, UserTestResult, PrintJob, init_db, \ Submission, SubmissionResult, User, Participation, UserTest, \
drop_db, enumerate_files UserTestResult, PrintJob, Announcement, init_db, drop_db, enumerate_files
from cms.db.filecacher import FileCacher from cms.db.filecacher import FileCacher
from cmscommon.archive import Archive from cmscommon.archive import Archive
from cmscommon.datetime import make_datetime from cmscommon.datetime import make_datetime
...@@ -128,13 +128,14 @@ class DumpImporter: ...@@ -128,13 +128,14 @@ class DumpImporter:
def __init__(self, drop, import_source, def __init__(self, drop, import_source,
load_files, load_model, skip_generated, load_files, load_model, skip_generated,
skip_submissions, skip_user_tests, skip_print_jobs): skip_submissions, skip_user_tests, skip_users, skip_print_jobs):
self.drop = drop self.drop = drop
self.load_files = load_files self.load_files = load_files
self.load_model = load_model self.load_model = load_model
self.skip_generated = skip_generated self.skip_generated = skip_generated
self.skip_submissions = skip_submissions self.skip_submissions = skip_submissions
self.skip_user_tests = skip_user_tests self.skip_user_tests = skip_user_tests
self.skip_users = skip_users
self.skip_print_jobs = skip_print_jobs self.skip_print_jobs = skip_print_jobs
self.import_source = import_source self.import_source = import_source
...@@ -233,9 +234,6 @@ class DumpImporter: ...@@ -233,9 +234,6 @@ class DumpImporter:
for id_, data in self.datas.items(): for id_, data in self.datas.items():
if not id_.startswith("_"): if not id_.startswith("_"):
self.objs[id_] = self.import_object(data) self.objs[id_] = self.import_object(data)
for id_, data in self.datas.items():
if not id_.startswith("_"):
self.add_relationships(data, self.objs[id_])
for k, v in list(self.objs.items()): for k, v in list(self.objs.items()):
...@@ -244,18 +242,28 @@ class DumpImporter: ...@@ -244,18 +242,28 @@ class DumpImporter:
del self.objs[k] del self.objs[k]
# Skip user_tests if requested # Skip user_tests if requested
if self.skip_user_tests and isinstance(v, UserTest): elif self.skip_user_tests and isinstance(v, UserTest):
del self.objs[k]
# Skip users if requested
elif self.skip_users and \
isinstance(v, (User, Participation, Submission,
UserTest, Announcement)):
del self.objs[k] del self.objs[k]
# Skip print jobs if requested # Skip print jobs if requested
if self.skip_print_jobs and isinstance(v, PrintJob): elif self.skip_print_jobs and isinstance(v, PrintJob):
del self.objs[k] del self.objs[k]
# Skip generated data if requested # Skip generated data if requested
if self.skip_generated and \ elif self.skip_generated and \
isinstance(v, (SubmissionResult, UserTestResult)): isinstance(v, (SubmissionResult, UserTestResult)):
del self.objs[k] del self.objs[k]
for id_, data in self.datas.items():
if not id_.startswith("_") and id_ in self.objs:
self.add_relationships(data, self.objs[id_])
contest_id = list() contest_id = list()
contest_files = set() contest_files = set()
...@@ -266,6 +274,11 @@ class DumpImporter: ...@@ -266,6 +274,11 @@ class DumpImporter:
# that depended on submissions or user tests that we # that depended on submissions or user tests that we
# might have removed above). # might have removed above).
for id_ in self.datas["_objects"]: for id_ in self.datas["_objects"]:
# It could have been removed by request
if id_ not in self.objs:
continue
obj = self.objs[id_] obj = self.objs[id_]
session.add(obj) session.add(obj)
session.flush() session.flush()
...@@ -277,6 +290,7 @@ class DumpImporter: ...@@ -277,6 +290,7 @@ class DumpImporter:
skip_submissions=self.skip_submissions, skip_submissions=self.skip_submissions,
skip_user_tests=self.skip_user_tests, skip_user_tests=self.skip_user_tests,
skip_print_jobs=self.skip_print_jobs, skip_print_jobs=self.skip_print_jobs,
skip_users=self.skip_users,
skip_generated=self.skip_generated) skip_generated=self.skip_generated)
session.commit() session.commit()
...@@ -405,12 +419,12 @@ class DumpImporter: ...@@ -405,12 +419,12 @@ class DumpImporter:
if val is None: if val is None:
setattr(obj, prp.key, None) setattr(obj, prp.key, None)
elif isinstance(val, str): elif isinstance(val, str):
setattr(obj, prp.key, self.objs[val]) setattr(obj, prp.key, self.objs.get(val))
elif isinstance(val, list): elif isinstance(val, list):
setattr(obj, prp.key, list(self.objs[i] for i in val)) setattr(obj, prp.key, list(self.objs[i] for i in val if i in self.objs))
elif isinstance(val, dict): elif isinstance(val, dict):
setattr(obj, prp.key, setattr(obj, prp.key,
dict((k, self.objs[v]) for k, v in val.items())) dict((k, self.objs[v]) for k, v in val.items() if v in self.objs))
else: else:
raise RuntimeError( raise RuntimeError(
"Unknown RelationshipProperty value: %s" % type(val)) "Unknown RelationshipProperty value: %s" % type(val))
...@@ -472,6 +486,8 @@ def main(): ...@@ -472,6 +486,8 @@ def main():
help="don't import submissions") help="don't import submissions")
parser.add_argument("-U", "--no-user-tests", action="store_true", parser.add_argument("-U", "--no-user-tests", action="store_true",
help="don't import user tests") help="don't import user tests")
parser.add_argument("-X", "--no-users", action="store_true",
help="don't import users")
parser.add_argument("-P", "--no-print-jobs", action="store_true", parser.add_argument("-P", "--no-print-jobs", action="store_true",
help="don't import print jobs") help="don't import print jobs")
parser.add_argument("import_source", action="store", type=utf8_decoder, parser.add_argument("import_source", action="store", type=utf8_decoder,
...@@ -486,6 +502,7 @@ def main(): ...@@ -486,6 +502,7 @@ def main():
skip_generated=args.no_generated, skip_generated=args.no_generated,
skip_submissions=args.no_submissions, skip_submissions=args.no_submissions,
skip_user_tests=args.no_user_tests, skip_user_tests=args.no_user_tests,
skip_users=args.no_users,
skip_print_jobs=args.no_print_jobs) skip_print_jobs=args.no_print_jobs)
success = importer.do_import() success = importer.do_import()
return 0 if success is True else 1 return 0 if success is True else 1
......
...@@ -85,7 +85,7 @@ class TestDumpExporter(DatabaseMixin, FileSystemMixin, unittest.TestCase): ...@@ -85,7 +85,7 @@ class TestDumpExporter(DatabaseMixin, FileSystemMixin, unittest.TestCase):
super().tearDown() super().tearDown()
def do_export(self, contest_ids, dump_files=True, skip_generated=False, def do_export(self, contest_ids, dump_files=True, skip_generated=False,
skip_submissions=False): skip_submissions=False, skip_users=False):
"""Create an exporter and call do_export in a convenient way""" """Create an exporter and call do_export in a convenient way"""
r = DumpExporter( r = DumpExporter(
contest_ids, contest_ids,
...@@ -95,6 +95,7 @@ class TestDumpExporter(DatabaseMixin, FileSystemMixin, unittest.TestCase): ...@@ -95,6 +95,7 @@ class TestDumpExporter(DatabaseMixin, FileSystemMixin, unittest.TestCase):
skip_generated=skip_generated, skip_generated=skip_generated,
skip_submissions=skip_submissions, skip_submissions=skip_submissions,
skip_user_tests=False, skip_user_tests=False,
skip_users=skip_users,
skip_print_jobs=False).do_export() skip_print_jobs=False).do_export()
dump_path = os.path.join(self.target, "contest.json") dump_path = os.path.join(self.target, "contest.json")
try: try:
...@@ -269,6 +270,25 @@ class TestDumpExporter(DatabaseMixin, FileSystemMixin, unittest.TestCase): ...@@ -269,6 +270,25 @@ class TestDumpExporter(DatabaseMixin, FileSystemMixin, unittest.TestCase):
self.assertNotInDump(SubmissionResult) self.assertNotInDump(SubmissionResult)
self.assertFileNotInDump(self.exe_digest) self.assertFileNotInDump(self.exe_digest)
def test_skip_users(self):
"""Test skipping users.
Should not export users and depending objects.
Should still export contest, tasks and their depending objects.
"""
self.assertTrue(self.do_export(None, skip_users=True))
self.assertInDump(Statement, digest=self.st_digest)
self.assertFileInDump(self.st_digest, self.st_content)
self.assertNotInDump(User)
self.assertNotInDump(Participation)
self.assertNotInDump(Submission)
self.assertNotInDump(SubmissionResult)
self.assertFileNotInDump(self.file_digest)
self.assertFileNotInDump(self.exe_digest)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -25,7 +25,7 @@ import unittest ...@@ -25,7 +25,7 @@ import unittest
# Needs to be first to allow for monkey patching the DB connection string. # Needs to be first to allow for monkey patching the DB connection string.
from cmstestsuite.unit_tests.databasemixin import DatabaseMixin from cmstestsuite.unit_tests.databasemixin import DatabaseMixin
from cms.db import Contest, FSObject, Session, version from cms.db import Contest, User, FSObject, Session, version
from cmscommon.digest import bytes_digest from cmscommon.digest import bytes_digest
from cmscontrib.DumpImporter import DumpImporter from cmscontrib.DumpImporter import DumpImporter
from cmstestsuite.unit_tests.filesystemmixin import FileSystemMixin from cmstestsuite.unit_tests.filesystemmixin import FileSystemMixin
...@@ -136,7 +136,8 @@ class TestDumpImporter(DatabaseMixin, FileSystemMixin, unittest.TestCase): ...@@ -136,7 +136,8 @@ class TestDumpImporter(DatabaseMixin, FileSystemMixin, unittest.TestCase):
super().tearDown() super().tearDown()
def do_import(self, drop=False, load_files=True, def do_import(self, drop=False, load_files=True,
skip_generated=False, skip_submissions=False): skip_generated=False, skip_submissions=False,
skip_users=False):
"""Create an importer and call do_import in a convenient way""" """Create an importer and call do_import in a convenient way"""
return DumpImporter( return DumpImporter(
drop, drop,
...@@ -146,6 +147,7 @@ class TestDumpImporter(DatabaseMixin, FileSystemMixin, unittest.TestCase): ...@@ -146,6 +147,7 @@ class TestDumpImporter(DatabaseMixin, FileSystemMixin, unittest.TestCase):
skip_generated=skip_generated, skip_generated=skip_generated,
skip_submissions=skip_submissions, skip_submissions=skip_submissions,
skip_user_tests=False, skip_user_tests=False,
skip_users=skip_users,
skip_print_jobs=False).do_import() skip_print_jobs=False).do_import()
def write_dump(self, dump): def write_dump(self, dump):
...@@ -195,6 +197,12 @@ class TestDumpImporter(DatabaseMixin, FileSystemMixin, unittest.TestCase): ...@@ -195,6 +197,12 @@ class TestDumpImporter(DatabaseMixin, FileSystemMixin, unittest.TestCase):
.filter(Contest.name == name).all() .filter(Contest.name == name).all()
self.assertEqual(len(db_contests), 0) self.assertEqual(len(db_contests), 0)
def assertUserNotInDb(self, username):
"""Assert that the user with the given username is not in the DB."""
db_users = self.session.query(User)\
.filter(User.username == username).all()
self.assertEqual(len(db_users), 0)
def assertFileInDb(self, digest, description, content): def assertFileInDb(self, digest, description, content):
"""Assert that the file with the given data is in the DB.""" """Assert that the file with the given data is in the DB."""
fsos = self.session.query(FSObject)\ fsos = self.session.query(FSObject)\
...@@ -279,6 +287,24 @@ class TestDumpImporter(DatabaseMixin, FileSystemMixin, unittest.TestCase): ...@@ -279,6 +287,24 @@ class TestDumpImporter(DatabaseMixin, FileSystemMixin, unittest.TestCase):
self.assertFileNotInDb(TestDumpImporter.GENERATED_FILE_DIGEST) self.assertFileNotInDb(TestDumpImporter.GENERATED_FILE_DIGEST)
self.assertFileNotInDb(TestDumpImporter.NON_GENERATED_FILE_DIGEST) self.assertFileNotInDb(TestDumpImporter.NON_GENERATED_FILE_DIGEST)
def test_import_skip_users(self):
"""Test importing everything but not the users."""
self.write_dump(TestDumpImporter.DUMP)
self.write_files(TestDumpImporter.FILES)
self.assertTrue(self.do_import(skip_users=True))
self.assertContestInDb("contestname", "contest description 你好",
[("taskname", "task title")],
[])
self.assertContestInDb(
self.other_contest_name, self.other_contest_description, [], [])
self.assertUserNotInDb("username")
self.assertFileNotInDb(TestDumpImporter.GENERATED_FILE_DIGEST)
self.assertFileNotInDb(TestDumpImporter.NON_GENERATED_FILE_DIGEST)
def test_import_old(self): def test_import_old(self):
"""Test importing an old dump. """Test importing an old dump.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment