From: Jakob Cornell Date: Tue, 6 Sep 2022 03:53:11 +0000 (-0500) Subject: Restructure for Debian packaging (WIP) X-Git-Tag: strikebot-0.0.7~27 X-Git-Url: https://jcornell.net/gitweb/gitweb.cgi?a=commitdiff_plain;h=b06d469eca7e86ae25381d7a0b1509250ffe8e60;p=counting.git Restructure for Debian packaging (WIP) --- diff --git a/.gitignore b/.gitignore index fbf1bc5..9de7c6a 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ *.egg-info __pycache__ .mypy_cache +.pybuild/ diff --git a/build_helper.py b/build_helper.py new file mode 100644 index 0000000..c669c83 --- /dev/null +++ b/build_helper.py @@ -0,0 +1,62 @@ +""" +A few tools to help keep Debian packages and artifacts well organized. +""" + +from argparse import ArgumentParser +from pathlib import Path +from subprocess import run +import re + + +def _is_output(path: Path) -> bool: + return any( + path.name.endswith(suffix) + for suffix in [ + ".build", + ".buildinfo", + ".changes", + ".deb", + ".debian.tar.xz", + ".dsc", + ] + ) + + +project_root = Path(__file__).parent +upstream_root = project_root.joinpath("strikebot") +build_dir = project_root.joinpath("build") + +parser = ArgumentParser() +tmp = parser.add_subparsers(dest = "cmd", required = True) + +tmp.add_parser("build") +tmp.add_parser("clean") + +args = parser.parse_args() +if args.cmd == "build": + # extract version from Python package config + patt = re.compile("version *= *(.+?) *$") + with upstream_root.joinpath("setup.cfg").open() as f: + [m] = filter(None, map(patt.match, f)) + version = m[1] + + # delete stale "orig" tarball + for p in project_root.glob("strikebot_*.orig.tar.xz"): + p.unlink() + + # regenerate the "orig" tarball from the current source + run(["dh_make", "-y", "--indep", "--createorig", "-p", "strikebot_" + version], cwd = upstream_root) + + # build the package + run(["debuild"], cwd = upstream_root, check = True) + + # move outputs to build directory + for p in project_root.iterdir(): + if _is_output(p): + p.rename(build_dir.joinpath(p.relative_to(project_root))) +elif args.cmd == "clean": + for p in build_dir.iterdir(): + if _is_output(p): + p.unlink() +else: + raise AssertionError() diff --git a/docs/sample_config.ini b/docs/sample_config.ini deleted file mode 100644 index 28e510e..0000000 --- a/docs/sample_config.ini +++ /dev/null @@ -1,53 +0,0 @@ -# see Python `configparser' docs for precise file syntax - -[config] - -######## basic operating parameters - -# space-separated authorization IDs for Reddit API token lookup -auth IDs = 13 15 - -bot user = count_better -enforcing = true -thread ID = abc123 - - -######## API pool error handling - -# (seconds) for error backoff, pause the API pool for this much time -API pool error delay = 5.5 - -# (seconds) pause the pool when the Reddit API returns two errors within any time span of this length -API pool error window = 5.5 - -# terminate if the Reddit API request queue exceeds this size -request queue limit = 100 - - -######## live thread update handling - -# (seconds) maximum time to hold updates for reordering -reorder buffer time = 0.25 - -# (seconds) minimum time to retain updates to enable future resyncs -update retention time = 120.0 - - -######## WebSocket pool - -# (seconds) maximum allowable spread in WebSocket message arrival times -WS parity time = 1.0 - -# goal size of WebSocket connection pool -WS pool size = 5 - -# (seconds) after receiving no messages for this much time, attempt to replace the connection if the pool isn't depleted -WS silent limit = 30 - -# (seconds) time after WebSocket handshake completion during which missed updates are excused -WS warmup time = 0.3 - - -# Postgres database configuration; same options as in a connect string -[db connect params] -host = example.org diff --git a/pyproject.toml b/pyproject.toml deleted file mode 100644 index 8fe2f47..0000000 --- a/pyproject.toml +++ /dev/null @@ -1,3 +0,0 @@ -[build-system] -requires = ["setuptools>=42", "wheel"] -build-backend = "setuptools.build_meta" diff --git a/setup.cfg b/setup.cfg deleted file mode 100644 index eb1fd25..0000000 --- a/setup.cfg +++ /dev/null @@ -1,15 +0,0 @@ -[metadata] -name = strikebot -version = 0.0.0 - -[options] -package_dir = - = src -packages = strikebot -python_requires = ~= 3.9 -install_requires = - asks ~= 2.4 - beautifulsoup4 ~= 4.11 - trio == 0.19 - trio-websocket == 0.9.2 - triopg == 0.6.0 diff --git a/setup.py b/setup.py deleted file mode 100644 index 056ba45..0000000 --- a/setup.py +++ /dev/null @@ -1,4 +0,0 @@ -import setuptools - - -setuptools.setup() diff --git a/src/strikebot/__init__.py b/src/strikebot/__init__.py deleted file mode 100644 index 7c37d14..0000000 --- a/src/strikebot/__init__.py +++ /dev/null @@ -1,321 +0,0 @@ -from contextlib import nullcontext as nullcontext -from dataclasses import dataclass -from functools import total_ordering -from itertools import islice -from typing import Optional -from uuid import UUID -import bisect -import datetime as dt -import importlib.metadata -import itertools -import logging - -from trio import CancelScope, EndOfChannel -import trio - -from strikebot.updates import Command, parse_update - - -__version__ = importlib.metadata.version(__package__) - - -_READ_ONLY: bool = False # suppress any API requests that modify the thread? - - -@dataclass -class _Update: - id: str - name: str - author: str - number: Optional[int] - command: Optional[Command] - ts: dt.datetime - count_attempt: bool - deletable: bool - stricken: bool - - def can_follow(self, prior: "_Update") -> bool: - """Determine whether this count update can follow another count update.""" - return self.number == prior.number + 1 and self.author != prior.author - - def __str__(self): - content = str(self.number) - if self.command: - content += "+" + self.command.name - return "{}({} by {})".format(self.id[4 : 8], content, self.author) - - -@total_ordering -@dataclass -class _BufferedUpdate: - update: _Update - release_at: float # Trio time - - def __lt__(self, other): - return self.update.ts < other.update.ts - - -@total_ordering -@dataclass -class _TimelineUpdate: - update: _Update - accepted: bool - - def __lt__(self, other): - return self.update.ts < other.update.ts - - def rejected(self) -> bool: - """Whether the update should be stricken.""" - return self.update.count_attempt and not self.accepted - - -def _parse_rfc_4122_ts(ts: int) -> dt.datetime: - epoch = dt.datetime(1582, 10, 15, tzinfo = dt.timezone.utc) - return epoch + dt.timedelta(microseconds = ts / 10) - - -def _format_update_ref(update: _Update, thread_id: str) -> str: - url = f"https://www.reddit.com/live/{thread_id}/updates/{update.id}" - author_md = update.author.replace("_", "\\_") # Reddit allows letters, digits, underscore, and hyphen - return f"[{update.number:,} by {author_md}]({url})" - - -def _format_bad_strike_alert(update: _Update, thread_id: str) -> str: - return _format_update_ref(update, thread_id) + " was incorrectly stricken." - - -def _format_curr_count(last_valid: Optional[_Update]) -> str: - if last_valid and last_valid.number is not None: - count_str = f"{last_valid.number + 1:,}" - else: - count_str = "unknown" - return f"_current count:_ {count_str}" - - -async def count_tracker_impl( - message_rx, - api_pool: "strikebot.reddit_api.ApiClientPool", - reorder_buffer_time: dt.timedelta, # duration to hold some updates for reordering - thread_id: str, - bot_user: str, - enforcing: bool, - update_retention: dt.timedelta, - logger: logging.Logger, -) -> None: - from strikebot.reddit_api import CorrectionUpdateRequest, DeleteRequest, ReportUpdateRequest, StrikeRequest - - enforcing = enforcing and not _READ_ONLY - - buffer_ = [] - timeline = [] - last_valid: Optional[_Update] = None - pending_strikes: set[str] = set() # names of updates to mark stricken on arrival - forgotten: bool = False # whether the retention period for an update has passed during the run - - def handle_update(update): - nonlocal last_valid - - tu = _TimelineUpdate(update, accepted = False) - - pos = bisect.bisect(timeline, tu) - if pos != len(timeline): - logger.warning(f"long transpo: {update.id}") - - if pos > 0 and timeline[pos - 1].update.id == update.id: - # The pool sent this update message multiple times. This could be because the message reached parity and - # its bucket was dropped just before a new connection was opened and sent its own copy, but connections also - # seem to send duplicate messages. - return - - pred = next( - (timeline[i].update for i in reversed(range(pos)) if timeline[i].accepted), - None - ) - contender = update.command is Command.RESET or update.number is not None - if contender and pred is None and last_valid and forgotten: - # This is a really long transpo which we can't judge valid or invalid since we've forgotten the surrounding - # updates. The best we can do is just ignore it. - logger.warning(f"ignoring {update.id}: no valid prior count on record") - else: - timeline.insert(pos, tu) - tu.accepted = ( - update.command is Command.RESET - or ( - update.number is not None - and (pred is None or pred.number is None or update.can_follow(pred)) - ) - ) - logger.debug("{} {} -> {}".format("accepted" if tu.accepted else "rejected", pred, update)) - if tu.accepted: - # resync subsequent updates - newly_valid = [] - newly_invalid = [] - resync_last_valid = update - converged = False - for scan_tu in islice(timeline, pos + 1, None): - if scan_tu.update.command is Command.RESET: - resync_last_valid = scan_tu.update - if last_valid: - converged = True - elif scan_tu.update.number is not None: - accept = resync_last_valid.number is None or scan_tu.update.can_follow(resync_last_valid) - if accept and scan_tu.accepted: - # resync would have no effect past this point - if last_valid: - converged = True - elif accept: - newly_valid.append(scan_tu) - resync_last_valid = scan_tu.update - elif scan_tu.accepted: - newly_invalid.append(scan_tu) - if scan_tu.update is last_valid: - last_valid = None - scan_tu.accepted = accept - if converged: - break - - if converged: - assert last_valid.ts >= resync_last_valid.ts - else: - last_valid = resync_last_valid - - parts = [] - if newly_valid: - parts.append( - "The following counts are valid:\n\n" - + "\n".join(" - " + _format_update_ref(tu.update, thread_id) for tu in newly_valid) - ) - - unstrikable = [tu for tu in newly_invalid if tu.update.stricken] - if unstrikable: - parts.append( - "The following counts are invalid:\n\n" - + "\n".join(" - " + _format_update_ref(tu.update, thread_id) for tu in unstrikable) - ) - - if update.stricken: - logger.info(f"bad strike of {update.id}") - parts.append(_format_bad_strike_alert(update, thread_id)) - - if parts: - parts.append(_format_curr_count(last_valid)) - if not _READ_ONLY: - api_pool.enqueue_request(CorrectionUpdateRequest(thread_id, "\n\n".join(parts))) - - for invalid_tu in newly_invalid: - if not invalid_tu.update.stricken: - if enforcing: - api_pool.enqueue_request(StrikeRequest(thread_id, invalid_tu.update.name, invalid_tu.update.ts)) - invalid_tu.update.stricken = True - elif update.count_attempt: - if enforcing: - api_pool.enqueue_request(StrikeRequest(thread_id, update.name, update.ts)) - update.stricken = True - - if not _READ_ONLY and update.command is Command.REPORT: - api_pool.enqueue_request(ReportUpdateRequest(thread_id, body = _format_curr_count(last_valid))) - - with message_rx: - while True: - if buffer_: - deadline = min(bu.release_at for bu in buffer_) - cancel_scope = CancelScope(deadline = deadline) - else: - cancel_scope = nullcontext() - - msg = None - with cancel_scope: - try: - msg = await message_rx.receive() - except EndOfChannel: - break - - if msg: - if msg.data["type"] == "update": - payload_data = msg.data["payload"]["data"] - stricken = payload_data["name"] in pending_strikes - pending_strikes.discard(payload_data["name"]) - if payload_data["author"] != bot_user: - if last_valid and last_valid.number is not None: - next_up = last_valid.number + 1 - else: - next_up = None - pu = parse_update(payload_data, next_up, bot_user) - rfc_ts = UUID(payload_data["id"]).time - update = _Update( - id = payload_data["id"], - name = payload_data["name"], - author = payload_data["author"], - number = pu.number, - command = pu.command, - ts = _parse_rfc_4122_ts(rfc_ts), - count_attempt = pu.count_attempt, - deletable = pu.deletable, - stricken = stricken or payload_data["stricken"], - ) - release_at = msg.timestamp + reorder_buffer_time.total_seconds() - bisect.insort(buffer_, _BufferedUpdate(update, release_at)) - elif msg.data["type"] == "strike": - slot = next( - ( - slot for slot in itertools.chain(buffer_, reversed(timeline)) - if slot.update.name == msg.data["payload"] - ), - None - ) - if slot: - if not slot.update.stricken: - slot.update.stricken = True - if isinstance(slot, _TimelineUpdate) and slot.accepted: - logger.info(f"bad strike of {slot.update.id}") - if not _READ_ONLY: - body = _format_bad_strike_alert(slot.update, thread_id) - api_pool.enqueue_request(CorrectionUpdateRequest(thread_id, body)) - else: - pending_strikes.add(msg.data["payload"]) - - threshold = dt.datetime.now(dt.timezone.utc) - update_retention - - # pull updates from the reorder buffer and process - new_buffer = [] - for bu in buffer_: - process = ( - # not a count - bu.update.number is None - - # valid count - or (last_valid is None or last_valid.number is None or bu.update.can_follow(last_valid)) - or bu.update.command is Command.RESET - - or trio.current_time() >= bu.release_at - or bu.update.command is Command.RESET - ) - - if process: - if bu.update.ts < threshold: - logger.warning(f"ignoring {bu.update}: arrived past retention window") - else: - logger.debug(f"processing {bu.update}") - handle_update(bu.update) - else: - logger.debug(f"holding {bu.update}, checked against {last_valid})") - new_buffer.append(bu) - buffer_ = new_buffer - logger.debug("last count {}".format(last_valid)) - - # delete/forget old updates - new_timeline = [] - i = 0 - for (i, tu) in enumerate(timeline): - if i >= len(timeline) - 10 or tu.update.ts >= threshold: - break - elif tu.rejected() and tu.update.deletable: - if enforcing: - api_pool.enqueue_request(DeleteRequest(thread_id, tu.update.name, tu.update.ts)) - elif tu.update is last_valid: - new_timeline.append(tu) - - if i: - forgotten = True - new_timeline.extend(islice(timeline, i, None)) - timeline = new_timeline diff --git a/src/strikebot/__main__.py b/src/strikebot/__main__.py deleted file mode 100644 index 63da2e1..0000000 --- a/src/strikebot/__main__.py +++ /dev/null @@ -1,190 +0,0 @@ -from enum import Enum -from inspect import getmodule -from logging import FileHandler, getLogger, StreamHandler -from signal import SIGUSR1 -from sys import stdout -from traceback import StackSummary -from typing import Optional -import argparse -import configparser -import datetime as dt -import logging - -from trio import open_memory_channel, open_nursery, open_signal_receiver -from trio.lowlevel import current_root_task, Task -import trio_asyncio -import triopg - -from strikebot import count_tracker_impl -from strikebot.db import Client, Messenger -from strikebot.live_ws import HealingReadPool, PoolMerger -from strikebot.reddit_api import ApiClientPool - - -_DEBUG_LOG_PATH: Optional[str] = None # path to file to write debug logs to, if any - - -_TaskKind = Enum( - "_TaskKind", - [ - "API_WORKER", - "DB_CLIENT", - "TOKEN_UPDATE", - "TRACKER", - "WS_MERGE_READER", - "WS_MERGE_TIMER", - "WS_REFRESHER", - "WS_WORKER", - ] -) - - -def _classify_task(task: Task) -> _TaskKind: - mod_name = getmodule(task.coro.cr_code).__name__ - if mod_name.startswith("trio_asyncio."): - return None - elif task.coro.__name__ == "_signal_handler_impl": - return None - else: - by_coro_name = { - "db_client_impl": _TaskKind.DB_CLIENT, - "token_updater_impl": _TaskKind.TOKEN_UPDATE, - "event_reader_impl": _TaskKind.WS_MERGE_READER, - "timeout_handler_impl": _TaskKind.WS_MERGE_TIMER, - "conn_refresher_impl": _TaskKind.WS_REFRESHER, - "count_tracker_impl": _TaskKind.TRACKER, - "_reader_impl": _TaskKind.WS_WORKER, - "worker_impl": _TaskKind.API_WORKER, - } - return by_coro_name[task.coro.__name__] - - -def _get_all_tasks(): - [ta_loop] = [ - task - for nursery in current_root_task().child_nurseries - for task in nursery.child_tasks - if getmodule(task.coro.cr_code).__name__.startswith("trio_asyncio.") - ] - for nursery in ta_loop.child_nurseries: - yield from nursery.child_tasks - - -def _print_task_info(): - groups = {kind: [] for kind in _TaskKind} - for task in _get_all_tasks(): - kind = _classify_task(task) - if kind: - coro = task.coro - frame_tuples = [] - while coro is not None: - if hasattr(coro, "cr_frame"): - frame_tuples.append((coro.cr_frame, coro.cr_frame.f_lineno)) - coro = coro.cr_await - else: - frame_tuples.append((coro.gi_frame, coro.gi_frame.f_lineno)) - coro = coro.gi_yieldfrom - groups[kind].append(StackSummary.extract(iter(frame_tuples))) - for kind in _TaskKind: - print(kind.name) - for ss in groups[kind]: - print(" task") - for line in ss.format(): - print(" " + line, end = "") - - -async def _signal_handler_impl(): - with open_signal_receiver(SIGUSR1) as sig_src: - async for _ in sig_src: - _print_task_info() - - -ap = argparse.ArgumentParser(__package__) -ap.add_argument("config_path") -args = ap.parse_args() - - -parser = configparser.ConfigParser() -with open(args.config_path) as config_file: - parser.read_file(config_file) - -main_cfg = parser["config"] - -api_pool_error_delay = dt.timedelta(seconds = main_cfg.getfloat("API pool error delay")) -api_pool_error_window = dt.timedelta(seconds = main_cfg.getfloat("API pool error window")) -auth_ids = set(map(int, main_cfg["auth IDs"].split())) -bot_user = main_cfg["bot user"] -enforcing = main_cfg.getboolean("enforcing") -reorder_buffer_time = dt.timedelta(seconds = main_cfg.getfloat("reorder buffer time")) -request_queue_limit = main_cfg.getint("request queue limit") -thread_id = main_cfg["thread ID"] -update_retention = dt.timedelta(seconds = main_cfg.getfloat("update retention time")) -ws_parity_time = dt.timedelta(seconds = main_cfg.getfloat("WS parity time")) -ws_pool_size = main_cfg.getint("WS pool size") -ws_silent_limit = dt.timedelta(seconds = main_cfg.getfloat("WS silent limit")) -ws_warmup = dt.timedelta(seconds = main_cfg.getfloat("WS warmup seconds")) - -db_cfg = parser["db connect params"] -getters = { - "port": db_cfg.getint, -} -db_connect_params = {k: getters.get(k, db_cfg.get)(k) for k in db_cfg} - - -logger = getLogger(__package__) -logger.setLevel(logging.DEBUG) - -handler = StreamHandler(stdout) -handler.setLevel(logging.WARNING) -handler.setFormatter(logging.Formatter("{asctime:23}: {name:17} {levelname:8}: {message}", style = "{")) -logger.addHandler(handler) - -if _DEBUG_LOG_PATH: - debug_handler = FileHandler(_DEBUG_LOG_PATH, mode = "w") - debug_handler.setLevel(logging.DEBUG) - debug_handler.setFormatter(logging.Formatter("{asctime:23}: {name:17} {levelname:8}: {message}", style = "{")) - logger.addHandler(debug_handler) - - -async def main(): - async with ( - triopg.connect(**db_connect_params) as db_conn, - open_nursery() as nursery_a, - open_nursery() as nursery_b, - open_nursery() as ws_pool_nursery, - open_nursery() as nursery_c, - ): - nursery_a.start_soon(_signal_handler_impl) - - client = Client(db_conn) - db_messenger = Messenger(client) - nursery_a.start_soon(db_messenger.db_client_impl) - - api_pool = ApiClientPool( - auth_ids, db_messenger, request_queue_limit, api_pool_error_window, api_pool_error_delay, - logger.getChild("api") - ) - nursery_a.start_soon(api_pool.token_updater_impl) - for _ in auth_ids: - await nursery_a.start(api_pool.worker_impl) - - (message_tx, message_rx) = open_memory_channel(0) - (pool_event_tx, pool_event_rx) = open_memory_channel(0) - - merger = PoolMerger(pool_event_rx, message_tx, ws_parity_time, ws_warmup, logger.getChild("merge")) - nursery_b.start_soon(merger.event_reader_impl) - nursery_b.start_soon(merger.timeout_handler_impl) - - ws_pool = HealingReadPool( - ws_pool_nursery, ws_pool_size, thread_id, api_pool, pool_event_tx, logger.getChild("live_ws"), - ws_silent_limit - ) - nursery_c.start_soon(ws_pool.conn_refresher_impl) - - nursery_c.start_soon( - count_tracker_impl, message_rx, api_pool, reorder_buffer_time, thread_id, bot_user, enforcing, - update_retention, logger.getChild("track") - ) - await ws_pool.init_workers() - -trio_asyncio.run(main) diff --git a/src/strikebot/common.py b/src/strikebot/common.py deleted file mode 100644 index 519037a..0000000 --- a/src/strikebot/common.py +++ /dev/null @@ -1,10 +0,0 @@ -from typing import Any - - -def int_digest(val: int) -> str: - return "{:04x}".format(val & 0xffff) - - -def obj_digest(obj: Any) -> str: - """Makes a short digest of the identity of an object.""" - return int_digest(id(obj)) diff --git a/src/strikebot/db.py b/src/strikebot/db.py deleted file mode 100644 index c9a060a..0000000 --- a/src/strikebot/db.py +++ /dev/null @@ -1,58 +0,0 @@ -"""Single-connection Postgres client with high-level wrappers for operations.""" - -from dataclasses import dataclass -from typing import Any, Iterable - -import trio - - -def _channel_sender(method): - async def wrapped(self, resp_channel, *args, **kwargs): - ret = await method(self, *args, **kwargs) - with resp_channel: - await resp_channel.send(ret) - - return wrapped - - -@dataclass -class Client: - """High-level wrappers for DB operations.""" - conn: Any - - @_channel_sender - async def get_auth_tokens(self, auth_ids: set[int]) -> dict[int, str]: - results = await self.conn.fetch( - """ - select - distinct on (id) - id, access_token - from - public.reddit_app_authorization - join unnest($1::integer[]) as request_id on id = request_id - where expires > current_timestamp - """, - list(auth_ids), - ) - return {r["id"]: r["access_token"] for r in results} - - -class Messenger: - def __init__(self, client: Client): - self._client = client - (self._request_tx, self._request_rx) = trio.open_memory_channel(0) - - async def do(self, method_name: str, args: Iterable) -> Any: - """This is run by consumers of the DB wrapper.""" - method = getattr(self._client, method_name) - (resp_tx, resp_rx) = trio.open_memory_channel(0) - coro = method(resp_tx, *args) - await self._request_tx.send(coro) - async with resp_rx: - return await resp_rx.receive() - - async def db_client_impl(self): - """This is the DB client Trio task.""" - async with self._client.conn, self._request_rx: - async for coro in self._request_rx: - await coro diff --git a/src/strikebot/live_ws.py b/src/strikebot/live_ws.py deleted file mode 100644 index cb8b3d0..0000000 --- a/src/strikebot/live_ws.py +++ /dev/null @@ -1,266 +0,0 @@ -from collections import deque -from contextlib import suppress -from dataclasses import dataclass -from typing import Any, AsyncContextManager, Optional -import datetime as dt -import json -import logging -import math - -from trio_websocket import ConnectionClosed, HandshakeError, open_websocket_url, WebSocketConnection -import trio - -from strikebot.common import int_digest, obj_digest -from strikebot.reddit_api import AboutLiveThreadRequest - - -@dataclass -class _PoolEvent: - timestamp: float # Trio time - - -@dataclass -class _Message(_PoolEvent): - data: Any - scope: Any - - def dedup_tag(self): - """A hashable essence of the contents of this message.""" - if self.data["type"] == "update": - return ("update", self.data["payload"]["data"]["id"]) - elif self.data["type"] in {"strike", "delete"}: - return (self.data["type"], self.data["payload"]) - else: - raise ValueError(self.data["type"]) - - -@dataclass -class _ConnectionDown(_PoolEvent): - scope: trio.CancelScope - - -@dataclass -class _ConnectionUp(_PoolEvent): - scope: trio.CancelScope - - -class HealingReadPool: - def __init__( - self, pool_nursery: trio.Nursery, size, live_thread_id, api_client_pool, pool_event_tx, logger: logging.Logger, - silent_limit: dt.timedelta - ): - assert size >= 2 - self._nursery = pool_nursery - self._size = size - self._live_thread_id = live_thread_id - self._api_client_pool = api_client_pool - self._pool_event_tx = pool_event_tx - self._logger = logger - self._silent_limit = silent_limit - - (self._refresh_queue_tx, self._refresh_queue_rx) = trio.open_memory_channel(size) - - # number of workers who have stopped receiving updates but whose tasks aren't yet stopped - self._closing = 0 - - async def init_workers(self): - for _ in range(self._size): - await self._spawn_reader() - if not self._nursery.child_tasks: - raise RuntimeError("Unable to create any WS connections") - - async def _reader_impl(self, conn_ctx: AsyncContextManager, cancel_scope, refresh_tx): - # TODO could use task's implicit cancel scope - try: - async with conn_ctx as conn: - self._logger.debug("scope up: {}".format(obj_digest(cancel_scope))) - with refresh_tx: - silent_timeout = False - with cancel_scope, suppress(ConnectionClosed): - while True: - with trio.move_on_after(self._silent_limit.total_seconds()) as timeout_scope: - message = await conn.get_message() - - if timeout_scope.cancelled_caught: - if len(self._nursery.child_tasks) - self._closing == self._size: - silent_timeout = True - break - else: - self._logger.debug("not replacing connection {}; {} tasks, {} closing".format( - obj_digest(cancel_scope), - len(self._nursery.child_tasks), - self._closing, - )) - else: - event = _Message(trio.current_time(), json.loads(message), cancel_scope) - await self._pool_event_tx.send(event) - - self._closing += 1 - if silent_timeout: - await conn.aclose() - self._logger.debug("replacing WS connection {} due to silent timeout".format(obj_digest(cancel_scope))) - elif cancel_scope.cancelled_caught: - await conn.aclose(1008, "Server unexpectedly stopped sending messages") - self._logger.warning("replacing WS connection {} due to missed update".format(obj_digest(cancel_scope))) - else: - self._logger.warning("replacing WS connection {} closed by server".format(obj_digest(cancel_scope))) - - refresh_tx.send_nowait(cancel_scope) - self._logger.debug("scope down: {} ({} tasks, {} closing)".format( - obj_digest(cancel_scope), - len(self._nursery.child_tasks), - self._closing, - )) - self._closing -= 1 - except HandshakeError: - self._logger.error("handshake error while opening WS connection") - - async def _new_connection(self) -> Optional[AsyncContextManager[WebSocketConnection]]: - request = AboutLiveThreadRequest(self._live_thread_id) - resp = await self._api_client_pool.make_request(request) - if resp.status_code == 200: - url = json.loads(resp.text)["data"]["websocket_url"] - return open_websocket_url(url) - else: - self._logger.error(f"thread info request failed (HTTP {resp.status_code}); can't open WS connection") - return None - - async def _spawn_reader(self) -> None: - conn = await self._new_connection() - if conn: - new_scope = trio.CancelScope() - await self._pool_event_tx.send(_ConnectionUp(trio.current_time(), new_scope)) - refresh_tx = self._refresh_queue_tx.clone() - self._nursery.start_soon(self._reader_impl, conn, new_scope, refresh_tx) - - async def conn_refresher_impl(self): - """Task to monitor and replace WS connections as they disconnect or go silent.""" - with self._refresh_queue_rx: - async for old_scope in self._refresh_queue_rx: - await self._pool_event_tx.send(_ConnectionDown(trio.current_time(), old_scope)) - await self._spawn_reader() - if not self._nursery.child_tasks: - raise RuntimeError("WS pool depleted") - - -def _tag_digest(tag): - return int_digest(hash(tag)) - - -def _format_scope_list(scopes): - return ", ".join(sorted(map(obj_digest, scopes))) - - -class PoolMerger: - @dataclass - class _Bucket: - start: float - recipients: set - - def __init__(self, pool_event_rx, message_tx, parity_timeout: dt.timedelta, conn_warmup: dt.timedelta, logger: logging.Logger): - """ - :param parity_timeout: max delay between message receipt on different connections - :param conn_warmup: max interval after connection within which missed updates are allowed - """ - self._pool_event_rx = pool_event_rx - self._message_tx = message_tx - self._parity_timeout = parity_timeout - self._conn_warmup = conn_warmup - self._logger = logger - - self._buckets = {} - self._scope_activations = {} - self._pending = deque() - self._outgoing_scopes = set() - (self._timer_poke_tx, self._timer_poke_rx) = trio.open_memory_channel(math.inf) - - def _log_current_buckets(self): - self._logger.debug("current buckets: {}".format(", ".join(map(_tag_digest, self._buckets.keys())))) - - async def event_reader_impl(self): - """Drop unused messages, deduplicate useful ones, and communicate with the timeout handler.""" - with self._pool_event_rx, self._timer_poke_tx: - async for event in self._pool_event_rx: - if isinstance(event, _ConnectionUp): - # An early add of an active scope could mean it's expected on a message that fired before it opened, - # resulting in a false positive for replacement. The timer task merges these in among the buckets by - # timestamp to avoid that. - self._pending.append(event) - elif isinstance(event, _Message): - if event.data["type"] in ["update", "strike", "delete"]: - tag = event.dedup_tag() - self._logger.debug("recv {} from {}".format(_tag_digest(tag), obj_digest(event.scope))) - if tag in self._buckets: - b = self._buckets[tag] - b.recipients.add(event.scope) - # If this scope is the last one for this bucket we could clear the bucket here, but since - # connections sometimes get repeat messages and second copies can arrive before the first - # copy has arrived on all connections, leaving the bucket open to absorb repeats can reduce - # the likelihood of a second bucket being allocated late for the second copy of a message - # and causing unnecessary connection replacement. - elif event.scope not in self._outgoing_scopes: - sane = ( - event.scope in self._scope_activations - or any(e.scope == event.scope for e in self._pending) - ) - if sane: - self._logger.debug("new bucket {}: {}".format(_tag_digest(tag), tag)) - self._buckets[tag] = self._Bucket(event.timestamp, {event.scope}) - self._log_current_buckets() - await self._message_tx.send(event) - else: - raise RuntimeError("recieved message from unrecognized WS connection") - else: - self._logger.debug("recv type {!r} from {} (discarding)".format(event.data["type"], obj_digest(event.scope))) - elif isinstance(event, _ConnectionDown): - # We don't need to worry about canceling this scope at all, so no need to require it for parity for - # any message, even older ones. The scope may be gone already, if we canceled it previously. - self._scope_activations.pop(event.scope, None) - self._pending = deque(ev for ev in self._pending if ev.scope is not event.scope) - self._outgoing_scopes.discard(event.scope) - else: - raise TypeError(f"Expected pool event, found {event!r}") - self._timer_poke_tx.send_nowait(None) # may have new work for the timer - - async def timeout_handler_impl(self): - """When connections have had enough time to reach parity on a message, signal replacement of any slackers.""" - with self._timer_poke_rx: - while True: - if self._buckets: - now = trio.current_time() - (tag, bucket) = min(self._buckets.items(), key = lambda i: i[1].start) - - # Make sure the right scope set is ready for the next bucket up - while self._pending and self._pending[0].timestamp < bucket.start: - ev = self._pending.popleft() - self._scope_activations[ev.scope] = ev.timestamp - - target_scopes = { - scope for (scope, active) in self._scope_activations.items() - if active + self._conn_warmup.total_seconds() < now - } - if now > bucket.start + self._parity_timeout.total_seconds(): - filled = target_scopes & bucket.recipients - missing = target_scopes - bucket.recipients - extra = bucket.recipients - target_scopes - - self._logger.debug("expiring bucket {}".format(_tag_digest(tag))) - if filled: - self._logger.debug(" filled {}: {}".format(len(filled), _format_scope_list(filled))) - if missing: - self._logger.debug(" missing {}: {}".format(len(missing), _format_scope_list(missing))) - if extra: - self._logger.debug(" extra {}: {}".format(len(extra), _format_scope_list(extra))) - - for scope in missing: - self._outgoing_scopes.add(scope) - scope.cancel() - del self._buckets[tag] - self._log_current_buckets() - else: - await trio.sleep_until(bucket.start + self._parity_timeout.total_seconds()) - else: - try: - await self._timer_poke_rx.receive() - except trio.EndOfChannel: - break diff --git a/src/strikebot/queue.py b/src/strikebot/queue.py deleted file mode 100644 index adfbbe6..0000000 --- a/src/strikebot/queue.py +++ /dev/null @@ -1,58 +0,0 @@ -"""Unbounded blocking queues for Trio""" - -from collections import deque -from dataclasses import dataclass -from functools import total_ordering -from typing import Any, Iterable -import heapq - -from trio.lowlevel import ParkingLot - - -class Queue: - def __init__(self): - self._deque = deque() - self._empty_wait = ParkingLot() - - def push(self, el: Any) -> None: - self._deque.append(el) - self._empty_wait.unpark() - - def extend(self, els: Iterable[Any]) -> None: - for el in els: - self.push(el) - - async def pop(self) -> Any: - if not self._deque: - await self._empty_wait.park() - return self._deque.popleft() - - def __len__(self): - return len(self._deque) - - -@total_ordering -@dataclass -class _ReverseOrdWrapper: - inner: Any - - def __lt__(self, other): - return other.inner < self.inner - - -class MaxHeap: - def __init__(self): - self._heap = [] - self._empty_wait = ParkingLot() - - def push(self, item): - heapq.heappush(self._heap, _ReverseOrdWrapper(item)) - self._empty_wait.unpark() - - async def pop(self): - if not self._heap: - await self._empty_wait.park() - return heapq.heappop(self._heap).inner - - def __len__(self) -> int: - return len(self._heap) diff --git a/src/strikebot/reddit_api.py b/src/strikebot/reddit_api.py deleted file mode 100644 index 3e69e6c..0000000 --- a/src/strikebot/reddit_api.py +++ /dev/null @@ -1,291 +0,0 @@ -"""Multi-auth HTTP connection pool with Reddit API wrappers and rate limiting.""" - -from abc import ABCMeta, abstractmethod -from dataclasses import dataclass, field -from functools import total_ordering -from socket import EAI_AGAIN, EAI_FAIL, gaierror -import datetime as dt -import logging - -from asks.response_objects import Response -import asks -import trio - -from strikebot import __version__ as VERSION -from strikebot.common import obj_digest -from strikebot.queue import MaxHeap, Queue - - -REQUEST_DELAY = dt.timedelta(seconds = 1) - -TOKEN_UPDATE_DELAY = dt.timedelta(minutes = 15) - -USER_AGENT_FMT = f"any:net.jcornell.strikebot.{{auth_id}}:v{VERSION} (by /u/jaxklax)" - -API_BASE_URL = "https://oauth.reddit.com" - - -@total_ordering -class _Request(metaclass = ABCMeta): - _SUBTYPE_PRECEDENCE = None # assigned later - - def __lt__(self, other): - if type(self) is type(other): - return self._subtype_cmp_key() < other._subtype_cmp_key() - else: - prec = self._SUBTYPE_PRECEDENCE - return prec.index(type(other)) < prec.index(type(self)) - - def __eq__(self, other): - if type(self) is type(other): - return self._subtype_cmp_key() == other._subtype_cmp_key() - else: - return False - - @abstractmethod - def _subtype_cmp_key(self): - # a large key corresponds to a high priority - raise NotImplementedError() - - @abstractmethod - def to_asks_kwargs(self): - raise NotImplementedError() - - -@dataclass(eq = False) -class StrikeRequest(_Request): - thread_id: str - update_name: str - update_ts: dt.datetime - - def to_asks_kwargs(self): - return { - "method": "POST", - "path": f"/api/live/{self.thread_id}/strike_update", - "data": { - "api_type": "json", - "id": self.update_name, - }, - } - - def _subtype_cmp_key(self): - return self.update_ts - - -@dataclass(eq = False) -class DeleteRequest(_Request): - thread_id: str - update_name: str - update_ts: dt.datetime - - def to_asks_kwargs(self): - return { - "method": "POST", - "path": f"/api/live/{self.thread_id}/delete_update", - "data": { - "api_type": "json", - "id": self.update_name, - }, - } - - def _subtype_cmp_key(self): - return -self.update_ts - - -@dataclass(eq = False) -class AboutLiveThreadRequest(_Request): - thread_id: str - created: float = field(default_factory = trio.current_time) - - def to_asks_kwargs(self): - return { - "method": "GET", - "path": f"/live/{self.thread_id}/about", - "params": { - "raw_json": "1", - }, - } - - def _subtype_cmp_key(self): - return -self.created - - -@dataclass(eq = False) -class _BaseUpdatePostRequest(_Request, metaclass = ABCMeta): - thread_id: str - body: str - created: float = field(default_factory = trio.current_time) - - def to_asks_kwargs(self): - return { - "method": "POST", - "path": f"/live/{self.thread_id}/update", - "data": { - "api_type": "json", - "body": self.body, - }, - } - - -@dataclass(eq = False) -class ReportUpdateRequest(_BaseUpdatePostRequest): - def _subtype_cmp_key(self): - return -self.created - - -@dataclass(eq = False) -class CorrectionUpdateRequest(_BaseUpdatePostRequest): - def _subtype_cmp_key(self): - return -self.created - - -# highest priority first -_Request._SUBTYPE_PRECEDENCE = [ - AboutLiveThreadRequest, - StrikeRequest, - CorrectionUpdateRequest, - ReportUpdateRequest, - DeleteRequest, -] - - -@dataclass -class AppCooldown: - auth_id: int - ready_at: float # Trio time - - -class ApiClientPool: - def __init__( - self, - auth_ids: set[int], - db_messenger: "strikebot.db.Messenger", - request_queue_limit: int, - error_window: dt.timedelta, - error_delay: dt.timedelta, - logger: logging.Logger, - ): - self._auth_ids = auth_ids - self._db_messenger = db_messenger - self._request_queue_limit = request_queue_limit - self._error_window = error_window - self._error_delay = error_delay - self._logger = logger - - now = trio.current_time() - self._tokens = {} - self._app_queue = Queue() - self._waiting = {id_: AppCooldown(id_, now) for id_ in auth_ids} - self._request_queue = MaxHeap() - - # pool-wide API error backoff - self._last_error = None - self._global_resume = None - - self._session = asks.Session(connections = len(auth_ids)) - self._session.base_location = API_BASE_URL - - async def _update_tokens(self): - tokens = await self._db_messenger.do("get_auth_tokens", (self._auth_ids,)) - self._tokens.update(tokens) - - awaken_auths = self._waiting.keys() & tokens.keys() - self._app_queue.extend(self._waiting.pop(auth_id) for auth_id in awaken_auths) - self._logger.debug("updated API tokens") - - async def token_updater_impl(self): - last_update = None - while True: - if last_update is not None: - await trio.sleep_until(last_update + TOKEN_UPDATE_DELAY.total_seconds()) - - if last_update is None: - last_update = trio.current_time() - else: - last_update += TOKEN_UPDATE_DELAY.total_seconds() - - await self._update_tokens() - - def _check_queue_size(self) -> None: - if len(self._request_queue) > self._request_queue_limit: - raise RuntimeError("request queue size exceeded limit") - - async def make_request(self, request: _Request) -> Response: - (resp_tx, resp_rx) = trio.open_memory_channel(0) - self.enqueue_request(request, resp_tx) - async with resp_rx: - return await resp_rx.receive() - - def enqueue_request(self, request: _Request, resp_tx = None) -> None: - self._logger.debug("request {}: {}".format(obj_digest(request), type(request).__name__)) - self._request_queue.push((request, resp_tx)) - self._check_queue_size() - self._logger.debug(f"{len(self._request_queue)} requests in queue") - - async def worker_impl(self, task_status): - task_status.started() - while True: - (request, resp_tx) = await self._request_queue.pop() - cooldown = await self._app_queue.pop() - await trio.sleep_until(cooldown.ready_at) - if self._global_resume: - await trio.sleep_until(self._global_resume) - - asks_kwargs = request.to_asks_kwargs() - headers = asks_kwargs.setdefault("headers", {}) - headers.update({ - "Authorization": "Bearer {}".format(self._tokens[cooldown.auth_id]), - "User-Agent": USER_AGENT_FMT.format(auth_id = cooldown.auth_id), - }) - - request_time = trio.current_time() - try: - resp = await self._session.request(**asks_kwargs) - except gaierror as e: - if e.errno in [EAI_FAIL, EAI_AGAIN]: - # DNS failure, probably temporary - error = True - else: - raise - else: - resp.body # read response - error = False - wait_for_token = False - log_suffix = " (request {})".format(obj_digest(request)) - if resp.status_code == 429: - # We disagreed about the rate limit state; just try again later. - self._logger.warning("rate limited by Reddit API" + log_suffix) - error = True - elif resp.status_code == 401: - self._logger.warning("got HTTP 401 from Reddit API" + log_suffix) - error = True - wait_for_token = True - elif resp.status_code in [404, 500, 503]: - self._logger.warning(f"got HTTP {resp.status_code} from Reddit API, retrying" + log_suffix) - error = True - elif 400 <= resp.status_code < 500: - # If we're doing something wrong, let's catch it right away. - raise RuntimeError(f"unexpected client error response: {resp.status_code}" + log_suffix) - else: - if resp.status_code != 200: - raise RuntimeError(f"unexpected status code {resp.status_code}") - self._logger.debug("success" + log_suffix) - if resp_tx: - await resp_tx.send(resp) - - if error: - self._request_queue.push((request, resp_tx)) - self._check_queue_size() - if self._last_error: - spread = dt.timedelta(seconds = request_time - self._last_error) - if spread <= self._error_window: - self._global_resume = request_time + self._error_delay.total_seconds() - self._last_error = request_time - - cooldown.ready_at = request_time + REQUEST_DELAY.total_seconds() - if wait_for_token: - self._waiting[cooldown.auth_id] = cooldown - if not self._app_queue: - self._logger.error("all workers waiting for API tokens") - else: - self._app_queue.push(cooldown) diff --git a/src/strikebot/tests.py b/src/strikebot/tests.py deleted file mode 100644 index 7bbf226..0000000 --- a/src/strikebot/tests.py +++ /dev/null @@ -1,50 +0,0 @@ -from unittest import TestCase - -from strikebot.updates import parse_update - - -def _build_payload(body_html: str) -> dict[str, str]: - return {"body_html": body_html} - - -class UpdateParsingTests(TestCase): - def test_successful_counts(self): - pu = parse_update(_build_payload("
12,345,678 spaghetti
"), None, "") - self.assertEqual(pu.number, 12345678) - self.assertFalse(pu.deletable) - - pu = parse_update(_build_payload("

0


oz
"), None, "") - self.assertEqual(pu.number, 0) - self.assertFalse(pu.deletable) - - pu = parse_update(_build_payload("
121 345 621
"), 121, "") - self.assertEqual(pu.number, 121) - self.assertFalse(pu.deletable) - - pu = parse_update(_build_payload("28 336 816"), 28336816, "") - self.assertEqual(pu.number, 28336816) - self.assertTrue(pu.deletable) - - def test_non_counts(self): - pu = parse_update(_build_payload("
zoo
"), None, "") - self.assertFalse(pu.count_attempt) - self.assertFalse(pu.deletable) - - def test_typos(self): - pu = parse_update(_build_payload("v9"), 888, "") - self.assertIsNone(pu.number) - self.assertTrue(pu.count_attempt) - - pu = parse_update(_build_payload("
v11.585 Empire
"), None, "") - self.assertIsNone(pu.number) - self.assertTrue(pu.count_attempt) - self.assertFalse(pu.deletable) - - pu = parse_update(_build_payload("
11, 585, 22
"), 11_585_202, "") - self.assertIsNone(pu.number) - self.assertTrue(pu.count_attempt) - self.assertTrue(pu.deletable) - - pu = parse_update(_build_payload("0490499"), 4999, "") - self.assertIsNone(pu.number) - self.assertTrue(pu.count_attempt) diff --git a/src/strikebot/updates.py b/src/strikebot/updates.py deleted file mode 100644 index e3f58cd..0000000 --- a/src/strikebot/updates.py +++ /dev/null @@ -1,216 +0,0 @@ -from __future__ import annotations -from dataclasses import dataclass -from enum import Enum -from typing import Optional -import re - -from bs4 import BeautifulSoup - - -Command = Enum("Command", ["RESET", "REPORT"]) - - -@dataclass -class ParsedUpdate: - number: Optional[int] - command: Optional[Command] - count_attempt: bool # either well-formed or typo - deletable: bool - - -def _parse_command(line: str, bot_user: str) -> Optional[Command]: - if line.lower() == f"/u/{bot_user} reset".lower(): - return Command.RESET - elif line.lower() in ["sidebar count", "current count"]: - return Command.REPORT - else: - return None - - -def parse_update(payload_data: dict, curr_count: Optional[int], bot_user: str) -> ParsedUpdate: - # curr_count is the next number up, one more than the last count - - NEW_LINE = object() - SPACE = object() - - # flatten the update content to plain text - tree = BeautifulSoup(payload_data["body_html"], "html.parser") - worklist = tree.contents - out = [[]] - while worklist: - el = worklist.pop() - if isinstance(el, str): - out[-1].append(el) - elif el is SPACE: - out[-1].append(el) - elif el is NEW_LINE or el.name == "br": - if out[-1]: - out.append([]) - elif el.name in ["em", "strong", "del", "span", "sup", "code", "a", "th", "td"]: - worklist.extend(reversed(el.contents)) - elif el.name in ["ul", "ol", "table", "thead", "tbody"]: - worklist.extend(reversed(el.contents)) - elif el.name in ["li", "p", "div", "h1", "h2", "blockquote"]: - worklist.append(NEW_LINE) - worklist.extend(reversed(el.contents)) - worklist.append(NEW_LINE) - elif el.name == "pre": - worklist.append(NEW_LINE) - worklist.extend([l] for l in reversed(el.text.splitlines())) - worklist.append(NEW_LINE) - elif el.name == "tr": - worklist.append(NEW_LINE) - for (i, cell) in enumerate(reversed(el.contents)): - worklist.append(cell) - if i != len(el.contents) - 1: - worklist.append(SPACE) - worklist.append(NEW_LINE) - else: - raise RuntimeError(f"can't parse tag {el.tag}") - - tmp_lines = ( - "".join(" " if part is SPACE else part for part in parts).strip() - for parts in out - ) - pre_strip_lines = list(filter(None, tmp_lines)) - - # normalize whitespace according to HTML rendering rules - # https://developer.mozilla.org/en-US/docs/Web/API/Document_Object_Model/Whitespace#explanation - stripped_lines = [ - re.sub(" +", " ", l.replace("\t", " ").replace("\n", " ")).strip(" ") - for l in pre_strip_lines - ] - - return _parse_from_lines(stripped_lines, curr_count, bot_user) - - -def _parse_from_lines(lines: list[str], curr_count: Optional[int], bot_user: str) -> ParsedUpdate: - command = next( - filter(None, (_parse_command(l, bot_user) for l in lines)), - None - ) - if lines: - # look for groups of digits (as many as possible) separated by a uniform separator from the valid set - first = lines[0] - match = re.match( - "(?Pv)?(?P-)?(?P\\d+((?P[,. \u2009]|, )\\d+((?P=sep)\\d+)*)?)", - first, - re.ASCII, # only recognize ASCII digits - ) - if match: - raw_digits = match["num"] - sep = match["sep"] - post = first[match.end() :] - - zeros = False - while len(raw_digits) > 1 and raw_digits[0] == "0": - zeros = True - raw_digits = raw_digits.removeprefix("0").removeprefix(sep or "") - - parts = raw_digits.split(sep) if sep else [raw_digits] - lone = len(lines) == 1 and (not post or post.isspace()) - typo = False - if lone: - all_parts_valid = ( - sep is None - or ( - 1 <= len(parts[0]) <= 3 - and all(len(p) == 3 for p in parts[1:]) - ) - ) - if match["v"] and len(parts) == 1 and len(parts[0]) <= 2: - # failed paste of leading digits - typo = True - elif match["v"] and all_parts_valid: - # v followed by count - typo = True - elif curr_count is not None and abs(curr_count) >= 100 and bool(match["neg"]) == (curr_count < 0): - goal_parts = _separate(str(abs(curr_count))) - partials = [ - goal_parts[: -1] + [goal_parts[-1][: -1]], # missing last digit - goal_parts[: -1] + [goal_parts[-1][: -2]], # missing last two digits - goal_parts[: -1] + [goal_parts[-1][: -2] + goal_parts[-1][-1]], # missing second-last digit - ] - if parts in partials: - # missing any of last two digits - typo = True - elif parts in [p[: -1] + [p[-1] + goal_parts[0]] + goal_parts[1 :] for p in partials]: - # double paste - typo = True - - if match["v"] or zeros or typo or (parts == ["0"] and match["neg"]): - number = None - count_attempt = True - deletable = lone - else: - if curr_count is not None and sep and sep.isspace(): - # Presume that the intended count consists of as many valid digit groups as necessary to match the - # number of digits in the expected count, if possible. - digit_count = len(str(abs(curr_count))) - use_parts = [] - accum = 0 - for (i, part) in enumerate(parts): - part_valid = len(part) <= 3 if i == 0 else len(part) == 3 - if part_valid and accum < digit_count: - use_parts.append(part) - accum += len(part) - else: - break - - # could still be a no-separator count with some extra digit groups on the same line - if not use_parts: - use_parts = [parts[0]] - - lone = lone and len(use_parts) == len(parts) - else: - # current count is unknown or no separator was used - use_parts = parts - - digits = "".join(use_parts) - number = -int(digits) if match["neg"] else int(digits) - special = ( - curr_count is not None - and abs(number - curr_count) <= 25 - and _is_special_number(number) - ) - deletable = lone and not special - if len(use_parts) == len(parts) and post and not post[0].isspace(): - count_attempt = curr_count is not None and abs(number - curr_count) <= 25 - number = None - else: - count_attempt = True - else: - # no count attempt found - number = None - count_attempt = False - deletable = False - else: - # no lines in update - number = None - count_attempt = False - deletable = True - - return ParsedUpdate( - number = number, - command = command, - count_attempt = count_attempt, - deletable = deletable, - ) - - -def _separate(digits: str) -> list[str]: - mod = len(digits) % 3 - out = [] - if mod: - out.append(digits[: mod]) - out.extend(digits[i : i + 3] for i in range(mod, len(digits), 3)) - return out - - -def _is_special_number(num: int) -> bool: - num_str = str(num) - return bool( - num % 1000 in [0, 1, 333, 999] - or (num > 10_000_000 and "".join(reversed(num_str)) == num_str) - or re.match(r"(.+)\1+$", num_str) # repeated sequence - ) diff --git a/strikebot/docs/sample_config.ini b/strikebot/docs/sample_config.ini new file mode 100644 index 0000000..28e510e --- /dev/null +++ b/strikebot/docs/sample_config.ini @@ -0,0 +1,53 @@ +# see Python `configparser' docs for precise file syntax + +[config] + +######## basic operating parameters + +# space-separated authorization IDs for Reddit API token lookup +auth IDs = 13 15 + +bot user = count_better +enforcing = true +thread ID = abc123 + + +######## API pool error handling + +# (seconds) for error backoff, pause the API pool for this much time +API pool error delay = 5.5 + +# (seconds) pause the pool when the Reddit API returns two errors within any time span of this length +API pool error window = 5.5 + +# terminate if the Reddit API request queue exceeds this size +request queue limit = 100 + + +######## live thread update handling + +# (seconds) maximum time to hold updates for reordering +reorder buffer time = 0.25 + +# (seconds) minimum time to retain updates to enable future resyncs +update retention time = 120.0 + + +######## WebSocket pool + +# (seconds) maximum allowable spread in WebSocket message arrival times +WS parity time = 1.0 + +# goal size of WebSocket connection pool +WS pool size = 5 + +# (seconds) after receiving no messages for this much time, attempt to replace the connection if the pool isn't depleted +WS silent limit = 30 + +# (seconds) time after WebSocket handshake completion during which missed updates are excused +WS warmup time = 0.3 + + +# Postgres database configuration; same options as in a connect string +[db connect params] +host = example.org diff --git a/strikebot/pyproject.toml b/strikebot/pyproject.toml new file mode 100644 index 0000000..8fe2f47 --- /dev/null +++ b/strikebot/pyproject.toml @@ -0,0 +1,3 @@ +[build-system] +requires = ["setuptools>=42", "wheel"] +build-backend = "setuptools.build_meta" diff --git a/strikebot/setup.cfg b/strikebot/setup.cfg new file mode 100644 index 0000000..f200179 --- /dev/null +++ b/strikebot/setup.cfg @@ -0,0 +1,15 @@ +[metadata] +name = strikebot +version = 0.0.0 + +[options] +package_dir = + = src +packages = strikebot +python_requires = ~= 3.9 +install_requires = + asks ~= 2.4 + beautifulsoup4 ~= 4.9 + trio == 0.19 + trio-websocket == 0.9.2 + triopg == 0.6.0 diff --git a/strikebot/setup.py b/strikebot/setup.py new file mode 100644 index 0000000..056ba45 --- /dev/null +++ b/strikebot/setup.py @@ -0,0 +1,4 @@ +import setuptools + + +setuptools.setup() diff --git a/strikebot/src/strikebot/__init__.py b/strikebot/src/strikebot/__init__.py new file mode 100644 index 0000000..7c37d14 --- /dev/null +++ b/strikebot/src/strikebot/__init__.py @@ -0,0 +1,321 @@ +from contextlib import nullcontext as nullcontext +from dataclasses import dataclass +from functools import total_ordering +from itertools import islice +from typing import Optional +from uuid import UUID +import bisect +import datetime as dt +import importlib.metadata +import itertools +import logging + +from trio import CancelScope, EndOfChannel +import trio + +from strikebot.updates import Command, parse_update + + +__version__ = importlib.metadata.version(__package__) + + +_READ_ONLY: bool = False # suppress any API requests that modify the thread? + + +@dataclass +class _Update: + id: str + name: str + author: str + number: Optional[int] + command: Optional[Command] + ts: dt.datetime + count_attempt: bool + deletable: bool + stricken: bool + + def can_follow(self, prior: "_Update") -> bool: + """Determine whether this count update can follow another count update.""" + return self.number == prior.number + 1 and self.author != prior.author + + def __str__(self): + content = str(self.number) + if self.command: + content += "+" + self.command.name + return "{}({} by {})".format(self.id[4 : 8], content, self.author) + + +@total_ordering +@dataclass +class _BufferedUpdate: + update: _Update + release_at: float # Trio time + + def __lt__(self, other): + return self.update.ts < other.update.ts + + +@total_ordering +@dataclass +class _TimelineUpdate: + update: _Update + accepted: bool + + def __lt__(self, other): + return self.update.ts < other.update.ts + + def rejected(self) -> bool: + """Whether the update should be stricken.""" + return self.update.count_attempt and not self.accepted + + +def _parse_rfc_4122_ts(ts: int) -> dt.datetime: + epoch = dt.datetime(1582, 10, 15, tzinfo = dt.timezone.utc) + return epoch + dt.timedelta(microseconds = ts / 10) + + +def _format_update_ref(update: _Update, thread_id: str) -> str: + url = f"https://www.reddit.com/live/{thread_id}/updates/{update.id}" + author_md = update.author.replace("_", "\\_") # Reddit allows letters, digits, underscore, and hyphen + return f"[{update.number:,} by {author_md}]({url})" + + +def _format_bad_strike_alert(update: _Update, thread_id: str) -> str: + return _format_update_ref(update, thread_id) + " was incorrectly stricken." + + +def _format_curr_count(last_valid: Optional[_Update]) -> str: + if last_valid and last_valid.number is not None: + count_str = f"{last_valid.number + 1:,}" + else: + count_str = "unknown" + return f"_current count:_ {count_str}" + + +async def count_tracker_impl( + message_rx, + api_pool: "strikebot.reddit_api.ApiClientPool", + reorder_buffer_time: dt.timedelta, # duration to hold some updates for reordering + thread_id: str, + bot_user: str, + enforcing: bool, + update_retention: dt.timedelta, + logger: logging.Logger, +) -> None: + from strikebot.reddit_api import CorrectionUpdateRequest, DeleteRequest, ReportUpdateRequest, StrikeRequest + + enforcing = enforcing and not _READ_ONLY + + buffer_ = [] + timeline = [] + last_valid: Optional[_Update] = None + pending_strikes: set[str] = set() # names of updates to mark stricken on arrival + forgotten: bool = False # whether the retention period for an update has passed during the run + + def handle_update(update): + nonlocal last_valid + + tu = _TimelineUpdate(update, accepted = False) + + pos = bisect.bisect(timeline, tu) + if pos != len(timeline): + logger.warning(f"long transpo: {update.id}") + + if pos > 0 and timeline[pos - 1].update.id == update.id: + # The pool sent this update message multiple times. This could be because the message reached parity and + # its bucket was dropped just before a new connection was opened and sent its own copy, but connections also + # seem to send duplicate messages. + return + + pred = next( + (timeline[i].update for i in reversed(range(pos)) if timeline[i].accepted), + None + ) + contender = update.command is Command.RESET or update.number is not None + if contender and pred is None and last_valid and forgotten: + # This is a really long transpo which we can't judge valid or invalid since we've forgotten the surrounding + # updates. The best we can do is just ignore it. + logger.warning(f"ignoring {update.id}: no valid prior count on record") + else: + timeline.insert(pos, tu) + tu.accepted = ( + update.command is Command.RESET + or ( + update.number is not None + and (pred is None or pred.number is None or update.can_follow(pred)) + ) + ) + logger.debug("{} {} -> {}".format("accepted" if tu.accepted else "rejected", pred, update)) + if tu.accepted: + # resync subsequent updates + newly_valid = [] + newly_invalid = [] + resync_last_valid = update + converged = False + for scan_tu in islice(timeline, pos + 1, None): + if scan_tu.update.command is Command.RESET: + resync_last_valid = scan_tu.update + if last_valid: + converged = True + elif scan_tu.update.number is not None: + accept = resync_last_valid.number is None or scan_tu.update.can_follow(resync_last_valid) + if accept and scan_tu.accepted: + # resync would have no effect past this point + if last_valid: + converged = True + elif accept: + newly_valid.append(scan_tu) + resync_last_valid = scan_tu.update + elif scan_tu.accepted: + newly_invalid.append(scan_tu) + if scan_tu.update is last_valid: + last_valid = None + scan_tu.accepted = accept + if converged: + break + + if converged: + assert last_valid.ts >= resync_last_valid.ts + else: + last_valid = resync_last_valid + + parts = [] + if newly_valid: + parts.append( + "The following counts are valid:\n\n" + + "\n".join(" - " + _format_update_ref(tu.update, thread_id) for tu in newly_valid) + ) + + unstrikable = [tu for tu in newly_invalid if tu.update.stricken] + if unstrikable: + parts.append( + "The following counts are invalid:\n\n" + + "\n".join(" - " + _format_update_ref(tu.update, thread_id) for tu in unstrikable) + ) + + if update.stricken: + logger.info(f"bad strike of {update.id}") + parts.append(_format_bad_strike_alert(update, thread_id)) + + if parts: + parts.append(_format_curr_count(last_valid)) + if not _READ_ONLY: + api_pool.enqueue_request(CorrectionUpdateRequest(thread_id, "\n\n".join(parts))) + + for invalid_tu in newly_invalid: + if not invalid_tu.update.stricken: + if enforcing: + api_pool.enqueue_request(StrikeRequest(thread_id, invalid_tu.update.name, invalid_tu.update.ts)) + invalid_tu.update.stricken = True + elif update.count_attempt: + if enforcing: + api_pool.enqueue_request(StrikeRequest(thread_id, update.name, update.ts)) + update.stricken = True + + if not _READ_ONLY and update.command is Command.REPORT: + api_pool.enqueue_request(ReportUpdateRequest(thread_id, body = _format_curr_count(last_valid))) + + with message_rx: + while True: + if buffer_: + deadline = min(bu.release_at for bu in buffer_) + cancel_scope = CancelScope(deadline = deadline) + else: + cancel_scope = nullcontext() + + msg = None + with cancel_scope: + try: + msg = await message_rx.receive() + except EndOfChannel: + break + + if msg: + if msg.data["type"] == "update": + payload_data = msg.data["payload"]["data"] + stricken = payload_data["name"] in pending_strikes + pending_strikes.discard(payload_data["name"]) + if payload_data["author"] != bot_user: + if last_valid and last_valid.number is not None: + next_up = last_valid.number + 1 + else: + next_up = None + pu = parse_update(payload_data, next_up, bot_user) + rfc_ts = UUID(payload_data["id"]).time + update = _Update( + id = payload_data["id"], + name = payload_data["name"], + author = payload_data["author"], + number = pu.number, + command = pu.command, + ts = _parse_rfc_4122_ts(rfc_ts), + count_attempt = pu.count_attempt, + deletable = pu.deletable, + stricken = stricken or payload_data["stricken"], + ) + release_at = msg.timestamp + reorder_buffer_time.total_seconds() + bisect.insort(buffer_, _BufferedUpdate(update, release_at)) + elif msg.data["type"] == "strike": + slot = next( + ( + slot for slot in itertools.chain(buffer_, reversed(timeline)) + if slot.update.name == msg.data["payload"] + ), + None + ) + if slot: + if not slot.update.stricken: + slot.update.stricken = True + if isinstance(slot, _TimelineUpdate) and slot.accepted: + logger.info(f"bad strike of {slot.update.id}") + if not _READ_ONLY: + body = _format_bad_strike_alert(slot.update, thread_id) + api_pool.enqueue_request(CorrectionUpdateRequest(thread_id, body)) + else: + pending_strikes.add(msg.data["payload"]) + + threshold = dt.datetime.now(dt.timezone.utc) - update_retention + + # pull updates from the reorder buffer and process + new_buffer = [] + for bu in buffer_: + process = ( + # not a count + bu.update.number is None + + # valid count + or (last_valid is None or last_valid.number is None or bu.update.can_follow(last_valid)) + or bu.update.command is Command.RESET + + or trio.current_time() >= bu.release_at + or bu.update.command is Command.RESET + ) + + if process: + if bu.update.ts < threshold: + logger.warning(f"ignoring {bu.update}: arrived past retention window") + else: + logger.debug(f"processing {bu.update}") + handle_update(bu.update) + else: + logger.debug(f"holding {bu.update}, checked against {last_valid})") + new_buffer.append(bu) + buffer_ = new_buffer + logger.debug("last count {}".format(last_valid)) + + # delete/forget old updates + new_timeline = [] + i = 0 + for (i, tu) in enumerate(timeline): + if i >= len(timeline) - 10 or tu.update.ts >= threshold: + break + elif tu.rejected() and tu.update.deletable: + if enforcing: + api_pool.enqueue_request(DeleteRequest(thread_id, tu.update.name, tu.update.ts)) + elif tu.update is last_valid: + new_timeline.append(tu) + + if i: + forgotten = True + new_timeline.extend(islice(timeline, i, None)) + timeline = new_timeline diff --git a/strikebot/src/strikebot/__main__.py b/strikebot/src/strikebot/__main__.py new file mode 100644 index 0000000..4c2b1d9 --- /dev/null +++ b/strikebot/src/strikebot/__main__.py @@ -0,0 +1,191 @@ +from enum import Enum +from inspect import getmodule +from logging import FileHandler, getLogger, StreamHandler +from pathlib import Path +from signal import SIGUSR1 +from sys import stdout +from traceback import StackSummary +from typing import Optional +import argparse +import configparser +import datetime as dt +import logging + +from trio import open_memory_channel, open_nursery, open_signal_receiver +from trio.lowlevel import current_root_task, Task +import trio_asyncio +import triopg + +from strikebot import count_tracker_impl +from strikebot.db import Client, Messenger +from strikebot.live_ws import HealingReadPool, PoolMerger +from strikebot.reddit_api import ApiClientPool + + +_DEBUG_LOG_PATH: Optional[Path] = None # file to write debug logs to, if any + + +_TaskKind = Enum( + "_TaskKind", + [ + "API_WORKER", + "DB_CLIENT", + "TOKEN_UPDATE", + "TRACKER", + "WS_MERGE_READER", + "WS_MERGE_TIMER", + "WS_REFRESHER", + "WS_WORKER", + ] +) + + +def _classify_task(task: Task) -> _TaskKind: + mod_name = getmodule(task.coro.cr_code).__name__ + if mod_name.startswith("trio_asyncio."): + return None + elif task.coro.__name__ == "_signal_handler_impl": + return None + else: + by_coro_name = { + "db_client_impl": _TaskKind.DB_CLIENT, + "token_updater_impl": _TaskKind.TOKEN_UPDATE, + "event_reader_impl": _TaskKind.WS_MERGE_READER, + "timeout_handler_impl": _TaskKind.WS_MERGE_TIMER, + "conn_refresher_impl": _TaskKind.WS_REFRESHER, + "count_tracker_impl": _TaskKind.TRACKER, + "_reader_impl": _TaskKind.WS_WORKER, + "worker_impl": _TaskKind.API_WORKER, + } + return by_coro_name[task.coro.__name__] + + +def _get_all_tasks(): + [ta_loop] = [ + task + for nursery in current_root_task().child_nurseries + for task in nursery.child_tasks + if getmodule(task.coro.cr_code).__name__.startswith("trio_asyncio.") + ] + for nursery in ta_loop.child_nurseries: + yield from nursery.child_tasks + + +def _print_task_info(): + groups = {kind: [] for kind in _TaskKind} + for task in _get_all_tasks(): + kind = _classify_task(task) + if kind: + coro = task.coro + frame_tuples = [] + while coro is not None: + if hasattr(coro, "cr_frame"): + frame_tuples.append((coro.cr_frame, coro.cr_frame.f_lineno)) + coro = coro.cr_await + else: + frame_tuples.append((coro.gi_frame, coro.gi_frame.f_lineno)) + coro = coro.gi_yieldfrom + groups[kind].append(StackSummary.extract(iter(frame_tuples))) + for kind in _TaskKind: + print(kind.name) + for ss in groups[kind]: + print(" task") + for line in ss.format(): + print(" " + line, end = "") + + +async def _signal_handler_impl(): + with open_signal_receiver(SIGUSR1) as sig_src: + async for _ in sig_src: + _print_task_info() + + +ap = argparse.ArgumentParser(__package__) +ap.add_argument("config_path") +args = ap.parse_args() + + +parser = configparser.ConfigParser() +with open(args.config_path) as config_file: + parser.read_file(config_file) + +main_cfg = parser["config"] + +api_pool_error_delay = dt.timedelta(seconds = main_cfg.getfloat("API pool error delay")) +api_pool_error_window = dt.timedelta(seconds = main_cfg.getfloat("API pool error window")) +auth_ids = set(map(int, main_cfg["auth IDs"].split())) +bot_user = main_cfg["bot user"] +enforcing = main_cfg.getboolean("enforcing") +reorder_buffer_time = dt.timedelta(seconds = main_cfg.getfloat("reorder buffer time")) +request_queue_limit = main_cfg.getint("request queue limit") +thread_id = main_cfg["thread ID"] +update_retention = dt.timedelta(seconds = main_cfg.getfloat("update retention time")) +ws_parity_time = dt.timedelta(seconds = main_cfg.getfloat("WS parity time")) +ws_pool_size = main_cfg.getint("WS pool size") +ws_silent_limit = dt.timedelta(seconds = main_cfg.getfloat("WS silent limit")) +ws_warmup = dt.timedelta(seconds = main_cfg.getfloat("WS warmup seconds")) + +db_cfg = parser["db connect params"] +getters = { + "port": db_cfg.getint, +} +db_connect_params = {k: getters.get(k, db_cfg.get)(k) for k in db_cfg} + + +logger = getLogger(__package__) +logger.setLevel(logging.DEBUG) + +handler = StreamHandler(stdout) +handler.setLevel(logging.WARNING) +handler.setFormatter(logging.Formatter("{asctime:23}: {name:17} {levelname:8}: {message}", style = "{")) +logger.addHandler(handler) + +if _DEBUG_LOG_PATH: + debug_handler = FileHandler(_DEBUG_LOG_PATH, mode = "w") + debug_handler.setLevel(logging.DEBUG) + debug_handler.setFormatter(logging.Formatter("{asctime:23}: {name:17} {levelname:8}: {message}", style = "{")) + logger.addHandler(debug_handler) + + +async def main(): + async with ( + triopg.connect(**db_connect_params) as db_conn, + open_nursery() as nursery_a, + open_nursery() as nursery_b, + open_nursery() as ws_pool_nursery, + open_nursery() as nursery_c, + ): + nursery_a.start_soon(_signal_handler_impl) + + client = Client(db_conn) + db_messenger = Messenger(client) + nursery_a.start_soon(db_messenger.db_client_impl) + + api_pool = ApiClientPool( + auth_ids, db_messenger, request_queue_limit, api_pool_error_window, api_pool_error_delay, + logger.getChild("api") + ) + nursery_a.start_soon(api_pool.token_updater_impl) + for _ in auth_ids: + await nursery_a.start(api_pool.worker_impl) + + (message_tx, message_rx) = open_memory_channel(0) + (pool_event_tx, pool_event_rx) = open_memory_channel(0) + + merger = PoolMerger(pool_event_rx, message_tx, ws_parity_time, ws_warmup, logger.getChild("merge")) + nursery_b.start_soon(merger.event_reader_impl) + nursery_b.start_soon(merger.timeout_handler_impl) + + ws_pool = HealingReadPool( + ws_pool_nursery, ws_pool_size, thread_id, api_pool, pool_event_tx, logger.getChild("live_ws"), + ws_silent_limit + ) + nursery_c.start_soon(ws_pool.conn_refresher_impl) + + nursery_c.start_soon( + count_tracker_impl, message_rx, api_pool, reorder_buffer_time, thread_id, bot_user, enforcing, + update_retention, logger.getChild("track") + ) + await ws_pool.init_workers() + +trio_asyncio.run(main) diff --git a/strikebot/src/strikebot/common.py b/strikebot/src/strikebot/common.py new file mode 100644 index 0000000..519037a --- /dev/null +++ b/strikebot/src/strikebot/common.py @@ -0,0 +1,10 @@ +from typing import Any + + +def int_digest(val: int) -> str: + return "{:04x}".format(val & 0xffff) + + +def obj_digest(obj: Any) -> str: + """Makes a short digest of the identity of an object.""" + return int_digest(id(obj)) diff --git a/strikebot/src/strikebot/db.py b/strikebot/src/strikebot/db.py new file mode 100644 index 0000000..c9a060a --- /dev/null +++ b/strikebot/src/strikebot/db.py @@ -0,0 +1,58 @@ +"""Single-connection Postgres client with high-level wrappers for operations.""" + +from dataclasses import dataclass +from typing import Any, Iterable + +import trio + + +def _channel_sender(method): + async def wrapped(self, resp_channel, *args, **kwargs): + ret = await method(self, *args, **kwargs) + with resp_channel: + await resp_channel.send(ret) + + return wrapped + + +@dataclass +class Client: + """High-level wrappers for DB operations.""" + conn: Any + + @_channel_sender + async def get_auth_tokens(self, auth_ids: set[int]) -> dict[int, str]: + results = await self.conn.fetch( + """ + select + distinct on (id) + id, access_token + from + public.reddit_app_authorization + join unnest($1::integer[]) as request_id on id = request_id + where expires > current_timestamp + """, + list(auth_ids), + ) + return {r["id"]: r["access_token"] for r in results} + + +class Messenger: + def __init__(self, client: Client): + self._client = client + (self._request_tx, self._request_rx) = trio.open_memory_channel(0) + + async def do(self, method_name: str, args: Iterable) -> Any: + """This is run by consumers of the DB wrapper.""" + method = getattr(self._client, method_name) + (resp_tx, resp_rx) = trio.open_memory_channel(0) + coro = method(resp_tx, *args) + await self._request_tx.send(coro) + async with resp_rx: + return await resp_rx.receive() + + async def db_client_impl(self): + """This is the DB client Trio task.""" + async with self._client.conn, self._request_rx: + async for coro in self._request_rx: + await coro diff --git a/strikebot/src/strikebot/live_ws.py b/strikebot/src/strikebot/live_ws.py new file mode 100644 index 0000000..cb8b3d0 --- /dev/null +++ b/strikebot/src/strikebot/live_ws.py @@ -0,0 +1,266 @@ +from collections import deque +from contextlib import suppress +from dataclasses import dataclass +from typing import Any, AsyncContextManager, Optional +import datetime as dt +import json +import logging +import math + +from trio_websocket import ConnectionClosed, HandshakeError, open_websocket_url, WebSocketConnection +import trio + +from strikebot.common import int_digest, obj_digest +from strikebot.reddit_api import AboutLiveThreadRequest + + +@dataclass +class _PoolEvent: + timestamp: float # Trio time + + +@dataclass +class _Message(_PoolEvent): + data: Any + scope: Any + + def dedup_tag(self): + """A hashable essence of the contents of this message.""" + if self.data["type"] == "update": + return ("update", self.data["payload"]["data"]["id"]) + elif self.data["type"] in {"strike", "delete"}: + return (self.data["type"], self.data["payload"]) + else: + raise ValueError(self.data["type"]) + + +@dataclass +class _ConnectionDown(_PoolEvent): + scope: trio.CancelScope + + +@dataclass +class _ConnectionUp(_PoolEvent): + scope: trio.CancelScope + + +class HealingReadPool: + def __init__( + self, pool_nursery: trio.Nursery, size, live_thread_id, api_client_pool, pool_event_tx, logger: logging.Logger, + silent_limit: dt.timedelta + ): + assert size >= 2 + self._nursery = pool_nursery + self._size = size + self._live_thread_id = live_thread_id + self._api_client_pool = api_client_pool + self._pool_event_tx = pool_event_tx + self._logger = logger + self._silent_limit = silent_limit + + (self._refresh_queue_tx, self._refresh_queue_rx) = trio.open_memory_channel(size) + + # number of workers who have stopped receiving updates but whose tasks aren't yet stopped + self._closing = 0 + + async def init_workers(self): + for _ in range(self._size): + await self._spawn_reader() + if not self._nursery.child_tasks: + raise RuntimeError("Unable to create any WS connections") + + async def _reader_impl(self, conn_ctx: AsyncContextManager, cancel_scope, refresh_tx): + # TODO could use task's implicit cancel scope + try: + async with conn_ctx as conn: + self._logger.debug("scope up: {}".format(obj_digest(cancel_scope))) + with refresh_tx: + silent_timeout = False + with cancel_scope, suppress(ConnectionClosed): + while True: + with trio.move_on_after(self._silent_limit.total_seconds()) as timeout_scope: + message = await conn.get_message() + + if timeout_scope.cancelled_caught: + if len(self._nursery.child_tasks) - self._closing == self._size: + silent_timeout = True + break + else: + self._logger.debug("not replacing connection {}; {} tasks, {} closing".format( + obj_digest(cancel_scope), + len(self._nursery.child_tasks), + self._closing, + )) + else: + event = _Message(trio.current_time(), json.loads(message), cancel_scope) + await self._pool_event_tx.send(event) + + self._closing += 1 + if silent_timeout: + await conn.aclose() + self._logger.debug("replacing WS connection {} due to silent timeout".format(obj_digest(cancel_scope))) + elif cancel_scope.cancelled_caught: + await conn.aclose(1008, "Server unexpectedly stopped sending messages") + self._logger.warning("replacing WS connection {} due to missed update".format(obj_digest(cancel_scope))) + else: + self._logger.warning("replacing WS connection {} closed by server".format(obj_digest(cancel_scope))) + + refresh_tx.send_nowait(cancel_scope) + self._logger.debug("scope down: {} ({} tasks, {} closing)".format( + obj_digest(cancel_scope), + len(self._nursery.child_tasks), + self._closing, + )) + self._closing -= 1 + except HandshakeError: + self._logger.error("handshake error while opening WS connection") + + async def _new_connection(self) -> Optional[AsyncContextManager[WebSocketConnection]]: + request = AboutLiveThreadRequest(self._live_thread_id) + resp = await self._api_client_pool.make_request(request) + if resp.status_code == 200: + url = json.loads(resp.text)["data"]["websocket_url"] + return open_websocket_url(url) + else: + self._logger.error(f"thread info request failed (HTTP {resp.status_code}); can't open WS connection") + return None + + async def _spawn_reader(self) -> None: + conn = await self._new_connection() + if conn: + new_scope = trio.CancelScope() + await self._pool_event_tx.send(_ConnectionUp(trio.current_time(), new_scope)) + refresh_tx = self._refresh_queue_tx.clone() + self._nursery.start_soon(self._reader_impl, conn, new_scope, refresh_tx) + + async def conn_refresher_impl(self): + """Task to monitor and replace WS connections as they disconnect or go silent.""" + with self._refresh_queue_rx: + async for old_scope in self._refresh_queue_rx: + await self._pool_event_tx.send(_ConnectionDown(trio.current_time(), old_scope)) + await self._spawn_reader() + if not self._nursery.child_tasks: + raise RuntimeError("WS pool depleted") + + +def _tag_digest(tag): + return int_digest(hash(tag)) + + +def _format_scope_list(scopes): + return ", ".join(sorted(map(obj_digest, scopes))) + + +class PoolMerger: + @dataclass + class _Bucket: + start: float + recipients: set + + def __init__(self, pool_event_rx, message_tx, parity_timeout: dt.timedelta, conn_warmup: dt.timedelta, logger: logging.Logger): + """ + :param parity_timeout: max delay between message receipt on different connections + :param conn_warmup: max interval after connection within which missed updates are allowed + """ + self._pool_event_rx = pool_event_rx + self._message_tx = message_tx + self._parity_timeout = parity_timeout + self._conn_warmup = conn_warmup + self._logger = logger + + self._buckets = {} + self._scope_activations = {} + self._pending = deque() + self._outgoing_scopes = set() + (self._timer_poke_tx, self._timer_poke_rx) = trio.open_memory_channel(math.inf) + + def _log_current_buckets(self): + self._logger.debug("current buckets: {}".format(", ".join(map(_tag_digest, self._buckets.keys())))) + + async def event_reader_impl(self): + """Drop unused messages, deduplicate useful ones, and communicate with the timeout handler.""" + with self._pool_event_rx, self._timer_poke_tx: + async for event in self._pool_event_rx: + if isinstance(event, _ConnectionUp): + # An early add of an active scope could mean it's expected on a message that fired before it opened, + # resulting in a false positive for replacement. The timer task merges these in among the buckets by + # timestamp to avoid that. + self._pending.append(event) + elif isinstance(event, _Message): + if event.data["type"] in ["update", "strike", "delete"]: + tag = event.dedup_tag() + self._logger.debug("recv {} from {}".format(_tag_digest(tag), obj_digest(event.scope))) + if tag in self._buckets: + b = self._buckets[tag] + b.recipients.add(event.scope) + # If this scope is the last one for this bucket we could clear the bucket here, but since + # connections sometimes get repeat messages and second copies can arrive before the first + # copy has arrived on all connections, leaving the bucket open to absorb repeats can reduce + # the likelihood of a second bucket being allocated late for the second copy of a message + # and causing unnecessary connection replacement. + elif event.scope not in self._outgoing_scopes: + sane = ( + event.scope in self._scope_activations + or any(e.scope == event.scope for e in self._pending) + ) + if sane: + self._logger.debug("new bucket {}: {}".format(_tag_digest(tag), tag)) + self._buckets[tag] = self._Bucket(event.timestamp, {event.scope}) + self._log_current_buckets() + await self._message_tx.send(event) + else: + raise RuntimeError("recieved message from unrecognized WS connection") + else: + self._logger.debug("recv type {!r} from {} (discarding)".format(event.data["type"], obj_digest(event.scope))) + elif isinstance(event, _ConnectionDown): + # We don't need to worry about canceling this scope at all, so no need to require it for parity for + # any message, even older ones. The scope may be gone already, if we canceled it previously. + self._scope_activations.pop(event.scope, None) + self._pending = deque(ev for ev in self._pending if ev.scope is not event.scope) + self._outgoing_scopes.discard(event.scope) + else: + raise TypeError(f"Expected pool event, found {event!r}") + self._timer_poke_tx.send_nowait(None) # may have new work for the timer + + async def timeout_handler_impl(self): + """When connections have had enough time to reach parity on a message, signal replacement of any slackers.""" + with self._timer_poke_rx: + while True: + if self._buckets: + now = trio.current_time() + (tag, bucket) = min(self._buckets.items(), key = lambda i: i[1].start) + + # Make sure the right scope set is ready for the next bucket up + while self._pending and self._pending[0].timestamp < bucket.start: + ev = self._pending.popleft() + self._scope_activations[ev.scope] = ev.timestamp + + target_scopes = { + scope for (scope, active) in self._scope_activations.items() + if active + self._conn_warmup.total_seconds() < now + } + if now > bucket.start + self._parity_timeout.total_seconds(): + filled = target_scopes & bucket.recipients + missing = target_scopes - bucket.recipients + extra = bucket.recipients - target_scopes + + self._logger.debug("expiring bucket {}".format(_tag_digest(tag))) + if filled: + self._logger.debug(" filled {}: {}".format(len(filled), _format_scope_list(filled))) + if missing: + self._logger.debug(" missing {}: {}".format(len(missing), _format_scope_list(missing))) + if extra: + self._logger.debug(" extra {}: {}".format(len(extra), _format_scope_list(extra))) + + for scope in missing: + self._outgoing_scopes.add(scope) + scope.cancel() + del self._buckets[tag] + self._log_current_buckets() + else: + await trio.sleep_until(bucket.start + self._parity_timeout.total_seconds()) + else: + try: + await self._timer_poke_rx.receive() + except trio.EndOfChannel: + break diff --git a/strikebot/src/strikebot/queue.py b/strikebot/src/strikebot/queue.py new file mode 100644 index 0000000..adfbbe6 --- /dev/null +++ b/strikebot/src/strikebot/queue.py @@ -0,0 +1,58 @@ +"""Unbounded blocking queues for Trio""" + +from collections import deque +from dataclasses import dataclass +from functools import total_ordering +from typing import Any, Iterable +import heapq + +from trio.lowlevel import ParkingLot + + +class Queue: + def __init__(self): + self._deque = deque() + self._empty_wait = ParkingLot() + + def push(self, el: Any) -> None: + self._deque.append(el) + self._empty_wait.unpark() + + def extend(self, els: Iterable[Any]) -> None: + for el in els: + self.push(el) + + async def pop(self) -> Any: + if not self._deque: + await self._empty_wait.park() + return self._deque.popleft() + + def __len__(self): + return len(self._deque) + + +@total_ordering +@dataclass +class _ReverseOrdWrapper: + inner: Any + + def __lt__(self, other): + return other.inner < self.inner + + +class MaxHeap: + def __init__(self): + self._heap = [] + self._empty_wait = ParkingLot() + + def push(self, item): + heapq.heappush(self._heap, _ReverseOrdWrapper(item)) + self._empty_wait.unpark() + + async def pop(self): + if not self._heap: + await self._empty_wait.park() + return heapq.heappop(self._heap).inner + + def __len__(self) -> int: + return len(self._heap) diff --git a/strikebot/src/strikebot/reddit_api.py b/strikebot/src/strikebot/reddit_api.py new file mode 100644 index 0000000..3e69e6c --- /dev/null +++ b/strikebot/src/strikebot/reddit_api.py @@ -0,0 +1,291 @@ +"""Multi-auth HTTP connection pool with Reddit API wrappers and rate limiting.""" + +from abc import ABCMeta, abstractmethod +from dataclasses import dataclass, field +from functools import total_ordering +from socket import EAI_AGAIN, EAI_FAIL, gaierror +import datetime as dt +import logging + +from asks.response_objects import Response +import asks +import trio + +from strikebot import __version__ as VERSION +from strikebot.common import obj_digest +from strikebot.queue import MaxHeap, Queue + + +REQUEST_DELAY = dt.timedelta(seconds = 1) + +TOKEN_UPDATE_DELAY = dt.timedelta(minutes = 15) + +USER_AGENT_FMT = f"any:net.jcornell.strikebot.{{auth_id}}:v{VERSION} (by /u/jaxklax)" + +API_BASE_URL = "https://oauth.reddit.com" + + +@total_ordering +class _Request(metaclass = ABCMeta): + _SUBTYPE_PRECEDENCE = None # assigned later + + def __lt__(self, other): + if type(self) is type(other): + return self._subtype_cmp_key() < other._subtype_cmp_key() + else: + prec = self._SUBTYPE_PRECEDENCE + return prec.index(type(other)) < prec.index(type(self)) + + def __eq__(self, other): + if type(self) is type(other): + return self._subtype_cmp_key() == other._subtype_cmp_key() + else: + return False + + @abstractmethod + def _subtype_cmp_key(self): + # a large key corresponds to a high priority + raise NotImplementedError() + + @abstractmethod + def to_asks_kwargs(self): + raise NotImplementedError() + + +@dataclass(eq = False) +class StrikeRequest(_Request): + thread_id: str + update_name: str + update_ts: dt.datetime + + def to_asks_kwargs(self): + return { + "method": "POST", + "path": f"/api/live/{self.thread_id}/strike_update", + "data": { + "api_type": "json", + "id": self.update_name, + }, + } + + def _subtype_cmp_key(self): + return self.update_ts + + +@dataclass(eq = False) +class DeleteRequest(_Request): + thread_id: str + update_name: str + update_ts: dt.datetime + + def to_asks_kwargs(self): + return { + "method": "POST", + "path": f"/api/live/{self.thread_id}/delete_update", + "data": { + "api_type": "json", + "id": self.update_name, + }, + } + + def _subtype_cmp_key(self): + return -self.update_ts + + +@dataclass(eq = False) +class AboutLiveThreadRequest(_Request): + thread_id: str + created: float = field(default_factory = trio.current_time) + + def to_asks_kwargs(self): + return { + "method": "GET", + "path": f"/live/{self.thread_id}/about", + "params": { + "raw_json": "1", + }, + } + + def _subtype_cmp_key(self): + return -self.created + + +@dataclass(eq = False) +class _BaseUpdatePostRequest(_Request, metaclass = ABCMeta): + thread_id: str + body: str + created: float = field(default_factory = trio.current_time) + + def to_asks_kwargs(self): + return { + "method": "POST", + "path": f"/live/{self.thread_id}/update", + "data": { + "api_type": "json", + "body": self.body, + }, + } + + +@dataclass(eq = False) +class ReportUpdateRequest(_BaseUpdatePostRequest): + def _subtype_cmp_key(self): + return -self.created + + +@dataclass(eq = False) +class CorrectionUpdateRequest(_BaseUpdatePostRequest): + def _subtype_cmp_key(self): + return -self.created + + +# highest priority first +_Request._SUBTYPE_PRECEDENCE = [ + AboutLiveThreadRequest, + StrikeRequest, + CorrectionUpdateRequest, + ReportUpdateRequest, + DeleteRequest, +] + + +@dataclass +class AppCooldown: + auth_id: int + ready_at: float # Trio time + + +class ApiClientPool: + def __init__( + self, + auth_ids: set[int], + db_messenger: "strikebot.db.Messenger", + request_queue_limit: int, + error_window: dt.timedelta, + error_delay: dt.timedelta, + logger: logging.Logger, + ): + self._auth_ids = auth_ids + self._db_messenger = db_messenger + self._request_queue_limit = request_queue_limit + self._error_window = error_window + self._error_delay = error_delay + self._logger = logger + + now = trio.current_time() + self._tokens = {} + self._app_queue = Queue() + self._waiting = {id_: AppCooldown(id_, now) for id_ in auth_ids} + self._request_queue = MaxHeap() + + # pool-wide API error backoff + self._last_error = None + self._global_resume = None + + self._session = asks.Session(connections = len(auth_ids)) + self._session.base_location = API_BASE_URL + + async def _update_tokens(self): + tokens = await self._db_messenger.do("get_auth_tokens", (self._auth_ids,)) + self._tokens.update(tokens) + + awaken_auths = self._waiting.keys() & tokens.keys() + self._app_queue.extend(self._waiting.pop(auth_id) for auth_id in awaken_auths) + self._logger.debug("updated API tokens") + + async def token_updater_impl(self): + last_update = None + while True: + if last_update is not None: + await trio.sleep_until(last_update + TOKEN_UPDATE_DELAY.total_seconds()) + + if last_update is None: + last_update = trio.current_time() + else: + last_update += TOKEN_UPDATE_DELAY.total_seconds() + + await self._update_tokens() + + def _check_queue_size(self) -> None: + if len(self._request_queue) > self._request_queue_limit: + raise RuntimeError("request queue size exceeded limit") + + async def make_request(self, request: _Request) -> Response: + (resp_tx, resp_rx) = trio.open_memory_channel(0) + self.enqueue_request(request, resp_tx) + async with resp_rx: + return await resp_rx.receive() + + def enqueue_request(self, request: _Request, resp_tx = None) -> None: + self._logger.debug("request {}: {}".format(obj_digest(request), type(request).__name__)) + self._request_queue.push((request, resp_tx)) + self._check_queue_size() + self._logger.debug(f"{len(self._request_queue)} requests in queue") + + async def worker_impl(self, task_status): + task_status.started() + while True: + (request, resp_tx) = await self._request_queue.pop() + cooldown = await self._app_queue.pop() + await trio.sleep_until(cooldown.ready_at) + if self._global_resume: + await trio.sleep_until(self._global_resume) + + asks_kwargs = request.to_asks_kwargs() + headers = asks_kwargs.setdefault("headers", {}) + headers.update({ + "Authorization": "Bearer {}".format(self._tokens[cooldown.auth_id]), + "User-Agent": USER_AGENT_FMT.format(auth_id = cooldown.auth_id), + }) + + request_time = trio.current_time() + try: + resp = await self._session.request(**asks_kwargs) + except gaierror as e: + if e.errno in [EAI_FAIL, EAI_AGAIN]: + # DNS failure, probably temporary + error = True + else: + raise + else: + resp.body # read response + error = False + wait_for_token = False + log_suffix = " (request {})".format(obj_digest(request)) + if resp.status_code == 429: + # We disagreed about the rate limit state; just try again later. + self._logger.warning("rate limited by Reddit API" + log_suffix) + error = True + elif resp.status_code == 401: + self._logger.warning("got HTTP 401 from Reddit API" + log_suffix) + error = True + wait_for_token = True + elif resp.status_code in [404, 500, 503]: + self._logger.warning(f"got HTTP {resp.status_code} from Reddit API, retrying" + log_suffix) + error = True + elif 400 <= resp.status_code < 500: + # If we're doing something wrong, let's catch it right away. + raise RuntimeError(f"unexpected client error response: {resp.status_code}" + log_suffix) + else: + if resp.status_code != 200: + raise RuntimeError(f"unexpected status code {resp.status_code}") + self._logger.debug("success" + log_suffix) + if resp_tx: + await resp_tx.send(resp) + + if error: + self._request_queue.push((request, resp_tx)) + self._check_queue_size() + if self._last_error: + spread = dt.timedelta(seconds = request_time - self._last_error) + if spread <= self._error_window: + self._global_resume = request_time + self._error_delay.total_seconds() + self._last_error = request_time + + cooldown.ready_at = request_time + REQUEST_DELAY.total_seconds() + if wait_for_token: + self._waiting[cooldown.auth_id] = cooldown + if not self._app_queue: + self._logger.error("all workers waiting for API tokens") + else: + self._app_queue.push(cooldown) diff --git a/strikebot/src/strikebot/tests.py b/strikebot/src/strikebot/tests.py new file mode 100644 index 0000000..7bbf226 --- /dev/null +++ b/strikebot/src/strikebot/tests.py @@ -0,0 +1,50 @@ +from unittest import TestCase + +from strikebot.updates import parse_update + + +def _build_payload(body_html: str) -> dict[str, str]: + return {"body_html": body_html} + + +class UpdateParsingTests(TestCase): + def test_successful_counts(self): + pu = parse_update(_build_payload("
12,345,678 spaghetti
"), None, "") + self.assertEqual(pu.number, 12345678) + self.assertFalse(pu.deletable) + + pu = parse_update(_build_payload("

0


oz
"), None, "") + self.assertEqual(pu.number, 0) + self.assertFalse(pu.deletable) + + pu = parse_update(_build_payload("
121 345 621
"), 121, "") + self.assertEqual(pu.number, 121) + self.assertFalse(pu.deletable) + + pu = parse_update(_build_payload("28 336 816"), 28336816, "") + self.assertEqual(pu.number, 28336816) + self.assertTrue(pu.deletable) + + def test_non_counts(self): + pu = parse_update(_build_payload("
zoo
"), None, "") + self.assertFalse(pu.count_attempt) + self.assertFalse(pu.deletable) + + def test_typos(self): + pu = parse_update(_build_payload("v9"), 888, "") + self.assertIsNone(pu.number) + self.assertTrue(pu.count_attempt) + + pu = parse_update(_build_payload("
v11.585 Empire
"), None, "") + self.assertIsNone(pu.number) + self.assertTrue(pu.count_attempt) + self.assertFalse(pu.deletable) + + pu = parse_update(_build_payload("
11, 585, 22
"), 11_585_202, "") + self.assertIsNone(pu.number) + self.assertTrue(pu.count_attempt) + self.assertTrue(pu.deletable) + + pu = parse_update(_build_payload("0490499"), 4999, "") + self.assertIsNone(pu.number) + self.assertTrue(pu.count_attempt) diff --git a/strikebot/src/strikebot/updates.py b/strikebot/src/strikebot/updates.py new file mode 100644 index 0000000..e3f58cd --- /dev/null +++ b/strikebot/src/strikebot/updates.py @@ -0,0 +1,216 @@ +from __future__ import annotations +from dataclasses import dataclass +from enum import Enum +from typing import Optional +import re + +from bs4 import BeautifulSoup + + +Command = Enum("Command", ["RESET", "REPORT"]) + + +@dataclass +class ParsedUpdate: + number: Optional[int] + command: Optional[Command] + count_attempt: bool # either well-formed or typo + deletable: bool + + +def _parse_command(line: str, bot_user: str) -> Optional[Command]: + if line.lower() == f"/u/{bot_user} reset".lower(): + return Command.RESET + elif line.lower() in ["sidebar count", "current count"]: + return Command.REPORT + else: + return None + + +def parse_update(payload_data: dict, curr_count: Optional[int], bot_user: str) -> ParsedUpdate: + # curr_count is the next number up, one more than the last count + + NEW_LINE = object() + SPACE = object() + + # flatten the update content to plain text + tree = BeautifulSoup(payload_data["body_html"], "html.parser") + worklist = tree.contents + out = [[]] + while worklist: + el = worklist.pop() + if isinstance(el, str): + out[-1].append(el) + elif el is SPACE: + out[-1].append(el) + elif el is NEW_LINE or el.name == "br": + if out[-1]: + out.append([]) + elif el.name in ["em", "strong", "del", "span", "sup", "code", "a", "th", "td"]: + worklist.extend(reversed(el.contents)) + elif el.name in ["ul", "ol", "table", "thead", "tbody"]: + worklist.extend(reversed(el.contents)) + elif el.name in ["li", "p", "div", "h1", "h2", "blockquote"]: + worklist.append(NEW_LINE) + worklist.extend(reversed(el.contents)) + worklist.append(NEW_LINE) + elif el.name == "pre": + worklist.append(NEW_LINE) + worklist.extend([l] for l in reversed(el.text.splitlines())) + worklist.append(NEW_LINE) + elif el.name == "tr": + worklist.append(NEW_LINE) + for (i, cell) in enumerate(reversed(el.contents)): + worklist.append(cell) + if i != len(el.contents) - 1: + worklist.append(SPACE) + worklist.append(NEW_LINE) + else: + raise RuntimeError(f"can't parse tag {el.tag}") + + tmp_lines = ( + "".join(" " if part is SPACE else part for part in parts).strip() + for parts in out + ) + pre_strip_lines = list(filter(None, tmp_lines)) + + # normalize whitespace according to HTML rendering rules + # https://developer.mozilla.org/en-US/docs/Web/API/Document_Object_Model/Whitespace#explanation + stripped_lines = [ + re.sub(" +", " ", l.replace("\t", " ").replace("\n", " ")).strip(" ") + for l in pre_strip_lines + ] + + return _parse_from_lines(stripped_lines, curr_count, bot_user) + + +def _parse_from_lines(lines: list[str], curr_count: Optional[int], bot_user: str) -> ParsedUpdate: + command = next( + filter(None, (_parse_command(l, bot_user) for l in lines)), + None + ) + if lines: + # look for groups of digits (as many as possible) separated by a uniform separator from the valid set + first = lines[0] + match = re.match( + "(?Pv)?(?P-)?(?P\\d+((?P[,. \u2009]|, )\\d+((?P=sep)\\d+)*)?)", + first, + re.ASCII, # only recognize ASCII digits + ) + if match: + raw_digits = match["num"] + sep = match["sep"] + post = first[match.end() :] + + zeros = False + while len(raw_digits) > 1 and raw_digits[0] == "0": + zeros = True + raw_digits = raw_digits.removeprefix("0").removeprefix(sep or "") + + parts = raw_digits.split(sep) if sep else [raw_digits] + lone = len(lines) == 1 and (not post or post.isspace()) + typo = False + if lone: + all_parts_valid = ( + sep is None + or ( + 1 <= len(parts[0]) <= 3 + and all(len(p) == 3 for p in parts[1:]) + ) + ) + if match["v"] and len(parts) == 1 and len(parts[0]) <= 2: + # failed paste of leading digits + typo = True + elif match["v"] and all_parts_valid: + # v followed by count + typo = True + elif curr_count is not None and abs(curr_count) >= 100 and bool(match["neg"]) == (curr_count < 0): + goal_parts = _separate(str(abs(curr_count))) + partials = [ + goal_parts[: -1] + [goal_parts[-1][: -1]], # missing last digit + goal_parts[: -1] + [goal_parts[-1][: -2]], # missing last two digits + goal_parts[: -1] + [goal_parts[-1][: -2] + goal_parts[-1][-1]], # missing second-last digit + ] + if parts in partials: + # missing any of last two digits + typo = True + elif parts in [p[: -1] + [p[-1] + goal_parts[0]] + goal_parts[1 :] for p in partials]: + # double paste + typo = True + + if match["v"] or zeros or typo or (parts == ["0"] and match["neg"]): + number = None + count_attempt = True + deletable = lone + else: + if curr_count is not None and sep and sep.isspace(): + # Presume that the intended count consists of as many valid digit groups as necessary to match the + # number of digits in the expected count, if possible. + digit_count = len(str(abs(curr_count))) + use_parts = [] + accum = 0 + for (i, part) in enumerate(parts): + part_valid = len(part) <= 3 if i == 0 else len(part) == 3 + if part_valid and accum < digit_count: + use_parts.append(part) + accum += len(part) + else: + break + + # could still be a no-separator count with some extra digit groups on the same line + if not use_parts: + use_parts = [parts[0]] + + lone = lone and len(use_parts) == len(parts) + else: + # current count is unknown or no separator was used + use_parts = parts + + digits = "".join(use_parts) + number = -int(digits) if match["neg"] else int(digits) + special = ( + curr_count is not None + and abs(number - curr_count) <= 25 + and _is_special_number(number) + ) + deletable = lone and not special + if len(use_parts) == len(parts) and post and not post[0].isspace(): + count_attempt = curr_count is not None and abs(number - curr_count) <= 25 + number = None + else: + count_attempt = True + else: + # no count attempt found + number = None + count_attempt = False + deletable = False + else: + # no lines in update + number = None + count_attempt = False + deletable = True + + return ParsedUpdate( + number = number, + command = command, + count_attempt = count_attempt, + deletable = deletable, + ) + + +def _separate(digits: str) -> list[str]: + mod = len(digits) % 3 + out = [] + if mod: + out.append(digits[: mod]) + out.extend(digits[i : i + 3] for i in range(mod, len(digits), 3)) + return out + + +def _is_special_number(num: int) -> bool: + num_str = str(num) + return bool( + num % 1000 in [0, 1, 333, 999] + or (num > 10_000_000 and "".join(reversed(num_str)) == num_str) + or re.match(r"(.+)\1+$", num_str) # repeated sequence + )