Skip to content
Snippets Groups Projects
Select Git revision
  • 11a89b4c4cb1689258f0888df7f5e228f4cbe38d
  • devel default
  • master
  • fo
  • jirka/typing
  • fo-base
  • mj/submit-images
  • jk/issue-96
  • jk/issue-196
  • honza/add-contestant
  • honza/mr7
  • honza/mrf
  • honza/mrd
  • honza/mra
  • honza/mr6
  • honza/submit-images
  • honza/kolo-vs-soutez
  • jh-stress-test-wip
  • shorten-schools
19 results

acct.py

Blame
  • server.py 7.39 KiB
    import asyncio
    import json
    import sys, os
    import cbor2
    import functools
    import traceback
    import datetime
    import time
    from pathlib import Path
    
    from utils import *
    import communication
    from communication import server_exec
    
    class MainServer(communication.FuncCaller):
        _download_server = None
        _preprocess_server = None
    
        async def _get_download_server(self):
            if self._download_server is None:
                import download_server
                s = await UnixSocket().connect("sockets/download_server")
                self._download_server = download_server.DownloadServer(s)
            return self._download_server
    
        async def _get_preprocess_server(self):
            if self._preprocess_server is None:
                import preprocess
                s = await UnixSocket().connect("sockets/preprocess_server")
                self._preprocess_server = preprocess.PreprocessServer(s)
            return self._preprocess_server
    
        async def _tree_walker(self, condition, worker, reverse=False):
            d = "data/realtime"
            for d_Y_m_d in sorted(os.listdir(d), reverse=reverse):
                path = d_Y_m_d
                dt = path_to_dt(path)
                if await condition(dt, datetime.timedelta(days=1)):
                    for d_H in sorted(os.listdir(d+"/"+path), reverse=reverse):
                        path = d_Y_m_d+"/"+d_H
                        dt = path_to_dt(path)
                        if await condition(dt, datetime.timedelta(hours=1)):
                            for d_M_S in sorted(os.listdir(d+"/"+path), reverse=reverse):
                                path = d_Y_m_d+"/"+d_H+"/"+d_M_S
                                dt = path_to_dt(path)
                                await worker(dt)
    
        @server_exec()
        async def list_realtime_data(self, dt_from: datetime.datetime, dt_to: datetime.datetime):
            out = []
            async def condition(dt, delta):
                return dt_intersect(dt_from, dt_to, dt, delta)
            async def worker(dt):
                if dt >= dt_from and dt <= dt_to:
                    out.append(dt)
            await self._tree_walker(condition, worker)
            return out
    
        async def list_next_realtime_data(self, dt_from: datetime):
            class Return(Exception):
                def __init__(self, val):
                    self.val = val
            async def condition(dt, delta):
                return dt_intersect(dt_from, dt_pinf, dt, delta)
            async def worker(dt):
                if dt > dt_from:
                    raise Return(dt)
            try:
                await self._tree_walker(condition, worker)
            except Return as r:
                return r.val
            return None
    
        async def list_prev_realtime_data(self, dt_from: datetime):
            class Return(Exception):
                def __init__(self, val):
                    self.val = val
            async def condition(dt, delta):
                return dt_intersect(dt_minf, dt_from, dt, delta)
            async def worker(dt):
                if dt < dt_from:
                    raise Return(dt)
            try:
                await self._tree_walker(condition, worker, reverse=True)
            except Return as r:
                return r.val
            return None
    
        @server_exec()
        async def get_data(self, dt: datetime.datetime):
            path = "data/realtime/"+dt_to_path(dt)
            if Path(path+".json.zst").exists():
                with open(path+".json.zst", "rb") as f:
                    return f.read()
            out = {}
            for filename in os.listdir(path):
                with open(path+'/'+filename, "rb") as f:
                    out[filename] = f.read()
            return out
    
        @server_exec()
        async def get_next_data(self, dt: datetime.datetime):
            next_dt = await self.list_next_realtime_data(dt)
            if next_dt is None:
                return None
            else:
                return next_dt, await self.get_data(dt)
    
        @server_exec()
        async def get_prev_data(self, dt: datetime.datetime):
            prev_dt = await self.list_prev_realtime_data(dt)
            if prev_dt is None:
                return None
            else:
                return prev_dt, await self.get_data(dt)
    
        @server_exec()
        async def get_preprocessed_data(self, dt: datetime.datetime, route_id: str):
            assert all(i.isalnum() for i in route_id)
            path = "data/realtime_by_route/"+dt.strftime("%Y-%m-%d/%H")
            out = {}
            if not os.path.exists(path):
                return await (await self._get_preprocess_server()).get_preprocessed_data(dt, dt+datetime.timedelta(hours=1), route_id)
            with open(path+"/source_timestamps") as f:
                source_timestamps = [ datetime.datetime.fromisoformat(x.strip()) for x in f ]
            if os.path.exists(path+"/"+route_id+".json.zst"):
                with open(path+"/"+route_id+".json.zst", "rb") as f:
                    return f.read(), source_timestamps
            else:
                return '{}', source_timestamps
    
        @server_exec()
        async def gtfs_get_file(self, dt: datetime.datetime, filename: str):
            assert all(i in "_./" or i.isalnum() for i in filename)
            assert "./" not in filename and "/." not in filename and not filename.startswith('.') and not filename.endswith('.')
            path = "data/gtfs/"+date_to_path(dt)+"/"+filename
            try:
                with open(path, "rb") as f:
                    s = f.read()
                    return s
            except FileNotFoundError:
                pass
            path2 = "data/gtfs/"+date_to_path(dt-datetime.timedelta(days=1))+"/"+filename
            try:
                with open(path2, "rb") as f:
                    s = f.read()
                    return s
            except FileNotFoundError:
                eprint(f"GTFS: {path} no such file")
                return None
    
        @server_exec()
        async def gtfs_get_stop_times(self, dt: datetime.datetime, trip_filter=None, route_filter=None):
            path = "data/gtfs/"+date_to_path(dt)+"/stop_times.txt"
            try:
                with open(path, "rb") as f:
                    s = f.read()
            except FileNotFoundError:
                path2 = "data/gtfs/"+date_to_path(dt-datetime.timedelta(days=1))+"/stop_times.txt"
                try:
                    with open(path2, "rb") as f:
                        s = f.read()
                except FileNotFoundError:
                    eprint(f"GTFS: {path} no such file")
                    return None
    
            head, *data, _ = s.split(b'\n')
            out_data = []
            trip_filter_encoded = trip_filter.encode("utf-8") if trip_filter else None
            route_filter_encoded = route_filter.encode("utf-8") if route_filter else None
            for x in data:
                y = x.split(b',')
                if trip_filter and y[0]!=trip_filter_encoded:
                    continue
                if route_filter and y[0].split(b"_")[0]!=route_filter_encoded:
                    continue
                out_data.append(x)
            return head+b'\n'+b'\n'.join(out_data)+b'\n'
    
    
        @server_exec()
        async def get_last_data(self):
            return await (await self._get_download_server()).get_last_data()
    
        @server_exec()
        async def wait_next_data(self, last_dt: datetime.datetime, preferably_compressed: bool = True):
            if last_dt.timestamp() + 600 < time.time():
                return await self.get_next_data(last_dt)
            r = await (await self._get_download_server()).wait_next_data(last_dt, preferably_compressed)
            if r == 1:
                eprint("wait_next_data from file system")
                return await self.get_next_data(last_dt)
            else:
                return r
    
    
    
    async def main():
        s = await communication.StdInOutSocket().connect()
        await MainServer(s, is_server=True)._server_.run()
    
    if __name__ == "__main__":
        asyncio.run(main())