Select Git revision
protocols.py
-
Martin Mareš authoredMartin Mareš authored
communication.py NaN GiB
import asyncio
import json
import sys, os
import cbor2
import functools
import traceback
import datetime
import time
from pathlib import Path
from utils import *
debug = False
def eprint(*args):
print(*args, file=sys.stderr, flush=True)
async def print_exceptions(corutine):
try:
return await corutine
except Exception as e:
eprint(traceback.format_exc())
raise e
async def connect_stdin_stdout():
loop = asyncio.get_event_loop()
reader = asyncio.StreamReader()
protocol = asyncio.StreamReaderProtocol(reader)
await loop.connect_read_pipe(lambda: protocol, sys.stdin)
w_transport, w_protocol = await loop.connect_write_pipe(asyncio.streams.FlowControlMixin, sys.stdout)
writer = asyncio.StreamWriter(w_transport, w_protocol, reader, loop)
return reader, writer
class Socket():
async def write():
raise NotImplementedError
async def read():
raise NotImplementedError
class SSHRunSocket(Socket):
async def connect(self, cmd):
self.proc = await asyncio.create_subprocess_shell(
cmd,
stdin=asyncio.subprocess.PIPE,
stdout=asyncio.subprocess.PIPE)
return self
async def write(self, data):
# eprint("WRITE", data)
self.proc.stdin.write(data)
async def read(self):
# eprint("READ BEGIN")
r = await self.proc.stdout.read(1000)
# eprint("READ END", r)
return r
class AsincioStreamSocket(Socket):
async def connect(self, reader, writer):
self.reader = reader
self.writer = writer
return self
async def write(self, data):
self.writer.write(data)
async def read(self):
# eprint("STDIN READ BEGIN")
r = await self.reader.read(1000)
# eprint("STDIN READ END", r)
return r
class StdInOutSocket(AsincioStreamSocket):
async def connect(self):
self.reader, self.writer = await connect_stdin_stdout()
return self
class UnixSocket(AsincioStreamSocket):
async def connect(self, path):
self.reader, self.writer = await asyncio.open_unix_connection(path)
return self
class MsgParser():
def __init__(self, socket):
self.socket = socket
self.buffer = []
self.len = 0
async def get_bytes(self, l):
while self.len < l:
self.buffer.append(await self.socket.read())
if not len(self.buffer[-1]):
raise EOFError()
self.len += len(self.buffer[-1])
data = b''.join(self.buffer)
self.buffer = [data[l:]]
self.len -= l
return data[:l]
async def get_msg(self):
head_len = int(await self.get_bytes(5))
data_len = int(await self.get_bytes(12))
head = await self.get_bytes(head_len)
data = await self.get_bytes(data_len)
# eprint("GET MSG", head, data)
return head, data
def int_to_str_len(val, l):
s = str(val)
assert len(s) <= l
return (l*" " + s)[-l:]
def msg(head, data):
assert head
return int_to_str_len(len(head), 5).encode('utf-8') + int_to_str_len(len(data), 12).encode('utf-8') + head + data
class WaitingQuestion():
def __init__(self, id):
self.future = asyncio.get_running_loop().create_future()
class Client():
def __init__(self, socket):
self.socket = socket
self.msg_parser = MsgParser(socket)
self.waiting_questions = {}
self.input_task = asyncio.create_task(print_exceptions(self.input_task_f()))
self.id_alocator = 0
async def input_task_f(self):
while True:
head_raw, data_raw = await self.msg_parser.get_msg()
head = json.loads(head_raw)
id = head["id"]
assert id in self.waiting_questions
q = self.waiting_questions[id]
del self.waiting_questions[id]
try:
q.future.set_result((head, data_raw))
except asyncio.exceptions.InvalidStateError:
pass # Probably task canceled
async def question(self, head, data):
id = self.id_alocator
self.id_alocator += 1
head = {'id': id, **head}
self.waiting_questions[id] = q = WaitingQuestion(id)
await self.socket.write(msg(json.dumps(head).encode('utf-8'), data))
# eprint(asyncio.all_tasks())
return await q.future
class Server():
def __init__(self, socket, callback):
self.socket = socket
self.msg_parser = MsgParser(socket)
self.callback = callback
self.id_alocator = 0
async def question_task_f(self, in_head, in_data_raw):
out_data_raw = await self.callback(in_head, in_data_raw, ...)
out_head = {'id': in_head['id']}
await self.socket.write(msg(json.dumps(out_head).encode('utf-8'), out_data_raw))
async def run(self):
while True:
head_raw, data_raw = await self.msg_parser.get_msg()
head = json.loads(head_raw)
self.input_task = asyncio.create_task(print_exceptions(self.question_task_f(head, data_raw)))
def cbor_dump(x):
return cbor2.dumps(x, timezone=local_timezone)
class FuncCaller():
def __init__(self, socket=None, is_server=False):
if socket == None:
is_server = True
self._socket_ = socket
self._is_server_ = is_server
if is_server:
if socket is not None:
self._server_ = Server(socket, self._server_caller_)
else:
self._client_ = Client(socket)
async def _run_(self, socket):
server = Server(socket, self._caller_)
await server.run()
async def _server_caller_(self, in_head, in_data_raw, connection_controll):
func_name = in_head["func_name"]
assert type(func_name) == str and func_name[0] != '_'
f = getattr(self, func_name)
in_data = cbor2.loads(in_data_raw)
r = await f(*in_data['args'], **in_data['kwargs'])
return cbor_dump({'return': r})
async def _client_caller_(self, func_name, *args, **kwargs):
if debug:
eprint("ASK:", func_name, args, kwargs)
in_head, in_data_raw = await self._client_.question({'func_name': func_name}, cbor_dump({'args': args, 'kwargs': kwargs}))
if debug:
eprint("DONE", func_name, args, kwargs)
in_data = cbor2.loads(in_data_raw)
return in_data["return"]
def server_exec():
def decorator(f):
@functools.wraps(f)
async def l(self, *args, **kwargs):
if self._is_server_:
return await f(self, *args, **kwargs)
else:
return await self._client_caller_(f.__name__, *args, **kwargs)
return l
return decorator
def path_to_dt(path):
path = path.split('.')[0]
path = path.replace('/', '-')
return datetime.datetime(*[int(i) for i in path.split('-') if i], tzinfo=local_timezone)
def dt_to_path(dt):
return dt.strftime("%Y-%m-%d/%H/%M-%S")
def date_to_path(dt):
return dt.strftime("%Y-%m-%d")
def dt_intersect(dt_from: datetime.datetime, dt_to: datetime.datetime, dt: datetime.datetime, delta: datetime.timedelta):
" Test if dt_from .. dt_to and dt .. dt + delta - epsilon intersect "
return dt+delta >= dt_from and dt < dt_to
class MainServer(FuncCaller):
_download_server = None
_preprocess_server = None
async def _get_download_server(self):
if self._download_server is None:
import download_server
s = await UnixSocket().connect("sockets/download_server")
self._download_server = download_server.DownloadServer(s)
return self._download_server
async def _get_preprocess_server(self):
if self._preprocess_server is None:
import preprocess
s = await UnixSocket().connect("sockets/preprocess_server")
self._preprocess_server = preprocess.PreprocessServer(s)
return self._preprocess_server
async def _tree_walker(self, condition, worker, reverse=False):
d = "data/realtime"
for d_Y_m_d in sorted(os.listdir(d), reverse=reverse):
path = d_Y_m_d
dt = path_to_dt(path)
if await condition(dt, datetime.timedelta(days=1)):
for d_H in sorted(os.listdir(d+"/"+path), reverse=reverse):
path = d_Y_m_d+"/"+d_H
dt = path_to_dt(path)
if await condition(dt, datetime.timedelta(hours=1)):
for d_M_S in sorted(os.listdir(d+"/"+path), reverse=reverse):
path = d_Y_m_d+"/"+d_H+"/"+d_M_S
dt = path_to_dt(path)
await worker(dt)
@server_exec()
async def list_realtime_data(self, dt_from: datetime.datetime, dt_to: datetime.datetime):
out = []
async def condition(dt, delta):
return dt_intersect(dt_from, dt_to, dt, delta)
async def worker(dt):
if dt >= dt_from and dt <= dt_to:
out.append(dt)
await self._tree_walker(condition, worker)
return out
async def list_next_realtime_data(self, dt_from: datetime):
class Return(Exception):
def __init__(self, val):
self.val = val
async def condition(dt, delta):
return dt_intersect(dt_from, dt_pinf, dt, delta)
async def worker(dt):
if dt > dt_from:
raise Return(dt)
try:
await self._tree_walker(condition, worker)
except Return as r:
return r.val
return None
async def list_prev_realtime_data(self, dt_from: datetime):
class Return(Exception):
def __init__(self, val):
self.val = val
async def condition(dt, delta):
return dt_intersect(dt_minf, dt_from, dt, delta)
async def worker(dt):
if dt < dt_from:
raise Return(dt)
try:
await self._tree_walker(condition, worker, reverse=True)
except Return as r:
return r.val
return None
@server_exec()
async def get_data(self, dt: datetime.datetime):
path = "data/realtime/"+dt_to_path(dt)
if Path(path+".json.zst").exists():
with open(path+".json.zst", "rb") as f:
return f.read()
out = {}
for filename in os.listdir(path):
with open(path+'/'+filename, "rb") as f:
out[filename] = f.read()
return out
@server_exec()
async def get_next_data(self, dt: datetime.datetime):
next_dt = await self.list_next_realtime_data(dt)
if next_dt is None:
return None
else:
return next_dt, await self.get_data(dt)
@server_exec()
async def get_prev_data(self, dt: datetime.datetime):
prev_dt = await self.list_prev_realtime_data(dt)
if prev_dt is None:
return None
else:
return prev_dt, await self.get_data(dt)
@server_exec()
async def get_preprocessed_data(self, dt: datetime.datetime, route_id: str):
assert all(i.isalnum() for i in route_id)
path = "data/realtime_by_route/"+dt.strftime("%Y-%m-%d/%H")
out = {}
if not os.path.exists(path):
return await (await self._get_preprocess_server()).get_preprocessed_data(dt, dt+datetime.timedelta(hours=1), route_id)
with open(path+"/source_timestamps") as f:
source_timestamps = [ datetime.datetime.fromisoformat(x.strip()) for x in f ]
if os.path.exists(path+"/"+route_id+".json.zst"):
with open(path+"/"+route_id+".json.zst", "rb") as f:
return f.read(), source_timestamps
else:
return '{}', source_timestamps
@server_exec()
async def gtfs_get_file(self, dt: datetime.datetime, filename: str):
assert all(i in "_./" or i.isalnum() for i in filename)
assert "./" not in filename and "/." not in filename and not filename.startswith('.') and not filename.endswith('.')
path = "data/gtfs/"+date_to_path(dt)+"/"+filename
try:
with open(path, "rb") as f:
s = f.read()
return s
except FileNotFoundError:
pass
path2 = "data/gtfs/"+date_to_path(dt-datetime.timedelta(days=1))+"/"+filename
try:
with open(path2, "rb") as f:
s = f.read()
return s
except FileNotFoundError:
eprint(f"GTFS: {path} no such file")
return None
@server_exec()
async def gtfs_get_stop_times(self, dt: datetime.datetime, trip_filter=None, route_filter=None):
path = "data/gtfs/"+date_to_path(dt)+"/stop_times.txt"
try:
with open(path, "rb") as f:
s = f.read()
except FileNotFoundError:
path2 = "data/gtfs/"+date_to_path(dt-datetime.timedelta(days=1))+"/stop_times.txt"
try:
with open(path2, "rb") as f:
s = f.read()
except FileNotFoundError:
eprint(f"GTFS: {path} no such file")
return None
head, *data, _ = s.split(b'\n')
out_data = []
trip_filter_encoded = trip_filter.encode("utf-8") if trip_filter else None
route_filter_encoded = route_filter.encode("utf-8") if route_filter else None
for x in data:
y = x.split(b',')
if trip_filter and y[0]!=trip_filter_encoded:
continue
if route_filter and y[0].split(b"_")[0]!=route_filter_encoded:
continue
out_data.append(x)
return head+b'\n'+b'\n'.join(out_data)+b'\n'
@server_exec()
async def get_last_data(self):
return await (await self._get_download_server()).get_last_data()
@server_exec()
async def wait_next_data(self, last_dt: datetime.datetime, preferably_compressed: bool = True):
if last_dt.timestamp() + 600 < time.time():
return await self.get_next_data(last_dt)
r = await (await self._get_download_server()).wait_next_data(last_dt, preferably_compressed)
if r == 1:
eprint("wait_next_data from file system")
return await self.get_next_data(last_dt)
else:
return r
class AdminServer(MainServer):
@server_exec()
async def remove_data(self, dt: datetime.datetime):
path = "data/realtime/"+dt_to_path(dt)
out = {}
if os.path.isdir(path):
for filename in os.listdir(path):
os.remove(path+'/'+filename)
os.rmdir(path)
if os.path.isfile(path+".json.zst"):
os.remove(path+".json.zst")
if os.path.isfile(path+".json.gzip"):
os.remove(path+".json.gzip")