Skip to content
Snippets Groups Projects
Select Git revision
  • b7f1a824a69e8e871507da4b8d88059a462b81d3
  • master default
2 results

gtfs.py

Blame
  • gtfs.py 6.33 KiB
    import datetime
    import csv
    from dataclasses import dataclass
    import numpy as np
    import os, sys
    from utils import *
    import weakref
    
    local_timezone = datetime.datetime.now(datetime.timezone.utc).astimezone().tzinfo
    
    _shape_loading_from_disk_count = 0
    _shape_loading_from_cache_count = 0
    
    @dataclass
    class Route:
        id: str
        name: str
    
    @dataclass
    class Trip:
        route: Route
        id: int
        direction: int
        shape_id: str
        services_today: bool
    
    
    @dataclass
    class Stop:
        id: int
        name: str
        lat: float
        lon: float
    
    
    @dataclass
    class ShapePoint:
        lat: str
        lon: float
        dist_traveled: float
    
    @dataclass
    class TripStop:
        stop: Stop
        arrival_time: float
        departure_time: float
        shape_dist_traveled: float
    
    
    class GtfsDay:
        def __init__(self, date, data_getter):
            self.date = date
            self.data_getter = data_getter
            self.stops = None
            self.trips = None
            self.trips_by_routes = None
            self.routes = None
            self.services_today = None
            self.stops_for_trip = None
            self._shape_cache = weakref.WeakValueDictionary()
            self._file_cahce = {}
    
        async def get_file(self, name, cache=False):
            if name in self._file_cahce:
                return self._file_cahce[name]
            s = await self.data_getter(datetime.datetime.combine(self.date, datetime.time()), name) # HACK: Date is not CBOR supported type
            r =  list(csv.DictReader(s.decode("utf-8").split("\n")))
            if cache:
                self._file_cahce[name] = r
            return r
    
        async def load_calendar(self):
            if self.services_today: return
            date = self.date.strftime('%Y%m%d')
            d = await self.get_file("calendar.txt")
            assert len(d) == len({x['service_id'] for x in d}), "Duplicit service_id"
            services_today = {x['service_id']: x[['monday','tuesday','wednesday','thursday','friday','saturday','sunday'][self.date.weekday()]] == '1' and x['start_date']<=date  and date<=x['end_date'] for x in d}
    
            d = await self.get_file("calendar_dates.txt")
            for x in d:
                if date == x['date']:
                    services_today[x['service_id']] = x['exception_type'] == 1
            self.services_today = services_today
    
        async def load_trips(self):
            await self.load_calendar()
            if self.trips: return
    
            d = await self.get_file("routes.txt")
            self.routes = {
                    x["route_id"]: Route(
                        x["route_id"],
                        x["route_short_name"],
                    ) for x in d
            }
    
            d = await self.get_file("trips.txt")
            self.trips = {
                    x["trip_id"]: Trip(
                        route=self.routes[x["route_id"]],
                        id=x["trip_id"],
                        direction=int(x["direction_id"]),
                        shape_id=x["shape_id"],
                        services_today=self.services_today[x["service_id"]]
                    ) for x in d
            }
            self.trips_by_routes = {}
            for t in self.trips.values():
                r_id = t.route.id
                self.trips_by_routes.setdefault(r_id, [])
                self.trips_by_routes[r_id].append(t)
    
    
        async def load_stops(self):
            if self.stops: return
            d = await self.get_file("stops.txt")
            if self.stops: return
            self.stops = {
                    x["stop_id"]: Stop(x["stop_id"], x["stop_name"], float(x["stop_lat"]), float(x["stop_lon"]))
                    for x in d
            }
    
        async def get_shape_for_trip_id(self, trip_id):
            global _shape_loading_from_disk_count, _shape_loading_from_cache_count
            await self.load_trips()
            if trip_id not in self.trips:
                eprint(f"get_shape_for_trip_id failed: no such trip_id {trip_id}")
                return None
            shape_id = self.trips[trip_id].shape_id
            shape = self._shape_cache.get(shape_id, None)
            if shape is None:
                _shape_loading_from_disk_count += 1
                d = await self.get_file("shape_by_id/"+shape_id)
                self._shape_cache[shape_id] = shape = np.array([
                    [float(x["shape_pt_lat"]), float(x["shape_pt_lon"]), float(x["shape_dist_traveled"])]
    
                for x in d ])
            else:
                _shape_loading_from_cache_count += 1
            return shape
    
        def parse_time(self, val):
            h, m, s  = map(int, val.split(":"))
            return datetime.datetime.combine(self.date, datetime.time(0, 0, 0), local_timezone) + datetime.timedelta(hours=h, minutes=m, seconds=s) # hack for h > 23
    
        async def get_stops_for_trip_id(self, trip_id, data=None):
            await self.load_trips()
            await self.load_stops()
            if self.stops_for_trip:
                if not trip_id in self.stops_for_trip:
                    eprint(f"get_stops_for_trip_id failed: no such trip {trip_id} at {self.date}")
                    return []
                return self.stops_for_trip[trip_id]
            if data is not None:
                d =  list(csv.DictReader(data.decode("utf-8").split("\n")))
            else:
                eprint("LOADING STOPS FOR", trip_id, self.date, "(without cache)")
                d = await self.get_file("stop_times.txt", cache=False)
            return [ TripStop(
                self.stops[x["stop_id"]],
                self.parse_time(x["arrival_time"]),
                self.parse_time(x["departure_time"]),
                float(x["shape_dist_traveled"])
            ) for x in d if x["trip_id"] == trip_id]
    
        async def load_stops_for_all_trips(self):
            if self.stops_for_trip:
                return
            await self.load_trips()
            await self.load_stops()
            d = await self.get_file("stop_times.txt")
            stops_for_trip = {}
            for x in d:
                trip_id = x["trip_id"]
                if not trip_id in stops_for_trip:
                    stops_for_trip[trip_id] = []
                stops_for_trip[trip_id].append(
                    TripStop(
                        self.stops[x["stop_id"]],
                        self.parse_time(x["arrival_time"]),
                        self.parse_time(x["departure_time"]),
                        float(x["shape_dist_traveled"])
                    )
                )
    
            self.stops_for_trip = stops_for_trip
    
    
    
    
    
    
    
    default_data_getter = None
    
    for_date_cache = {}
    
    def for_date(date, data_getter=None):
        if isinstance(date, datetime.datetime):
            date = date.date()
        if date not in for_date_cache:
            for_date_cache[date] = GtfsDay(date, data_getter or default_data_getter)
        return for_date_cache[date]