import asyncio
import json
import math
import pprint
import time
import datetime
import pathlib

import gtfs
from utils import *

_shape_matching_time = 0


async def unzip_parse(data):
    if isinstance(data, dict):
        assert len(data) == 1
        data = list(data.values())[0]
    if isinstance(data, bytes) and data.startswith(b'\x1f\x8b'):
            proc = await asyncio.create_subprocess_exec("gunzip", stdout=asyncio.subprocess.PIPE, stdin=asyncio.subprocess.PIPE)
            data, stderr = await proc.communicate(data)
    if isinstance(data, bytes) and data.startswith(b'(\xb5/\xfd'):
            proc = await asyncio.create_subprocess_exec("unzstd", stdout=asyncio.subprocess.PIPE, stdin=asyncio.subprocess.PIPE)
            data, stderr = await proc.communicate(data)
    return json.loads(data)


def dist(a, b):
    assert 40 < a[0] and a[0] < 60, a[0]
    assert  5 < a[1] and a[1] < 25, a[1]
    assert 40 < b[0] and b[0] < 60, b[0]
    assert  5 < b[1] and b[1] < 25, b[1] # Check that we are somewhere near Prague
    #           lat                           lon
    return math.sqrt(((a[0]-b[0])*111.2)**2 + ((a[1]-b[1])*71.50)**2)
    # ... and then use distance approximation based on length of one degree there.

lon_muntiplicator = 71.50/111.2
# This is magic constant for approximation globe to plane working near Prague.
# If lon is multiplied with this constant, one unit of lat and multiplied lon will have approximately the same distance

def shape_indexer(shape, i: float):
    ii = int(i)
    rest = i - ii
    if rest < 0.0001:
        return shape[ii]
    return shape[ii]*(1-rest) + shape[ii+1]*rest


async def get_data_of_trip(trip_id, date_from, date_to):
    c = await get_communication()
    dts = await c.list_realtime_data(date_from, date_to)
    out = []
    for dt in dts:
        tc = None
        print("GET", dt)
        data = await unzip_parse(await c.get_data(dt))
        for dato in data["features"]:
            if dato["properties"]["trip"]["gtfs"]["trip_id"] == trip_id:
                tc = dato
        out.append((dt, tc))
    return out

class Trip:
    def __init__(self, trip_id, date):
        self.trip_id = trip_id
        self.date = date

class AbstractHistoryPoint:
    pass

class HistoryPoint(AbstractHistoryPoint):
    def __init__(self, json, capture_time, save_json=True):
        if save_json:
            self.json = json
        self.state_position = json['properties']['last_position']['state_position']
        self.openapi_shape_dist_traveled = json['properties']['last_position']['shape_dist_traveled']
        self.lon, self.lat = json["geometry"]["coordinates"]
        self.number_of_captures = 0
        self.first_captured = capture_time
        self.last_captured = capture_time
        self.origin_timestamp = datetime.datetime.fromisoformat(json['properties']['last_position']['origin_timestamp'])
        self.openapi_last_stop = json['properties']['last_position']['last_stop']
        self.openapi_next_stop = json['properties']['last_position']['next_stop']
        self.openapi_delay = json['properties']['last_position']['delay']

class HistoryPointFromPreprocessed(AbstractHistoryPoint):
    def __init__(self, json):
        self.lat = json["lat"]
        self.lon = json["lon"]
        self.state_position = json["state_position"]
        self.first_captured = datetime.datetime.fromisoformat(json["first_captured"])
        self.last_captured = datetime.datetime.fromisoformat(json["last_captured"])
        self.origin_timestamp = datetime.datetime.fromisoformat(json["origin_timestamp"])
        self.openapi_shape_dist_traveled = json["openapi_shape_dist_traveled"]
        self.openapi_next_stop = json["openapi_next_stop"]
        self.openapi_last_stop = json["openapi_last_stop"]
        self.openapi_delay = json["openapi_delay"]
        self.shape_point = json["shape_point"]
        self.shape_point_dist_traveled = json["shape_point_dist_traveled"]


class TripHistory:
    def __init__(self, trip):
        self.trip = trip
        self.history = []
        self.trip_json = None

    async def load_stops(self, data=None):
        self.stops = await gtfs.for_date(self.trip.date).get_stops_for_trip_id(self.trip.trip_id, data=data)

    async def load_gtfs_shape(self):
        self.gtfs_shape = await gtfs.for_date(self.trip.date).get_shape_for_trip_id(self.trip.trip_id)
        if self.gtfs_shape is not None:
            assert len(self.gtfs_shape), f"Zero len shape for {self.trip.date} {self.trip.trip_id}"


    async def load_history(self, dt_from, dt_to):
        tps = await get_data_of_trip(self.trip.trip_id, dt_from, dt_to)

        for dt, tp in tps:
            self.add_history_point(dt, tp)

    def add_history_point(self, dt, json, save_json=True):
        global _shape_matching_time
        if json is not None:
            if self.trip_json is None:
                self.trip_json = json["properties"]["trip"]
            if self.trip_json != json["properties"]["trip"]:
                ...
                    # print("Trip json changed")
                    # pprint.pp(self.trip_json)
                    # print("---------------------")
                    # pprint.pp(json["properties"]["trip"])
                    # print("=====================")
            lon, lat = json["geometry"]["coordinates"]

            if (
                    len(self.history)
                    and lon == self.history[-1].lon
                    and lat == self.history[-1].lat
                    and datetime.datetime.fromisoformat(json["properties"]["last_position"]["origin_timestamp"]) == self.history[-1].origin_timestamp
                    and json["properties"]["last_position"]["state_position"] == self.history[-1].state_position
            ):
                self.history[-1].last_captured = dt
            else:
                hp = HistoryPoint(json, dt, save_json=save_json)
                if hp.state_position in ['on_track', 'at_stop'] and self.gtfs_shape is not None:
                    if len(self.history):
                        last_shape_point_id = self.history[-1].shape_point
                        if last_shape_point_id is None:
                            last_shape_point_id = 0 # We are on the begin of the track (last point was in before track state)
                    else:
                        last_shape_point_id = None # We don't know where we are

                    last_shape_point = shape_indexer(self.gtfs_shape, last_shape_point_id) if last_shape_point_id is not None else None

                    dist_traveled_mutiplicator = 0.01

                    def calc_key(i):
                        x1, y1 = self.gtfs_shape[i][0], self.gtfs_shape[i][1] * lon_muntiplicator
                        x2, y2 = self.gtfs_shape[i+1][0], self.gtfs_shape[i+1][1] * lon_muntiplicator
                        x3, y3 = hp.lat, hp.lon * lon_muntiplicator
                        dx, dy = x2-x1, y2-y1
                        det = dx*dx + dy*dy
                        if det < 0.000000000001:
                            a = 0
                        else:
                            a = (dy*(y3-y1)+dx*(x3-x1))/det
                            a = min(max(a, 0), 1)
                        near_pt = shape_indexer(self.gtfs_shape, i+a)
                        return dist((hp.lat, hp.lon), (near_pt[0], near_pt[1])) + (dist_traveled_mutiplicator*abs(near_pt[2] - last_shape_point[2]) if last_shape_point is not None else 0), i+a

                    _shape_matching_time -= time.time()
                    if last_shape_point is None:
                        hp.shape_point = min(
                                calc_key(i) for i in range(len(self.gtfs_shape)-1)
                        )[1]
                    else:
                        opt, opt_key = None, 10e9

                        i = int(last_shape_point_id)
                        while i < len(self.gtfs_shape)-1 and opt_key >= dist_traveled_mutiplicator*abs(self.gtfs_shape[i][2] - last_shape_point[2]):
                            k, v = calc_key(i)
                            if k < opt_key:
                                opt_key, opt = k, v
                            i += 1

                        i = int(last_shape_point_id) - 1
                        while i >= 0 and opt_key >= dist_traveled_mutiplicator*abs(self.gtfs_shape[i+1][2] - last_shape_point[2]):
                            k, v = calc_key(i)
                            if k < opt_key:
                                opt_key, opt = k, v
                            i -= 1
                        hp.shape_point = opt

                    i = hp.shape_point
                    _shape_matching_time += time.time()

                    hp.shape_point_dist_traveled = shape_indexer(self.gtfs_shape, i)[2]
                else:
                    hp.shape_point = hp.shape_point_dist_traveled = None

                self.history.append(hp)

        else:
           if len(tps_new):
               ...
                # tps_new[-1][2]["without_data"] += 1

    def add_preprocessed_data(self, data, dt_from=dt_minf, dt_to=dt_pinf):
        for x in data["history"]:
            hp = HistoryPointFromPreprocessed(x)
            if dt_from <= hp.last_captured and hp.first_captured < dt_to:
                self.history.append(HistoryPointFromPreprocessed(x))
        self.history.sort(key=lambda hp: hp.first_captured)






class TripPoint:
    def __init__(self, json, capture_time):
        self.json = json
        self.origin_timestamp = datetime.datetime.fromisoformat(json['properties']['last_position']['origin_timestamp'])
        self.start_timestamp = datetime.datetime.fromisoformat(json['properties']['trip']['start_timestamp'])
        self.trip = Trip(json["properties"]["trip"]["gtfs"]["trip_id"], self.start_timestamp.date())
        self.capture_time = capture_time
        self.state_position = json['properties']['last_position']['state_position']
        self.openapi_shape_dist_traveled = json['properties']['last_position']['shape_dist_traveled']
        self.openapi_last_stop = json['properties']['last_position']['last_stop']
        self.openapi_next_stop = json['properties']['last_position']['next_stop']
        self.openapi_delay = json['properties']['last_position']['delay']

def write_data(dt, data):
    import random

    def only_printable_acii(s):
        return all(ord(x) >= ord(' ') for x in s) and s.isascii()

    tmp_file = dt.strftime(f"tmp/realtime-%Y-%m-%d--%H-%M-%S-{random.randint(0,100000)}")

    if  type(data) == dict:
        os.mkdir(tmp_file)
        for fname, d in data.items():
            assert '/' not in fname and only_printable_acii(fname)
            with open(tmp_file+'/'+fname, "xb") as f:
                f.write(d)
    else:
        with open(tmp_file, "xb") as f:
            assert type(data) == bytes and data.startswith(b'(\xb5/\xfd')
            f.write(data)

    file_path = [ "data/realtime", dt.strftime('%Y-%m-%d'), dt.strftime('%H'), dt.strftime("%M-%S") if type(data) == dict else dt.strftime("%M-%S.json.zst")]

    for i in range(2,4):
        try:
            os.mkdir('/'.join(file_path[:i]))
        except FileExistsError:
            pass

    os.rename(tmp_file, "/".join(file_path))