From 684a6be9820deaa0f989b1c9ed48fac9b7795395 Mon Sep 17 00:00:00 2001 From: Jakob Cornell Date: Fri, 3 Jun 2022 14:42:57 -0500 Subject: [PATCH] MVP features and fixes --- docs/sample_config.ini | 47 ++++++++++--- src/strikebot/__init__.py | 102 +++++++++++++++++----------- src/strikebot/__main__.py | 17 +++-- src/strikebot/db.py | 16 ++++- src/strikebot/live_ws.py | 70 ++++++++++++------- src/strikebot/queue.py | 27 ++++++-- src/strikebot/reddit_api.py | 129 +++++++++++++++++++++++++----------- src/strikebot/tests.py | 14 ++-- src/strikebot/updates.py | 44 +++++++----- 9 files changed, 322 insertions(+), 144 deletions(-) diff --git a/docs/sample_config.ini b/docs/sample_config.ini index 66fc49e..28e510e 100644 --- a/docs/sample_config.ini +++ b/docs/sample_config.ini @@ -1,27 +1,52 @@ +# see Python `configparser' docs for precise file syntax + [config] + +######## basic operating parameters + +# space-separated authorization IDs for Reddit API token lookup +auth IDs = 13 15 + bot user = count_better +enforcing = true +thread ID = abc123 -# (seconds) maximum allowable spread in WebSocket message arrival times -WS parity time = 1.0 -# goal size of WebSocket connection pool -WS pool size = 5 +######## API pool error handling + +# (seconds) for error backoff, pause the API pool for this much time +API pool error delay = 5.5 + +# (seconds) pause the pool when the Reddit API returns two errors within any time span of this length +API pool error window = 5.5 + +# terminate if the Reddit API request queue exceeds this size +request queue limit = 100 -# (seconds) time after WebSocket handshake completion during which missed updates are excused -WS warmup time = 0.3 + +######## live thread update handling # (seconds) maximum time to hold updates for reordering reorder buffer time = 0.25 -enforcing = true - # (seconds) minimum time to retain updates to enable future resyncs update retention time = 120.0 -thread ID = abc123 -# space-separated authorization IDs for Reddit API token lookup -auth IDs = 13 15 +######## WebSocket pool + +# (seconds) maximum allowable spread in WebSocket message arrival times +WS parity time = 1.0 + +# goal size of WebSocket connection pool +WS pool size = 5 + +# (seconds) after receiving no messages for this much time, attempt to replace the connection if the pool isn't depleted +WS silent limit = 30 + +# (seconds) time after WebSocket handshake completion during which missed updates are excused +WS warmup time = 0.3 + # Postgres database configuration; same options as in a connect string [db connect params] diff --git a/src/strikebot/__init__.py b/src/strikebot/__init__.py index 63fff59..10b4013 100644 --- a/src/strikebot/__init__.py +++ b/src/strikebot/__init__.py @@ -1,14 +1,14 @@ from contextlib import nullcontext as nullcontext from dataclasses import dataclass from functools import total_ordering -from typing import Optional, Set +from itertools import islice +from typing import Optional from uuid import UUID import bisect import datetime as dt import importlib.metadata import itertools import logging -import re from trio import CancelScope, EndOfChannel import trio @@ -17,6 +17,7 @@ from strikebot.updates import Command, parse_update __version__ = importlib.metadata.version(__package__) +QUIET = True # TODO remove @dataclass @@ -51,11 +52,14 @@ class _BufferedUpdate: class _TimelineUpdate: update: _Update accepted: bool - bad_count: bool def __lt__(self, other): return self.update.ts < other.update.ts + def rejected(self) -> bool: + """Whether the update should be stricken.""" + return self.update.count_attempt and not self.accepted + def _parse_rfc_4122_ts(ts: int) -> dt.datetime: epoch = dt.datetime(1582, 10, 15, tzinfo = dt.timezone.utc) @@ -95,16 +99,21 @@ async def count_tracker_impl( buffer_ = [] timeline = [] last_valid: Optional[_Update] = None - pending_strikes: Set[str] = set() # names of updates to mark stricken on arrival + pending_strikes: set[str] = set() # names of updates to mark stricken on arrival def handle_update(update): - nonlocal delete_start, last_valid + nonlocal last_valid - tu = _TimelineUpdate(update, accepted = False, bad_count = False) + tu = _TimelineUpdate(update, accepted = False) pos = bisect.bisect(timeline, tu) if pos != len(timeline): logger.warning(f"long transpo: {update.name}") + if pos > 0 and timeline[pos - 1].update.id == update.id: + # The pool sent this update message multiple times. This could be because the message reached parity and + # its bucket was dropped just before a new connection was opened and sent its own copy, but connections also + # seem to send duplicate messages. + return pred = next( (timeline[i].update for i in reversed(range(pos)) if timeline[i].accepted), @@ -113,6 +122,8 @@ async def count_tracker_impl( if update.command is not Command.RESET and update.number is not None and last_valid and not pred: logger.warning("ignoring {update.name}: no valid prior count on record") else: + timeline.insert(pos, tu) + assert len({ctu.update.id for ctu in timeline}) == len(timeline) # TODO remove tu.accepted = ( update.command is Command.RESET or ( @@ -120,24 +131,32 @@ async def count_tracker_impl( and (pred is None or pred.number is None or update.can_follow(pred)) ) ) - timeline.insert(pos, tu) + logger.debug(f"accepted: {tu.accepted}") + logger.debug(" pred: {}".format(pred and (pred.id, pred.number, pred.author))) if tu.accepted: - # resync subsequent updates already processed + # resync subsequent updates newly_valid = [] newly_invalid = [] - last_valid = update - for scan_tu in timeline[pos + 1:]: + resync_last_valid = update + for scan_tu in islice(timeline, pos + 1, None): if scan_tu.update.command is Command.RESET: - last_valid = scan_tu.update + break elif scan_tu.update.number is not None: accept = last_valid.number is None or scan_tu.update.can_follow(last_valid) - if accept and not scan_tu.accepted: + if accept and scan_tu.accepted: + # resync would have no effect past this point + break + elif accept: newly_valid.append(scan_tu) - elif not accept and scan_tu.accepted: + resync_last_valid = scan_tu.update + elif scan_tu.accepted: newly_invalid.append(scan_tu) scan_tu.accepted = accept - if accept: - last_valid = scan_tu.update + + if last_valid: + last_valid = max(last_valid, resync_last_valid, key = lambda u: u.ts) + else: + last_valid = resync_last_valid parts = [] if newly_valid: @@ -145,8 +164,6 @@ async def count_tracker_impl( "The following counts are valid:\n\n" + "\n".join(" - " + _format_update_ref(tu.update, thread_id) for tu in newly_valid) ) - for tu in newly_valid: - tu.bad_count = False unstrikable = [tu for tu in newly_invalid if tu.update.stricken] if unstrikable: @@ -154,30 +171,29 @@ async def count_tracker_impl( "The following counts are invalid:\n\n" + "\n".join(" - " + _format_update_ref(tu.update, thread_id) for tu in unstrikable) ) - for tu in unstrikable: - tu.bad_count = True if update.stricken: logger.info(f"bad strike of {update.name}") parts.append(_format_bad_strike_alert(update, thread_id)) + if parts: + parts.append(_format_curr_count(last_valid)) + if not QUIET: + api_pool.enqueue_request(CorrectionUpdateRequest(thread_id, "\n\n".join(parts))) + for invalid_tu in newly_invalid: if not invalid_tu.update.stricken: if enforcing: api_pool.enqueue_request(StrikeRequest(thread_id, invalid_tu.update.name, invalid_tu.update.ts)) - invalid_tu.bad_count = True invalid_tu.update.stricken = True - if parts: - parts.append(_format_curr_count(last_valid)) - api_pool.enqueue_request(CorrectionUpdateRequest(thread_id, "\n\n".join(parts))) - elif update.number is not None or update.count_attempt: + elif update.count_attempt: if enforcing: api_pool.enqueue_request(StrikeRequest(thread_id, update.name, update.ts)) - tu.bad_count = True update.stricken = True if update.command is Command.REPORT: - api_pool.enqueue_request(ReportUpdateRequest(thread_id, body = _format_curr_count(last_valid))) + if not QUIET: + api_pool.enqueue_request(ReportUpdateRequest(thread_id, body = _format_curr_count(last_valid))) with message_rx: while True: @@ -200,7 +216,10 @@ async def count_tracker_impl( stricken = payload_data["name"] in pending_strikes pending_strikes.discard(payload_data["name"]) if payload_data["author"] != bot_user: - next_up = last_valid.number if last_valid else None + if last_valid and last_valid.number is not None: + next_up = last_valid.number + 1 + else: + next_up = None pu = parse_update(payload_data, next_up, bot_user) rfc_ts = UUID(payload_data["id"]).time update = _Update( @@ -217,7 +236,6 @@ async def count_tracker_impl( release_at = msg.timestamp + reorder_buffer_time.total_seconds() bisect.insort(buffer_, _BufferedUpdate(update, release_at)) elif msg.data["type"] == "strike": - UUID(re.match("LiveUpdate_(.+)$", msg.data["payload"])[1]) # sanity check payload slot = next( ( slot for slot in itertools.chain(buffer_, reversed(timeline)) @@ -226,11 +244,12 @@ async def count_tracker_impl( None ) if slot: - slot.update.stricken = True - if isinstance(slot, _TimelineUpdate) and slot.accepted: - logger.info(f"bad strike of {slot.update.name}") - body = _format_bad_strike_alert(slot.update, thread_id) - api_pool.enqueue_request(CorrectionUpdateRequest(thread_id, body)) + if not slot.update.stricken: + slot.update.stricken = True + if not QUIET and isinstance(slot, _TimelineUpdate) and slot.accepted: + logger.info(f"bad strike of {slot.update.name}") + body = _format_bad_strike_alert(slot.update, thread_id) + api_pool.enqueue_request(CorrectionUpdateRequest(thread_id, body)) else: pending_strikes.add(msg.data["payload"]) @@ -245,27 +264,34 @@ async def count_tracker_impl( # valid count or (last_valid is None or last_valid.number is None or bu.update.can_follow(last_valid)) - - # invalid and not likely to become valid by transpo - or bu.update.number > last_valid.number + 3 + or bu.update.command is Command.RESET or trio.current_time() >= bu.release_at + or bu.update.command is Command.RESET ) if process: if bu.update.ts < threshold: logger.warning(f"ignoring {bu.update.name}: arrived past retention window") else: + logger.debug(f"processing update {bu.update.id} ({bu.update.number} by {bu.update.author})") handle_update(bu.update) else: + logger.debug(f"holding {bu.update.id} ({bu.update.number}, checked against {last_valid.number})") new_buffer.append(bu) buffer_ = new_buffer + logger.debug("last count {}".format(last_valid and (last_valid.id, last_valid.number))) # TODO remove # delete/forget old updates + new_timeline = [] + i = 0 for (i, tu) in enumerate(timeline): if i >= len(timeline) - 10 or tu.update.ts >= threshold: break - elif tu.bad_count and tu.update.deletable: + elif tu.rejected() and tu.update.deletable: if enforcing: api_pool.enqueue_request(DeleteRequest(thread_id, tu.update.name, tu.update.ts)) - del timeline[: i] + elif tu.update is last_valid: + new_timeline.append(tu) + new_timeline.extend(islice(timeline, i, None)) + timeline = new_timeline diff --git a/src/strikebot/__main__.py b/src/strikebot/__main__.py index a1c6705..217f5c9 100644 --- a/src/strikebot/__main__.py +++ b/src/strikebot/__main__.py @@ -25,14 +25,18 @@ with open(args.config_path) as config_file: main_cfg = parser["config"] +api_pool_error_delay = dt.timedelta(seconds = main_cfg.getfloat("API pool error delay")) +api_pool_error_window = dt.timedelta(seconds = main_cfg.getfloat("API pool error window")) auth_ids = set(map(int, main_cfg["auth IDs"].split())) bot_user = main_cfg["bot user"] enforcing = main_cfg.getboolean("enforcing") reorder_buffer_time = dt.timedelta(seconds = main_cfg.getfloat("reorder buffer time")) +request_queue_limit = main_cfg.getint("request queue limit") thread_id = main_cfg["thread ID"] update_retention = dt.timedelta(seconds = main_cfg.getfloat("update retention time")) ws_parity_time = dt.timedelta(seconds = main_cfg.getfloat("WS parity time")) ws_pool_size = main_cfg.getint("WS pool size") +ws_silent_limit = dt.timedelta(seconds = main_cfg.getfloat("WS silent limit")) ws_warmup = dt.timedelta(seconds = main_cfg.getfloat("WS warmup seconds")) db_connect_params = dict(parser["db connect params"]) @@ -54,7 +58,9 @@ async def main(): nursery_a.start_soon(db_messenger.db_client_impl) - api_pool = ApiClientPool(auth_ids, db_messenger, logger) + api_pool = ApiClientPool( + auth_ids, db_messenger, request_queue_limit, api_pool_error_window, api_pool_error_delay, logger + ) nursery_a.start_soon(api_pool.token_updater_impl) for _ in auth_ids: nursery_a.start_soon(api_pool.worker_impl) @@ -62,16 +68,17 @@ async def main(): (message_tx, message_rx) = trio.open_memory_channel(0) (pool_event_tx, pool_event_rx) = trio.open_memory_channel(0) merger = PoolMerger(pool_event_rx, message_tx, ws_parity_time, ws_warmup, logger) - async with trio.open_nursery() as nursery_b: + async with trio.open_nursery() as nursery_b, trio.open_nursery() as ws_pool_nursery: nursery_b.start_soon(merger.event_reader_impl) nursery_b.start_soon(merger.timeout_handler_impl) - pool = HealingReadPool(nursery_b, ws_pool_size, thread_id, api_pool, pool_event_tx, logger) - async with pool, trio.open_nursery() as nursery_c: - nursery_c.start_soon(pool.conn_refresher_impl) + ws_pool = HealingReadPool(ws_pool_nursery, ws_pool_size, thread_id, api_pool, pool_event_tx, logger, ws_silent_limit) + async with ws_pool, trio.open_nursery() as nursery_c: + nursery_c.start_soon(ws_pool.conn_refresher_impl) nursery_c.start_soon( count_tracker_impl, message_rx, api_pool, reorder_buffer_time, thread_id, bot_user, enforcing, update_retention, logger ) + await ws_pool.init_workers() trio_asyncio.run(main) diff --git a/src/strikebot/db.py b/src/strikebot/db.py index 76a19b6..a88ffe8 100644 --- a/src/strikebot/db.py +++ b/src/strikebot/db.py @@ -20,8 +20,20 @@ class Client: conn: Any @_channel_sender - async def get_auth_tokens(self, auth_ids: set[int]): - raise NotImplementedError() + async def get_auth_tokens(self, auth_ids: set[int]) -> dict[int, str]: + results = await self.conn.fetch( + """ + select + distinct on (id) + id, access_token + from + public.reddit_app_authorization + join unnest($1::integer[]) as request_id on id = request_id + where expires > current_timestamp + """, + list(auth_ids), + ) + return {r["id"]: r["access_token"] for r in results} class Messenger: diff --git a/src/strikebot/live_ws.py b/src/strikebot/live_ws.py index 40f02dc..b257905 100644 --- a/src/strikebot/live_ws.py +++ b/src/strikebot/live_ws.py @@ -7,6 +7,7 @@ import json import logging import math +from trio.lowlevel import checkpoint_if_cancelled from trio_websocket import ConnectionClosed, HandshakeError, open_websocket_url, WebSocketConnection import trio @@ -44,44 +45,63 @@ class _ConnectionUp(_PoolEvent): class HealingReadPool: - def __init__(self, nursery: trio.Nursery, size, live_thread_id, api_client_pool, pool_event_tx, logger: logging.Logger): + def __init__( + self, pool_nursery: trio.Nursery, size, live_thread_id, api_client_pool, pool_event_tx, logger: logging.Logger, + silent_limit: dt.timedelta + ): assert size >= 2 - self._nursery = nursery + self._nursery = pool_nursery self._size = size self._live_thread_id = live_thread_id self._api_client_pool = api_client_pool self._pool_event_tx = pool_event_tx self._logger = logger + self._silent_limit = silent_limit (self._refresh_queue_tx, self._refresh_queue_rx) = trio.open_memory_channel(size) - self._active_count = 0 async def __aenter__(self): - for _ in range(self._size): - await self._spawn_reader() - if not self._active_count: - raise RuntimeError("Unable to create any WS connections") + pass async def __aexit__(self, exc_type, exc_value, traceback): self._refresh_queue_tx.close() + async def init_workers(self): + for _ in range(self._size): + await self._spawn_reader() + if not self._nursery.child_tasks: + raise RuntimeError("Unable to create any WS connections") + async def _reader_impl(self, conn_ctx: AsyncContextManager, cancel_scope, refresh_tx): + # TODO could use task's implicit cancel scope try: async with conn_ctx as conn: - self._active_count += 1 with refresh_tx: + _tsc = False # TODO remove with cancel_scope, suppress(ConnectionClosed): while True: - message = await conn.get_message() - event = _Message(trio.current_time(), json.loads(message), cancel_scope) - await self._pool_event_tx.send(event) - - if cancel_scope.cancelled_caught: - await conn.aclose(4058, "Server unexpectedly stopped sending messages") - self._logger.warning("replacing WS connection due to missed updates") + with trio.move_on_after(self._silent_limit.total_seconds()) as timeout_scope: + message = await conn.get_message() + + if timeout_scope.cancelled_caught: + if len(self._nursery.child_tasks) == self._size: + _tsc = True + cancel_scope.cancel() + await checkpoint_if_cancelled() + else: + event = _Message(trio.current_time(), json.loads(message), cancel_scope) + await self._pool_event_tx.send(event) + + assert _tsc == timeout_scope.cancelled_caught + if timeout_scope.cancelled_caught: + self._logger.debug("replacing WS connection due to silent timeout") + await conn.aclose() + elif cancel_scope.cancelled_caught: + await conn.aclose(1008, "Server unexpectedly stopped sending messages") + self._logger.warning("replacing WS connection due to missed update") else: self._logger.warning("replacing WS connection closed by server") + self._logger.debug(f"post-kill tasks {len(self._nursery.child_tasks)}") refresh_tx.send_nowait(cancel_scope) - self._active_count -= 1 except HandshakeError: self._logger.error("handshake error while opening WS connection") @@ -102,6 +122,7 @@ class HealingReadPool: await self._pool_event_tx.send(_ConnectionUp(trio.current_time(), new_scope)) refresh_tx = self._refresh_queue_tx.clone() self._nursery.start_soon(self._reader_impl, conn, new_scope, refresh_tx) + self._logger.debug(f"post-spawn tasks {len(self._nursery.child_tasks)}") async def conn_refresher_impl(self): """Task to monitor and replace WS connections as they disconnect or go silent.""" @@ -109,7 +130,7 @@ class HealingReadPool: async for old_scope in self._refresh_queue_rx: await self._pool_event_tx.send(_ConnectionDown(trio.current_time(), old_scope)) await self._spawn_reader() - if not self._active_count: + if not self._nursery.child_tasks: raise RuntimeError("WS pool depleted") @@ -133,6 +154,7 @@ class PoolMerger: self._buckets = {} self._scope_activations = {} self._pending = deque() + self._outgoing_scopes = set() (self._timer_poke_tx, self._timer_poke_rx) = trio.open_memory_channel(math.inf) async def event_reader_impl(self): @@ -145,27 +167,29 @@ class PoolMerger: # by timestamp to avoid that. self._pending.append(event) elif isinstance(event, _Message): - if event.data["type"] in {"update", "strike", "delete"}: + if event.data["type"] in ["update", "strike", "delete"]: tag = event.dedup_tag() if tag in self._buckets: b = self._buckets[tag] b.recipients.add(event.scope) if b.recipients >= self._scope_activations.keys(): del self._buckets[tag] - else: + elif event.scope not in self._outgoing_scopes: sane = ( event.scope in self._scope_activations or any(e.scope == event.scope for e in self._pending) ) if sane: self._buckets[tag] = self._Bucket(event.timestamp, {event.scope}) - self._message_tx.send_nowait(event) + await self._message_tx.send(event) else: raise RuntimeError("recieved message from unrecognized WS connection") elif isinstance(event, _ConnectionDown): # We don't need to worry about canceling this scope at all, so no need to require it for parity for # any message, even older ones. The scope may be gone already, if we canceled it previously. self._scope_activations.pop(event.scope, None) + self._pending = deque(ev for ev in self._pending if ev.scope is not event.scope) + self._outgoing_scopes.discard(event.scope) else: raise TypeError(f"Expected pool event, found {event!r}") self._timer_poke_tx.send_nowait(None) # may have new work for the timer @@ -187,12 +211,10 @@ class PoolMerger: scope for (scope, active) in self._scope_activations.items() if active + self._conn_warmup.total_seconds() < now } - if bucket.recipients >= target_scopes: - del self._buckets[tag] - elif now > bucket.start + self._parity_timeout.total_seconds(): + if now > bucket.start + self._parity_timeout.total_seconds(): for scope in target_scopes - bucket.recipients: + self._outgoing_scopes.add(scope) scope.cancel() - del self._scope_activations[scope] del self._buckets[tag] else: await trio.sleep_until(bucket.start + self._parity_timeout.total_seconds()) diff --git a/src/strikebot/queue.py b/src/strikebot/queue.py index fed0757..b4ebbe3 100644 --- a/src/strikebot/queue.py +++ b/src/strikebot/queue.py @@ -1,13 +1,33 @@ -"""Unbounded blocking priority queue for Trio""" +"""Unbounded blocking queues for Trio""" +from collections import deque from dataclasses import dataclass from functools import total_ordering -from typing import Any +from typing import Any, Iterable import heapq from trio.lowlevel import ParkingLot +class Queue: + def __init__(self): + self._deque = deque() + self._empty_wait = ParkingLot() + + def push(self, el: Any) -> None: + self._deque.append(el) + self._empty_wait.unpark() + + def extend(self, els: Iterable[Any]) -> None: + for el in els: + self.push(el) + + async def pop(self) -> Any: + if not self._deque: + await self._empty_wait.park() + return self._deque.popleft() + + @total_ordering @dataclass class _ReverseOrdWrapper: @@ -24,8 +44,7 @@ class MaxHeap: def push(self, item): heapq.heappush(self._heap, _ReverseOrdWrapper(item)) - if len(self._empty_wait): - self._empty_wait.unpark() + self._empty_wait.unpark() async def pop(self): if not self._heap: diff --git a/src/strikebot/reddit_api.py b/src/strikebot/reddit_api.py index 600b38e..ab59a95 100644 --- a/src/strikebot/reddit_api.py +++ b/src/strikebot/reddit_api.py @@ -1,9 +1,9 @@ """Multi-auth HTTP connection pool with Reddit API wrappers and rate limiting.""" from abc import ABCMeta, abstractmethod -from collections import deque from dataclasses import dataclass, field from functools import total_ordering +from socket import EAI_AGAIN, EAI_FAIL, gaierror import datetime as dt import logging @@ -12,7 +12,7 @@ import asks import trio from strikebot import __version__ as VERSION -from strikebot.queue import MaxHeap +from strikebot.queue import MaxHeap, Queue REQUEST_DELAY = dt.timedelta(seconds = 1) @@ -30,14 +30,14 @@ class _Request(metaclass = ABCMeta): def __lt__(self, other): if type(self) is type(other): - return self._subtype_key() < other._subtype_key() + return self._subtype_cmp_key() < other._subtype_cmp_key() else: prec = self._SUBTYPE_PRECEDENCE return prec.index(type(other)) < prec.index(type(self)) def __eq__(self, other): if type(self) is type(other): - return self._subtype_key() == other._subtype_key() + return self._subtype_cmp_key() == other._subtype_cmp_key() else: return False @@ -155,74 +155,129 @@ class AppCooldown: class ApiClientPool: - def __init__(self, auth_ids: set[int], db_messenger, logger: logging.Logger): + def __init__( + self, + auth_ids: set[int], + db_messenger: "strikebot.db.Messenger", + request_queue_limit: int, + error_window: dt.timedelta, + error_delay: dt.timedelta, + logger: logging.Logger, + ): + self._auth_ids = auth_ids self._db_messenger = db_messenger + self._request_queue_limit = request_queue_limit + self._error_window = error_window + self._error_delay = error_delay self._logger = logger - self._tokens = {id_: None for id_ in auth_ids} - self._tokens_installed = trio.Event() now = trio.current_time() - self._app_queue = deque(AppCooldown(id_, now) for id_ in auth_ids) + self._tokens = {} + self._app_queue = Queue() + self._waiting = {id_: AppCooldown(id_, now) for id_ in auth_ids} self._request_queue = MaxHeap() + # pool-wide API error backoff + self._last_error = None + self._global_resume = None + self._session = asks.Session(connections = len(auth_ids)) self._session.base_location = API_BASE_URL async def _update_tokens(self): - tokens = await self._db_messenger.do("get_auth_tokens", (self._tokens.keys(),)) - assert tokens.keys() == self._tokens.keys() - self._tokens = tokens - self._tokens_installed.set() + tokens = await self._db_messenger.do("get_auth_tokens", (self._auth_ids,)) + self._tokens.update(tokens) + + awaken_auths = self._waiting.keys() & tokens.keys() + self._app_queue.extend(self._waiting.pop(auth_id) for auth_id in awaken_auths) async def token_updater_impl(self): last_update = None while True: if last_update is not None: await trio.sleep_until(last_update + TOKEN_UPDATE_DELAY.total_seconds()) - last_update = trio.current_time() + + if last_update is None: + last_update = trio.current_time() + else: + last_update += TOKEN_UPDATE_DELAY.total_seconds() + await self._update_tokens() def _check_queue_size(self) -> None: - if len(self._request_queue) > 5: - self._logger.warning(f"API workers may be saturated; {len(self._request_queue)} requests in queue") + if len(self._request_queue) > self._request_queue_limit: + raise RuntimeError("request queue size exceeded limit") async def make_request(self, request: _Request) -> Response: (resp_tx, resp_rx) = trio.open_memory_channel(0) self._request_queue.push((request, resp_tx)) + self._check_queue_size() async with resp_rx: return await resp_rx.receive() def enqueue_request(self, request: _Request) -> None: self._request_queue.push((request, None)) + self._check_queue_size() async def worker_impl(self): - await self._tokens_installed.wait() while True: (request, resp_tx) = await self._request_queue.pop() - while trio.current_time() < self._app_queue[0].ready_at: - await trio.sleep_until(self._app_queue[0].ready_at) + cooldown = await self._app_queue.pop() + await trio.sleep_until(cooldown.ready_at) + if self._global_resume: + await trio.sleep_until(self._global_resume) - cooldown = self._app_queue.popleft() asks_kwargs = request.to_asks_kwargs() - asks_kwargs.setdefault("headers", {}).update({ + headers = asks_kwargs.setdefault("headers", {}) + headers.update({ "Authorization": "Bearer {}".format(self._tokens[cooldown.auth_id]), "User-Agent": USER_AGENT_FMT.format(auth_id = cooldown.auth_id), }) - cooldown.ready_at = trio.current_time() + REQUEST_DELAY.total_seconds() - resp = await self._session.request(**asks_kwargs) - - if resp.status_code == 429: - # We disagreed about the rate limit state; just try again later. - self._request_queue.put((request, resp_tx)) - elif 400 <= resp.status_code < 500: - # If we're doing something wrong, let's catch it right away. - if resp_tx: - resp_tx.close() - raise RuntimeError("Unexpected client error response: {}".format(resp.status_code)) - else: - if resp.status_code != 200: - self._logger.warning(f"got HTTP {resp.status_code} from Reddit API") - if resp_tx: - await resp_tx.send(resp) - self._app_queue.append(cooldown) + request_time = trio.current_time() + try: + resp = await self._session.request(**asks_kwargs) + except gaierror as e: + if e.errno in [EAI_FAIL, EAI_AGAIN]: + # DNS failure, probably temporary + error = True + else: + raise + else: + resp.body # read response + error = False + wait_for_token = False + if resp.status_code == 429: + # We disagreed about the rate limit state; just try again later. + self._logger.warning("rate limited by Reddit API") + error = True + elif resp.status_code == 401: + self._logger.warning("got HTTP 401 from Reddit API") + error = True + wait_for_token = True + elif resp.status_code in [404, 500]: + self._logger.warning(f"got HTTP {resp.status_code} from Reddit API, retrying") + error = True + elif 400 <= resp.status_code < 500: + # If we're doing something wrong, let's catch it right away. + raise RuntimeError(f"unexpected client error response: {resp.status_code}") + else: + if resp.status_code != 200: + raise RuntimeError(f"unexpected status code {resp.status_code}") + if resp_tx: + await resp_tx.send(resp) + + if error: + self._request_queue.push((request, resp_tx)) + self._check_queue_size() + if self._last_error: + spread = dt.timedelta(seconds = request_time - self._last_error) + if spread <= self._error_window: + self._global_resume = request_time + self._error_delay.total_seconds() + self._last_error = request_time + + cooldown.ready_at = request_time + REQUEST_DELAY.total_seconds() + if wait_for_token: + self._waiting[cooldown.auth_id] = cooldown + else: + self._app_queue.push(cooldown) diff --git a/src/strikebot/tests.py b/src/strikebot/tests.py index 9b9ddae..55b3426 100644 --- a/src/strikebot/tests.py +++ b/src/strikebot/tests.py @@ -9,34 +9,34 @@ def _build_payload(body_html: str) -> dict[str, str]: class UpdateParsingTests(TestCase): def test_successful_counts(self): - pu = parse_update(_build_payload("
12,345,678 spaghetti
"), None) + pu = parse_update(_build_payload("
12,345,678 spaghetti
"), None, "") self.assertEqual(pu.number, 12345678) self.assertFalse(pu.deletable) - pu = parse_update(_build_payload("

0


oz
"), None) + pu = parse_update(_build_payload("

0


oz
"), None, "") self.assertEqual(pu.number, 0) self.assertFalse(pu.deletable) def test_non_counts(self): - pu = parse_update(_build_payload("
zoo
"), None) + pu = parse_update(_build_payload("
zoo
"), None, "") self.assertFalse(pu.count_attempt) self.assertFalse(pu.deletable) def test_typos(self): - pu = parse_update(_build_payload("v9"), 888) + pu = parse_update(_build_payload("v9"), 888, "") self.assertIsNone(pu.number) self.assertTrue(pu.count_attempt) - pu = parse_update(_build_payload("
v11.585 Empire
"), None) + pu = parse_update(_build_payload("
v11.585 Empire
"), None, "") self.assertIsNone(pu.number) self.assertTrue(pu.count_attempt) self.assertFalse(pu.deletable) - pu = parse_update(_build_payload("
11, 585, 22
"), 11_585_202) + pu = parse_update(_build_payload("
11, 585, 22
"), 11_585_202, "") self.assertIsNone(pu.number) self.assertTrue(pu.count_attempt) self.assertTrue(pu.deletable) - pu = parse_update(_build_payload("0490499"), 4999) + pu = parse_update(_build_payload("0490499"), 4999, "") self.assertIsNone(pu.number) self.assertTrue(pu.count_attempt) diff --git a/src/strikebot/updates.py b/src/strikebot/updates.py index cd814bf..f61d4e2 100644 --- a/src/strikebot/updates.py +++ b/src/strikebot/updates.py @@ -14,7 +14,7 @@ Command = Enum("Command", ["RESET", "REPORT"]) class ParsedUpdate: number: Optional[int] command: Optional[Command] - count_attempt: bool + count_attempt: bool # either well-formed or typo deletable: bool @@ -28,6 +28,8 @@ def _parse_command(line: str, bot_user: str) -> Optional[Command]: def parse_update(payload_data: dict, curr_count: Optional[int], bot_user: str) -> ParsedUpdate: + # curr_count is the next number up, one more than the last count + NEW_LINE = object() SPACE = object() @@ -52,7 +54,6 @@ def parse_update(payload_data: dict, curr_count: Optional[int], bot_user: str) - if el.text: worklist.appendleft(el.text) elif el.tag == "li": - assert not el.tail worklist.appendleft(NEW_LINE) worklist.appendleft(el.text) elif el.tag in ["p", "div", "h1", "h2", "blockquote"]: @@ -66,28 +67,26 @@ def parse_update(payload_data: dict, curr_count: Optional[int], bot_user: str) - elif el.tag in ["ul", "ol"]: if el.tail: worklist.appendleft(el.tail) - assert not el.text for sub in reversed(el): worklist.appendleft(sub) worklist.appendleft(NEW_LINE) elif el.tag == "pre": - out.extend([l] for l in el.text.splitlines()) + if el.text: + out.extend([l] for l in el.text.splitlines()) + worklist.appendleft(NEW_LINE) if el.tail: worklist.appendleft(el.tail) - worklist.appendleft(NEW_LINE) + for sub in reversed(el): + worklist.appendleft(sub) elif el.tag == "table": if el.tail: worklist.appendleft(el.tail) worklist.appendleft(NEW_LINE) - assert not el.text for sub in reversed(el): assert sub.tag in ["thead", "tbody"] - assert not sub.text for row in reversed(sub): assert row.tag == "tr" - assert not row.text or row.tail for (i, cell) in enumerate(reversed(row)): - assert not cell.tail worklist.appendleft(cell) if i != len(row) - 1: worklist.appendleft(SPACE) @@ -115,13 +114,14 @@ def parse_update(payload_data: dict, curr_count: Optional[int], bot_user: str) - if lines: first = lines[0] match = re.match( - "(?Pv)?(?P-)?(?P\\d+((?P[,. \u2009]|, )\\d+((?P=sep)\\d+)*)?)( |$)", + "(?Pv)?(?P-)?(?P\\d+((?P[,. \u2009]|, )\\d+((?P=sep)\\d+)*)?)", first, re.ASCII, # only recognize ASCII digits ) if match: ct_str = match["num"] sep = match["sep"] + post = first[match.end() :] zeros = False while len(ct_str) > 1 and ct_str[0] == "0": @@ -130,11 +130,14 @@ def parse_update(payload_data: dict, curr_count: Optional[int], bot_user: str) - parts = ct_str.split(sep) if sep else [ct_str] parts_valid = ( - len(parts[0]) in range(1, 3) - and all(len(p) == 3 for p in parts[1:]) + sep is None + or ( + len(parts[0]) in range(1, 4) + and all(len(p) == 3 for p in parts[1:]) + ) ) digits = "".join(parts) - lone = not first[match.end() :].strip() and len(lines) == 1 + lone = len(lines) == 1 and (not post or post.isspace()) typo = False if lone: if match["v"] and len(ct_str) <= 2: @@ -143,7 +146,7 @@ def parse_update(payload_data: dict, curr_count: Optional[int], bot_user: str) - elif match["v"] and parts_valid: # v followed by count typo = True - elif curr_count and curr_count >= 100: + elif curr_count and curr_count >= 100 and bool(match["neg"]) == (curr_count < 0): goal = (sep or "").join(_separate(str(curr_count))) partials = [goal[: -2], goal[: -1], goal[: -2] + goal[-1]] if ct_str in partials: @@ -152,15 +155,24 @@ def parse_update(payload_data: dict, curr_count: Optional[int], bot_user: str) - elif ct_str in [p + goal for p in partials]: # double paste typo = True + if match["v"] or zeros or typo or (digits == "0" and match["neg"]): number = None count_attempt = True deletable = lone elif parts_valid: number = -int(digits) if match["neg"] else int(digits) - count_attempt = True - special = curr_count is not None and abs(number - curr_count) <= 25 and _is_special_number(number) + special = ( + curr_count is not None + and abs(number - curr_count) <= 25 + and _is_special_number(number) + ) deletable = lone and not special + if post and not post[0].isspace(): + count_attempt = curr_count is not None and abs(number - curr_count) <= 25 + number = None + else: + count_attempt = True else: number = None count_attempt = False -- 2.30.2