From c63499942ce0f722691ade572fbb262dab8aafb2 Mon Sep 17 00:00:00 2001 From: Jakob Cornell Date: Thu, 28 Apr 2022 00:44:29 -0500 Subject: [PATCH 1/1] Full MVP and some unit tests --- .gitignore | 3 + docs/sample_config.ini | 28 ++++ pyproject.toml | 3 + setup.cfg | 14 ++ setup.py | 4 + src/strikebot/__init__.py | 271 ++++++++++++++++++++++++++++++++++++ src/strikebot/__main__.py | 77 ++++++++++ src/strikebot/db.py | 44 ++++++ src/strikebot/live_ws.py | 203 +++++++++++++++++++++++++++ src/strikebot/queue.py | 36 +++++ src/strikebot/reddit_api.py | 228 ++++++++++++++++++++++++++++++ src/strikebot/tests.py | 42 ++++++ src/strikebot/updates.py | 202 +++++++++++++++++++++++++++ 13 files changed, 1155 insertions(+) create mode 100644 .gitignore create mode 100644 docs/sample_config.ini create mode 100644 pyproject.toml create mode 100644 setup.cfg create mode 100644 setup.py create mode 100644 src/strikebot/__init__.py create mode 100644 src/strikebot/__main__.py create mode 100644 src/strikebot/db.py create mode 100644 src/strikebot/live_ws.py create mode 100644 src/strikebot/queue.py create mode 100644 src/strikebot/reddit_api.py create mode 100644 src/strikebot/tests.py create mode 100644 src/strikebot/updates.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..fbf1bc5 --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +*.egg-info +__pycache__ +.mypy_cache diff --git a/docs/sample_config.ini b/docs/sample_config.ini new file mode 100644 index 0000000..66fc49e --- /dev/null +++ b/docs/sample_config.ini @@ -0,0 +1,28 @@ +[config] +bot user = count_better + +# (seconds) maximum allowable spread in WebSocket message arrival times +WS parity time = 1.0 + +# goal size of WebSocket connection pool +WS pool size = 5 + +# (seconds) time after WebSocket handshake completion during which missed updates are excused +WS warmup time = 0.3 + +# (seconds) maximum time to hold updates for reordering +reorder buffer time = 0.25 + +enforcing = true + +# (seconds) minimum time to retain updates to enable future resyncs +update retention time = 120.0 + +thread ID = abc123 + +# space-separated authorization IDs for Reddit API token lookup +auth IDs = 13 15 + +# Postgres database configuration; same options as in a connect string +[db connect params] +host = example.org diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..8fe2f47 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,3 @@ +[build-system] +requires = ["setuptools>=42", "wheel"] +build-backend = "setuptools.build_meta" diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 0000000..1c54cc8 --- /dev/null +++ b/setup.cfg @@ -0,0 +1,14 @@ +[metadata] +name = strikebot +version = 0.0.0 + +[options] +package_dir = + = src +packages = strikebot +python_requires = ~= 3.9 +install_requires = + trio ~= 0.19 + triopg ~= 0.6 + trio-websocket ~= 0.9 + asks ~= 2.4 diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..056ba45 --- /dev/null +++ b/setup.py @@ -0,0 +1,4 @@ +import setuptools + + +setuptools.setup() diff --git a/src/strikebot/__init__.py b/src/strikebot/__init__.py new file mode 100644 index 0000000..63fff59 --- /dev/null +++ b/src/strikebot/__init__.py @@ -0,0 +1,271 @@ +from contextlib import nullcontext as nullcontext +from dataclasses import dataclass +from functools import total_ordering +from typing import Optional, Set +from uuid import UUID +import bisect +import datetime as dt +import importlib.metadata +import itertools +import logging +import re + +from trio import CancelScope, EndOfChannel +import trio + +from strikebot.updates import Command, parse_update + + +__version__ = importlib.metadata.version(__package__) + + +@dataclass +class _Update: + id: str + name: str + author: str + number: Optional[int] + command: Optional[Command] + ts: dt.datetime + count_attempt: bool + deletable: bool + stricken: bool + + def can_follow(self, prior: "_Update") -> bool: + """Determine whether this count update can follow another count update.""" + return self.number == prior.number + 1 and self.author != prior.author + + +@total_ordering +@dataclass +class _BufferedUpdate: + update: _Update + release_at: float # Trio time + + def __lt__(self, other): + return self.update.ts < other.update.ts + + +@total_ordering +@dataclass +class _TimelineUpdate: + update: _Update + accepted: bool + bad_count: bool + + def __lt__(self, other): + return self.update.ts < other.update.ts + + +def _parse_rfc_4122_ts(ts: int) -> dt.datetime: + epoch = dt.datetime(1582, 10, 15, tzinfo = dt.timezone.utc) + return epoch + dt.timedelta(microseconds = ts / 10) + + +def _format_update_ref(update: _Update, thread_id: str) -> str: + url = f"https://www.reddit.com/live/{thread_id}/updates/{update.id}" + author_md = update.author.replace("_", "\\_") # Reddit allows letters, digits, underscore, and hyphen + return f"[{update.number:,} by {author_md}]({url})" + + +def _format_bad_strike_alert(update: _Update, thread_id: str) -> str: + return _format_update_ref(update, thread_id) + " was incorrectly stricken." + + +def _format_curr_count(last_valid: Optional[_Update]) -> str: + if last_valid and last_valid.number is not None: + count_str = f"{last_valid.number + 1:,}" + else: + count_str = "unknown" + return f"_current count:_ {count_str}" + + +async def count_tracker_impl( + message_rx, + api_pool: "strikebot.reddit_api.ApiClientPool", + reorder_buffer_time: dt.timedelta, # duration to hold some updates for reordering + thread_id: str, + bot_user: str, + enforcing: bool, + update_retention: dt.timedelta, + logger: logging.Logger, +) -> None: + from strikebot.reddit_api import CorrectionUpdateRequest, DeleteRequest, ReportUpdateRequest, StrikeRequest + + buffer_ = [] + timeline = [] + last_valid: Optional[_Update] = None + pending_strikes: Set[str] = set() # names of updates to mark stricken on arrival + + def handle_update(update): + nonlocal delete_start, last_valid + + tu = _TimelineUpdate(update, accepted = False, bad_count = False) + + pos = bisect.bisect(timeline, tu) + if pos != len(timeline): + logger.warning(f"long transpo: {update.name}") + + pred = next( + (timeline[i].update for i in reversed(range(pos)) if timeline[i].accepted), + None + ) + if update.command is not Command.RESET and update.number is not None and last_valid and not pred: + logger.warning("ignoring {update.name}: no valid prior count on record") + else: + tu.accepted = ( + update.command is Command.RESET + or ( + update.number is not None + and (pred is None or pred.number is None or update.can_follow(pred)) + ) + ) + timeline.insert(pos, tu) + if tu.accepted: + # resync subsequent updates already processed + newly_valid = [] + newly_invalid = [] + last_valid = update + for scan_tu in timeline[pos + 1:]: + if scan_tu.update.command is Command.RESET: + last_valid = scan_tu.update + elif scan_tu.update.number is not None: + accept = last_valid.number is None or scan_tu.update.can_follow(last_valid) + if accept and not scan_tu.accepted: + newly_valid.append(scan_tu) + elif not accept and scan_tu.accepted: + newly_invalid.append(scan_tu) + scan_tu.accepted = accept + if accept: + last_valid = scan_tu.update + + parts = [] + if newly_valid: + parts.append( + "The following counts are valid:\n\n" + + "\n".join(" - " + _format_update_ref(tu.update, thread_id) for tu in newly_valid) + ) + for tu in newly_valid: + tu.bad_count = False + + unstrikable = [tu for tu in newly_invalid if tu.update.stricken] + if unstrikable: + parts.append( + "The following counts are invalid:\n\n" + + "\n".join(" - " + _format_update_ref(tu.update, thread_id) for tu in unstrikable) + ) + for tu in unstrikable: + tu.bad_count = True + + if update.stricken: + logger.info(f"bad strike of {update.name}") + parts.append(_format_bad_strike_alert(update, thread_id)) + + for invalid_tu in newly_invalid: + if not invalid_tu.update.stricken: + if enforcing: + api_pool.enqueue_request(StrikeRequest(thread_id, invalid_tu.update.name, invalid_tu.update.ts)) + invalid_tu.bad_count = True + invalid_tu.update.stricken = True + if parts: + parts.append(_format_curr_count(last_valid)) + api_pool.enqueue_request(CorrectionUpdateRequest(thread_id, "\n\n".join(parts))) + elif update.number is not None or update.count_attempt: + if enforcing: + api_pool.enqueue_request(StrikeRequest(thread_id, update.name, update.ts)) + tu.bad_count = True + update.stricken = True + + if update.command is Command.REPORT: + api_pool.enqueue_request(ReportUpdateRequest(thread_id, body = _format_curr_count(last_valid))) + + with message_rx: + while True: + if buffer_: + deadline = min(bu.release_at for bu in buffer_) + cancel_scope = CancelScope(deadline = deadline) + else: + cancel_scope = nullcontext() + + msg = None + with cancel_scope: + try: + msg = await message_rx.receive() + except EndOfChannel: + break + + if msg: + if msg.data["type"] == "update": + payload_data = msg.data["payload"]["data"] + stricken = payload_data["name"] in pending_strikes + pending_strikes.discard(payload_data["name"]) + if payload_data["author"] != bot_user: + next_up = last_valid.number if last_valid else None + pu = parse_update(payload_data, next_up, bot_user) + rfc_ts = UUID(payload_data["id"]).time + update = _Update( + id = payload_data["id"], + name = payload_data["name"], + author = payload_data["author"], + number = pu.number, + command = pu.command, + ts = _parse_rfc_4122_ts(rfc_ts), + count_attempt = pu.count_attempt, + deletable = pu.deletable, + stricken = stricken or payload_data["stricken"], + ) + release_at = msg.timestamp + reorder_buffer_time.total_seconds() + bisect.insort(buffer_, _BufferedUpdate(update, release_at)) + elif msg.data["type"] == "strike": + UUID(re.match("LiveUpdate_(.+)$", msg.data["payload"])[1]) # sanity check payload + slot = next( + ( + slot for slot in itertools.chain(buffer_, reversed(timeline)) + if slot.update.name == msg.data["payload"] + ), + None + ) + if slot: + slot.update.stricken = True + if isinstance(slot, _TimelineUpdate) and slot.accepted: + logger.info(f"bad strike of {slot.update.name}") + body = _format_bad_strike_alert(slot.update, thread_id) + api_pool.enqueue_request(CorrectionUpdateRequest(thread_id, body)) + else: + pending_strikes.add(msg.data["payload"]) + + threshold = dt.datetime.now(dt.timezone.utc) - update_retention + + # pull updates from the reorder buffer and process + new_buffer = [] + for bu in buffer_: + process = ( + # not a count + bu.update.number is None + + # valid count + or (last_valid is None or last_valid.number is None or bu.update.can_follow(last_valid)) + + # invalid and not likely to become valid by transpo + or bu.update.number > last_valid.number + 3 + + or trio.current_time() >= bu.release_at + ) + + if process: + if bu.update.ts < threshold: + logger.warning(f"ignoring {bu.update.name}: arrived past retention window") + else: + handle_update(bu.update) + else: + new_buffer.append(bu) + buffer_ = new_buffer + + # delete/forget old updates + for (i, tu) in enumerate(timeline): + if i >= len(timeline) - 10 or tu.update.ts >= threshold: + break + elif tu.bad_count and tu.update.deletable: + if enforcing: + api_pool.enqueue_request(DeleteRequest(thread_id, tu.update.name, tu.update.ts)) + del timeline[: i] diff --git a/src/strikebot/__main__.py b/src/strikebot/__main__.py new file mode 100644 index 0000000..a1c6705 --- /dev/null +++ b/src/strikebot/__main__.py @@ -0,0 +1,77 @@ +import argparse +import configparser +import datetime as dt +import logging +import sys + +import trio +import trio_asyncio +import triopg + +from strikebot import count_tracker_impl +from strikebot.db import Client, Messenger +from strikebot.live_ws import HealingReadPool, PoolMerger +from strikebot.reddit_api import ApiClientPool + + +ap = argparse.ArgumentParser(__package__) +ap.add_argument("config_path") +args = ap.parse_args() + + +parser = configparser.ConfigParser() +with open(args.config_path) as config_file: + parser.read_file(config_file) + +main_cfg = parser["config"] + +auth_ids = set(map(int, main_cfg["auth IDs"].split())) +bot_user = main_cfg["bot user"] +enforcing = main_cfg.getboolean("enforcing") +reorder_buffer_time = dt.timedelta(seconds = main_cfg.getfloat("reorder buffer time")) +thread_id = main_cfg["thread ID"] +update_retention = dt.timedelta(seconds = main_cfg.getfloat("update retention time")) +ws_parity_time = dt.timedelta(seconds = main_cfg.getfloat("WS parity time")) +ws_pool_size = main_cfg.getint("WS pool size") +ws_warmup = dt.timedelta(seconds = main_cfg.getfloat("WS warmup seconds")) + +db_connect_params = dict(parser["db connect params"]) + + +logger = logging.getLogger(__package__) +logger.setLevel(logging.DEBUG) + +handler = logging.StreamHandler(sys.stdout) +handler.setFormatter(logging.Formatter("{asctime:23}: {levelname:8}: {message}", style = "{")) +logger.addHandler(handler) + + +async def main(): + async with trio.open_nursery() as nursery_a, triopg.connect(**db_connect_params) as db_conn: + (req_tx, req_rx) = trio.open_memory_channel(0) + client = Client(db_conn) + db_messenger = Messenger(client) + + nursery_a.start_soon(db_messenger.db_client_impl) + + api_pool = ApiClientPool(auth_ids, db_messenger, logger) + nursery_a.start_soon(api_pool.token_updater_impl) + for _ in auth_ids: + nursery_a.start_soon(api_pool.worker_impl) + + (message_tx, message_rx) = trio.open_memory_channel(0) + (pool_event_tx, pool_event_rx) = trio.open_memory_channel(0) + merger = PoolMerger(pool_event_rx, message_tx, ws_parity_time, ws_warmup, logger) + async with trio.open_nursery() as nursery_b: + nursery_b.start_soon(merger.event_reader_impl) + nursery_b.start_soon(merger.timeout_handler_impl) + pool = HealingReadPool(nursery_b, ws_pool_size, thread_id, api_pool, pool_event_tx, logger) + async with pool, trio.open_nursery() as nursery_c: + nursery_c.start_soon(pool.conn_refresher_impl) + nursery_c.start_soon( + count_tracker_impl, + message_rx, api_pool, reorder_buffer_time, thread_id, bot_user, enforcing, update_retention, logger + ) + + +trio_asyncio.run(main) diff --git a/src/strikebot/db.py b/src/strikebot/db.py new file mode 100644 index 0000000..76a19b6 --- /dev/null +++ b/src/strikebot/db.py @@ -0,0 +1,44 @@ +"""Single-connection Postgres client with high-level wrappers for operations.""" + +from dataclasses import dataclass +from typing import Any, Iterable + +import trio + + +def _channel_sender(method): + async def wrapped(self, resp_channel, *args, **kwargs): + with resp_channel: + await resp_channel.send(await method(self, *args, **kwargs)) + + return wrapped + + +@dataclass +class Client: + """High-level wrappers for DB operations.""" + conn: Any + + @_channel_sender + async def get_auth_tokens(self, auth_ids: set[int]): + raise NotImplementedError() + + +class Messenger: + def __init__(self, client: Client): + self._client = client + (self._request_tx, self._request_rx) = trio.open_memory_channel(0) + + async def do(self, method_name: str, args: Iterable) -> Any: + """This is run by consumers of the DB wrapper.""" + method = getattr(self._client, method_name) + (resp_tx, resp_rx) = trio.open_memory_channel(0) + await self._request_tx.send(method(resp_tx, *args)) + async with resp_rx: + return await resp_rx.receive() + + async def db_client_impl(self): + """This is the DB client Trio task.""" + async with self._client.conn, self._request_rx: + async for coro in self._request_rx: + await coro diff --git a/src/strikebot/live_ws.py b/src/strikebot/live_ws.py new file mode 100644 index 0000000..40f02dc --- /dev/null +++ b/src/strikebot/live_ws.py @@ -0,0 +1,203 @@ +from collections import deque +from contextlib import suppress +from dataclasses import dataclass +from typing import Any, AsyncContextManager, Optional +import datetime as dt +import json +import logging +import math + +from trio_websocket import ConnectionClosed, HandshakeError, open_websocket_url, WebSocketConnection +import trio + +from strikebot.reddit_api import AboutLiveThreadRequest + + +@dataclass +class _PoolEvent: + timestamp: float # Trio time + + +@dataclass +class _Message(_PoolEvent): + data: Any + scope: Any + + def dedup_tag(self): + """A hashable essence of the contents of this message.""" + if self.data["type"] == "update": + return ("update", self.data["payload"]["data"]["id"]) + elif self.data["type"] in {"strike", "delete"}: + return (self.data["type"], self.data["payload"]) + else: + raise ValueError(self.data["type"]) + + +@dataclass +class _ConnectionDown(_PoolEvent): + scope: trio.CancelScope + + +@dataclass +class _ConnectionUp(_PoolEvent): + scope: trio.CancelScope + + +class HealingReadPool: + def __init__(self, nursery: trio.Nursery, size, live_thread_id, api_client_pool, pool_event_tx, logger: logging.Logger): + assert size >= 2 + self._nursery = nursery + self._size = size + self._live_thread_id = live_thread_id + self._api_client_pool = api_client_pool + self._pool_event_tx = pool_event_tx + self._logger = logger + (self._refresh_queue_tx, self._refresh_queue_rx) = trio.open_memory_channel(size) + self._active_count = 0 + + async def __aenter__(self): + for _ in range(self._size): + await self._spawn_reader() + if not self._active_count: + raise RuntimeError("Unable to create any WS connections") + + async def __aexit__(self, exc_type, exc_value, traceback): + self._refresh_queue_tx.close() + + async def _reader_impl(self, conn_ctx: AsyncContextManager, cancel_scope, refresh_tx): + try: + async with conn_ctx as conn: + self._active_count += 1 + with refresh_tx: + with cancel_scope, suppress(ConnectionClosed): + while True: + message = await conn.get_message() + event = _Message(trio.current_time(), json.loads(message), cancel_scope) + await self._pool_event_tx.send(event) + + if cancel_scope.cancelled_caught: + await conn.aclose(4058, "Server unexpectedly stopped sending messages") + self._logger.warning("replacing WS connection due to missed updates") + else: + self._logger.warning("replacing WS connection closed by server") + refresh_tx.send_nowait(cancel_scope) + self._active_count -= 1 + except HandshakeError: + self._logger.error("handshake error while opening WS connection") + + async def _new_connection(self) -> Optional[AsyncContextManager[WebSocketConnection]]: + request = AboutLiveThreadRequest(self._live_thread_id) + resp = await self._api_client_pool.make_request(request) + if resp.status_code == 200: + url = json.loads(resp.text)["data"]["websocket_url"] + return open_websocket_url(url) + else: + self._logger.error(f"thread info request failed (HTTP {resp.status_code}); can't open WS connection") + return None + + async def _spawn_reader(self) -> None: + conn = await self._new_connection() + if conn: + new_scope = trio.CancelScope() + await self._pool_event_tx.send(_ConnectionUp(trio.current_time(), new_scope)) + refresh_tx = self._refresh_queue_tx.clone() + self._nursery.start_soon(self._reader_impl, conn, new_scope, refresh_tx) + + async def conn_refresher_impl(self): + """Task to monitor and replace WS connections as they disconnect or go silent.""" + with self._refresh_queue_rx: + async for old_scope in self._refresh_queue_rx: + await self._pool_event_tx.send(_ConnectionDown(trio.current_time(), old_scope)) + await self._spawn_reader() + if not self._active_count: + raise RuntimeError("WS pool depleted") + + +class PoolMerger: + @dataclass + class _Bucket: + start: float + recipients: set + + def __init__(self, pool_event_rx, message_tx, parity_timeout: dt.timedelta, conn_warmup: dt.timedelta, logger: logging.Logger): + """ + :param parity_timeout: max delay between message receipt on different connections + :param conn_warmup: max interval after connection within which missed updates are allowed + """ + self._pool_event_rx = pool_event_rx + self._message_tx = message_tx + self._parity_timeout = parity_timeout + self._conn_warmup = conn_warmup + self._logger = logger + + self._buckets = {} + self._scope_activations = {} + self._pending = deque() + (self._timer_poke_tx, self._timer_poke_rx) = trio.open_memory_channel(math.inf) + + async def event_reader_impl(self): + """Drop unused messages, deduplicate useful ones, and install info needed by the timeout handler.""" + with self._pool_event_rx, self._timer_poke_tx: + async for event in self._pool_event_rx: + if isinstance(event, _ConnectionUp): + # An early add of an active scope could mean it's expected on a message that fired before it opened, + # resulting in a false positive for replacement. The timer task merges these in among the buckets + # by timestamp to avoid that. + self._pending.append(event) + elif isinstance(event, _Message): + if event.data["type"] in {"update", "strike", "delete"}: + tag = event.dedup_tag() + if tag in self._buckets: + b = self._buckets[tag] + b.recipients.add(event.scope) + if b.recipients >= self._scope_activations.keys(): + del self._buckets[tag] + else: + sane = ( + event.scope in self._scope_activations + or any(e.scope == event.scope for e in self._pending) + ) + if sane: + self._buckets[tag] = self._Bucket(event.timestamp, {event.scope}) + self._message_tx.send_nowait(event) + else: + raise RuntimeError("recieved message from unrecognized WS connection") + elif isinstance(event, _ConnectionDown): + # We don't need to worry about canceling this scope at all, so no need to require it for parity for + # any message, even older ones. The scope may be gone already, if we canceled it previously. + self._scope_activations.pop(event.scope, None) + else: + raise TypeError(f"Expected pool event, found {event!r}") + self._timer_poke_tx.send_nowait(None) # may have new work for the timer + + async def timeout_handler_impl(self): + """When connections have had enough time to reach parity on a message, signal replacement of any slackers.""" + with self._timer_poke_rx: + while True: + if self._buckets: + now = trio.current_time() + (tag, bucket) = min(self._buckets.items(), key = lambda i: i[1].start) + + # Make sure the right scope set is ready for the next bucket up + while self._pending and self._pending[0].timestamp < bucket.start: + ev = self._pending.popleft() + self._scope_activations[ev.scope] = ev.timestamp + + target_scopes = { + scope for (scope, active) in self._scope_activations.items() + if active + self._conn_warmup.total_seconds() < now + } + if bucket.recipients >= target_scopes: + del self._buckets[tag] + elif now > bucket.start + self._parity_timeout.total_seconds(): + for scope in target_scopes - bucket.recipients: + scope.cancel() + del self._scope_activations[scope] + del self._buckets[tag] + else: + await trio.sleep_until(bucket.start + self._parity_timeout.total_seconds()) + else: + try: + await self._timer_poke_rx.receive() + except trio.EndOfChannel: + break diff --git a/src/strikebot/queue.py b/src/strikebot/queue.py new file mode 100644 index 0000000..fed0757 --- /dev/null +++ b/src/strikebot/queue.py @@ -0,0 +1,36 @@ +"""Unbounded blocking priority queue for Trio""" + +from dataclasses import dataclass +from functools import total_ordering +from typing import Any +import heapq + +from trio.lowlevel import ParkingLot + + +@total_ordering +@dataclass +class _ReverseOrdWrapper: + inner: Any + + def __lt__(self, other): + return other.inner < self.inner + + +class MaxHeap: + def __init__(self): + self._heap = [] + self._empty_wait = ParkingLot() + + def push(self, item): + heapq.heappush(self._heap, _ReverseOrdWrapper(item)) + if len(self._empty_wait): + self._empty_wait.unpark() + + async def pop(self): + if not self._heap: + await self._empty_wait.park() + return heapq.heappop(self._heap).inner + + def __len__(self) -> int: + return len(self._heap) diff --git a/src/strikebot/reddit_api.py b/src/strikebot/reddit_api.py new file mode 100644 index 0000000..600b38e --- /dev/null +++ b/src/strikebot/reddit_api.py @@ -0,0 +1,228 @@ +"""Multi-auth HTTP connection pool with Reddit API wrappers and rate limiting.""" + +from abc import ABCMeta, abstractmethod +from collections import deque +from dataclasses import dataclass, field +from functools import total_ordering +import datetime as dt +import logging + +from asks.response_objects import Response +import asks +import trio + +from strikebot import __version__ as VERSION +from strikebot.queue import MaxHeap + + +REQUEST_DELAY = dt.timedelta(seconds = 1) + +TOKEN_UPDATE_DELAY = dt.timedelta(minutes = 15) + +USER_AGENT_FMT = f"any:net.jcornell.strikebot.{{auth_id}}:v{VERSION} (by /u/jaxklax)" + +API_BASE_URL = "https://oauth.reddit.com" + + +@total_ordering +class _Request(metaclass = ABCMeta): + _SUBTYPE_PRECEDENCE = None # assigned later + + def __lt__(self, other): + if type(self) is type(other): + return self._subtype_key() < other._subtype_key() + else: + prec = self._SUBTYPE_PRECEDENCE + return prec.index(type(other)) < prec.index(type(self)) + + def __eq__(self, other): + if type(self) is type(other): + return self._subtype_key() == other._subtype_key() + else: + return False + + @abstractmethod + def _subtype_cmp_key(self): + # a large key corresponds to a high priority + raise NotImplementedError() + + @abstractmethod + def to_asks_kwargs(self): + raise NotImplementedError() + + +@dataclass(eq = False) +class StrikeRequest(_Request): + thread_id: str + update_name: str + update_ts: dt.datetime + + def to_asks_kwargs(self): + return { + "method": "POST", + "path": f"/api/live/{self.thread_id}/strike_update", + "data": { + "api_type": "json", + "id": self.update_name, + }, + } + + def _subtype_cmp_key(self): + return self.update_ts + + +@dataclass(eq = False) +class DeleteRequest(_Request): + thread_id: str + update_name: str + update_ts: dt.datetime + + def to_asks_kwargs(self): + return { + "method": "POST", + "path": f"/api/live/{self.thread_id}/delete_update", + "data": { + "api_type": "json", + "id": self.update_name, + }, + } + + def _subtype_cmp_key(self): + return -self.update_ts + + +@dataclass(eq = False) +class AboutLiveThreadRequest(_Request): + thread_id: str + created: float = field(default_factory = trio.current_time) + + def to_asks_kwargs(self): + return { + "method": "GET", + "path": f"/live/{self.thread_id}/about", + "params": { + "raw_json": "1", + }, + } + + def _subtype_cmp_key(self): + return -self.created + + +@dataclass(eq = False) +class _BaseUpdatePostRequest(_Request, metaclass = ABCMeta): + thread_id: str + body: str + created: float = field(default_factory = trio.current_time) + + def to_asks_kwargs(self): + return { + "method": "POST", + "path": f"/live/{self.thread_id}/update", + "data": { + "api_type": "json", + "body": self.body, + }, + } + + +@dataclass(eq = False) +class ReportUpdateRequest(_BaseUpdatePostRequest): + def _subtype_cmp_key(self): + return -self.created + + +@dataclass(eq = False) +class CorrectionUpdateRequest(_BaseUpdatePostRequest): + def _subtype_cmp_key(self): + return -self.created + + +# highest priority first +_Request._SUBTYPE_PRECEDENCE = [ + AboutLiveThreadRequest, + StrikeRequest, + CorrectionUpdateRequest, + ReportUpdateRequest, + DeleteRequest, +] + + +@dataclass +class AppCooldown: + auth_id: int + ready_at: float # Trio time + + +class ApiClientPool: + def __init__(self, auth_ids: set[int], db_messenger, logger: logging.Logger): + self._db_messenger = db_messenger + self._logger = logger + + self._tokens = {id_: None for id_ in auth_ids} + self._tokens_installed = trio.Event() + now = trio.current_time() + self._app_queue = deque(AppCooldown(id_, now) for id_ in auth_ids) + self._request_queue = MaxHeap() + + self._session = asks.Session(connections = len(auth_ids)) + self._session.base_location = API_BASE_URL + + async def _update_tokens(self): + tokens = await self._db_messenger.do("get_auth_tokens", (self._tokens.keys(),)) + assert tokens.keys() == self._tokens.keys() + self._tokens = tokens + self._tokens_installed.set() + + async def token_updater_impl(self): + last_update = None + while True: + if last_update is not None: + await trio.sleep_until(last_update + TOKEN_UPDATE_DELAY.total_seconds()) + last_update = trio.current_time() + await self._update_tokens() + + def _check_queue_size(self) -> None: + if len(self._request_queue) > 5: + self._logger.warning(f"API workers may be saturated; {len(self._request_queue)} requests in queue") + + async def make_request(self, request: _Request) -> Response: + (resp_tx, resp_rx) = trio.open_memory_channel(0) + self._request_queue.push((request, resp_tx)) + async with resp_rx: + return await resp_rx.receive() + + def enqueue_request(self, request: _Request) -> None: + self._request_queue.push((request, None)) + + async def worker_impl(self): + await self._tokens_installed.wait() + while True: + (request, resp_tx) = await self._request_queue.pop() + while trio.current_time() < self._app_queue[0].ready_at: + await trio.sleep_until(self._app_queue[0].ready_at) + + cooldown = self._app_queue.popleft() + asks_kwargs = request.to_asks_kwargs() + asks_kwargs.setdefault("headers", {}).update({ + "Authorization": "Bearer {}".format(self._tokens[cooldown.auth_id]), + "User-Agent": USER_AGENT_FMT.format(auth_id = cooldown.auth_id), + }) + cooldown.ready_at = trio.current_time() + REQUEST_DELAY.total_seconds() + resp = await self._session.request(**asks_kwargs) + + if resp.status_code == 429: + # We disagreed about the rate limit state; just try again later. + self._request_queue.put((request, resp_tx)) + elif 400 <= resp.status_code < 500: + # If we're doing something wrong, let's catch it right away. + if resp_tx: + resp_tx.close() + raise RuntimeError("Unexpected client error response: {}".format(resp.status_code)) + else: + if resp.status_code != 200: + self._logger.warning(f"got HTTP {resp.status_code} from Reddit API") + if resp_tx: + await resp_tx.send(resp) + + self._app_queue.append(cooldown) diff --git a/src/strikebot/tests.py b/src/strikebot/tests.py new file mode 100644 index 0000000..9b9ddae --- /dev/null +++ b/src/strikebot/tests.py @@ -0,0 +1,42 @@ +from unittest import TestCase + +from strikebot.updates import parse_update + + +def _build_payload(body_html: str) -> dict[str, str]: + return {"body_html": body_html} + + +class UpdateParsingTests(TestCase): + def test_successful_counts(self): + pu = parse_update(_build_payload("
12,345,678 spaghetti
"), None) + self.assertEqual(pu.number, 12345678) + self.assertFalse(pu.deletable) + + pu = parse_update(_build_payload("

0


oz
"), None) + self.assertEqual(pu.number, 0) + self.assertFalse(pu.deletable) + + def test_non_counts(self): + pu = parse_update(_build_payload("
zoo
"), None) + self.assertFalse(pu.count_attempt) + self.assertFalse(pu.deletable) + + def test_typos(self): + pu = parse_update(_build_payload("v9"), 888) + self.assertIsNone(pu.number) + self.assertTrue(pu.count_attempt) + + pu = parse_update(_build_payload("
v11.585 Empire
"), None) + self.assertIsNone(pu.number) + self.assertTrue(pu.count_attempt) + self.assertFalse(pu.deletable) + + pu = parse_update(_build_payload("
11, 585, 22
"), 11_585_202) + self.assertIsNone(pu.number) + self.assertTrue(pu.count_attempt) + self.assertTrue(pu.deletable) + + pu = parse_update(_build_payload("0490499"), 4999) + self.assertIsNone(pu.number) + self.assertTrue(pu.count_attempt) diff --git a/src/strikebot/updates.py b/src/strikebot/updates.py new file mode 100644 index 0000000..cd814bf --- /dev/null +++ b/src/strikebot/updates.py @@ -0,0 +1,202 @@ +from __future__ import annotations +from collections import deque +from dataclasses import dataclass +from enum import Enum +from typing import Optional +from xml.etree import ElementTree +import re + + +Command = Enum("Command", ["RESET", "REPORT"]) + + +@dataclass +class ParsedUpdate: + number: Optional[int] + command: Optional[Command] + count_attempt: bool + deletable: bool + + +def _parse_command(line: str, bot_user: str) -> Optional[Command]: + if line.lower() == f"/u/{bot_user} reset".lower(): + return Command.RESET + elif line.lower() in ["sidebar count", "current count"]: + return Command.REPORT + else: + return None + + +def parse_update(payload_data: dict, curr_count: Optional[int], bot_user: str) -> ParsedUpdate: + NEW_LINE = object() + SPACE = object() + + # flatten the update content to plain text + doc = ElementTree.fromstring(payload_data["body_html"]) + worklist = deque([doc]) + out = [[]] + while worklist: + el = worklist.popleft() + if el is NEW_LINE: + if out[-1]: + out.append([]) + elif el is SPACE: + out[-1].append(el) + elif isinstance(el, str): + out[-1].append(el.replace("\n", " ")) + elif el.tag in ["em", "strong", "del", "span", "sup", "code", "a", "th", "td"]: + if el.tail: + worklist.appendleft(el.tail) + for sub in reversed(el): + worklist.appendleft(sub) + if el.text: + worklist.appendleft(el.text) + elif el.tag == "li": + assert not el.tail + worklist.appendleft(NEW_LINE) + worklist.appendleft(el.text) + elif el.tag in ["p", "div", "h1", "h2", "blockquote"]: + if el.tail: + worklist.appendleft(el.tail) + worklist.appendleft(NEW_LINE) + for sub in reversed(el): + worklist.appendleft(sub) + if el.text: + worklist.appendleft(el.text) + elif el.tag in ["ul", "ol"]: + if el.tail: + worklist.appendleft(el.tail) + assert not el.text + for sub in reversed(el): + worklist.appendleft(sub) + worklist.appendleft(NEW_LINE) + elif el.tag == "pre": + out.extend([l] for l in el.text.splitlines()) + if el.tail: + worklist.appendleft(el.tail) + worklist.appendleft(NEW_LINE) + elif el.tag == "table": + if el.tail: + worklist.appendleft(el.tail) + worklist.appendleft(NEW_LINE) + assert not el.text + for sub in reversed(el): + assert sub.tag in ["thead", "tbody"] + assert not sub.text + for row in reversed(sub): + assert row.tag == "tr" + assert not row.text or row.tail + for (i, cell) in enumerate(reversed(row)): + assert not cell.tail + worklist.appendleft(cell) + if i != len(row) - 1: + worklist.appendleft(SPACE) + worklist.appendleft(NEW_LINE) + elif el.tag == "br": + if el.tail: + worklist.appendleft(el.tail) + worklist.appendleft(NEW_LINE) + else: + raise RuntimeError(f"can't parse tag {el.tag}") + + lines = list(filter( + None, + ( + "".join(" " if part is SPACE else part for part in parts).strip() + for parts in out + ) + )) + + command = next( + filter(None, (_parse_command(l, bot_user) for l in lines)), + None + ) + + if lines: + first = lines[0] + match = re.match( + "(?Pv)?(?P-)?(?P\\d+((?P[,. \u2009]|, )\\d+((?P=sep)\\d+)*)?)( |$)", + first, + re.ASCII, # only recognize ASCII digits + ) + if match: + ct_str = match["num"] + sep = match["sep"] + + zeros = False + while len(ct_str) > 1 and ct_str[0] == "0": + zeros = True + ct_str = ct_str.removeprefix("0").removeprefix(sep or "") + + parts = ct_str.split(sep) if sep else [ct_str] + parts_valid = ( + len(parts[0]) in range(1, 3) + and all(len(p) == 3 for p in parts[1:]) + ) + digits = "".join(parts) + lone = not first[match.end() :].strip() and len(lines) == 1 + typo = False + if lone: + if match["v"] and len(ct_str) <= 2: + # failed paste of leading digits + typo = True + elif match["v"] and parts_valid: + # v followed by count + typo = True + elif curr_count and curr_count >= 100: + goal = (sep or "").join(_separate(str(curr_count))) + partials = [goal[: -2], goal[: -1], goal[: -2] + goal[-1]] + if ct_str in partials: + # missing any of last two digits + typo = True + elif ct_str in [p + goal for p in partials]: + # double paste + typo = True + if match["v"] or zeros or typo or (digits == "0" and match["neg"]): + number = None + count_attempt = True + deletable = lone + elif parts_valid: + number = -int(digits) if match["neg"] else int(digits) + count_attempt = True + special = curr_count is not None and abs(number - curr_count) <= 25 and _is_special_number(number) + deletable = lone and not special + else: + number = None + count_attempt = False + deletable = lone + else: + # no count attempt found + number = None + count_attempt = False + deletable = False + else: + # no lines in update + number = None + count_attempt = False + deletable = True + + return ParsedUpdate( + number = number, + command = command, + count_attempt = count_attempt, + deletable = deletable, + ) + + +def _separate(digits: str) -> list[str]: + mod = len(digits) % 3 + out = [] + if mod: + out.append(digits[: mod]) + out.extend(digits[i : i + 3] for i in range(mod, len(digits), 3)) + return out + + +def _is_special_number(num: int) -> bool: + num_str = str(num) + return bool( + num % 1000 in [0, 1, 333, 999] + or (num > 10_000_000 and "".join(reversed(num_str)) == num_str) + or re.match(r"(.+)\1+$", num_str) # repeated sequence + ) -- 2.30.2