MVP features and fixes
authorJakob Cornell <jakob+gpg@jcornell.net>
Fri, 3 Jun 2022 19:42:57 +0000 (14:42 -0500)
committerJakob Cornell <jakob+gpg@jcornell.net>
Fri, 3 Jun 2022 19:42:57 +0000 (14:42 -0500)
docs/sample_config.ini
src/strikebot/__init__.py
src/strikebot/__main__.py
src/strikebot/db.py
src/strikebot/live_ws.py
src/strikebot/queue.py
src/strikebot/reddit_api.py
src/strikebot/tests.py
src/strikebot/updates.py

index 66fc49e4a64c4c980071470eb66738a23ef35d43..28e510e54497f38b680613bea9c9f53bbc979c1b 100644 (file)
@@ -1,27 +1,52 @@
+# see Python `configparser' docs for precise file syntax
+
 [config]
+
+######## basic operating parameters
+
+# space-separated authorization IDs for Reddit API token lookup
+auth IDs = 13 15
+
 bot user = count_better
+enforcing = true
+thread ID = abc123
 
-# (seconds) maximum allowable spread in WebSocket message arrival times
-WS parity time = 1.0
 
-# goal size of WebSocket connection pool
-WS pool size = 5
+######## API pool error handling
+
+# (seconds) for error backoff, pause the API pool for this much time
+API pool error delay = 5.5
+
+# (seconds) pause the pool when the Reddit API returns two errors within any time span of this length
+API pool error window = 5.5
+
+# terminate if the Reddit API request queue exceeds this size
+request queue limit = 100
 
-# (seconds) time after WebSocket handshake completion during which missed updates are excused
-WS warmup time = 0.3
+
+######## live thread update handling
 
 # (seconds) maximum time to hold updates for reordering
 reorder buffer time = 0.25
 
-enforcing = true
-
 # (seconds) minimum time to retain updates to enable future resyncs
 update retention time = 120.0
 
-thread ID = abc123
 
-# space-separated authorization IDs for Reddit API token lookup
-auth IDs = 13 15
+######## WebSocket pool
+
+# (seconds) maximum allowable spread in WebSocket message arrival times
+WS parity time = 1.0
+
+# goal size of WebSocket connection pool
+WS pool size = 5
+
+# (seconds) after receiving no messages for this much time, attempt to replace the connection if the pool isn't depleted
+WS silent limit = 30
+
+# (seconds) time after WebSocket handshake completion during which missed updates are excused
+WS warmup time = 0.3
+
 
 # Postgres database configuration; same options as in a connect string
 [db connect params]
index 63fff59e32ad736b7689be238e7bbd632bd2b0cd..10b4013b2096d5d10e178e956db10f7d54674530 100644 (file)
@@ -1,14 +1,14 @@
 from contextlib import nullcontext as nullcontext
 from dataclasses import dataclass
 from functools import total_ordering
-from typing import Optional, Set
+from itertools import islice
+from typing import Optional
 from uuid import UUID
 import bisect
 import datetime as dt
 import importlib.metadata
 import itertools
 import logging
-import re
 
 from trio import CancelScope, EndOfChannel
 import trio
@@ -17,6 +17,7 @@ from strikebot.updates import Command, parse_update
 
 
 __version__ = importlib.metadata.version(__package__)
+QUIET = True  # TODO remove
 
 
 @dataclass
@@ -51,11 +52,14 @@ class _BufferedUpdate:
 class _TimelineUpdate:
        update: _Update
        accepted: bool
-       bad_count: bool
 
        def __lt__(self, other):
                return self.update.ts < other.update.ts
 
+       def rejected(self) -> bool:
+               """Whether the update should be stricken."""
+               return self.update.count_attempt and not self.accepted
+
 
 def _parse_rfc_4122_ts(ts: int) -> dt.datetime:
        epoch = dt.datetime(1582, 10, 15, tzinfo = dt.timezone.utc)
@@ -95,16 +99,21 @@ async def count_tracker_impl(
        buffer_ = []
        timeline = []
        last_valid: Optional[_Update] = None
-       pending_strikes: Set[str] = set()  # names of updates to mark stricken on arrival
+       pending_strikes: set[str] = set()  # names of updates to mark stricken on arrival
 
        def handle_update(update):
-               nonlocal delete_start, last_valid
+               nonlocal last_valid
 
-               tu = _TimelineUpdate(update, accepted = False, bad_count = False)
+               tu = _TimelineUpdate(update, accepted = False)
 
                pos = bisect.bisect(timeline, tu)
                if pos != len(timeline):
                        logger.warning(f"long transpo: {update.name}")
+               if pos > 0 and timeline[pos - 1].update.id == update.id:
+                       # The pool sent this update message multiple times.  This could be because the message reached parity and
+                       # its bucket was dropped just before a new connection was opened and sent its own copy, but connections also
+                       # seem to send duplicate messages.
+                       return
 
                pred = next(
                        (timeline[i].update for i in reversed(range(pos)) if timeline[i].accepted),
@@ -113,6 +122,8 @@ async def count_tracker_impl(
                if update.command is not Command.RESET and update.number is not None and last_valid and not pred:
                        logger.warning("ignoring {update.name}: no valid prior count on record")
                else:
+                       timeline.insert(pos, tu)
+                       assert len({ctu.update.id for ctu in timeline}) == len(timeline)  # TODO remove
                        tu.accepted = (
                                update.command is Command.RESET
                                or (
@@ -120,24 +131,32 @@ async def count_tracker_impl(
                                        and (pred is None or pred.number is None or update.can_follow(pred))
                                )
                        )
-                       timeline.insert(pos, tu)
+                       logger.debug(f"accepted: {tu.accepted}")
+                       logger.debug("  pred: {}".format(pred and (pred.id, pred.number, pred.author)))
                        if tu.accepted:
-                               # resync subsequent updates already processed
+                               # resync subsequent updates
                                newly_valid = []
                                newly_invalid = []
-                               last_valid = update
-                               for scan_tu in timeline[pos + 1:]:
+                               resync_last_valid = update
+                               for scan_tu in islice(timeline, pos + 1, None):
                                        if scan_tu.update.command is Command.RESET:
-                                               last_valid = scan_tu.update
+                                               break
                                        elif scan_tu.update.number is not None:
                                                accept = last_valid.number is None or scan_tu.update.can_follow(last_valid)
-                                               if accept and not scan_tu.accepted:
+                                               if accept and scan_tu.accepted:
+                                                       # resync would have no effect past this point
+                                                       break
+                                               elif accept:
                                                        newly_valid.append(scan_tu)
-                                               elif not accept and scan_tu.accepted:
+                                                       resync_last_valid = scan_tu.update
+                                               elif scan_tu.accepted:
                                                        newly_invalid.append(scan_tu)
                                                scan_tu.accepted = accept
-                                               if accept:
-                                                       last_valid = scan_tu.update
+
+                               if last_valid:
+                                       last_valid = max(last_valid, resync_last_valid, key = lambda u: u.ts)
+                               else:
+                                       last_valid = resync_last_valid
 
                                parts = []
                                if newly_valid:
@@ -145,8 +164,6 @@ async def count_tracker_impl(
                                                "The following counts are valid:\n\n"
                                                + "\n".join(" - " + _format_update_ref(tu.update, thread_id) for tu in newly_valid)
                                        )
-                                       for tu in newly_valid:
-                                               tu.bad_count = False
 
                                unstrikable = [tu for tu in newly_invalid if tu.update.stricken]
                                if unstrikable:
@@ -154,30 +171,29 @@ async def count_tracker_impl(
                                                "The following counts are invalid:\n\n"
                                                + "\n".join(" - " + _format_update_ref(tu.update, thread_id) for tu in unstrikable)
                                        )
-                                       for tu in unstrikable:
-                                               tu.bad_count = True
 
                                if update.stricken:
                                        logger.info(f"bad strike of {update.name}")
                                        parts.append(_format_bad_strike_alert(update, thread_id))
 
+                               if parts:
+                                       parts.append(_format_curr_count(last_valid))
+                                       if not QUIET:
+                                               api_pool.enqueue_request(CorrectionUpdateRequest(thread_id, "\n\n".join(parts)))
+
                                for invalid_tu in newly_invalid:
                                        if not invalid_tu.update.stricken:
                                                if enforcing:
                                                        api_pool.enqueue_request(StrikeRequest(thread_id, invalid_tu.update.name, invalid_tu.update.ts))
-                                               invalid_tu.bad_count = True
                                                invalid_tu.update.stricken = True
-                               if parts:
-                                       parts.append(_format_curr_count(last_valid))
-                                       api_pool.enqueue_request(CorrectionUpdateRequest(thread_id, "\n\n".join(parts)))
-                       elif update.number is not None or update.count_attempt:
+                       elif update.count_attempt:
                                if enforcing:
                                        api_pool.enqueue_request(StrikeRequest(thread_id, update.name, update.ts))
-                               tu.bad_count = True
                                update.stricken = True
 
                if update.command is Command.REPORT:
-                       api_pool.enqueue_request(ReportUpdateRequest(thread_id, body = _format_curr_count(last_valid)))
+                       if not QUIET:
+                               api_pool.enqueue_request(ReportUpdateRequest(thread_id, body = _format_curr_count(last_valid)))
 
        with message_rx:
                while True:
@@ -200,7 +216,10 @@ async def count_tracker_impl(
                                        stricken = payload_data["name"] in pending_strikes
                                        pending_strikes.discard(payload_data["name"])
                                        if payload_data["author"] != bot_user:
-                                               next_up = last_valid.number if last_valid else None
+                                               if last_valid and last_valid.number is not None:
+                                                       next_up = last_valid.number + 1
+                                               else:
+                                                       next_up = None
                                                pu = parse_update(payload_data, next_up, bot_user)
                                                rfc_ts = UUID(payload_data["id"]).time
                                                update = _Update(
@@ -217,7 +236,6 @@ async def count_tracker_impl(
                                                release_at = msg.timestamp + reorder_buffer_time.total_seconds()
                                                bisect.insort(buffer_, _BufferedUpdate(update, release_at))
                                elif msg.data["type"] == "strike":
-                                       UUID(re.match("LiveUpdate_(.+)$", msg.data["payload"])[1])  # sanity check payload
                                        slot = next(
                                                (
                                                        slot for slot in itertools.chain(buffer_, reversed(timeline))
@@ -226,11 +244,12 @@ async def count_tracker_impl(
                                                None
                                        )
                                        if slot:
-                                               slot.update.stricken = True
-                                               if isinstance(slot, _TimelineUpdate) and slot.accepted:
-                                                       logger.info(f"bad strike of {slot.update.name}")
-                                                       body = _format_bad_strike_alert(slot.update, thread_id)
-                                                       api_pool.enqueue_request(CorrectionUpdateRequest(thread_id, body))
+                                               if not slot.update.stricken:
+                                                       slot.update.stricken = True
+                                                       if not QUIET and isinstance(slot, _TimelineUpdate) and slot.accepted:
+                                                               logger.info(f"bad strike of {slot.update.name}")
+                                                               body = _format_bad_strike_alert(slot.update, thread_id)
+                                                               api_pool.enqueue_request(CorrectionUpdateRequest(thread_id, body))
                                        else:
                                                pending_strikes.add(msg.data["payload"])
 
@@ -245,27 +264,34 @@ async def count_tracker_impl(
 
                                        # valid count
                                        or (last_valid is None or last_valid.number is None or bu.update.can_follow(last_valid))
-
-                                       # invalid and not likely to become valid by transpo
-                                       or bu.update.number > last_valid.number + 3
+                                       or bu.update.command is Command.RESET
 
                                        or trio.current_time() >= bu.release_at
+                                       or bu.update.command is Command.RESET
                                )
 
                                if process:
                                        if bu.update.ts < threshold:
                                                logger.warning(f"ignoring {bu.update.name}: arrived past retention window")
                                        else:
+                                               logger.debug(f"processing update {bu.update.id} ({bu.update.number} by {bu.update.author})")
                                                handle_update(bu.update)
                                else:
+                                       logger.debug(f"holding {bu.update.id} ({bu.update.number}, checked against {last_valid.number})")
                                        new_buffer.append(bu)
                        buffer_ = new_buffer
+                       logger.debug("last count {}".format(last_valid and (last_valid.id, last_valid.number)))  # TODO remove
 
                        # delete/forget old updates
+                       new_timeline = []
+                       i = 0
                        for (i, tu) in enumerate(timeline):
                                if i >= len(timeline) - 10 or tu.update.ts >= threshold:
                                        break
-                               elif tu.bad_count and tu.update.deletable:
+                               elif tu.rejected() and tu.update.deletable:
                                        if enforcing:
                                                api_pool.enqueue_request(DeleteRequest(thread_id, tu.update.name, tu.update.ts))
-                       del timeline[: i]
+                               elif tu.update is last_valid:
+                                       new_timeline.append(tu)
+                       new_timeline.extend(islice(timeline, i, None))
+                       timeline = new_timeline
index a1c6705af5323d27ca6beb161aa54b5ead8f03f7..217f5c90f80355d51c92bbed5027573d605fa5f1 100644 (file)
@@ -25,14 +25,18 @@ with open(args.config_path) as config_file:
 
 main_cfg = parser["config"]
 
+api_pool_error_delay = dt.timedelta(seconds = main_cfg.getfloat("API pool error delay"))
+api_pool_error_window = dt.timedelta(seconds = main_cfg.getfloat("API pool error window"))
 auth_ids = set(map(int, main_cfg["auth IDs"].split()))
 bot_user = main_cfg["bot user"]
 enforcing = main_cfg.getboolean("enforcing")
 reorder_buffer_time = dt.timedelta(seconds = main_cfg.getfloat("reorder buffer time"))
+request_queue_limit = main_cfg.getint("request queue limit")
 thread_id = main_cfg["thread ID"]
 update_retention = dt.timedelta(seconds = main_cfg.getfloat("update retention time"))
 ws_parity_time = dt.timedelta(seconds = main_cfg.getfloat("WS parity time"))
 ws_pool_size = main_cfg.getint("WS pool size")
+ws_silent_limit = dt.timedelta(seconds = main_cfg.getfloat("WS silent limit"))
 ws_warmup = dt.timedelta(seconds = main_cfg.getfloat("WS warmup seconds"))
 
 db_connect_params = dict(parser["db connect params"])
@@ -54,7 +58,9 @@ async def main():
 
                nursery_a.start_soon(db_messenger.db_client_impl)
 
-               api_pool = ApiClientPool(auth_ids, db_messenger, logger)
+               api_pool = ApiClientPool(
+                       auth_ids, db_messenger, request_queue_limit, api_pool_error_window, api_pool_error_delay, logger
+               )
                nursery_a.start_soon(api_pool.token_updater_impl)
                for _ in auth_ids:
                        nursery_a.start_soon(api_pool.worker_impl)
@@ -62,16 +68,17 @@ async def main():
                (message_tx, message_rx) = trio.open_memory_channel(0)
                (pool_event_tx, pool_event_rx) = trio.open_memory_channel(0)
                merger = PoolMerger(pool_event_rx, message_tx, ws_parity_time, ws_warmup, logger)
-               async with trio.open_nursery() as nursery_b:
+               async with trio.open_nursery() as nursery_b, trio.open_nursery() as ws_pool_nursery:
                        nursery_b.start_soon(merger.event_reader_impl)
                        nursery_b.start_soon(merger.timeout_handler_impl)
-                       pool = HealingReadPool(nursery_b, ws_pool_size, thread_id, api_pool, pool_event_tx, logger)
-                       async with pool, trio.open_nursery() as nursery_c:
-                               nursery_c.start_soon(pool.conn_refresher_impl)
+                       ws_pool = HealingReadPool(ws_pool_nursery, ws_pool_size, thread_id, api_pool, pool_event_tx, logger, ws_silent_limit)
+                       async with ws_pool, trio.open_nursery() as nursery_c:
+                               nursery_c.start_soon(ws_pool.conn_refresher_impl)
                                nursery_c.start_soon(
                                        count_tracker_impl,
                                        message_rx, api_pool, reorder_buffer_time, thread_id, bot_user, enforcing, update_retention, logger
                                )
+                               await ws_pool.init_workers()
 
 
 trio_asyncio.run(main)
index 76a19b6a4482a7cdef9549d3c7721b721e8a19cf..a88ffe89b9cccbbfb62bc24a16dfb6ead068c524 100644 (file)
@@ -20,8 +20,20 @@ class Client:
        conn: Any
 
        @_channel_sender
-       async def get_auth_tokens(self, auth_ids: set[int]):
-               raise NotImplementedError()
+       async def get_auth_tokens(self, auth_ids: set[int]) -> dict[int, str]:
+               results = await self.conn.fetch(
+                       """
+                               select
+                                       distinct on (id)
+                                       id, access_token
+                               from
+                                       public.reddit_app_authorization
+                                       join unnest($1::integer[]) as request_id on id = request_id
+                               where expires > current_timestamp
+                       """,
+                       list(auth_ids),
+               )
+               return {r["id"]: r["access_token"] for r in results}
 
 
 class Messenger:
index 40f02dcbce1a08ae02923fe8ec638e69a50680d6..b257905abf0d2727a150629af541d580b0590e1a 100644 (file)
@@ -7,6 +7,7 @@ import json
 import logging
 import math
 
+from trio.lowlevel import checkpoint_if_cancelled
 from trio_websocket import ConnectionClosed, HandshakeError, open_websocket_url, WebSocketConnection
 import trio
 
@@ -44,44 +45,63 @@ class _ConnectionUp(_PoolEvent):
 
 
 class HealingReadPool:
-       def __init__(self, nursery: trio.Nursery, size, live_thread_id, api_client_pool, pool_event_tx, logger: logging.Logger):
+       def __init__(
+               self, pool_nursery: trio.Nursery, size, live_thread_id, api_client_pool, pool_event_tx, logger: logging.Logger,
+               silent_limit: dt.timedelta
+       ):
                assert size >= 2
-               self._nursery = nursery
+               self._nursery = pool_nursery
                self._size = size
                self._live_thread_id = live_thread_id
                self._api_client_pool = api_client_pool
                self._pool_event_tx = pool_event_tx
                self._logger = logger
+               self._silent_limit = silent_limit
                (self._refresh_queue_tx, self._refresh_queue_rx) = trio.open_memory_channel(size)
-               self._active_count = 0
 
        async def __aenter__(self):
-               for _ in range(self._size):
-                       await self._spawn_reader()
-               if not self._active_count:
-                       raise RuntimeError("Unable to create any WS connections")
+               pass
 
        async def __aexit__(self, exc_type, exc_value, traceback):
                self._refresh_queue_tx.close()
 
+       async def init_workers(self):
+               for _ in range(self._size):
+                       await self._spawn_reader()
+               if not self._nursery.child_tasks:
+                       raise RuntimeError("Unable to create any WS connections")
+
        async def _reader_impl(self, conn_ctx: AsyncContextManager, cancel_scope, refresh_tx):
+               # TODO could use task's implicit cancel scope
                try:
                        async with conn_ctx as conn:
-                               self._active_count += 1
                                with refresh_tx:
+                                       _tsc = False  # TODO remove
                                        with cancel_scope, suppress(ConnectionClosed):
                                                while True:
-                                                       message = await conn.get_message()
-                                                       event = _Message(trio.current_time(), json.loads(message), cancel_scope)
-                                                       await self._pool_event_tx.send(event)
-
-                                       if cancel_scope.cancelled_caught:
-                                               await conn.aclose(4058, "Server unexpectedly stopped sending messages")
-                                               self._logger.warning("replacing WS connection due to missed updates")
+                                                       with trio.move_on_after(self._silent_limit.total_seconds()) as timeout_scope:
+                                                               message = await conn.get_message()
+
+                                                       if timeout_scope.cancelled_caught:
+                                                               if len(self._nursery.child_tasks) == self._size:
+                                                                       _tsc = True
+                                                                       cancel_scope.cancel()
+                                                                       await checkpoint_if_cancelled()
+                                                       else:
+                                                               event = _Message(trio.current_time(), json.loads(message), cancel_scope)
+                                                               await self._pool_event_tx.send(event)
+
+                                       assert _tsc == timeout_scope.cancelled_caught
+                                       if timeout_scope.cancelled_caught:
+                                               self._logger.debug("replacing WS connection due to silent timeout")
+                                               await conn.aclose()
+                                       elif cancel_scope.cancelled_caught:
+                                               await conn.aclose(1008, "Server unexpectedly stopped sending messages")
+                                               self._logger.warning("replacing WS connection due to missed update")
                                        else:
                                                self._logger.warning("replacing WS connection closed by server")
+                                       self._logger.debug(f"post-kill tasks {len(self._nursery.child_tasks)}")
                                        refresh_tx.send_nowait(cancel_scope)
-                               self._active_count -= 1
                except HandshakeError:
                        self._logger.error("handshake error while opening WS connection")
 
@@ -102,6 +122,7 @@ class HealingReadPool:
                        await self._pool_event_tx.send(_ConnectionUp(trio.current_time(), new_scope))
                        refresh_tx = self._refresh_queue_tx.clone()
                        self._nursery.start_soon(self._reader_impl, conn, new_scope, refresh_tx)
+               self._logger.debug(f"post-spawn tasks {len(self._nursery.child_tasks)}")
 
        async def conn_refresher_impl(self):
                """Task to monitor and replace WS connections as they disconnect or go silent."""
@@ -109,7 +130,7 @@ class HealingReadPool:
                        async for old_scope in self._refresh_queue_rx:
                                await self._pool_event_tx.send(_ConnectionDown(trio.current_time(), old_scope))
                                await self._spawn_reader()
-                               if not self._active_count:
+                               if not self._nursery.child_tasks:
                                        raise RuntimeError("WS pool depleted")
 
 
@@ -133,6 +154,7 @@ class PoolMerger:
                self._buckets = {}
                self._scope_activations = {}
                self._pending = deque()
+               self._outgoing_scopes = set()
                (self._timer_poke_tx, self._timer_poke_rx) = trio.open_memory_channel(math.inf)
 
        async def event_reader_impl(self):
@@ -145,27 +167,29 @@ class PoolMerger:
                                        # by timestamp to avoid that.
                                        self._pending.append(event)
                                elif isinstance(event, _Message):
-                                       if event.data["type"] in {"update", "strike", "delete"}:
+                                       if event.data["type"] in ["update", "strike", "delete"]:
                                                tag = event.dedup_tag()
                                                if tag in self._buckets:
                                                        b = self._buckets[tag]
                                                        b.recipients.add(event.scope)
                                                        if b.recipients >= self._scope_activations.keys():
                                                                del self._buckets[tag]
-                                               else:
+                                               elif event.scope not in self._outgoing_scopes:
                                                        sane = (
                                                                event.scope in self._scope_activations
                                                                or any(e.scope == event.scope for e in self._pending)
                                                        )
                                                        if sane:
                                                                self._buckets[tag] = self._Bucket(event.timestamp, {event.scope})
-                                                               self._message_tx.send_nowait(event)
+                                                               await self._message_tx.send(event)
                                                        else:
                                                                raise RuntimeError("recieved message from unrecognized WS connection")
                                elif isinstance(event, _ConnectionDown):
                                        # We don't need to worry about canceling this scope at all, so no need to require it for parity for
                                        # any message, even older ones.  The scope may be gone already, if we canceled it previously.
                                        self._scope_activations.pop(event.scope, None)
+                                       self._pending = deque(ev for ev in self._pending if ev.scope is not event.scope)
+                                       self._outgoing_scopes.discard(event.scope)
                                else:
                                        raise TypeError(f"Expected pool event, found {event!r}")
                                self._timer_poke_tx.send_nowait(None)  # may have new work for the timer
@@ -187,12 +211,10 @@ class PoolMerger:
                                                scope for (scope, active) in self._scope_activations.items()
                                                if active + self._conn_warmup.total_seconds() < now
                                        }
-                                       if bucket.recipients >= target_scopes:
-                                               del self._buckets[tag]
-                                       elif now > bucket.start + self._parity_timeout.total_seconds():
+                                       if now > bucket.start + self._parity_timeout.total_seconds():
                                                for scope in target_scopes - bucket.recipients:
+                                                       self._outgoing_scopes.add(scope)
                                                        scope.cancel()
-                                                       del self._scope_activations[scope]
                                                del self._buckets[tag]
                                        else:
                                                await trio.sleep_until(bucket.start + self._parity_timeout.total_seconds())
index fed075734d83d6d53a03300abcb7f8be933b089d..b4ebbe32a56ebc37b3336e2bcc4b289a82a1727c 100644 (file)
@@ -1,13 +1,33 @@
-"""Unbounded blocking priority queue for Trio"""
+"""Unbounded blocking queues for Trio"""
 
+from collections import deque
 from dataclasses import dataclass
 from functools import total_ordering
-from typing import Any
+from typing import Any, Iterable
 import heapq
 
 from trio.lowlevel import ParkingLot
 
 
+class Queue:
+       def __init__(self):
+               self._deque = deque()
+               self._empty_wait = ParkingLot()
+
+       def push(self, el: Any) -> None:
+               self._deque.append(el)
+               self._empty_wait.unpark()
+
+       def extend(self, els: Iterable[Any]) -> None:
+               for el in els:
+                       self.push(el)
+
+       async def pop(self) -> Any:
+               if not self._deque:
+                       await self._empty_wait.park()
+               return self._deque.popleft()
+
+
 @total_ordering
 @dataclass
 class _ReverseOrdWrapper:
@@ -24,8 +44,7 @@ class MaxHeap:
 
        def push(self, item):
                heapq.heappush(self._heap, _ReverseOrdWrapper(item))
-               if len(self._empty_wait):
-                       self._empty_wait.unpark()
+               self._empty_wait.unpark()
 
        async def pop(self):
                if not self._heap:
index 600b38ea28222dd5271031eb7b22917bf6b0c1e9..ab59a95fd99c1d0a0c6b3c0f210427e7b6d78590 100644 (file)
@@ -1,9 +1,9 @@
 """Multi-auth HTTP connection pool with Reddit API wrappers and rate limiting."""
 
 from abc import ABCMeta, abstractmethod
-from collections import deque
 from dataclasses import dataclass, field
 from functools import total_ordering
+from socket import EAI_AGAIN, EAI_FAIL, gaierror
 import datetime as dt
 import logging
 
@@ -12,7 +12,7 @@ import asks
 import trio
 
 from strikebot import __version__ as VERSION
-from strikebot.queue import MaxHeap
+from strikebot.queue import MaxHeap, Queue
 
 
 REQUEST_DELAY = dt.timedelta(seconds = 1)
@@ -30,14 +30,14 @@ class _Request(metaclass = ABCMeta):
 
        def __lt__(self, other):
                if type(self) is type(other):
-                       return self._subtype_key() < other._subtype_key()
+                       return self._subtype_cmp_key() < other._subtype_cmp_key()
                else:
                        prec = self._SUBTYPE_PRECEDENCE
                        return prec.index(type(other)) < prec.index(type(self))
 
        def __eq__(self, other):
                if type(self) is type(other):
-                       return self._subtype_key() == other._subtype_key()
+                       return self._subtype_cmp_key() == other._subtype_cmp_key()
                else:
                        return False
 
@@ -155,74 +155,129 @@ class AppCooldown:
 
 
 class ApiClientPool:
-       def __init__(self, auth_ids: set[int], db_messenger, logger: logging.Logger):
+       def __init__(
+               self,
+               auth_ids: set[int],
+               db_messenger: "strikebot.db.Messenger",
+               request_queue_limit: int,
+               error_window: dt.timedelta,
+               error_delay: dt.timedelta,
+               logger: logging.Logger,
+       ):
+               self._auth_ids = auth_ids
                self._db_messenger = db_messenger
+               self._request_queue_limit = request_queue_limit
+               self._error_window = error_window
+               self._error_delay = error_delay
                self._logger = logger
 
-               self._tokens = {id_: None for id_ in auth_ids}
-               self._tokens_installed = trio.Event()
                now = trio.current_time()
-               self._app_queue = deque(AppCooldown(id_, now) for id_ in auth_ids)
+               self._tokens = {}
+               self._app_queue = Queue()
+               self._waiting = {id_: AppCooldown(id_, now) for id_ in auth_ids}
                self._request_queue = MaxHeap()
 
+               # pool-wide API error backoff
+               self._last_error = None
+               self._global_resume = None
+
                self._session = asks.Session(connections = len(auth_ids))
                self._session.base_location = API_BASE_URL
 
        async def _update_tokens(self):
-               tokens = await self._db_messenger.do("get_auth_tokens", (self._tokens.keys(),))
-               assert tokens.keys() == self._tokens.keys()
-               self._tokens = tokens
-               self._tokens_installed.set()
+               tokens = await self._db_messenger.do("get_auth_tokens", (self._auth_ids,))
+               self._tokens.update(tokens)
+
+               awaken_auths = self._waiting.keys() & tokens.keys()
+               self._app_queue.extend(self._waiting.pop(auth_id) for auth_id in awaken_auths)
 
        async def token_updater_impl(self):
                last_update = None
                while True:
                        if last_update is not None:
                                await trio.sleep_until(last_update + TOKEN_UPDATE_DELAY.total_seconds())
-                       last_update = trio.current_time()
+
+                       if last_update is None:
+                               last_update = trio.current_time()
+                       else:
+                               last_update += TOKEN_UPDATE_DELAY.total_seconds()
+
                        await self._update_tokens()
 
        def _check_queue_size(self) -> None:
-               if len(self._request_queue) > 5:
-                       self._logger.warning(f"API workers may be saturated; {len(self._request_queue)} requests in queue")
+               if len(self._request_queue) > self._request_queue_limit:
+                       raise RuntimeError("request queue size exceeded limit")
 
        async def make_request(self, request: _Request) -> Response:
                (resp_tx, resp_rx) = trio.open_memory_channel(0)
                self._request_queue.push((request, resp_tx))
+               self._check_queue_size()
                async with resp_rx:
                        return await resp_rx.receive()
 
        def enqueue_request(self, request: _Request) -> None:
                self._request_queue.push((request, None))
+               self._check_queue_size()
 
        async def worker_impl(self):
-               await self._tokens_installed.wait()
                while True:
                        (request, resp_tx) = await self._request_queue.pop()
-                       while trio.current_time() < self._app_queue[0].ready_at:
-                               await trio.sleep_until(self._app_queue[0].ready_at)
+                       cooldown = await self._app_queue.pop()
+                       await trio.sleep_until(cooldown.ready_at)
+                       if self._global_resume:
+                               await trio.sleep_until(self._global_resume)
 
-                       cooldown = self._app_queue.popleft()
                        asks_kwargs = request.to_asks_kwargs()
-                       asks_kwargs.setdefault("headers", {}).update({
+                       headers = asks_kwargs.setdefault("headers", {})
+                       headers.update({
                                "Authorization": "Bearer {}".format(self._tokens[cooldown.auth_id]),
                                "User-Agent": USER_AGENT_FMT.format(auth_id = cooldown.auth_id),
                        })
-                       cooldown.ready_at = trio.current_time() + REQUEST_DELAY.total_seconds()
-                       resp = await self._session.request(**asks_kwargs)
-
-                       if resp.status_code == 429:
-                               # We disagreed about the rate limit state; just try again later.
-                               self._request_queue.put((request, resp_tx))
-                       elif 400 <= resp.status_code < 500:
-                               # If we're doing something wrong, let's catch it right away.
-                               if resp_tx:
-                                       resp_tx.close()
-                               raise RuntimeError("Unexpected client error response: {}".format(resp.status_code))
-                       else:
-                               if resp.status_code != 200:
-                                       self._logger.warning(f"got HTTP {resp.status_code} from Reddit API")
-                               if resp_tx:
-                                       await resp_tx.send(resp)
 
-                       self._app_queue.append(cooldown)
+                       request_time = trio.current_time()
+                       try:
+                               resp = await self._session.request(**asks_kwargs)
+                       except gaierror as e:
+                               if e.errno in [EAI_FAIL, EAI_AGAIN]:
+                                       # DNS failure, probably temporary
+                                       error = True
+                               else:
+                                       raise
+                       else:
+                               resp.body  # read response
+                               error = False
+                               wait_for_token = False
+                               if resp.status_code == 429:
+                                       # We disagreed about the rate limit state; just try again later.
+                                       self._logger.warning("rate limited by Reddit API")
+                                       error = True
+                               elif resp.status_code == 401:
+                                       self._logger.warning("got HTTP 401 from Reddit API")
+                                       error = True
+                                       wait_for_token = True
+                               elif resp.status_code in [404, 500]:
+                                       self._logger.warning(f"got HTTP {resp.status_code} from Reddit API, retrying")
+                                       error = True
+                               elif 400 <= resp.status_code < 500:
+                                       # If we're doing something wrong, let's catch it right away.
+                                       raise RuntimeError(f"unexpected client error response: {resp.status_code}")
+                               else:
+                                       if resp.status_code != 200:
+                                               raise RuntimeError(f"unexpected status code {resp.status_code}")
+                                       if resp_tx:
+                                               await resp_tx.send(resp)
+
+                       if error:
+                               self._request_queue.push((request, resp_tx))
+                               self._check_queue_size()
+                               if self._last_error:
+                                       spread = dt.timedelta(seconds = request_time - self._last_error)
+                                       if spread <= self._error_window:
+                                               self._global_resume = request_time + self._error_delay.total_seconds()
+                               self._last_error = request_time
+
+                       cooldown.ready_at = request_time + REQUEST_DELAY.total_seconds()
+                       if wait_for_token:
+                               self._waiting[cooldown.auth_id] = cooldown
+                       else:
+                               self._app_queue.push(cooldown)
index 9b9ddaeffc8d4a11fa5ffe582999ab8115360b5e..55b34261dd44c685a08fb93c88b2cee3af58e54b 100644 (file)
@@ -9,34 +9,34 @@ def _build_payload(body_html: str) -> dict[str, str]:
 
 class UpdateParsingTests(TestCase):
        def test_successful_counts(self):
-               pu = parse_update(_build_payload("<div>12,345,678 spaghetti</div>"), None)
+               pu = parse_update(_build_payload("<div>12,345,678 spaghetti</div>"), None, "")
                self.assertEqual(pu.number, 12345678)
                self.assertFalse(pu.deletable)
 
-               pu = parse_update(_build_payload("<div><p>0</p><br/>oz</div>"), None)
+               pu = parse_update(_build_payload("<div><p>0</p><br/>oz</div>"), None, "")
                self.assertEqual(pu.number, 0)
                self.assertFalse(pu.deletable)
 
        def test_non_counts(self):
-               pu = parse_update(_build_payload("<div><table><tbody><tr><td>zoo</td></tr></tbody></table></div>"), None)
+               pu = parse_update(_build_payload("<div><table><tbody><tr><td>zoo</td></tr></tbody></table></div>"), None, "")
                self.assertFalse(pu.count_attempt)
                self.assertFalse(pu.deletable)
 
        def test_typos(self):
-               pu = parse_update(_build_payload("<span>v9</span>"), 888)
+               pu = parse_update(_build_payload("<span>v9</span>"), 888, "")
                self.assertIsNone(pu.number)
                self.assertTrue(pu.count_attempt)
 
-               pu = parse_update(_build_payload("<div>v11.585 Empire</div>"), None)
+               pu = parse_update(_build_payload("<div>v11.585 Empire</div>"), None, "")
                self.assertIsNone(pu.number)
                self.assertTrue(pu.count_attempt)
                self.assertFalse(pu.deletable)
 
-               pu = parse_update(_build_payload("<div>11, 585, 22 </div>"), 11_585_202)
+               pu = parse_update(_build_payload("<div>11, 585, 22 </div>"), 11_585_202, "")
                self.assertIsNone(pu.number)
                self.assertTrue(pu.count_attempt)
                self.assertTrue(pu.deletable)
 
-               pu = parse_update(_build_payload("<span>0490499</span>"), 4999)
+               pu = parse_update(_build_payload("<span>0490499</span>"), 4999, "")
                self.assertIsNone(pu.number)
                self.assertTrue(pu.count_attempt)
index cd814bf8b22473fa314f356f9e425b3acaf1fc44..f61d4e2432549ff6adc8e01c03adcf1990523f00 100644 (file)
@@ -14,7 +14,7 @@ Command = Enum("Command", ["RESET", "REPORT"])
 class ParsedUpdate:
        number: Optional[int]
        command: Optional[Command]
-       count_attempt: bool
+       count_attempt: bool  # either well-formed or typo
        deletable: bool
 
 
@@ -28,6 +28,8 @@ def _parse_command(line: str, bot_user: str) -> Optional[Command]:
 
 
 def parse_update(payload_data: dict, curr_count: Optional[int], bot_user: str) -> ParsedUpdate:
+       # curr_count is the next number up, one more than the last count
+
        NEW_LINE = object()
        SPACE = object()
 
@@ -52,7 +54,6 @@ def parse_update(payload_data: dict, curr_count: Optional[int], bot_user: str) -
                        if el.text:
                                worklist.appendleft(el.text)
                elif el.tag == "li":
-                       assert not el.tail
                        worklist.appendleft(NEW_LINE)
                        worklist.appendleft(el.text)
                elif el.tag in ["p", "div", "h1", "h2", "blockquote"]:
@@ -66,28 +67,26 @@ def parse_update(payload_data: dict, curr_count: Optional[int], bot_user: str) -
                elif el.tag in ["ul", "ol"]:
                        if el.tail:
                                worklist.appendleft(el.tail)
-                       assert not el.text
                        for sub in reversed(el):
                                worklist.appendleft(sub)
                        worklist.appendleft(NEW_LINE)
                elif el.tag == "pre":
-                       out.extend([l] for l in el.text.splitlines())
+                       if el.text:
+                               out.extend([l] for l in el.text.splitlines())
+                       worklist.appendleft(NEW_LINE)
                        if el.tail:
                                worklist.appendleft(el.tail)
-                       worklist.appendleft(NEW_LINE)
+                       for sub in reversed(el):
+                               worklist.appendleft(sub)
                elif el.tag == "table":
                        if el.tail:
                                worklist.appendleft(el.tail)
                        worklist.appendleft(NEW_LINE)
-                       assert not el.text
                        for sub in reversed(el):
                                assert sub.tag in ["thead", "tbody"]
-                               assert not sub.text
                                for row in reversed(sub):
                                        assert row.tag == "tr"
-                                       assert not row.text or row.tail
                                        for (i, cell) in enumerate(reversed(row)):
-                                               assert not cell.tail
                                                worklist.appendleft(cell)
                                                if i != len(row) - 1:
                                                        worklist.appendleft(SPACE)
@@ -115,13 +114,14 @@ def parse_update(payload_data: dict, curr_count: Optional[int], bot_user: str) -
        if lines:
                first = lines[0]
                match = re.match(
-                       "(?P<v>v)?(?P<neg>-)?(?P<num>\\d+((?P<sep>[,. \u2009]|, )\\d+((?P=sep)\\d+)*)?)( |$)",
+                       "(?P<v>v)?(?P<neg>-)?(?P<num>\\d+((?P<sep>[,. \u2009]|, )\\d+((?P=sep)\\d+)*)?)",
                        first,
                        re.ASCII,  # only recognize ASCII digits
                )
                if match:
                        ct_str = match["num"]
                        sep = match["sep"]
+                       post = first[match.end() :]
 
                        zeros = False
                        while len(ct_str) > 1 and ct_str[0] == "0":
@@ -130,11 +130,14 @@ def parse_update(payload_data: dict, curr_count: Optional[int], bot_user: str) -
 
                        parts = ct_str.split(sep) if sep else [ct_str]
                        parts_valid = (
-                               len(parts[0]) in range(1, 3)
-                               and all(len(p) == 3 for p in parts[1:])
+                               sep is None
+                               or (
+                                       len(parts[0]) in range(1, 4)
+                                       and all(len(p) == 3 for p in parts[1:])
+                               )
                        )
                        digits = "".join(parts)
-                       lone = not first[match.end() :].strip() and len(lines) == 1
+                       lone = len(lines) == 1 and (not post or post.isspace())
                        typo = False
                        if lone:
                                if match["v"] and len(ct_str) <= 2:
@@ -143,7 +146,7 @@ def parse_update(payload_data: dict, curr_count: Optional[int], bot_user: str) -
                                elif match["v"] and parts_valid:
                                        # v followed by count
                                        typo = True
-                               elif curr_count and curr_count >= 100:
+                               elif curr_count and curr_count >= 100 and bool(match["neg"]) == (curr_count < 0):
                                        goal = (sep or "").join(_separate(str(curr_count)))
                                        partials = [goal[: -2], goal[: -1], goal[: -2] + goal[-1]]
                                        if ct_str in partials:
@@ -152,15 +155,24 @@ def parse_update(payload_data: dict, curr_count: Optional[int], bot_user: str) -
                                        elif ct_str in [p + goal for p in partials]:
                                                # double paste
                                                typo = True
+
                        if match["v"] or zeros or typo or (digits == "0" and match["neg"]):
                                number = None
                                count_attempt = True
                                deletable = lone
                        elif parts_valid:
                                number = -int(digits) if match["neg"] else int(digits)
-                               count_attempt = True
-                               special = curr_count is not None and abs(number - curr_count) <= 25 and _is_special_number(number)
+                               special = (
+                                       curr_count is not None
+                                       and abs(number - curr_count) <= 25
+                                       and _is_special_number(number)
+                               )
                                deletable = lone and not special
+                               if post and not post[0].isspace():
+                                       count_attempt = curr_count is not None and abs(number - curr_count) <= 25
+                                       number = None
+                               else:
+                                       count_attempt = True
                        else:
                                number = None
                                count_attempt = False