commit ab38966272114c858f459328e694354f61dee2ea
parent 6da36f7e4f10e281122a50db81de2b28bffad452
Author: Steve Gattuso <steve@stevegattuso.me>
Date: Mon, 6 Nov 2023 11:08:51 +0100
properly calculate volume
Diffstat:
5 files changed, 41 insertions(+), 37 deletions(-)
diff --git a/bin/hourly-volume-rollup b/bin/hourly-volume-rollup
@@ -11,7 +11,6 @@ import forerad.persistence as persistence
import forerad.utils as utils
store = persistence.SQLiteStore()
-log = utils.get_logger()
# The day that Citibike began publishing data on
ORIGIN_DATE = datetime.date(2013, 6, 1)
@@ -23,19 +22,25 @@ def detect_status(year: int, month: int) -> tuple[set[pd.Timestamp], set[pd.Time
"""
first_day = datetime.date(year, month, 1)
next_month = utils.next_month(first_day.year, first_day.month)
- expected_members = set(pd.date_range(
+ range = pd.date_range(
start=first_day,
end=next_month,
freq="1H",
+ tz=utils.TZ_NYC,
inclusive="left"
- ))
+ )
+ expected = set(range.tz_convert(utils.TZ_UTC))
rollup = store.fetch_hourly_volume_rollup(first_day, next_month)
- actual_members = set(rollup.index.values)
+ actual = set(rollup.index)
+
+ missing = expected - actual
+ return expected, actual, missing
- missing = expected_members - actual_members
- return expected_members, actual_members, missing
+def is_complete(archive: historical.HistoricalTripArchive) -> bool:
+ _, _, missing = detect_status(archive.date.year, archive.date.month)
+ return len(missing) == 0
def derive_rollup(a: historical.HistoricalTripArchive):
df = a.fetch_df()[['started_at']].reset_index()
@@ -45,7 +50,7 @@ def derive_rollup(a: historical.HistoricalTripArchive):
.reset_index()
store.write_hourly_volume_rollup(df)
- log.info(f"Wrote {len(df)} members to table")
+ utils.logger.info(f"Wrote {len(df)} members to table")
def main__populate(month_str):
archives = historical.HistoricalTripArchive.list_cached()
@@ -53,6 +58,10 @@ def main__populate(month_str):
year, month = utils.parse_month_str(month_str)
archives = [a for a in archives if a.date.year == year and a.date.month == month]
+ # Filter out completed rollups
+ archives = [a for a in archives if not is_complete(a)]
+ utils.logger.info(f'Rolling up {len(archives)} months of data')
+
[derive_rollup(a) for a in archives]
diff --git a/bin/scraper b/bin/scraper
@@ -11,7 +11,6 @@ import forerad.utils as utils
import forerad.scrapers.historical as scrape_historical
-log = utils.get_logger()
store = persistence.SQLiteStore()
@@ -21,16 +20,16 @@ def main__fetch(args: argparse.Namespace):
if args.month is not None:
archives = [a for a in archives if a.date.strftime('%Y-%m') == args.month]
if args.month and len(archives) != 1:
- log.error(f'Month filter "{args.month}" yielded {len(archives)} results. Aborting!')
+ utils.logger.error(f'Month filter "{args.month}" yielded {len(archives)} results. Aborting!')
sys.exit(1)
for archive in archives:
month_str = archive.date.strftime("%Y-%m")
if archive.is_downloaded:
- log.info(f'{month_str} is already persisted, skipping.')
+ utils.logger.info(f'{month_str} is already persisted, skipping.')
continue
- log.info(f'Fetching and storing {month_str}')
+ utils.logger.info(f'Fetching and storing {month_str}')
archive.fetch_df()
diff --git a/forerad/persistence.py b/forerad/persistence.py
@@ -18,28 +18,30 @@ class SQLiteStore():
be used to query the started_at and ended_at columns.
"""
dt = datetime.datetime.combine(date, datetime.datetime.min.time())
- unix_epoch_str = utils.TZ_NYC.localize(dt)\
- .astimezone(utils.TZ_UTC)\
- .strftime("%s")
+ localized = utils.TZ_NYC.localize(dt)\
+ .astimezone(utils.TZ_UTC)
- return int(unix_epoch_str)
+ return int(localized.strftime('%s'))
def fetch_hourly_volume_rollup(self, start_date: datetime.date, end_date: datetime.date) -> pd.DataFrame:
query = """
SELECT
- date(datetime, 'unixepoch') AS datetime,
+ datetime(datetime, 'unixepoch') AS datetime,
trip_count
FROM hourly_volume_rollup
WHERE
datetime >= ? AND
- datetime < ?
+ datetime <= ?
"""
start_dt = self.__localize_date(start_date)
end_dt = self.__localize_date(end_date)
results = pd.read_sql(query, self.db, params=(start_dt, end_dt))
+ results['datetime'] = pd.to_datetime(results['datetime'])\
+ .dt.tz_localize(utils.TZ_UTC)
+
return results.set_index('datetime')
@@ -58,6 +60,5 @@ class SQLiteStore():
"""
cur = self.db.cursor()
- print(values)
cur.executemany(query, values.to_dict('records'))
self.db.commit()
diff --git a/forerad/scrapers/historical.py b/forerad/scrapers/historical.py
@@ -11,13 +11,12 @@ from botocore import UNSIGNED
import forerad.utils as utils
-log = utils.get_logger()
ARCHIVE_REGEX = re.compile("^([0-9]{4})([0-9]{2})-citibike-tripdata((.zip$)|(.csv.zip$))")
CACHE_DIR = pathlib.Path(__file__).parent.parent.parent / pathlib.Path('.forerad-cache')
TRIP_BUCKET = 'tripdata'
if not CACHE_DIR.exists():
CACHE_DIR.mkdir()
- log.debug('Initializing .cache dir')
+ utils.logger.debug('Initializing .cache dir')
def __get_s3_client():
config = bclient.Config(signature_version=UNSIGNED)
@@ -40,7 +39,7 @@ class HistoricalTripArchive():
"""
match = ARCHIVE_REGEX.match(obj['Key'])
if match is None:
- log.error(f"Skipping object {obj['Key']}")
+ utils.logger.error(f"Skipping object {obj['Key']}")
return None
groups = match.groups()
@@ -79,13 +78,13 @@ class HistoricalTripArchive():
return None
with open(archive_path, 'rb') as f:
- log.info(f"Loading {self.object_key} from cache")
+ utils.logger.info(f"Loading {self.object_key} from cache")
return io.BytesIO(f.read())
def __store_blob(self, blob: io.BytesIO):
archive_path = CACHE_DIR / self.object_key
with open(archive_path, 'wb') as f:
- log.info(f"Storing {self.object_key} in cache")
+ utils.logger.info(f"Storing {self.object_key} in cache")
f.write(blob.getbuffer())
blob.seek(0)
@@ -97,7 +96,7 @@ class HistoricalTripArchive():
"""
blob = self.__fetch_cached_blob()
if blob is None:
- log.info(f"Fetching {self.csv_name} from S3")
+ utils.logger.info(f"Fetching {self.csv_name} from S3")
s3 = __get_s3_client()
resp = s3.get_object(Bucket=TRIP_BUCKET, Key=self.object_key)
blob = io.BytesIO(resp['Body'].read())
@@ -109,8 +108,8 @@ class HistoricalTripArchive():
csv_name = self.csv_name
if csv_name not in file_list and len(file_list) != 1:
- log.error(f"Could not extract {self.csv_name}:")
- log.error(file_list)
+ utils.logger.error(f"Could not extract {self.csv_name}:")
+ utils.logger.error(file_list)
raise Exception("Could not extract {self.csv_name}")
if csv_name not in file_list:
diff --git a/forerad/utils.py b/forerad/utils.py
@@ -6,6 +6,13 @@ import pytz
TZ_NYC = pytz.timezone('America/New_York')
TZ_UTC = pytz.timezone('UTC')
+logger = logging.getLogger('forerad')
+stream = logging.StreamHandler()
+fmt = logging.Formatter("%(asctime)s [%(levelname)s]: %(message)s")
+stream.setFormatter(fmt)
+logger.addHandler(stream)
+logger.setLevel(logging.INFO)
+
def parse_month_str(month_str: str) -> tuple[int, int]:
"""
Parses a string formatted YYYY-MM into a tuple of (year, month). Used for
@@ -24,14 +31,3 @@ def next_month(year, month) -> datetime.date:
"""
date = datetime.date(year, month, 1)
return (date + datetime.timedelta(days=32)).replace(day=1)
-
-def get_logger():
- logger = logging.getLogger('forerad')
-
- stream = logging.StreamHandler()
- fmt = logging.Formatter("%(asctime)s [%(levelname)s]: %(message)s")
- stream.setFormatter(fmt)
-
- logger.addHandler(stream)
- logger.setLevel(logging.INFO)
- return logger