Skip to content
Snippets Groups Projects
Select Git revision
  • 4b2f6d8ef7519d5eded20fc3f24659c7bb152be0
  • devel default
  • master
  • fo
  • jirka/typing
  • fo-base
  • mj/submit-images
  • jk/issue-96
  • jk/issue-196
  • honza/add-contestant
  • honza/mr7
  • honza/mrf
  • honza/mrd
  • honza/mra
  • honza/mr6
  • honza/submit-images
  • honza/kolo-vs-soutez
  • jh-stress-test-wip
  • shorten-schools
19 results

protocols.py

Blame
  • communication.py NaN GiB
    import asyncio
    import json
    import sys, os
    import cbor2
    import functools
    import traceback
    import datetime
    import time
    from pathlib import Path
    
    from utils import *
    
    debug = False
    
    def eprint(*args):
        print(*args, file=sys.stderr, flush=True)
    
    async def print_exceptions(corutine):
        try:
            return await corutine
        except Exception as e:
            eprint(traceback.format_exc())
            raise e
    
    async def connect_stdin_stdout():
        loop = asyncio.get_event_loop()
        reader = asyncio.StreamReader()
        protocol = asyncio.StreamReaderProtocol(reader)
        await loop.connect_read_pipe(lambda: protocol, sys.stdin)
        w_transport, w_protocol = await loop.connect_write_pipe(asyncio.streams.FlowControlMixin, sys.stdout)
        writer = asyncio.StreamWriter(w_transport, w_protocol, reader, loop)
        return reader, writer
    
    class Socket():
        async def write():
            raise NotImplementedError
        async def read():
            raise NotImplementedError
    
    class SSHRunSocket(Socket):
        async def connect(self, cmd):
            self.proc = await asyncio.create_subprocess_shell(
                cmd,
                stdin=asyncio.subprocess.PIPE,
                stdout=asyncio.subprocess.PIPE)
            return self
    
        async def write(self, data):
            # eprint("WRITE", data)
            self.proc.stdin.write(data)
    
        async def read(self):
            # eprint("READ BEGIN")
            r = await self.proc.stdout.read(1000)
            # eprint("READ END", r)
            return r
    
    class AsincioStreamSocket(Socket):
        async def connect(self, reader, writer):
            self.reader = reader
            self.writer = writer
            return self
    
        async def write(self, data):
            self.writer.write(data)
    
        async def read(self):
            # eprint("STDIN READ BEGIN")
            r = await self.reader.read(1000)
            # eprint("STDIN READ END", r)
            return r
    
    class StdInOutSocket(AsincioStreamSocket):
        async def connect(self):
            self.reader, self.writer = await connect_stdin_stdout()
            return self
    
    class UnixSocket(AsincioStreamSocket):
        async def connect(self, path):
            self.reader, self.writer = await asyncio.open_unix_connection(path)
            return self
    
    
    
    class MsgParser():
        def __init__(self, socket):
            self.socket = socket
            self.buffer = []
            self.len = 0
    
        async def get_bytes(self, l):
            while self.len < l:
                self.buffer.append(await self.socket.read())
                if not len(self.buffer[-1]):
                    raise EOFError()
                self.len += len(self.buffer[-1])
            data = b''.join(self.buffer)
            self.buffer = [data[l:]]
            self.len -= l
            return data[:l]
    
        async def get_msg(self):
            head_len = int(await self.get_bytes(5))
            data_len = int(await self.get_bytes(12))
            head = await self.get_bytes(head_len)
            data = await self.get_bytes(data_len)
            # eprint("GET MSG", head, data)
            return head, data
    
    def int_to_str_len(val, l):
        s = str(val)
        assert len(s) <= l
        return (l*" " + s)[-l:]
    
    def msg(head, data):
        assert head
        return int_to_str_len(len(head), 5).encode('utf-8') + int_to_str_len(len(data), 12).encode('utf-8') + head + data
    
    class WaitingQuestion():
        def __init__(self, id):
            self.future = asyncio.get_running_loop().create_future()
    
    class Client():
        def __init__(self, socket):
            self.socket = socket
            self.msg_parser = MsgParser(socket)
            self.waiting_questions = {}
            self.input_task = asyncio.create_task(print_exceptions(self.input_task_f()))
            self.id_alocator = 0
    
        async def input_task_f(self):
            while True:
                head_raw, data_raw = await self.msg_parser.get_msg()
                head = json.loads(head_raw)
                id = head["id"]
                assert id in self.waiting_questions
                q = self.waiting_questions[id]
                del self.waiting_questions[id]
                try:
                    q.future.set_result((head, data_raw))
                except asyncio.exceptions.InvalidStateError:
                    pass # Probably task canceled
    
        async def question(self, head, data):
            id = self.id_alocator
            self.id_alocator += 1
            head = {'id': id, **head}
    
            self.waiting_questions[id] = q = WaitingQuestion(id)
            await self.socket.write(msg(json.dumps(head).encode('utf-8'), data))
            # eprint(asyncio.all_tasks())
            return await q.future
    
    class Server():
        def __init__(self, socket, callback):
            self.socket = socket
            self.msg_parser = MsgParser(socket)
            self.callback = callback
            self.id_alocator = 0
    
        async def question_task_f(self, in_head, in_data_raw):
            out_data_raw = await self.callback(in_head, in_data_raw, ...)
            out_head = {'id': in_head['id']}
            await self.socket.write(msg(json.dumps(out_head).encode('utf-8'), out_data_raw))
        async def run(self):
            while True:
                head_raw, data_raw = await self.msg_parser.get_msg()
                head = json.loads(head_raw)
                self.input_task = asyncio.create_task(print_exceptions(self.question_task_f(head, data_raw)))
    
    
    def cbor_dump(x):
        return cbor2.dumps(x, timezone=local_timezone)
    
    class FuncCaller():
        def __init__(self, socket=None, is_server=False):
            if socket == None:
                is_server = True
            self._socket_ = socket
            self._is_server_ = is_server
            if is_server:
                if socket is not None:
                    self._server_ = Server(socket, self._server_caller_)
            else:
                self._client_ = Client(socket)
    
        async def _run_(self, socket):
            server = Server(socket, self._caller_)
            await server.run()
    
        async def _server_caller_(self, in_head, in_data_raw, connection_controll):
            func_name = in_head["func_name"]
            assert type(func_name) == str and func_name[0] != '_'
            f = getattr(self, func_name)
            in_data = cbor2.loads(in_data_raw)
            r = await f(*in_data['args'], **in_data['kwargs'])
            return cbor_dump({'return': r})
    
        async def _client_caller_(self, func_name, *args, **kwargs):
            if debug:
                eprint("ASK:", func_name, args, kwargs)
            in_head, in_data_raw = await self._client_.question({'func_name': func_name}, cbor_dump({'args': args, 'kwargs': kwargs}))
            if debug:
                eprint("DONE", func_name, args, kwargs)
            in_data = cbor2.loads(in_data_raw)
            return in_data["return"]
    
    
    
    def server_exec():
        def decorator(f):
            @functools.wraps(f)
            async def l(self, *args, **kwargs):
                if self._is_server_:
                    return await f(self, *args, **kwargs)
                else:
                    return await self._client_caller_(f.__name__, *args, **kwargs)
            return l
        return decorator
    
    def path_to_dt(path):
        path = path.split('.')[0]
        path = path.replace('/', '-')
        return datetime.datetime(*[int(i) for i in path.split('-') if i], tzinfo=local_timezone)
    
    def dt_to_path(dt):
        return dt.strftime("%Y-%m-%d/%H/%M-%S")
    
    def date_to_path(dt):
        return dt.strftime("%Y-%m-%d")
    
    def dt_intersect(dt_from: datetime.datetime, dt_to: datetime.datetime, dt: datetime.datetime, delta: datetime.timedelta):
        " Test if dt_from .. dt_to and dt .. dt + delta - epsilon intersect "
        return dt+delta >= dt_from and dt < dt_to
    
    class MainServer(FuncCaller):
        _download_server = None
        _preprocess_server = None
    
        async def _get_download_server(self):
            if self._download_server is None:
                import download_server
                s = await UnixSocket().connect("sockets/download_server")
                self._download_server = download_server.DownloadServer(s)
            return self._download_server
    
        async def _get_preprocess_server(self):
            if self._preprocess_server is None:
                import preprocess
                s = await UnixSocket().connect("sockets/preprocess_server")
                self._preprocess_server = preprocess.PreprocessServer(s)
            return self._preprocess_server
    
        async def _tree_walker(self, condition, worker, reverse=False):
            d = "data/realtime"
            for d_Y_m_d in sorted(os.listdir(d), reverse=reverse):
                path = d_Y_m_d
                dt = path_to_dt(path)
                if await condition(dt, datetime.timedelta(days=1)):
                    for d_H in sorted(os.listdir(d+"/"+path), reverse=reverse):
                        path = d_Y_m_d+"/"+d_H
                        dt = path_to_dt(path)
                        if await condition(dt, datetime.timedelta(hours=1)):
                            for d_M_S in sorted(os.listdir(d+"/"+path), reverse=reverse):
                                path = d_Y_m_d+"/"+d_H+"/"+d_M_S
                                dt = path_to_dt(path)
                                await worker(dt)
    
        @server_exec()
        async def list_realtime_data(self, dt_from: datetime.datetime, dt_to: datetime.datetime):
            out = []
            async def condition(dt, delta):
                return dt_intersect(dt_from, dt_to, dt, delta)
            async def worker(dt):
                if dt >= dt_from and dt <= dt_to:
                    out.append(dt)
            await self._tree_walker(condition, worker)
            return out
    
        async def list_next_realtime_data(self, dt_from: datetime):
            class Return(Exception):
                def __init__(self, val):
                    self.val = val
            async def condition(dt, delta):
                return dt_intersect(dt_from, dt_pinf, dt, delta)
            async def worker(dt):
                if dt > dt_from:
                    raise Return(dt)
            try:
                await self._tree_walker(condition, worker)
            except Return as r:
                return r.val
            return None
    
        async def list_prev_realtime_data(self, dt_from: datetime):
            class Return(Exception):
                def __init__(self, val):
                    self.val = val
            async def condition(dt, delta):
                return dt_intersect(dt_minf, dt_from, dt, delta)
            async def worker(dt):
                if dt < dt_from:
                    raise Return(dt)
            try:
                await self._tree_walker(condition, worker, reverse=True)
            except Return as r:
                return r.val
            return None
    
        @server_exec()
        async def get_data(self, dt: datetime.datetime):
            path = "data/realtime/"+dt_to_path(dt)
            if Path(path+".json.zst").exists():
                with open(path+".json.zst", "rb") as f:
                    return f.read()
            out = {}
            for filename in os.listdir(path):
                with open(path+'/'+filename, "rb") as f:
                    out[filename] = f.read()
            return out
    
        @server_exec()
        async def get_next_data(self, dt: datetime.datetime):
            next_dt = await self.list_next_realtime_data(dt)
            if next_dt is None:
                return None
            else:
                return next_dt, await self.get_data(dt)
    
        @server_exec()
        async def get_prev_data(self, dt: datetime.datetime):
            prev_dt = await self.list_prev_realtime_data(dt)
            if prev_dt is None:
                return None
            else:
                return prev_dt, await self.get_data(dt)
    
        @server_exec()
        async def get_preprocessed_data(self, dt: datetime.datetime, route_id: str):
            assert all(i.isalnum() for i in route_id)
            path = "data/realtime_by_route/"+dt.strftime("%Y-%m-%d/%H")
            out = {}
            if not os.path.exists(path):
                return await (await self._get_preprocess_server()).get_preprocessed_data(dt, dt+datetime.timedelta(hours=1), route_id)
            with open(path+"/source_timestamps") as f:
                source_timestamps = [ datetime.datetime.fromisoformat(x.strip()) for x in f ]
            if os.path.exists(path+"/"+route_id+".json.zst"):
                with open(path+"/"+route_id+".json.zst", "rb") as f:
                    return f.read(), source_timestamps
            else:
                return '{}', source_timestamps
    
        @server_exec()
        async def gtfs_get_file(self, dt: datetime.datetime, filename: str):
            assert all(i in "_./" or i.isalnum() for i in filename)
            assert "./" not in filename and "/." not in filename and not filename.startswith('.') and not filename.endswith('.')
            path = "data/gtfs/"+date_to_path(dt)+"/"+filename
            try:
                with open(path, "rb") as f:
                    s = f.read()
                    return s
            except FileNotFoundError:
                pass
            path2 = "data/gtfs/"+date_to_path(dt-datetime.timedelta(days=1))+"/"+filename
            try:
                with open(path2, "rb") as f:
                    s = f.read()
                    return s
            except FileNotFoundError:
                eprint(f"GTFS: {path} no such file")
                return None
    
        @server_exec()
        async def gtfs_get_stop_times(self, dt: datetime.datetime, trip_filter=None, route_filter=None):
            path = "data/gtfs/"+date_to_path(dt)+"/stop_times.txt"
            try:
                with open(path, "rb") as f:
                    s = f.read()
            except FileNotFoundError:
                path2 = "data/gtfs/"+date_to_path(dt-datetime.timedelta(days=1))+"/stop_times.txt"
                try:
                    with open(path2, "rb") as f:
                        s = f.read()
                except FileNotFoundError:
                    eprint(f"GTFS: {path} no such file")
                    return None
    
            head, *data, _ = s.split(b'\n')
            out_data = []
            trip_filter_encoded = trip_filter.encode("utf-8") if trip_filter else None
            route_filter_encoded = route_filter.encode("utf-8") if route_filter else None
            for x in data:
                y = x.split(b',')
                if trip_filter and y[0]!=trip_filter_encoded:
                    continue
                if route_filter and y[0].split(b"_")[0]!=route_filter_encoded:
                    continue
                out_data.append(x)
            return head+b'\n'+b'\n'.join(out_data)+b'\n'
    
    
        @server_exec()
        async def get_last_data(self):
            return await (await self._get_download_server()).get_last_data()
    
        @server_exec()
        async def wait_next_data(self, last_dt: datetime.datetime, preferably_compressed: bool = True):
            if last_dt.timestamp() + 600 < time.time():
                return await self.get_next_data(last_dt)
            r = await (await self._get_download_server()).wait_next_data(last_dt, preferably_compressed)
            if r == 1:
                eprint("wait_next_data from file system")
                return await self.get_next_data(last_dt)
            else:
                return r
    
    
    class AdminServer(MainServer):
        @server_exec()
        async def remove_data(self, dt: datetime.datetime):
            path = "data/realtime/"+dt_to_path(dt)
            out = {}
            if os.path.isdir(path):
                for filename in os.listdir(path):
                    os.remove(path+'/'+filename)
                os.rmdir(path)
            if os.path.isfile(path+".json.zst"):
                os.remove(path+".json.zst")
            if os.path.isfile(path+".json.gzip"):
                os.remove(path+".json.gzip")