From: Jakob Cornell Date: Mon, 9 Jan 2023 00:07:32 +0000 (-0600) Subject: Reorganize for multi-project builds, improve build helper X-Git-Tag: strikebot-0.0.8~1 X-Git-Url: https://jcornell.net/gitweb/gitweb.cgi?a=commitdiff_plain;h=3d9ff30a00e4701540424a5b42fa391c6173c31f;p=counting.git Reorganize for multi-project builds, improve build helper --- diff --git a/build_helper.py b/build_helper.py index c747894..07d3be2 100644 --- a/build_helper.py +++ b/build_helper.py @@ -1,5 +1,7 @@ """ -A few tools to help keep Debian packages and artifacts well organized. +A few tools to help keep Debian packages and artifacts well organized. Run me from the counting repo root. + +Run as root if doing a build, as the """ from argparse import ArgumentParser @@ -24,18 +26,26 @@ def _is_output(path: Path) -> bool: ) -project_root = Path(__file__).parent -upstream_root = project_root.joinpath("strikebot") -build_dir = project_root.joinpath("build") - parser = ArgumentParser() + tmp = parser.add_subparsers(dest = "cmd", required = True) +project_cmd_parsers = [ + tmp.add_parser("build"), + tmp.add_parser("clean"), +] -tmp.add_parser("build") -tmp.add_parser("clean") +for subparser in project_cmd_parsers: + subparser.add_argument("project") args = parser.parse_args() + +project_root = Path(__file__).parent.joinpath(args.project) +upstream_root = project_root.joinpath(args.project) +build_dir = project_root.joinpath("build") + if args.cmd == "build": + build_dir.mkdir(exist_ok = True) + # extract version from Python package config patt = re.compile("version *= *(.+?) *$") with upstream_root.joinpath("setup.cfg").open() as f: @@ -43,25 +53,30 @@ if args.cmd == "build": version = m[1] # delete stale "orig" tarball - for p in project_root.glob("strikebot_*.orig.tar.xz"): + for p in project_root.glob(f"{args.project}_*.orig.tar.xz"): p.unlink() # regenerate the "orig" tarball from the current source - run(["dh_make", "--yes", "--python", "--createorig", "-p", "strikebot_" + version], cwd = upstream_root) + run(["dh_make", "--yes", "--python", "--createorig", "-p", f"{args.project}_{version}"], cwd = upstream_root) - # build the source package - run(["debuild", "-i", "-us", "-uc", "-S"], cwd = upstream_root, check = True) + try: + # build the source package + run(["debuild", "-i", "-us", "-uc", "-S"], cwd = upstream_root, check = True) - [orig_path] = project_root.glob("strikebot_*.orig.tar.xz") - orig_name = orig_path.name - - # move outputs to build directory - for p in project_root.iterdir(): - if _is_output(p): - p.rename(build_dir.joinpath(p.relative_to(project_root))) + [orig_path] = project_root.glob(f"{args.project}_*.orig.tar.xz") + orig_name = orig_path.name - # create temporary link for orig tarball to satisfy binary package build - orig_path.symlink_to(build_dir.joinpath(orig_name).relative_to(project_root)) + try: + # build binary package + run(["pdebuild"], cwd = upstream_root, check = True) + finally: + # clean up + orig_path.unlink() + finally: + # move source package and intermediates to build directory + for p in project_root.iterdir(): + if _is_output(p): + p.rename(build_dir.joinpath(p.relative_to(project_root))) elif args.cmd == "clean": for p in build_dir.iterdir(): if _is_output(p): diff --git a/mentionbot/docs/sample_config.ini b/mentionbot/docs/sample_config.ini deleted file mode 100644 index 4228b6a..0000000 --- a/mentionbot/docs/sample_config.ini +++ /dev/null @@ -1,28 +0,0 @@ -# see Python `configparser' docs for precise file syntax - -[config] - -# authorization ID for Reddit API token lookup -auth ID = 3 - -# (int, minutes) -token update period = 15 - -thread ID = abc123 - -# (int) maximum number of mentions within an update for which to send messages -mention limit = 10 - -# (float, seconds) wait this much time before sending notifications, in case of update deletion -message buffer time = 10 - -# (float, seconds) wait this much time between API polls (and WS connection checks) -API poll period = 90 - -# (float, seconds) replace the WS connection if it falls this far behind the API in update timestamps -max WS delay = 2 - - -# Postgres database configuration; same options as in a connect string. -[db connect params] -host = example.org diff --git a/mentionbot/mentionbot/docs/sample_config.ini b/mentionbot/mentionbot/docs/sample_config.ini new file mode 100644 index 0000000..4228b6a --- /dev/null +++ b/mentionbot/mentionbot/docs/sample_config.ini @@ -0,0 +1,28 @@ +# see Python `configparser' docs for precise file syntax + +[config] + +# authorization ID for Reddit API token lookup +auth ID = 3 + +# (int, minutes) +token update period = 15 + +thread ID = abc123 + +# (int) maximum number of mentions within an update for which to send messages +mention limit = 10 + +# (float, seconds) wait this much time before sending notifications, in case of update deletion +message buffer time = 10 + +# (float, seconds) wait this much time between API polls (and WS connection checks) +API poll period = 90 + +# (float, seconds) replace the WS connection if it falls this far behind the API in update timestamps +max WS delay = 2 + + +# Postgres database configuration; same options as in a connect string. +[db connect params] +host = example.org diff --git a/mentionbot/mentionbot/pyproject.toml b/mentionbot/mentionbot/pyproject.toml new file mode 100644 index 0000000..8fe2f47 --- /dev/null +++ b/mentionbot/mentionbot/pyproject.toml @@ -0,0 +1,3 @@ +[build-system] +requires = ["setuptools>=42", "wheel"] +build-backend = "setuptools.build_meta" diff --git a/mentionbot/mentionbot/setup.cfg b/mentionbot/mentionbot/setup.cfg new file mode 100644 index 0000000..b3f9866 --- /dev/null +++ b/mentionbot/mentionbot/setup.cfg @@ -0,0 +1,15 @@ +[metadata] +name = mentionbot +version = 0.0.0 + +[options] +package_dir = + = src +packages = mentionbot +python_requires = ~= 3.9 +install_requires = + beautifulsoup4 ~= 4.9 + #psycopg2 ~= 2.8 + psycopg2-binary ~= 2.8 + trio == 0.19 + trio-websocket == 0.9.2 diff --git a/mentionbot/mentionbot/setup.py b/mentionbot/mentionbot/setup.py new file mode 100644 index 0000000..056ba45 --- /dev/null +++ b/mentionbot/mentionbot/setup.py @@ -0,0 +1,4 @@ +import setuptools + + +setuptools.setup() diff --git a/mentionbot/mentionbot/src/mentionbot/__init__.py b/mentionbot/mentionbot/src/mentionbot/__init__.py new file mode 100644 index 0000000..76fd7e2 --- /dev/null +++ b/mentionbot/mentionbot/src/mentionbot/__init__.py @@ -0,0 +1,187 @@ +from dataclasses import dataclass +from datetime import timedelta +from http.client import HTTPResponse +from logging import Logger +from socket import EAI_AGAIN, EAI_FAIL, gaierror +from typing import AsyncContextManager, Callable, Iterable, Optional, Type +from urllib.error import HTTPError +from urllib.request import Request, urlopen +from uuid import UUID +import importlib.metadata +import re + +from bs4 import BeautifulSoup +from trio import CancelScope, MemoryReceiveChannel, MemorySendChannel, move_on_at +import trio + +from .abc import DeleteEvent, EventStream, Notifier, NotifyInfo, ThreadTitleProvider, UpdateChecker, UpdateEvent + + +_USER_AGENT = "any:net.jcornell.mentionbot:v{} (by /u/jaxklax)".format( + importlib.metadata.version(__package__) +) + + +@dataclass +class TokenRef: + """Mutable container for an access token.""" + + token: str + + def get_common_headers(self) -> dict[str, str]: + return { + "User-Agent": _USER_AGENT, + "Authorization": "Bearer {}".format(self.token) + } + + +def try_request( + request: Request, + logger: Logger, + suppress_codes: list[int] = [404, 429, 500, 503], +) -> Optional[HTTPResponse]: + + try: + return urlopen(request) + except gaierror as e: + if e.errno in [EAI_FAIL, EAI_AGAIN]: + logger.error(f"DNS failure: {e}") + return None + else: + raise + except HTTPError as e: + e.read() + e.close() + if e.code in suppress_codes: + logger.error(f"HTTP {e.code}") + return None + else: + raise + + +@dataclass +class WsListenerState: + """The WebSockets listener uses this object to communicate with the API poll and message sender tasks.""" + + # The API poll task uses this scope to force a refresh of the WS connection if it's been too quiet. + scope: Optional[CancelScope] = None + + # The WS task uses this to track the UUID timestamp of the most recent update message update, so the API poll task + # can detect when the connection goes silent. + last_update_ts: Optional[int] = None + + # The WS task sets this when the WS connection goes down and clears it when it's recovered. + down: bool = True + + # The WS task sets this to the UUID timestamp of the first update recieved on connection recovery so the message + # sender task knows when to stop checking for deletion. + resync_ts: Optional[int] = None + + +def _parse_mentions(body: BeautifulSoup) -> set[str]: + """Extract the names of users tagged in an update.""" + return { + el["href"].removeprefix("/u/") + for el in body.find_all("a", href = re.compile(r"^/u/[\w-]{3,20}$", re.ASCII)) + } + + +async def ws_listener_impl( + event_stream_factory: Callable[[], AsyncContextManager[EventStream]], + event_stream_catch: Iterable[Type[BaseException]], + state: WsListenerState, + event_tx: MemorySendChannel, # message type `Event' + logger: Logger, +) -> None: + + while True: + state.scope = CancelScope() + with state.scope: + logger.debug("opening a new connection") + try: + async with event_stream_factory() as event_stream: + first_update = True + while True: + event = await event_stream.get_event() + if isinstance(event, UpdateEvent): + update_id = UUID(event.payload_data["id"]) + state.last_update_ts = update_id.time + if first_update: + # first update after discontinuity + state.down = False + state.resync_ts = update_id.time + first_update = False + elif isinstance(event, DeleteEvent): + logger.debug("delete event for " + event.payload) + else: + raise RuntimeError(f"expected Event, got {event}") + + await event_tx.send(event) + except tuple(event_stream_catch) as e: + logger.warning(f"WebSockets error: {e}") + + state.down = True + + +async def message_sender_impl( + notifier: Notifier, + title_provider: ThreadTitleProvider, + update_checker: UpdateChecker, + ws_state: WsListenerState, + thread_id: str, + event_rx: MemoryReceiveChannel, # message type `Event' + buffer_time: timedelta, + mention_limit: int, + logger: Logger, +) -> None: + + @dataclass + class QueueEntry: + payload_data: dict + mentions: set[str] + arrival: float # Trio time + + queue: list[QueueEntry] = [] + + while True: + logger.debug("queue: {}".format([e.payload_data["id"] for e in queue])) + if queue: + scope = move_on_at(queue[0].arrival + buffer_time.total_seconds()) + else: + scope = CancelScope() # no timeout + + with scope: + event = await event_rx.receive() + + if scope.cancelled_caught: + # timed out + entry = queue.pop(0) + update_id = entry.payload_data["id"] + if ws_state.down or ws_state.resync_ts and UUID(update_id).time < ws_state.resync_ts: + notify = update_checker.exists(thread_id, update_id) + else: + notify = True + + if notify: + thread_title = title_provider.title_for(thread_id) + for recipient in entry.mentions: + notifier.notify(NotifyInfo( + recipient, thread_id, thread_title, update_id, entry.payload_data["author"], + entry.payload_data["body"] + )) + else: + # got an event + if isinstance(event, UpdateEvent): + update_id = event.payload_data["id"] + if not any(entry.payload_data["id"] == update_id for entry in queue): + mentions = _parse_mentions(BeautifulSoup(event.payload_data["body_html"], "html.parser")) + if mentions: + if len(mentions) > mention_limit: + logger.info(f"ignoring {update_id} due to too many mentions ({len(mentions)})") + else: + queue.append(QueueEntry(event.payload_data, mentions, trio.current_time())) + elif isinstance(event, DeleteEvent): + logger.debug("delete event for " + event.payload) + queue = [entry for entry in queue if entry.payload_data["name"] != event.payload] + else: + raise RuntimeError(f"invalid event type {type(event)}") diff --git a/mentionbot/mentionbot/src/mentionbot/__main__.py b/mentionbot/mentionbot/src/mentionbot/__main__.py new file mode 100644 index 0000000..960a228 --- /dev/null +++ b/mentionbot/mentionbot/src/mentionbot/__main__.py @@ -0,0 +1,291 @@ +""" +Mentionbot: Notify users by private message when they're /u/-mentioned in a live thread. +""" + +from argparse import ArgumentParser +from configparser import ConfigParser +from contextlib import asynccontextmanager +from dataclasses import dataclass +from datetime import timedelta +from logging import Formatter, getLogger, Logger, NOTSET, StreamHandler +from sys import stdout +from typing import ClassVar, Iterator, Optional, Type +from urllib.error import HTTPError +from urllib.parse import urlencode +from urllib.request import Request, urlopen +from uuid import UUID +import json +import logging + +from trio import MemorySendChannel, open_memory_channel, open_nursery +from trio_websocket import ConnectionClosed, HandshakeError, open_websocket_url, WebSocketConnection +import psycopg2 +import trio + +from . import TokenRef, message_sender_impl, try_request, ws_listener_impl, WsListenerState +from .abc import DeleteEvent, Event, EventStream, Notifier, NotifyInfo, ThreadTitleProvider, UpdateChecker, UpdateEvent + + +async def _token_update_impl(period: timedelta, db_conn, auth_id: int, task_status) -> None: + def fetch_token() -> str: + cursor = db_conn.cursor() + cursor.execute( + """ + select access_token from public.reddit_app_authorization + where id = %s + """, + (auth_id,) + ) + [(token,)] = cursor.fetchall() + return token + + ref = TokenRef(fetch_token()) + task_status.started(ref) + while True: + await trio.sleep(period.total_seconds()) + ref.token = fetch_token() + + +async def _api_poll_impl( + ws_state: WsListenerState, + token_ref: TokenRef, + thread_id: str, + event_tx: MemorySendChannel, # message type `Event' + poll_period: timedelta, + max_ws_delay: timedelta, + logger: Logger, +) -> None: + + base_url = f"https://oauth.reddit.com/live/{thread_id}" + + latest_payload: Optional[dict] = None + while True: + # fetch updates and emit events + params = { + "raw_json": "1", + } + if latest_payload: + updates = [] + while True: + params["before"] = latest_payload["data"]["name"] + url = base_url + "?" + urlencode(params) + resp = try_request(Request(url, headers = token_ref.get_common_headers()), logger) + if resp: + with resp: + children = json.load(resp)["data"]["children"] + if children: + updates.extend(reversed(children)) + latest_payload = children[0] + else: + break + else: + break + else: + params["limit"] = "1" + url = base_url + "?" + urlencode(params) + resp = try_request(Request(url, headers = token_ref.get_common_headers()), logger) + if resp: + with resp: + updates = json.load(resp)["data"]["children"] + if updates: + latest_payload = updates[0] + else: + updates = [] + + for update in updates: + await event_tx.send(UpdateEvent(update["data"])) + + # check if WS connection is behind + if not ws_state.down and latest_payload: + max_delta = max_ws_delay.total_seconds() * 10 ** 7 # in 100*nanosecond + if UUID(latest_payload["data"]["id"]).time - ws_state.last_update_ts > max_delta: + ws_state.scope.cancel() + + await trio.sleep(poll_period.total_seconds()) + + +def _build_notification_md(info: NotifyInfo) -> str: + # TODO escape thread title? + return "\n\n".join([ + f"from [{info.author}](/user/{info.author}) in [{info.thread_title}](/live/{info.thread_id})", + "\n".join(">" + line for line in info.update_body.splitlines()), + ( + f"[^^permalink](/live/{info.thread_id}/updates/{info.update_id})" + + f" [^^context](/live/{info.thread_id}?after=LiveUpdate_{info.update_id})" + + " [^^about ^^these ^^notifications](/r/live_mentions/wiki)" + ), + ]) + + +@dataclass +class RedditNotifier(Notifier): + token_ref: TokenRef + logger: Logger + + def notify(self, info: NotifyInfo) -> None: + params = { + "raw_json": "1", + "api_type": "json", + "subject": "username mention", + "text": _build_notification_md(info), + } + req = Request( + "https://oauth.reddit.com/api/compose", + method = "POST", + data = urlencode({**params, "to": info.recipient}).encode("ascii"), + headers = self.token_ref.get_common_headers(), + ) + resp = try_request(req, self.logger) + if resp: + with resp: + data = json.load(resp)["json"] + if data["errors"]: + [(code, _, _)] = data["errors"] + if code != "USER_DOESNT_EXIST": + raise RuntimeError("PM send error: {code}") + + +@dataclass +class RedditThreadTitleProvider(ThreadTitleProvider): + token_ref: TokenRef + + def title_for(self, thread_id: str) -> str: + req = Request( + f"https://oauth.reddit.com/live/{thread_id}/about?raw_json=1", + headers = self.token_ref.get_common_headers() + ) + with urlopen(req) as resp: + return json.load(resp)["data"]["title"] + + +@dataclass +class RedditUpdateChecker(UpdateChecker): + token_ref: TokenRef + logger: Logger + + def exists(self, thread_id: str, update_id: str) -> bool: + req = Request( + f"https://oauth.reddit.com/live/{thread_id}/updates/{update_id}?raw_json=1", + headers = self.token_ref.get_common_headers(), + ) + try: + resp = try_request(req, self.logger, suppress_codes = [429, 500, 503]) + except HTTPError as e: + e.read() + e.close() + if e.code == 404: + return False + else: + raise + else: + if resp: + resp.read() + resp.close() + return True + + +@dataclass +class WsConnectionEventStream(EventStream): + EXCEPTION_TYPES: ClassVar[set[Type[BaseException]]] = {HandshakeError, ConnectionClosed} + + ws_conn: WebSocketConnection + + async def get_event(self) -> Event: + while True: + message = json.loads(await self.ws_conn.get_message()) + if message["type"] == "update": + return UpdateEvent(message["payload"]["data"]) + elif message["type"] == "delete": + return DeleteEvent(message["payload"]) + + +async def _main( + auth_id: int, + buffer_time: timedelta, + db_conn, + logger: Logger, + max_ws_delay: timedelta, + mention_limit: int, + poll_period: timedelta, + thread_id: str, + token_period: timedelta, +) -> None: + + @asynccontextmanager + async def event_stream_factory() -> Iterator[WsConnectionEventStream]: + resp = urlopen(Request( + f"https://oauth.reddit.com/live/{thread_id}/about?raw_json=1", + headers = token_ref.get_common_headers(), + )) + with resp: + ws_url = json.load(resp)["data"]["websocket_url"] + + async with open_websocket_url(ws_url) as ws_conn: + yield WsConnectionEventStream(ws_conn) + + async with open_nursery() as nursery: + token_ref = await nursery.start(_token_update_impl, token_period, db_conn, auth_id) + + ws_state = WsListenerState() + (event_tx, event_rx) = open_memory_channel(0) + nursery.start_soon( + ws_listener_impl, + event_stream_factory, WsConnectionEventStream.EXCEPTION_TYPES, ws_state, event_tx, logger.getChild("WS") + ) + nursery.start_soon( + _api_poll_impl, + ws_state, token_ref, thread_id, event_tx, poll_period, max_ws_delay, logger.getChild("poll") + ) + + notifier = RedditNotifier(token_ref, logger) + title_provider = RedditThreadTitleProvider(token_ref) + deletion_checker = RedditUpdateChecker(token_ref, logger) + nursery.start_soon( + message_sender_impl, + notifier, title_provider, deletion_checker, ws_state, thread_id, event_rx, buffer_time, mention_limit, + logger.getChild("sender") + ) + + +if __name__ == "__main__": + arg_parser = ArgumentParser(__package__) + arg_parser.add_argument("config_path") + args = arg_parser.parse_args() + + config_parser = ConfigParser() + with open(args.config_path) as config_file: + config_parser.read_file(config_file) + + main_cfg = config_parser["config"] + + auth_id = main_cfg.getint("auth ID") + assert auth_id is not None + + token_period = timedelta(minutes = main_cfg.getint("token update period")) + + thread_id = main_cfg["thread ID"] + + mention_limit = main_cfg.getint("mention limit") + assert mention_limit is not None and mention_limit >= 0 + + buffer_time = timedelta(seconds = main_cfg.getfloat("message buffer time")) + + poll_period = timedelta(seconds = main_cfg.getfloat("API poll period")) + + max_ws_delay = timedelta(seconds = main_cfg.getfloat("max WS delay")) + + db_conn = psycopg2.connect(**config_parser["db connect params"]) + db_conn.autocommit = True + + logger = getLogger(__package__) + logger.setLevel(NOTSET - 1) # filter in handler(s) instead + + handler = StreamHandler(stdout) + handler.setLevel(logging.DEBUG) + handler.setFormatter(Formatter("{asctime:23}: {name:17}: {levelname:8}: {message}", style = "{")) + logger.addHandler(handler) + + trio.run( + _main, + auth_id, buffer_time, db_conn, logger, max_ws_delay, mention_limit, poll_period, thread_id, token_period + ) diff --git a/mentionbot/mentionbot/src/mentionbot/abc.py b/mentionbot/mentionbot/src/mentionbot/abc.py new file mode 100644 index 0000000..3c1d7fd --- /dev/null +++ b/mentionbot/mentionbot/src/mentionbot/abc.py @@ -0,0 +1,50 @@ +from abc import ABCMeta, abstractmethod +from dataclasses import dataclass + + +class Event(metaclass = ABCMeta): + pass + + +@dataclass +class UpdateEvent(Event): + payload_data: dict # object at "payload" -> "data" in the message JSON + + +@dataclass +class DeleteEvent(Event): + payload: str # update name + + +class EventStream(metaclass = ABCMeta): + @abstractmethod + async def get_event(self) -> Event: + raise NotImplementedError() + + +@dataclass +class NotifyInfo: + recipient: str + thread_id: str + thread_title: str + update_id: str + author: str + update_body: str + + +class Notifier(metaclass = ABCMeta): + @abstractmethod + def notify(self, info: NotifyInfo) -> None: + raise NotImplementedError() + + +class ThreadTitleProvider(metaclass = ABCMeta): + @abstractmethod + def title_for(self, thread_id: str) -> str: + raise NotImplementedError() + + +class UpdateChecker(metaclass = ABCMeta): + @abstractmethod + def exists(self, thread_id: str, update_id: str) -> bool: + raise NotImplementedError() diff --git a/mentionbot/mentionbot/src/mentionbot/tests.py b/mentionbot/mentionbot/src/mentionbot/tests.py new file mode 100644 index 0000000..0617c90 --- /dev/null +++ b/mentionbot/mentionbot/src/mentionbot/tests.py @@ -0,0 +1,350 @@ +from contextlib import asynccontextmanager, contextmanager +from dataclasses import dataclass, field +from datetime import timedelta +from logging import getLogger +from numbers import Real +from os import environ +from typing import Iterator, TypeVar +from unittest import TestCase +from urllib.request import Request +from uuid import UUID, uuid1 +import logging + +from trio import CancelScope, MemoryReceiveChannel, move_on_after, open_memory_channel, open_nursery, sleep_until +from trio.testing import MockClock, wait_all_tasks_blocked +import trio + +from . import message_sender_impl, TokenRef, try_request, WsListenerState, ws_listener_impl +from .abc import DeleteEvent, Event, EventStream, Notifier, NotifyInfo, ThreadTitleProvider, UpdateChecker, UpdateEvent + + +T = TypeVar("T") + + +@dataclass +class MockNotifier(Notifier): + notifies: list[NotifyInfo] = field(default_factory = list) + + def notify(self, info: NotifyInfo) -> None: + self.notifies.append(info) + + +class MockTitleProvider(ThreadTitleProvider): + def title_for(self, thread_id: str) -> str: + return thread_id + + +class UnusedUpdateChecker(UpdateChecker): + def exists(self, thread_id: str, update_id: str) -> bool: + raise AssertionError("unexpected deletion check") + + +@dataclass +class DictBackedUpdateChecker(UpdateChecker): + exists_by_id: dict[str, bool] = field(default_factory = dict) + queries: list[tuple[str, str]] = field(default_factory = list) + + def exists(self, thread_id: str, update_id: str) -> bool: + self.queries.append((thread_id, update_id)) + return self.exists_by_id[update_id] + + +@dataclass +class ChannelBackedEventStream(EventStream): + event_rx: MemoryReceiveChannel + + async def get_event(self) -> Event: + obj = await self.event_rx.receive() + if isinstance(obj, BaseException): + raise obj + else: + return obj + + +@asynccontextmanager +async def null_async_context(val: T) -> Iterator[T]: + yield val + + +def _update_event(author: str, mentions: list[str]) -> UpdateEvent: + id_ = uuid1() + body = str(id_) + " " + " ".join("/u/" + user for user in mentions) + body_html = str(id_) + " " + " ".join( + f'/u/{user}' + for user in mentions + ) + return UpdateEvent( + { + "id": str(id_), + "name": "LiveUpdate_" + str(id_), + "author": author, + "body": body, + "body_html": body_html, + } + ) + + +def setUpModule() -> None: + getLogger().setLevel(logging.CRITICAL + 1) # disable logging + + +class TaskTests(TestCase): + @contextmanager + def _assert_completes_in(self, duration_seconds: Real) -> Iterator[None]: + with move_on_after(duration_seconds) as scope: + yield + if scope.cancelled_caught: + self.fail("operation timed out") + + def test_buffering(self) -> None: + """ + No API poll task, just a custom task mocking the WS listener by playing events over a channel, and a real + notifier task injected with mocks to capture output. The WS state is hard-coded as up, so no testing of recovery + from WS connection issues is included in this test. + + Primarily it makes sure that updates are passed through the buffer and processed if and only if a delete event + doesn't arrive while they're buffered. + """ + + buffer_time = timedelta(seconds = 1) + notifier = MockNotifier() + (event_tx, event_rx) = open_memory_channel(0) + + ign_event = _update_event("sender", []) + event_a = _update_event("sender", ["rec1", "rec2"]) + event_b = _update_event("sender", ["rec1", "rec2"]) + + async def event_stream_impl(scope: CancelScope) -> None: + await sleep_until(1.0) + await event_tx.send(ign_event) + await event_tx.send(event_a) + + await sleep_until(1.5) + self.assertFalse(notifier.notifies) + await event_tx.send(DeleteEvent(event_a.payload_data["id"])) + + await sleep_until(2.0) + await event_tx.send(event_b) + + # let things settle and terminate both tasks + await sleep_until(4.0) + scope.cancel() + + async def main() -> None: + async with open_nursery() as nursery: + nursery.start_soon(event_stream_impl, nursery.cancel_scope) + nursery.start_soon( + message_sender_impl, + notifier, MockTitleProvider(), UnusedUpdateChecker(), WsListenerState(down = False), "ThreadId", + event_rx, buffer_time, getLogger() + ) + + trio.run(main, clock = MockClock(autojump_threshold = 0)) + + # only notifications sent are for event B + self.assertEqual(len(notifier.notifies), 2) + self.assertEqual({i.recipient for i in notifier.notifies}, {"rec1", "rec2"}) + self.assertTrue(all(i.author == "sender" for i in notifier.notifies)) + self.assertTrue(all(i.update_id == event_b.payload_data["id"] for i in notifier.notifies)) + + def test_ws_trouble(self) -> None: + """ + Ensure that the notifier task responds correctly to the WS connection going down and coming back up. + """ + + thread_id = "ThreadId" + buffer_time = timedelta(seconds = 3) + update_checker = DictBackedUpdateChecker() + notifier = MockNotifier() + (event_tx, event_rx) = open_memory_channel(0) + + # sent while WS is up, checked and notifies sent + event_a = _update_event("sender", ["rec"]) + + # sent and deleted while WS is down, checked and discarded from buffer + event_b = _update_event("sender", ["rec"]) + + # sent on WS reconnect, not checked and notifies sent + event_c = _update_event("sender", ["rec"]) + + # sent after WS reconnect, not checked and notifies sent + event_d = _update_event("sender", ["rec"]) + + ws_state = WsListenerState(down = False) + + async def ws_simulator_impl() -> None: + update_checker.exists_by_id[event_a.payload_data["id"]] = True + await event_tx.send(event_a) + + await sleep_until(1.0) + ws_state.down = True + update_checker.exists_by_id[event_b.payload_data["id"]] = True + await event_tx.send(event_b) + + await sleep_until(2.0) + update_checker.exists_by_id[event_b.payload_data["id"]] = False + ws_state.down = False + ws_state.resync_ts = UUID(event_c.payload_data["id"]).time + await event_tx.send(event_c) + + await sleep_until(3.0) + await event_tx.send(event_d) + + async def main() -> None: + async with open_nursery() as nursery: + nursery.start_soon(ws_simulator_impl) + nursery.start_soon( + message_sender_impl, + notifier, MockTitleProvider(), update_checker, ws_state, "ThreadId", event_rx, buffer_time, + getLogger() + ) + + await sleep_until(3.5) + self.assertEqual([i.update_id for i in notifier.notifies], [event_a.payload_data["id"]]) + notifier.notifies.clear() + self.assertEqual(update_checker.queries, [(thread_id, event_a.payload_data["id"])]) + update_checker.queries.clear() + + await sleep_until(4.5) + self.assertFalse(notifier.notifies) + self.assertEqual(update_checker.queries, [(thread_id, event_b.payload_data["id"])]) + update_checker.queries.clear() + + await sleep_until(5.5) + self.assertEqual([i.update_id for i in notifier.notifies], [event_c.payload_data["id"]]) + notifier.notifies.clear() + self.assertFalse(update_checker.queries) + + await sleep_until(6.5) + self.assertEqual([i.update_id for i in notifier.notifies], [event_d.payload_data["id"]]) + notifier.notifies.clear() + self.assertFalse(update_checker.queries) + + nursery.cancel_scope.cancel() + + trio.run(main, clock = MockClock(autojump_threshold = 0)) + + def test_ws_listener_cancel(self) -> None: + ws_state = WsListenerState() + + (event_in_tx, event_in_rx) = open_memory_channel(0) + event_stream = ChannelBackedEventStream(event_in_rx) + event_stream_factory = lambda: null_async_context(event_stream) + + (event_out_tx, event_out_rx) = open_memory_channel(0) + + async def main(): + async with open_nursery() as nursery: + nursery.start_soon( + ws_listener_impl, + event_stream_factory, set(), ws_state, event_out_tx, getLogger() + ) + + event_1 = _update_event("author", []) + await event_in_tx.send(event_1) + with self._assert_completes_in(1): + event_out = await event_out_rx.receive() + self.assertIs(event_out, event_1) + + event_2 = _update_event("author", ["recipient"]) + await event_in_tx.send(event_2) + with self._assert_completes_in(1): + event_out = await event_out_rx.receive() + self.assertIs(event_out, event_2) + + initial_scope = ws_state.scope + initial_scope.cancel() + await wait_all_tasks_blocked() + self.assertIsNotNone(ws_state.scope) + self.assertIsNot(ws_state.scope, initial_scope) + self.assertTrue(ws_state.down) + + event_3 = _update_event("author", []) + ts = UUID(event_3.payload_data["id"]).time + await event_in_tx.send(event_3) + with self._assert_completes_in(1): + event_out = await event_out_rx.receive() + self.assertIs(event_out, event_3) + self.assertEqual(ws_state.last_update_ts, ts) + self.assertFalse(ws_state.down) + self.assertEqual(ws_state.resync_ts, ts) + + nursery.cancel_scope.cancel() + + trio.run(main, clock = MockClock(autojump_threshold = 0)) + + def test_ws_listener_error(self) -> None: + """Test that the WS listener responds correctly to errors from the input event stream.""" + + class MockStreamException(Exception): + pass + + ws_state = WsListenerState() + + (event_in_tx, event_in_rx) = open_memory_channel(0) + event_stream = ChannelBackedEventStream(event_in_rx) + event_stream_factory = lambda: null_async_context(event_stream) + + (event_out_tx, event_out_rx) = open_memory_channel(0) + + async def main(): + async with open_nursery() as nursery: + nursery.start_soon( + ws_listener_impl, + event_stream_factory, {MockStreamException}, ws_state, event_out_tx, getLogger() + ) + + event_1 = _update_event("author", []) + await event_in_tx.send(event_1) + with self._assert_completes_in(1): + event_out = await event_out_rx.receive() + self.assertIs(event_out, event_1) + + initial_scope = ws_state.scope + await event_in_tx.send(MockStreamException()) + await wait_all_tasks_blocked() + self.assertIsNotNone(ws_state.scope) + self.assertIsNot(ws_state.scope, initial_scope) + self.assertTrue(ws_state.down) + + event_2 = _update_event("author", ["recipient"]) + ts = UUID(event_2.payload_data["id"]).time + await event_in_tx.send(event_2) + with self._assert_completes_in(1): + event_out = await event_out_rx.receive() + self.assertIs(event_out, event_2) + self.assertEqual(ws_state.last_update_ts, ts) + self.assertFalse(ws_state.down) + self.assertEqual(ws_state.resync_ts, ts) + + nursery.cancel_scope.cancel() + + trio.run(main, clock = MockClock(autojump_threshold = 0)) + + +class RedditTests(TestCase): + """Tests for interactions with Reddit.""" + + @classmethod + def setUpClass(cls) -> None: + cls._TOKEN_REF = TokenRef(environ["API_TOKEN"]) + + def test_try_request(self) -> None: + self.assertIsNotNone( + try_request( + Request( + "https://oauth.reddit.com/api/live/happening_now", + headers = self._TOKEN_REF.get_common_headers(), + ), + getLogger(), + ) + ) + + def test_try_request_http_error(self) -> None: + self.assertIsNone( + try_request( + Request("https://oauth.reddit.com/api/live/happening_now"), + getLogger(), + suppress_codes = [403], + ) + ) diff --git a/mentionbot/pyproject.toml b/mentionbot/pyproject.toml deleted file mode 100644 index 8fe2f47..0000000 --- a/mentionbot/pyproject.toml +++ /dev/null @@ -1,3 +0,0 @@ -[build-system] -requires = ["setuptools>=42", "wheel"] -build-backend = "setuptools.build_meta" diff --git a/mentionbot/setup.cfg b/mentionbot/setup.cfg deleted file mode 100644 index b3f9866..0000000 --- a/mentionbot/setup.cfg +++ /dev/null @@ -1,15 +0,0 @@ -[metadata] -name = mentionbot -version = 0.0.0 - -[options] -package_dir = - = src -packages = mentionbot -python_requires = ~= 3.9 -install_requires = - beautifulsoup4 ~= 4.9 - #psycopg2 ~= 2.8 - psycopg2-binary ~= 2.8 - trio == 0.19 - trio-websocket == 0.9.2 diff --git a/mentionbot/setup.py b/mentionbot/setup.py deleted file mode 100644 index 056ba45..0000000 --- a/mentionbot/setup.py +++ /dev/null @@ -1,4 +0,0 @@ -import setuptools - - -setuptools.setup() diff --git a/mentionbot/src/mentionbot/__init__.py b/mentionbot/src/mentionbot/__init__.py deleted file mode 100644 index 76fd7e2..0000000 --- a/mentionbot/src/mentionbot/__init__.py +++ /dev/null @@ -1,187 +0,0 @@ -from dataclasses import dataclass -from datetime import timedelta -from http.client import HTTPResponse -from logging import Logger -from socket import EAI_AGAIN, EAI_FAIL, gaierror -from typing import AsyncContextManager, Callable, Iterable, Optional, Type -from urllib.error import HTTPError -from urllib.request import Request, urlopen -from uuid import UUID -import importlib.metadata -import re - -from bs4 import BeautifulSoup -from trio import CancelScope, MemoryReceiveChannel, MemorySendChannel, move_on_at -import trio - -from .abc import DeleteEvent, EventStream, Notifier, NotifyInfo, ThreadTitleProvider, UpdateChecker, UpdateEvent - - -_USER_AGENT = "any:net.jcornell.mentionbot:v{} (by /u/jaxklax)".format( - importlib.metadata.version(__package__) -) - - -@dataclass -class TokenRef: - """Mutable container for an access token.""" - - token: str - - def get_common_headers(self) -> dict[str, str]: - return { - "User-Agent": _USER_AGENT, - "Authorization": "Bearer {}".format(self.token) - } - - -def try_request( - request: Request, - logger: Logger, - suppress_codes: list[int] = [404, 429, 500, 503], -) -> Optional[HTTPResponse]: - - try: - return urlopen(request) - except gaierror as e: - if e.errno in [EAI_FAIL, EAI_AGAIN]: - logger.error(f"DNS failure: {e}") - return None - else: - raise - except HTTPError as e: - e.read() - e.close() - if e.code in suppress_codes: - logger.error(f"HTTP {e.code}") - return None - else: - raise - - -@dataclass -class WsListenerState: - """The WebSockets listener uses this object to communicate with the API poll and message sender tasks.""" - - # The API poll task uses this scope to force a refresh of the WS connection if it's been too quiet. - scope: Optional[CancelScope] = None - - # The WS task uses this to track the UUID timestamp of the most recent update message update, so the API poll task - # can detect when the connection goes silent. - last_update_ts: Optional[int] = None - - # The WS task sets this when the WS connection goes down and clears it when it's recovered. - down: bool = True - - # The WS task sets this to the UUID timestamp of the first update recieved on connection recovery so the message - # sender task knows when to stop checking for deletion. - resync_ts: Optional[int] = None - - -def _parse_mentions(body: BeautifulSoup) -> set[str]: - """Extract the names of users tagged in an update.""" - return { - el["href"].removeprefix("/u/") - for el in body.find_all("a", href = re.compile(r"^/u/[\w-]{3,20}$", re.ASCII)) - } - - -async def ws_listener_impl( - event_stream_factory: Callable[[], AsyncContextManager[EventStream]], - event_stream_catch: Iterable[Type[BaseException]], - state: WsListenerState, - event_tx: MemorySendChannel, # message type `Event' - logger: Logger, -) -> None: - - while True: - state.scope = CancelScope() - with state.scope: - logger.debug("opening a new connection") - try: - async with event_stream_factory() as event_stream: - first_update = True - while True: - event = await event_stream.get_event() - if isinstance(event, UpdateEvent): - update_id = UUID(event.payload_data["id"]) - state.last_update_ts = update_id.time - if first_update: - # first update after discontinuity - state.down = False - state.resync_ts = update_id.time - first_update = False - elif isinstance(event, DeleteEvent): - logger.debug("delete event for " + event.payload) - else: - raise RuntimeError(f"expected Event, got {event}") - - await event_tx.send(event) - except tuple(event_stream_catch) as e: - logger.warning(f"WebSockets error: {e}") - - state.down = True - - -async def message_sender_impl( - notifier: Notifier, - title_provider: ThreadTitleProvider, - update_checker: UpdateChecker, - ws_state: WsListenerState, - thread_id: str, - event_rx: MemoryReceiveChannel, # message type `Event' - buffer_time: timedelta, - mention_limit: int, - logger: Logger, -) -> None: - - @dataclass - class QueueEntry: - payload_data: dict - mentions: set[str] - arrival: float # Trio time - - queue: list[QueueEntry] = [] - - while True: - logger.debug("queue: {}".format([e.payload_data["id"] for e in queue])) - if queue: - scope = move_on_at(queue[0].arrival + buffer_time.total_seconds()) - else: - scope = CancelScope() # no timeout - - with scope: - event = await event_rx.receive() - - if scope.cancelled_caught: - # timed out - entry = queue.pop(0) - update_id = entry.payload_data["id"] - if ws_state.down or ws_state.resync_ts and UUID(update_id).time < ws_state.resync_ts: - notify = update_checker.exists(thread_id, update_id) - else: - notify = True - - if notify: - thread_title = title_provider.title_for(thread_id) - for recipient in entry.mentions: - notifier.notify(NotifyInfo( - recipient, thread_id, thread_title, update_id, entry.payload_data["author"], - entry.payload_data["body"] - )) - else: - # got an event - if isinstance(event, UpdateEvent): - update_id = event.payload_data["id"] - if not any(entry.payload_data["id"] == update_id for entry in queue): - mentions = _parse_mentions(BeautifulSoup(event.payload_data["body_html"], "html.parser")) - if mentions: - if len(mentions) > mention_limit: - logger.info(f"ignoring {update_id} due to too many mentions ({len(mentions)})") - else: - queue.append(QueueEntry(event.payload_data, mentions, trio.current_time())) - elif isinstance(event, DeleteEvent): - logger.debug("delete event for " + event.payload) - queue = [entry for entry in queue if entry.payload_data["name"] != event.payload] - else: - raise RuntimeError(f"invalid event type {type(event)}") diff --git a/mentionbot/src/mentionbot/__main__.py b/mentionbot/src/mentionbot/__main__.py deleted file mode 100644 index 960a228..0000000 --- a/mentionbot/src/mentionbot/__main__.py +++ /dev/null @@ -1,291 +0,0 @@ -""" -Mentionbot: Notify users by private message when they're /u/-mentioned in a live thread. -""" - -from argparse import ArgumentParser -from configparser import ConfigParser -from contextlib import asynccontextmanager -from dataclasses import dataclass -from datetime import timedelta -from logging import Formatter, getLogger, Logger, NOTSET, StreamHandler -from sys import stdout -from typing import ClassVar, Iterator, Optional, Type -from urllib.error import HTTPError -from urllib.parse import urlencode -from urllib.request import Request, urlopen -from uuid import UUID -import json -import logging - -from trio import MemorySendChannel, open_memory_channel, open_nursery -from trio_websocket import ConnectionClosed, HandshakeError, open_websocket_url, WebSocketConnection -import psycopg2 -import trio - -from . import TokenRef, message_sender_impl, try_request, ws_listener_impl, WsListenerState -from .abc import DeleteEvent, Event, EventStream, Notifier, NotifyInfo, ThreadTitleProvider, UpdateChecker, UpdateEvent - - -async def _token_update_impl(period: timedelta, db_conn, auth_id: int, task_status) -> None: - def fetch_token() -> str: - cursor = db_conn.cursor() - cursor.execute( - """ - select access_token from public.reddit_app_authorization - where id = %s - """, - (auth_id,) - ) - [(token,)] = cursor.fetchall() - return token - - ref = TokenRef(fetch_token()) - task_status.started(ref) - while True: - await trio.sleep(period.total_seconds()) - ref.token = fetch_token() - - -async def _api_poll_impl( - ws_state: WsListenerState, - token_ref: TokenRef, - thread_id: str, - event_tx: MemorySendChannel, # message type `Event' - poll_period: timedelta, - max_ws_delay: timedelta, - logger: Logger, -) -> None: - - base_url = f"https://oauth.reddit.com/live/{thread_id}" - - latest_payload: Optional[dict] = None - while True: - # fetch updates and emit events - params = { - "raw_json": "1", - } - if latest_payload: - updates = [] - while True: - params["before"] = latest_payload["data"]["name"] - url = base_url + "?" + urlencode(params) - resp = try_request(Request(url, headers = token_ref.get_common_headers()), logger) - if resp: - with resp: - children = json.load(resp)["data"]["children"] - if children: - updates.extend(reversed(children)) - latest_payload = children[0] - else: - break - else: - break - else: - params["limit"] = "1" - url = base_url + "?" + urlencode(params) - resp = try_request(Request(url, headers = token_ref.get_common_headers()), logger) - if resp: - with resp: - updates = json.load(resp)["data"]["children"] - if updates: - latest_payload = updates[0] - else: - updates = [] - - for update in updates: - await event_tx.send(UpdateEvent(update["data"])) - - # check if WS connection is behind - if not ws_state.down and latest_payload: - max_delta = max_ws_delay.total_seconds() * 10 ** 7 # in 100*nanosecond - if UUID(latest_payload["data"]["id"]).time - ws_state.last_update_ts > max_delta: - ws_state.scope.cancel() - - await trio.sleep(poll_period.total_seconds()) - - -def _build_notification_md(info: NotifyInfo) -> str: - # TODO escape thread title? - return "\n\n".join([ - f"from [{info.author}](/user/{info.author}) in [{info.thread_title}](/live/{info.thread_id})", - "\n".join(">" + line for line in info.update_body.splitlines()), - ( - f"[^^permalink](/live/{info.thread_id}/updates/{info.update_id})" - + f" [^^context](/live/{info.thread_id}?after=LiveUpdate_{info.update_id})" - + " [^^about ^^these ^^notifications](/r/live_mentions/wiki)" - ), - ]) - - -@dataclass -class RedditNotifier(Notifier): - token_ref: TokenRef - logger: Logger - - def notify(self, info: NotifyInfo) -> None: - params = { - "raw_json": "1", - "api_type": "json", - "subject": "username mention", - "text": _build_notification_md(info), - } - req = Request( - "https://oauth.reddit.com/api/compose", - method = "POST", - data = urlencode({**params, "to": info.recipient}).encode("ascii"), - headers = self.token_ref.get_common_headers(), - ) - resp = try_request(req, self.logger) - if resp: - with resp: - data = json.load(resp)["json"] - if data["errors"]: - [(code, _, _)] = data["errors"] - if code != "USER_DOESNT_EXIST": - raise RuntimeError("PM send error: {code}") - - -@dataclass -class RedditThreadTitleProvider(ThreadTitleProvider): - token_ref: TokenRef - - def title_for(self, thread_id: str) -> str: - req = Request( - f"https://oauth.reddit.com/live/{thread_id}/about?raw_json=1", - headers = self.token_ref.get_common_headers() - ) - with urlopen(req) as resp: - return json.load(resp)["data"]["title"] - - -@dataclass -class RedditUpdateChecker(UpdateChecker): - token_ref: TokenRef - logger: Logger - - def exists(self, thread_id: str, update_id: str) -> bool: - req = Request( - f"https://oauth.reddit.com/live/{thread_id}/updates/{update_id}?raw_json=1", - headers = self.token_ref.get_common_headers(), - ) - try: - resp = try_request(req, self.logger, suppress_codes = [429, 500, 503]) - except HTTPError as e: - e.read() - e.close() - if e.code == 404: - return False - else: - raise - else: - if resp: - resp.read() - resp.close() - return True - - -@dataclass -class WsConnectionEventStream(EventStream): - EXCEPTION_TYPES: ClassVar[set[Type[BaseException]]] = {HandshakeError, ConnectionClosed} - - ws_conn: WebSocketConnection - - async def get_event(self) -> Event: - while True: - message = json.loads(await self.ws_conn.get_message()) - if message["type"] == "update": - return UpdateEvent(message["payload"]["data"]) - elif message["type"] == "delete": - return DeleteEvent(message["payload"]) - - -async def _main( - auth_id: int, - buffer_time: timedelta, - db_conn, - logger: Logger, - max_ws_delay: timedelta, - mention_limit: int, - poll_period: timedelta, - thread_id: str, - token_period: timedelta, -) -> None: - - @asynccontextmanager - async def event_stream_factory() -> Iterator[WsConnectionEventStream]: - resp = urlopen(Request( - f"https://oauth.reddit.com/live/{thread_id}/about?raw_json=1", - headers = token_ref.get_common_headers(), - )) - with resp: - ws_url = json.load(resp)["data"]["websocket_url"] - - async with open_websocket_url(ws_url) as ws_conn: - yield WsConnectionEventStream(ws_conn) - - async with open_nursery() as nursery: - token_ref = await nursery.start(_token_update_impl, token_period, db_conn, auth_id) - - ws_state = WsListenerState() - (event_tx, event_rx) = open_memory_channel(0) - nursery.start_soon( - ws_listener_impl, - event_stream_factory, WsConnectionEventStream.EXCEPTION_TYPES, ws_state, event_tx, logger.getChild("WS") - ) - nursery.start_soon( - _api_poll_impl, - ws_state, token_ref, thread_id, event_tx, poll_period, max_ws_delay, logger.getChild("poll") - ) - - notifier = RedditNotifier(token_ref, logger) - title_provider = RedditThreadTitleProvider(token_ref) - deletion_checker = RedditUpdateChecker(token_ref, logger) - nursery.start_soon( - message_sender_impl, - notifier, title_provider, deletion_checker, ws_state, thread_id, event_rx, buffer_time, mention_limit, - logger.getChild("sender") - ) - - -if __name__ == "__main__": - arg_parser = ArgumentParser(__package__) - arg_parser.add_argument("config_path") - args = arg_parser.parse_args() - - config_parser = ConfigParser() - with open(args.config_path) as config_file: - config_parser.read_file(config_file) - - main_cfg = config_parser["config"] - - auth_id = main_cfg.getint("auth ID") - assert auth_id is not None - - token_period = timedelta(minutes = main_cfg.getint("token update period")) - - thread_id = main_cfg["thread ID"] - - mention_limit = main_cfg.getint("mention limit") - assert mention_limit is not None and mention_limit >= 0 - - buffer_time = timedelta(seconds = main_cfg.getfloat("message buffer time")) - - poll_period = timedelta(seconds = main_cfg.getfloat("API poll period")) - - max_ws_delay = timedelta(seconds = main_cfg.getfloat("max WS delay")) - - db_conn = psycopg2.connect(**config_parser["db connect params"]) - db_conn.autocommit = True - - logger = getLogger(__package__) - logger.setLevel(NOTSET - 1) # filter in handler(s) instead - - handler = StreamHandler(stdout) - handler.setLevel(logging.DEBUG) - handler.setFormatter(Formatter("{asctime:23}: {name:17}: {levelname:8}: {message}", style = "{")) - logger.addHandler(handler) - - trio.run( - _main, - auth_id, buffer_time, db_conn, logger, max_ws_delay, mention_limit, poll_period, thread_id, token_period - ) diff --git a/mentionbot/src/mentionbot/abc.py b/mentionbot/src/mentionbot/abc.py deleted file mode 100644 index 3c1d7fd..0000000 --- a/mentionbot/src/mentionbot/abc.py +++ /dev/null @@ -1,50 +0,0 @@ -from abc import ABCMeta, abstractmethod -from dataclasses import dataclass - - -class Event(metaclass = ABCMeta): - pass - - -@dataclass -class UpdateEvent(Event): - payload_data: dict # object at "payload" -> "data" in the message JSON - - -@dataclass -class DeleteEvent(Event): - payload: str # update name - - -class EventStream(metaclass = ABCMeta): - @abstractmethod - async def get_event(self) -> Event: - raise NotImplementedError() - - -@dataclass -class NotifyInfo: - recipient: str - thread_id: str - thread_title: str - update_id: str - author: str - update_body: str - - -class Notifier(metaclass = ABCMeta): - @abstractmethod - def notify(self, info: NotifyInfo) -> None: - raise NotImplementedError() - - -class ThreadTitleProvider(metaclass = ABCMeta): - @abstractmethod - def title_for(self, thread_id: str) -> str: - raise NotImplementedError() - - -class UpdateChecker(metaclass = ABCMeta): - @abstractmethod - def exists(self, thread_id: str, update_id: str) -> bool: - raise NotImplementedError() diff --git a/mentionbot/src/mentionbot/tests.py b/mentionbot/src/mentionbot/tests.py deleted file mode 100644 index 0617c90..0000000 --- a/mentionbot/src/mentionbot/tests.py +++ /dev/null @@ -1,350 +0,0 @@ -from contextlib import asynccontextmanager, contextmanager -from dataclasses import dataclass, field -from datetime import timedelta -from logging import getLogger -from numbers import Real -from os import environ -from typing import Iterator, TypeVar -from unittest import TestCase -from urllib.request import Request -from uuid import UUID, uuid1 -import logging - -from trio import CancelScope, MemoryReceiveChannel, move_on_after, open_memory_channel, open_nursery, sleep_until -from trio.testing import MockClock, wait_all_tasks_blocked -import trio - -from . import message_sender_impl, TokenRef, try_request, WsListenerState, ws_listener_impl -from .abc import DeleteEvent, Event, EventStream, Notifier, NotifyInfo, ThreadTitleProvider, UpdateChecker, UpdateEvent - - -T = TypeVar("T") - - -@dataclass -class MockNotifier(Notifier): - notifies: list[NotifyInfo] = field(default_factory = list) - - def notify(self, info: NotifyInfo) -> None: - self.notifies.append(info) - - -class MockTitleProvider(ThreadTitleProvider): - def title_for(self, thread_id: str) -> str: - return thread_id - - -class UnusedUpdateChecker(UpdateChecker): - def exists(self, thread_id: str, update_id: str) -> bool: - raise AssertionError("unexpected deletion check") - - -@dataclass -class DictBackedUpdateChecker(UpdateChecker): - exists_by_id: dict[str, bool] = field(default_factory = dict) - queries: list[tuple[str, str]] = field(default_factory = list) - - def exists(self, thread_id: str, update_id: str) -> bool: - self.queries.append((thread_id, update_id)) - return self.exists_by_id[update_id] - - -@dataclass -class ChannelBackedEventStream(EventStream): - event_rx: MemoryReceiveChannel - - async def get_event(self) -> Event: - obj = await self.event_rx.receive() - if isinstance(obj, BaseException): - raise obj - else: - return obj - - -@asynccontextmanager -async def null_async_context(val: T) -> Iterator[T]: - yield val - - -def _update_event(author: str, mentions: list[str]) -> UpdateEvent: - id_ = uuid1() - body = str(id_) + " " + " ".join("/u/" + user for user in mentions) - body_html = str(id_) + " " + " ".join( - f'/u/{user}' - for user in mentions - ) - return UpdateEvent( - { - "id": str(id_), - "name": "LiveUpdate_" + str(id_), - "author": author, - "body": body, - "body_html": body_html, - } - ) - - -def setUpModule() -> None: - getLogger().setLevel(logging.CRITICAL + 1) # disable logging - - -class TaskTests(TestCase): - @contextmanager - def _assert_completes_in(self, duration_seconds: Real) -> Iterator[None]: - with move_on_after(duration_seconds) as scope: - yield - if scope.cancelled_caught: - self.fail("operation timed out") - - def test_buffering(self) -> None: - """ - No API poll task, just a custom task mocking the WS listener by playing events over a channel, and a real - notifier task injected with mocks to capture output. The WS state is hard-coded as up, so no testing of recovery - from WS connection issues is included in this test. - - Primarily it makes sure that updates are passed through the buffer and processed if and only if a delete event - doesn't arrive while they're buffered. - """ - - buffer_time = timedelta(seconds = 1) - notifier = MockNotifier() - (event_tx, event_rx) = open_memory_channel(0) - - ign_event = _update_event("sender", []) - event_a = _update_event("sender", ["rec1", "rec2"]) - event_b = _update_event("sender", ["rec1", "rec2"]) - - async def event_stream_impl(scope: CancelScope) -> None: - await sleep_until(1.0) - await event_tx.send(ign_event) - await event_tx.send(event_a) - - await sleep_until(1.5) - self.assertFalse(notifier.notifies) - await event_tx.send(DeleteEvent(event_a.payload_data["id"])) - - await sleep_until(2.0) - await event_tx.send(event_b) - - # let things settle and terminate both tasks - await sleep_until(4.0) - scope.cancel() - - async def main() -> None: - async with open_nursery() as nursery: - nursery.start_soon(event_stream_impl, nursery.cancel_scope) - nursery.start_soon( - message_sender_impl, - notifier, MockTitleProvider(), UnusedUpdateChecker(), WsListenerState(down = False), "ThreadId", - event_rx, buffer_time, getLogger() - ) - - trio.run(main, clock = MockClock(autojump_threshold = 0)) - - # only notifications sent are for event B - self.assertEqual(len(notifier.notifies), 2) - self.assertEqual({i.recipient for i in notifier.notifies}, {"rec1", "rec2"}) - self.assertTrue(all(i.author == "sender" for i in notifier.notifies)) - self.assertTrue(all(i.update_id == event_b.payload_data["id"] for i in notifier.notifies)) - - def test_ws_trouble(self) -> None: - """ - Ensure that the notifier task responds correctly to the WS connection going down and coming back up. - """ - - thread_id = "ThreadId" - buffer_time = timedelta(seconds = 3) - update_checker = DictBackedUpdateChecker() - notifier = MockNotifier() - (event_tx, event_rx) = open_memory_channel(0) - - # sent while WS is up, checked and notifies sent - event_a = _update_event("sender", ["rec"]) - - # sent and deleted while WS is down, checked and discarded from buffer - event_b = _update_event("sender", ["rec"]) - - # sent on WS reconnect, not checked and notifies sent - event_c = _update_event("sender", ["rec"]) - - # sent after WS reconnect, not checked and notifies sent - event_d = _update_event("sender", ["rec"]) - - ws_state = WsListenerState(down = False) - - async def ws_simulator_impl() -> None: - update_checker.exists_by_id[event_a.payload_data["id"]] = True - await event_tx.send(event_a) - - await sleep_until(1.0) - ws_state.down = True - update_checker.exists_by_id[event_b.payload_data["id"]] = True - await event_tx.send(event_b) - - await sleep_until(2.0) - update_checker.exists_by_id[event_b.payload_data["id"]] = False - ws_state.down = False - ws_state.resync_ts = UUID(event_c.payload_data["id"]).time - await event_tx.send(event_c) - - await sleep_until(3.0) - await event_tx.send(event_d) - - async def main() -> None: - async with open_nursery() as nursery: - nursery.start_soon(ws_simulator_impl) - nursery.start_soon( - message_sender_impl, - notifier, MockTitleProvider(), update_checker, ws_state, "ThreadId", event_rx, buffer_time, - getLogger() - ) - - await sleep_until(3.5) - self.assertEqual([i.update_id for i in notifier.notifies], [event_a.payload_data["id"]]) - notifier.notifies.clear() - self.assertEqual(update_checker.queries, [(thread_id, event_a.payload_data["id"])]) - update_checker.queries.clear() - - await sleep_until(4.5) - self.assertFalse(notifier.notifies) - self.assertEqual(update_checker.queries, [(thread_id, event_b.payload_data["id"])]) - update_checker.queries.clear() - - await sleep_until(5.5) - self.assertEqual([i.update_id for i in notifier.notifies], [event_c.payload_data["id"]]) - notifier.notifies.clear() - self.assertFalse(update_checker.queries) - - await sleep_until(6.5) - self.assertEqual([i.update_id for i in notifier.notifies], [event_d.payload_data["id"]]) - notifier.notifies.clear() - self.assertFalse(update_checker.queries) - - nursery.cancel_scope.cancel() - - trio.run(main, clock = MockClock(autojump_threshold = 0)) - - def test_ws_listener_cancel(self) -> None: - ws_state = WsListenerState() - - (event_in_tx, event_in_rx) = open_memory_channel(0) - event_stream = ChannelBackedEventStream(event_in_rx) - event_stream_factory = lambda: null_async_context(event_stream) - - (event_out_tx, event_out_rx) = open_memory_channel(0) - - async def main(): - async with open_nursery() as nursery: - nursery.start_soon( - ws_listener_impl, - event_stream_factory, set(), ws_state, event_out_tx, getLogger() - ) - - event_1 = _update_event("author", []) - await event_in_tx.send(event_1) - with self._assert_completes_in(1): - event_out = await event_out_rx.receive() - self.assertIs(event_out, event_1) - - event_2 = _update_event("author", ["recipient"]) - await event_in_tx.send(event_2) - with self._assert_completes_in(1): - event_out = await event_out_rx.receive() - self.assertIs(event_out, event_2) - - initial_scope = ws_state.scope - initial_scope.cancel() - await wait_all_tasks_blocked() - self.assertIsNotNone(ws_state.scope) - self.assertIsNot(ws_state.scope, initial_scope) - self.assertTrue(ws_state.down) - - event_3 = _update_event("author", []) - ts = UUID(event_3.payload_data["id"]).time - await event_in_tx.send(event_3) - with self._assert_completes_in(1): - event_out = await event_out_rx.receive() - self.assertIs(event_out, event_3) - self.assertEqual(ws_state.last_update_ts, ts) - self.assertFalse(ws_state.down) - self.assertEqual(ws_state.resync_ts, ts) - - nursery.cancel_scope.cancel() - - trio.run(main, clock = MockClock(autojump_threshold = 0)) - - def test_ws_listener_error(self) -> None: - """Test that the WS listener responds correctly to errors from the input event stream.""" - - class MockStreamException(Exception): - pass - - ws_state = WsListenerState() - - (event_in_tx, event_in_rx) = open_memory_channel(0) - event_stream = ChannelBackedEventStream(event_in_rx) - event_stream_factory = lambda: null_async_context(event_stream) - - (event_out_tx, event_out_rx) = open_memory_channel(0) - - async def main(): - async with open_nursery() as nursery: - nursery.start_soon( - ws_listener_impl, - event_stream_factory, {MockStreamException}, ws_state, event_out_tx, getLogger() - ) - - event_1 = _update_event("author", []) - await event_in_tx.send(event_1) - with self._assert_completes_in(1): - event_out = await event_out_rx.receive() - self.assertIs(event_out, event_1) - - initial_scope = ws_state.scope - await event_in_tx.send(MockStreamException()) - await wait_all_tasks_blocked() - self.assertIsNotNone(ws_state.scope) - self.assertIsNot(ws_state.scope, initial_scope) - self.assertTrue(ws_state.down) - - event_2 = _update_event("author", ["recipient"]) - ts = UUID(event_2.payload_data["id"]).time - await event_in_tx.send(event_2) - with self._assert_completes_in(1): - event_out = await event_out_rx.receive() - self.assertIs(event_out, event_2) - self.assertEqual(ws_state.last_update_ts, ts) - self.assertFalse(ws_state.down) - self.assertEqual(ws_state.resync_ts, ts) - - nursery.cancel_scope.cancel() - - trio.run(main, clock = MockClock(autojump_threshold = 0)) - - -class RedditTests(TestCase): - """Tests for interactions with Reddit.""" - - @classmethod - def setUpClass(cls) -> None: - cls._TOKEN_REF = TokenRef(environ["API_TOKEN"]) - - def test_try_request(self) -> None: - self.assertIsNotNone( - try_request( - Request( - "https://oauth.reddit.com/api/live/happening_now", - headers = self._TOKEN_REF.get_common_headers(), - ), - getLogger(), - ) - ) - - def test_try_request_http_error(self) -> None: - self.assertIsNone( - try_request( - Request("https://oauth.reddit.com/api/live/happening_now"), - getLogger(), - suppress_codes = [403], - ) - ) diff --git a/strikebot/docs/sample_config.ini b/strikebot/docs/sample_config.ini deleted file mode 100644 index 34dec07..0000000 --- a/strikebot/docs/sample_config.ini +++ /dev/null @@ -1,65 +0,0 @@ -# see Python `configparser' docs for precise file syntax - -[config] - -######## basic operating parameters - -# space-separated authorization IDs for Reddit API token lookup -auth IDs = 13 15 - -bot user = count_better - -# if false, don't strike or delete updates or report incorrectly stricken updates (optional, default true) -enforcing = true - -# log messages at or above this severity to standard out; level names and numbers are supported (optional, default -# WARNING); see https://docs.python.org/3/library/logging.html#levels -log level = WARNING - -# if true, don't modify the live thread at all (implies enforcing = false) (optional, default false) -read-only = false - -thread ID = abc123 - - -######## API pool error handling - -# (seconds) for error backoff, pause the API pool for this much time -API pool error delay = 5.5 - -# (seconds) pause the pool when the Reddit API returns two errors within any time span of this length -API pool error window = 5.5 - -# terminate if the Reddit API request queue exceeds this size -request queue limit = 100 - - -######## live thread update handling - -# (seconds) maximum time to hold updates for reordering -reorder buffer time = 0.25 - -# (seconds) minimum time to retain updates to enable future resyncs -update retention time = 120.0 - - -######## WebSocket pool - -# (seconds) maximum allowable spread in WebSocket message arrival times -WS parity time = 1.0 - -# goal size of WebSocket connection pool -WS pool size = 5 - -# (seconds) after receiving no messages for this much time, attempt to replace the connection if the pool isn't depleted -WS silent limit = 30 - -# (seconds) time after WebSocket handshake completion during which missed updates are excused -WS warmup time = 0.3 - - -# Postgres database configuration; same options as in a connect string. Note that because Python's `SSLContext' is used -# to implement TLS client auth, using `sslkey' to specify a key obtained from an OpenSSL engine may not be supported (a -# file path is the only option). -[db connect params] -host = example.org diff --git a/strikebot/pyproject.toml b/strikebot/pyproject.toml deleted file mode 100644 index 8fe2f47..0000000 --- a/strikebot/pyproject.toml +++ /dev/null @@ -1,3 +0,0 @@ -[build-system] -requires = ["setuptools>=42", "wheel"] -build-backend = "setuptools.build_meta" diff --git a/strikebot/setup.cfg b/strikebot/setup.cfg deleted file mode 100644 index cdd2e66..0000000 --- a/strikebot/setup.cfg +++ /dev/null @@ -1,15 +0,0 @@ -[metadata] -name = strikebot -version = 0.0.7 - -[options] -package_dir = - = src -packages = strikebot -python_requires = ~= 3.9 -install_requires = - asks ~= 2.4 - beautifulsoup4 ~= 4.9 - trio == 0.19 - trio-websocket == 0.9.2 - triopg == 0.6.0 diff --git a/strikebot/setup.py b/strikebot/setup.py deleted file mode 100644 index 056ba45..0000000 --- a/strikebot/setup.py +++ /dev/null @@ -1,4 +0,0 @@ -import setuptools - - -setuptools.setup() diff --git a/strikebot/src/strikebot/__init__.py b/strikebot/src/strikebot/__init__.py deleted file mode 100644 index daa9944..0000000 --- a/strikebot/src/strikebot/__init__.py +++ /dev/null @@ -1,317 +0,0 @@ -"""Count tracking logic and bot's external interface.""" - -from contextlib import nullcontext as nullcontext -from dataclasses import dataclass -from functools import total_ordering -from itertools import islice -from typing import Optional -from uuid import UUID -import bisect -import datetime as dt -import importlib.metadata -import itertools -import logging - -from trio import CancelScope, EndOfChannel -import trio - -from strikebot.updates import Command, parse_update - - -__version__ = importlib.metadata.version(__package__) - - -@dataclass -class _Update: - id: str - name: str - author: str - number: Optional[int] - command: Optional[Command] - ts: dt.datetime - count_attempt: bool - deletable: bool - stricken: bool - - def can_follow(self, prior: "_Update") -> bool: - """Determine whether this count update can follow another count update.""" - return self.number == prior.number + 1 and self.author != prior.author - - def __str__(self): - content = str(self.number) - if self.command: - content += "+" + self.command.name - return "{}({} by {})".format(self.id[4 : 8], content, self.author) - - -@total_ordering -@dataclass -class _BufferedUpdate: - update: _Update - release_at: float # Trio time - - def __lt__(self, other): - return self.update.ts < other.update.ts - - -@total_ordering -@dataclass -class _TimelineUpdate: - update: _Update - accepted: bool - - def __lt__(self, other): - return self.update.ts < other.update.ts - - def rejected(self) -> bool: - """Whether the update should be stricken.""" - return self.update.count_attempt and not self.accepted - - -def _parse_rfc_4122_ts(ts: int) -> dt.datetime: - epoch = dt.datetime(1582, 10, 15, tzinfo = dt.timezone.utc) - return epoch + dt.timedelta(microseconds = ts / 10) - - -def _format_update_ref(update: _Update, thread_id: str) -> str: - url = f"https://www.reddit.com/live/{thread_id}/updates/{update.id}" - author_md = update.author.replace("_", "\\_") # Reddit allows letters, digits, underscore, and hyphen - return f"[{update.number:,} by {author_md}]({url})" - - -def _format_bad_strike_alert(update: _Update, thread_id: str) -> str: - return "_incorrectly stricken:_ " + _format_update_ref(update, thread_id) - - -def _format_curr_count(last_valid: Optional[_Update]) -> str: - if last_valid and last_valid.number is not None: - count_str = f"{last_valid.number + 1:,}" - else: - count_str = "unknown" - return f"_current count:_ {count_str}" - - -async def count_tracker_impl( - message_rx, - api_pool: "strikebot.reddit_api.ApiClientPool", - reorder_buffer_time: dt.timedelta, # duration to hold some updates for reordering - thread_id: str, - bot_user: str, - enforcing: bool, - read_only: bool, - update_retention: dt.timedelta, - logger: logging.Logger, -) -> None: - from strikebot.reddit_api import CorrectionUpdateRequest, DeleteRequest, ReportUpdateRequest, StrikeRequest - - enforcing = enforcing and not read_only - - buffer_ = [] - timeline = [] - last_valid: Optional[_Update] = None - pending_strikes: set[str] = set() # names of updates to mark stricken on arrival - forgotten: bool = False # whether the retention period for an update has passed during the run - - def handle_update(update): - nonlocal last_valid - - tu = _TimelineUpdate(update, accepted = False) - - pos = bisect.bisect(timeline, tu) - if pos != len(timeline): - logger.warning(f"long transpo: {update.id}") - - if pos > 0 and timeline[pos - 1].update.id == update.id: - # The pool sent this update message multiple times. This could be because the message reached parity and - # its bucket was dropped just before a new connection was opened and sent its own copy, but connections also - # seem to send duplicate messages. - return - - pred = next( - (timeline[i].update for i in reversed(range(pos)) if timeline[i].accepted), - None - ) - contender = update.command is Command.RESET or update.number is not None - if contender and pred is None and last_valid and forgotten: - # This is a really long transpo which we can't judge valid or invalid since we've forgotten the surrounding - # updates. The best we can do is just ignore it. - logger.warning(f"ignoring {update.id}: no valid prior count on record") - else: - timeline.insert(pos, tu) - tu.accepted = ( - update.command is Command.RESET - or ( - update.number is not None - and (pred is None or pred.number is None or update.can_follow(pred)) - and not update.stricken - ) - ) - logger.debug("{} {} -> {}".format("accepted" if tu.accepted else "rejected", pred, update)) - if tu.accepted: - # resync subsequent updates - newly_invalid = [] - resync_last_valid = update - converged = False - for scan_tu in islice(timeline, pos + 1, None): - if scan_tu.update.command is Command.RESET: - resync_last_valid = scan_tu.update - if last_valid: - converged = True - elif scan_tu.update.number is not None: - accept = ( - (resync_last_valid.number is None or scan_tu.update.can_follow(resync_last_valid)) - and not scan_tu.update.stricken - ) - if accept: - assert scan_tu.accepted - resync_last_valid = scan_tu.update - # resync would have no effect past this point - if last_valid: - converged = True - elif scan_tu.accepted: - newly_invalid.append(scan_tu) - if scan_tu.update is last_valid: - last_valid = None - scan_tu.accepted = accept - if converged: - break - - if converged: - assert last_valid.ts >= resync_last_valid.ts - else: - last_valid = resync_last_valid - - parts = [] - unstrikable = [tu for tu in newly_invalid if tu.update.stricken] - if unstrikable: - parts.append( - "The following counts are invalid:\n\n" - + "\n".join(" - " + _format_update_ref(tu.update, thread_id) for tu in unstrikable) - ) - - if update.stricken: - logger.info(f"bad strike of {update.id}") - parts.append(_format_bad_strike_alert(update, thread_id)) - - if parts: - parts.append(_format_curr_count(last_valid)) - if not read_only: - api_pool.enqueue_request(CorrectionUpdateRequest(thread_id, "\n\n".join(parts))) - - for invalid_tu in newly_invalid: - if not invalid_tu.update.stricken: - if enforcing: - api_pool.enqueue_request(StrikeRequest(thread_id, invalid_tu.update.name, invalid_tu.update.ts)) - invalid_tu.update.stricken = True - elif update.count_attempt and not update.stricken: - if enforcing: - api_pool.enqueue_request(StrikeRequest(thread_id, update.name, update.ts)) - update.stricken = True - - if not read_only and update.command is Command.REPORT: - api_pool.enqueue_request(ReportUpdateRequest(thread_id, body = _format_curr_count(last_valid))) - - async with message_rx: - while True: - if buffer_: - deadline = min(bu.release_at for bu in buffer_) - cancel_scope = CancelScope(deadline = deadline) - else: - cancel_scope = nullcontext() - - msg = None - with cancel_scope: - try: - msg = await message_rx.receive() - except EndOfChannel: - break - - if msg: - if msg.data["type"] == "update": - payload_data = msg.data["payload"]["data"] - stricken = payload_data["name"] in pending_strikes - pending_strikes.discard(payload_data["name"]) - if payload_data["author"] != bot_user: - if last_valid and last_valid.number is not None: - next_up = last_valid.number + 1 - else: - next_up = None - pu = parse_update(payload_data, next_up, bot_user) - rfc_ts = UUID(payload_data["id"]).time - update = _Update( - id = payload_data["id"], - name = payload_data["name"], - author = payload_data["author"], - number = pu.number, - command = pu.command, - ts = _parse_rfc_4122_ts(rfc_ts), - count_attempt = pu.count_attempt, - deletable = pu.deletable, - stricken = stricken or payload_data["stricken"], - ) - release_at = msg.timestamp + reorder_buffer_time.total_seconds() - bisect.insort(buffer_, _BufferedUpdate(update, release_at)) - elif msg.data["type"] == "strike": - slot = next( - ( - slot for slot in itertools.chain(buffer_, reversed(timeline)) - if slot.update.name == msg.data["payload"] - ), - None - ) - if slot: - if not slot.update.stricken: - slot.update.stricken = True - if isinstance(slot, _TimelineUpdate) and slot.accepted: - logger.info(f"bad strike of {slot.update.id}") - if not read_only: - body = _format_bad_strike_alert(slot.update, thread_id) - api_pool.enqueue_request(CorrectionUpdateRequest(thread_id, body)) - else: - pending_strikes.add(msg.data["payload"]) - - threshold = dt.datetime.now(dt.timezone.utc) - update_retention - - # pull updates from the reorder buffer and process - new_buffer = [] - for bu in buffer_: - process = ( - # not a count - bu.update.number is None - - # valid count - or (last_valid is None or last_valid.number is None or bu.update.can_follow(last_valid)) - or bu.update.command is Command.RESET - - or trio.current_time() >= bu.release_at - or bu.update.command is Command.RESET - ) - - if process: - if bu.update.ts < threshold: - logger.warning(f"ignoring {bu.update}: arrived past retention window") - else: - logger.debug(f"processing {bu.update}") - handle_update(bu.update) - else: - logger.debug(f"holding {bu.update}, checked against {last_valid})") - new_buffer.append(bu) - buffer_ = new_buffer - logger.debug("last count {}".format(last_valid)) - - # delete/forget old updates - new_timeline = [] - i = 0 - for (i, tu) in enumerate(timeline): - if i >= len(timeline) - 10 or tu.update.ts >= threshold: - break - elif tu.rejected() and tu.update.deletable: - if enforcing: - api_pool.enqueue_request(DeleteRequest(thread_id, tu.update.name, tu.update.ts)) - elif tu.update is last_valid: - new_timeline.append(tu) - - if i: - forgotten = True - new_timeline.extend(islice(timeline, i, None)) - timeline = new_timeline diff --git a/strikebot/src/strikebot/__main__.py b/strikebot/src/strikebot/__main__.py deleted file mode 100644 index fc80e23..0000000 --- a/strikebot/src/strikebot/__main__.py +++ /dev/null @@ -1,224 +0,0 @@ -from enum import Enum -from inspect import getmodule -from logging import FileHandler, getLogger, StreamHandler -from pathlib import Path -from signal import SIGUSR1 -from sys import stdout -from traceback import StackSummary -from typing import Optional -import argparse -import configparser -import datetime as dt -import logging -import ssl - -from trio import open_memory_channel, open_nursery, open_signal_receiver - -try: - from trio.lowlevel import current_root_task, Task -except ImportError: - from trio.hazmat import current_root_task, Task - -import trio_asyncio -import triopg - -from strikebot import count_tracker_impl -from strikebot.db import Client, Messenger -from strikebot.live_ws import HealingReadPool, PoolMerger -from strikebot.reddit_api import ApiClientPool - - -_DEBUG_LOG_PATH: Optional[Path] = None # file to write debug logs to, if any - - -_TaskKind = Enum( - "_TaskKind", - [ - "API_WORKER", - "DB_CLIENT", - "TOKEN_UPDATE", - "TRACKER", - "WS_MERGE_READER", - "WS_MERGE_TIMER", - "WS_REFRESHER", - "WS_WORKER", - ] -) - - -def _classify_task(task: Task) -> _TaskKind: - mod_name = getmodule(task.coro.cr_code).__name__ - if mod_name.startswith("trio_asyncio."): - return None - elif task.coro.__name__ == "_signal_handler_impl": - return None - else: - by_coro_name = { - "db_client_impl": _TaskKind.DB_CLIENT, - "token_updater_impl": _TaskKind.TOKEN_UPDATE, - "event_reader_impl": _TaskKind.WS_MERGE_READER, - "timeout_handler_impl": _TaskKind.WS_MERGE_TIMER, - "conn_refresher_impl": _TaskKind.WS_REFRESHER, - "count_tracker_impl": _TaskKind.TRACKER, - "_reader_impl": _TaskKind.WS_WORKER, - "worker_impl": _TaskKind.API_WORKER, - } - return by_coro_name[task.coro.__name__] - - -def _get_all_tasks(): - [ta_loop] = [ - task - for nursery in current_root_task().child_nurseries - for task in nursery.child_tasks - if getmodule(task.coro.cr_code).__name__.startswith("trio_asyncio.") - ] - for nursery in ta_loop.child_nurseries: - yield from nursery.child_tasks - - -def _print_task_info(): - groups = {kind: [] for kind in _TaskKind} - for task in _get_all_tasks(): - kind = _classify_task(task) - if kind: - coro = task.coro - frame_tuples = [] - while coro is not None: - if hasattr(coro, "cr_frame"): - frame_tuples.append((coro.cr_frame, coro.cr_frame.f_lineno)) - coro = coro.cr_await - else: - frame_tuples.append((coro.gi_frame, coro.gi_frame.f_lineno)) - coro = coro.gi_yieldfrom - groups[kind].append(StackSummary.extract(iter(frame_tuples))) - for kind in _TaskKind: - print(kind.name) - for ss in groups[kind]: - print(" task") - for line in ss.format(): - print(" " + line, end = "") - - -async def _signal_handler_impl(): - with open_signal_receiver(SIGUSR1) as sig_src: - async for _ in sig_src: - _print_task_info() - - -ap = argparse.ArgumentParser(__package__) -ap.add_argument("config_path") -args = ap.parse_args() - - -parser = configparser.ConfigParser() -with open(args.config_path) as config_file: - parser.read_file(config_file) - -main_cfg = parser["config"] - -api_pool_error_delay = dt.timedelta(seconds = main_cfg.getfloat("API pool error delay")) -api_pool_error_window = dt.timedelta(seconds = main_cfg.getfloat("API pool error window")) -auth_ids = set(map(int, main_cfg["auth IDs"].split())) -bot_user = main_cfg["bot user"] -enforcing = main_cfg.getboolean("enforcing", fallback = True) - -raw_log_level = main_cfg.get("log level", "WARNING") -try: - log_level = int(raw_log_level) -except ValueError: - log_level = raw_log_level - -read_only = main_cfg.getboolean("read-only", fallback = False) -reorder_buffer_time = dt.timedelta(seconds = main_cfg.getfloat("reorder buffer time")) -request_queue_limit = main_cfg.getint("request queue limit") -thread_id = main_cfg["thread ID"] -update_retention = dt.timedelta(seconds = main_cfg.getfloat("update retention time")) -ws_parity_time = dt.timedelta(seconds = main_cfg.getfloat("WS parity time")) -ws_pool_size = main_cfg.getint("WS pool size") -ws_silent_limit = dt.timedelta(seconds = main_cfg.getfloat("WS silent limit")) -ws_warmup = dt.timedelta(seconds = main_cfg.getfloat("WS warmup time")) - - -db_cfg = parser["db connect params"] -cert_path = db_cfg.pop("sslcert", None) -key_path = db_cfg.pop("sslkey", None) -key_pass = db_cfg.pop("sslpassword", None) - -getters = { - "port": db_cfg.getint, -} -db_connect_kwargs = {k: getters.get(k, db_cfg.get)(k) for k in db_cfg} - -if cert_path: - # patch TLS client cert auth support; asyncpg adds it in a later version - db_ssl_ctx = ssl.create_default_context() - db_ssl_ctx.load_cert_chain( - cert_path, - key_path, - None if key_pass is None else lambda: key_pass, - ) - db_connect_kwargs["ssl"] = db_ssl_ctx - - -if read_only and enforcing: - raise RuntimeError("can't use read-only mode with enforcing on") - - -logger = getLogger(__package__) -logger.setLevel(logging.NOTSET - 1) # NOTSET means inherit from parent; we use handlers to filter - -handler = StreamHandler(stdout) -handler.setLevel(log_level) -handler.setFormatter(logging.Formatter("{asctime:23}: {name:17} {levelname:8}: {message}", style = "{")) -logger.addHandler(handler) - -if _DEBUG_LOG_PATH: # TODO remove this ad hoc setup - debug_handler = FileHandler(_DEBUG_LOG_PATH, mode = "w") - debug_handler.setLevel(logging.DEBUG) - debug_handler.setFormatter(logging.Formatter("{asctime:23}: {name:17} {levelname:8}: {message}", style = "{")) - logger.addHandler(debug_handler) - - -async def main(): - async with ( - triopg.connect(**db_connect_kwargs) as db_conn, - open_nursery() as nursery_a, - open_nursery() as nursery_b, - open_nursery() as ws_pool_nursery, - open_nursery() as nursery_c, - ): - nursery_a.start_soon(_signal_handler_impl) - - client = Client(db_conn) - db_messenger = Messenger(client) - nursery_a.start_soon(db_messenger.db_client_impl) - - api_pool = ApiClientPool( - auth_ids, db_messenger, request_queue_limit, api_pool_error_window, api_pool_error_delay, - logger.getChild("api") - ) - nursery_a.start_soon(api_pool.token_updater_impl) - for _ in auth_ids: - await nursery_a.start(api_pool.worker_impl) - - (message_tx, message_rx) = open_memory_channel(0) - (pool_event_tx, pool_event_rx) = open_memory_channel(0) - - merger = PoolMerger(pool_event_rx, message_tx, ws_parity_time, ws_warmup, logger.getChild("merge")) - nursery_b.start_soon(merger.event_reader_impl) - nursery_b.start_soon(merger.timeout_handler_impl) - - ws_pool = HealingReadPool( - ws_pool_nursery, ws_pool_size, thread_id, api_pool, pool_event_tx, logger.getChild("live_ws"), - ws_silent_limit - ) - nursery_c.start_soon(ws_pool.conn_refresher_impl) - - nursery_c.start_soon( - count_tracker_impl, message_rx, api_pool, reorder_buffer_time, thread_id, bot_user, enforcing, read_only, - update_retention, logger.getChild("track") - ) - await ws_pool.init_workers() - -trio_asyncio.run(main) diff --git a/strikebot/src/strikebot/common.py b/strikebot/src/strikebot/common.py deleted file mode 100644 index 519037a..0000000 --- a/strikebot/src/strikebot/common.py +++ /dev/null @@ -1,10 +0,0 @@ -from typing import Any - - -def int_digest(val: int) -> str: - return "{:04x}".format(val & 0xffff) - - -def obj_digest(obj: Any) -> str: - """Makes a short digest of the identity of an object.""" - return int_digest(id(obj)) diff --git a/strikebot/src/strikebot/db.py b/strikebot/src/strikebot/db.py deleted file mode 100644 index d1a2f42..0000000 --- a/strikebot/src/strikebot/db.py +++ /dev/null @@ -1,58 +0,0 @@ -"""Single-connection Postgres client with high-level wrappers for operations.""" - -from dataclasses import dataclass -from typing import Any, Iterable - -import trio - - -def _channel_sender(method): - async def wrapped(self, resp_channel, *args, **kwargs): - ret = await method(self, *args, **kwargs) - async with resp_channel: - await resp_channel.send(ret) - - return wrapped - - -@dataclass -class Client: - """High-level wrappers for DB operations.""" - conn: Any - - @_channel_sender - async def get_auth_tokens(self, auth_ids: set[int]) -> dict[int, str]: - results = await self.conn.fetch( - """ - select - distinct on (id) - id, access_token - from - public.reddit_app_authorization - join unnest($1::integer[]) as request_id on id = request_id - where expires > current_timestamp - """, - list(auth_ids), - ) - return {r["id"]: r["access_token"] for r in results} - - -class Messenger: - def __init__(self, client: Client): - self._client = client - (self._request_tx, self._request_rx) = trio.open_memory_channel(0) - - async def do(self, method_name: str, args: Iterable) -> Any: - """This is run by consumers of the DB wrapper.""" - method = getattr(self._client, method_name) - (resp_tx, resp_rx) = trio.open_memory_channel(0) - coro = method(resp_tx, *args) - await self._request_tx.send(coro) - async with resp_rx: - return await resp_rx.receive() - - async def db_client_impl(self): - """This is the DB client Trio task.""" - async with self._client.conn, self._request_rx: - async for coro in self._request_rx: - await coro diff --git a/strikebot/src/strikebot/live_ws.py b/strikebot/src/strikebot/live_ws.py deleted file mode 100644 index 031d652..0000000 --- a/strikebot/src/strikebot/live_ws.py +++ /dev/null @@ -1,266 +0,0 @@ -from collections import deque -from contextlib import suppress -from dataclasses import dataclass -from typing import Any, AsyncContextManager, Optional -import datetime as dt -import json -import logging -import math - -from trio_websocket import ConnectionClosed, HandshakeError, open_websocket_url, WebSocketConnection -import trio - -from strikebot.common import int_digest, obj_digest -from strikebot.reddit_api import AboutLiveThreadRequest - - -@dataclass -class _PoolEvent: - timestamp: float # Trio time - - -@dataclass -class _Message(_PoolEvent): - data: Any - scope: Any - - def dedup_tag(self): - """A hashable essence of the contents of this message.""" - if self.data["type"] == "update": - return ("update", self.data["payload"]["data"]["id"]) - elif self.data["type"] in {"strike", "delete"}: - return (self.data["type"], self.data["payload"]) - else: - raise ValueError(self.data["type"]) - - -@dataclass -class _ConnectionDown(_PoolEvent): - scope: trio.CancelScope - - -@dataclass -class _ConnectionUp(_PoolEvent): - scope: trio.CancelScope - - -class HealingReadPool: - def __init__( - self, pool_nursery: trio.Nursery, size, live_thread_id, api_client_pool, pool_event_tx, logger: logging.Logger, - silent_limit: dt.timedelta - ): - assert size >= 2 - self._nursery = pool_nursery - self._size = size - self._live_thread_id = live_thread_id - self._api_client_pool = api_client_pool - self._pool_event_tx = pool_event_tx - self._logger = logger - self._silent_limit = silent_limit - - (self._refresh_queue_tx, self._refresh_queue_rx) = trio.open_memory_channel(size) - - # number of workers who have stopped receiving updates but whose tasks aren't yet stopped - self._closing = 0 - - async def init_workers(self): - for _ in range(self._size): - await self._spawn_reader() - if not self._nursery.child_tasks: - raise RuntimeError("Unable to create any WS connections") - - async def _reader_impl(self, conn_ctx: AsyncContextManager, cancel_scope, refresh_tx): - # TODO could use task's implicit cancel scope - try: - async with conn_ctx as conn: - self._logger.debug("scope up: {}".format(obj_digest(cancel_scope))) - async with refresh_tx: - silent_timeout = False - with cancel_scope, suppress(ConnectionClosed): - while True: - with trio.move_on_after(self._silent_limit.total_seconds()) as timeout_scope: - message = await conn.get_message() - - if timeout_scope.cancelled_caught: - if len(self._nursery.child_tasks) - self._closing == self._size: - silent_timeout = True - break - else: - self._logger.debug("not replacing connection {}; {} tasks, {} closing".format( - obj_digest(cancel_scope), - len(self._nursery.child_tasks), - self._closing, - )) - else: - event = _Message(trio.current_time(), json.loads(message), cancel_scope) - await self._pool_event_tx.send(event) - - self._closing += 1 - if silent_timeout: - await conn.aclose() - self._logger.debug("replacing WS connection {} due to silent timeout".format(obj_digest(cancel_scope))) - elif cancel_scope.cancelled_caught: - await conn.aclose(1008, "Server unexpectedly stopped sending messages") - self._logger.warning("replacing WS connection {} due to missed update".format(obj_digest(cancel_scope))) - else: - self._logger.warning("replacing WS connection {} closed by server".format(obj_digest(cancel_scope))) - - refresh_tx.send_nowait(cancel_scope) - self._logger.debug("scope down: {} ({} tasks, {} closing)".format( - obj_digest(cancel_scope), - len(self._nursery.child_tasks), - self._closing, - )) - self._closing -= 1 - except HandshakeError: - self._logger.error("handshake error while opening WS connection") - - async def _new_connection(self) -> Optional[AsyncContextManager[WebSocketConnection]]: - request = AboutLiveThreadRequest(self._live_thread_id) - resp = await self._api_client_pool.make_request(request) - if resp.status_code == 200: - url = json.loads(resp.text)["data"]["websocket_url"] - return open_websocket_url(url) - else: - self._logger.error(f"thread info request failed (HTTP {resp.status_code}); can't open WS connection") - return None - - async def _spawn_reader(self) -> None: - conn = await self._new_connection() - if conn: - new_scope = trio.CancelScope() - await self._pool_event_tx.send(_ConnectionUp(trio.current_time(), new_scope)) - refresh_tx = self._refresh_queue_tx.clone() - self._nursery.start_soon(self._reader_impl, conn, new_scope, refresh_tx) - - async def conn_refresher_impl(self): - """Task to monitor and replace WS connections as they disconnect or go silent.""" - async with self._refresh_queue_rx: - async for old_scope in self._refresh_queue_rx: - await self._pool_event_tx.send(_ConnectionDown(trio.current_time(), old_scope)) - await self._spawn_reader() - if not self._nursery.child_tasks: - raise RuntimeError("WS pool depleted") - - -def _tag_digest(tag): - return int_digest(hash(tag)) - - -def _format_scope_list(scopes): - return ", ".join(sorted(map(obj_digest, scopes))) - - -class PoolMerger: - @dataclass - class _Bucket: - start: float - recipients: set - - def __init__(self, pool_event_rx, message_tx, parity_timeout: dt.timedelta, conn_warmup: dt.timedelta, logger: logging.Logger): - """ - :param parity_timeout: max delay between message receipt on different connections - :param conn_warmup: max interval after connection within which missed updates are allowed - """ - self._pool_event_rx = pool_event_rx - self._message_tx = message_tx - self._parity_timeout = parity_timeout - self._conn_warmup = conn_warmup - self._logger = logger - - self._buckets = {} - self._scope_activations = {} - self._pending = deque() - self._outgoing_scopes = set() - (self._timer_poke_tx, self._timer_poke_rx) = trio.open_memory_channel(math.inf) - - def _log_current_buckets(self): - self._logger.debug("current buckets: {}".format(", ".join(map(_tag_digest, self._buckets.keys())))) - - async def event_reader_impl(self): - """Drop unused messages, deduplicate useful ones, and communicate with the timeout handler.""" - async with self._pool_event_rx, self._timer_poke_tx: - async for event in self._pool_event_rx: - if isinstance(event, _ConnectionUp): - # An early add of an active scope could mean it's expected on a message that fired before it opened, - # resulting in a false positive for replacement. The timer task merges these in among the buckets by - # timestamp to avoid that. - self._pending.append(event) - elif isinstance(event, _Message): - if event.data["type"] in ["update", "strike", "delete"]: - tag = event.dedup_tag() - self._logger.debug("recv {} from {}".format(_tag_digest(tag), obj_digest(event.scope))) - if tag in self._buckets: - b = self._buckets[tag] - b.recipients.add(event.scope) - # If this scope is the last one for this bucket we could clear the bucket here, but since - # connections sometimes get repeat messages and second copies can arrive before the first - # copy has arrived on all connections, leaving the bucket open to absorb repeats can reduce - # the likelihood of a second bucket being allocated late for the second copy of a message - # and causing unnecessary connection replacement. - elif event.scope not in self._outgoing_scopes: - sane = ( - event.scope in self._scope_activations - or any(e.scope == event.scope for e in self._pending) - ) - if sane: - self._logger.debug("new bucket {}: {}".format(_tag_digest(tag), tag)) - self._buckets[tag] = self._Bucket(event.timestamp, {event.scope}) - self._log_current_buckets() - await self._message_tx.send(event) - else: - raise RuntimeError("recieved message from unrecognized WS connection") - else: - self._logger.debug("recv type {!r} from {} (discarding)".format(event.data["type"], obj_digest(event.scope))) - elif isinstance(event, _ConnectionDown): - # We don't need to worry about canceling this scope at all, so no need to require it for parity for - # any message, even older ones. The scope may be gone already, if we canceled it previously. - self._scope_activations.pop(event.scope, None) - self._pending = deque(ev for ev in self._pending if ev.scope is not event.scope) - self._outgoing_scopes.discard(event.scope) - else: - raise TypeError(f"Expected pool event, found {event!r}") - self._timer_poke_tx.send_nowait(None) # may have new work for the timer - - async def timeout_handler_impl(self): - """When connections have had enough time to reach parity on a message, signal replacement of any slackers.""" - async with self._timer_poke_rx: - while True: - if self._buckets: - now = trio.current_time() - (tag, bucket) = min(self._buckets.items(), key = lambda i: i[1].start) - - # Make sure the right scope set is ready for the next bucket up - while self._pending and self._pending[0].timestamp < bucket.start: - ev = self._pending.popleft() - self._scope_activations[ev.scope] = ev.timestamp - - target_scopes = { - scope for (scope, active) in self._scope_activations.items() - if active + self._conn_warmup.total_seconds() < now - } - if now > bucket.start + self._parity_timeout.total_seconds(): - filled = target_scopes & bucket.recipients - missing = target_scopes - bucket.recipients - extra = bucket.recipients - target_scopes - - self._logger.debug("expiring bucket {}".format(_tag_digest(tag))) - if filled: - self._logger.debug(" filled {}: {}".format(len(filled), _format_scope_list(filled))) - if missing: - self._logger.debug(" missing {}: {}".format(len(missing), _format_scope_list(missing))) - if extra: - self._logger.debug(" extra {}: {}".format(len(extra), _format_scope_list(extra))) - - for scope in missing: - self._outgoing_scopes.add(scope) - scope.cancel() - del self._buckets[tag] - self._log_current_buckets() - else: - await trio.sleep_until(bucket.start + self._parity_timeout.total_seconds()) - else: - try: - await self._timer_poke_rx.receive() - except trio.EndOfChannel: - break diff --git a/strikebot/src/strikebot/queue.py b/strikebot/src/strikebot/queue.py deleted file mode 100644 index acf5ef0..0000000 --- a/strikebot/src/strikebot/queue.py +++ /dev/null @@ -1,61 +0,0 @@ -"""Unbounded blocking queues for Trio""" - -from collections import deque -from dataclasses import dataclass -from functools import total_ordering -from typing import Any, Iterable -import heapq - -try: - from trio.lowlevel import ParkingLot -except ImportError: - from trio.hazmat import ParkingLot - - -class Queue: - def __init__(self): - self._deque = deque() - self._empty_wait = ParkingLot() - - def push(self, el: Any) -> None: - self._deque.append(el) - self._empty_wait.unpark() - - def extend(self, els: Iterable[Any]) -> None: - for el in els: - self.push(el) - - async def pop(self) -> Any: - if not self._deque: - await self._empty_wait.park() - return self._deque.popleft() - - def __len__(self): - return len(self._deque) - - -@total_ordering -@dataclass -class _ReverseOrdWrapper: - inner: Any - - def __lt__(self, other): - return other.inner < self.inner - - -class MaxHeap: - def __init__(self): - self._heap = [] - self._empty_wait = ParkingLot() - - def push(self, item): - heapq.heappush(self._heap, _ReverseOrdWrapper(item)) - self._empty_wait.unpark() - - async def pop(self): - if not self._heap: - await self._empty_wait.park() - return heapq.heappop(self._heap).inner - - def __len__(self) -> int: - return len(self._heap) diff --git a/strikebot/src/strikebot/reddit_api.py b/strikebot/src/strikebot/reddit_api.py deleted file mode 100644 index 2059707..0000000 --- a/strikebot/src/strikebot/reddit_api.py +++ /dev/null @@ -1,291 +0,0 @@ -"""Multi-auth HTTP connection pool with Reddit API wrappers and rate limiting.""" - -from abc import ABCMeta, abstractmethod -from dataclasses import dataclass, field -from functools import total_ordering -from socket import EAI_AGAIN, EAI_FAIL, gaierror -import datetime as dt -import logging - -from asks.response_objects import Response -import asks -import trio - -from strikebot import __version__ as VERSION -from strikebot.common import obj_digest -from strikebot.queue import MaxHeap, Queue - - -REQUEST_DELAY = dt.timedelta(seconds = 1) - -TOKEN_UPDATE_DELAY = dt.timedelta(minutes = 15) - -USER_AGENT_FMT = f"any:net.jcornell.strikebot.{{auth_id}}:v{VERSION} (by /u/jaxklax)" - -API_BASE_URL = "https://oauth.reddit.com" - - -@total_ordering -class _Request(metaclass = ABCMeta): - _SUBTYPE_PRECEDENCE = None # assigned later - - def __lt__(self, other): - if type(self) is type(other): - return self._subtype_cmp_key() < other._subtype_cmp_key() - else: - prec = self._SUBTYPE_PRECEDENCE - return prec.index(type(other)) < prec.index(type(self)) - - def __eq__(self, other): - if type(self) is type(other): - return self._subtype_cmp_key() == other._subtype_cmp_key() - else: - return False - - @abstractmethod - def _subtype_cmp_key(self): - # a large key corresponds to a high priority - raise NotImplementedError() - - @abstractmethod - def to_asks_kwargs(self): - raise NotImplementedError() - - -@dataclass(eq = False) -class StrikeRequest(_Request): - thread_id: str - update_name: str - update_ts: dt.datetime - - def to_asks_kwargs(self): - return { - "method": "POST", - "path": f"/api/live/{self.thread_id}/strike_update", - "data": { - "api_type": "json", - "id": self.update_name, - }, - } - - def _subtype_cmp_key(self): - return self.update_ts - - -@dataclass(eq = False) -class DeleteRequest(_Request): - thread_id: str - update_name: str - update_ts: dt.datetime - - def to_asks_kwargs(self): - return { - "method": "POST", - "path": f"/api/live/{self.thread_id}/delete_update", - "data": { - "api_type": "json", - "id": self.update_name, - }, - } - - def _subtype_cmp_key(self): - return -self.update_ts.timestamp() - - -@dataclass(eq = False) -class AboutLiveThreadRequest(_Request): - thread_id: str - created: float = field(default_factory = trio.current_time) - - def to_asks_kwargs(self): - return { - "method": "GET", - "path": f"/live/{self.thread_id}/about", - "params": { - "raw_json": "1", - }, - } - - def _subtype_cmp_key(self): - return -self.created - - -@dataclass(eq = False) -class _BaseUpdatePostRequest(_Request, metaclass = ABCMeta): - thread_id: str - body: str - created: float = field(default_factory = trio.current_time) - - def to_asks_kwargs(self): - return { - "method": "POST", - "path": f"/live/{self.thread_id}/update", - "data": { - "api_type": "json", - "body": self.body, - }, - } - - -@dataclass(eq = False) -class ReportUpdateRequest(_BaseUpdatePostRequest): - def _subtype_cmp_key(self): - return -self.created - - -@dataclass(eq = False) -class CorrectionUpdateRequest(_BaseUpdatePostRequest): - def _subtype_cmp_key(self): - return -self.created - - -# highest priority first -_Request._SUBTYPE_PRECEDENCE = [ - AboutLiveThreadRequest, - StrikeRequest, - CorrectionUpdateRequest, - ReportUpdateRequest, - DeleteRequest, -] - - -@dataclass -class AppCooldown: - auth_id: int - ready_at: float # Trio time - - -class ApiClientPool: - def __init__( - self, - auth_ids: set[int], - db_messenger: "strikebot.db.Messenger", - request_queue_limit: int, - error_window: dt.timedelta, - error_delay: dt.timedelta, - logger: logging.Logger, - ): - self._auth_ids = auth_ids - self._db_messenger = db_messenger - self._request_queue_limit = request_queue_limit - self._error_window = error_window - self._error_delay = error_delay - self._logger = logger - - now = trio.current_time() - self._tokens = {} - self._app_queue = Queue() - self._waiting = {id_: AppCooldown(id_, now) for id_ in auth_ids} - self._request_queue = MaxHeap() - - # pool-wide API error backoff - self._last_error = None - self._global_resume = None - - self._session = asks.Session(connections = len(auth_ids)) - self._session.base_location = API_BASE_URL - - async def _update_tokens(self): - tokens = await self._db_messenger.do("get_auth_tokens", (self._auth_ids,)) - self._tokens.update(tokens) - - awaken_auths = self._waiting.keys() & tokens.keys() - self._app_queue.extend(self._waiting.pop(auth_id) for auth_id in awaken_auths) - self._logger.debug("updated API tokens") - - async def token_updater_impl(self): - last_update = None - while True: - if last_update is not None: - await trio.sleep_until(last_update + TOKEN_UPDATE_DELAY.total_seconds()) - - if last_update is None: - last_update = trio.current_time() - else: - last_update += TOKEN_UPDATE_DELAY.total_seconds() - - await self._update_tokens() - - def _check_queue_size(self) -> None: - if len(self._request_queue) > self._request_queue_limit: - raise RuntimeError("request queue size exceeded limit") - - async def make_request(self, request: _Request) -> Response: - (resp_tx, resp_rx) = trio.open_memory_channel(0) - self.enqueue_request(request, resp_tx) - async with resp_rx: - return await resp_rx.receive() - - def enqueue_request(self, request: _Request, resp_tx = None) -> None: - self._logger.debug("request {}: {}".format(obj_digest(request), type(request).__name__)) - self._request_queue.push((request, resp_tx)) - self._check_queue_size() - self._logger.debug(f"{len(self._request_queue)} requests in queue") - - async def worker_impl(self, task_status): - task_status.started() - while True: - (request, resp_tx) = await self._request_queue.pop() - cooldown = await self._app_queue.pop() - await trio.sleep_until(cooldown.ready_at) - if self._global_resume: - await trio.sleep_until(self._global_resume) - - asks_kwargs = request.to_asks_kwargs() - headers = asks_kwargs.setdefault("headers", {}) - headers.update({ - "Authorization": "Bearer {}".format(self._tokens[cooldown.auth_id]), - "User-Agent": USER_AGENT_FMT.format(auth_id = cooldown.auth_id), - }) - - request_time = trio.current_time() - try: - resp = await self._session.request(**asks_kwargs) - except gaierror as e: - if e.errno in [EAI_FAIL, EAI_AGAIN]: - # DNS failure, probably temporary - error = True - else: - raise - else: - resp.body # read response - error = False - wait_for_token = False - log_suffix = " (request {})".format(obj_digest(request)) - if resp.status_code == 429: - # We disagreed about the rate limit state; just try again later. - self._logger.warning("rate limited by Reddit API" + log_suffix) - error = True - elif resp.status_code == 401: - self._logger.warning("got HTTP 401 from Reddit API" + log_suffix) - error = True - wait_for_token = True - elif resp.status_code in [404, 500, 503]: - self._logger.warning(f"got HTTP {resp.status_code} from Reddit API, retrying" + log_suffix) - error = True - elif 400 <= resp.status_code < 500: - # If we're doing something wrong, let's catch it right away. - raise RuntimeError(f"unexpected client error response: {resp.status_code}" + log_suffix) - else: - if resp.status_code != 200: - raise RuntimeError(f"unexpected status code {resp.status_code}") - self._logger.debug("success" + log_suffix) - if resp_tx: - await resp_tx.send(resp) - - if error: - self._request_queue.push((request, resp_tx)) - self._check_queue_size() - if self._last_error: - spread = dt.timedelta(seconds = request_time - self._last_error) - if spread <= self._error_window: - self._global_resume = request_time + self._error_delay.total_seconds() - self._last_error = request_time - - cooldown.ready_at = request_time + REQUEST_DELAY.total_seconds() - if wait_for_token: - self._waiting[cooldown.auth_id] = cooldown - if not self._app_queue: - self._logger.error("all workers waiting for API tokens") - else: - self._app_queue.push(cooldown) diff --git a/strikebot/src/strikebot/tests.py b/strikebot/src/strikebot/tests.py deleted file mode 100644 index 966bca6..0000000 --- a/strikebot/src/strikebot/tests.py +++ /dev/null @@ -1,62 +0,0 @@ -from unittest import TestCase - -from strikebot.updates import parse_update - - -def _build_payload(body_html: str) -> dict[str, str]: - return {"body_html": body_html} - - -class UpdateParsingTests(TestCase): - def test_successful_counts(self): - pu = parse_update(_build_payload("
12,345,678 spaghetti
"), None, "") - self.assertEqual(pu.number, 12345678) - self.assertFalse(pu.deletable) - - pu = parse_update(_build_payload("

0


oz
"), None, "") - self.assertEqual(pu.number, 0) - self.assertFalse(pu.deletable) - - pu = parse_update(_build_payload("
121 345 621
"), 121, "") - self.assertEqual(pu.number, 121) - self.assertFalse(pu.deletable) - - pu = parse_update(_build_payload("28 336 816"), 28336816, "") - self.assertEqual(pu.number, 28336816) - self.assertTrue(pu.deletable) - - def test_non_counts(self): - pu = parse_update(_build_payload("
zoo
"), None, "") - self.assertFalse(pu.count_attempt) - self.assertFalse(pu.deletable) - - def test_typos(self): - pu = parse_update(_build_payload("v9"), 888, "") - self.assertIsNone(pu.number) - self.assertTrue(pu.count_attempt) - - pu = parse_update(_build_payload("
v11.585 Empire
"), None, "") - self.assertIsNone(pu.number) - self.assertTrue(pu.count_attempt) - self.assertFalse(pu.deletable) - - pu = parse_update(_build_payload("
11, 585, 22
"), 11_585_202, "") - self.assertIsNone(pu.number) - self.assertTrue(pu.count_attempt) - self.assertTrue(pu.deletable) - - pu = parse_update(_build_payload("0490499"), 4999, "") - self.assertIsNone(pu.number) - self.assertTrue(pu.count_attempt) - - # this update is marked non-deletable, but I'm not sure it should be - pu = parse_update(_build_payload("12,123123"), 12_123_123, "") - self.assertIsNone(pu.number) - self.assertTrue(pu.count_attempt) - - def test_html_handling(self): - pu = parse_update(_build_payload("123
,456"), None, "") - self.assertEqual(pu.number, 123) - - pu = parse_update(_build_payload("
123\n456
"), None, "") - self.assertEqual(pu.number, 123) diff --git a/strikebot/src/strikebot/updates.py b/strikebot/src/strikebot/updates.py deleted file mode 100644 index 01eb67b..0000000 --- a/strikebot/src/strikebot/updates.py +++ /dev/null @@ -1,226 +0,0 @@ -from __future__ import annotations -from dataclasses import dataclass -from enum import Enum -from typing import Optional -import re - -from bs4 import BeautifulSoup - - -Command = Enum("Command", ["RESET", "REPORT"]) - - -@dataclass -class ParsedUpdate: - number: Optional[int] - command: Optional[Command] - count_attempt: bool # either well-formed or typo - deletable: bool - - -def _parse_command(line: str, bot_user: str) -> Optional[Command]: - if line.lower() == f"/u/{bot_user} reset".lower(): - return Command.RESET - elif line.lower() in ["sidebar count", "current count"]: - return Command.REPORT - else: - return None - - -def parse_update(payload_data: dict, curr_count: Optional[int], bot_user: str) -> ParsedUpdate: - # curr_count is the next number up, one more than the last count - - NEW_LINE = object() - SPACE = object() - - # flatten the update content to plain text - tree = BeautifulSoup(payload_data["body_html"], "html.parser") - worklist = list(reversed(tree.contents)) - out = [[]] - while worklist: - el = worklist.pop() - if isinstance(el, str): - out[-1].append(el) - elif el is SPACE: - out[-1].append(el) - elif el is NEW_LINE or el.name in ["br", "hr"]: - if out[-1]: - out.append([]) - elif el.name in ["em", "strong", "del", "span", "sup", "code", "a", "th", "td"]: - worklist.extend(reversed(el.contents)) - elif el.name in ["ul", "ol", "table", "thead", "tbody"]: - worklist.extend(reversed(el.contents)) - elif el.name in ["li", "p", "div", "blockquote"] or re.match(r"h[1-6]$", el.name): - worklist.append(NEW_LINE) - worklist.extend(reversed(el.contents)) - worklist.append(NEW_LINE) - elif el.name == "pre": - out.extend([l] for l in el.text.splitlines()) - out.append([]) - elif el.name == "tr": - worklist.append(NEW_LINE) - for (i, cell) in enumerate(reversed(el.contents)): - worklist.append(cell) - if i != len(el.contents) - 1: - worklist.append(SPACE) - worklist.append(NEW_LINE) - else: - raise RuntimeError(f"can't parse tag {el.name}") - - tmp_lines = ( - "".join(" " if part is SPACE else part for part in parts).strip() - for parts in out - ) - pre_strip_lines = list(filter(None, tmp_lines)) - - # normalize whitespace according to HTML rendering rules - # https://developer.mozilla.org/en-US/docs/Web/API/Document_Object_Model/Whitespace#explanation - stripped_lines = [ - re.sub(" +", " ", l.replace("\t", " ").replace("\n", " ")).strip(" ") - for l in pre_strip_lines - ] - - return _parse_from_lines(stripped_lines, curr_count, bot_user) - - -def _parse_from_lines(lines: list[str], curr_count: Optional[int], bot_user: str) -> ParsedUpdate: - command = next( - (cmd for l in lines if (cmd := _parse_command(l, bot_user))), - None - ) - if lines: - # look for groups of digits (as many as possible) separated by a uniform separator from the valid set - first = lines[0] - match = re.match( - "(?Pv)?(?P-)?(?P\\d+((?P[,. \u2009]|, )\\d+((?P=sep)\\d+)*)?)", - first, - re.ASCII, # only recognize ASCII digits - ) - if match: - raw_digits = match["num"] - sep = match["sep"] - post = first[match.end() :] - - zeros = False - while len(raw_digits) > 1 and raw_digits[0] == "0": - zeros = True - raw_digits = raw_digits.removeprefix("0").removeprefix(sep or "") - - parts = raw_digits.split(sep) if sep else [raw_digits] - all_parts_valid = ( - sep is None - or ( - 1 <= len(parts[0]) <= 3 - and all(len(p) == 3 for p in parts[1:]) - ) - ) - lone = len(lines) == 1 and (not post or post.isspace()) - typo = False - if lone: - if match["v"] and len(parts) == 1 and len(parts[0]) <= 2: - # failed paste of leading digits - typo = True - elif match["v"] and all_parts_valid: - # v followed by count - typo = True - elif curr_count is not None and abs(curr_count) >= 100 and bool(match["neg"]) == (curr_count < 0): - goal_parts = _separate(str(abs(curr_count))) - partials = [ - goal_parts[: -1] + [goal_parts[-1][: -1]], # missing last digit - goal_parts[: -1] + [goal_parts[-1][: -2]], # missing last two digits - goal_parts[: -1] + [goal_parts[-1][: -2] + goal_parts[-1][-1]], # missing second-last digit - ] - if parts in partials: - # missing any of last two digits - typo = True - elif any(parts == p[: -1] + [p[-1] + goal_parts[0]] + goal_parts[1 :] for p in partials): - # double paste - typo = True - - if match["v"] or zeros or typo or (parts == ["0"] and match["neg"]): - number = None - count_attempt = True - deletable = lone - else: - groups_okay = True - if curr_count is not None and sep and sep.isspace(): - # Presume that the intended count consists of as many valid digit groups as necessary to match the - # number of digits in the expected count, if possible. - digit_count = len(str(abs(curr_count))) - use_parts = [] - accum = 0 - for (i, part) in enumerate(parts): - part_valid = len(part) <= 3 if i == 0 else len(part) == 3 - if part_valid and accum < digit_count: - use_parts.append(part) - accum += len(part) - else: - break - - # could still be a no-separator count with some extra digit groups on the same line - if not use_parts: - use_parts = [parts[0]] - - lone = lone and len(use_parts) == len(parts) - else: - # current count is unknown or no separator was used - groups_okay = all_parts_valid - use_parts = parts - - digits = "".join(use_parts) - number = -int(digits) if match["neg"] else int(digits) - special = ( - curr_count is not None - and abs(number - curr_count) <= 25 - and _is_special_number(number) - ) - if groups_okay: - deletable = lone and not special - if len(use_parts) == len(parts) and post and not post[0].isspace(): - count_attempt = curr_count is not None and abs(number - curr_count) <= 25 - number = None - else: - count_attempt = True - else: - number = None - count_attempt = True - deletable = False - - if first[0].isdigit() and not count_attempt: - count_attempt = True - deletable = False - else: - # no count attempt found - number = None - count_attempt = False - deletable = False - else: - # no lines in update - number = None - count_attempt = False - deletable = True - - return ParsedUpdate( - number = number, - command = command, - count_attempt = count_attempt, - deletable = deletable, - ) - - -def _separate(digits: str) -> list[str]: - mod = len(digits) % 3 - out = [] - if mod: - out.append(digits[: mod]) - out.extend(digits[i : i + 3] for i in range(mod, len(digits), 3)) - return out - - -def _is_special_number(num: int) -> bool: - num_str = str(num) - return bool( - num % 1000 in [0, 1, 333, 999] - or (num > 10_000_000 and "".join(reversed(num_str)) == num_str) - or re.match(r"(.+)\1+$", num_str) # repeated sequence - ) diff --git a/strikebot/strikebot/docs/sample_config.ini b/strikebot/strikebot/docs/sample_config.ini new file mode 100644 index 0000000..34dec07 --- /dev/null +++ b/strikebot/strikebot/docs/sample_config.ini @@ -0,0 +1,65 @@ +# see Python `configparser' docs for precise file syntax + +[config] + +######## basic operating parameters + +# space-separated authorization IDs for Reddit API token lookup +auth IDs = 13 15 + +bot user = count_better + +# if false, don't strike or delete updates or report incorrectly stricken updates (optional, default true) +enforcing = true + +# log messages at or above this severity to standard out; level names and numbers are supported (optional, default +# WARNING); see https://docs.python.org/3/library/logging.html#levels +log level = WARNING + +# if true, don't modify the live thread at all (implies enforcing = false) (optional, default false) +read-only = false + +thread ID = abc123 + + +######## API pool error handling + +# (seconds) for error backoff, pause the API pool for this much time +API pool error delay = 5.5 + +# (seconds) pause the pool when the Reddit API returns two errors within any time span of this length +API pool error window = 5.5 + +# terminate if the Reddit API request queue exceeds this size +request queue limit = 100 + + +######## live thread update handling + +# (seconds) maximum time to hold updates for reordering +reorder buffer time = 0.25 + +# (seconds) minimum time to retain updates to enable future resyncs +update retention time = 120.0 + + +######## WebSocket pool + +# (seconds) maximum allowable spread in WebSocket message arrival times +WS parity time = 1.0 + +# goal size of WebSocket connection pool +WS pool size = 5 + +# (seconds) after receiving no messages for this much time, attempt to replace the connection if the pool isn't depleted +WS silent limit = 30 + +# (seconds) time after WebSocket handshake completion during which missed updates are excused +WS warmup time = 0.3 + + +# Postgres database configuration; same options as in a connect string. Note that because Python's `SSLContext' is used +# to implement TLS client auth, using `sslkey' to specify a key obtained from an OpenSSL engine may not be supported (a +# file path is the only option). +[db connect params] +host = example.org diff --git a/strikebot/strikebot/pyproject.toml b/strikebot/strikebot/pyproject.toml new file mode 100644 index 0000000..8fe2f47 --- /dev/null +++ b/strikebot/strikebot/pyproject.toml @@ -0,0 +1,3 @@ +[build-system] +requires = ["setuptools>=42", "wheel"] +build-backend = "setuptools.build_meta" diff --git a/strikebot/strikebot/setup.cfg b/strikebot/strikebot/setup.cfg new file mode 100644 index 0000000..cdd2e66 --- /dev/null +++ b/strikebot/strikebot/setup.cfg @@ -0,0 +1,15 @@ +[metadata] +name = strikebot +version = 0.0.7 + +[options] +package_dir = + = src +packages = strikebot +python_requires = ~= 3.9 +install_requires = + asks ~= 2.4 + beautifulsoup4 ~= 4.9 + trio == 0.19 + trio-websocket == 0.9.2 + triopg == 0.6.0 diff --git a/strikebot/strikebot/setup.py b/strikebot/strikebot/setup.py new file mode 100644 index 0000000..056ba45 --- /dev/null +++ b/strikebot/strikebot/setup.py @@ -0,0 +1,4 @@ +import setuptools + + +setuptools.setup() diff --git a/strikebot/strikebot/src/strikebot/__init__.py b/strikebot/strikebot/src/strikebot/__init__.py new file mode 100644 index 0000000..daa9944 --- /dev/null +++ b/strikebot/strikebot/src/strikebot/__init__.py @@ -0,0 +1,317 @@ +"""Count tracking logic and bot's external interface.""" + +from contextlib import nullcontext as nullcontext +from dataclasses import dataclass +from functools import total_ordering +from itertools import islice +from typing import Optional +from uuid import UUID +import bisect +import datetime as dt +import importlib.metadata +import itertools +import logging + +from trio import CancelScope, EndOfChannel +import trio + +from strikebot.updates import Command, parse_update + + +__version__ = importlib.metadata.version(__package__) + + +@dataclass +class _Update: + id: str + name: str + author: str + number: Optional[int] + command: Optional[Command] + ts: dt.datetime + count_attempt: bool + deletable: bool + stricken: bool + + def can_follow(self, prior: "_Update") -> bool: + """Determine whether this count update can follow another count update.""" + return self.number == prior.number + 1 and self.author != prior.author + + def __str__(self): + content = str(self.number) + if self.command: + content += "+" + self.command.name + return "{}({} by {})".format(self.id[4 : 8], content, self.author) + + +@total_ordering +@dataclass +class _BufferedUpdate: + update: _Update + release_at: float # Trio time + + def __lt__(self, other): + return self.update.ts < other.update.ts + + +@total_ordering +@dataclass +class _TimelineUpdate: + update: _Update + accepted: bool + + def __lt__(self, other): + return self.update.ts < other.update.ts + + def rejected(self) -> bool: + """Whether the update should be stricken.""" + return self.update.count_attempt and not self.accepted + + +def _parse_rfc_4122_ts(ts: int) -> dt.datetime: + epoch = dt.datetime(1582, 10, 15, tzinfo = dt.timezone.utc) + return epoch + dt.timedelta(microseconds = ts / 10) + + +def _format_update_ref(update: _Update, thread_id: str) -> str: + url = f"https://www.reddit.com/live/{thread_id}/updates/{update.id}" + author_md = update.author.replace("_", "\\_") # Reddit allows letters, digits, underscore, and hyphen + return f"[{update.number:,} by {author_md}]({url})" + + +def _format_bad_strike_alert(update: _Update, thread_id: str) -> str: + return "_incorrectly stricken:_ " + _format_update_ref(update, thread_id) + + +def _format_curr_count(last_valid: Optional[_Update]) -> str: + if last_valid and last_valid.number is not None: + count_str = f"{last_valid.number + 1:,}" + else: + count_str = "unknown" + return f"_current count:_ {count_str}" + + +async def count_tracker_impl( + message_rx, + api_pool: "strikebot.reddit_api.ApiClientPool", + reorder_buffer_time: dt.timedelta, # duration to hold some updates for reordering + thread_id: str, + bot_user: str, + enforcing: bool, + read_only: bool, + update_retention: dt.timedelta, + logger: logging.Logger, +) -> None: + from strikebot.reddit_api import CorrectionUpdateRequest, DeleteRequest, ReportUpdateRequest, StrikeRequest + + enforcing = enforcing and not read_only + + buffer_ = [] + timeline = [] + last_valid: Optional[_Update] = None + pending_strikes: set[str] = set() # names of updates to mark stricken on arrival + forgotten: bool = False # whether the retention period for an update has passed during the run + + def handle_update(update): + nonlocal last_valid + + tu = _TimelineUpdate(update, accepted = False) + + pos = bisect.bisect(timeline, tu) + if pos != len(timeline): + logger.warning(f"long transpo: {update.id}") + + if pos > 0 and timeline[pos - 1].update.id == update.id: + # The pool sent this update message multiple times. This could be because the message reached parity and + # its bucket was dropped just before a new connection was opened and sent its own copy, but connections also + # seem to send duplicate messages. + return + + pred = next( + (timeline[i].update for i in reversed(range(pos)) if timeline[i].accepted), + None + ) + contender = update.command is Command.RESET or update.number is not None + if contender and pred is None and last_valid and forgotten: + # This is a really long transpo which we can't judge valid or invalid since we've forgotten the surrounding + # updates. The best we can do is just ignore it. + logger.warning(f"ignoring {update.id}: no valid prior count on record") + else: + timeline.insert(pos, tu) + tu.accepted = ( + update.command is Command.RESET + or ( + update.number is not None + and (pred is None or pred.number is None or update.can_follow(pred)) + and not update.stricken + ) + ) + logger.debug("{} {} -> {}".format("accepted" if tu.accepted else "rejected", pred, update)) + if tu.accepted: + # resync subsequent updates + newly_invalid = [] + resync_last_valid = update + converged = False + for scan_tu in islice(timeline, pos + 1, None): + if scan_tu.update.command is Command.RESET: + resync_last_valid = scan_tu.update + if last_valid: + converged = True + elif scan_tu.update.number is not None: + accept = ( + (resync_last_valid.number is None or scan_tu.update.can_follow(resync_last_valid)) + and not scan_tu.update.stricken + ) + if accept: + assert scan_tu.accepted + resync_last_valid = scan_tu.update + # resync would have no effect past this point + if last_valid: + converged = True + elif scan_tu.accepted: + newly_invalid.append(scan_tu) + if scan_tu.update is last_valid: + last_valid = None + scan_tu.accepted = accept + if converged: + break + + if converged: + assert last_valid.ts >= resync_last_valid.ts + else: + last_valid = resync_last_valid + + parts = [] + unstrikable = [tu for tu in newly_invalid if tu.update.stricken] + if unstrikable: + parts.append( + "The following counts are invalid:\n\n" + + "\n".join(" - " + _format_update_ref(tu.update, thread_id) for tu in unstrikable) + ) + + if update.stricken: + logger.info(f"bad strike of {update.id}") + parts.append(_format_bad_strike_alert(update, thread_id)) + + if parts: + parts.append(_format_curr_count(last_valid)) + if not read_only: + api_pool.enqueue_request(CorrectionUpdateRequest(thread_id, "\n\n".join(parts))) + + for invalid_tu in newly_invalid: + if not invalid_tu.update.stricken: + if enforcing: + api_pool.enqueue_request(StrikeRequest(thread_id, invalid_tu.update.name, invalid_tu.update.ts)) + invalid_tu.update.stricken = True + elif update.count_attempt and not update.stricken: + if enforcing: + api_pool.enqueue_request(StrikeRequest(thread_id, update.name, update.ts)) + update.stricken = True + + if not read_only and update.command is Command.REPORT: + api_pool.enqueue_request(ReportUpdateRequest(thread_id, body = _format_curr_count(last_valid))) + + async with message_rx: + while True: + if buffer_: + deadline = min(bu.release_at for bu in buffer_) + cancel_scope = CancelScope(deadline = deadline) + else: + cancel_scope = nullcontext() + + msg = None + with cancel_scope: + try: + msg = await message_rx.receive() + except EndOfChannel: + break + + if msg: + if msg.data["type"] == "update": + payload_data = msg.data["payload"]["data"] + stricken = payload_data["name"] in pending_strikes + pending_strikes.discard(payload_data["name"]) + if payload_data["author"] != bot_user: + if last_valid and last_valid.number is not None: + next_up = last_valid.number + 1 + else: + next_up = None + pu = parse_update(payload_data, next_up, bot_user) + rfc_ts = UUID(payload_data["id"]).time + update = _Update( + id = payload_data["id"], + name = payload_data["name"], + author = payload_data["author"], + number = pu.number, + command = pu.command, + ts = _parse_rfc_4122_ts(rfc_ts), + count_attempt = pu.count_attempt, + deletable = pu.deletable, + stricken = stricken or payload_data["stricken"], + ) + release_at = msg.timestamp + reorder_buffer_time.total_seconds() + bisect.insort(buffer_, _BufferedUpdate(update, release_at)) + elif msg.data["type"] == "strike": + slot = next( + ( + slot for slot in itertools.chain(buffer_, reversed(timeline)) + if slot.update.name == msg.data["payload"] + ), + None + ) + if slot: + if not slot.update.stricken: + slot.update.stricken = True + if isinstance(slot, _TimelineUpdate) and slot.accepted: + logger.info(f"bad strike of {slot.update.id}") + if not read_only: + body = _format_bad_strike_alert(slot.update, thread_id) + api_pool.enqueue_request(CorrectionUpdateRequest(thread_id, body)) + else: + pending_strikes.add(msg.data["payload"]) + + threshold = dt.datetime.now(dt.timezone.utc) - update_retention + + # pull updates from the reorder buffer and process + new_buffer = [] + for bu in buffer_: + process = ( + # not a count + bu.update.number is None + + # valid count + or (last_valid is None or last_valid.number is None or bu.update.can_follow(last_valid)) + or bu.update.command is Command.RESET + + or trio.current_time() >= bu.release_at + or bu.update.command is Command.RESET + ) + + if process: + if bu.update.ts < threshold: + logger.warning(f"ignoring {bu.update}: arrived past retention window") + else: + logger.debug(f"processing {bu.update}") + handle_update(bu.update) + else: + logger.debug(f"holding {bu.update}, checked against {last_valid})") + new_buffer.append(bu) + buffer_ = new_buffer + logger.debug("last count {}".format(last_valid)) + + # delete/forget old updates + new_timeline = [] + i = 0 + for (i, tu) in enumerate(timeline): + if i >= len(timeline) - 10 or tu.update.ts >= threshold: + break + elif tu.rejected() and tu.update.deletable: + if enforcing: + api_pool.enqueue_request(DeleteRequest(thread_id, tu.update.name, tu.update.ts)) + elif tu.update is last_valid: + new_timeline.append(tu) + + if i: + forgotten = True + new_timeline.extend(islice(timeline, i, None)) + timeline = new_timeline diff --git a/strikebot/strikebot/src/strikebot/__main__.py b/strikebot/strikebot/src/strikebot/__main__.py new file mode 100644 index 0000000..fc80e23 --- /dev/null +++ b/strikebot/strikebot/src/strikebot/__main__.py @@ -0,0 +1,224 @@ +from enum import Enum +from inspect import getmodule +from logging import FileHandler, getLogger, StreamHandler +from pathlib import Path +from signal import SIGUSR1 +from sys import stdout +from traceback import StackSummary +from typing import Optional +import argparse +import configparser +import datetime as dt +import logging +import ssl + +from trio import open_memory_channel, open_nursery, open_signal_receiver + +try: + from trio.lowlevel import current_root_task, Task +except ImportError: + from trio.hazmat import current_root_task, Task + +import trio_asyncio +import triopg + +from strikebot import count_tracker_impl +from strikebot.db import Client, Messenger +from strikebot.live_ws import HealingReadPool, PoolMerger +from strikebot.reddit_api import ApiClientPool + + +_DEBUG_LOG_PATH: Optional[Path] = None # file to write debug logs to, if any + + +_TaskKind = Enum( + "_TaskKind", + [ + "API_WORKER", + "DB_CLIENT", + "TOKEN_UPDATE", + "TRACKER", + "WS_MERGE_READER", + "WS_MERGE_TIMER", + "WS_REFRESHER", + "WS_WORKER", + ] +) + + +def _classify_task(task: Task) -> _TaskKind: + mod_name = getmodule(task.coro.cr_code).__name__ + if mod_name.startswith("trio_asyncio."): + return None + elif task.coro.__name__ == "_signal_handler_impl": + return None + else: + by_coro_name = { + "db_client_impl": _TaskKind.DB_CLIENT, + "token_updater_impl": _TaskKind.TOKEN_UPDATE, + "event_reader_impl": _TaskKind.WS_MERGE_READER, + "timeout_handler_impl": _TaskKind.WS_MERGE_TIMER, + "conn_refresher_impl": _TaskKind.WS_REFRESHER, + "count_tracker_impl": _TaskKind.TRACKER, + "_reader_impl": _TaskKind.WS_WORKER, + "worker_impl": _TaskKind.API_WORKER, + } + return by_coro_name[task.coro.__name__] + + +def _get_all_tasks(): + [ta_loop] = [ + task + for nursery in current_root_task().child_nurseries + for task in nursery.child_tasks + if getmodule(task.coro.cr_code).__name__.startswith("trio_asyncio.") + ] + for nursery in ta_loop.child_nurseries: + yield from nursery.child_tasks + + +def _print_task_info(): + groups = {kind: [] for kind in _TaskKind} + for task in _get_all_tasks(): + kind = _classify_task(task) + if kind: + coro = task.coro + frame_tuples = [] + while coro is not None: + if hasattr(coro, "cr_frame"): + frame_tuples.append((coro.cr_frame, coro.cr_frame.f_lineno)) + coro = coro.cr_await + else: + frame_tuples.append((coro.gi_frame, coro.gi_frame.f_lineno)) + coro = coro.gi_yieldfrom + groups[kind].append(StackSummary.extract(iter(frame_tuples))) + for kind in _TaskKind: + print(kind.name) + for ss in groups[kind]: + print(" task") + for line in ss.format(): + print(" " + line, end = "") + + +async def _signal_handler_impl(): + with open_signal_receiver(SIGUSR1) as sig_src: + async for _ in sig_src: + _print_task_info() + + +ap = argparse.ArgumentParser(__package__) +ap.add_argument("config_path") +args = ap.parse_args() + + +parser = configparser.ConfigParser() +with open(args.config_path) as config_file: + parser.read_file(config_file) + +main_cfg = parser["config"] + +api_pool_error_delay = dt.timedelta(seconds = main_cfg.getfloat("API pool error delay")) +api_pool_error_window = dt.timedelta(seconds = main_cfg.getfloat("API pool error window")) +auth_ids = set(map(int, main_cfg["auth IDs"].split())) +bot_user = main_cfg["bot user"] +enforcing = main_cfg.getboolean("enforcing", fallback = True) + +raw_log_level = main_cfg.get("log level", "WARNING") +try: + log_level = int(raw_log_level) +except ValueError: + log_level = raw_log_level + +read_only = main_cfg.getboolean("read-only", fallback = False) +reorder_buffer_time = dt.timedelta(seconds = main_cfg.getfloat("reorder buffer time")) +request_queue_limit = main_cfg.getint("request queue limit") +thread_id = main_cfg["thread ID"] +update_retention = dt.timedelta(seconds = main_cfg.getfloat("update retention time")) +ws_parity_time = dt.timedelta(seconds = main_cfg.getfloat("WS parity time")) +ws_pool_size = main_cfg.getint("WS pool size") +ws_silent_limit = dt.timedelta(seconds = main_cfg.getfloat("WS silent limit")) +ws_warmup = dt.timedelta(seconds = main_cfg.getfloat("WS warmup time")) + + +db_cfg = parser["db connect params"] +cert_path = db_cfg.pop("sslcert", None) +key_path = db_cfg.pop("sslkey", None) +key_pass = db_cfg.pop("sslpassword", None) + +getters = { + "port": db_cfg.getint, +} +db_connect_kwargs = {k: getters.get(k, db_cfg.get)(k) for k in db_cfg} + +if cert_path: + # patch TLS client cert auth support; asyncpg adds it in a later version + db_ssl_ctx = ssl.create_default_context() + db_ssl_ctx.load_cert_chain( + cert_path, + key_path, + None if key_pass is None else lambda: key_pass, + ) + db_connect_kwargs["ssl"] = db_ssl_ctx + + +if read_only and enforcing: + raise RuntimeError("can't use read-only mode with enforcing on") + + +logger = getLogger(__package__) +logger.setLevel(logging.NOTSET - 1) # NOTSET means inherit from parent; we use handlers to filter + +handler = StreamHandler(stdout) +handler.setLevel(log_level) +handler.setFormatter(logging.Formatter("{asctime:23}: {name:17} {levelname:8}: {message}", style = "{")) +logger.addHandler(handler) + +if _DEBUG_LOG_PATH: # TODO remove this ad hoc setup + debug_handler = FileHandler(_DEBUG_LOG_PATH, mode = "w") + debug_handler.setLevel(logging.DEBUG) + debug_handler.setFormatter(logging.Formatter("{asctime:23}: {name:17} {levelname:8}: {message}", style = "{")) + logger.addHandler(debug_handler) + + +async def main(): + async with ( + triopg.connect(**db_connect_kwargs) as db_conn, + open_nursery() as nursery_a, + open_nursery() as nursery_b, + open_nursery() as ws_pool_nursery, + open_nursery() as nursery_c, + ): + nursery_a.start_soon(_signal_handler_impl) + + client = Client(db_conn) + db_messenger = Messenger(client) + nursery_a.start_soon(db_messenger.db_client_impl) + + api_pool = ApiClientPool( + auth_ids, db_messenger, request_queue_limit, api_pool_error_window, api_pool_error_delay, + logger.getChild("api") + ) + nursery_a.start_soon(api_pool.token_updater_impl) + for _ in auth_ids: + await nursery_a.start(api_pool.worker_impl) + + (message_tx, message_rx) = open_memory_channel(0) + (pool_event_tx, pool_event_rx) = open_memory_channel(0) + + merger = PoolMerger(pool_event_rx, message_tx, ws_parity_time, ws_warmup, logger.getChild("merge")) + nursery_b.start_soon(merger.event_reader_impl) + nursery_b.start_soon(merger.timeout_handler_impl) + + ws_pool = HealingReadPool( + ws_pool_nursery, ws_pool_size, thread_id, api_pool, pool_event_tx, logger.getChild("live_ws"), + ws_silent_limit + ) + nursery_c.start_soon(ws_pool.conn_refresher_impl) + + nursery_c.start_soon( + count_tracker_impl, message_rx, api_pool, reorder_buffer_time, thread_id, bot_user, enforcing, read_only, + update_retention, logger.getChild("track") + ) + await ws_pool.init_workers() + +trio_asyncio.run(main) diff --git a/strikebot/strikebot/src/strikebot/common.py b/strikebot/strikebot/src/strikebot/common.py new file mode 100644 index 0000000..519037a --- /dev/null +++ b/strikebot/strikebot/src/strikebot/common.py @@ -0,0 +1,10 @@ +from typing import Any + + +def int_digest(val: int) -> str: + return "{:04x}".format(val & 0xffff) + + +def obj_digest(obj: Any) -> str: + """Makes a short digest of the identity of an object.""" + return int_digest(id(obj)) diff --git a/strikebot/strikebot/src/strikebot/db.py b/strikebot/strikebot/src/strikebot/db.py new file mode 100644 index 0000000..d1a2f42 --- /dev/null +++ b/strikebot/strikebot/src/strikebot/db.py @@ -0,0 +1,58 @@ +"""Single-connection Postgres client with high-level wrappers for operations.""" + +from dataclasses import dataclass +from typing import Any, Iterable + +import trio + + +def _channel_sender(method): + async def wrapped(self, resp_channel, *args, **kwargs): + ret = await method(self, *args, **kwargs) + async with resp_channel: + await resp_channel.send(ret) + + return wrapped + + +@dataclass +class Client: + """High-level wrappers for DB operations.""" + conn: Any + + @_channel_sender + async def get_auth_tokens(self, auth_ids: set[int]) -> dict[int, str]: + results = await self.conn.fetch( + """ + select + distinct on (id) + id, access_token + from + public.reddit_app_authorization + join unnest($1::integer[]) as request_id on id = request_id + where expires > current_timestamp + """, + list(auth_ids), + ) + return {r["id"]: r["access_token"] for r in results} + + +class Messenger: + def __init__(self, client: Client): + self._client = client + (self._request_tx, self._request_rx) = trio.open_memory_channel(0) + + async def do(self, method_name: str, args: Iterable) -> Any: + """This is run by consumers of the DB wrapper.""" + method = getattr(self._client, method_name) + (resp_tx, resp_rx) = trio.open_memory_channel(0) + coro = method(resp_tx, *args) + await self._request_tx.send(coro) + async with resp_rx: + return await resp_rx.receive() + + async def db_client_impl(self): + """This is the DB client Trio task.""" + async with self._client.conn, self._request_rx: + async for coro in self._request_rx: + await coro diff --git a/strikebot/strikebot/src/strikebot/live_ws.py b/strikebot/strikebot/src/strikebot/live_ws.py new file mode 100644 index 0000000..031d652 --- /dev/null +++ b/strikebot/strikebot/src/strikebot/live_ws.py @@ -0,0 +1,266 @@ +from collections import deque +from contextlib import suppress +from dataclasses import dataclass +from typing import Any, AsyncContextManager, Optional +import datetime as dt +import json +import logging +import math + +from trio_websocket import ConnectionClosed, HandshakeError, open_websocket_url, WebSocketConnection +import trio + +from strikebot.common import int_digest, obj_digest +from strikebot.reddit_api import AboutLiveThreadRequest + + +@dataclass +class _PoolEvent: + timestamp: float # Trio time + + +@dataclass +class _Message(_PoolEvent): + data: Any + scope: Any + + def dedup_tag(self): + """A hashable essence of the contents of this message.""" + if self.data["type"] == "update": + return ("update", self.data["payload"]["data"]["id"]) + elif self.data["type"] in {"strike", "delete"}: + return (self.data["type"], self.data["payload"]) + else: + raise ValueError(self.data["type"]) + + +@dataclass +class _ConnectionDown(_PoolEvent): + scope: trio.CancelScope + + +@dataclass +class _ConnectionUp(_PoolEvent): + scope: trio.CancelScope + + +class HealingReadPool: + def __init__( + self, pool_nursery: trio.Nursery, size, live_thread_id, api_client_pool, pool_event_tx, logger: logging.Logger, + silent_limit: dt.timedelta + ): + assert size >= 2 + self._nursery = pool_nursery + self._size = size + self._live_thread_id = live_thread_id + self._api_client_pool = api_client_pool + self._pool_event_tx = pool_event_tx + self._logger = logger + self._silent_limit = silent_limit + + (self._refresh_queue_tx, self._refresh_queue_rx) = trio.open_memory_channel(size) + + # number of workers who have stopped receiving updates but whose tasks aren't yet stopped + self._closing = 0 + + async def init_workers(self): + for _ in range(self._size): + await self._spawn_reader() + if not self._nursery.child_tasks: + raise RuntimeError("Unable to create any WS connections") + + async def _reader_impl(self, conn_ctx: AsyncContextManager, cancel_scope, refresh_tx): + # TODO could use task's implicit cancel scope + try: + async with conn_ctx as conn: + self._logger.debug("scope up: {}".format(obj_digest(cancel_scope))) + async with refresh_tx: + silent_timeout = False + with cancel_scope, suppress(ConnectionClosed): + while True: + with trio.move_on_after(self._silent_limit.total_seconds()) as timeout_scope: + message = await conn.get_message() + + if timeout_scope.cancelled_caught: + if len(self._nursery.child_tasks) - self._closing == self._size: + silent_timeout = True + break + else: + self._logger.debug("not replacing connection {}; {} tasks, {} closing".format( + obj_digest(cancel_scope), + len(self._nursery.child_tasks), + self._closing, + )) + else: + event = _Message(trio.current_time(), json.loads(message), cancel_scope) + await self._pool_event_tx.send(event) + + self._closing += 1 + if silent_timeout: + await conn.aclose() + self._logger.debug("replacing WS connection {} due to silent timeout".format(obj_digest(cancel_scope))) + elif cancel_scope.cancelled_caught: + await conn.aclose(1008, "Server unexpectedly stopped sending messages") + self._logger.warning("replacing WS connection {} due to missed update".format(obj_digest(cancel_scope))) + else: + self._logger.warning("replacing WS connection {} closed by server".format(obj_digest(cancel_scope))) + + refresh_tx.send_nowait(cancel_scope) + self._logger.debug("scope down: {} ({} tasks, {} closing)".format( + obj_digest(cancel_scope), + len(self._nursery.child_tasks), + self._closing, + )) + self._closing -= 1 + except HandshakeError: + self._logger.error("handshake error while opening WS connection") + + async def _new_connection(self) -> Optional[AsyncContextManager[WebSocketConnection]]: + request = AboutLiveThreadRequest(self._live_thread_id) + resp = await self._api_client_pool.make_request(request) + if resp.status_code == 200: + url = json.loads(resp.text)["data"]["websocket_url"] + return open_websocket_url(url) + else: + self._logger.error(f"thread info request failed (HTTP {resp.status_code}); can't open WS connection") + return None + + async def _spawn_reader(self) -> None: + conn = await self._new_connection() + if conn: + new_scope = trio.CancelScope() + await self._pool_event_tx.send(_ConnectionUp(trio.current_time(), new_scope)) + refresh_tx = self._refresh_queue_tx.clone() + self._nursery.start_soon(self._reader_impl, conn, new_scope, refresh_tx) + + async def conn_refresher_impl(self): + """Task to monitor and replace WS connections as they disconnect or go silent.""" + async with self._refresh_queue_rx: + async for old_scope in self._refresh_queue_rx: + await self._pool_event_tx.send(_ConnectionDown(trio.current_time(), old_scope)) + await self._spawn_reader() + if not self._nursery.child_tasks: + raise RuntimeError("WS pool depleted") + + +def _tag_digest(tag): + return int_digest(hash(tag)) + + +def _format_scope_list(scopes): + return ", ".join(sorted(map(obj_digest, scopes))) + + +class PoolMerger: + @dataclass + class _Bucket: + start: float + recipients: set + + def __init__(self, pool_event_rx, message_tx, parity_timeout: dt.timedelta, conn_warmup: dt.timedelta, logger: logging.Logger): + """ + :param parity_timeout: max delay between message receipt on different connections + :param conn_warmup: max interval after connection within which missed updates are allowed + """ + self._pool_event_rx = pool_event_rx + self._message_tx = message_tx + self._parity_timeout = parity_timeout + self._conn_warmup = conn_warmup + self._logger = logger + + self._buckets = {} + self._scope_activations = {} + self._pending = deque() + self._outgoing_scopes = set() + (self._timer_poke_tx, self._timer_poke_rx) = trio.open_memory_channel(math.inf) + + def _log_current_buckets(self): + self._logger.debug("current buckets: {}".format(", ".join(map(_tag_digest, self._buckets.keys())))) + + async def event_reader_impl(self): + """Drop unused messages, deduplicate useful ones, and communicate with the timeout handler.""" + async with self._pool_event_rx, self._timer_poke_tx: + async for event in self._pool_event_rx: + if isinstance(event, _ConnectionUp): + # An early add of an active scope could mean it's expected on a message that fired before it opened, + # resulting in a false positive for replacement. The timer task merges these in among the buckets by + # timestamp to avoid that. + self._pending.append(event) + elif isinstance(event, _Message): + if event.data["type"] in ["update", "strike", "delete"]: + tag = event.dedup_tag() + self._logger.debug("recv {} from {}".format(_tag_digest(tag), obj_digest(event.scope))) + if tag in self._buckets: + b = self._buckets[tag] + b.recipients.add(event.scope) + # If this scope is the last one for this bucket we could clear the bucket here, but since + # connections sometimes get repeat messages and second copies can arrive before the first + # copy has arrived on all connections, leaving the bucket open to absorb repeats can reduce + # the likelihood of a second bucket being allocated late for the second copy of a message + # and causing unnecessary connection replacement. + elif event.scope not in self._outgoing_scopes: + sane = ( + event.scope in self._scope_activations + or any(e.scope == event.scope for e in self._pending) + ) + if sane: + self._logger.debug("new bucket {}: {}".format(_tag_digest(tag), tag)) + self._buckets[tag] = self._Bucket(event.timestamp, {event.scope}) + self._log_current_buckets() + await self._message_tx.send(event) + else: + raise RuntimeError("recieved message from unrecognized WS connection") + else: + self._logger.debug("recv type {!r} from {} (discarding)".format(event.data["type"], obj_digest(event.scope))) + elif isinstance(event, _ConnectionDown): + # We don't need to worry about canceling this scope at all, so no need to require it for parity for + # any message, even older ones. The scope may be gone already, if we canceled it previously. + self._scope_activations.pop(event.scope, None) + self._pending = deque(ev for ev in self._pending if ev.scope is not event.scope) + self._outgoing_scopes.discard(event.scope) + else: + raise TypeError(f"Expected pool event, found {event!r}") + self._timer_poke_tx.send_nowait(None) # may have new work for the timer + + async def timeout_handler_impl(self): + """When connections have had enough time to reach parity on a message, signal replacement of any slackers.""" + async with self._timer_poke_rx: + while True: + if self._buckets: + now = trio.current_time() + (tag, bucket) = min(self._buckets.items(), key = lambda i: i[1].start) + + # Make sure the right scope set is ready for the next bucket up + while self._pending and self._pending[0].timestamp < bucket.start: + ev = self._pending.popleft() + self._scope_activations[ev.scope] = ev.timestamp + + target_scopes = { + scope for (scope, active) in self._scope_activations.items() + if active + self._conn_warmup.total_seconds() < now + } + if now > bucket.start + self._parity_timeout.total_seconds(): + filled = target_scopes & bucket.recipients + missing = target_scopes - bucket.recipients + extra = bucket.recipients - target_scopes + + self._logger.debug("expiring bucket {}".format(_tag_digest(tag))) + if filled: + self._logger.debug(" filled {}: {}".format(len(filled), _format_scope_list(filled))) + if missing: + self._logger.debug(" missing {}: {}".format(len(missing), _format_scope_list(missing))) + if extra: + self._logger.debug(" extra {}: {}".format(len(extra), _format_scope_list(extra))) + + for scope in missing: + self._outgoing_scopes.add(scope) + scope.cancel() + del self._buckets[tag] + self._log_current_buckets() + else: + await trio.sleep_until(bucket.start + self._parity_timeout.total_seconds()) + else: + try: + await self._timer_poke_rx.receive() + except trio.EndOfChannel: + break diff --git a/strikebot/strikebot/src/strikebot/queue.py b/strikebot/strikebot/src/strikebot/queue.py new file mode 100644 index 0000000..acf5ef0 --- /dev/null +++ b/strikebot/strikebot/src/strikebot/queue.py @@ -0,0 +1,61 @@ +"""Unbounded blocking queues for Trio""" + +from collections import deque +from dataclasses import dataclass +from functools import total_ordering +from typing import Any, Iterable +import heapq + +try: + from trio.lowlevel import ParkingLot +except ImportError: + from trio.hazmat import ParkingLot + + +class Queue: + def __init__(self): + self._deque = deque() + self._empty_wait = ParkingLot() + + def push(self, el: Any) -> None: + self._deque.append(el) + self._empty_wait.unpark() + + def extend(self, els: Iterable[Any]) -> None: + for el in els: + self.push(el) + + async def pop(self) -> Any: + if not self._deque: + await self._empty_wait.park() + return self._deque.popleft() + + def __len__(self): + return len(self._deque) + + +@total_ordering +@dataclass +class _ReverseOrdWrapper: + inner: Any + + def __lt__(self, other): + return other.inner < self.inner + + +class MaxHeap: + def __init__(self): + self._heap = [] + self._empty_wait = ParkingLot() + + def push(self, item): + heapq.heappush(self._heap, _ReverseOrdWrapper(item)) + self._empty_wait.unpark() + + async def pop(self): + if not self._heap: + await self._empty_wait.park() + return heapq.heappop(self._heap).inner + + def __len__(self) -> int: + return len(self._heap) diff --git a/strikebot/strikebot/src/strikebot/reddit_api.py b/strikebot/strikebot/src/strikebot/reddit_api.py new file mode 100644 index 0000000..0cf5daa --- /dev/null +++ b/strikebot/strikebot/src/strikebot/reddit_api.py @@ -0,0 +1,290 @@ +"""Multi-auth HTTP connection pool with Reddit API wrappers and rate limiting.""" + +from abc import ABCMeta, abstractmethod +from dataclasses import dataclass, field +from functools import total_ordering +from socket import EAI_AGAIN, EAI_FAIL, gaierror +import datetime as dt +import logging + +from asks.response_objects import Response +import asks +import trio + +from strikebot import __version__ as VERSION +from strikebot.common import obj_digest +from strikebot.queue import MaxHeap, Queue + + +REQUEST_DELAY = dt.timedelta(seconds = 1) + +TOKEN_UPDATE_DELAY = dt.timedelta(minutes = 15) + +USER_AGENT_FMT = f"any:net.jcornell.strikebot.{{auth_id}}:v{VERSION} (by /u/jaxklax)" + +API_BASE_URL = "https://oauth.reddit.com" + + +@total_ordering +class _Request(metaclass = ABCMeta): + _SUBTYPE_PRECEDENCE = None # assigned later + + def __lt__(self, other): + if type(self) is type(other): + return self._subtype_cmp_key() < other._subtype_cmp_key() + else: + prec = self._SUBTYPE_PRECEDENCE + return prec.index(type(other)) < prec.index(type(self)) + + def __eq__(self, other): + if type(self) is type(other): + return self._subtype_cmp_key() == other._subtype_cmp_key() + else: + return False + + @abstractmethod + def _subtype_cmp_key(self): + # a large key corresponds to a high priority + raise NotImplementedError() + + @abstractmethod + def to_asks_kwargs(self): + raise NotImplementedError() + + +@dataclass(eq = False) +class StrikeRequest(_Request): + thread_id: str + update_name: str + update_ts: dt.datetime + + def to_asks_kwargs(self): + return { + "method": "POST", + "path": f"/api/live/{self.thread_id}/strike_update", + "data": { + "api_type": "json", + "id": self.update_name, + }, + } + + def _subtype_cmp_key(self): + return self.update_ts + + +@dataclass(eq = False) +class DeleteRequest(_Request): + thread_id: str + update_name: str + update_ts: dt.datetime + + def to_asks_kwargs(self): + return { + "method": "POST", + "path": f"/api/live/{self.thread_id}/delete_update", + "data": { + "api_type": "json", + "id": self.update_name, + }, + } + + def _subtype_cmp_key(self): + return -self.update_ts.timestamp() + + +@dataclass(eq = False) +class AboutLiveThreadRequest(_Request): + thread_id: str + created: float = field(default_factory = trio.current_time) + + def to_asks_kwargs(self): + return { + "method": "GET", + "path": f"/live/{self.thread_id}/about", + "params": { + "raw_json": "1", + }, + } + + def _subtype_cmp_key(self): + return -self.created + + +@dataclass(eq = False) +class _BaseUpdatePostRequest(_Request, metaclass = ABCMeta): + thread_id: str + body: str + created: float = field(default_factory = trio.current_time) + + def to_asks_kwargs(self): + return { + "method": "POST", + "path": f"/live/{self.thread_id}/update", + "data": { + "api_type": "json", + "body": self.body, + }, + } + + +@dataclass(eq = False) +class ReportUpdateRequest(_BaseUpdatePostRequest): + def _subtype_cmp_key(self): + return -self.created + + +@dataclass(eq = False) +class CorrectionUpdateRequest(_BaseUpdatePostRequest): + def _subtype_cmp_key(self): + return -self.created + + +# highest priority first +_Request._SUBTYPE_PRECEDENCE = [ + AboutLiveThreadRequest, + StrikeRequest, + CorrectionUpdateRequest, + ReportUpdateRequest, + DeleteRequest, +] + + +@dataclass +class AppCooldown: + auth_id: int + ready_at: float # Trio time + + +class ApiClientPool: + def __init__( + self, + auth_ids: set[int], + db_messenger: "strikebot.db.Messenger", + request_queue_limit: int, + error_window: dt.timedelta, + error_delay: dt.timedelta, + logger: logging.Logger, + ): + self._auth_ids = auth_ids + self._db_messenger = db_messenger + self._request_queue_limit = request_queue_limit + self._error_window = error_window + self._error_delay = error_delay + self._logger = logger + + now = trio.current_time() + self._tokens = {} + self._app_queue = Queue() + self._waiting = {id_: AppCooldown(id_, now) for id_ in auth_ids} + self._request_queue = MaxHeap() + + # pool-wide API error backoff + self._last_error = None + self._global_resume = None + + self._session = asks.Session(connections = len(auth_ids)) + self._session.base_location = API_BASE_URL + + async def _update_tokens(self): + tokens = await self._db_messenger.do("get_auth_tokens", (self._auth_ids,)) + self._tokens.update(tokens) + + awaken_auths = self._waiting.keys() & tokens.keys() + self._app_queue.extend(self._waiting.pop(auth_id) for auth_id in awaken_auths) + self._logger.debug("updated API tokens") + + async def token_updater_impl(self): + last_update = None + while True: + if last_update is not None: + await trio.sleep_until(last_update + TOKEN_UPDATE_DELAY.total_seconds()) + + if last_update is None: + last_update = trio.current_time() + else: + last_update += TOKEN_UPDATE_DELAY.total_seconds() + + await self._update_tokens() + + def _check_queue_size(self) -> None: + if len(self._request_queue) > self._request_queue_limit: + raise RuntimeError("request queue size exceeded limit") + + async def make_request(self, request: _Request) -> Response: + (resp_tx, resp_rx) = trio.open_memory_channel(0) + self.enqueue_request(request, resp_tx) + async with resp_rx: + return await resp_rx.receive() + + def enqueue_request(self, request: _Request, resp_tx = None) -> None: + self._logger.debug("request {}: {}".format(obj_digest(request), type(request).__name__)) + self._request_queue.push((request, resp_tx)) + self._logger.debug(f"{len(self._request_queue)} requests in queue") + self._check_queue_size() + + async def worker_impl(self, task_status): + task_status.started() + while True: + (request, resp_tx) = await self._request_queue.pop() + cooldown = await self._app_queue.pop() + await trio.sleep_until(cooldown.ready_at) + if self._global_resume: + await trio.sleep_until(self._global_resume) + + asks_kwargs = request.to_asks_kwargs() + headers = asks_kwargs.setdefault("headers", {}) + headers.update({ + "Authorization": "Bearer {}".format(self._tokens[cooldown.auth_id]), + "User-Agent": USER_AGENT_FMT.format(auth_id = cooldown.auth_id), + }) + + request_time = trio.current_time() + try: + resp = await self._session.request(**asks_kwargs) + except gaierror as e: + if e.errno in [EAI_FAIL, EAI_AGAIN]: + # DNS failure, probably temporary + error = True + else: + raise + else: + resp.body # read response + error = False + wait_for_token = False + log_suffix = " (request {})".format(obj_digest(request)) + if resp.status_code == 429: + # We disagreed about the rate limit state; just try again later. + self._logger.warning("rate limited by Reddit API" + log_suffix) + error = True + elif resp.status_code == 401: + self._logger.warning("got HTTP 401 from Reddit API" + log_suffix) + error = True + wait_for_token = True + elif resp.status_code in [404, 500, 503]: + self._logger.warning(f"got HTTP {resp.status_code} from Reddit API, retrying" + log_suffix) + error = True + elif 400 <= resp.status_code < 500: + # If we're doing something wrong, let's catch it right away. + raise RuntimeError(f"unexpected client error response: {resp.status_code}" + log_suffix) + else: + if resp.status_code != 200: + raise RuntimeError(f"unexpected status code {resp.status_code}") + self._logger.debug("success" + log_suffix) + if resp_tx: + await resp_tx.send(resp) + + if error: + self.enqueue_request(request, resp_tx) + if self._last_error: + spread = dt.timedelta(seconds = request_time - self._last_error) + if spread <= self._error_window: + self._global_resume = request_time + self._error_delay.total_seconds() + self._last_error = request_time + + cooldown.ready_at = request_time + REQUEST_DELAY.total_seconds() + if wait_for_token: + self._waiting[cooldown.auth_id] = cooldown + if not self._app_queue: + self._logger.error("all workers waiting for API tokens") + else: + self._app_queue.push(cooldown) diff --git a/strikebot/strikebot/src/strikebot/tests.py b/strikebot/strikebot/src/strikebot/tests.py new file mode 100644 index 0000000..966bca6 --- /dev/null +++ b/strikebot/strikebot/src/strikebot/tests.py @@ -0,0 +1,62 @@ +from unittest import TestCase + +from strikebot.updates import parse_update + + +def _build_payload(body_html: str) -> dict[str, str]: + return {"body_html": body_html} + + +class UpdateParsingTests(TestCase): + def test_successful_counts(self): + pu = parse_update(_build_payload("
12,345,678 spaghetti
"), None, "") + self.assertEqual(pu.number, 12345678) + self.assertFalse(pu.deletable) + + pu = parse_update(_build_payload("

0


oz
"), None, "") + self.assertEqual(pu.number, 0) + self.assertFalse(pu.deletable) + + pu = parse_update(_build_payload("
121 345 621
"), 121, "") + self.assertEqual(pu.number, 121) + self.assertFalse(pu.deletable) + + pu = parse_update(_build_payload("28 336 816"), 28336816, "") + self.assertEqual(pu.number, 28336816) + self.assertTrue(pu.deletable) + + def test_non_counts(self): + pu = parse_update(_build_payload("
zoo
"), None, "") + self.assertFalse(pu.count_attempt) + self.assertFalse(pu.deletable) + + def test_typos(self): + pu = parse_update(_build_payload("v9"), 888, "") + self.assertIsNone(pu.number) + self.assertTrue(pu.count_attempt) + + pu = parse_update(_build_payload("
v11.585 Empire
"), None, "") + self.assertIsNone(pu.number) + self.assertTrue(pu.count_attempt) + self.assertFalse(pu.deletable) + + pu = parse_update(_build_payload("
11, 585, 22
"), 11_585_202, "") + self.assertIsNone(pu.number) + self.assertTrue(pu.count_attempt) + self.assertTrue(pu.deletable) + + pu = parse_update(_build_payload("0490499"), 4999, "") + self.assertIsNone(pu.number) + self.assertTrue(pu.count_attempt) + + # this update is marked non-deletable, but I'm not sure it should be + pu = parse_update(_build_payload("12,123123"), 12_123_123, "") + self.assertIsNone(pu.number) + self.assertTrue(pu.count_attempt) + + def test_html_handling(self): + pu = parse_update(_build_payload("123
,456"), None, "") + self.assertEqual(pu.number, 123) + + pu = parse_update(_build_payload("
123\n456
"), None, "") + self.assertEqual(pu.number, 123) diff --git a/strikebot/strikebot/src/strikebot/updates.py b/strikebot/strikebot/src/strikebot/updates.py new file mode 100644 index 0000000..01eb67b --- /dev/null +++ b/strikebot/strikebot/src/strikebot/updates.py @@ -0,0 +1,226 @@ +from __future__ import annotations +from dataclasses import dataclass +from enum import Enum +from typing import Optional +import re + +from bs4 import BeautifulSoup + + +Command = Enum("Command", ["RESET", "REPORT"]) + + +@dataclass +class ParsedUpdate: + number: Optional[int] + command: Optional[Command] + count_attempt: bool # either well-formed or typo + deletable: bool + + +def _parse_command(line: str, bot_user: str) -> Optional[Command]: + if line.lower() == f"/u/{bot_user} reset".lower(): + return Command.RESET + elif line.lower() in ["sidebar count", "current count"]: + return Command.REPORT + else: + return None + + +def parse_update(payload_data: dict, curr_count: Optional[int], bot_user: str) -> ParsedUpdate: + # curr_count is the next number up, one more than the last count + + NEW_LINE = object() + SPACE = object() + + # flatten the update content to plain text + tree = BeautifulSoup(payload_data["body_html"], "html.parser") + worklist = list(reversed(tree.contents)) + out = [[]] + while worklist: + el = worklist.pop() + if isinstance(el, str): + out[-1].append(el) + elif el is SPACE: + out[-1].append(el) + elif el is NEW_LINE or el.name in ["br", "hr"]: + if out[-1]: + out.append([]) + elif el.name in ["em", "strong", "del", "span", "sup", "code", "a", "th", "td"]: + worklist.extend(reversed(el.contents)) + elif el.name in ["ul", "ol", "table", "thead", "tbody"]: + worklist.extend(reversed(el.contents)) + elif el.name in ["li", "p", "div", "blockquote"] or re.match(r"h[1-6]$", el.name): + worklist.append(NEW_LINE) + worklist.extend(reversed(el.contents)) + worklist.append(NEW_LINE) + elif el.name == "pre": + out.extend([l] for l in el.text.splitlines()) + out.append([]) + elif el.name == "tr": + worklist.append(NEW_LINE) + for (i, cell) in enumerate(reversed(el.contents)): + worklist.append(cell) + if i != len(el.contents) - 1: + worklist.append(SPACE) + worklist.append(NEW_LINE) + else: + raise RuntimeError(f"can't parse tag {el.name}") + + tmp_lines = ( + "".join(" " if part is SPACE else part for part in parts).strip() + for parts in out + ) + pre_strip_lines = list(filter(None, tmp_lines)) + + # normalize whitespace according to HTML rendering rules + # https://developer.mozilla.org/en-US/docs/Web/API/Document_Object_Model/Whitespace#explanation + stripped_lines = [ + re.sub(" +", " ", l.replace("\t", " ").replace("\n", " ")).strip(" ") + for l in pre_strip_lines + ] + + return _parse_from_lines(stripped_lines, curr_count, bot_user) + + +def _parse_from_lines(lines: list[str], curr_count: Optional[int], bot_user: str) -> ParsedUpdate: + command = next( + (cmd for l in lines if (cmd := _parse_command(l, bot_user))), + None + ) + if lines: + # look for groups of digits (as many as possible) separated by a uniform separator from the valid set + first = lines[0] + match = re.match( + "(?Pv)?(?P-)?(?P\\d+((?P[,. \u2009]|, )\\d+((?P=sep)\\d+)*)?)", + first, + re.ASCII, # only recognize ASCII digits + ) + if match: + raw_digits = match["num"] + sep = match["sep"] + post = first[match.end() :] + + zeros = False + while len(raw_digits) > 1 and raw_digits[0] == "0": + zeros = True + raw_digits = raw_digits.removeprefix("0").removeprefix(sep or "") + + parts = raw_digits.split(sep) if sep else [raw_digits] + all_parts_valid = ( + sep is None + or ( + 1 <= len(parts[0]) <= 3 + and all(len(p) == 3 for p in parts[1:]) + ) + ) + lone = len(lines) == 1 and (not post or post.isspace()) + typo = False + if lone: + if match["v"] and len(parts) == 1 and len(parts[0]) <= 2: + # failed paste of leading digits + typo = True + elif match["v"] and all_parts_valid: + # v followed by count + typo = True + elif curr_count is not None and abs(curr_count) >= 100 and bool(match["neg"]) == (curr_count < 0): + goal_parts = _separate(str(abs(curr_count))) + partials = [ + goal_parts[: -1] + [goal_parts[-1][: -1]], # missing last digit + goal_parts[: -1] + [goal_parts[-1][: -2]], # missing last two digits + goal_parts[: -1] + [goal_parts[-1][: -2] + goal_parts[-1][-1]], # missing second-last digit + ] + if parts in partials: + # missing any of last two digits + typo = True + elif any(parts == p[: -1] + [p[-1] + goal_parts[0]] + goal_parts[1 :] for p in partials): + # double paste + typo = True + + if match["v"] or zeros or typo or (parts == ["0"] and match["neg"]): + number = None + count_attempt = True + deletable = lone + else: + groups_okay = True + if curr_count is not None and sep and sep.isspace(): + # Presume that the intended count consists of as many valid digit groups as necessary to match the + # number of digits in the expected count, if possible. + digit_count = len(str(abs(curr_count))) + use_parts = [] + accum = 0 + for (i, part) in enumerate(parts): + part_valid = len(part) <= 3 if i == 0 else len(part) == 3 + if part_valid and accum < digit_count: + use_parts.append(part) + accum += len(part) + else: + break + + # could still be a no-separator count with some extra digit groups on the same line + if not use_parts: + use_parts = [parts[0]] + + lone = lone and len(use_parts) == len(parts) + else: + # current count is unknown or no separator was used + groups_okay = all_parts_valid + use_parts = parts + + digits = "".join(use_parts) + number = -int(digits) if match["neg"] else int(digits) + special = ( + curr_count is not None + and abs(number - curr_count) <= 25 + and _is_special_number(number) + ) + if groups_okay: + deletable = lone and not special + if len(use_parts) == len(parts) and post and not post[0].isspace(): + count_attempt = curr_count is not None and abs(number - curr_count) <= 25 + number = None + else: + count_attempt = True + else: + number = None + count_attempt = True + deletable = False + + if first[0].isdigit() and not count_attempt: + count_attempt = True + deletable = False + else: + # no count attempt found + number = None + count_attempt = False + deletable = False + else: + # no lines in update + number = None + count_attempt = False + deletable = True + + return ParsedUpdate( + number = number, + command = command, + count_attempt = count_attempt, + deletable = deletable, + ) + + +def _separate(digits: str) -> list[str]: + mod = len(digits) % 3 + out = [] + if mod: + out.append(digits[: mod]) + out.extend(digits[i : i + 3] for i in range(mod, len(digits), 3)) + return out + + +def _is_special_number(num: int) -> bool: + num_str = str(num) + return bool( + num % 1000 in [0, 1, 333, 999] + or (num > 10_000_000 and "".join(reversed(num_str)) == num_str) + or re.match(r"(.+)\1+$", num_str) # repeated sequence + )