Skip to content
Snippets Groups Projects
Select Git revision
  • 0e78c590f981f3cea39a2aa3d6d00a91266cc373
  • devel default
  • master
  • fo
  • jirka/typing
  • fo-base
  • mj/submit-images
  • jk/issue-96
  • jk/issue-196
  • honza/add-contestant
  • honza/mr7
  • honza/mrf
  • honza/mrd
  • honza/mra
  • honza/mr6
  • honza/submit-images
  • honza/kolo-vs-soutez
  • jh-stress-test-wip
  • shorten-schools
19 results

org_round.py

Blame
  • communication.py 7.30 KiB
    import asyncio
    import json
    import sys, os
    import cbor2
    import functools
    import traceback
    from datetime import datetime, timedelta, timezone
    
    def eprint(*args):
        print(*args, file=sys.stderr, flush=True)
    
    async def print_exceptions(corutine):
        try:
            return await corutine
        except Exception as e:
            eprint(traceback.format_exc())
            raise e
    
    async def connect_stdin_stdout():
        loop = asyncio.get_event_loop()
        reader = asyncio.StreamReader()
        protocol = asyncio.StreamReaderProtocol(reader)
        await loop.connect_read_pipe(lambda: protocol, sys.stdin)
        w_transport, w_protocol = await loop.connect_write_pipe(asyncio.streams.FlowControlMixin, sys.stdout)
        writer = asyncio.StreamWriter(w_transport, w_protocol, reader, loop)
        return reader, writer
    
    class Socket():
        async def write():
            raise NotImplementedError
        async def read():
            raise NotImplementedError
    
    class SSHRunSocket(Socket):
        async def connect(self, cmd):
            self.proc = await asyncio.create_subprocess_shell(
                cmd,
                stdin=asyncio.subprocess.PIPE,
                stdout=asyncio.subprocess.PIPE)
            return self
    
        async def write(self, data):
            # eprint("WRITE", data)
            self.proc.stdin.write(data)
    
        async def read(self):
            # eprint("READ BEGIN")
            r = await self.proc.stdout.read(1000)
            # eprint("READ END", r)
            return r
    
    class StdInOutSocket(Socket):
        async def connect(self):
            self.reader, self.writer = await connect_stdin_stdout()
            return self
    
        async def write(self, data):
            self.writer.write(data)
    
        async def read(self):
            # eprint("STDIN READ BEGIN")
            r = await self.reader.read(1000)
            # eprint("STDIN READ END", r)
            return r
    
    
    class MsgParser():
        def __init__(self, socket):
            self.socket = socket
            self.buffer = []
            self.len = 0
    
        async def get_bytes(self, l):
            while self.len < l:
                self.buffer.append(await self.socket.read())
                if not len(self.buffer[-1]):
                    raise EOFError()
                self.len += len(self.buffer[-1])
            data = b''.join(self.buffer)
            self.buffer = [data[l:]]
            self.len -= l
            return data[:l]
    
        async def get_msg(self):
            head_len = int(await self.get_bytes(5))
            data_len = int(await self.get_bytes(12))
            head = await self.get_bytes(head_len)
            data = await self.get_bytes(data_len)
            # eprint("GET MSG", head, data)
            return head, data
    
    def int_to_str_len(val, l):
        s = str(val)
        assert len(s) <= l
        return (l*" " + s)[-l:]
    
    def msg(head, data):
        assert head
        return int_to_str_len(len(head), 5).encode('utf-8') + int_to_str_len(len(data), 12).encode('utf-8') + head + data
    
    class WaitingQuestion():
        def __init__(self, id):
            self.future = asyncio.get_running_loop().create_future()
    
    class Client():
        def __init__(self, socket):
            self.socket = socket
            self.msg_parser = MsgParser(socket)
            self.waiting_questions = {}
            self.input_task = asyncio.create_task(print_exceptions(self.input_task_f()))
            self.id_alocator = 0
    
        async def input_task_f(self):
            while True:
                head_raw, data_raw = await self.msg_parser.get_msg()
                head = json.loads(head_raw)
                id = head["id"]
                assert id in self.waiting_questions
                q = self.waiting_questions[id]
                q.future.set_result((head, data_raw))
    
        async def question(self, head, data):
            id = self.id_alocator
            self.id_alocator += 1
            head = {'id': id, **head}
    
            self.waiting_questions[id] = q = WaitingQuestion(id)
            await self.socket.write(msg(json.dumps(head).encode('utf-8'), data))
            # eprint(asyncio.all_tasks())
            return await q.future
    
    class Server():
        def __init__(self, socket, callback):
            self.socket = socket
            self.msg_parser = MsgParser(socket)
            self.callback = callback
            self.id_alocator = 0
    
        async def question_task_f(self, in_head, in_data_raw):
            out_data_raw = await self.callback(in_head, in_data_raw, ...)
            out_head = {'id': in_head['id']}
            await self.socket.write(msg(json.dumps(out_head).encode('utf-8'), out_data_raw))
        async def run(self):
            while True:
                head_raw, data_raw = await self.msg_parser.get_msg()
                head = json.loads(head_raw)
                self.input_task = asyncio.create_task(print_exceptions(self.question_task_f(head, data_raw)))
    
    local_timezone = datetime.now(timezone.utc).astimezone().tzinfo
    
    def cbor_dump(x):
        return cbor2.dumps(x, timezone=local_timezone)
    
    class FuncCaller():
        def __init__(self, socket, is_server=False):
            self._socket_ = socket
            self._is_server_ = is_server
            if is_server:
                self._server_ = Server(socket, self._server_caller_)
            else:
                self._client_ = Client(socket)
    
        async def _run_(self, socket):
            server = Server(socket, self._caller_)
            await server.run()
    
        async def _server_caller_(self, in_head, in_data_raw, connection_controll):
            func_name = in_head["func_name"]
            assert type(func_name) == str and func_name[0] != '_'
            f = getattr(self, func_name)
            in_data = cbor2.loads(in_data_raw)
            r = await f(*in_data['args'], **in_data['kwargs'])
            return cbor_dump({'return': r})
    
        async def _client_caller_(self, func_name, *args, **kwargs):
            in_head, in_data_raw = await self._client_.question({'func_name': func_name}, cbor_dump({'args': args, 'kwargs': kwargs}))
            in_data = cbor2.loads(in_data_raw)
            return in_data["return"]
            
    
    
    
    def server_exec():
        def decorator(f):
            @functools.wraps(f)
            async def l(self, *args, **kwargs):
                if self._is_server_:
                    return await f(self, *args, **kwargs)
                else:
                    return await self._client_caller_(f.__name__, *args, **kwargs)
            return l
        return decorator
    
    def path_to_dt(path):
        path = path.replace('/', '-')
        return datetime(*[int(i) for i in path.split('-') if i], tzinfo=local_timezone)
    
    def dt_to_path(dt):
        return dt.strftime("%Y-%m-%d/%H/%M-%S")
    
    class DownloadServer(FuncCaller):
        @server_exec()
        async def list_realtime_data(self, date_from: datetime, date_to: datetime):
            out = []
            d = "data/realtime"
            for d_Y_m_d in sorted(os.listdir(d)):
                path = d_Y_m_d
                dt = path_to_dt(path)
                if dt+timedelta(days=1) >= date_from and dt < date_to:
                    for d_H in sorted(os.listdir(d+"/"+path)):
                        path = d_Y_m_d+"/"+d_H
                        dt = path_to_dt(path)
                        if dt+timedelta(hours=1) >= date_from and dt < date_to:
                            for d_M_S in sorted(os.listdir(d+"/"+path)):
                                path = d_Y_m_d+"/"+d_H+"/"+d_M_S
                                dt = path_to_dt(path)
                                if dt >= date_from and dt <= date_to:
                                    out.append(dt)
    
            return out
    
        @server_exec()
        async def get_data(self, dt: datetime):
            path = "data/realtime/"+dt_to_path(dt)
            out = {}
            for filename in os.listdir(path):
                with open(path+'/'+filename, "rb") as f:
                    out[filename] = f.read()
            return out