Various MVP improvements, bug fixes, and new debugging stuff
authorJakob Cornell <jakob+gpg@jcornell.net>
Thu, 21 Jul 2022 02:16:11 +0000 (21:16 -0500)
committerJakob Cornell <jakob+gpg@jcornell.net>
Thu, 21 Jul 2022 03:31:02 +0000 (22:31 -0500)
setup.cfg
src/strikebot/__init__.py
src/strikebot/__main__.py
src/strikebot/common.py [new file with mode: 0644]
src/strikebot/db.py
src/strikebot/live_ws.py
src/strikebot/queue.py
src/strikebot/reddit_api.py
src/strikebot/tests.py
src/strikebot/updates.py

index 1c54cc81c242112808f86b0209301b831023af3c..eb1fd253cfcaf2babb4f723f244dd8e8e7a7545c 100644 (file)
--- a/setup.cfg
+++ b/setup.cfg
@@ -8,7 +8,8 @@ package_dir =
 packages = strikebot
 python_requires = ~= 3.9
 install_requires =
-       trio ~= 0.19
-       triopg ~= 0.6
-       trio-websocket ~= 0.9
        asks ~= 2.4
+       beautifulsoup4 ~= 4.11
+       trio == 0.19
+       trio-websocket == 0.9.2
+       triopg == 0.6.0
index 10b4013b2096d5d10e178e956db10f7d54674530..d75d4d1163479ee4b11b6aafd6f24b2e6246c356 100644 (file)
@@ -17,7 +17,6 @@ from strikebot.updates import Command, parse_update
 
 
 __version__ = importlib.metadata.version(__package__)
-QUIET = True  # TODO remove
 
 
 @dataclass
@@ -36,6 +35,12 @@ class _Update:
                """Determine whether this count update can follow another count update."""
                return self.number == prior.number + 1 and self.author != prior.author
 
+       def __str__(self):
+               content = str(self.number)
+               if self.command:
+                       content += "+" + self.command.name
+               return "{}({} by {})".format(self.id[4 : 8], content, self.author)
+
 
 @total_ordering
 @dataclass
@@ -100,6 +105,7 @@ async def count_tracker_impl(
        timeline = []
        last_valid: Optional[_Update] = None
        pending_strikes: set[str] = set()  # names of updates to mark stricken on arrival
+       forgotten: bool = False  # whether the retention period for an update has passed during the run
 
        def handle_update(update):
                nonlocal last_valid
@@ -108,7 +114,8 @@ async def count_tracker_impl(
 
                pos = bisect.bisect(timeline, tu)
                if pos != len(timeline):
-                       logger.warning(f"long transpo: {update.name}")
+                       logger.warning(f"long transpo: {update.id}")
+
                if pos > 0 and timeline[pos - 1].update.id == update.id:
                        # The pool sent this update message multiple times.  This could be because the message reached parity and
                        # its bucket was dropped just before a new connection was opened and sent its own copy, but connections also
@@ -119,11 +126,13 @@ async def count_tracker_impl(
                        (timeline[i].update for i in reversed(range(pos)) if timeline[i].accepted),
                        None
                )
-               if update.command is not Command.RESET and update.number is not None and last_valid and not pred:
-                       logger.warning("ignoring {update.name}: no valid prior count on record")
+               contender = update.command is Command.RESET or update.number is not None
+               if contender and pred is None and last_valid and forgotten:
+                       # This is a really long transpo which we can't judge valid or invalid since we've forgotten the surrounding
+                       # updates. The best we can do is just ignore it.
+                       logger.warning(f"ignoring {update.id}: no valid prior count on record")
                else:
                        timeline.insert(pos, tu)
-                       assert len({ctu.update.id for ctu in timeline}) == len(timeline)  # TODO remove
                        tu.accepted = (
                                update.command is Command.RESET
                                or (
@@ -131,30 +140,37 @@ async def count_tracker_impl(
                                        and (pred is None or pred.number is None or update.can_follow(pred))
                                )
                        )
-                       logger.debug(f"accepted: {tu.accepted}")
-                       logger.debug("  pred: {}".format(pred and (pred.id, pred.number, pred.author)))
+                       logger.debug("{} {} -> {}".format("accepted" if tu.accepted else "rejected", pred, update))
                        if tu.accepted:
                                # resync subsequent updates
                                newly_valid = []
                                newly_invalid = []
                                resync_last_valid = update
+                               converged = False
                                for scan_tu in islice(timeline, pos + 1, None):
                                        if scan_tu.update.command is Command.RESET:
-                                               break
+                                               resync_last_valid = scan_tu.update
+                                               if last_valid:
+                                                       converged = True
                                        elif scan_tu.update.number is not None:
-                                               accept = last_valid.number is None or scan_tu.update.can_follow(last_valid)
+                                               accept = resync_last_valid.number is None or scan_tu.update.can_follow(resync_last_valid)
                                                if accept and scan_tu.accepted:
                                                        # resync would have no effect past this point
-                                                       break
+                                                       if last_valid:
+                                                               converged = True
                                                elif accept:
                                                        newly_valid.append(scan_tu)
                                                        resync_last_valid = scan_tu.update
                                                elif scan_tu.accepted:
                                                        newly_invalid.append(scan_tu)
+                                                       if scan_tu.update is last_valid:
+                                                               last_valid = None
                                                scan_tu.accepted = accept
+                                       if converged:
+                                               break
 
-                               if last_valid:
-                                       last_valid = max(last_valid, resync_last_valid, key = lambda u: u.ts)
+                               if converged:
+                                       assert last_valid.ts >= resync_last_valid.ts
                                else:
                                        last_valid = resync_last_valid
 
@@ -173,13 +189,12 @@ async def count_tracker_impl(
                                        )
 
                                if update.stricken:
-                                       logger.info(f"bad strike of {update.name}")
+                                       logger.info(f"bad strike of {update.id}")
                                        parts.append(_format_bad_strike_alert(update, thread_id))
 
                                if parts:
                                        parts.append(_format_curr_count(last_valid))
-                                       if not QUIET:
-                                               api_pool.enqueue_request(CorrectionUpdateRequest(thread_id, "\n\n".join(parts)))
+                                       api_pool.enqueue_request(CorrectionUpdateRequest(thread_id, "\n\n".join(parts)))
 
                                for invalid_tu in newly_invalid:
                                        if not invalid_tu.update.stricken:
@@ -192,8 +207,7 @@ async def count_tracker_impl(
                                update.stricken = True
 
                if update.command is Command.REPORT:
-                       if not QUIET:
-                               api_pool.enqueue_request(ReportUpdateRequest(thread_id, body = _format_curr_count(last_valid)))
+                       api_pool.enqueue_request(ReportUpdateRequest(thread_id, body = _format_curr_count(last_valid)))
 
        with message_rx:
                while True:
@@ -246,8 +260,8 @@ async def count_tracker_impl(
                                        if slot:
                                                if not slot.update.stricken:
                                                        slot.update.stricken = True
-                                                       if not QUIET and isinstance(slot, _TimelineUpdate) and slot.accepted:
-                                                               logger.info(f"bad strike of {slot.update.name}")
+                                                       if isinstance(slot, _TimelineUpdate) and slot.accepted:
+                                                               logger.info(f"bad strike of {slot.update.id}")
                                                                body = _format_bad_strike_alert(slot.update, thread_id)
                                                                api_pool.enqueue_request(CorrectionUpdateRequest(thread_id, body))
                                        else:
@@ -272,15 +286,15 @@ async def count_tracker_impl(
 
                                if process:
                                        if bu.update.ts < threshold:
-                                               logger.warning(f"ignoring {bu.update.name}: arrived past retention window")
+                                               logger.warning(f"ignoring {bu.update}: arrived past retention window")
                                        else:
-                                               logger.debug(f"processing update {bu.update.id} ({bu.update.number} by {bu.update.author})")
+                                               logger.debug(f"processing {bu.update}")
                                                handle_update(bu.update)
                                else:
-                                       logger.debug(f"holding {bu.update.id} ({bu.update.number}, checked against {last_valid.number})")
+                                       logger.debug(f"holding {bu.update}, checked against {last_valid})")
                                        new_buffer.append(bu)
                        buffer_ = new_buffer
-                       logger.debug("last count {}".format(last_valid and (last_valid.id, last_valid.number)))  # TODO remove
+                       logger.debug("last count {}".format(last_valid))
 
                        # delete/forget old updates
                        new_timeline = []
@@ -293,5 +307,8 @@ async def count_tracker_impl(
                                                api_pool.enqueue_request(DeleteRequest(thread_id, tu.update.name, tu.update.ts))
                                elif tu.update is last_valid:
                                        new_timeline.append(tu)
+
+                       if i:
+                               forgotten = True
                        new_timeline.extend(islice(timeline, i, None))
                        timeline = new_timeline
index 217f5c90f80355d51c92bbed5027573d605fa5f1..63da2e150d8bd3c0cc30857a57ce8f8afe2f3fc9 100644 (file)
@@ -1,10 +1,17 @@
+from enum import Enum
+from inspect import getmodule
+from logging import FileHandler, getLogger, StreamHandler
+from signal import SIGUSR1
+from sys import stdout
+from traceback import StackSummary
+from typing import Optional
 import argparse
 import configparser
 import datetime as dt
 import logging
-import sys
 
-import trio
+from trio import open_memory_channel, open_nursery, open_signal_receiver
+from trio.lowlevel import current_root_task, Task
 import trio_asyncio
 import triopg
 
@@ -14,6 +21,84 @@ from strikebot.live_ws import HealingReadPool, PoolMerger
 from strikebot.reddit_api import ApiClientPool
 
 
+_DEBUG_LOG_PATH: Optional[str] = None  # path to file to write debug logs to, if any
+
+
+_TaskKind = Enum(
+       "_TaskKind",
+       [
+               "API_WORKER",
+               "DB_CLIENT",
+               "TOKEN_UPDATE",
+               "TRACKER",
+               "WS_MERGE_READER",
+               "WS_MERGE_TIMER",
+               "WS_REFRESHER",
+               "WS_WORKER",
+       ]
+)
+
+
+def _classify_task(task: Task) -> _TaskKind:
+       mod_name = getmodule(task.coro.cr_code).__name__
+       if mod_name.startswith("trio_asyncio."):
+               return None
+       elif task.coro.__name__ == "_signal_handler_impl":
+               return None
+       else:
+               by_coro_name = {
+                       "db_client_impl": _TaskKind.DB_CLIENT,
+                       "token_updater_impl": _TaskKind.TOKEN_UPDATE,
+                       "event_reader_impl": _TaskKind.WS_MERGE_READER,
+                       "timeout_handler_impl": _TaskKind.WS_MERGE_TIMER,
+                       "conn_refresher_impl": _TaskKind.WS_REFRESHER,
+                       "count_tracker_impl": _TaskKind.TRACKER,
+                       "_reader_impl": _TaskKind.WS_WORKER,
+                       "worker_impl": _TaskKind.API_WORKER,
+               }
+               return by_coro_name[task.coro.__name__]
+
+
+def _get_all_tasks():
+       [ta_loop] = [
+               task
+               for nursery in current_root_task().child_nurseries
+               for task in nursery.child_tasks
+               if getmodule(task.coro.cr_code).__name__.startswith("trio_asyncio.")
+       ]
+       for nursery in ta_loop.child_nurseries:
+               yield from nursery.child_tasks
+
+
+def _print_task_info():
+       groups = {kind: [] for kind in _TaskKind}
+       for task in _get_all_tasks():
+               kind = _classify_task(task)
+               if kind:
+                       coro = task.coro
+                       frame_tuples = []
+                       while coro is not None:
+                               if hasattr(coro, "cr_frame"):
+                                       frame_tuples.append((coro.cr_frame, coro.cr_frame.f_lineno))
+                                       coro = coro.cr_await
+                               else:
+                                       frame_tuples.append((coro.gi_frame, coro.gi_frame.f_lineno))
+                                       coro = coro.gi_yieldfrom
+                       groups[kind].append(StackSummary.extract(iter(frame_tuples)))
+       for kind in _TaskKind:
+               print(kind.name)
+               for ss in groups[kind]:
+                       print("  task")
+                       for line in ss.format():
+                               print("    " + line, end = "")
+
+
+async def _signal_handler_impl():
+       with open_signal_receiver(SIGUSR1) as sig_src:
+               async for _ in sig_src:
+                       _print_task_info()
+
+
 ap = argparse.ArgumentParser(__package__)
 ap.add_argument("config_path")
 args = ap.parse_args()
@@ -39,46 +124,67 @@ ws_pool_size = main_cfg.getint("WS pool size")
 ws_silent_limit = dt.timedelta(seconds = main_cfg.getfloat("WS silent limit"))
 ws_warmup = dt.timedelta(seconds = main_cfg.getfloat("WS warmup seconds"))
 
-db_connect_params = dict(parser["db connect params"])
+db_cfg = parser["db connect params"]
+getters = {
+       "port": db_cfg.getint,
+}
+db_connect_params = {k: getters.get(k, db_cfg.get)(k) for k in db_cfg}
 
 
-logger = logging.getLogger(__package__)
+logger = getLogger(__package__)
 logger.setLevel(logging.DEBUG)
 
-handler = logging.StreamHandler(sys.stdout)
-handler.setFormatter(logging.Formatter("{asctime:23}: {levelname:8}: {message}", style = "{"))
+handler = StreamHandler(stdout)
+handler.setLevel(logging.WARNING)
+handler.setFormatter(logging.Formatter("{asctime:23}: {name:17} {levelname:8}: {message}", style = "{"))
 logger.addHandler(handler)
 
+if _DEBUG_LOG_PATH:
+       debug_handler = FileHandler(_DEBUG_LOG_PATH, mode = "w")
+       debug_handler.setLevel(logging.DEBUG)
+       debug_handler.setFormatter(logging.Formatter("{asctime:23}: {name:17} {levelname:8}: {message}", style = "{"))
+       logger.addHandler(debug_handler)
+
 
 async def main():
-       async with trio.open_nursery() as nursery_a, triopg.connect(**db_connect_params) as db_conn:
-               (req_tx, req_rx) = trio.open_memory_channel(0)
+       async with (
+               triopg.connect(**db_connect_params) as db_conn,
+               open_nursery() as nursery_a,
+               open_nursery() as nursery_b,
+               open_nursery() as ws_pool_nursery,
+               open_nursery() as nursery_c,
+       ):
+               nursery_a.start_soon(_signal_handler_impl)
+
                client = Client(db_conn)
                db_messenger = Messenger(client)
-
                nursery_a.start_soon(db_messenger.db_client_impl)
 
                api_pool = ApiClientPool(
-                       auth_ids, db_messenger, request_queue_limit, api_pool_error_window, api_pool_error_delay, logger
+                       auth_ids, db_messenger, request_queue_limit, api_pool_error_window, api_pool_error_delay,
+                       logger.getChild("api")
                )
                nursery_a.start_soon(api_pool.token_updater_impl)
                for _ in auth_ids:
-                       nursery_a.start_soon(api_pool.worker_impl)
-
-               (message_tx, message_rx) = trio.open_memory_channel(0)
-               (pool_event_tx, pool_event_rx) = trio.open_memory_channel(0)
-               merger = PoolMerger(pool_event_rx, message_tx, ws_parity_time, ws_warmup, logger)
-               async with trio.open_nursery() as nursery_b, trio.open_nursery() as ws_pool_nursery:
-                       nursery_b.start_soon(merger.event_reader_impl)
-                       nursery_b.start_soon(merger.timeout_handler_impl)
-                       ws_pool = HealingReadPool(ws_pool_nursery, ws_pool_size, thread_id, api_pool, pool_event_tx, logger, ws_silent_limit)
-                       async with ws_pool, trio.open_nursery() as nursery_c:
-                               nursery_c.start_soon(ws_pool.conn_refresher_impl)
-                               nursery_c.start_soon(
-                                       count_tracker_impl,
-                                       message_rx, api_pool, reorder_buffer_time, thread_id, bot_user, enforcing, update_retention, logger
-                               )
-                               await ws_pool.init_workers()
+                       await nursery_a.start(api_pool.worker_impl)
 
+               (message_tx, message_rx) = open_memory_channel(0)
+               (pool_event_tx, pool_event_rx) = open_memory_channel(0)
+
+               merger = PoolMerger(pool_event_rx, message_tx, ws_parity_time, ws_warmup, logger.getChild("merge"))
+               nursery_b.start_soon(merger.event_reader_impl)
+               nursery_b.start_soon(merger.timeout_handler_impl)
+
+               ws_pool = HealingReadPool(
+                       ws_pool_nursery, ws_pool_size, thread_id, api_pool, pool_event_tx, logger.getChild("live_ws"),
+                       ws_silent_limit
+               )
+               nursery_c.start_soon(ws_pool.conn_refresher_impl)
+
+               nursery_c.start_soon(
+                       count_tracker_impl, message_rx, api_pool, reorder_buffer_time, thread_id, bot_user, enforcing,
+                       update_retention, logger.getChild("track")
+               )
+               await ws_pool.init_workers()
 
 trio_asyncio.run(main)
diff --git a/src/strikebot/common.py b/src/strikebot/common.py
new file mode 100644 (file)
index 0000000..519037a
--- /dev/null
@@ -0,0 +1,10 @@
+from typing import Any
+
+
+def int_digest(val: int) -> str:
+       return "{:04x}".format(val & 0xffff)
+
+
+def obj_digest(obj: Any) -> str:
+       """Makes a short digest of the identity of an object."""
+       return int_digest(id(obj))
index a88ffe89b9cccbbfb62bc24a16dfb6ead068c524..c9a060af8dc5c229d168f6193cbcfc5a2cd78698 100644 (file)
@@ -8,8 +8,9 @@ import trio
 
 def _channel_sender(method):
        async def wrapped(self, resp_channel, *args, **kwargs):
+               ret = await method(self, *args, **kwargs)
                with resp_channel:
-                       await resp_channel.send(await method(self, *args, **kwargs))
+                       await resp_channel.send(ret)
 
        return wrapped
 
@@ -45,7 +46,8 @@ class Messenger:
                """This is run by consumers of the DB wrapper."""
                method = getattr(self._client, method_name)
                (resp_tx, resp_rx) = trio.open_memory_channel(0)
-               await self._request_tx.send(method(resp_tx, *args))
+               coro = method(resp_tx, *args)
+               await self._request_tx.send(coro)
                async with resp_rx:
                        return await resp_rx.receive()
 
index b257905abf0d2727a150629af541d580b0590e1a..cb8b3d0b63f26cd168c310e2e3cd8847b479e179 100644 (file)
@@ -7,10 +7,10 @@ import json
 import logging
 import math
 
-from trio.lowlevel import checkpoint_if_cancelled
 from trio_websocket import ConnectionClosed, HandshakeError, open_websocket_url, WebSocketConnection
 import trio
 
+from strikebot.common import int_digest, obj_digest
 from strikebot.reddit_api import AboutLiveThreadRequest
 
 
@@ -57,13 +57,11 @@ class HealingReadPool:
                self._pool_event_tx = pool_event_tx
                self._logger = logger
                self._silent_limit = silent_limit
-               (self._refresh_queue_tx, self._refresh_queue_rx) = trio.open_memory_channel(size)
 
-       async def __aenter__(self):
-               pass
+               (self._refresh_queue_tx, self._refresh_queue_rx) = trio.open_memory_channel(size)
 
-       async def __aexit__(self, exc_type, exc_value, traceback):
-               self._refresh_queue_tx.close()
+               # number of workers who have stopped receiving updates but whose tasks aren't yet stopped
+               self._closing = 0
 
        async def init_workers(self):
                for _ in range(self._size):
@@ -75,33 +73,45 @@ class HealingReadPool:
                # TODO could use task's implicit cancel scope
                try:
                        async with conn_ctx as conn:
+                               self._logger.debug("scope up: {}".format(obj_digest(cancel_scope)))
                                with refresh_tx:
-                                       _tsc = False  # TODO remove
+                                       silent_timeout = False
                                        with cancel_scope, suppress(ConnectionClosed):
                                                while True:
                                                        with trio.move_on_after(self._silent_limit.total_seconds()) as timeout_scope:
                                                                message = await conn.get_message()
 
                                                        if timeout_scope.cancelled_caught:
-                                                               if len(self._nursery.child_tasks) == self._size:
-                                                                       _tsc = True
-                                                                       cancel_scope.cancel()
-                                                                       await checkpoint_if_cancelled()
+                                                               if len(self._nursery.child_tasks) - self._closing == self._size:
+                                                                       silent_timeout = True
+                                                                       break
+                                                               else:
+                                                                       self._logger.debug("not replacing connection {}; {} tasks, {} closing".format(
+                                                                               obj_digest(cancel_scope),
+                                                                               len(self._nursery.child_tasks),
+                                                                               self._closing,
+                                                                       ))
                                                        else:
                                                                event = _Message(trio.current_time(), json.loads(message), cancel_scope)
                                                                await self._pool_event_tx.send(event)
 
-                                       assert _tsc == timeout_scope.cancelled_caught
-                                       if timeout_scope.cancelled_caught:
-                                               self._logger.debug("replacing WS connection due to silent timeout")
+                                       self._closing += 1
+                                       if silent_timeout:
                                                await conn.aclose()
+                                               self._logger.debug("replacing WS connection {} due to silent timeout".format(obj_digest(cancel_scope)))
                                        elif cancel_scope.cancelled_caught:
                                                await conn.aclose(1008, "Server unexpectedly stopped sending messages")
-                                               self._logger.warning("replacing WS connection due to missed update")
+                                               self._logger.warning("replacing WS connection {} due to missed update".format(obj_digest(cancel_scope)))
                                        else:
-                                               self._logger.warning("replacing WS connection closed by server")
-                                       self._logger.debug(f"post-kill tasks {len(self._nursery.child_tasks)}")
+                                               self._logger.warning("replacing WS connection {} closed by server".format(obj_digest(cancel_scope)))
+
                                        refresh_tx.send_nowait(cancel_scope)
+                                       self._logger.debug("scope down: {} ({} tasks, {} closing)".format(
+                                               obj_digest(cancel_scope),
+                                               len(self._nursery.child_tasks),
+                                               self._closing,
+                                       ))
+                                       self._closing -= 1
                except HandshakeError:
                        self._logger.error("handshake error while opening WS connection")
 
@@ -122,7 +132,6 @@ class HealingReadPool:
                        await self._pool_event_tx.send(_ConnectionUp(trio.current_time(), new_scope))
                        refresh_tx = self._refresh_queue_tx.clone()
                        self._nursery.start_soon(self._reader_impl, conn, new_scope, refresh_tx)
-               self._logger.debug(f"post-spawn tasks {len(self._nursery.child_tasks)}")
 
        async def conn_refresher_impl(self):
                """Task to monitor and replace WS connections as they disconnect or go silent."""
@@ -134,6 +143,14 @@ class HealingReadPool:
                                        raise RuntimeError("WS pool depleted")
 
 
+def _tag_digest(tag):
+       return int_digest(hash(tag))
+
+
+def _format_scope_list(scopes):
+       return ", ".join(sorted(map(obj_digest, scopes)))
+
+
 class PoolMerger:
        @dataclass
        class _Bucket:
@@ -157,33 +174,44 @@ class PoolMerger:
                self._outgoing_scopes = set()
                (self._timer_poke_tx, self._timer_poke_rx) = trio.open_memory_channel(math.inf)
 
+       def _log_current_buckets(self):
+               self._logger.debug("current buckets: {}".format(", ".join(map(_tag_digest, self._buckets.keys()))))
+
        async def event_reader_impl(self):
-               """Drop unused messages, deduplicate useful ones, and install info needed by the timeout handler."""
+               """Drop unused messages, deduplicate useful ones, and communicate with the timeout handler."""
                with self._pool_event_rx, self._timer_poke_tx:
                        async for event in self._pool_event_rx:
                                if isinstance(event, _ConnectionUp):
                                        # An early add of an active scope could mean it's expected on a message that fired before it opened,
-                                       # resulting in a false positive for replacement.  The timer task merges these in among the buckets
-                                       # by timestamp to avoid that.
+                                       # resulting in a false positive for replacement. The timer task merges these in among the buckets by
+                                       # timestamp to avoid that.
                                        self._pending.append(event)
                                elif isinstance(event, _Message):
                                        if event.data["type"] in ["update", "strike", "delete"]:
                                                tag = event.dedup_tag()
+                                               self._logger.debug("recv {} from {}".format(_tag_digest(tag), obj_digest(event.scope)))
                                                if tag in self._buckets:
                                                        b = self._buckets[tag]
                                                        b.recipients.add(event.scope)
-                                                       if b.recipients >= self._scope_activations.keys():
-                                                               del self._buckets[tag]
+                                                       # If this scope is the last one for this bucket we could clear the bucket here, but since
+                                                       # connections sometimes get repeat messages and second copies can arrive before the first
+                                                       # copy has arrived on all connections, leaving the bucket open to absorb repeats can reduce
+                                                       # the likelihood of a second bucket being allocated late for the second copy of a message
+                                                       # and causing unnecessary connection replacement.
                                                elif event.scope not in self._outgoing_scopes:
                                                        sane = (
                                                                event.scope in self._scope_activations
                                                                or any(e.scope == event.scope for e in self._pending)
                                                        )
                                                        if sane:
+                                                               self._logger.debug("new bucket {}: {}".format(_tag_digest(tag), tag))
                                                                self._buckets[tag] = self._Bucket(event.timestamp, {event.scope})
+                                                               self._log_current_buckets()
                                                                await self._message_tx.send(event)
                                                        else:
                                                                raise RuntimeError("recieved message from unrecognized WS connection")
+                                       else:
+                                               self._logger.debug("recv type {!r} from {} (discarding)".format(event.data["type"], obj_digest(event.scope)))
                                elif isinstance(event, _ConnectionDown):
                                        # We don't need to worry about canceling this scope at all, so no need to require it for parity for
                                        # any message, even older ones.  The scope may be gone already, if we canceled it previously.
@@ -212,10 +240,23 @@ class PoolMerger:
                                                if active + self._conn_warmup.total_seconds() < now
                                        }
                                        if now > bucket.start + self._parity_timeout.total_seconds():
-                                               for scope in target_scopes - bucket.recipients:
+                                               filled = target_scopes & bucket.recipients
+                                               missing = target_scopes - bucket.recipients
+                                               extra = bucket.recipients - target_scopes
+
+                                               self._logger.debug("expiring bucket {}".format(_tag_digest(tag)))
+                                               if filled:
+                                                       self._logger.debug("  filled {}: {}".format(len(filled), _format_scope_list(filled)))
+                                               if missing:
+                                                       self._logger.debug("  missing {}: {}".format(len(missing), _format_scope_list(missing)))
+                                               if extra:
+                                                       self._logger.debug("  extra {}: {}".format(len(extra), _format_scope_list(extra)))
+
+                                               for scope in missing:
                                                        self._outgoing_scopes.add(scope)
                                                        scope.cancel()
                                                del self._buckets[tag]
+                                               self._log_current_buckets()
                                        else:
                                                await trio.sleep_until(bucket.start + self._parity_timeout.total_seconds())
                                else:
index b4ebbe32a56ebc37b3336e2bcc4b289a82a1727c..adfbbe64013e0884ca9f8b2624c9cdf87abf9c8e 100644 (file)
@@ -27,6 +27,9 @@ class Queue:
                        await self._empty_wait.park()
                return self._deque.popleft()
 
+       def __len__(self):
+               return len(self._deque)
+
 
 @total_ordering
 @dataclass
index ab59a95fd99c1d0a0c6b3c0f210427e7b6d78590..3e69e6c5730764bfa77f66c0e8df12ba2a25086c 100644 (file)
@@ -12,6 +12,7 @@ import asks
 import trio
 
 from strikebot import __version__ as VERSION
+from strikebot.common import obj_digest
 from strikebot.queue import MaxHeap, Queue
 
 
@@ -190,6 +191,7 @@ class ApiClientPool:
 
                awaken_auths = self._waiting.keys() & tokens.keys()
                self._app_queue.extend(self._waiting.pop(auth_id) for auth_id in awaken_auths)
+               self._logger.debug("updated API tokens")
 
        async def token_updater_impl(self):
                last_update = None
@@ -210,16 +212,18 @@ class ApiClientPool:
 
        async def make_request(self, request: _Request) -> Response:
                (resp_tx, resp_rx) = trio.open_memory_channel(0)
-               self._request_queue.push((request, resp_tx))
-               self._check_queue_size()
+               self.enqueue_request(request, resp_tx)
                async with resp_rx:
                        return await resp_rx.receive()
 
-       def enqueue_request(self, request: _Request) -> None:
-               self._request_queue.push((request, None))
+       def enqueue_request(self, request: _Request, resp_tx = None) -> None:
+               self._logger.debug("request {}: {}".format(obj_digest(request), type(request).__name__))
+               self._request_queue.push((request, resp_tx))
                self._check_queue_size()
+               self._logger.debug(f"{len(self._request_queue)} requests in queue")
 
-       async def worker_impl(self):
+       async def worker_impl(self, task_status):
+               task_status.started()
                while True:
                        (request, resp_tx) = await self._request_queue.pop()
                        cooldown = await self._app_queue.pop()
@@ -247,23 +251,25 @@ class ApiClientPool:
                                resp.body  # read response
                                error = False
                                wait_for_token = False
+                               log_suffix = " (request {})".format(obj_digest(request))
                                if resp.status_code == 429:
                                        # We disagreed about the rate limit state; just try again later.
-                                       self._logger.warning("rate limited by Reddit API")
+                                       self._logger.warning("rate limited by Reddit API" + log_suffix)
                                        error = True
                                elif resp.status_code == 401:
-                                       self._logger.warning("got HTTP 401 from Reddit API")
+                                       self._logger.warning("got HTTP 401 from Reddit API" + log_suffix)
                                        error = True
                                        wait_for_token = True
-                               elif resp.status_code in [404, 500]:
-                                       self._logger.warning(f"got HTTP {resp.status_code} from Reddit API, retrying")
+                               elif resp.status_code in [404, 500, 503]:
+                                       self._logger.warning(f"got HTTP {resp.status_code} from Reddit API, retrying" + log_suffix)
                                        error = True
                                elif 400 <= resp.status_code < 500:
                                        # If we're doing something wrong, let's catch it right away.
-                                       raise RuntimeError(f"unexpected client error response: {resp.status_code}")
+                                       raise RuntimeError(f"unexpected client error response: {resp.status_code}" + log_suffix)
                                else:
                                        if resp.status_code != 200:
                                                raise RuntimeError(f"unexpected status code {resp.status_code}")
+                                       self._logger.debug("success" + log_suffix)
                                        if resp_tx:
                                                await resp_tx.send(resp)
 
@@ -279,5 +285,7 @@ class ApiClientPool:
                        cooldown.ready_at = request_time + REQUEST_DELAY.total_seconds()
                        if wait_for_token:
                                self._waiting[cooldown.auth_id] = cooldown
+                               if not self._app_queue:
+                                       self._logger.error("all workers waiting for API tokens")
                        else:
                                self._app_queue.push(cooldown)
index 55b34261dd44c685a08fb93c88b2cee3af58e54b..7bbf2260a3d82f4aa7622a1ef82002d00ca19b4f 100644 (file)
@@ -17,6 +17,14 @@ class UpdateParsingTests(TestCase):
                self.assertEqual(pu.number, 0)
                self.assertFalse(pu.deletable)
 
+               pu = parse_update(_build_payload("<div>121 345 621</div>"), 121, "")
+               self.assertEqual(pu.number, 121)
+               self.assertFalse(pu.deletable)
+
+               pu = parse_update(_build_payload("28 336 816"), 28336816, "")
+               self.assertEqual(pu.number, 28336816)
+               self.assertTrue(pu.deletable)
+
        def test_non_counts(self):
                pu = parse_update(_build_payload("<div><table><tbody><tr><td>zoo</td></tr></tbody></table></div>"), None, "")
                self.assertFalse(pu.count_attempt)
index f61d4e2432549ff6adc8e01c03adcf1990523f00..e3f58cd477b2e976b44e7674d563ba8bcbcb3ba7 100644 (file)
@@ -1,11 +1,11 @@
 from __future__ import annotations
-from collections import deque
 from dataclasses import dataclass
 from enum import Enum
 from typing import Optional
-from xml.etree import ElementTree
 import re
 
+from bs4 import BeautifulSoup
+
 
 Command = Enum("Command", ["RESET", "REPORT"])
 
@@ -34,84 +34,63 @@ def parse_update(payload_data: dict, curr_count: Optional[int], bot_user: str) -
        SPACE = object()
 
        # flatten the update content to plain text
-       doc = ElementTree.fromstring(payload_data["body_html"])
-       worklist = deque([doc])
+       tree = BeautifulSoup(payload_data["body_html"], "html.parser")
+       worklist = tree.contents
        out = [[]]
        while worklist:
-               el = worklist.popleft()
-               if el is NEW_LINE:
-                       if out[-1]:
-                               out.append([])
+               el = worklist.pop()
+               if isinstance(el, str):
+                       out[-1].append(el)
                elif el is SPACE:
                        out[-1].append(el)
-               elif isinstance(el, str):
-                       out[-1].append(el.replace("\n", " "))
-               elif el.tag in ["em", "strong", "del", "span", "sup", "code", "a", "th", "td"]:
-                       if el.tail:
-                               worklist.appendleft(el.tail)
-                       for sub in reversed(el):
-                               worklist.appendleft(sub)
-                       if el.text:
-                               worklist.appendleft(el.text)
-               elif el.tag == "li":
-                       worklist.appendleft(NEW_LINE)
-                       worklist.appendleft(el.text)
-               elif el.tag in ["p", "div", "h1", "h2", "blockquote"]:
-                       if el.tail:
-                               worklist.appendleft(el.tail)
-                       worklist.appendleft(NEW_LINE)
-                       for sub in reversed(el):
-                               worklist.appendleft(sub)
-                       if el.text:
-                               worklist.appendleft(el.text)
-               elif el.tag in ["ul", "ol"]:
-                       if el.tail:
-                               worklist.appendleft(el.tail)
-                       for sub in reversed(el):
-                               worklist.appendleft(sub)
-                       worklist.appendleft(NEW_LINE)
-               elif el.tag == "pre":
-                       if el.text:
-                               out.extend([l] for l in el.text.splitlines())
-                       worklist.appendleft(NEW_LINE)
-                       if el.tail:
-                               worklist.appendleft(el.tail)
-                       for sub in reversed(el):
-                               worklist.appendleft(sub)
-               elif el.tag == "table":
-                       if el.tail:
-                               worklist.appendleft(el.tail)
-                       worklist.appendleft(NEW_LINE)
-                       for sub in reversed(el):
-                               assert sub.tag in ["thead", "tbody"]
-                               for row in reversed(sub):
-                                       assert row.tag == "tr"
-                                       for (i, cell) in enumerate(reversed(row)):
-                                               worklist.appendleft(cell)
-                                               if i != len(row) - 1:
-                                                       worklist.appendleft(SPACE)
-                                       worklist.appendleft(NEW_LINE)
-               elif el.tag == "br":
-                       if el.tail:
-                               worklist.appendleft(el.tail)
-                       worklist.appendleft(NEW_LINE)
+               elif el is NEW_LINE or el.name == "br":
+                       if out[-1]:
+                               out.append([])
+               elif el.name in ["em", "strong", "del", "span", "sup", "code", "a", "th", "td"]:
+                       worklist.extend(reversed(el.contents))
+               elif el.name in ["ul", "ol", "table", "thead", "tbody"]:
+                       worklist.extend(reversed(el.contents))
+               elif el.name in ["li", "p", "div", "h1", "h2", "blockquote"]:
+                       worklist.append(NEW_LINE)
+                       worklist.extend(reversed(el.contents))
+                       worklist.append(NEW_LINE)
+               elif el.name == "pre":
+                       worklist.append(NEW_LINE)
+                       worklist.extend([l] for l in reversed(el.text.splitlines()))
+                       worklist.append(NEW_LINE)
+               elif el.name == "tr":
+                       worklist.append(NEW_LINE)
+                       for (i, cell) in enumerate(reversed(el.contents)):
+                               worklist.append(cell)
+                               if i != len(el.contents) - 1:
+                                       worklist.append(SPACE)
+                       worklist.append(NEW_LINE)
                else:
                        raise RuntimeError(f"can't parse tag {el.tag}")
 
-       lines = list(filter(
-               None,
-               (
-                       "".join(" " if part is SPACE else part for part in parts).strip()
-                       for parts in out
-               )
-       ))
+       tmp_lines = (
+               "".join(" " if part is SPACE else part for part in parts).strip()
+               for parts in out
+       )
+       pre_strip_lines = list(filter(None, tmp_lines))
+
+       # normalize whitespace according to HTML rendering rules
+       # https://developer.mozilla.org/en-US/docs/Web/API/Document_Object_Model/Whitespace#explanation
+       stripped_lines = [
+               re.sub(" +", " ", l.replace("\t", " ").replace("\n", " ")).strip(" ")
+               for l in pre_strip_lines
+       ]
+
+       return _parse_from_lines(stripped_lines, curr_count, bot_user)
+
 
+def _parse_from_lines(lines: list[str], curr_count: Optional[int], bot_user: str) -> ParsedUpdate:
        command = next(
                filter(None, (_parse_command(l, bot_user) for l in lines)),
                None
        )
-
        if lines:
+               # look for groups of digits (as many as possible) separated by a uniform separator from the valid set
                first = lines[0]
                match = re.match(
                        "(?P<v>v)?(?P<neg>-)?(?P<num>\\d+((?P<sep>[,. \u2009]|, )\\d+((?P=sep)\\d+)*)?)",
@@ -119,48 +98,75 @@ def parse_update(payload_data: dict, curr_count: Optional[int], bot_user: str) -
                        re.ASCII,  # only recognize ASCII digits
                )
                if match:
-                       ct_str = match["num"]
+                       raw_digits = match["num"]
                        sep = match["sep"]
                        post = first[match.end() :]
 
                        zeros = False
-                       while len(ct_str) > 1 and ct_str[0] == "0":
+                       while len(raw_digits) > 1 and raw_digits[0] == "0":
                                zeros = True
-                               ct_str = ct_str.removeprefix("0").removeprefix(sep or "")
-
-                       parts = ct_str.split(sep) if sep else [ct_str]
-                       parts_valid = (
-                               sep is None
-                               or (
-                                       len(parts[0]) in range(1, 4)
-                                       and all(len(p) == 3 for p in parts[1:])
-                               )
-                       )
-                       digits = "".join(parts)
+                               raw_digits = raw_digits.removeprefix("0").removeprefix(sep or "")
+
+                       parts = raw_digits.split(sep) if sep else [raw_digits]
                        lone = len(lines) == 1 and (not post or post.isspace())
                        typo = False
                        if lone:
-                               if match["v"] and len(ct_str) <= 2:
+                               all_parts_valid = (
+                                       sep is None
+                                       or (
+                                               1 <= len(parts[0]) <= 3
+                                               and all(len(p) == 3 for p in parts[1:])
+                                       )
+                               )
+                               if match["v"] and len(parts) == 1 and len(parts[0]) <= 2:
                                        # failed paste of leading digits
                                        typo = True
-                               elif match["v"] and parts_valid:
+                               elif match["v"] and all_parts_valid:
                                        # v followed by count
                                        typo = True
-                               elif curr_count and curr_count >= 100 and bool(match["neg"]) == (curr_count < 0):
-                                       goal = (sep or "").join(_separate(str(curr_count)))
-                                       partials = [goal[: -2], goal[: -1], goal[: -2] + goal[-1]]
-                                       if ct_str in partials:
+                               elif curr_count is not None and abs(curr_count) >= 100 and bool(match["neg"]) == (curr_count < 0):
+                                       goal_parts = _separate(str(abs(curr_count)))
+                                       partials = [
+                                               goal_parts[: -1] + [goal_parts[-1][: -1]],  # missing last digit
+                                               goal_parts[: -1] + [goal_parts[-1][: -2]],  # missing last two digits
+                                               goal_parts[: -1] + [goal_parts[-1][: -2] + goal_parts[-1][-1]],  # missing second-last digit
+                                       ]
+                                       if parts in partials:
                                                # missing any of last two digits
                                                typo = True
-                                       elif ct_str in [p + goal for p in partials]:
+                                       elif parts in [p[: -1] + [p[-1] + goal_parts[0]] + goal_parts[1 :] for p in partials]:
                                                # double paste
                                                typo = True
 
-                       if match["v"] or zeros or typo or (digits == "0" and match["neg"]):
+                       if match["v"] or zeros or typo or (parts == ["0"] and match["neg"]):
                                number = None
                                count_attempt = True
                                deletable = lone
-                       elif parts_valid:
+                       else:
+                               if curr_count is not None and sep and sep.isspace():
+                                       # Presume that the intended count consists of as many valid digit groups as necessary to match the
+                                       # number of digits in the expected count, if possible.
+                                       digit_count = len(str(abs(curr_count)))
+                                       use_parts = []
+                                       accum = 0
+                                       for (i, part) in enumerate(parts):
+                                               part_valid = len(part) <= 3 if i == 0 else len(part) == 3
+                                               if part_valid and accum < digit_count:
+                                                       use_parts.append(part)
+                                                       accum += len(part)
+                                               else:
+                                                       break
+
+                                       # could still be a no-separator count with some extra digit groups on the same line
+                                       if not use_parts:
+                                               use_parts = [parts[0]]
+
+                                       lone = lone and len(use_parts) == len(parts)
+                               else:
+                                       # current count is unknown or no separator was used
+                                       use_parts = parts
+
+                               digits = "".join(use_parts)
                                number = -int(digits) if match["neg"] else int(digits)
                                special = (
                                        curr_count is not None
@@ -168,15 +174,11 @@ def parse_update(payload_data: dict, curr_count: Optional[int], bot_user: str) -
                                        and _is_special_number(number)
                                )
                                deletable = lone and not special
-                               if post and not post[0].isspace():
+                               if len(use_parts) == len(parts) and post and not post[0].isspace():
                                        count_attempt = curr_count is not None and abs(number - curr_count) <= 25
                                        number = None
                                else:
                                        count_attempt = True
-                       else:
-                               number = None
-                               count_attempt = False
-                               deletable = lone
                else:
                        # no count attempt found
                        number = None