Select Git revision
org_round.py
-
Martin Mareš authoredMartin Mareš authored
communication.py 7.30 KiB
import asyncio
import json
import sys, os
import cbor2
import functools
import traceback
from datetime import datetime, timedelta, timezone
def eprint(*args):
print(*args, file=sys.stderr, flush=True)
async def print_exceptions(corutine):
try:
return await corutine
except Exception as e:
eprint(traceback.format_exc())
raise e
async def connect_stdin_stdout():
loop = asyncio.get_event_loop()
reader = asyncio.StreamReader()
protocol = asyncio.StreamReaderProtocol(reader)
await loop.connect_read_pipe(lambda: protocol, sys.stdin)
w_transport, w_protocol = await loop.connect_write_pipe(asyncio.streams.FlowControlMixin, sys.stdout)
writer = asyncio.StreamWriter(w_transport, w_protocol, reader, loop)
return reader, writer
class Socket():
async def write():
raise NotImplementedError
async def read():
raise NotImplementedError
class SSHRunSocket(Socket):
async def connect(self, cmd):
self.proc = await asyncio.create_subprocess_shell(
cmd,
stdin=asyncio.subprocess.PIPE,
stdout=asyncio.subprocess.PIPE)
return self
async def write(self, data):
# eprint("WRITE", data)
self.proc.stdin.write(data)
async def read(self):
# eprint("READ BEGIN")
r = await self.proc.stdout.read(1000)
# eprint("READ END", r)
return r
class StdInOutSocket(Socket):
async def connect(self):
self.reader, self.writer = await connect_stdin_stdout()
return self
async def write(self, data):
self.writer.write(data)
async def read(self):
# eprint("STDIN READ BEGIN")
r = await self.reader.read(1000)
# eprint("STDIN READ END", r)
return r
class MsgParser():
def __init__(self, socket):
self.socket = socket
self.buffer = []
self.len = 0
async def get_bytes(self, l):
while self.len < l:
self.buffer.append(await self.socket.read())
if not len(self.buffer[-1]):
raise EOFError()
self.len += len(self.buffer[-1])
data = b''.join(self.buffer)
self.buffer = [data[l:]]
self.len -= l
return data[:l]
async def get_msg(self):
head_len = int(await self.get_bytes(5))
data_len = int(await self.get_bytes(12))
head = await self.get_bytes(head_len)
data = await self.get_bytes(data_len)
# eprint("GET MSG", head, data)
return head, data
def int_to_str_len(val, l):
s = str(val)
assert len(s) <= l
return (l*" " + s)[-l:]
def msg(head, data):
assert head
return int_to_str_len(len(head), 5).encode('utf-8') + int_to_str_len(len(data), 12).encode('utf-8') + head + data
class WaitingQuestion():
def __init__(self, id):
self.future = asyncio.get_running_loop().create_future()
class Client():
def __init__(self, socket):
self.socket = socket
self.msg_parser = MsgParser(socket)
self.waiting_questions = {}
self.input_task = asyncio.create_task(print_exceptions(self.input_task_f()))
self.id_alocator = 0
async def input_task_f(self):
while True:
head_raw, data_raw = await self.msg_parser.get_msg()
head = json.loads(head_raw)
id = head["id"]
assert id in self.waiting_questions
q = self.waiting_questions[id]
q.future.set_result((head, data_raw))
async def question(self, head, data):
id = self.id_alocator
self.id_alocator += 1
head = {'id': id, **head}
self.waiting_questions[id] = q = WaitingQuestion(id)
await self.socket.write(msg(json.dumps(head).encode('utf-8'), data))
# eprint(asyncio.all_tasks())
return await q.future
class Server():
def __init__(self, socket, callback):
self.socket = socket
self.msg_parser = MsgParser(socket)
self.callback = callback
self.id_alocator = 0
async def question_task_f(self, in_head, in_data_raw):
out_data_raw = await self.callback(in_head, in_data_raw, ...)
out_head = {'id': in_head['id']}
await self.socket.write(msg(json.dumps(out_head).encode('utf-8'), out_data_raw))
async def run(self):
while True:
head_raw, data_raw = await self.msg_parser.get_msg()
head = json.loads(head_raw)
self.input_task = asyncio.create_task(print_exceptions(self.question_task_f(head, data_raw)))
local_timezone = datetime.now(timezone.utc).astimezone().tzinfo
def cbor_dump(x):
return cbor2.dumps(x, timezone=local_timezone)
class FuncCaller():
def __init__(self, socket, is_server=False):
self._socket_ = socket
self._is_server_ = is_server
if is_server:
self._server_ = Server(socket, self._server_caller_)
else:
self._client_ = Client(socket)
async def _run_(self, socket):
server = Server(socket, self._caller_)
await server.run()
async def _server_caller_(self, in_head, in_data_raw, connection_controll):
func_name = in_head["func_name"]
assert type(func_name) == str and func_name[0] != '_'
f = getattr(self, func_name)
in_data = cbor2.loads(in_data_raw)
r = await f(*in_data['args'], **in_data['kwargs'])
return cbor_dump({'return': r})
async def _client_caller_(self, func_name, *args, **kwargs):
in_head, in_data_raw = await self._client_.question({'func_name': func_name}, cbor_dump({'args': args, 'kwargs': kwargs}))
in_data = cbor2.loads(in_data_raw)
return in_data["return"]
def server_exec():
def decorator(f):
@functools.wraps(f)
async def l(self, *args, **kwargs):
if self._is_server_:
return await f(self, *args, **kwargs)
else:
return await self._client_caller_(f.__name__, *args, **kwargs)
return l
return decorator
def path_to_dt(path):
path = path.replace('/', '-')
return datetime(*[int(i) for i in path.split('-') if i], tzinfo=local_timezone)
def dt_to_path(dt):
return dt.strftime("%Y-%m-%d/%H/%M-%S")
class DownloadServer(FuncCaller):
@server_exec()
async def list_realtime_data(self, date_from: datetime, date_to: datetime):
out = []
d = "data/realtime"
for d_Y_m_d in sorted(os.listdir(d)):
path = d_Y_m_d
dt = path_to_dt(path)
if dt+timedelta(days=1) >= date_from and dt < date_to:
for d_H in sorted(os.listdir(d+"/"+path)):
path = d_Y_m_d+"/"+d_H
dt = path_to_dt(path)
if dt+timedelta(hours=1) >= date_from and dt < date_to:
for d_M_S in sorted(os.listdir(d+"/"+path)):
path = d_Y_m_d+"/"+d_H+"/"+d_M_S
dt = path_to_dt(path)
if dt >= date_from and dt <= date_to:
out.append(dt)
return out
@server_exec()
async def get_data(self, dt: datetime):
path = "data/realtime/"+dt_to_path(dt)
out = {}
for filename in os.listdir(path):
with open(path+'/'+filename, "rb") as f:
out[filename] = f.read()
return out