From ad7814650fd94162c093bc938ae068ba6decb945 Mon Sep 17 00:00:00 2001 From: Jakob Cornell Date: Wed, 20 Jul 2022 21:16:11 -0500 Subject: [PATCH] Various MVP improvements, bug fixes, and new debugging stuff --- setup.cfg | 7 +- src/strikebot/__init__.py | 63 +++++++----- src/strikebot/__main__.py | 158 +++++++++++++++++++++++++----- src/strikebot/common.py | 10 ++ src/strikebot/db.py | 6 +- src/strikebot/live_ws.py | 89 ++++++++++++----- src/strikebot/queue.py | 3 + src/strikebot/reddit_api.py | 28 ++++-- src/strikebot/tests.py | 8 ++ src/strikebot/updates.py | 190 ++++++++++++++++++------------------ 10 files changed, 380 insertions(+), 182 deletions(-) create mode 100644 src/strikebot/common.py diff --git a/setup.cfg b/setup.cfg index 1c54cc8..eb1fd25 100644 --- a/setup.cfg +++ b/setup.cfg @@ -8,7 +8,8 @@ package_dir = packages = strikebot python_requires = ~= 3.9 install_requires = - trio ~= 0.19 - triopg ~= 0.6 - trio-websocket ~= 0.9 asks ~= 2.4 + beautifulsoup4 ~= 4.11 + trio == 0.19 + trio-websocket == 0.9.2 + triopg == 0.6.0 diff --git a/src/strikebot/__init__.py b/src/strikebot/__init__.py index 10b4013..d75d4d1 100644 --- a/src/strikebot/__init__.py +++ b/src/strikebot/__init__.py @@ -17,7 +17,6 @@ from strikebot.updates import Command, parse_update __version__ = importlib.metadata.version(__package__) -QUIET = True # TODO remove @dataclass @@ -36,6 +35,12 @@ class _Update: """Determine whether this count update can follow another count update.""" return self.number == prior.number + 1 and self.author != prior.author + def __str__(self): + content = str(self.number) + if self.command: + content += "+" + self.command.name + return "{}({} by {})".format(self.id[4 : 8], content, self.author) + @total_ordering @dataclass @@ -100,6 +105,7 @@ async def count_tracker_impl( timeline = [] last_valid: Optional[_Update] = None pending_strikes: set[str] = set() # names of updates to mark stricken on arrival + forgotten: bool = False # whether the retention period for an update has passed during the run def handle_update(update): nonlocal last_valid @@ -108,7 +114,8 @@ async def count_tracker_impl( pos = bisect.bisect(timeline, tu) if pos != len(timeline): - logger.warning(f"long transpo: {update.name}") + logger.warning(f"long transpo: {update.id}") + if pos > 0 and timeline[pos - 1].update.id == update.id: # The pool sent this update message multiple times. This could be because the message reached parity and # its bucket was dropped just before a new connection was opened and sent its own copy, but connections also @@ -119,11 +126,13 @@ async def count_tracker_impl( (timeline[i].update for i in reversed(range(pos)) if timeline[i].accepted), None ) - if update.command is not Command.RESET and update.number is not None and last_valid and not pred: - logger.warning("ignoring {update.name}: no valid prior count on record") + contender = update.command is Command.RESET or update.number is not None + if contender and pred is None and last_valid and forgotten: + # This is a really long transpo which we can't judge valid or invalid since we've forgotten the surrounding + # updates. The best we can do is just ignore it. + logger.warning(f"ignoring {update.id}: no valid prior count on record") else: timeline.insert(pos, tu) - assert len({ctu.update.id for ctu in timeline}) == len(timeline) # TODO remove tu.accepted = ( update.command is Command.RESET or ( @@ -131,30 +140,37 @@ async def count_tracker_impl( and (pred is None or pred.number is None or update.can_follow(pred)) ) ) - logger.debug(f"accepted: {tu.accepted}") - logger.debug(" pred: {}".format(pred and (pred.id, pred.number, pred.author))) + logger.debug("{} {} -> {}".format("accepted" if tu.accepted else "rejected", pred, update)) if tu.accepted: # resync subsequent updates newly_valid = [] newly_invalid = [] resync_last_valid = update + converged = False for scan_tu in islice(timeline, pos + 1, None): if scan_tu.update.command is Command.RESET: - break + resync_last_valid = scan_tu.update + if last_valid: + converged = True elif scan_tu.update.number is not None: - accept = last_valid.number is None or scan_tu.update.can_follow(last_valid) + accept = resync_last_valid.number is None or scan_tu.update.can_follow(resync_last_valid) if accept and scan_tu.accepted: # resync would have no effect past this point - break + if last_valid: + converged = True elif accept: newly_valid.append(scan_tu) resync_last_valid = scan_tu.update elif scan_tu.accepted: newly_invalid.append(scan_tu) + if scan_tu.update is last_valid: + last_valid = None scan_tu.accepted = accept + if converged: + break - if last_valid: - last_valid = max(last_valid, resync_last_valid, key = lambda u: u.ts) + if converged: + assert last_valid.ts >= resync_last_valid.ts else: last_valid = resync_last_valid @@ -173,13 +189,12 @@ async def count_tracker_impl( ) if update.stricken: - logger.info(f"bad strike of {update.name}") + logger.info(f"bad strike of {update.id}") parts.append(_format_bad_strike_alert(update, thread_id)) if parts: parts.append(_format_curr_count(last_valid)) - if not QUIET: - api_pool.enqueue_request(CorrectionUpdateRequest(thread_id, "\n\n".join(parts))) + api_pool.enqueue_request(CorrectionUpdateRequest(thread_id, "\n\n".join(parts))) for invalid_tu in newly_invalid: if not invalid_tu.update.stricken: @@ -192,8 +207,7 @@ async def count_tracker_impl( update.stricken = True if update.command is Command.REPORT: - if not QUIET: - api_pool.enqueue_request(ReportUpdateRequest(thread_id, body = _format_curr_count(last_valid))) + api_pool.enqueue_request(ReportUpdateRequest(thread_id, body = _format_curr_count(last_valid))) with message_rx: while True: @@ -246,8 +260,8 @@ async def count_tracker_impl( if slot: if not slot.update.stricken: slot.update.stricken = True - if not QUIET and isinstance(slot, _TimelineUpdate) and slot.accepted: - logger.info(f"bad strike of {slot.update.name}") + if isinstance(slot, _TimelineUpdate) and slot.accepted: + logger.info(f"bad strike of {slot.update.id}") body = _format_bad_strike_alert(slot.update, thread_id) api_pool.enqueue_request(CorrectionUpdateRequest(thread_id, body)) else: @@ -272,15 +286,15 @@ async def count_tracker_impl( if process: if bu.update.ts < threshold: - logger.warning(f"ignoring {bu.update.name}: arrived past retention window") + logger.warning(f"ignoring {bu.update}: arrived past retention window") else: - logger.debug(f"processing update {bu.update.id} ({bu.update.number} by {bu.update.author})") + logger.debug(f"processing {bu.update}") handle_update(bu.update) else: - logger.debug(f"holding {bu.update.id} ({bu.update.number}, checked against {last_valid.number})") + logger.debug(f"holding {bu.update}, checked against {last_valid})") new_buffer.append(bu) buffer_ = new_buffer - logger.debug("last count {}".format(last_valid and (last_valid.id, last_valid.number))) # TODO remove + logger.debug("last count {}".format(last_valid)) # delete/forget old updates new_timeline = [] @@ -293,5 +307,8 @@ async def count_tracker_impl( api_pool.enqueue_request(DeleteRequest(thread_id, tu.update.name, tu.update.ts)) elif tu.update is last_valid: new_timeline.append(tu) + + if i: + forgotten = True new_timeline.extend(islice(timeline, i, None)) timeline = new_timeline diff --git a/src/strikebot/__main__.py b/src/strikebot/__main__.py index 217f5c9..63da2e1 100644 --- a/src/strikebot/__main__.py +++ b/src/strikebot/__main__.py @@ -1,10 +1,17 @@ +from enum import Enum +from inspect import getmodule +from logging import FileHandler, getLogger, StreamHandler +from signal import SIGUSR1 +from sys import stdout +from traceback import StackSummary +from typing import Optional import argparse import configparser import datetime as dt import logging -import sys -import trio +from trio import open_memory_channel, open_nursery, open_signal_receiver +from trio.lowlevel import current_root_task, Task import trio_asyncio import triopg @@ -14,6 +21,84 @@ from strikebot.live_ws import HealingReadPool, PoolMerger from strikebot.reddit_api import ApiClientPool +_DEBUG_LOG_PATH: Optional[str] = None # path to file to write debug logs to, if any + + +_TaskKind = Enum( + "_TaskKind", + [ + "API_WORKER", + "DB_CLIENT", + "TOKEN_UPDATE", + "TRACKER", + "WS_MERGE_READER", + "WS_MERGE_TIMER", + "WS_REFRESHER", + "WS_WORKER", + ] +) + + +def _classify_task(task: Task) -> _TaskKind: + mod_name = getmodule(task.coro.cr_code).__name__ + if mod_name.startswith("trio_asyncio."): + return None + elif task.coro.__name__ == "_signal_handler_impl": + return None + else: + by_coro_name = { + "db_client_impl": _TaskKind.DB_CLIENT, + "token_updater_impl": _TaskKind.TOKEN_UPDATE, + "event_reader_impl": _TaskKind.WS_MERGE_READER, + "timeout_handler_impl": _TaskKind.WS_MERGE_TIMER, + "conn_refresher_impl": _TaskKind.WS_REFRESHER, + "count_tracker_impl": _TaskKind.TRACKER, + "_reader_impl": _TaskKind.WS_WORKER, + "worker_impl": _TaskKind.API_WORKER, + } + return by_coro_name[task.coro.__name__] + + +def _get_all_tasks(): + [ta_loop] = [ + task + for nursery in current_root_task().child_nurseries + for task in nursery.child_tasks + if getmodule(task.coro.cr_code).__name__.startswith("trio_asyncio.") + ] + for nursery in ta_loop.child_nurseries: + yield from nursery.child_tasks + + +def _print_task_info(): + groups = {kind: [] for kind in _TaskKind} + for task in _get_all_tasks(): + kind = _classify_task(task) + if kind: + coro = task.coro + frame_tuples = [] + while coro is not None: + if hasattr(coro, "cr_frame"): + frame_tuples.append((coro.cr_frame, coro.cr_frame.f_lineno)) + coro = coro.cr_await + else: + frame_tuples.append((coro.gi_frame, coro.gi_frame.f_lineno)) + coro = coro.gi_yieldfrom + groups[kind].append(StackSummary.extract(iter(frame_tuples))) + for kind in _TaskKind: + print(kind.name) + for ss in groups[kind]: + print(" task") + for line in ss.format(): + print(" " + line, end = "") + + +async def _signal_handler_impl(): + with open_signal_receiver(SIGUSR1) as sig_src: + async for _ in sig_src: + _print_task_info() + + ap = argparse.ArgumentParser(__package__) ap.add_argument("config_path") args = ap.parse_args() @@ -39,46 +124,67 @@ ws_pool_size = main_cfg.getint("WS pool size") ws_silent_limit = dt.timedelta(seconds = main_cfg.getfloat("WS silent limit")) ws_warmup = dt.timedelta(seconds = main_cfg.getfloat("WS warmup seconds")) -db_connect_params = dict(parser["db connect params"]) +db_cfg = parser["db connect params"] +getters = { + "port": db_cfg.getint, +} +db_connect_params = {k: getters.get(k, db_cfg.get)(k) for k in db_cfg} -logger = logging.getLogger(__package__) +logger = getLogger(__package__) logger.setLevel(logging.DEBUG) -handler = logging.StreamHandler(sys.stdout) -handler.setFormatter(logging.Formatter("{asctime:23}: {levelname:8}: {message}", style = "{")) +handler = StreamHandler(stdout) +handler.setLevel(logging.WARNING) +handler.setFormatter(logging.Formatter("{asctime:23}: {name:17} {levelname:8}: {message}", style = "{")) logger.addHandler(handler) +if _DEBUG_LOG_PATH: + debug_handler = FileHandler(_DEBUG_LOG_PATH, mode = "w") + debug_handler.setLevel(logging.DEBUG) + debug_handler.setFormatter(logging.Formatter("{asctime:23}: {name:17} {levelname:8}: {message}", style = "{")) + logger.addHandler(debug_handler) + async def main(): - async with trio.open_nursery() as nursery_a, triopg.connect(**db_connect_params) as db_conn: - (req_tx, req_rx) = trio.open_memory_channel(0) + async with ( + triopg.connect(**db_connect_params) as db_conn, + open_nursery() as nursery_a, + open_nursery() as nursery_b, + open_nursery() as ws_pool_nursery, + open_nursery() as nursery_c, + ): + nursery_a.start_soon(_signal_handler_impl) + client = Client(db_conn) db_messenger = Messenger(client) - nursery_a.start_soon(db_messenger.db_client_impl) api_pool = ApiClientPool( - auth_ids, db_messenger, request_queue_limit, api_pool_error_window, api_pool_error_delay, logger + auth_ids, db_messenger, request_queue_limit, api_pool_error_window, api_pool_error_delay, + logger.getChild("api") ) nursery_a.start_soon(api_pool.token_updater_impl) for _ in auth_ids: - nursery_a.start_soon(api_pool.worker_impl) - - (message_tx, message_rx) = trio.open_memory_channel(0) - (pool_event_tx, pool_event_rx) = trio.open_memory_channel(0) - merger = PoolMerger(pool_event_rx, message_tx, ws_parity_time, ws_warmup, logger) - async with trio.open_nursery() as nursery_b, trio.open_nursery() as ws_pool_nursery: - nursery_b.start_soon(merger.event_reader_impl) - nursery_b.start_soon(merger.timeout_handler_impl) - ws_pool = HealingReadPool(ws_pool_nursery, ws_pool_size, thread_id, api_pool, pool_event_tx, logger, ws_silent_limit) - async with ws_pool, trio.open_nursery() as nursery_c: - nursery_c.start_soon(ws_pool.conn_refresher_impl) - nursery_c.start_soon( - count_tracker_impl, - message_rx, api_pool, reorder_buffer_time, thread_id, bot_user, enforcing, update_retention, logger - ) - await ws_pool.init_workers() + await nursery_a.start(api_pool.worker_impl) + (message_tx, message_rx) = open_memory_channel(0) + (pool_event_tx, pool_event_rx) = open_memory_channel(0) + + merger = PoolMerger(pool_event_rx, message_tx, ws_parity_time, ws_warmup, logger.getChild("merge")) + nursery_b.start_soon(merger.event_reader_impl) + nursery_b.start_soon(merger.timeout_handler_impl) + + ws_pool = HealingReadPool( + ws_pool_nursery, ws_pool_size, thread_id, api_pool, pool_event_tx, logger.getChild("live_ws"), + ws_silent_limit + ) + nursery_c.start_soon(ws_pool.conn_refresher_impl) + + nursery_c.start_soon( + count_tracker_impl, message_rx, api_pool, reorder_buffer_time, thread_id, bot_user, enforcing, + update_retention, logger.getChild("track") + ) + await ws_pool.init_workers() trio_asyncio.run(main) diff --git a/src/strikebot/common.py b/src/strikebot/common.py new file mode 100644 index 0000000..519037a --- /dev/null +++ b/src/strikebot/common.py @@ -0,0 +1,10 @@ +from typing import Any + + +def int_digest(val: int) -> str: + return "{:04x}".format(val & 0xffff) + + +def obj_digest(obj: Any) -> str: + """Makes a short digest of the identity of an object.""" + return int_digest(id(obj)) diff --git a/src/strikebot/db.py b/src/strikebot/db.py index a88ffe8..c9a060a 100644 --- a/src/strikebot/db.py +++ b/src/strikebot/db.py @@ -8,8 +8,9 @@ import trio def _channel_sender(method): async def wrapped(self, resp_channel, *args, **kwargs): + ret = await method(self, *args, **kwargs) with resp_channel: - await resp_channel.send(await method(self, *args, **kwargs)) + await resp_channel.send(ret) return wrapped @@ -45,7 +46,8 @@ class Messenger: """This is run by consumers of the DB wrapper.""" method = getattr(self._client, method_name) (resp_tx, resp_rx) = trio.open_memory_channel(0) - await self._request_tx.send(method(resp_tx, *args)) + coro = method(resp_tx, *args) + await self._request_tx.send(coro) async with resp_rx: return await resp_rx.receive() diff --git a/src/strikebot/live_ws.py b/src/strikebot/live_ws.py index b257905..cb8b3d0 100644 --- a/src/strikebot/live_ws.py +++ b/src/strikebot/live_ws.py @@ -7,10 +7,10 @@ import json import logging import math -from trio.lowlevel import checkpoint_if_cancelled from trio_websocket import ConnectionClosed, HandshakeError, open_websocket_url, WebSocketConnection import trio +from strikebot.common import int_digest, obj_digest from strikebot.reddit_api import AboutLiveThreadRequest @@ -57,13 +57,11 @@ class HealingReadPool: self._pool_event_tx = pool_event_tx self._logger = logger self._silent_limit = silent_limit - (self._refresh_queue_tx, self._refresh_queue_rx) = trio.open_memory_channel(size) - async def __aenter__(self): - pass + (self._refresh_queue_tx, self._refresh_queue_rx) = trio.open_memory_channel(size) - async def __aexit__(self, exc_type, exc_value, traceback): - self._refresh_queue_tx.close() + # number of workers who have stopped receiving updates but whose tasks aren't yet stopped + self._closing = 0 async def init_workers(self): for _ in range(self._size): @@ -75,33 +73,45 @@ class HealingReadPool: # TODO could use task's implicit cancel scope try: async with conn_ctx as conn: + self._logger.debug("scope up: {}".format(obj_digest(cancel_scope))) with refresh_tx: - _tsc = False # TODO remove + silent_timeout = False with cancel_scope, suppress(ConnectionClosed): while True: with trio.move_on_after(self._silent_limit.total_seconds()) as timeout_scope: message = await conn.get_message() if timeout_scope.cancelled_caught: - if len(self._nursery.child_tasks) == self._size: - _tsc = True - cancel_scope.cancel() - await checkpoint_if_cancelled() + if len(self._nursery.child_tasks) - self._closing == self._size: + silent_timeout = True + break + else: + self._logger.debug("not replacing connection {}; {} tasks, {} closing".format( + obj_digest(cancel_scope), + len(self._nursery.child_tasks), + self._closing, + )) else: event = _Message(trio.current_time(), json.loads(message), cancel_scope) await self._pool_event_tx.send(event) - assert _tsc == timeout_scope.cancelled_caught - if timeout_scope.cancelled_caught: - self._logger.debug("replacing WS connection due to silent timeout") + self._closing += 1 + if silent_timeout: await conn.aclose() + self._logger.debug("replacing WS connection {} due to silent timeout".format(obj_digest(cancel_scope))) elif cancel_scope.cancelled_caught: await conn.aclose(1008, "Server unexpectedly stopped sending messages") - self._logger.warning("replacing WS connection due to missed update") + self._logger.warning("replacing WS connection {} due to missed update".format(obj_digest(cancel_scope))) else: - self._logger.warning("replacing WS connection closed by server") - self._logger.debug(f"post-kill tasks {len(self._nursery.child_tasks)}") + self._logger.warning("replacing WS connection {} closed by server".format(obj_digest(cancel_scope))) + refresh_tx.send_nowait(cancel_scope) + self._logger.debug("scope down: {} ({} tasks, {} closing)".format( + obj_digest(cancel_scope), + len(self._nursery.child_tasks), + self._closing, + )) + self._closing -= 1 except HandshakeError: self._logger.error("handshake error while opening WS connection") @@ -122,7 +132,6 @@ class HealingReadPool: await self._pool_event_tx.send(_ConnectionUp(trio.current_time(), new_scope)) refresh_tx = self._refresh_queue_tx.clone() self._nursery.start_soon(self._reader_impl, conn, new_scope, refresh_tx) - self._logger.debug(f"post-spawn tasks {len(self._nursery.child_tasks)}") async def conn_refresher_impl(self): """Task to monitor and replace WS connections as they disconnect or go silent.""" @@ -134,6 +143,14 @@ class HealingReadPool: raise RuntimeError("WS pool depleted") +def _tag_digest(tag): + return int_digest(hash(tag)) + + +def _format_scope_list(scopes): + return ", ".join(sorted(map(obj_digest, scopes))) + + class PoolMerger: @dataclass class _Bucket: @@ -157,33 +174,44 @@ class PoolMerger: self._outgoing_scopes = set() (self._timer_poke_tx, self._timer_poke_rx) = trio.open_memory_channel(math.inf) + def _log_current_buckets(self): + self._logger.debug("current buckets: {}".format(", ".join(map(_tag_digest, self._buckets.keys())))) + async def event_reader_impl(self): - """Drop unused messages, deduplicate useful ones, and install info needed by the timeout handler.""" + """Drop unused messages, deduplicate useful ones, and communicate with the timeout handler.""" with self._pool_event_rx, self._timer_poke_tx: async for event in self._pool_event_rx: if isinstance(event, _ConnectionUp): # An early add of an active scope could mean it's expected on a message that fired before it opened, - # resulting in a false positive for replacement. The timer task merges these in among the buckets - # by timestamp to avoid that. + # resulting in a false positive for replacement. The timer task merges these in among the buckets by + # timestamp to avoid that. self._pending.append(event) elif isinstance(event, _Message): if event.data["type"] in ["update", "strike", "delete"]: tag = event.dedup_tag() + self._logger.debug("recv {} from {}".format(_tag_digest(tag), obj_digest(event.scope))) if tag in self._buckets: b = self._buckets[tag] b.recipients.add(event.scope) - if b.recipients >= self._scope_activations.keys(): - del self._buckets[tag] + # If this scope is the last one for this bucket we could clear the bucket here, but since + # connections sometimes get repeat messages and second copies can arrive before the first + # copy has arrived on all connections, leaving the bucket open to absorb repeats can reduce + # the likelihood of a second bucket being allocated late for the second copy of a message + # and causing unnecessary connection replacement. elif event.scope not in self._outgoing_scopes: sane = ( event.scope in self._scope_activations or any(e.scope == event.scope for e in self._pending) ) if sane: + self._logger.debug("new bucket {}: {}".format(_tag_digest(tag), tag)) self._buckets[tag] = self._Bucket(event.timestamp, {event.scope}) + self._log_current_buckets() await self._message_tx.send(event) else: raise RuntimeError("recieved message from unrecognized WS connection") + else: + self._logger.debug("recv type {!r} from {} (discarding)".format(event.data["type"], obj_digest(event.scope))) elif isinstance(event, _ConnectionDown): # We don't need to worry about canceling this scope at all, so no need to require it for parity for # any message, even older ones. The scope may be gone already, if we canceled it previously. @@ -212,10 +240,23 @@ class PoolMerger: if active + self._conn_warmup.total_seconds() < now } if now > bucket.start + self._parity_timeout.total_seconds(): - for scope in target_scopes - bucket.recipients: + filled = target_scopes & bucket.recipients + missing = target_scopes - bucket.recipients + extra = bucket.recipients - target_scopes + + self._logger.debug("expiring bucket {}".format(_tag_digest(tag))) + if filled: + self._logger.debug(" filled {}: {}".format(len(filled), _format_scope_list(filled))) + if missing: + self._logger.debug(" missing {}: {}".format(len(missing), _format_scope_list(missing))) + if extra: + self._logger.debug(" extra {}: {}".format(len(extra), _format_scope_list(extra))) + + for scope in missing: self._outgoing_scopes.add(scope) scope.cancel() del self._buckets[tag] + self._log_current_buckets() else: await trio.sleep_until(bucket.start + self._parity_timeout.total_seconds()) else: diff --git a/src/strikebot/queue.py b/src/strikebot/queue.py index b4ebbe3..adfbbe6 100644 --- a/src/strikebot/queue.py +++ b/src/strikebot/queue.py @@ -27,6 +27,9 @@ class Queue: await self._empty_wait.park() return self._deque.popleft() + def __len__(self): + return len(self._deque) + @total_ordering @dataclass diff --git a/src/strikebot/reddit_api.py b/src/strikebot/reddit_api.py index ab59a95..3e69e6c 100644 --- a/src/strikebot/reddit_api.py +++ b/src/strikebot/reddit_api.py @@ -12,6 +12,7 @@ import asks import trio from strikebot import __version__ as VERSION +from strikebot.common import obj_digest from strikebot.queue import MaxHeap, Queue @@ -190,6 +191,7 @@ class ApiClientPool: awaken_auths = self._waiting.keys() & tokens.keys() self._app_queue.extend(self._waiting.pop(auth_id) for auth_id in awaken_auths) + self._logger.debug("updated API tokens") async def token_updater_impl(self): last_update = None @@ -210,16 +212,18 @@ class ApiClientPool: async def make_request(self, request: _Request) -> Response: (resp_tx, resp_rx) = trio.open_memory_channel(0) - self._request_queue.push((request, resp_tx)) - self._check_queue_size() + self.enqueue_request(request, resp_tx) async with resp_rx: return await resp_rx.receive() - def enqueue_request(self, request: _Request) -> None: - self._request_queue.push((request, None)) + def enqueue_request(self, request: _Request, resp_tx = None) -> None: + self._logger.debug("request {}: {}".format(obj_digest(request), type(request).__name__)) + self._request_queue.push((request, resp_tx)) self._check_queue_size() + self._logger.debug(f"{len(self._request_queue)} requests in queue") - async def worker_impl(self): + async def worker_impl(self, task_status): + task_status.started() while True: (request, resp_tx) = await self._request_queue.pop() cooldown = await self._app_queue.pop() @@ -247,23 +251,25 @@ class ApiClientPool: resp.body # read response error = False wait_for_token = False + log_suffix = " (request {})".format(obj_digest(request)) if resp.status_code == 429: # We disagreed about the rate limit state; just try again later. - self._logger.warning("rate limited by Reddit API") + self._logger.warning("rate limited by Reddit API" + log_suffix) error = True elif resp.status_code == 401: - self._logger.warning("got HTTP 401 from Reddit API") + self._logger.warning("got HTTP 401 from Reddit API" + log_suffix) error = True wait_for_token = True - elif resp.status_code in [404, 500]: - self._logger.warning(f"got HTTP {resp.status_code} from Reddit API, retrying") + elif resp.status_code in [404, 500, 503]: + self._logger.warning(f"got HTTP {resp.status_code} from Reddit API, retrying" + log_suffix) error = True elif 400 <= resp.status_code < 500: # If we're doing something wrong, let's catch it right away. - raise RuntimeError(f"unexpected client error response: {resp.status_code}") + raise RuntimeError(f"unexpected client error response: {resp.status_code}" + log_suffix) else: if resp.status_code != 200: raise RuntimeError(f"unexpected status code {resp.status_code}") + self._logger.debug("success" + log_suffix) if resp_tx: await resp_tx.send(resp) @@ -279,5 +285,7 @@ class ApiClientPool: cooldown.ready_at = request_time + REQUEST_DELAY.total_seconds() if wait_for_token: self._waiting[cooldown.auth_id] = cooldown + if not self._app_queue: + self._logger.error("all workers waiting for API tokens") else: self._app_queue.push(cooldown) diff --git a/src/strikebot/tests.py b/src/strikebot/tests.py index 55b3426..7bbf226 100644 --- a/src/strikebot/tests.py +++ b/src/strikebot/tests.py @@ -17,6 +17,14 @@ class UpdateParsingTests(TestCase): self.assertEqual(pu.number, 0) self.assertFalse(pu.deletable) + pu = parse_update(_build_payload("
121 345 621
"), 121, "") + self.assertEqual(pu.number, 121) + self.assertFalse(pu.deletable) + + pu = parse_update(_build_payload("28 336 816"), 28336816, "") + self.assertEqual(pu.number, 28336816) + self.assertTrue(pu.deletable) + def test_non_counts(self): pu = parse_update(_build_payload("
zoo
"), None, "") self.assertFalse(pu.count_attempt) diff --git a/src/strikebot/updates.py b/src/strikebot/updates.py index f61d4e2..e3f58cd 100644 --- a/src/strikebot/updates.py +++ b/src/strikebot/updates.py @@ -1,11 +1,11 @@ from __future__ import annotations -from collections import deque from dataclasses import dataclass from enum import Enum from typing import Optional -from xml.etree import ElementTree import re +from bs4 import BeautifulSoup + Command = Enum("Command", ["RESET", "REPORT"]) @@ -34,84 +34,63 @@ def parse_update(payload_data: dict, curr_count: Optional[int], bot_user: str) - SPACE = object() # flatten the update content to plain text - doc = ElementTree.fromstring(payload_data["body_html"]) - worklist = deque([doc]) + tree = BeautifulSoup(payload_data["body_html"], "html.parser") + worklist = tree.contents out = [[]] while worklist: - el = worklist.popleft() - if el is NEW_LINE: - if out[-1]: - out.append([]) + el = worklist.pop() + if isinstance(el, str): + out[-1].append(el) elif el is SPACE: out[-1].append(el) - elif isinstance(el, str): - out[-1].append(el.replace("\n", " ")) - elif el.tag in ["em", "strong", "del", "span", "sup", "code", "a", "th", "td"]: - if el.tail: - worklist.appendleft(el.tail) - for sub in reversed(el): - worklist.appendleft(sub) - if el.text: - worklist.appendleft(el.text) - elif el.tag == "li": - worklist.appendleft(NEW_LINE) - worklist.appendleft(el.text) - elif el.tag in ["p", "div", "h1", "h2", "blockquote"]: - if el.tail: - worklist.appendleft(el.tail) - worklist.appendleft(NEW_LINE) - for sub in reversed(el): - worklist.appendleft(sub) - if el.text: - worklist.appendleft(el.text) - elif el.tag in ["ul", "ol"]: - if el.tail: - worklist.appendleft(el.tail) - for sub in reversed(el): - worklist.appendleft(sub) - worklist.appendleft(NEW_LINE) - elif el.tag == "pre": - if el.text: - out.extend([l] for l in el.text.splitlines()) - worklist.appendleft(NEW_LINE) - if el.tail: - worklist.appendleft(el.tail) - for sub in reversed(el): - worklist.appendleft(sub) - elif el.tag == "table": - if el.tail: - worklist.appendleft(el.tail) - worklist.appendleft(NEW_LINE) - for sub in reversed(el): - assert sub.tag in ["thead", "tbody"] - for row in reversed(sub): - assert row.tag == "tr" - for (i, cell) in enumerate(reversed(row)): - worklist.appendleft(cell) - if i != len(row) - 1: - worklist.appendleft(SPACE) - worklist.appendleft(NEW_LINE) - elif el.tag == "br": - if el.tail: - worklist.appendleft(el.tail) - worklist.appendleft(NEW_LINE) + elif el is NEW_LINE or el.name == "br": + if out[-1]: + out.append([]) + elif el.name in ["em", "strong", "del", "span", "sup", "code", "a", "th", "td"]: + worklist.extend(reversed(el.contents)) + elif el.name in ["ul", "ol", "table", "thead", "tbody"]: + worklist.extend(reversed(el.contents)) + elif el.name in ["li", "p", "div", "h1", "h2", "blockquote"]: + worklist.append(NEW_LINE) + worklist.extend(reversed(el.contents)) + worklist.append(NEW_LINE) + elif el.name == "pre": + worklist.append(NEW_LINE) + worklist.extend([l] for l in reversed(el.text.splitlines())) + worklist.append(NEW_LINE) + elif el.name == "tr": + worklist.append(NEW_LINE) + for (i, cell) in enumerate(reversed(el.contents)): + worklist.append(cell) + if i != len(el.contents) - 1: + worklist.append(SPACE) + worklist.append(NEW_LINE) else: raise RuntimeError(f"can't parse tag {el.tag}") - lines = list(filter( - None, - ( - "".join(" " if part is SPACE else part for part in parts).strip() - for parts in out - ) - )) + tmp_lines = ( + "".join(" " if part is SPACE else part for part in parts).strip() + for parts in out + ) + pre_strip_lines = list(filter(None, tmp_lines)) + + # normalize whitespace according to HTML rendering rules + # https://developer.mozilla.org/en-US/docs/Web/API/Document_Object_Model/Whitespace#explanation + stripped_lines = [ + re.sub(" +", " ", l.replace("\t", " ").replace("\n", " ")).strip(" ") + for l in pre_strip_lines + ] + + return _parse_from_lines(stripped_lines, curr_count, bot_user) + +def _parse_from_lines(lines: list[str], curr_count: Optional[int], bot_user: str) -> ParsedUpdate: command = next( filter(None, (_parse_command(l, bot_user) for l in lines)), None ) - if lines: + # look for groups of digits (as many as possible) separated by a uniform separator from the valid set first = lines[0] match = re.match( "(?Pv)?(?P-)?(?P\\d+((?P[,. \u2009]|, )\\d+((?P=sep)\\d+)*)?)", @@ -119,48 +98,75 @@ def parse_update(payload_data: dict, curr_count: Optional[int], bot_user: str) - re.ASCII, # only recognize ASCII digits ) if match: - ct_str = match["num"] + raw_digits = match["num"] sep = match["sep"] post = first[match.end() :] zeros = False - while len(ct_str) > 1 and ct_str[0] == "0": + while len(raw_digits) > 1 and raw_digits[0] == "0": zeros = True - ct_str = ct_str.removeprefix("0").removeprefix(sep or "") - - parts = ct_str.split(sep) if sep else [ct_str] - parts_valid = ( - sep is None - or ( - len(parts[0]) in range(1, 4) - and all(len(p) == 3 for p in parts[1:]) - ) - ) - digits = "".join(parts) + raw_digits = raw_digits.removeprefix("0").removeprefix(sep or "") + + parts = raw_digits.split(sep) if sep else [raw_digits] lone = len(lines) == 1 and (not post or post.isspace()) typo = False if lone: - if match["v"] and len(ct_str) <= 2: + all_parts_valid = ( + sep is None + or ( + 1 <= len(parts[0]) <= 3 + and all(len(p) == 3 for p in parts[1:]) + ) + ) + if match["v"] and len(parts) == 1 and len(parts[0]) <= 2: # failed paste of leading digits typo = True - elif match["v"] and parts_valid: + elif match["v"] and all_parts_valid: # v followed by count typo = True - elif curr_count and curr_count >= 100 and bool(match["neg"]) == (curr_count < 0): - goal = (sep or "").join(_separate(str(curr_count))) - partials = [goal[: -2], goal[: -1], goal[: -2] + goal[-1]] - if ct_str in partials: + elif curr_count is not None and abs(curr_count) >= 100 and bool(match["neg"]) == (curr_count < 0): + goal_parts = _separate(str(abs(curr_count))) + partials = [ + goal_parts[: -1] + [goal_parts[-1][: -1]], # missing last digit + goal_parts[: -1] + [goal_parts[-1][: -2]], # missing last two digits + goal_parts[: -1] + [goal_parts[-1][: -2] + goal_parts[-1][-1]], # missing second-last digit + ] + if parts in partials: # missing any of last two digits typo = True - elif ct_str in [p + goal for p in partials]: + elif parts in [p[: -1] + [p[-1] + goal_parts[0]] + goal_parts[1 :] for p in partials]: # double paste typo = True - if match["v"] or zeros or typo or (digits == "0" and match["neg"]): + if match["v"] or zeros or typo or (parts == ["0"] and match["neg"]): number = None count_attempt = True deletable = lone - elif parts_valid: + else: + if curr_count is not None and sep and sep.isspace(): + # Presume that the intended count consists of as many valid digit groups as necessary to match the + # number of digits in the expected count, if possible. + digit_count = len(str(abs(curr_count))) + use_parts = [] + accum = 0 + for (i, part) in enumerate(parts): + part_valid = len(part) <= 3 if i == 0 else len(part) == 3 + if part_valid and accum < digit_count: + use_parts.append(part) + accum += len(part) + else: + break + + # could still be a no-separator count with some extra digit groups on the same line + if not use_parts: + use_parts = [parts[0]] + + lone = lone and len(use_parts) == len(parts) + else: + # current count is unknown or no separator was used + use_parts = parts + + digits = "".join(use_parts) number = -int(digits) if match["neg"] else int(digits) special = ( curr_count is not None @@ -168,15 +174,11 @@ def parse_update(payload_data: dict, curr_count: Optional[int], bot_user: str) - and _is_special_number(number) ) deletable = lone and not special - if post and not post[0].isspace(): + if len(use_parts) == len(parts) and post and not post[0].isspace(): count_attempt = curr_count is not None and abs(number - curr_count) <= 25 number = None else: count_attempt = True - else: - number = None - count_attempt = False - deletable = lone else: # no count attempt found number = None -- 2.30.2