"""
-A few tools to help keep Debian packages and artifacts well organized.
+A few tools to help keep Debian packages and artifacts well organized. Run me from the counting repo root.
+
+Run as root if doing a build, as the
"""
from argparse import ArgumentParser
)
-project_root = Path(__file__).parent
-upstream_root = project_root.joinpath("strikebot")
-build_dir = project_root.joinpath("build")
-
parser = ArgumentParser()
+
tmp = parser.add_subparsers(dest = "cmd", required = True)
+project_cmd_parsers = [
+ tmp.add_parser("build"),
+ tmp.add_parser("clean"),
+]
-tmp.add_parser("build")
-tmp.add_parser("clean")
+for subparser in project_cmd_parsers:
+ subparser.add_argument("project")
args = parser.parse_args()
+
+project_root = Path(__file__).parent.joinpath(args.project)
+upstream_root = project_root.joinpath(args.project)
+build_dir = project_root.joinpath("build")
+
if args.cmd == "build":
+ build_dir.mkdir(exist_ok = True)
+
# extract version from Python package config
patt = re.compile("version *= *(.+?) *$")
with upstream_root.joinpath("setup.cfg").open() as f:
version = m[1]
# delete stale "orig" tarball
- for p in project_root.glob("strikebot_*.orig.tar.xz"):
+ for p in project_root.glob(f"{args.project}_*.orig.tar.xz"):
p.unlink()
# regenerate the "orig" tarball from the current source
- run(["dh_make", "--yes", "--python", "--createorig", "-p", "strikebot_" + version], cwd = upstream_root)
+ run(["dh_make", "--yes", "--python", "--createorig", "-p", f"{args.project}_{version}"], cwd = upstream_root)
- # build the source package
- run(["debuild", "-i", "-us", "-uc", "-S"], cwd = upstream_root, check = True)
+ try:
+ # build the source package
+ run(["debuild", "-i", "-us", "-uc", "-S"], cwd = upstream_root, check = True)
- [orig_path] = project_root.glob("strikebot_*.orig.tar.xz")
- orig_name = orig_path.name
-
- # move outputs to build directory
- for p in project_root.iterdir():
- if _is_output(p):
- p.rename(build_dir.joinpath(p.relative_to(project_root)))
+ [orig_path] = project_root.glob(f"{args.project}_*.orig.tar.xz")
+ orig_name = orig_path.name
- # create temporary link for orig tarball to satisfy binary package build
- orig_path.symlink_to(build_dir.joinpath(orig_name).relative_to(project_root))
+ try:
+ # build binary package
+ run(["pdebuild"], cwd = upstream_root, check = True)
+ finally:
+ # clean up
+ orig_path.unlink()
+ finally:
+ # move source package and intermediates to build directory
+ for p in project_root.iterdir():
+ if _is_output(p):
+ p.rename(build_dir.joinpath(p.relative_to(project_root)))
elif args.cmd == "clean":
for p in build_dir.iterdir():
if _is_output(p):
+++ /dev/null
-# see Python `configparser' docs for precise file syntax
-
-[config]
-
-# authorization ID for Reddit API token lookup
-auth ID = 3
-
-# (int, minutes)
-token update period = 15
-
-thread ID = abc123
-
-# (int) maximum number of mentions within an update for which to send messages
-mention limit = 10
-
-# (float, seconds) wait this much time before sending notifications, in case of update deletion
-message buffer time = 10
-
-# (float, seconds) wait this much time between API polls (and WS connection checks)
-API poll period = 90
-
-# (float, seconds) replace the WS connection if it falls this far behind the API in update timestamps
-max WS delay = 2
-
-
-# Postgres database configuration; same options as in a connect string.
-[db connect params]
-host = example.org
--- /dev/null
+# see Python `configparser' docs for precise file syntax
+
+[config]
+
+# authorization ID for Reddit API token lookup
+auth ID = 3
+
+# (int, minutes)
+token update period = 15
+
+thread ID = abc123
+
+# (int) maximum number of mentions within an update for which to send messages
+mention limit = 10
+
+# (float, seconds) wait this much time before sending notifications, in case of update deletion
+message buffer time = 10
+
+# (float, seconds) wait this much time between API polls (and WS connection checks)
+API poll period = 90
+
+# (float, seconds) replace the WS connection if it falls this far behind the API in update timestamps
+max WS delay = 2
+
+
+# Postgres database configuration; same options as in a connect string.
+[db connect params]
+host = example.org
--- /dev/null
+[build-system]
+requires = ["setuptools>=42", "wheel"]
+build-backend = "setuptools.build_meta"
--- /dev/null
+[metadata]
+name = mentionbot
+version = 0.0.0
+
+[options]
+package_dir =
+ = src
+packages = mentionbot
+python_requires = ~= 3.9
+install_requires =
+ beautifulsoup4 ~= 4.9
+ #psycopg2 ~= 2.8
+ psycopg2-binary ~= 2.8
+ trio == 0.19
+ trio-websocket == 0.9.2
--- /dev/null
+import setuptools
+
+
+setuptools.setup()
--- /dev/null
+from dataclasses import dataclass
+from datetime import timedelta
+from http.client import HTTPResponse
+from logging import Logger
+from socket import EAI_AGAIN, EAI_FAIL, gaierror
+from typing import AsyncContextManager, Callable, Iterable, Optional, Type
+from urllib.error import HTTPError
+from urllib.request import Request, urlopen
+from uuid import UUID
+import importlib.metadata
+import re
+
+from bs4 import BeautifulSoup
+from trio import CancelScope, MemoryReceiveChannel, MemorySendChannel, move_on_at
+import trio
+
+from .abc import DeleteEvent, EventStream, Notifier, NotifyInfo, ThreadTitleProvider, UpdateChecker, UpdateEvent
+
+
+_USER_AGENT = "any:net.jcornell.mentionbot:v{} (by /u/jaxklax)".format(
+ importlib.metadata.version(__package__)
+)
+
+
+@dataclass
+class TokenRef:
+ """Mutable container for an access token."""
+
+ token: str
+
+ def get_common_headers(self) -> dict[str, str]:
+ return {
+ "User-Agent": _USER_AGENT,
+ "Authorization": "Bearer {}".format(self.token)
+ }
+
+
+def try_request(
+ request: Request,
+ logger: Logger,
+ suppress_codes: list[int] = [404, 429, 500, 503],
+) -> Optional[HTTPResponse]:
+
+ try:
+ return urlopen(request)
+ except gaierror as e:
+ if e.errno in [EAI_FAIL, EAI_AGAIN]:
+ logger.error(f"DNS failure: {e}")
+ return None
+ else:
+ raise
+ except HTTPError as e:
+ e.read()
+ e.close()
+ if e.code in suppress_codes:
+ logger.error(f"HTTP {e.code}")
+ return None
+ else:
+ raise
+
+
+@dataclass
+class WsListenerState:
+ """The WebSockets listener uses this object to communicate with the API poll and message sender tasks."""
+
+ # The API poll task uses this scope to force a refresh of the WS connection if it's been too quiet.
+ scope: Optional[CancelScope] = None
+
+ # The WS task uses this to track the UUID timestamp of the most recent update message update, so the API poll task
+ # can detect when the connection goes silent.
+ last_update_ts: Optional[int] = None
+
+ # The WS task sets this when the WS connection goes down and clears it when it's recovered.
+ down: bool = True
+
+ # The WS task sets this to the UUID timestamp of the first update recieved on connection recovery so the message
+ # sender task knows when to stop checking for deletion.
+ resync_ts: Optional[int] = None
+
+
+def _parse_mentions(body: BeautifulSoup) -> set[str]:
+ """Extract the names of users tagged in an update."""
+ return {
+ el["href"].removeprefix("/u/")
+ for el in body.find_all("a", href = re.compile(r"^/u/[\w-]{3,20}$", re.ASCII))
+ }
+
+
+async def ws_listener_impl(
+ event_stream_factory: Callable[[], AsyncContextManager[EventStream]],
+ event_stream_catch: Iterable[Type[BaseException]],
+ state: WsListenerState,
+ event_tx: MemorySendChannel, # message type `Event'
+ logger: Logger,
+) -> None:
+
+ while True:
+ state.scope = CancelScope()
+ with state.scope:
+ logger.debug("opening a new connection")
+ try:
+ async with event_stream_factory() as event_stream:
+ first_update = True
+ while True:
+ event = await event_stream.get_event()
+ if isinstance(event, UpdateEvent):
+ update_id = UUID(event.payload_data["id"])
+ state.last_update_ts = update_id.time
+ if first_update:
+ # first update after discontinuity
+ state.down = False
+ state.resync_ts = update_id.time
+ first_update = False
+ elif isinstance(event, DeleteEvent):
+ logger.debug("delete event for " + event.payload)
+ else:
+ raise RuntimeError(f"expected Event, got {event}")
+
+ await event_tx.send(event)
+ except tuple(event_stream_catch) as e:
+ logger.warning(f"WebSockets error: {e}")
+
+ state.down = True
+
+
+async def message_sender_impl(
+ notifier: Notifier,
+ title_provider: ThreadTitleProvider,
+ update_checker: UpdateChecker,
+ ws_state: WsListenerState,
+ thread_id: str,
+ event_rx: MemoryReceiveChannel, # message type `Event'
+ buffer_time: timedelta,
+ mention_limit: int,
+ logger: Logger,
+) -> None:
+
+ @dataclass
+ class QueueEntry:
+ payload_data: dict
+ mentions: set[str]
+ arrival: float # Trio time
+
+ queue: list[QueueEntry] = []
+
+ while True:
+ logger.debug("queue: {}".format([e.payload_data["id"] for e in queue]))
+ if queue:
+ scope = move_on_at(queue[0].arrival + buffer_time.total_seconds())
+ else:
+ scope = CancelScope() # no timeout
+
+ with scope:
+ event = await event_rx.receive()
+
+ if scope.cancelled_caught:
+ # timed out
+ entry = queue.pop(0)
+ update_id = entry.payload_data["id"]
+ if ws_state.down or ws_state.resync_ts and UUID(update_id).time < ws_state.resync_ts:
+ notify = update_checker.exists(thread_id, update_id)
+ else:
+ notify = True
+
+ if notify:
+ thread_title = title_provider.title_for(thread_id)
+ for recipient in entry.mentions:
+ notifier.notify(NotifyInfo(
+ recipient, thread_id, thread_title, update_id, entry.payload_data["author"],
+ entry.payload_data["body"]
+ ))
+ else:
+ # got an event
+ if isinstance(event, UpdateEvent):
+ update_id = event.payload_data["id"]
+ if not any(entry.payload_data["id"] == update_id for entry in queue):
+ mentions = _parse_mentions(BeautifulSoup(event.payload_data["body_html"], "html.parser"))
+ if mentions:
+ if len(mentions) > mention_limit:
+ logger.info(f"ignoring {update_id} due to too many mentions ({len(mentions)})")
+ else:
+ queue.append(QueueEntry(event.payload_data, mentions, trio.current_time()))
+ elif isinstance(event, DeleteEvent):
+ logger.debug("delete event for " + event.payload)
+ queue = [entry for entry in queue if entry.payload_data["name"] != event.payload]
+ else:
+ raise RuntimeError(f"invalid event type {type(event)}")
--- /dev/null
+"""
+Mentionbot: Notify users by private message when they're /u/-mentioned in a live thread.
+"""
+
+from argparse import ArgumentParser
+from configparser import ConfigParser
+from contextlib import asynccontextmanager
+from dataclasses import dataclass
+from datetime import timedelta
+from logging import Formatter, getLogger, Logger, NOTSET, StreamHandler
+from sys import stdout
+from typing import ClassVar, Iterator, Optional, Type
+from urllib.error import HTTPError
+from urllib.parse import urlencode
+from urllib.request import Request, urlopen
+from uuid import UUID
+import json
+import logging
+
+from trio import MemorySendChannel, open_memory_channel, open_nursery
+from trio_websocket import ConnectionClosed, HandshakeError, open_websocket_url, WebSocketConnection
+import psycopg2
+import trio
+
+from . import TokenRef, message_sender_impl, try_request, ws_listener_impl, WsListenerState
+from .abc import DeleteEvent, Event, EventStream, Notifier, NotifyInfo, ThreadTitleProvider, UpdateChecker, UpdateEvent
+
+
+async def _token_update_impl(period: timedelta, db_conn, auth_id: int, task_status) -> None:
+ def fetch_token() -> str:
+ cursor = db_conn.cursor()
+ cursor.execute(
+ """
+ select access_token from public.reddit_app_authorization
+ where id = %s
+ """,
+ (auth_id,)
+ )
+ [(token,)] = cursor.fetchall()
+ return token
+
+ ref = TokenRef(fetch_token())
+ task_status.started(ref)
+ while True:
+ await trio.sleep(period.total_seconds())
+ ref.token = fetch_token()
+
+
+async def _api_poll_impl(
+ ws_state: WsListenerState,
+ token_ref: TokenRef,
+ thread_id: str,
+ event_tx: MemorySendChannel, # message type `Event'
+ poll_period: timedelta,
+ max_ws_delay: timedelta,
+ logger: Logger,
+) -> None:
+
+ base_url = f"https://oauth.reddit.com/live/{thread_id}"
+
+ latest_payload: Optional[dict] = None
+ while True:
+ # fetch updates and emit events
+ params = {
+ "raw_json": "1",
+ }
+ if latest_payload:
+ updates = []
+ while True:
+ params["before"] = latest_payload["data"]["name"]
+ url = base_url + "?" + urlencode(params)
+ resp = try_request(Request(url, headers = token_ref.get_common_headers()), logger)
+ if resp:
+ with resp:
+ children = json.load(resp)["data"]["children"]
+ if children:
+ updates.extend(reversed(children))
+ latest_payload = children[0]
+ else:
+ break
+ else:
+ break
+ else:
+ params["limit"] = "1"
+ url = base_url + "?" + urlencode(params)
+ resp = try_request(Request(url, headers = token_ref.get_common_headers()), logger)
+ if resp:
+ with resp:
+ updates = json.load(resp)["data"]["children"]
+ if updates:
+ latest_payload = updates[0]
+ else:
+ updates = []
+
+ for update in updates:
+ await event_tx.send(UpdateEvent(update["data"]))
+
+ # check if WS connection is behind
+ if not ws_state.down and latest_payload:
+ max_delta = max_ws_delay.total_seconds() * 10 ** 7 # in 100*nanosecond
+ if UUID(latest_payload["data"]["id"]).time - ws_state.last_update_ts > max_delta:
+ ws_state.scope.cancel()
+
+ await trio.sleep(poll_period.total_seconds())
+
+
+def _build_notification_md(info: NotifyInfo) -> str:
+ # TODO escape thread title?
+ return "\n\n".join([
+ f"from [{info.author}](/user/{info.author}) in [{info.thread_title}](/live/{info.thread_id})",
+ "\n".join(">" + line for line in info.update_body.splitlines()),
+ (
+ f"[^^permalink](/live/{info.thread_id}/updates/{info.update_id})"
+ + f" [^^context](/live/{info.thread_id}?after=LiveUpdate_{info.update_id})"
+ + " [^^about ^^these ^^notifications](/r/live_mentions/wiki)"
+ ),
+ ])
+
+
+@dataclass
+class RedditNotifier(Notifier):
+ token_ref: TokenRef
+ logger: Logger
+
+ def notify(self, info: NotifyInfo) -> None:
+ params = {
+ "raw_json": "1",
+ "api_type": "json",
+ "subject": "username mention",
+ "text": _build_notification_md(info),
+ }
+ req = Request(
+ "https://oauth.reddit.com/api/compose",
+ method = "POST",
+ data = urlencode({**params, "to": info.recipient}).encode("ascii"),
+ headers = self.token_ref.get_common_headers(),
+ )
+ resp = try_request(req, self.logger)
+ if resp:
+ with resp:
+ data = json.load(resp)["json"]
+ if data["errors"]:
+ [(code, _, _)] = data["errors"]
+ if code != "USER_DOESNT_EXIST":
+ raise RuntimeError("PM send error: {code}")
+
+
+@dataclass
+class RedditThreadTitleProvider(ThreadTitleProvider):
+ token_ref: TokenRef
+
+ def title_for(self, thread_id: str) -> str:
+ req = Request(
+ f"https://oauth.reddit.com/live/{thread_id}/about?raw_json=1",
+ headers = self.token_ref.get_common_headers()
+ )
+ with urlopen(req) as resp:
+ return json.load(resp)["data"]["title"]
+
+
+@dataclass
+class RedditUpdateChecker(UpdateChecker):
+ token_ref: TokenRef
+ logger: Logger
+
+ def exists(self, thread_id: str, update_id: str) -> bool:
+ req = Request(
+ f"https://oauth.reddit.com/live/{thread_id}/updates/{update_id}?raw_json=1",
+ headers = self.token_ref.get_common_headers(),
+ )
+ try:
+ resp = try_request(req, self.logger, suppress_codes = [429, 500, 503])
+ except HTTPError as e:
+ e.read()
+ e.close()
+ if e.code == 404:
+ return False
+ else:
+ raise
+ else:
+ if resp:
+ resp.read()
+ resp.close()
+ return True
+
+
+@dataclass
+class WsConnectionEventStream(EventStream):
+ EXCEPTION_TYPES: ClassVar[set[Type[BaseException]]] = {HandshakeError, ConnectionClosed}
+
+ ws_conn: WebSocketConnection
+
+ async def get_event(self) -> Event:
+ while True:
+ message = json.loads(await self.ws_conn.get_message())
+ if message["type"] == "update":
+ return UpdateEvent(message["payload"]["data"])
+ elif message["type"] == "delete":
+ return DeleteEvent(message["payload"])
+
+
+async def _main(
+ auth_id: int,
+ buffer_time: timedelta,
+ db_conn,
+ logger: Logger,
+ max_ws_delay: timedelta,
+ mention_limit: int,
+ poll_period: timedelta,
+ thread_id: str,
+ token_period: timedelta,
+) -> None:
+
+ @asynccontextmanager
+ async def event_stream_factory() -> Iterator[WsConnectionEventStream]:
+ resp = urlopen(Request(
+ f"https://oauth.reddit.com/live/{thread_id}/about?raw_json=1",
+ headers = token_ref.get_common_headers(),
+ ))
+ with resp:
+ ws_url = json.load(resp)["data"]["websocket_url"]
+
+ async with open_websocket_url(ws_url) as ws_conn:
+ yield WsConnectionEventStream(ws_conn)
+
+ async with open_nursery() as nursery:
+ token_ref = await nursery.start(_token_update_impl, token_period, db_conn, auth_id)
+
+ ws_state = WsListenerState()
+ (event_tx, event_rx) = open_memory_channel(0)
+ nursery.start_soon(
+ ws_listener_impl,
+ event_stream_factory, WsConnectionEventStream.EXCEPTION_TYPES, ws_state, event_tx, logger.getChild("WS")
+ )
+ nursery.start_soon(
+ _api_poll_impl,
+ ws_state, token_ref, thread_id, event_tx, poll_period, max_ws_delay, logger.getChild("poll")
+ )
+
+ notifier = RedditNotifier(token_ref, logger)
+ title_provider = RedditThreadTitleProvider(token_ref)
+ deletion_checker = RedditUpdateChecker(token_ref, logger)
+ nursery.start_soon(
+ message_sender_impl,
+ notifier, title_provider, deletion_checker, ws_state, thread_id, event_rx, buffer_time, mention_limit,
+ logger.getChild("sender")
+ )
+
+
+if __name__ == "__main__":
+ arg_parser = ArgumentParser(__package__)
+ arg_parser.add_argument("config_path")
+ args = arg_parser.parse_args()
+
+ config_parser = ConfigParser()
+ with open(args.config_path) as config_file:
+ config_parser.read_file(config_file)
+
+ main_cfg = config_parser["config"]
+
+ auth_id = main_cfg.getint("auth ID")
+ assert auth_id is not None
+
+ token_period = timedelta(minutes = main_cfg.getint("token update period"))
+
+ thread_id = main_cfg["thread ID"]
+
+ mention_limit = main_cfg.getint("mention limit")
+ assert mention_limit is not None and mention_limit >= 0
+
+ buffer_time = timedelta(seconds = main_cfg.getfloat("message buffer time"))
+
+ poll_period = timedelta(seconds = main_cfg.getfloat("API poll period"))
+
+ max_ws_delay = timedelta(seconds = main_cfg.getfloat("max WS delay"))
+
+ db_conn = psycopg2.connect(**config_parser["db connect params"])
+ db_conn.autocommit = True
+
+ logger = getLogger(__package__)
+ logger.setLevel(NOTSET - 1) # filter in handler(s) instead
+
+ handler = StreamHandler(stdout)
+ handler.setLevel(logging.DEBUG)
+ handler.setFormatter(Formatter("{asctime:23}: {name:17}: {levelname:8}: {message}", style = "{"))
+ logger.addHandler(handler)
+
+ trio.run(
+ _main,
+ auth_id, buffer_time, db_conn, logger, max_ws_delay, mention_limit, poll_period, thread_id, token_period
+ )
--- /dev/null
+from abc import ABCMeta, abstractmethod
+from dataclasses import dataclass
+
+
+class Event(metaclass = ABCMeta):
+ pass
+
+
+@dataclass
+class UpdateEvent(Event):
+ payload_data: dict # object at "payload" -> "data" in the message JSON
+
+
+@dataclass
+class DeleteEvent(Event):
+ payload: str # update name
+
+
+class EventStream(metaclass = ABCMeta):
+ @abstractmethod
+ async def get_event(self) -> Event:
+ raise NotImplementedError()
+
+
+@dataclass
+class NotifyInfo:
+ recipient: str
+ thread_id: str
+ thread_title: str
+ update_id: str
+ author: str
+ update_body: str
+
+
+class Notifier(metaclass = ABCMeta):
+ @abstractmethod
+ def notify(self, info: NotifyInfo) -> None:
+ raise NotImplementedError()
+
+
+class ThreadTitleProvider(metaclass = ABCMeta):
+ @abstractmethod
+ def title_for(self, thread_id: str) -> str:
+ raise NotImplementedError()
+
+
+class UpdateChecker(metaclass = ABCMeta):
+ @abstractmethod
+ def exists(self, thread_id: str, update_id: str) -> bool:
+ raise NotImplementedError()
--- /dev/null
+from contextlib import asynccontextmanager, contextmanager
+from dataclasses import dataclass, field
+from datetime import timedelta
+from logging import getLogger
+from numbers import Real
+from os import environ
+from typing import Iterator, TypeVar
+from unittest import TestCase
+from urllib.request import Request
+from uuid import UUID, uuid1
+import logging
+
+from trio import CancelScope, MemoryReceiveChannel, move_on_after, open_memory_channel, open_nursery, sleep_until
+from trio.testing import MockClock, wait_all_tasks_blocked
+import trio
+
+from . import message_sender_impl, TokenRef, try_request, WsListenerState, ws_listener_impl
+from .abc import DeleteEvent, Event, EventStream, Notifier, NotifyInfo, ThreadTitleProvider, UpdateChecker, UpdateEvent
+
+
+T = TypeVar("T")
+
+
+@dataclass
+class MockNotifier(Notifier):
+ notifies: list[NotifyInfo] = field(default_factory = list)
+
+ def notify(self, info: NotifyInfo) -> None:
+ self.notifies.append(info)
+
+
+class MockTitleProvider(ThreadTitleProvider):
+ def title_for(self, thread_id: str) -> str:
+ return thread_id
+
+
+class UnusedUpdateChecker(UpdateChecker):
+ def exists(self, thread_id: str, update_id: str) -> bool:
+ raise AssertionError("unexpected deletion check")
+
+
+@dataclass
+class DictBackedUpdateChecker(UpdateChecker):
+ exists_by_id: dict[str, bool] = field(default_factory = dict)
+ queries: list[tuple[str, str]] = field(default_factory = list)
+
+ def exists(self, thread_id: str, update_id: str) -> bool:
+ self.queries.append((thread_id, update_id))
+ return self.exists_by_id[update_id]
+
+
+@dataclass
+class ChannelBackedEventStream(EventStream):
+ event_rx: MemoryReceiveChannel
+
+ async def get_event(self) -> Event:
+ obj = await self.event_rx.receive()
+ if isinstance(obj, BaseException):
+ raise obj
+ else:
+ return obj
+
+
+@asynccontextmanager
+async def null_async_context(val: T) -> Iterator[T]:
+ yield val
+
+
+def _update_event(author: str, mentions: list[str]) -> UpdateEvent:
+ id_ = uuid1()
+ body = str(id_) + " " + " ".join("/u/" + user for user in mentions)
+ body_html = str(id_) + " " + " ".join(
+ f'<a href="/u/{user}">/u/{user}</a>'
+ for user in mentions
+ )
+ return UpdateEvent(
+ {
+ "id": str(id_),
+ "name": "LiveUpdate_" + str(id_),
+ "author": author,
+ "body": body,
+ "body_html": body_html,
+ }
+ )
+
+
+def setUpModule() -> None:
+ getLogger().setLevel(logging.CRITICAL + 1) # disable logging
+
+
+class TaskTests(TestCase):
+ @contextmanager
+ def _assert_completes_in(self, duration_seconds: Real) -> Iterator[None]:
+ with move_on_after(duration_seconds) as scope:
+ yield
+ if scope.cancelled_caught:
+ self.fail("operation timed out")
+
+ def test_buffering(self) -> None:
+ """
+ No API poll task, just a custom task mocking the WS listener by playing events over a channel, and a real
+ notifier task injected with mocks to capture output. The WS state is hard-coded as up, so no testing of recovery
+ from WS connection issues is included in this test.
+
+ Primarily it makes sure that updates are passed through the buffer and processed if and only if a delete event
+ doesn't arrive while they're buffered.
+ """
+
+ buffer_time = timedelta(seconds = 1)
+ notifier = MockNotifier()
+ (event_tx, event_rx) = open_memory_channel(0)
+
+ ign_event = _update_event("sender", [])
+ event_a = _update_event("sender", ["rec1", "rec2"])
+ event_b = _update_event("sender", ["rec1", "rec2"])
+
+ async def event_stream_impl(scope: CancelScope) -> None:
+ await sleep_until(1.0)
+ await event_tx.send(ign_event)
+ await event_tx.send(event_a)
+
+ await sleep_until(1.5)
+ self.assertFalse(notifier.notifies)
+ await event_tx.send(DeleteEvent(event_a.payload_data["id"]))
+
+ await sleep_until(2.0)
+ await event_tx.send(event_b)
+
+ # let things settle and terminate both tasks
+ await sleep_until(4.0)
+ scope.cancel()
+
+ async def main() -> None:
+ async with open_nursery() as nursery:
+ nursery.start_soon(event_stream_impl, nursery.cancel_scope)
+ nursery.start_soon(
+ message_sender_impl,
+ notifier, MockTitleProvider(), UnusedUpdateChecker(), WsListenerState(down = False), "ThreadId",
+ event_rx, buffer_time, getLogger()
+ )
+
+ trio.run(main, clock = MockClock(autojump_threshold = 0))
+
+ # only notifications sent are for event B
+ self.assertEqual(len(notifier.notifies), 2)
+ self.assertEqual({i.recipient for i in notifier.notifies}, {"rec1", "rec2"})
+ self.assertTrue(all(i.author == "sender" for i in notifier.notifies))
+ self.assertTrue(all(i.update_id == event_b.payload_data["id"] for i in notifier.notifies))
+
+ def test_ws_trouble(self) -> None:
+ """
+ Ensure that the notifier task responds correctly to the WS connection going down and coming back up.
+ """
+
+ thread_id = "ThreadId"
+ buffer_time = timedelta(seconds = 3)
+ update_checker = DictBackedUpdateChecker()
+ notifier = MockNotifier()
+ (event_tx, event_rx) = open_memory_channel(0)
+
+ # sent while WS is up, checked and notifies sent
+ event_a = _update_event("sender", ["rec"])
+
+ # sent and deleted while WS is down, checked and discarded from buffer
+ event_b = _update_event("sender", ["rec"])
+
+ # sent on WS reconnect, not checked and notifies sent
+ event_c = _update_event("sender", ["rec"])
+
+ # sent after WS reconnect, not checked and notifies sent
+ event_d = _update_event("sender", ["rec"])
+
+ ws_state = WsListenerState(down = False)
+
+ async def ws_simulator_impl() -> None:
+ update_checker.exists_by_id[event_a.payload_data["id"]] = True
+ await event_tx.send(event_a)
+
+ await sleep_until(1.0)
+ ws_state.down = True
+ update_checker.exists_by_id[event_b.payload_data["id"]] = True
+ await event_tx.send(event_b)
+
+ await sleep_until(2.0)
+ update_checker.exists_by_id[event_b.payload_data["id"]] = False
+ ws_state.down = False
+ ws_state.resync_ts = UUID(event_c.payload_data["id"]).time
+ await event_tx.send(event_c)
+
+ await sleep_until(3.0)
+ await event_tx.send(event_d)
+
+ async def main() -> None:
+ async with open_nursery() as nursery:
+ nursery.start_soon(ws_simulator_impl)
+ nursery.start_soon(
+ message_sender_impl,
+ notifier, MockTitleProvider(), update_checker, ws_state, "ThreadId", event_rx, buffer_time,
+ getLogger()
+ )
+
+ await sleep_until(3.5)
+ self.assertEqual([i.update_id for i in notifier.notifies], [event_a.payload_data["id"]])
+ notifier.notifies.clear()
+ self.assertEqual(update_checker.queries, [(thread_id, event_a.payload_data["id"])])
+ update_checker.queries.clear()
+
+ await sleep_until(4.5)
+ self.assertFalse(notifier.notifies)
+ self.assertEqual(update_checker.queries, [(thread_id, event_b.payload_data["id"])])
+ update_checker.queries.clear()
+
+ await sleep_until(5.5)
+ self.assertEqual([i.update_id for i in notifier.notifies], [event_c.payload_data["id"]])
+ notifier.notifies.clear()
+ self.assertFalse(update_checker.queries)
+
+ await sleep_until(6.5)
+ self.assertEqual([i.update_id for i in notifier.notifies], [event_d.payload_data["id"]])
+ notifier.notifies.clear()
+ self.assertFalse(update_checker.queries)
+
+ nursery.cancel_scope.cancel()
+
+ trio.run(main, clock = MockClock(autojump_threshold = 0))
+
+ def test_ws_listener_cancel(self) -> None:
+ ws_state = WsListenerState()
+
+ (event_in_tx, event_in_rx) = open_memory_channel(0)
+ event_stream = ChannelBackedEventStream(event_in_rx)
+ event_stream_factory = lambda: null_async_context(event_stream)
+
+ (event_out_tx, event_out_rx) = open_memory_channel(0)
+
+ async def main():
+ async with open_nursery() as nursery:
+ nursery.start_soon(
+ ws_listener_impl,
+ event_stream_factory, set(), ws_state, event_out_tx, getLogger()
+ )
+
+ event_1 = _update_event("author", [])
+ await event_in_tx.send(event_1)
+ with self._assert_completes_in(1):
+ event_out = await event_out_rx.receive()
+ self.assertIs(event_out, event_1)
+
+ event_2 = _update_event("author", ["recipient"])
+ await event_in_tx.send(event_2)
+ with self._assert_completes_in(1):
+ event_out = await event_out_rx.receive()
+ self.assertIs(event_out, event_2)
+
+ initial_scope = ws_state.scope
+ initial_scope.cancel()
+ await wait_all_tasks_blocked()
+ self.assertIsNotNone(ws_state.scope)
+ self.assertIsNot(ws_state.scope, initial_scope)
+ self.assertTrue(ws_state.down)
+
+ event_3 = _update_event("author", [])
+ ts = UUID(event_3.payload_data["id"]).time
+ await event_in_tx.send(event_3)
+ with self._assert_completes_in(1):
+ event_out = await event_out_rx.receive()
+ self.assertIs(event_out, event_3)
+ self.assertEqual(ws_state.last_update_ts, ts)
+ self.assertFalse(ws_state.down)
+ self.assertEqual(ws_state.resync_ts, ts)
+
+ nursery.cancel_scope.cancel()
+
+ trio.run(main, clock = MockClock(autojump_threshold = 0))
+
+ def test_ws_listener_error(self) -> None:
+ """Test that the WS listener responds correctly to errors from the input event stream."""
+
+ class MockStreamException(Exception):
+ pass
+
+ ws_state = WsListenerState()
+
+ (event_in_tx, event_in_rx) = open_memory_channel(0)
+ event_stream = ChannelBackedEventStream(event_in_rx)
+ event_stream_factory = lambda: null_async_context(event_stream)
+
+ (event_out_tx, event_out_rx) = open_memory_channel(0)
+
+ async def main():
+ async with open_nursery() as nursery:
+ nursery.start_soon(
+ ws_listener_impl,
+ event_stream_factory, {MockStreamException}, ws_state, event_out_tx, getLogger()
+ )
+
+ event_1 = _update_event("author", [])
+ await event_in_tx.send(event_1)
+ with self._assert_completes_in(1):
+ event_out = await event_out_rx.receive()
+ self.assertIs(event_out, event_1)
+
+ initial_scope = ws_state.scope
+ await event_in_tx.send(MockStreamException())
+ await wait_all_tasks_blocked()
+ self.assertIsNotNone(ws_state.scope)
+ self.assertIsNot(ws_state.scope, initial_scope)
+ self.assertTrue(ws_state.down)
+
+ event_2 = _update_event("author", ["recipient"])
+ ts = UUID(event_2.payload_data["id"]).time
+ await event_in_tx.send(event_2)
+ with self._assert_completes_in(1):
+ event_out = await event_out_rx.receive()
+ self.assertIs(event_out, event_2)
+ self.assertEqual(ws_state.last_update_ts, ts)
+ self.assertFalse(ws_state.down)
+ self.assertEqual(ws_state.resync_ts, ts)
+
+ nursery.cancel_scope.cancel()
+
+ trio.run(main, clock = MockClock(autojump_threshold = 0))
+
+
+class RedditTests(TestCase):
+ """Tests for interactions with Reddit."""
+
+ @classmethod
+ def setUpClass(cls) -> None:
+ cls._TOKEN_REF = TokenRef(environ["API_TOKEN"])
+
+ def test_try_request(self) -> None:
+ self.assertIsNotNone(
+ try_request(
+ Request(
+ "https://oauth.reddit.com/api/live/happening_now",
+ headers = self._TOKEN_REF.get_common_headers(),
+ ),
+ getLogger(),
+ )
+ )
+
+ def test_try_request_http_error(self) -> None:
+ self.assertIsNone(
+ try_request(
+ Request("https://oauth.reddit.com/api/live/happening_now"),
+ getLogger(),
+ suppress_codes = [403],
+ )
+ )
+++ /dev/null
-[build-system]
-requires = ["setuptools>=42", "wheel"]
-build-backend = "setuptools.build_meta"
+++ /dev/null
-[metadata]
-name = mentionbot
-version = 0.0.0
-
-[options]
-package_dir =
- = src
-packages = mentionbot
-python_requires = ~= 3.9
-install_requires =
- beautifulsoup4 ~= 4.9
- #psycopg2 ~= 2.8
- psycopg2-binary ~= 2.8
- trio == 0.19
- trio-websocket == 0.9.2
+++ /dev/null
-import setuptools
-
-
-setuptools.setup()
+++ /dev/null
-from dataclasses import dataclass
-from datetime import timedelta
-from http.client import HTTPResponse
-from logging import Logger
-from socket import EAI_AGAIN, EAI_FAIL, gaierror
-from typing import AsyncContextManager, Callable, Iterable, Optional, Type
-from urllib.error import HTTPError
-from urllib.request import Request, urlopen
-from uuid import UUID
-import importlib.metadata
-import re
-
-from bs4 import BeautifulSoup
-from trio import CancelScope, MemoryReceiveChannel, MemorySendChannel, move_on_at
-import trio
-
-from .abc import DeleteEvent, EventStream, Notifier, NotifyInfo, ThreadTitleProvider, UpdateChecker, UpdateEvent
-
-
-_USER_AGENT = "any:net.jcornell.mentionbot:v{} (by /u/jaxklax)".format(
- importlib.metadata.version(__package__)
-)
-
-
-@dataclass
-class TokenRef:
- """Mutable container for an access token."""
-
- token: str
-
- def get_common_headers(self) -> dict[str, str]:
- return {
- "User-Agent": _USER_AGENT,
- "Authorization": "Bearer {}".format(self.token)
- }
-
-
-def try_request(
- request: Request,
- logger: Logger,
- suppress_codes: list[int] = [404, 429, 500, 503],
-) -> Optional[HTTPResponse]:
-
- try:
- return urlopen(request)
- except gaierror as e:
- if e.errno in [EAI_FAIL, EAI_AGAIN]:
- logger.error(f"DNS failure: {e}")
- return None
- else:
- raise
- except HTTPError as e:
- e.read()
- e.close()
- if e.code in suppress_codes:
- logger.error(f"HTTP {e.code}")
- return None
- else:
- raise
-
-
-@dataclass
-class WsListenerState:
- """The WebSockets listener uses this object to communicate with the API poll and message sender tasks."""
-
- # The API poll task uses this scope to force a refresh of the WS connection if it's been too quiet.
- scope: Optional[CancelScope] = None
-
- # The WS task uses this to track the UUID timestamp of the most recent update message update, so the API poll task
- # can detect when the connection goes silent.
- last_update_ts: Optional[int] = None
-
- # The WS task sets this when the WS connection goes down and clears it when it's recovered.
- down: bool = True
-
- # The WS task sets this to the UUID timestamp of the first update recieved on connection recovery so the message
- # sender task knows when to stop checking for deletion.
- resync_ts: Optional[int] = None
-
-
-def _parse_mentions(body: BeautifulSoup) -> set[str]:
- """Extract the names of users tagged in an update."""
- return {
- el["href"].removeprefix("/u/")
- for el in body.find_all("a", href = re.compile(r"^/u/[\w-]{3,20}$", re.ASCII))
- }
-
-
-async def ws_listener_impl(
- event_stream_factory: Callable[[], AsyncContextManager[EventStream]],
- event_stream_catch: Iterable[Type[BaseException]],
- state: WsListenerState,
- event_tx: MemorySendChannel, # message type `Event'
- logger: Logger,
-) -> None:
-
- while True:
- state.scope = CancelScope()
- with state.scope:
- logger.debug("opening a new connection")
- try:
- async with event_stream_factory() as event_stream:
- first_update = True
- while True:
- event = await event_stream.get_event()
- if isinstance(event, UpdateEvent):
- update_id = UUID(event.payload_data["id"])
- state.last_update_ts = update_id.time
- if first_update:
- # first update after discontinuity
- state.down = False
- state.resync_ts = update_id.time
- first_update = False
- elif isinstance(event, DeleteEvent):
- logger.debug("delete event for " + event.payload)
- else:
- raise RuntimeError(f"expected Event, got {event}")
-
- await event_tx.send(event)
- except tuple(event_stream_catch) as e:
- logger.warning(f"WebSockets error: {e}")
-
- state.down = True
-
-
-async def message_sender_impl(
- notifier: Notifier,
- title_provider: ThreadTitleProvider,
- update_checker: UpdateChecker,
- ws_state: WsListenerState,
- thread_id: str,
- event_rx: MemoryReceiveChannel, # message type `Event'
- buffer_time: timedelta,
- mention_limit: int,
- logger: Logger,
-) -> None:
-
- @dataclass
- class QueueEntry:
- payload_data: dict
- mentions: set[str]
- arrival: float # Trio time
-
- queue: list[QueueEntry] = []
-
- while True:
- logger.debug("queue: {}".format([e.payload_data["id"] for e in queue]))
- if queue:
- scope = move_on_at(queue[0].arrival + buffer_time.total_seconds())
- else:
- scope = CancelScope() # no timeout
-
- with scope:
- event = await event_rx.receive()
-
- if scope.cancelled_caught:
- # timed out
- entry = queue.pop(0)
- update_id = entry.payload_data["id"]
- if ws_state.down or ws_state.resync_ts and UUID(update_id).time < ws_state.resync_ts:
- notify = update_checker.exists(thread_id, update_id)
- else:
- notify = True
-
- if notify:
- thread_title = title_provider.title_for(thread_id)
- for recipient in entry.mentions:
- notifier.notify(NotifyInfo(
- recipient, thread_id, thread_title, update_id, entry.payload_data["author"],
- entry.payload_data["body"]
- ))
- else:
- # got an event
- if isinstance(event, UpdateEvent):
- update_id = event.payload_data["id"]
- if not any(entry.payload_data["id"] == update_id for entry in queue):
- mentions = _parse_mentions(BeautifulSoup(event.payload_data["body_html"], "html.parser"))
- if mentions:
- if len(mentions) > mention_limit:
- logger.info(f"ignoring {update_id} due to too many mentions ({len(mentions)})")
- else:
- queue.append(QueueEntry(event.payload_data, mentions, trio.current_time()))
- elif isinstance(event, DeleteEvent):
- logger.debug("delete event for " + event.payload)
- queue = [entry for entry in queue if entry.payload_data["name"] != event.payload]
- else:
- raise RuntimeError(f"invalid event type {type(event)}")
+++ /dev/null
-"""
-Mentionbot: Notify users by private message when they're /u/-mentioned in a live thread.
-"""
-
-from argparse import ArgumentParser
-from configparser import ConfigParser
-from contextlib import asynccontextmanager
-from dataclasses import dataclass
-from datetime import timedelta
-from logging import Formatter, getLogger, Logger, NOTSET, StreamHandler
-from sys import stdout
-from typing import ClassVar, Iterator, Optional, Type
-from urllib.error import HTTPError
-from urllib.parse import urlencode
-from urllib.request import Request, urlopen
-from uuid import UUID
-import json
-import logging
-
-from trio import MemorySendChannel, open_memory_channel, open_nursery
-from trio_websocket import ConnectionClosed, HandshakeError, open_websocket_url, WebSocketConnection
-import psycopg2
-import trio
-
-from . import TokenRef, message_sender_impl, try_request, ws_listener_impl, WsListenerState
-from .abc import DeleteEvent, Event, EventStream, Notifier, NotifyInfo, ThreadTitleProvider, UpdateChecker, UpdateEvent
-
-
-async def _token_update_impl(period: timedelta, db_conn, auth_id: int, task_status) -> None:
- def fetch_token() -> str:
- cursor = db_conn.cursor()
- cursor.execute(
- """
- select access_token from public.reddit_app_authorization
- where id = %s
- """,
- (auth_id,)
- )
- [(token,)] = cursor.fetchall()
- return token
-
- ref = TokenRef(fetch_token())
- task_status.started(ref)
- while True:
- await trio.sleep(period.total_seconds())
- ref.token = fetch_token()
-
-
-async def _api_poll_impl(
- ws_state: WsListenerState,
- token_ref: TokenRef,
- thread_id: str,
- event_tx: MemorySendChannel, # message type `Event'
- poll_period: timedelta,
- max_ws_delay: timedelta,
- logger: Logger,
-) -> None:
-
- base_url = f"https://oauth.reddit.com/live/{thread_id}"
-
- latest_payload: Optional[dict] = None
- while True:
- # fetch updates and emit events
- params = {
- "raw_json": "1",
- }
- if latest_payload:
- updates = []
- while True:
- params["before"] = latest_payload["data"]["name"]
- url = base_url + "?" + urlencode(params)
- resp = try_request(Request(url, headers = token_ref.get_common_headers()), logger)
- if resp:
- with resp:
- children = json.load(resp)["data"]["children"]
- if children:
- updates.extend(reversed(children))
- latest_payload = children[0]
- else:
- break
- else:
- break
- else:
- params["limit"] = "1"
- url = base_url + "?" + urlencode(params)
- resp = try_request(Request(url, headers = token_ref.get_common_headers()), logger)
- if resp:
- with resp:
- updates = json.load(resp)["data"]["children"]
- if updates:
- latest_payload = updates[0]
- else:
- updates = []
-
- for update in updates:
- await event_tx.send(UpdateEvent(update["data"]))
-
- # check if WS connection is behind
- if not ws_state.down and latest_payload:
- max_delta = max_ws_delay.total_seconds() * 10 ** 7 # in 100*nanosecond
- if UUID(latest_payload["data"]["id"]).time - ws_state.last_update_ts > max_delta:
- ws_state.scope.cancel()
-
- await trio.sleep(poll_period.total_seconds())
-
-
-def _build_notification_md(info: NotifyInfo) -> str:
- # TODO escape thread title?
- return "\n\n".join([
- f"from [{info.author}](/user/{info.author}) in [{info.thread_title}](/live/{info.thread_id})",
- "\n".join(">" + line for line in info.update_body.splitlines()),
- (
- f"[^^permalink](/live/{info.thread_id}/updates/{info.update_id})"
- + f" [^^context](/live/{info.thread_id}?after=LiveUpdate_{info.update_id})"
- + " [^^about ^^these ^^notifications](/r/live_mentions/wiki)"
- ),
- ])
-
-
-@dataclass
-class RedditNotifier(Notifier):
- token_ref: TokenRef
- logger: Logger
-
- def notify(self, info: NotifyInfo) -> None:
- params = {
- "raw_json": "1",
- "api_type": "json",
- "subject": "username mention",
- "text": _build_notification_md(info),
- }
- req = Request(
- "https://oauth.reddit.com/api/compose",
- method = "POST",
- data = urlencode({**params, "to": info.recipient}).encode("ascii"),
- headers = self.token_ref.get_common_headers(),
- )
- resp = try_request(req, self.logger)
- if resp:
- with resp:
- data = json.load(resp)["json"]
- if data["errors"]:
- [(code, _, _)] = data["errors"]
- if code != "USER_DOESNT_EXIST":
- raise RuntimeError("PM send error: {code}")
-
-
-@dataclass
-class RedditThreadTitleProvider(ThreadTitleProvider):
- token_ref: TokenRef
-
- def title_for(self, thread_id: str) -> str:
- req = Request(
- f"https://oauth.reddit.com/live/{thread_id}/about?raw_json=1",
- headers = self.token_ref.get_common_headers()
- )
- with urlopen(req) as resp:
- return json.load(resp)["data"]["title"]
-
-
-@dataclass
-class RedditUpdateChecker(UpdateChecker):
- token_ref: TokenRef
- logger: Logger
-
- def exists(self, thread_id: str, update_id: str) -> bool:
- req = Request(
- f"https://oauth.reddit.com/live/{thread_id}/updates/{update_id}?raw_json=1",
- headers = self.token_ref.get_common_headers(),
- )
- try:
- resp = try_request(req, self.logger, suppress_codes = [429, 500, 503])
- except HTTPError as e:
- e.read()
- e.close()
- if e.code == 404:
- return False
- else:
- raise
- else:
- if resp:
- resp.read()
- resp.close()
- return True
-
-
-@dataclass
-class WsConnectionEventStream(EventStream):
- EXCEPTION_TYPES: ClassVar[set[Type[BaseException]]] = {HandshakeError, ConnectionClosed}
-
- ws_conn: WebSocketConnection
-
- async def get_event(self) -> Event:
- while True:
- message = json.loads(await self.ws_conn.get_message())
- if message["type"] == "update":
- return UpdateEvent(message["payload"]["data"])
- elif message["type"] == "delete":
- return DeleteEvent(message["payload"])
-
-
-async def _main(
- auth_id: int,
- buffer_time: timedelta,
- db_conn,
- logger: Logger,
- max_ws_delay: timedelta,
- mention_limit: int,
- poll_period: timedelta,
- thread_id: str,
- token_period: timedelta,
-) -> None:
-
- @asynccontextmanager
- async def event_stream_factory() -> Iterator[WsConnectionEventStream]:
- resp = urlopen(Request(
- f"https://oauth.reddit.com/live/{thread_id}/about?raw_json=1",
- headers = token_ref.get_common_headers(),
- ))
- with resp:
- ws_url = json.load(resp)["data"]["websocket_url"]
-
- async with open_websocket_url(ws_url) as ws_conn:
- yield WsConnectionEventStream(ws_conn)
-
- async with open_nursery() as nursery:
- token_ref = await nursery.start(_token_update_impl, token_period, db_conn, auth_id)
-
- ws_state = WsListenerState()
- (event_tx, event_rx) = open_memory_channel(0)
- nursery.start_soon(
- ws_listener_impl,
- event_stream_factory, WsConnectionEventStream.EXCEPTION_TYPES, ws_state, event_tx, logger.getChild("WS")
- )
- nursery.start_soon(
- _api_poll_impl,
- ws_state, token_ref, thread_id, event_tx, poll_period, max_ws_delay, logger.getChild("poll")
- )
-
- notifier = RedditNotifier(token_ref, logger)
- title_provider = RedditThreadTitleProvider(token_ref)
- deletion_checker = RedditUpdateChecker(token_ref, logger)
- nursery.start_soon(
- message_sender_impl,
- notifier, title_provider, deletion_checker, ws_state, thread_id, event_rx, buffer_time, mention_limit,
- logger.getChild("sender")
- )
-
-
-if __name__ == "__main__":
- arg_parser = ArgumentParser(__package__)
- arg_parser.add_argument("config_path")
- args = arg_parser.parse_args()
-
- config_parser = ConfigParser()
- with open(args.config_path) as config_file:
- config_parser.read_file(config_file)
-
- main_cfg = config_parser["config"]
-
- auth_id = main_cfg.getint("auth ID")
- assert auth_id is not None
-
- token_period = timedelta(minutes = main_cfg.getint("token update period"))
-
- thread_id = main_cfg["thread ID"]
-
- mention_limit = main_cfg.getint("mention limit")
- assert mention_limit is not None and mention_limit >= 0
-
- buffer_time = timedelta(seconds = main_cfg.getfloat("message buffer time"))
-
- poll_period = timedelta(seconds = main_cfg.getfloat("API poll period"))
-
- max_ws_delay = timedelta(seconds = main_cfg.getfloat("max WS delay"))
-
- db_conn = psycopg2.connect(**config_parser["db connect params"])
- db_conn.autocommit = True
-
- logger = getLogger(__package__)
- logger.setLevel(NOTSET - 1) # filter in handler(s) instead
-
- handler = StreamHandler(stdout)
- handler.setLevel(logging.DEBUG)
- handler.setFormatter(Formatter("{asctime:23}: {name:17}: {levelname:8}: {message}", style = "{"))
- logger.addHandler(handler)
-
- trio.run(
- _main,
- auth_id, buffer_time, db_conn, logger, max_ws_delay, mention_limit, poll_period, thread_id, token_period
- )
+++ /dev/null
-from abc import ABCMeta, abstractmethod
-from dataclasses import dataclass
-
-
-class Event(metaclass = ABCMeta):
- pass
-
-
-@dataclass
-class UpdateEvent(Event):
- payload_data: dict # object at "payload" -> "data" in the message JSON
-
-
-@dataclass
-class DeleteEvent(Event):
- payload: str # update name
-
-
-class EventStream(metaclass = ABCMeta):
- @abstractmethod
- async def get_event(self) -> Event:
- raise NotImplementedError()
-
-
-@dataclass
-class NotifyInfo:
- recipient: str
- thread_id: str
- thread_title: str
- update_id: str
- author: str
- update_body: str
-
-
-class Notifier(metaclass = ABCMeta):
- @abstractmethod
- def notify(self, info: NotifyInfo) -> None:
- raise NotImplementedError()
-
-
-class ThreadTitleProvider(metaclass = ABCMeta):
- @abstractmethod
- def title_for(self, thread_id: str) -> str:
- raise NotImplementedError()
-
-
-class UpdateChecker(metaclass = ABCMeta):
- @abstractmethod
- def exists(self, thread_id: str, update_id: str) -> bool:
- raise NotImplementedError()
+++ /dev/null
-from contextlib import asynccontextmanager, contextmanager
-from dataclasses import dataclass, field
-from datetime import timedelta
-from logging import getLogger
-from numbers import Real
-from os import environ
-from typing import Iterator, TypeVar
-from unittest import TestCase
-from urllib.request import Request
-from uuid import UUID, uuid1
-import logging
-
-from trio import CancelScope, MemoryReceiveChannel, move_on_after, open_memory_channel, open_nursery, sleep_until
-from trio.testing import MockClock, wait_all_tasks_blocked
-import trio
-
-from . import message_sender_impl, TokenRef, try_request, WsListenerState, ws_listener_impl
-from .abc import DeleteEvent, Event, EventStream, Notifier, NotifyInfo, ThreadTitleProvider, UpdateChecker, UpdateEvent
-
-
-T = TypeVar("T")
-
-
-@dataclass
-class MockNotifier(Notifier):
- notifies: list[NotifyInfo] = field(default_factory = list)
-
- def notify(self, info: NotifyInfo) -> None:
- self.notifies.append(info)
-
-
-class MockTitleProvider(ThreadTitleProvider):
- def title_for(self, thread_id: str) -> str:
- return thread_id
-
-
-class UnusedUpdateChecker(UpdateChecker):
- def exists(self, thread_id: str, update_id: str) -> bool:
- raise AssertionError("unexpected deletion check")
-
-
-@dataclass
-class DictBackedUpdateChecker(UpdateChecker):
- exists_by_id: dict[str, bool] = field(default_factory = dict)
- queries: list[tuple[str, str]] = field(default_factory = list)
-
- def exists(self, thread_id: str, update_id: str) -> bool:
- self.queries.append((thread_id, update_id))
- return self.exists_by_id[update_id]
-
-
-@dataclass
-class ChannelBackedEventStream(EventStream):
- event_rx: MemoryReceiveChannel
-
- async def get_event(self) -> Event:
- obj = await self.event_rx.receive()
- if isinstance(obj, BaseException):
- raise obj
- else:
- return obj
-
-
-@asynccontextmanager
-async def null_async_context(val: T) -> Iterator[T]:
- yield val
-
-
-def _update_event(author: str, mentions: list[str]) -> UpdateEvent:
- id_ = uuid1()
- body = str(id_) + " " + " ".join("/u/" + user for user in mentions)
- body_html = str(id_) + " " + " ".join(
- f'<a href="/u/{user}">/u/{user}</a>'
- for user in mentions
- )
- return UpdateEvent(
- {
- "id": str(id_),
- "name": "LiveUpdate_" + str(id_),
- "author": author,
- "body": body,
- "body_html": body_html,
- }
- )
-
-
-def setUpModule() -> None:
- getLogger().setLevel(logging.CRITICAL + 1) # disable logging
-
-
-class TaskTests(TestCase):
- @contextmanager
- def _assert_completes_in(self, duration_seconds: Real) -> Iterator[None]:
- with move_on_after(duration_seconds) as scope:
- yield
- if scope.cancelled_caught:
- self.fail("operation timed out")
-
- def test_buffering(self) -> None:
- """
- No API poll task, just a custom task mocking the WS listener by playing events over a channel, and a real
- notifier task injected with mocks to capture output. The WS state is hard-coded as up, so no testing of recovery
- from WS connection issues is included in this test.
-
- Primarily it makes sure that updates are passed through the buffer and processed if and only if a delete event
- doesn't arrive while they're buffered.
- """
-
- buffer_time = timedelta(seconds = 1)
- notifier = MockNotifier()
- (event_tx, event_rx) = open_memory_channel(0)
-
- ign_event = _update_event("sender", [])
- event_a = _update_event("sender", ["rec1", "rec2"])
- event_b = _update_event("sender", ["rec1", "rec2"])
-
- async def event_stream_impl(scope: CancelScope) -> None:
- await sleep_until(1.0)
- await event_tx.send(ign_event)
- await event_tx.send(event_a)
-
- await sleep_until(1.5)
- self.assertFalse(notifier.notifies)
- await event_tx.send(DeleteEvent(event_a.payload_data["id"]))
-
- await sleep_until(2.0)
- await event_tx.send(event_b)
-
- # let things settle and terminate both tasks
- await sleep_until(4.0)
- scope.cancel()
-
- async def main() -> None:
- async with open_nursery() as nursery:
- nursery.start_soon(event_stream_impl, nursery.cancel_scope)
- nursery.start_soon(
- message_sender_impl,
- notifier, MockTitleProvider(), UnusedUpdateChecker(), WsListenerState(down = False), "ThreadId",
- event_rx, buffer_time, getLogger()
- )
-
- trio.run(main, clock = MockClock(autojump_threshold = 0))
-
- # only notifications sent are for event B
- self.assertEqual(len(notifier.notifies), 2)
- self.assertEqual({i.recipient for i in notifier.notifies}, {"rec1", "rec2"})
- self.assertTrue(all(i.author == "sender" for i in notifier.notifies))
- self.assertTrue(all(i.update_id == event_b.payload_data["id"] for i in notifier.notifies))
-
- def test_ws_trouble(self) -> None:
- """
- Ensure that the notifier task responds correctly to the WS connection going down and coming back up.
- """
-
- thread_id = "ThreadId"
- buffer_time = timedelta(seconds = 3)
- update_checker = DictBackedUpdateChecker()
- notifier = MockNotifier()
- (event_tx, event_rx) = open_memory_channel(0)
-
- # sent while WS is up, checked and notifies sent
- event_a = _update_event("sender", ["rec"])
-
- # sent and deleted while WS is down, checked and discarded from buffer
- event_b = _update_event("sender", ["rec"])
-
- # sent on WS reconnect, not checked and notifies sent
- event_c = _update_event("sender", ["rec"])
-
- # sent after WS reconnect, not checked and notifies sent
- event_d = _update_event("sender", ["rec"])
-
- ws_state = WsListenerState(down = False)
-
- async def ws_simulator_impl() -> None:
- update_checker.exists_by_id[event_a.payload_data["id"]] = True
- await event_tx.send(event_a)
-
- await sleep_until(1.0)
- ws_state.down = True
- update_checker.exists_by_id[event_b.payload_data["id"]] = True
- await event_tx.send(event_b)
-
- await sleep_until(2.0)
- update_checker.exists_by_id[event_b.payload_data["id"]] = False
- ws_state.down = False
- ws_state.resync_ts = UUID(event_c.payload_data["id"]).time
- await event_tx.send(event_c)
-
- await sleep_until(3.0)
- await event_tx.send(event_d)
-
- async def main() -> None:
- async with open_nursery() as nursery:
- nursery.start_soon(ws_simulator_impl)
- nursery.start_soon(
- message_sender_impl,
- notifier, MockTitleProvider(), update_checker, ws_state, "ThreadId", event_rx, buffer_time,
- getLogger()
- )
-
- await sleep_until(3.5)
- self.assertEqual([i.update_id for i in notifier.notifies], [event_a.payload_data["id"]])
- notifier.notifies.clear()
- self.assertEqual(update_checker.queries, [(thread_id, event_a.payload_data["id"])])
- update_checker.queries.clear()
-
- await sleep_until(4.5)
- self.assertFalse(notifier.notifies)
- self.assertEqual(update_checker.queries, [(thread_id, event_b.payload_data["id"])])
- update_checker.queries.clear()
-
- await sleep_until(5.5)
- self.assertEqual([i.update_id for i in notifier.notifies], [event_c.payload_data["id"]])
- notifier.notifies.clear()
- self.assertFalse(update_checker.queries)
-
- await sleep_until(6.5)
- self.assertEqual([i.update_id for i in notifier.notifies], [event_d.payload_data["id"]])
- notifier.notifies.clear()
- self.assertFalse(update_checker.queries)
-
- nursery.cancel_scope.cancel()
-
- trio.run(main, clock = MockClock(autojump_threshold = 0))
-
- def test_ws_listener_cancel(self) -> None:
- ws_state = WsListenerState()
-
- (event_in_tx, event_in_rx) = open_memory_channel(0)
- event_stream = ChannelBackedEventStream(event_in_rx)
- event_stream_factory = lambda: null_async_context(event_stream)
-
- (event_out_tx, event_out_rx) = open_memory_channel(0)
-
- async def main():
- async with open_nursery() as nursery:
- nursery.start_soon(
- ws_listener_impl,
- event_stream_factory, set(), ws_state, event_out_tx, getLogger()
- )
-
- event_1 = _update_event("author", [])
- await event_in_tx.send(event_1)
- with self._assert_completes_in(1):
- event_out = await event_out_rx.receive()
- self.assertIs(event_out, event_1)
-
- event_2 = _update_event("author", ["recipient"])
- await event_in_tx.send(event_2)
- with self._assert_completes_in(1):
- event_out = await event_out_rx.receive()
- self.assertIs(event_out, event_2)
-
- initial_scope = ws_state.scope
- initial_scope.cancel()
- await wait_all_tasks_blocked()
- self.assertIsNotNone(ws_state.scope)
- self.assertIsNot(ws_state.scope, initial_scope)
- self.assertTrue(ws_state.down)
-
- event_3 = _update_event("author", [])
- ts = UUID(event_3.payload_data["id"]).time
- await event_in_tx.send(event_3)
- with self._assert_completes_in(1):
- event_out = await event_out_rx.receive()
- self.assertIs(event_out, event_3)
- self.assertEqual(ws_state.last_update_ts, ts)
- self.assertFalse(ws_state.down)
- self.assertEqual(ws_state.resync_ts, ts)
-
- nursery.cancel_scope.cancel()
-
- trio.run(main, clock = MockClock(autojump_threshold = 0))
-
- def test_ws_listener_error(self) -> None:
- """Test that the WS listener responds correctly to errors from the input event stream."""
-
- class MockStreamException(Exception):
- pass
-
- ws_state = WsListenerState()
-
- (event_in_tx, event_in_rx) = open_memory_channel(0)
- event_stream = ChannelBackedEventStream(event_in_rx)
- event_stream_factory = lambda: null_async_context(event_stream)
-
- (event_out_tx, event_out_rx) = open_memory_channel(0)
-
- async def main():
- async with open_nursery() as nursery:
- nursery.start_soon(
- ws_listener_impl,
- event_stream_factory, {MockStreamException}, ws_state, event_out_tx, getLogger()
- )
-
- event_1 = _update_event("author", [])
- await event_in_tx.send(event_1)
- with self._assert_completes_in(1):
- event_out = await event_out_rx.receive()
- self.assertIs(event_out, event_1)
-
- initial_scope = ws_state.scope
- await event_in_tx.send(MockStreamException())
- await wait_all_tasks_blocked()
- self.assertIsNotNone(ws_state.scope)
- self.assertIsNot(ws_state.scope, initial_scope)
- self.assertTrue(ws_state.down)
-
- event_2 = _update_event("author", ["recipient"])
- ts = UUID(event_2.payload_data["id"]).time
- await event_in_tx.send(event_2)
- with self._assert_completes_in(1):
- event_out = await event_out_rx.receive()
- self.assertIs(event_out, event_2)
- self.assertEqual(ws_state.last_update_ts, ts)
- self.assertFalse(ws_state.down)
- self.assertEqual(ws_state.resync_ts, ts)
-
- nursery.cancel_scope.cancel()
-
- trio.run(main, clock = MockClock(autojump_threshold = 0))
-
-
-class RedditTests(TestCase):
- """Tests for interactions with Reddit."""
-
- @classmethod
- def setUpClass(cls) -> None:
- cls._TOKEN_REF = TokenRef(environ["API_TOKEN"])
-
- def test_try_request(self) -> None:
- self.assertIsNotNone(
- try_request(
- Request(
- "https://oauth.reddit.com/api/live/happening_now",
- headers = self._TOKEN_REF.get_common_headers(),
- ),
- getLogger(),
- )
- )
-
- def test_try_request_http_error(self) -> None:
- self.assertIsNone(
- try_request(
- Request("https://oauth.reddit.com/api/live/happening_now"),
- getLogger(),
- suppress_codes = [403],
- )
- )
+++ /dev/null
-# see Python `configparser' docs for precise file syntax
-
-[config]
-
-######## basic operating parameters
-
-# space-separated authorization IDs for Reddit API token lookup
-auth IDs = 13 15
-
-bot user = count_better
-
-# if false, don't strike or delete updates or report incorrectly stricken updates (optional, default true)
-enforcing = true
-
-# log messages at or above this severity to standard out; level names and numbers are supported (optional, default
-# WARNING); see https://docs.python.org/3/library/logging.html#levels
-log level = WARNING
-
-# if true, don't modify the live thread at all (implies enforcing = false) (optional, default false)
-read-only = false
-
-thread ID = abc123
-
-
-######## API pool error handling
-
-# (seconds) for error backoff, pause the API pool for this much time
-API pool error delay = 5.5
-
-# (seconds) pause the pool when the Reddit API returns two errors within any time span of this length
-API pool error window = 5.5
-
-# terminate if the Reddit API request queue exceeds this size
-request queue limit = 100
-
-
-######## live thread update handling
-
-# (seconds) maximum time to hold updates for reordering
-reorder buffer time = 0.25
-
-# (seconds) minimum time to retain updates to enable future resyncs
-update retention time = 120.0
-
-
-######## WebSocket pool
-
-# (seconds) maximum allowable spread in WebSocket message arrival times
-WS parity time = 1.0
-
-# goal size of WebSocket connection pool
-WS pool size = 5
-
-# (seconds) after receiving no messages for this much time, attempt to replace the connection if the pool isn't depleted
-WS silent limit = 30
-
-# (seconds) time after WebSocket handshake completion during which missed updates are excused
-WS warmup time = 0.3
-
-
-# Postgres database configuration; same options as in a connect string. Note that because Python's `SSLContext' is used
-# to implement TLS client auth, using `sslkey' to specify a key obtained from an OpenSSL engine may not be supported (a
-# file path is the only option).
-[db connect params]
-host = example.org
+++ /dev/null
-[build-system]
-requires = ["setuptools>=42", "wheel"]
-build-backend = "setuptools.build_meta"
+++ /dev/null
-[metadata]
-name = strikebot
-version = 0.0.7
-
-[options]
-package_dir =
- = src
-packages = strikebot
-python_requires = ~= 3.9
-install_requires =
- asks ~= 2.4
- beautifulsoup4 ~= 4.9
- trio == 0.19
- trio-websocket == 0.9.2
- triopg == 0.6.0
+++ /dev/null
-import setuptools
-
-
-setuptools.setup()
+++ /dev/null
-"""Count tracking logic and bot's external interface."""
-
-from contextlib import nullcontext as nullcontext
-from dataclasses import dataclass
-from functools import total_ordering
-from itertools import islice
-from typing import Optional
-from uuid import UUID
-import bisect
-import datetime as dt
-import importlib.metadata
-import itertools
-import logging
-
-from trio import CancelScope, EndOfChannel
-import trio
-
-from strikebot.updates import Command, parse_update
-
-
-__version__ = importlib.metadata.version(__package__)
-
-
-@dataclass
-class _Update:
- id: str
- name: str
- author: str
- number: Optional[int]
- command: Optional[Command]
- ts: dt.datetime
- count_attempt: bool
- deletable: bool
- stricken: bool
-
- def can_follow(self, prior: "_Update") -> bool:
- """Determine whether this count update can follow another count update."""
- return self.number == prior.number + 1 and self.author != prior.author
-
- def __str__(self):
- content = str(self.number)
- if self.command:
- content += "+" + self.command.name
- return "{}({} by {})".format(self.id[4 : 8], content, self.author)
-
-
-@total_ordering
-@dataclass
-class _BufferedUpdate:
- update: _Update
- release_at: float # Trio time
-
- def __lt__(self, other):
- return self.update.ts < other.update.ts
-
-
-@total_ordering
-@dataclass
-class _TimelineUpdate:
- update: _Update
- accepted: bool
-
- def __lt__(self, other):
- return self.update.ts < other.update.ts
-
- def rejected(self) -> bool:
- """Whether the update should be stricken."""
- return self.update.count_attempt and not self.accepted
-
-
-def _parse_rfc_4122_ts(ts: int) -> dt.datetime:
- epoch = dt.datetime(1582, 10, 15, tzinfo = dt.timezone.utc)
- return epoch + dt.timedelta(microseconds = ts / 10)
-
-
-def _format_update_ref(update: _Update, thread_id: str) -> str:
- url = f"https://www.reddit.com/live/{thread_id}/updates/{update.id}"
- author_md = update.author.replace("_", "\\_") # Reddit allows letters, digits, underscore, and hyphen
- return f"[{update.number:,} by {author_md}]({url})"
-
-
-def _format_bad_strike_alert(update: _Update, thread_id: str) -> str:
- return "_incorrectly stricken:_ " + _format_update_ref(update, thread_id)
-
-
-def _format_curr_count(last_valid: Optional[_Update]) -> str:
- if last_valid and last_valid.number is not None:
- count_str = f"{last_valid.number + 1:,}"
- else:
- count_str = "unknown"
- return f"_current count:_ {count_str}"
-
-
-async def count_tracker_impl(
- message_rx,
- api_pool: "strikebot.reddit_api.ApiClientPool",
- reorder_buffer_time: dt.timedelta, # duration to hold some updates for reordering
- thread_id: str,
- bot_user: str,
- enforcing: bool,
- read_only: bool,
- update_retention: dt.timedelta,
- logger: logging.Logger,
-) -> None:
- from strikebot.reddit_api import CorrectionUpdateRequest, DeleteRequest, ReportUpdateRequest, StrikeRequest
-
- enforcing = enforcing and not read_only
-
- buffer_ = []
- timeline = []
- last_valid: Optional[_Update] = None
- pending_strikes: set[str] = set() # names of updates to mark stricken on arrival
- forgotten: bool = False # whether the retention period for an update has passed during the run
-
- def handle_update(update):
- nonlocal last_valid
-
- tu = _TimelineUpdate(update, accepted = False)
-
- pos = bisect.bisect(timeline, tu)
- if pos != len(timeline):
- logger.warning(f"long transpo: {update.id}")
-
- if pos > 0 and timeline[pos - 1].update.id == update.id:
- # The pool sent this update message multiple times. This could be because the message reached parity and
- # its bucket was dropped just before a new connection was opened and sent its own copy, but connections also
- # seem to send duplicate messages.
- return
-
- pred = next(
- (timeline[i].update for i in reversed(range(pos)) if timeline[i].accepted),
- None
- )
- contender = update.command is Command.RESET or update.number is not None
- if contender and pred is None and last_valid and forgotten:
- # This is a really long transpo which we can't judge valid or invalid since we've forgotten the surrounding
- # updates. The best we can do is just ignore it.
- logger.warning(f"ignoring {update.id}: no valid prior count on record")
- else:
- timeline.insert(pos, tu)
- tu.accepted = (
- update.command is Command.RESET
- or (
- update.number is not None
- and (pred is None or pred.number is None or update.can_follow(pred))
- and not update.stricken
- )
- )
- logger.debug("{} {} -> {}".format("accepted" if tu.accepted else "rejected", pred, update))
- if tu.accepted:
- # resync subsequent updates
- newly_invalid = []
- resync_last_valid = update
- converged = False
- for scan_tu in islice(timeline, pos + 1, None):
- if scan_tu.update.command is Command.RESET:
- resync_last_valid = scan_tu.update
- if last_valid:
- converged = True
- elif scan_tu.update.number is not None:
- accept = (
- (resync_last_valid.number is None or scan_tu.update.can_follow(resync_last_valid))
- and not scan_tu.update.stricken
- )
- if accept:
- assert scan_tu.accepted
- resync_last_valid = scan_tu.update
- # resync would have no effect past this point
- if last_valid:
- converged = True
- elif scan_tu.accepted:
- newly_invalid.append(scan_tu)
- if scan_tu.update is last_valid:
- last_valid = None
- scan_tu.accepted = accept
- if converged:
- break
-
- if converged:
- assert last_valid.ts >= resync_last_valid.ts
- else:
- last_valid = resync_last_valid
-
- parts = []
- unstrikable = [tu for tu in newly_invalid if tu.update.stricken]
- if unstrikable:
- parts.append(
- "The following counts are invalid:\n\n"
- + "\n".join(" - " + _format_update_ref(tu.update, thread_id) for tu in unstrikable)
- )
-
- if update.stricken:
- logger.info(f"bad strike of {update.id}")
- parts.append(_format_bad_strike_alert(update, thread_id))
-
- if parts:
- parts.append(_format_curr_count(last_valid))
- if not read_only:
- api_pool.enqueue_request(CorrectionUpdateRequest(thread_id, "\n\n".join(parts)))
-
- for invalid_tu in newly_invalid:
- if not invalid_tu.update.stricken:
- if enforcing:
- api_pool.enqueue_request(StrikeRequest(thread_id, invalid_tu.update.name, invalid_tu.update.ts))
- invalid_tu.update.stricken = True
- elif update.count_attempt and not update.stricken:
- if enforcing:
- api_pool.enqueue_request(StrikeRequest(thread_id, update.name, update.ts))
- update.stricken = True
-
- if not read_only and update.command is Command.REPORT:
- api_pool.enqueue_request(ReportUpdateRequest(thread_id, body = _format_curr_count(last_valid)))
-
- async with message_rx:
- while True:
- if buffer_:
- deadline = min(bu.release_at for bu in buffer_)
- cancel_scope = CancelScope(deadline = deadline)
- else:
- cancel_scope = nullcontext()
-
- msg = None
- with cancel_scope:
- try:
- msg = await message_rx.receive()
- except EndOfChannel:
- break
-
- if msg:
- if msg.data["type"] == "update":
- payload_data = msg.data["payload"]["data"]
- stricken = payload_data["name"] in pending_strikes
- pending_strikes.discard(payload_data["name"])
- if payload_data["author"] != bot_user:
- if last_valid and last_valid.number is not None:
- next_up = last_valid.number + 1
- else:
- next_up = None
- pu = parse_update(payload_data, next_up, bot_user)
- rfc_ts = UUID(payload_data["id"]).time
- update = _Update(
- id = payload_data["id"],
- name = payload_data["name"],
- author = payload_data["author"],
- number = pu.number,
- command = pu.command,
- ts = _parse_rfc_4122_ts(rfc_ts),
- count_attempt = pu.count_attempt,
- deletable = pu.deletable,
- stricken = stricken or payload_data["stricken"],
- )
- release_at = msg.timestamp + reorder_buffer_time.total_seconds()
- bisect.insort(buffer_, _BufferedUpdate(update, release_at))
- elif msg.data["type"] == "strike":
- slot = next(
- (
- slot for slot in itertools.chain(buffer_, reversed(timeline))
- if slot.update.name == msg.data["payload"]
- ),
- None
- )
- if slot:
- if not slot.update.stricken:
- slot.update.stricken = True
- if isinstance(slot, _TimelineUpdate) and slot.accepted:
- logger.info(f"bad strike of {slot.update.id}")
- if not read_only:
- body = _format_bad_strike_alert(slot.update, thread_id)
- api_pool.enqueue_request(CorrectionUpdateRequest(thread_id, body))
- else:
- pending_strikes.add(msg.data["payload"])
-
- threshold = dt.datetime.now(dt.timezone.utc) - update_retention
-
- # pull updates from the reorder buffer and process
- new_buffer = []
- for bu in buffer_:
- process = (
- # not a count
- bu.update.number is None
-
- # valid count
- or (last_valid is None or last_valid.number is None or bu.update.can_follow(last_valid))
- or bu.update.command is Command.RESET
-
- or trio.current_time() >= bu.release_at
- or bu.update.command is Command.RESET
- )
-
- if process:
- if bu.update.ts < threshold:
- logger.warning(f"ignoring {bu.update}: arrived past retention window")
- else:
- logger.debug(f"processing {bu.update}")
- handle_update(bu.update)
- else:
- logger.debug(f"holding {bu.update}, checked against {last_valid})")
- new_buffer.append(bu)
- buffer_ = new_buffer
- logger.debug("last count {}".format(last_valid))
-
- # delete/forget old updates
- new_timeline = []
- i = 0
- for (i, tu) in enumerate(timeline):
- if i >= len(timeline) - 10 or tu.update.ts >= threshold:
- break
- elif tu.rejected() and tu.update.deletable:
- if enforcing:
- api_pool.enqueue_request(DeleteRequest(thread_id, tu.update.name, tu.update.ts))
- elif tu.update is last_valid:
- new_timeline.append(tu)
-
- if i:
- forgotten = True
- new_timeline.extend(islice(timeline, i, None))
- timeline = new_timeline
+++ /dev/null
-from enum import Enum
-from inspect import getmodule
-from logging import FileHandler, getLogger, StreamHandler
-from pathlib import Path
-from signal import SIGUSR1
-from sys import stdout
-from traceback import StackSummary
-from typing import Optional
-import argparse
-import configparser
-import datetime as dt
-import logging
-import ssl
-
-from trio import open_memory_channel, open_nursery, open_signal_receiver
-
-try:
- from trio.lowlevel import current_root_task, Task
-except ImportError:
- from trio.hazmat import current_root_task, Task
-
-import trio_asyncio
-import triopg
-
-from strikebot import count_tracker_impl
-from strikebot.db import Client, Messenger
-from strikebot.live_ws import HealingReadPool, PoolMerger
-from strikebot.reddit_api import ApiClientPool
-
-
-_DEBUG_LOG_PATH: Optional[Path] = None # file to write debug logs to, if any
-
-
-_TaskKind = Enum(
- "_TaskKind",
- [
- "API_WORKER",
- "DB_CLIENT",
- "TOKEN_UPDATE",
- "TRACKER",
- "WS_MERGE_READER",
- "WS_MERGE_TIMER",
- "WS_REFRESHER",
- "WS_WORKER",
- ]
-)
-
-
-def _classify_task(task: Task) -> _TaskKind:
- mod_name = getmodule(task.coro.cr_code).__name__
- if mod_name.startswith("trio_asyncio."):
- return None
- elif task.coro.__name__ == "_signal_handler_impl":
- return None
- else:
- by_coro_name = {
- "db_client_impl": _TaskKind.DB_CLIENT,
- "token_updater_impl": _TaskKind.TOKEN_UPDATE,
- "event_reader_impl": _TaskKind.WS_MERGE_READER,
- "timeout_handler_impl": _TaskKind.WS_MERGE_TIMER,
- "conn_refresher_impl": _TaskKind.WS_REFRESHER,
- "count_tracker_impl": _TaskKind.TRACKER,
- "_reader_impl": _TaskKind.WS_WORKER,
- "worker_impl": _TaskKind.API_WORKER,
- }
- return by_coro_name[task.coro.__name__]
-
-
-def _get_all_tasks():
- [ta_loop] = [
- task
- for nursery in current_root_task().child_nurseries
- for task in nursery.child_tasks
- if getmodule(task.coro.cr_code).__name__.startswith("trio_asyncio.")
- ]
- for nursery in ta_loop.child_nurseries:
- yield from nursery.child_tasks
-
-
-def _print_task_info():
- groups = {kind: [] for kind in _TaskKind}
- for task in _get_all_tasks():
- kind = _classify_task(task)
- if kind:
- coro = task.coro
- frame_tuples = []
- while coro is not None:
- if hasattr(coro, "cr_frame"):
- frame_tuples.append((coro.cr_frame, coro.cr_frame.f_lineno))
- coro = coro.cr_await
- else:
- frame_tuples.append((coro.gi_frame, coro.gi_frame.f_lineno))
- coro = coro.gi_yieldfrom
- groups[kind].append(StackSummary.extract(iter(frame_tuples)))
- for kind in _TaskKind:
- print(kind.name)
- for ss in groups[kind]:
- print(" task")
- for line in ss.format():
- print(" " + line, end = "")
-
-
-async def _signal_handler_impl():
- with open_signal_receiver(SIGUSR1) as sig_src:
- async for _ in sig_src:
- _print_task_info()
-
-
-ap = argparse.ArgumentParser(__package__)
-ap.add_argument("config_path")
-args = ap.parse_args()
-
-
-parser = configparser.ConfigParser()
-with open(args.config_path) as config_file:
- parser.read_file(config_file)
-
-main_cfg = parser["config"]
-
-api_pool_error_delay = dt.timedelta(seconds = main_cfg.getfloat("API pool error delay"))
-api_pool_error_window = dt.timedelta(seconds = main_cfg.getfloat("API pool error window"))
-auth_ids = set(map(int, main_cfg["auth IDs"].split()))
-bot_user = main_cfg["bot user"]
-enforcing = main_cfg.getboolean("enforcing", fallback = True)
-
-raw_log_level = main_cfg.get("log level", "WARNING")
-try:
- log_level = int(raw_log_level)
-except ValueError:
- log_level = raw_log_level
-
-read_only = main_cfg.getboolean("read-only", fallback = False)
-reorder_buffer_time = dt.timedelta(seconds = main_cfg.getfloat("reorder buffer time"))
-request_queue_limit = main_cfg.getint("request queue limit")
-thread_id = main_cfg["thread ID"]
-update_retention = dt.timedelta(seconds = main_cfg.getfloat("update retention time"))
-ws_parity_time = dt.timedelta(seconds = main_cfg.getfloat("WS parity time"))
-ws_pool_size = main_cfg.getint("WS pool size")
-ws_silent_limit = dt.timedelta(seconds = main_cfg.getfloat("WS silent limit"))
-ws_warmup = dt.timedelta(seconds = main_cfg.getfloat("WS warmup time"))
-
-
-db_cfg = parser["db connect params"]
-cert_path = db_cfg.pop("sslcert", None)
-key_path = db_cfg.pop("sslkey", None)
-key_pass = db_cfg.pop("sslpassword", None)
-
-getters = {
- "port": db_cfg.getint,
-}
-db_connect_kwargs = {k: getters.get(k, db_cfg.get)(k) for k in db_cfg}
-
-if cert_path:
- # patch TLS client cert auth support; asyncpg adds it in a later version
- db_ssl_ctx = ssl.create_default_context()
- db_ssl_ctx.load_cert_chain(
- cert_path,
- key_path,
- None if key_pass is None else lambda: key_pass,
- )
- db_connect_kwargs["ssl"] = db_ssl_ctx
-
-
-if read_only and enforcing:
- raise RuntimeError("can't use read-only mode with enforcing on")
-
-
-logger = getLogger(__package__)
-logger.setLevel(logging.NOTSET - 1) # NOTSET means inherit from parent; we use handlers to filter
-
-handler = StreamHandler(stdout)
-handler.setLevel(log_level)
-handler.setFormatter(logging.Formatter("{asctime:23}: {name:17} {levelname:8}: {message}", style = "{"))
-logger.addHandler(handler)
-
-if _DEBUG_LOG_PATH: # TODO remove this ad hoc setup
- debug_handler = FileHandler(_DEBUG_LOG_PATH, mode = "w")
- debug_handler.setLevel(logging.DEBUG)
- debug_handler.setFormatter(logging.Formatter("{asctime:23}: {name:17} {levelname:8}: {message}", style = "{"))
- logger.addHandler(debug_handler)
-
-
-async def main():
- async with (
- triopg.connect(**db_connect_kwargs) as db_conn,
- open_nursery() as nursery_a,
- open_nursery() as nursery_b,
- open_nursery() as ws_pool_nursery,
- open_nursery() as nursery_c,
- ):
- nursery_a.start_soon(_signal_handler_impl)
-
- client = Client(db_conn)
- db_messenger = Messenger(client)
- nursery_a.start_soon(db_messenger.db_client_impl)
-
- api_pool = ApiClientPool(
- auth_ids, db_messenger, request_queue_limit, api_pool_error_window, api_pool_error_delay,
- logger.getChild("api")
- )
- nursery_a.start_soon(api_pool.token_updater_impl)
- for _ in auth_ids:
- await nursery_a.start(api_pool.worker_impl)
-
- (message_tx, message_rx) = open_memory_channel(0)
- (pool_event_tx, pool_event_rx) = open_memory_channel(0)
-
- merger = PoolMerger(pool_event_rx, message_tx, ws_parity_time, ws_warmup, logger.getChild("merge"))
- nursery_b.start_soon(merger.event_reader_impl)
- nursery_b.start_soon(merger.timeout_handler_impl)
-
- ws_pool = HealingReadPool(
- ws_pool_nursery, ws_pool_size, thread_id, api_pool, pool_event_tx, logger.getChild("live_ws"),
- ws_silent_limit
- )
- nursery_c.start_soon(ws_pool.conn_refresher_impl)
-
- nursery_c.start_soon(
- count_tracker_impl, message_rx, api_pool, reorder_buffer_time, thread_id, bot_user, enforcing, read_only,
- update_retention, logger.getChild("track")
- )
- await ws_pool.init_workers()
-
-trio_asyncio.run(main)
+++ /dev/null
-from typing import Any
-
-
-def int_digest(val: int) -> str:
- return "{:04x}".format(val & 0xffff)
-
-
-def obj_digest(obj: Any) -> str:
- """Makes a short digest of the identity of an object."""
- return int_digest(id(obj))
+++ /dev/null
-"""Single-connection Postgres client with high-level wrappers for operations."""
-
-from dataclasses import dataclass
-from typing import Any, Iterable
-
-import trio
-
-
-def _channel_sender(method):
- async def wrapped(self, resp_channel, *args, **kwargs):
- ret = await method(self, *args, **kwargs)
- async with resp_channel:
- await resp_channel.send(ret)
-
- return wrapped
-
-
-@dataclass
-class Client:
- """High-level wrappers for DB operations."""
- conn: Any
-
- @_channel_sender
- async def get_auth_tokens(self, auth_ids: set[int]) -> dict[int, str]:
- results = await self.conn.fetch(
- """
- select
- distinct on (id)
- id, access_token
- from
- public.reddit_app_authorization
- join unnest($1::integer[]) as request_id on id = request_id
- where expires > current_timestamp
- """,
- list(auth_ids),
- )
- return {r["id"]: r["access_token"] for r in results}
-
-
-class Messenger:
- def __init__(self, client: Client):
- self._client = client
- (self._request_tx, self._request_rx) = trio.open_memory_channel(0)
-
- async def do(self, method_name: str, args: Iterable) -> Any:
- """This is run by consumers of the DB wrapper."""
- method = getattr(self._client, method_name)
- (resp_tx, resp_rx) = trio.open_memory_channel(0)
- coro = method(resp_tx, *args)
- await self._request_tx.send(coro)
- async with resp_rx:
- return await resp_rx.receive()
-
- async def db_client_impl(self):
- """This is the DB client Trio task."""
- async with self._client.conn, self._request_rx:
- async for coro in self._request_rx:
- await coro
+++ /dev/null
-from collections import deque
-from contextlib import suppress
-from dataclasses import dataclass
-from typing import Any, AsyncContextManager, Optional
-import datetime as dt
-import json
-import logging
-import math
-
-from trio_websocket import ConnectionClosed, HandshakeError, open_websocket_url, WebSocketConnection
-import trio
-
-from strikebot.common import int_digest, obj_digest
-from strikebot.reddit_api import AboutLiveThreadRequest
-
-
-@dataclass
-class _PoolEvent:
- timestamp: float # Trio time
-
-
-@dataclass
-class _Message(_PoolEvent):
- data: Any
- scope: Any
-
- def dedup_tag(self):
- """A hashable essence of the contents of this message."""
- if self.data["type"] == "update":
- return ("update", self.data["payload"]["data"]["id"])
- elif self.data["type"] in {"strike", "delete"}:
- return (self.data["type"], self.data["payload"])
- else:
- raise ValueError(self.data["type"])
-
-
-@dataclass
-class _ConnectionDown(_PoolEvent):
- scope: trio.CancelScope
-
-
-@dataclass
-class _ConnectionUp(_PoolEvent):
- scope: trio.CancelScope
-
-
-class HealingReadPool:
- def __init__(
- self, pool_nursery: trio.Nursery, size, live_thread_id, api_client_pool, pool_event_tx, logger: logging.Logger,
- silent_limit: dt.timedelta
- ):
- assert size >= 2
- self._nursery = pool_nursery
- self._size = size
- self._live_thread_id = live_thread_id
- self._api_client_pool = api_client_pool
- self._pool_event_tx = pool_event_tx
- self._logger = logger
- self._silent_limit = silent_limit
-
- (self._refresh_queue_tx, self._refresh_queue_rx) = trio.open_memory_channel(size)
-
- # number of workers who have stopped receiving updates but whose tasks aren't yet stopped
- self._closing = 0
-
- async def init_workers(self):
- for _ in range(self._size):
- await self._spawn_reader()
- if not self._nursery.child_tasks:
- raise RuntimeError("Unable to create any WS connections")
-
- async def _reader_impl(self, conn_ctx: AsyncContextManager, cancel_scope, refresh_tx):
- # TODO could use task's implicit cancel scope
- try:
- async with conn_ctx as conn:
- self._logger.debug("scope up: {}".format(obj_digest(cancel_scope)))
- async with refresh_tx:
- silent_timeout = False
- with cancel_scope, suppress(ConnectionClosed):
- while True:
- with trio.move_on_after(self._silent_limit.total_seconds()) as timeout_scope:
- message = await conn.get_message()
-
- if timeout_scope.cancelled_caught:
- if len(self._nursery.child_tasks) - self._closing == self._size:
- silent_timeout = True
- break
- else:
- self._logger.debug("not replacing connection {}; {} tasks, {} closing".format(
- obj_digest(cancel_scope),
- len(self._nursery.child_tasks),
- self._closing,
- ))
- else:
- event = _Message(trio.current_time(), json.loads(message), cancel_scope)
- await self._pool_event_tx.send(event)
-
- self._closing += 1
- if silent_timeout:
- await conn.aclose()
- self._logger.debug("replacing WS connection {} due to silent timeout".format(obj_digest(cancel_scope)))
- elif cancel_scope.cancelled_caught:
- await conn.aclose(1008, "Server unexpectedly stopped sending messages")
- self._logger.warning("replacing WS connection {} due to missed update".format(obj_digest(cancel_scope)))
- else:
- self._logger.warning("replacing WS connection {} closed by server".format(obj_digest(cancel_scope)))
-
- refresh_tx.send_nowait(cancel_scope)
- self._logger.debug("scope down: {} ({} tasks, {} closing)".format(
- obj_digest(cancel_scope),
- len(self._nursery.child_tasks),
- self._closing,
- ))
- self._closing -= 1
- except HandshakeError:
- self._logger.error("handshake error while opening WS connection")
-
- async def _new_connection(self) -> Optional[AsyncContextManager[WebSocketConnection]]:
- request = AboutLiveThreadRequest(self._live_thread_id)
- resp = await self._api_client_pool.make_request(request)
- if resp.status_code == 200:
- url = json.loads(resp.text)["data"]["websocket_url"]
- return open_websocket_url(url)
- else:
- self._logger.error(f"thread info request failed (HTTP {resp.status_code}); can't open WS connection")
- return None
-
- async def _spawn_reader(self) -> None:
- conn = await self._new_connection()
- if conn:
- new_scope = trio.CancelScope()
- await self._pool_event_tx.send(_ConnectionUp(trio.current_time(), new_scope))
- refresh_tx = self._refresh_queue_tx.clone()
- self._nursery.start_soon(self._reader_impl, conn, new_scope, refresh_tx)
-
- async def conn_refresher_impl(self):
- """Task to monitor and replace WS connections as they disconnect or go silent."""
- async with self._refresh_queue_rx:
- async for old_scope in self._refresh_queue_rx:
- await self._pool_event_tx.send(_ConnectionDown(trio.current_time(), old_scope))
- await self._spawn_reader()
- if not self._nursery.child_tasks:
- raise RuntimeError("WS pool depleted")
-
-
-def _tag_digest(tag):
- return int_digest(hash(tag))
-
-
-def _format_scope_list(scopes):
- return ", ".join(sorted(map(obj_digest, scopes)))
-
-
-class PoolMerger:
- @dataclass
- class _Bucket:
- start: float
- recipients: set
-
- def __init__(self, pool_event_rx, message_tx, parity_timeout: dt.timedelta, conn_warmup: dt.timedelta, logger: logging.Logger):
- """
- :param parity_timeout: max delay between message receipt on different connections
- :param conn_warmup: max interval after connection within which missed updates are allowed
- """
- self._pool_event_rx = pool_event_rx
- self._message_tx = message_tx
- self._parity_timeout = parity_timeout
- self._conn_warmup = conn_warmup
- self._logger = logger
-
- self._buckets = {}
- self._scope_activations = {}
- self._pending = deque()
- self._outgoing_scopes = set()
- (self._timer_poke_tx, self._timer_poke_rx) = trio.open_memory_channel(math.inf)
-
- def _log_current_buckets(self):
- self._logger.debug("current buckets: {}".format(", ".join(map(_tag_digest, self._buckets.keys()))))
-
- async def event_reader_impl(self):
- """Drop unused messages, deduplicate useful ones, and communicate with the timeout handler."""
- async with self._pool_event_rx, self._timer_poke_tx:
- async for event in self._pool_event_rx:
- if isinstance(event, _ConnectionUp):
- # An early add of an active scope could mean it's expected on a message that fired before it opened,
- # resulting in a false positive for replacement. The timer task merges these in among the buckets by
- # timestamp to avoid that.
- self._pending.append(event)
- elif isinstance(event, _Message):
- if event.data["type"] in ["update", "strike", "delete"]:
- tag = event.dedup_tag()
- self._logger.debug("recv {} from {}".format(_tag_digest(tag), obj_digest(event.scope)))
- if tag in self._buckets:
- b = self._buckets[tag]
- b.recipients.add(event.scope)
- # If this scope is the last one for this bucket we could clear the bucket here, but since
- # connections sometimes get repeat messages and second copies can arrive before the first
- # copy has arrived on all connections, leaving the bucket open to absorb repeats can reduce
- # the likelihood of a second bucket being allocated late for the second copy of a message
- # and causing unnecessary connection replacement.
- elif event.scope not in self._outgoing_scopes:
- sane = (
- event.scope in self._scope_activations
- or any(e.scope == event.scope for e in self._pending)
- )
- if sane:
- self._logger.debug("new bucket {}: {}".format(_tag_digest(tag), tag))
- self._buckets[tag] = self._Bucket(event.timestamp, {event.scope})
- self._log_current_buckets()
- await self._message_tx.send(event)
- else:
- raise RuntimeError("recieved message from unrecognized WS connection")
- else:
- self._logger.debug("recv type {!r} from {} (discarding)".format(event.data["type"], obj_digest(event.scope)))
- elif isinstance(event, _ConnectionDown):
- # We don't need to worry about canceling this scope at all, so no need to require it for parity for
- # any message, even older ones. The scope may be gone already, if we canceled it previously.
- self._scope_activations.pop(event.scope, None)
- self._pending = deque(ev for ev in self._pending if ev.scope is not event.scope)
- self._outgoing_scopes.discard(event.scope)
- else:
- raise TypeError(f"Expected pool event, found {event!r}")
- self._timer_poke_tx.send_nowait(None) # may have new work for the timer
-
- async def timeout_handler_impl(self):
- """When connections have had enough time to reach parity on a message, signal replacement of any slackers."""
- async with self._timer_poke_rx:
- while True:
- if self._buckets:
- now = trio.current_time()
- (tag, bucket) = min(self._buckets.items(), key = lambda i: i[1].start)
-
- # Make sure the right scope set is ready for the next bucket up
- while self._pending and self._pending[0].timestamp < bucket.start:
- ev = self._pending.popleft()
- self._scope_activations[ev.scope] = ev.timestamp
-
- target_scopes = {
- scope for (scope, active) in self._scope_activations.items()
- if active + self._conn_warmup.total_seconds() < now
- }
- if now > bucket.start + self._parity_timeout.total_seconds():
- filled = target_scopes & bucket.recipients
- missing = target_scopes - bucket.recipients
- extra = bucket.recipients - target_scopes
-
- self._logger.debug("expiring bucket {}".format(_tag_digest(tag)))
- if filled:
- self._logger.debug(" filled {}: {}".format(len(filled), _format_scope_list(filled)))
- if missing:
- self._logger.debug(" missing {}: {}".format(len(missing), _format_scope_list(missing)))
- if extra:
- self._logger.debug(" extra {}: {}".format(len(extra), _format_scope_list(extra)))
-
- for scope in missing:
- self._outgoing_scopes.add(scope)
- scope.cancel()
- del self._buckets[tag]
- self._log_current_buckets()
- else:
- await trio.sleep_until(bucket.start + self._parity_timeout.total_seconds())
- else:
- try:
- await self._timer_poke_rx.receive()
- except trio.EndOfChannel:
- break
+++ /dev/null
-"""Unbounded blocking queues for Trio"""
-
-from collections import deque
-from dataclasses import dataclass
-from functools import total_ordering
-from typing import Any, Iterable
-import heapq
-
-try:
- from trio.lowlevel import ParkingLot
-except ImportError:
- from trio.hazmat import ParkingLot
-
-
-class Queue:
- def __init__(self):
- self._deque = deque()
- self._empty_wait = ParkingLot()
-
- def push(self, el: Any) -> None:
- self._deque.append(el)
- self._empty_wait.unpark()
-
- def extend(self, els: Iterable[Any]) -> None:
- for el in els:
- self.push(el)
-
- async def pop(self) -> Any:
- if not self._deque:
- await self._empty_wait.park()
- return self._deque.popleft()
-
- def __len__(self):
- return len(self._deque)
-
-
-@total_ordering
-@dataclass
-class _ReverseOrdWrapper:
- inner: Any
-
- def __lt__(self, other):
- return other.inner < self.inner
-
-
-class MaxHeap:
- def __init__(self):
- self._heap = []
- self._empty_wait = ParkingLot()
-
- def push(self, item):
- heapq.heappush(self._heap, _ReverseOrdWrapper(item))
- self._empty_wait.unpark()
-
- async def pop(self):
- if not self._heap:
- await self._empty_wait.park()
- return heapq.heappop(self._heap).inner
-
- def __len__(self) -> int:
- return len(self._heap)
+++ /dev/null
-"""Multi-auth HTTP connection pool with Reddit API wrappers and rate limiting."""
-
-from abc import ABCMeta, abstractmethod
-from dataclasses import dataclass, field
-from functools import total_ordering
-from socket import EAI_AGAIN, EAI_FAIL, gaierror
-import datetime as dt
-import logging
-
-from asks.response_objects import Response
-import asks
-import trio
-
-from strikebot import __version__ as VERSION
-from strikebot.common import obj_digest
-from strikebot.queue import MaxHeap, Queue
-
-
-REQUEST_DELAY = dt.timedelta(seconds = 1)
-
-TOKEN_UPDATE_DELAY = dt.timedelta(minutes = 15)
-
-USER_AGENT_FMT = f"any:net.jcornell.strikebot.{{auth_id}}:v{VERSION} (by /u/jaxklax)"
-
-API_BASE_URL = "https://oauth.reddit.com"
-
-
-@total_ordering
-class _Request(metaclass = ABCMeta):
- _SUBTYPE_PRECEDENCE = None # assigned later
-
- def __lt__(self, other):
- if type(self) is type(other):
- return self._subtype_cmp_key() < other._subtype_cmp_key()
- else:
- prec = self._SUBTYPE_PRECEDENCE
- return prec.index(type(other)) < prec.index(type(self))
-
- def __eq__(self, other):
- if type(self) is type(other):
- return self._subtype_cmp_key() == other._subtype_cmp_key()
- else:
- return False
-
- @abstractmethod
- def _subtype_cmp_key(self):
- # a large key corresponds to a high priority
- raise NotImplementedError()
-
- @abstractmethod
- def to_asks_kwargs(self):
- raise NotImplementedError()
-
-
-@dataclass(eq = False)
-class StrikeRequest(_Request):
- thread_id: str
- update_name: str
- update_ts: dt.datetime
-
- def to_asks_kwargs(self):
- return {
- "method": "POST",
- "path": f"/api/live/{self.thread_id}/strike_update",
- "data": {
- "api_type": "json",
- "id": self.update_name,
- },
- }
-
- def _subtype_cmp_key(self):
- return self.update_ts
-
-
-@dataclass(eq = False)
-class DeleteRequest(_Request):
- thread_id: str
- update_name: str
- update_ts: dt.datetime
-
- def to_asks_kwargs(self):
- return {
- "method": "POST",
- "path": f"/api/live/{self.thread_id}/delete_update",
- "data": {
- "api_type": "json",
- "id": self.update_name,
- },
- }
-
- def _subtype_cmp_key(self):
- return -self.update_ts.timestamp()
-
-
-@dataclass(eq = False)
-class AboutLiveThreadRequest(_Request):
- thread_id: str
- created: float = field(default_factory = trio.current_time)
-
- def to_asks_kwargs(self):
- return {
- "method": "GET",
- "path": f"/live/{self.thread_id}/about",
- "params": {
- "raw_json": "1",
- },
- }
-
- def _subtype_cmp_key(self):
- return -self.created
-
-
-@dataclass(eq = False)
-class _BaseUpdatePostRequest(_Request, metaclass = ABCMeta):
- thread_id: str
- body: str
- created: float = field(default_factory = trio.current_time)
-
- def to_asks_kwargs(self):
- return {
- "method": "POST",
- "path": f"/live/{self.thread_id}/update",
- "data": {
- "api_type": "json",
- "body": self.body,
- },
- }
-
-
-@dataclass(eq = False)
-class ReportUpdateRequest(_BaseUpdatePostRequest):
- def _subtype_cmp_key(self):
- return -self.created
-
-
-@dataclass(eq = False)
-class CorrectionUpdateRequest(_BaseUpdatePostRequest):
- def _subtype_cmp_key(self):
- return -self.created
-
-
-# highest priority first
-_Request._SUBTYPE_PRECEDENCE = [
- AboutLiveThreadRequest,
- StrikeRequest,
- CorrectionUpdateRequest,
- ReportUpdateRequest,
- DeleteRequest,
-]
-
-
-@dataclass
-class AppCooldown:
- auth_id: int
- ready_at: float # Trio time
-
-
-class ApiClientPool:
- def __init__(
- self,
- auth_ids: set[int],
- db_messenger: "strikebot.db.Messenger",
- request_queue_limit: int,
- error_window: dt.timedelta,
- error_delay: dt.timedelta,
- logger: logging.Logger,
- ):
- self._auth_ids = auth_ids
- self._db_messenger = db_messenger
- self._request_queue_limit = request_queue_limit
- self._error_window = error_window
- self._error_delay = error_delay
- self._logger = logger
-
- now = trio.current_time()
- self._tokens = {}
- self._app_queue = Queue()
- self._waiting = {id_: AppCooldown(id_, now) for id_ in auth_ids}
- self._request_queue = MaxHeap()
-
- # pool-wide API error backoff
- self._last_error = None
- self._global_resume = None
-
- self._session = asks.Session(connections = len(auth_ids))
- self._session.base_location = API_BASE_URL
-
- async def _update_tokens(self):
- tokens = await self._db_messenger.do("get_auth_tokens", (self._auth_ids,))
- self._tokens.update(tokens)
-
- awaken_auths = self._waiting.keys() & tokens.keys()
- self._app_queue.extend(self._waiting.pop(auth_id) for auth_id in awaken_auths)
- self._logger.debug("updated API tokens")
-
- async def token_updater_impl(self):
- last_update = None
- while True:
- if last_update is not None:
- await trio.sleep_until(last_update + TOKEN_UPDATE_DELAY.total_seconds())
-
- if last_update is None:
- last_update = trio.current_time()
- else:
- last_update += TOKEN_UPDATE_DELAY.total_seconds()
-
- await self._update_tokens()
-
- def _check_queue_size(self) -> None:
- if len(self._request_queue) > self._request_queue_limit:
- raise RuntimeError("request queue size exceeded limit")
-
- async def make_request(self, request: _Request) -> Response:
- (resp_tx, resp_rx) = trio.open_memory_channel(0)
- self.enqueue_request(request, resp_tx)
- async with resp_rx:
- return await resp_rx.receive()
-
- def enqueue_request(self, request: _Request, resp_tx = None) -> None:
- self._logger.debug("request {}: {}".format(obj_digest(request), type(request).__name__))
- self._request_queue.push((request, resp_tx))
- self._check_queue_size()
- self._logger.debug(f"{len(self._request_queue)} requests in queue")
-
- async def worker_impl(self, task_status):
- task_status.started()
- while True:
- (request, resp_tx) = await self._request_queue.pop()
- cooldown = await self._app_queue.pop()
- await trio.sleep_until(cooldown.ready_at)
- if self._global_resume:
- await trio.sleep_until(self._global_resume)
-
- asks_kwargs = request.to_asks_kwargs()
- headers = asks_kwargs.setdefault("headers", {})
- headers.update({
- "Authorization": "Bearer {}".format(self._tokens[cooldown.auth_id]),
- "User-Agent": USER_AGENT_FMT.format(auth_id = cooldown.auth_id),
- })
-
- request_time = trio.current_time()
- try:
- resp = await self._session.request(**asks_kwargs)
- except gaierror as e:
- if e.errno in [EAI_FAIL, EAI_AGAIN]:
- # DNS failure, probably temporary
- error = True
- else:
- raise
- else:
- resp.body # read response
- error = False
- wait_for_token = False
- log_suffix = " (request {})".format(obj_digest(request))
- if resp.status_code == 429:
- # We disagreed about the rate limit state; just try again later.
- self._logger.warning("rate limited by Reddit API" + log_suffix)
- error = True
- elif resp.status_code == 401:
- self._logger.warning("got HTTP 401 from Reddit API" + log_suffix)
- error = True
- wait_for_token = True
- elif resp.status_code in [404, 500, 503]:
- self._logger.warning(f"got HTTP {resp.status_code} from Reddit API, retrying" + log_suffix)
- error = True
- elif 400 <= resp.status_code < 500:
- # If we're doing something wrong, let's catch it right away.
- raise RuntimeError(f"unexpected client error response: {resp.status_code}" + log_suffix)
- else:
- if resp.status_code != 200:
- raise RuntimeError(f"unexpected status code {resp.status_code}")
- self._logger.debug("success" + log_suffix)
- if resp_tx:
- await resp_tx.send(resp)
-
- if error:
- self._request_queue.push((request, resp_tx))
- self._check_queue_size()
- if self._last_error:
- spread = dt.timedelta(seconds = request_time - self._last_error)
- if spread <= self._error_window:
- self._global_resume = request_time + self._error_delay.total_seconds()
- self._last_error = request_time
-
- cooldown.ready_at = request_time + REQUEST_DELAY.total_seconds()
- if wait_for_token:
- self._waiting[cooldown.auth_id] = cooldown
- if not self._app_queue:
- self._logger.error("all workers waiting for API tokens")
- else:
- self._app_queue.push(cooldown)
+++ /dev/null
-from unittest import TestCase
-
-from strikebot.updates import parse_update
-
-
-def _build_payload(body_html: str) -> dict[str, str]:
- return {"body_html": body_html}
-
-
-class UpdateParsingTests(TestCase):
- def test_successful_counts(self):
- pu = parse_update(_build_payload("<div>12,345,678 spaghetti</div>"), None, "")
- self.assertEqual(pu.number, 12345678)
- self.assertFalse(pu.deletable)
-
- pu = parse_update(_build_payload("<div><p>0</p><br/>oz</div>"), None, "")
- self.assertEqual(pu.number, 0)
- self.assertFalse(pu.deletable)
-
- pu = parse_update(_build_payload("<div>121 345 621</div>"), 121, "")
- self.assertEqual(pu.number, 121)
- self.assertFalse(pu.deletable)
-
- pu = parse_update(_build_payload("28 336 816"), 28336816, "")
- self.assertEqual(pu.number, 28336816)
- self.assertTrue(pu.deletable)
-
- def test_non_counts(self):
- pu = parse_update(_build_payload("<div><table><tbody><tr><td>zoo</td></tr></tbody></table></div>"), None, "")
- self.assertFalse(pu.count_attempt)
- self.assertFalse(pu.deletable)
-
- def test_typos(self):
- pu = parse_update(_build_payload("<span>v9</span>"), 888, "")
- self.assertIsNone(pu.number)
- self.assertTrue(pu.count_attempt)
-
- pu = parse_update(_build_payload("<div>v11.585 Empire</div>"), None, "")
- self.assertIsNone(pu.number)
- self.assertTrue(pu.count_attempt)
- self.assertFalse(pu.deletable)
-
- pu = parse_update(_build_payload("<div>11, 585, 22 </div>"), 11_585_202, "")
- self.assertIsNone(pu.number)
- self.assertTrue(pu.count_attempt)
- self.assertTrue(pu.deletable)
-
- pu = parse_update(_build_payload("<span>0490499</span>"), 4999, "")
- self.assertIsNone(pu.number)
- self.assertTrue(pu.count_attempt)
-
- # this update is marked non-deletable, but I'm not sure it should be
- pu = parse_update(_build_payload("12,123123"), 12_123_123, "")
- self.assertIsNone(pu.number)
- self.assertTrue(pu.count_attempt)
-
- def test_html_handling(self):
- pu = parse_update(_build_payload("123<hr>,456"), None, "")
- self.assertEqual(pu.number, 123)
-
- pu = parse_update(_build_payload("<pre>123\n456</pre>"), None, "")
- self.assertEqual(pu.number, 123)
+++ /dev/null
-from __future__ import annotations
-from dataclasses import dataclass
-from enum import Enum
-from typing import Optional
-import re
-
-from bs4 import BeautifulSoup
-
-
-Command = Enum("Command", ["RESET", "REPORT"])
-
-
-@dataclass
-class ParsedUpdate:
- number: Optional[int]
- command: Optional[Command]
- count_attempt: bool # either well-formed or typo
- deletable: bool
-
-
-def _parse_command(line: str, bot_user: str) -> Optional[Command]:
- if line.lower() == f"/u/{bot_user} reset".lower():
- return Command.RESET
- elif line.lower() in ["sidebar count", "current count"]:
- return Command.REPORT
- else:
- return None
-
-
-def parse_update(payload_data: dict, curr_count: Optional[int], bot_user: str) -> ParsedUpdate:
- # curr_count is the next number up, one more than the last count
-
- NEW_LINE = object()
- SPACE = object()
-
- # flatten the update content to plain text
- tree = BeautifulSoup(payload_data["body_html"], "html.parser")
- worklist = list(reversed(tree.contents))
- out = [[]]
- while worklist:
- el = worklist.pop()
- if isinstance(el, str):
- out[-1].append(el)
- elif el is SPACE:
- out[-1].append(el)
- elif el is NEW_LINE or el.name in ["br", "hr"]:
- if out[-1]:
- out.append([])
- elif el.name in ["em", "strong", "del", "span", "sup", "code", "a", "th", "td"]:
- worklist.extend(reversed(el.contents))
- elif el.name in ["ul", "ol", "table", "thead", "tbody"]:
- worklist.extend(reversed(el.contents))
- elif el.name in ["li", "p", "div", "blockquote"] or re.match(r"h[1-6]$", el.name):
- worklist.append(NEW_LINE)
- worklist.extend(reversed(el.contents))
- worklist.append(NEW_LINE)
- elif el.name == "pre":
- out.extend([l] for l in el.text.splitlines())
- out.append([])
- elif el.name == "tr":
- worklist.append(NEW_LINE)
- for (i, cell) in enumerate(reversed(el.contents)):
- worklist.append(cell)
- if i != len(el.contents) - 1:
- worklist.append(SPACE)
- worklist.append(NEW_LINE)
- else:
- raise RuntimeError(f"can't parse tag {el.name}")
-
- tmp_lines = (
- "".join(" " if part is SPACE else part for part in parts).strip()
- for parts in out
- )
- pre_strip_lines = list(filter(None, tmp_lines))
-
- # normalize whitespace according to HTML rendering rules
- # https://developer.mozilla.org/en-US/docs/Web/API/Document_Object_Model/Whitespace#explanation
- stripped_lines = [
- re.sub(" +", " ", l.replace("\t", " ").replace("\n", " ")).strip(" ")
- for l in pre_strip_lines
- ]
-
- return _parse_from_lines(stripped_lines, curr_count, bot_user)
-
-
-def _parse_from_lines(lines: list[str], curr_count: Optional[int], bot_user: str) -> ParsedUpdate:
- command = next(
- (cmd for l in lines if (cmd := _parse_command(l, bot_user))),
- None
- )
- if lines:
- # look for groups of digits (as many as possible) separated by a uniform separator from the valid set
- first = lines[0]
- match = re.match(
- "(?P<v>v)?(?P<neg>-)?(?P<num>\\d+((?P<sep>[,. \u2009]|, )\\d+((?P=sep)\\d+)*)?)",
- first,
- re.ASCII, # only recognize ASCII digits
- )
- if match:
- raw_digits = match["num"]
- sep = match["sep"]
- post = first[match.end() :]
-
- zeros = False
- while len(raw_digits) > 1 and raw_digits[0] == "0":
- zeros = True
- raw_digits = raw_digits.removeprefix("0").removeprefix(sep or "")
-
- parts = raw_digits.split(sep) if sep else [raw_digits]
- all_parts_valid = (
- sep is None
- or (
- 1 <= len(parts[0]) <= 3
- and all(len(p) == 3 for p in parts[1:])
- )
- )
- lone = len(lines) == 1 and (not post or post.isspace())
- typo = False
- if lone:
- if match["v"] and len(parts) == 1 and len(parts[0]) <= 2:
- # failed paste of leading digits
- typo = True
- elif match["v"] and all_parts_valid:
- # v followed by count
- typo = True
- elif curr_count is not None and abs(curr_count) >= 100 and bool(match["neg"]) == (curr_count < 0):
- goal_parts = _separate(str(abs(curr_count)))
- partials = [
- goal_parts[: -1] + [goal_parts[-1][: -1]], # missing last digit
- goal_parts[: -1] + [goal_parts[-1][: -2]], # missing last two digits
- goal_parts[: -1] + [goal_parts[-1][: -2] + goal_parts[-1][-1]], # missing second-last digit
- ]
- if parts in partials:
- # missing any of last two digits
- typo = True
- elif any(parts == p[: -1] + [p[-1] + goal_parts[0]] + goal_parts[1 :] for p in partials):
- # double paste
- typo = True
-
- if match["v"] or zeros or typo or (parts == ["0"] and match["neg"]):
- number = None
- count_attempt = True
- deletable = lone
- else:
- groups_okay = True
- if curr_count is not None and sep and sep.isspace():
- # Presume that the intended count consists of as many valid digit groups as necessary to match the
- # number of digits in the expected count, if possible.
- digit_count = len(str(abs(curr_count)))
- use_parts = []
- accum = 0
- for (i, part) in enumerate(parts):
- part_valid = len(part) <= 3 if i == 0 else len(part) == 3
- if part_valid and accum < digit_count:
- use_parts.append(part)
- accum += len(part)
- else:
- break
-
- # could still be a no-separator count with some extra digit groups on the same line
- if not use_parts:
- use_parts = [parts[0]]
-
- lone = lone and len(use_parts) == len(parts)
- else:
- # current count is unknown or no separator was used
- groups_okay = all_parts_valid
- use_parts = parts
-
- digits = "".join(use_parts)
- number = -int(digits) if match["neg"] else int(digits)
- special = (
- curr_count is not None
- and abs(number - curr_count) <= 25
- and _is_special_number(number)
- )
- if groups_okay:
- deletable = lone and not special
- if len(use_parts) == len(parts) and post and not post[0].isspace():
- count_attempt = curr_count is not None and abs(number - curr_count) <= 25
- number = None
- else:
- count_attempt = True
- else:
- number = None
- count_attempt = True
- deletable = False
-
- if first[0].isdigit() and not count_attempt:
- count_attempt = True
- deletable = False
- else:
- # no count attempt found
- number = None
- count_attempt = False
- deletable = False
- else:
- # no lines in update
- number = None
- count_attempt = False
- deletable = True
-
- return ParsedUpdate(
- number = number,
- command = command,
- count_attempt = count_attempt,
- deletable = deletable,
- )
-
-
-def _separate(digits: str) -> list[str]:
- mod = len(digits) % 3
- out = []
- if mod:
- out.append(digits[: mod])
- out.extend(digits[i : i + 3] for i in range(mod, len(digits), 3))
- return out
-
-
-def _is_special_number(num: int) -> bool:
- num_str = str(num)
- return bool(
- num % 1000 in [0, 1, 333, 999]
- or (num > 10_000_000 and "".join(reversed(num_str)) == num_str)
- or re.match(r"(.+)\1+$", num_str) # repeated sequence
- )
--- /dev/null
+# see Python `configparser' docs for precise file syntax
+
+[config]
+
+######## basic operating parameters
+
+# space-separated authorization IDs for Reddit API token lookup
+auth IDs = 13 15
+
+bot user = count_better
+
+# if false, don't strike or delete updates or report incorrectly stricken updates (optional, default true)
+enforcing = true
+
+# log messages at or above this severity to standard out; level names and numbers are supported (optional, default
+# WARNING); see https://docs.python.org/3/library/logging.html#levels
+log level = WARNING
+
+# if true, don't modify the live thread at all (implies enforcing = false) (optional, default false)
+read-only = false
+
+thread ID = abc123
+
+
+######## API pool error handling
+
+# (seconds) for error backoff, pause the API pool for this much time
+API pool error delay = 5.5
+
+# (seconds) pause the pool when the Reddit API returns two errors within any time span of this length
+API pool error window = 5.5
+
+# terminate if the Reddit API request queue exceeds this size
+request queue limit = 100
+
+
+######## live thread update handling
+
+# (seconds) maximum time to hold updates for reordering
+reorder buffer time = 0.25
+
+# (seconds) minimum time to retain updates to enable future resyncs
+update retention time = 120.0
+
+
+######## WebSocket pool
+
+# (seconds) maximum allowable spread in WebSocket message arrival times
+WS parity time = 1.0
+
+# goal size of WebSocket connection pool
+WS pool size = 5
+
+# (seconds) after receiving no messages for this much time, attempt to replace the connection if the pool isn't depleted
+WS silent limit = 30
+
+# (seconds) time after WebSocket handshake completion during which missed updates are excused
+WS warmup time = 0.3
+
+
+# Postgres database configuration; same options as in a connect string. Note that because Python's `SSLContext' is used
+# to implement TLS client auth, using `sslkey' to specify a key obtained from an OpenSSL engine may not be supported (a
+# file path is the only option).
+[db connect params]
+host = example.org
--- /dev/null
+[build-system]
+requires = ["setuptools>=42", "wheel"]
+build-backend = "setuptools.build_meta"
--- /dev/null
+[metadata]
+name = strikebot
+version = 0.0.7
+
+[options]
+package_dir =
+ = src
+packages = strikebot
+python_requires = ~= 3.9
+install_requires =
+ asks ~= 2.4
+ beautifulsoup4 ~= 4.9
+ trio == 0.19
+ trio-websocket == 0.9.2
+ triopg == 0.6.0
--- /dev/null
+import setuptools
+
+
+setuptools.setup()
--- /dev/null
+"""Count tracking logic and bot's external interface."""
+
+from contextlib import nullcontext as nullcontext
+from dataclasses import dataclass
+from functools import total_ordering
+from itertools import islice
+from typing import Optional
+from uuid import UUID
+import bisect
+import datetime as dt
+import importlib.metadata
+import itertools
+import logging
+
+from trio import CancelScope, EndOfChannel
+import trio
+
+from strikebot.updates import Command, parse_update
+
+
+__version__ = importlib.metadata.version(__package__)
+
+
+@dataclass
+class _Update:
+ id: str
+ name: str
+ author: str
+ number: Optional[int]
+ command: Optional[Command]
+ ts: dt.datetime
+ count_attempt: bool
+ deletable: bool
+ stricken: bool
+
+ def can_follow(self, prior: "_Update") -> bool:
+ """Determine whether this count update can follow another count update."""
+ return self.number == prior.number + 1 and self.author != prior.author
+
+ def __str__(self):
+ content = str(self.number)
+ if self.command:
+ content += "+" + self.command.name
+ return "{}({} by {})".format(self.id[4 : 8], content, self.author)
+
+
+@total_ordering
+@dataclass
+class _BufferedUpdate:
+ update: _Update
+ release_at: float # Trio time
+
+ def __lt__(self, other):
+ return self.update.ts < other.update.ts
+
+
+@total_ordering
+@dataclass
+class _TimelineUpdate:
+ update: _Update
+ accepted: bool
+
+ def __lt__(self, other):
+ return self.update.ts < other.update.ts
+
+ def rejected(self) -> bool:
+ """Whether the update should be stricken."""
+ return self.update.count_attempt and not self.accepted
+
+
+def _parse_rfc_4122_ts(ts: int) -> dt.datetime:
+ epoch = dt.datetime(1582, 10, 15, tzinfo = dt.timezone.utc)
+ return epoch + dt.timedelta(microseconds = ts / 10)
+
+
+def _format_update_ref(update: _Update, thread_id: str) -> str:
+ url = f"https://www.reddit.com/live/{thread_id}/updates/{update.id}"
+ author_md = update.author.replace("_", "\\_") # Reddit allows letters, digits, underscore, and hyphen
+ return f"[{update.number:,} by {author_md}]({url})"
+
+
+def _format_bad_strike_alert(update: _Update, thread_id: str) -> str:
+ return "_incorrectly stricken:_ " + _format_update_ref(update, thread_id)
+
+
+def _format_curr_count(last_valid: Optional[_Update]) -> str:
+ if last_valid and last_valid.number is not None:
+ count_str = f"{last_valid.number + 1:,}"
+ else:
+ count_str = "unknown"
+ return f"_current count:_ {count_str}"
+
+
+async def count_tracker_impl(
+ message_rx,
+ api_pool: "strikebot.reddit_api.ApiClientPool",
+ reorder_buffer_time: dt.timedelta, # duration to hold some updates for reordering
+ thread_id: str,
+ bot_user: str,
+ enforcing: bool,
+ read_only: bool,
+ update_retention: dt.timedelta,
+ logger: logging.Logger,
+) -> None:
+ from strikebot.reddit_api import CorrectionUpdateRequest, DeleteRequest, ReportUpdateRequest, StrikeRequest
+
+ enforcing = enforcing and not read_only
+
+ buffer_ = []
+ timeline = []
+ last_valid: Optional[_Update] = None
+ pending_strikes: set[str] = set() # names of updates to mark stricken on arrival
+ forgotten: bool = False # whether the retention period for an update has passed during the run
+
+ def handle_update(update):
+ nonlocal last_valid
+
+ tu = _TimelineUpdate(update, accepted = False)
+
+ pos = bisect.bisect(timeline, tu)
+ if pos != len(timeline):
+ logger.warning(f"long transpo: {update.id}")
+
+ if pos > 0 and timeline[pos - 1].update.id == update.id:
+ # The pool sent this update message multiple times. This could be because the message reached parity and
+ # its bucket was dropped just before a new connection was opened and sent its own copy, but connections also
+ # seem to send duplicate messages.
+ return
+
+ pred = next(
+ (timeline[i].update for i in reversed(range(pos)) if timeline[i].accepted),
+ None
+ )
+ contender = update.command is Command.RESET or update.number is not None
+ if contender and pred is None and last_valid and forgotten:
+ # This is a really long transpo which we can't judge valid or invalid since we've forgotten the surrounding
+ # updates. The best we can do is just ignore it.
+ logger.warning(f"ignoring {update.id}: no valid prior count on record")
+ else:
+ timeline.insert(pos, tu)
+ tu.accepted = (
+ update.command is Command.RESET
+ or (
+ update.number is not None
+ and (pred is None or pred.number is None or update.can_follow(pred))
+ and not update.stricken
+ )
+ )
+ logger.debug("{} {} -> {}".format("accepted" if tu.accepted else "rejected", pred, update))
+ if tu.accepted:
+ # resync subsequent updates
+ newly_invalid = []
+ resync_last_valid = update
+ converged = False
+ for scan_tu in islice(timeline, pos + 1, None):
+ if scan_tu.update.command is Command.RESET:
+ resync_last_valid = scan_tu.update
+ if last_valid:
+ converged = True
+ elif scan_tu.update.number is not None:
+ accept = (
+ (resync_last_valid.number is None or scan_tu.update.can_follow(resync_last_valid))
+ and not scan_tu.update.stricken
+ )
+ if accept:
+ assert scan_tu.accepted
+ resync_last_valid = scan_tu.update
+ # resync would have no effect past this point
+ if last_valid:
+ converged = True
+ elif scan_tu.accepted:
+ newly_invalid.append(scan_tu)
+ if scan_tu.update is last_valid:
+ last_valid = None
+ scan_tu.accepted = accept
+ if converged:
+ break
+
+ if converged:
+ assert last_valid.ts >= resync_last_valid.ts
+ else:
+ last_valid = resync_last_valid
+
+ parts = []
+ unstrikable = [tu for tu in newly_invalid if tu.update.stricken]
+ if unstrikable:
+ parts.append(
+ "The following counts are invalid:\n\n"
+ + "\n".join(" - " + _format_update_ref(tu.update, thread_id) for tu in unstrikable)
+ )
+
+ if update.stricken:
+ logger.info(f"bad strike of {update.id}")
+ parts.append(_format_bad_strike_alert(update, thread_id))
+
+ if parts:
+ parts.append(_format_curr_count(last_valid))
+ if not read_only:
+ api_pool.enqueue_request(CorrectionUpdateRequest(thread_id, "\n\n".join(parts)))
+
+ for invalid_tu in newly_invalid:
+ if not invalid_tu.update.stricken:
+ if enforcing:
+ api_pool.enqueue_request(StrikeRequest(thread_id, invalid_tu.update.name, invalid_tu.update.ts))
+ invalid_tu.update.stricken = True
+ elif update.count_attempt and not update.stricken:
+ if enforcing:
+ api_pool.enqueue_request(StrikeRequest(thread_id, update.name, update.ts))
+ update.stricken = True
+
+ if not read_only and update.command is Command.REPORT:
+ api_pool.enqueue_request(ReportUpdateRequest(thread_id, body = _format_curr_count(last_valid)))
+
+ async with message_rx:
+ while True:
+ if buffer_:
+ deadline = min(bu.release_at for bu in buffer_)
+ cancel_scope = CancelScope(deadline = deadline)
+ else:
+ cancel_scope = nullcontext()
+
+ msg = None
+ with cancel_scope:
+ try:
+ msg = await message_rx.receive()
+ except EndOfChannel:
+ break
+
+ if msg:
+ if msg.data["type"] == "update":
+ payload_data = msg.data["payload"]["data"]
+ stricken = payload_data["name"] in pending_strikes
+ pending_strikes.discard(payload_data["name"])
+ if payload_data["author"] != bot_user:
+ if last_valid and last_valid.number is not None:
+ next_up = last_valid.number + 1
+ else:
+ next_up = None
+ pu = parse_update(payload_data, next_up, bot_user)
+ rfc_ts = UUID(payload_data["id"]).time
+ update = _Update(
+ id = payload_data["id"],
+ name = payload_data["name"],
+ author = payload_data["author"],
+ number = pu.number,
+ command = pu.command,
+ ts = _parse_rfc_4122_ts(rfc_ts),
+ count_attempt = pu.count_attempt,
+ deletable = pu.deletable,
+ stricken = stricken or payload_data["stricken"],
+ )
+ release_at = msg.timestamp + reorder_buffer_time.total_seconds()
+ bisect.insort(buffer_, _BufferedUpdate(update, release_at))
+ elif msg.data["type"] == "strike":
+ slot = next(
+ (
+ slot for slot in itertools.chain(buffer_, reversed(timeline))
+ if slot.update.name == msg.data["payload"]
+ ),
+ None
+ )
+ if slot:
+ if not slot.update.stricken:
+ slot.update.stricken = True
+ if isinstance(slot, _TimelineUpdate) and slot.accepted:
+ logger.info(f"bad strike of {slot.update.id}")
+ if not read_only:
+ body = _format_bad_strike_alert(slot.update, thread_id)
+ api_pool.enqueue_request(CorrectionUpdateRequest(thread_id, body))
+ else:
+ pending_strikes.add(msg.data["payload"])
+
+ threshold = dt.datetime.now(dt.timezone.utc) - update_retention
+
+ # pull updates from the reorder buffer and process
+ new_buffer = []
+ for bu in buffer_:
+ process = (
+ # not a count
+ bu.update.number is None
+
+ # valid count
+ or (last_valid is None or last_valid.number is None or bu.update.can_follow(last_valid))
+ or bu.update.command is Command.RESET
+
+ or trio.current_time() >= bu.release_at
+ or bu.update.command is Command.RESET
+ )
+
+ if process:
+ if bu.update.ts < threshold:
+ logger.warning(f"ignoring {bu.update}: arrived past retention window")
+ else:
+ logger.debug(f"processing {bu.update}")
+ handle_update(bu.update)
+ else:
+ logger.debug(f"holding {bu.update}, checked against {last_valid})")
+ new_buffer.append(bu)
+ buffer_ = new_buffer
+ logger.debug("last count {}".format(last_valid))
+
+ # delete/forget old updates
+ new_timeline = []
+ i = 0
+ for (i, tu) in enumerate(timeline):
+ if i >= len(timeline) - 10 or tu.update.ts >= threshold:
+ break
+ elif tu.rejected() and tu.update.deletable:
+ if enforcing:
+ api_pool.enqueue_request(DeleteRequest(thread_id, tu.update.name, tu.update.ts))
+ elif tu.update is last_valid:
+ new_timeline.append(tu)
+
+ if i:
+ forgotten = True
+ new_timeline.extend(islice(timeline, i, None))
+ timeline = new_timeline
--- /dev/null
+from enum import Enum
+from inspect import getmodule
+from logging import FileHandler, getLogger, StreamHandler
+from pathlib import Path
+from signal import SIGUSR1
+from sys import stdout
+from traceback import StackSummary
+from typing import Optional
+import argparse
+import configparser
+import datetime as dt
+import logging
+import ssl
+
+from trio import open_memory_channel, open_nursery, open_signal_receiver
+
+try:
+ from trio.lowlevel import current_root_task, Task
+except ImportError:
+ from trio.hazmat import current_root_task, Task
+
+import trio_asyncio
+import triopg
+
+from strikebot import count_tracker_impl
+from strikebot.db import Client, Messenger
+from strikebot.live_ws import HealingReadPool, PoolMerger
+from strikebot.reddit_api import ApiClientPool
+
+
+_DEBUG_LOG_PATH: Optional[Path] = None # file to write debug logs to, if any
+
+
+_TaskKind = Enum(
+ "_TaskKind",
+ [
+ "API_WORKER",
+ "DB_CLIENT",
+ "TOKEN_UPDATE",
+ "TRACKER",
+ "WS_MERGE_READER",
+ "WS_MERGE_TIMER",
+ "WS_REFRESHER",
+ "WS_WORKER",
+ ]
+)
+
+
+def _classify_task(task: Task) -> _TaskKind:
+ mod_name = getmodule(task.coro.cr_code).__name__
+ if mod_name.startswith("trio_asyncio."):
+ return None
+ elif task.coro.__name__ == "_signal_handler_impl":
+ return None
+ else:
+ by_coro_name = {
+ "db_client_impl": _TaskKind.DB_CLIENT,
+ "token_updater_impl": _TaskKind.TOKEN_UPDATE,
+ "event_reader_impl": _TaskKind.WS_MERGE_READER,
+ "timeout_handler_impl": _TaskKind.WS_MERGE_TIMER,
+ "conn_refresher_impl": _TaskKind.WS_REFRESHER,
+ "count_tracker_impl": _TaskKind.TRACKER,
+ "_reader_impl": _TaskKind.WS_WORKER,
+ "worker_impl": _TaskKind.API_WORKER,
+ }
+ return by_coro_name[task.coro.__name__]
+
+
+def _get_all_tasks():
+ [ta_loop] = [
+ task
+ for nursery in current_root_task().child_nurseries
+ for task in nursery.child_tasks
+ if getmodule(task.coro.cr_code).__name__.startswith("trio_asyncio.")
+ ]
+ for nursery in ta_loop.child_nurseries:
+ yield from nursery.child_tasks
+
+
+def _print_task_info():
+ groups = {kind: [] for kind in _TaskKind}
+ for task in _get_all_tasks():
+ kind = _classify_task(task)
+ if kind:
+ coro = task.coro
+ frame_tuples = []
+ while coro is not None:
+ if hasattr(coro, "cr_frame"):
+ frame_tuples.append((coro.cr_frame, coro.cr_frame.f_lineno))
+ coro = coro.cr_await
+ else:
+ frame_tuples.append((coro.gi_frame, coro.gi_frame.f_lineno))
+ coro = coro.gi_yieldfrom
+ groups[kind].append(StackSummary.extract(iter(frame_tuples)))
+ for kind in _TaskKind:
+ print(kind.name)
+ for ss in groups[kind]:
+ print(" task")
+ for line in ss.format():
+ print(" " + line, end = "")
+
+
+async def _signal_handler_impl():
+ with open_signal_receiver(SIGUSR1) as sig_src:
+ async for _ in sig_src:
+ _print_task_info()
+
+
+ap = argparse.ArgumentParser(__package__)
+ap.add_argument("config_path")
+args = ap.parse_args()
+
+
+parser = configparser.ConfigParser()
+with open(args.config_path) as config_file:
+ parser.read_file(config_file)
+
+main_cfg = parser["config"]
+
+api_pool_error_delay = dt.timedelta(seconds = main_cfg.getfloat("API pool error delay"))
+api_pool_error_window = dt.timedelta(seconds = main_cfg.getfloat("API pool error window"))
+auth_ids = set(map(int, main_cfg["auth IDs"].split()))
+bot_user = main_cfg["bot user"]
+enforcing = main_cfg.getboolean("enforcing", fallback = True)
+
+raw_log_level = main_cfg.get("log level", "WARNING")
+try:
+ log_level = int(raw_log_level)
+except ValueError:
+ log_level = raw_log_level
+
+read_only = main_cfg.getboolean("read-only", fallback = False)
+reorder_buffer_time = dt.timedelta(seconds = main_cfg.getfloat("reorder buffer time"))
+request_queue_limit = main_cfg.getint("request queue limit")
+thread_id = main_cfg["thread ID"]
+update_retention = dt.timedelta(seconds = main_cfg.getfloat("update retention time"))
+ws_parity_time = dt.timedelta(seconds = main_cfg.getfloat("WS parity time"))
+ws_pool_size = main_cfg.getint("WS pool size")
+ws_silent_limit = dt.timedelta(seconds = main_cfg.getfloat("WS silent limit"))
+ws_warmup = dt.timedelta(seconds = main_cfg.getfloat("WS warmup time"))
+
+
+db_cfg = parser["db connect params"]
+cert_path = db_cfg.pop("sslcert", None)
+key_path = db_cfg.pop("sslkey", None)
+key_pass = db_cfg.pop("sslpassword", None)
+
+getters = {
+ "port": db_cfg.getint,
+}
+db_connect_kwargs = {k: getters.get(k, db_cfg.get)(k) for k in db_cfg}
+
+if cert_path:
+ # patch TLS client cert auth support; asyncpg adds it in a later version
+ db_ssl_ctx = ssl.create_default_context()
+ db_ssl_ctx.load_cert_chain(
+ cert_path,
+ key_path,
+ None if key_pass is None else lambda: key_pass,
+ )
+ db_connect_kwargs["ssl"] = db_ssl_ctx
+
+
+if read_only and enforcing:
+ raise RuntimeError("can't use read-only mode with enforcing on")
+
+
+logger = getLogger(__package__)
+logger.setLevel(logging.NOTSET - 1) # NOTSET means inherit from parent; we use handlers to filter
+
+handler = StreamHandler(stdout)
+handler.setLevel(log_level)
+handler.setFormatter(logging.Formatter("{asctime:23}: {name:17} {levelname:8}: {message}", style = "{"))
+logger.addHandler(handler)
+
+if _DEBUG_LOG_PATH: # TODO remove this ad hoc setup
+ debug_handler = FileHandler(_DEBUG_LOG_PATH, mode = "w")
+ debug_handler.setLevel(logging.DEBUG)
+ debug_handler.setFormatter(logging.Formatter("{asctime:23}: {name:17} {levelname:8}: {message}", style = "{"))
+ logger.addHandler(debug_handler)
+
+
+async def main():
+ async with (
+ triopg.connect(**db_connect_kwargs) as db_conn,
+ open_nursery() as nursery_a,
+ open_nursery() as nursery_b,
+ open_nursery() as ws_pool_nursery,
+ open_nursery() as nursery_c,
+ ):
+ nursery_a.start_soon(_signal_handler_impl)
+
+ client = Client(db_conn)
+ db_messenger = Messenger(client)
+ nursery_a.start_soon(db_messenger.db_client_impl)
+
+ api_pool = ApiClientPool(
+ auth_ids, db_messenger, request_queue_limit, api_pool_error_window, api_pool_error_delay,
+ logger.getChild("api")
+ )
+ nursery_a.start_soon(api_pool.token_updater_impl)
+ for _ in auth_ids:
+ await nursery_a.start(api_pool.worker_impl)
+
+ (message_tx, message_rx) = open_memory_channel(0)
+ (pool_event_tx, pool_event_rx) = open_memory_channel(0)
+
+ merger = PoolMerger(pool_event_rx, message_tx, ws_parity_time, ws_warmup, logger.getChild("merge"))
+ nursery_b.start_soon(merger.event_reader_impl)
+ nursery_b.start_soon(merger.timeout_handler_impl)
+
+ ws_pool = HealingReadPool(
+ ws_pool_nursery, ws_pool_size, thread_id, api_pool, pool_event_tx, logger.getChild("live_ws"),
+ ws_silent_limit
+ )
+ nursery_c.start_soon(ws_pool.conn_refresher_impl)
+
+ nursery_c.start_soon(
+ count_tracker_impl, message_rx, api_pool, reorder_buffer_time, thread_id, bot_user, enforcing, read_only,
+ update_retention, logger.getChild("track")
+ )
+ await ws_pool.init_workers()
+
+trio_asyncio.run(main)
--- /dev/null
+from typing import Any
+
+
+def int_digest(val: int) -> str:
+ return "{:04x}".format(val & 0xffff)
+
+
+def obj_digest(obj: Any) -> str:
+ """Makes a short digest of the identity of an object."""
+ return int_digest(id(obj))
--- /dev/null
+"""Single-connection Postgres client with high-level wrappers for operations."""
+
+from dataclasses import dataclass
+from typing import Any, Iterable
+
+import trio
+
+
+def _channel_sender(method):
+ async def wrapped(self, resp_channel, *args, **kwargs):
+ ret = await method(self, *args, **kwargs)
+ async with resp_channel:
+ await resp_channel.send(ret)
+
+ return wrapped
+
+
+@dataclass
+class Client:
+ """High-level wrappers for DB operations."""
+ conn: Any
+
+ @_channel_sender
+ async def get_auth_tokens(self, auth_ids: set[int]) -> dict[int, str]:
+ results = await self.conn.fetch(
+ """
+ select
+ distinct on (id)
+ id, access_token
+ from
+ public.reddit_app_authorization
+ join unnest($1::integer[]) as request_id on id = request_id
+ where expires > current_timestamp
+ """,
+ list(auth_ids),
+ )
+ return {r["id"]: r["access_token"] for r in results}
+
+
+class Messenger:
+ def __init__(self, client: Client):
+ self._client = client
+ (self._request_tx, self._request_rx) = trio.open_memory_channel(0)
+
+ async def do(self, method_name: str, args: Iterable) -> Any:
+ """This is run by consumers of the DB wrapper."""
+ method = getattr(self._client, method_name)
+ (resp_tx, resp_rx) = trio.open_memory_channel(0)
+ coro = method(resp_tx, *args)
+ await self._request_tx.send(coro)
+ async with resp_rx:
+ return await resp_rx.receive()
+
+ async def db_client_impl(self):
+ """This is the DB client Trio task."""
+ async with self._client.conn, self._request_rx:
+ async for coro in self._request_rx:
+ await coro
--- /dev/null
+from collections import deque
+from contextlib import suppress
+from dataclasses import dataclass
+from typing import Any, AsyncContextManager, Optional
+import datetime as dt
+import json
+import logging
+import math
+
+from trio_websocket import ConnectionClosed, HandshakeError, open_websocket_url, WebSocketConnection
+import trio
+
+from strikebot.common import int_digest, obj_digest
+from strikebot.reddit_api import AboutLiveThreadRequest
+
+
+@dataclass
+class _PoolEvent:
+ timestamp: float # Trio time
+
+
+@dataclass
+class _Message(_PoolEvent):
+ data: Any
+ scope: Any
+
+ def dedup_tag(self):
+ """A hashable essence of the contents of this message."""
+ if self.data["type"] == "update":
+ return ("update", self.data["payload"]["data"]["id"])
+ elif self.data["type"] in {"strike", "delete"}:
+ return (self.data["type"], self.data["payload"])
+ else:
+ raise ValueError(self.data["type"])
+
+
+@dataclass
+class _ConnectionDown(_PoolEvent):
+ scope: trio.CancelScope
+
+
+@dataclass
+class _ConnectionUp(_PoolEvent):
+ scope: trio.CancelScope
+
+
+class HealingReadPool:
+ def __init__(
+ self, pool_nursery: trio.Nursery, size, live_thread_id, api_client_pool, pool_event_tx, logger: logging.Logger,
+ silent_limit: dt.timedelta
+ ):
+ assert size >= 2
+ self._nursery = pool_nursery
+ self._size = size
+ self._live_thread_id = live_thread_id
+ self._api_client_pool = api_client_pool
+ self._pool_event_tx = pool_event_tx
+ self._logger = logger
+ self._silent_limit = silent_limit
+
+ (self._refresh_queue_tx, self._refresh_queue_rx) = trio.open_memory_channel(size)
+
+ # number of workers who have stopped receiving updates but whose tasks aren't yet stopped
+ self._closing = 0
+
+ async def init_workers(self):
+ for _ in range(self._size):
+ await self._spawn_reader()
+ if not self._nursery.child_tasks:
+ raise RuntimeError("Unable to create any WS connections")
+
+ async def _reader_impl(self, conn_ctx: AsyncContextManager, cancel_scope, refresh_tx):
+ # TODO could use task's implicit cancel scope
+ try:
+ async with conn_ctx as conn:
+ self._logger.debug("scope up: {}".format(obj_digest(cancel_scope)))
+ async with refresh_tx:
+ silent_timeout = False
+ with cancel_scope, suppress(ConnectionClosed):
+ while True:
+ with trio.move_on_after(self._silent_limit.total_seconds()) as timeout_scope:
+ message = await conn.get_message()
+
+ if timeout_scope.cancelled_caught:
+ if len(self._nursery.child_tasks) - self._closing == self._size:
+ silent_timeout = True
+ break
+ else:
+ self._logger.debug("not replacing connection {}; {} tasks, {} closing".format(
+ obj_digest(cancel_scope),
+ len(self._nursery.child_tasks),
+ self._closing,
+ ))
+ else:
+ event = _Message(trio.current_time(), json.loads(message), cancel_scope)
+ await self._pool_event_tx.send(event)
+
+ self._closing += 1
+ if silent_timeout:
+ await conn.aclose()
+ self._logger.debug("replacing WS connection {} due to silent timeout".format(obj_digest(cancel_scope)))
+ elif cancel_scope.cancelled_caught:
+ await conn.aclose(1008, "Server unexpectedly stopped sending messages")
+ self._logger.warning("replacing WS connection {} due to missed update".format(obj_digest(cancel_scope)))
+ else:
+ self._logger.warning("replacing WS connection {} closed by server".format(obj_digest(cancel_scope)))
+
+ refresh_tx.send_nowait(cancel_scope)
+ self._logger.debug("scope down: {} ({} tasks, {} closing)".format(
+ obj_digest(cancel_scope),
+ len(self._nursery.child_tasks),
+ self._closing,
+ ))
+ self._closing -= 1
+ except HandshakeError:
+ self._logger.error("handshake error while opening WS connection")
+
+ async def _new_connection(self) -> Optional[AsyncContextManager[WebSocketConnection]]:
+ request = AboutLiveThreadRequest(self._live_thread_id)
+ resp = await self._api_client_pool.make_request(request)
+ if resp.status_code == 200:
+ url = json.loads(resp.text)["data"]["websocket_url"]
+ return open_websocket_url(url)
+ else:
+ self._logger.error(f"thread info request failed (HTTP {resp.status_code}); can't open WS connection")
+ return None
+
+ async def _spawn_reader(self) -> None:
+ conn = await self._new_connection()
+ if conn:
+ new_scope = trio.CancelScope()
+ await self._pool_event_tx.send(_ConnectionUp(trio.current_time(), new_scope))
+ refresh_tx = self._refresh_queue_tx.clone()
+ self._nursery.start_soon(self._reader_impl, conn, new_scope, refresh_tx)
+
+ async def conn_refresher_impl(self):
+ """Task to monitor and replace WS connections as they disconnect or go silent."""
+ async with self._refresh_queue_rx:
+ async for old_scope in self._refresh_queue_rx:
+ await self._pool_event_tx.send(_ConnectionDown(trio.current_time(), old_scope))
+ await self._spawn_reader()
+ if not self._nursery.child_tasks:
+ raise RuntimeError("WS pool depleted")
+
+
+def _tag_digest(tag):
+ return int_digest(hash(tag))
+
+
+def _format_scope_list(scopes):
+ return ", ".join(sorted(map(obj_digest, scopes)))
+
+
+class PoolMerger:
+ @dataclass
+ class _Bucket:
+ start: float
+ recipients: set
+
+ def __init__(self, pool_event_rx, message_tx, parity_timeout: dt.timedelta, conn_warmup: dt.timedelta, logger: logging.Logger):
+ """
+ :param parity_timeout: max delay between message receipt on different connections
+ :param conn_warmup: max interval after connection within which missed updates are allowed
+ """
+ self._pool_event_rx = pool_event_rx
+ self._message_tx = message_tx
+ self._parity_timeout = parity_timeout
+ self._conn_warmup = conn_warmup
+ self._logger = logger
+
+ self._buckets = {}
+ self._scope_activations = {}
+ self._pending = deque()
+ self._outgoing_scopes = set()
+ (self._timer_poke_tx, self._timer_poke_rx) = trio.open_memory_channel(math.inf)
+
+ def _log_current_buckets(self):
+ self._logger.debug("current buckets: {}".format(", ".join(map(_tag_digest, self._buckets.keys()))))
+
+ async def event_reader_impl(self):
+ """Drop unused messages, deduplicate useful ones, and communicate with the timeout handler."""
+ async with self._pool_event_rx, self._timer_poke_tx:
+ async for event in self._pool_event_rx:
+ if isinstance(event, _ConnectionUp):
+ # An early add of an active scope could mean it's expected on a message that fired before it opened,
+ # resulting in a false positive for replacement. The timer task merges these in among the buckets by
+ # timestamp to avoid that.
+ self._pending.append(event)
+ elif isinstance(event, _Message):
+ if event.data["type"] in ["update", "strike", "delete"]:
+ tag = event.dedup_tag()
+ self._logger.debug("recv {} from {}".format(_tag_digest(tag), obj_digest(event.scope)))
+ if tag in self._buckets:
+ b = self._buckets[tag]
+ b.recipients.add(event.scope)
+ # If this scope is the last one for this bucket we could clear the bucket here, but since
+ # connections sometimes get repeat messages and second copies can arrive before the first
+ # copy has arrived on all connections, leaving the bucket open to absorb repeats can reduce
+ # the likelihood of a second bucket being allocated late for the second copy of a message
+ # and causing unnecessary connection replacement.
+ elif event.scope not in self._outgoing_scopes:
+ sane = (
+ event.scope in self._scope_activations
+ or any(e.scope == event.scope for e in self._pending)
+ )
+ if sane:
+ self._logger.debug("new bucket {}: {}".format(_tag_digest(tag), tag))
+ self._buckets[tag] = self._Bucket(event.timestamp, {event.scope})
+ self._log_current_buckets()
+ await self._message_tx.send(event)
+ else:
+ raise RuntimeError("recieved message from unrecognized WS connection")
+ else:
+ self._logger.debug("recv type {!r} from {} (discarding)".format(event.data["type"], obj_digest(event.scope)))
+ elif isinstance(event, _ConnectionDown):
+ # We don't need to worry about canceling this scope at all, so no need to require it for parity for
+ # any message, even older ones. The scope may be gone already, if we canceled it previously.
+ self._scope_activations.pop(event.scope, None)
+ self._pending = deque(ev for ev in self._pending if ev.scope is not event.scope)
+ self._outgoing_scopes.discard(event.scope)
+ else:
+ raise TypeError(f"Expected pool event, found {event!r}")
+ self._timer_poke_tx.send_nowait(None) # may have new work for the timer
+
+ async def timeout_handler_impl(self):
+ """When connections have had enough time to reach parity on a message, signal replacement of any slackers."""
+ async with self._timer_poke_rx:
+ while True:
+ if self._buckets:
+ now = trio.current_time()
+ (tag, bucket) = min(self._buckets.items(), key = lambda i: i[1].start)
+
+ # Make sure the right scope set is ready for the next bucket up
+ while self._pending and self._pending[0].timestamp < bucket.start:
+ ev = self._pending.popleft()
+ self._scope_activations[ev.scope] = ev.timestamp
+
+ target_scopes = {
+ scope for (scope, active) in self._scope_activations.items()
+ if active + self._conn_warmup.total_seconds() < now
+ }
+ if now > bucket.start + self._parity_timeout.total_seconds():
+ filled = target_scopes & bucket.recipients
+ missing = target_scopes - bucket.recipients
+ extra = bucket.recipients - target_scopes
+
+ self._logger.debug("expiring bucket {}".format(_tag_digest(tag)))
+ if filled:
+ self._logger.debug(" filled {}: {}".format(len(filled), _format_scope_list(filled)))
+ if missing:
+ self._logger.debug(" missing {}: {}".format(len(missing), _format_scope_list(missing)))
+ if extra:
+ self._logger.debug(" extra {}: {}".format(len(extra), _format_scope_list(extra)))
+
+ for scope in missing:
+ self._outgoing_scopes.add(scope)
+ scope.cancel()
+ del self._buckets[tag]
+ self._log_current_buckets()
+ else:
+ await trio.sleep_until(bucket.start + self._parity_timeout.total_seconds())
+ else:
+ try:
+ await self._timer_poke_rx.receive()
+ except trio.EndOfChannel:
+ break
--- /dev/null
+"""Unbounded blocking queues for Trio"""
+
+from collections import deque
+from dataclasses import dataclass
+from functools import total_ordering
+from typing import Any, Iterable
+import heapq
+
+try:
+ from trio.lowlevel import ParkingLot
+except ImportError:
+ from trio.hazmat import ParkingLot
+
+
+class Queue:
+ def __init__(self):
+ self._deque = deque()
+ self._empty_wait = ParkingLot()
+
+ def push(self, el: Any) -> None:
+ self._deque.append(el)
+ self._empty_wait.unpark()
+
+ def extend(self, els: Iterable[Any]) -> None:
+ for el in els:
+ self.push(el)
+
+ async def pop(self) -> Any:
+ if not self._deque:
+ await self._empty_wait.park()
+ return self._deque.popleft()
+
+ def __len__(self):
+ return len(self._deque)
+
+
+@total_ordering
+@dataclass
+class _ReverseOrdWrapper:
+ inner: Any
+
+ def __lt__(self, other):
+ return other.inner < self.inner
+
+
+class MaxHeap:
+ def __init__(self):
+ self._heap = []
+ self._empty_wait = ParkingLot()
+
+ def push(self, item):
+ heapq.heappush(self._heap, _ReverseOrdWrapper(item))
+ self._empty_wait.unpark()
+
+ async def pop(self):
+ if not self._heap:
+ await self._empty_wait.park()
+ return heapq.heappop(self._heap).inner
+
+ def __len__(self) -> int:
+ return len(self._heap)
--- /dev/null
+"""Multi-auth HTTP connection pool with Reddit API wrappers and rate limiting."""
+
+from abc import ABCMeta, abstractmethod
+from dataclasses import dataclass, field
+from functools import total_ordering
+from socket import EAI_AGAIN, EAI_FAIL, gaierror
+import datetime as dt
+import logging
+
+from asks.response_objects import Response
+import asks
+import trio
+
+from strikebot import __version__ as VERSION
+from strikebot.common import obj_digest
+from strikebot.queue import MaxHeap, Queue
+
+
+REQUEST_DELAY = dt.timedelta(seconds = 1)
+
+TOKEN_UPDATE_DELAY = dt.timedelta(minutes = 15)
+
+USER_AGENT_FMT = f"any:net.jcornell.strikebot.{{auth_id}}:v{VERSION} (by /u/jaxklax)"
+
+API_BASE_URL = "https://oauth.reddit.com"
+
+
+@total_ordering
+class _Request(metaclass = ABCMeta):
+ _SUBTYPE_PRECEDENCE = None # assigned later
+
+ def __lt__(self, other):
+ if type(self) is type(other):
+ return self._subtype_cmp_key() < other._subtype_cmp_key()
+ else:
+ prec = self._SUBTYPE_PRECEDENCE
+ return prec.index(type(other)) < prec.index(type(self))
+
+ def __eq__(self, other):
+ if type(self) is type(other):
+ return self._subtype_cmp_key() == other._subtype_cmp_key()
+ else:
+ return False
+
+ @abstractmethod
+ def _subtype_cmp_key(self):
+ # a large key corresponds to a high priority
+ raise NotImplementedError()
+
+ @abstractmethod
+ def to_asks_kwargs(self):
+ raise NotImplementedError()
+
+
+@dataclass(eq = False)
+class StrikeRequest(_Request):
+ thread_id: str
+ update_name: str
+ update_ts: dt.datetime
+
+ def to_asks_kwargs(self):
+ return {
+ "method": "POST",
+ "path": f"/api/live/{self.thread_id}/strike_update",
+ "data": {
+ "api_type": "json",
+ "id": self.update_name,
+ },
+ }
+
+ def _subtype_cmp_key(self):
+ return self.update_ts
+
+
+@dataclass(eq = False)
+class DeleteRequest(_Request):
+ thread_id: str
+ update_name: str
+ update_ts: dt.datetime
+
+ def to_asks_kwargs(self):
+ return {
+ "method": "POST",
+ "path": f"/api/live/{self.thread_id}/delete_update",
+ "data": {
+ "api_type": "json",
+ "id": self.update_name,
+ },
+ }
+
+ def _subtype_cmp_key(self):
+ return -self.update_ts.timestamp()
+
+
+@dataclass(eq = False)
+class AboutLiveThreadRequest(_Request):
+ thread_id: str
+ created: float = field(default_factory = trio.current_time)
+
+ def to_asks_kwargs(self):
+ return {
+ "method": "GET",
+ "path": f"/live/{self.thread_id}/about",
+ "params": {
+ "raw_json": "1",
+ },
+ }
+
+ def _subtype_cmp_key(self):
+ return -self.created
+
+
+@dataclass(eq = False)
+class _BaseUpdatePostRequest(_Request, metaclass = ABCMeta):
+ thread_id: str
+ body: str
+ created: float = field(default_factory = trio.current_time)
+
+ def to_asks_kwargs(self):
+ return {
+ "method": "POST",
+ "path": f"/live/{self.thread_id}/update",
+ "data": {
+ "api_type": "json",
+ "body": self.body,
+ },
+ }
+
+
+@dataclass(eq = False)
+class ReportUpdateRequest(_BaseUpdatePostRequest):
+ def _subtype_cmp_key(self):
+ return -self.created
+
+
+@dataclass(eq = False)
+class CorrectionUpdateRequest(_BaseUpdatePostRequest):
+ def _subtype_cmp_key(self):
+ return -self.created
+
+
+# highest priority first
+_Request._SUBTYPE_PRECEDENCE = [
+ AboutLiveThreadRequest,
+ StrikeRequest,
+ CorrectionUpdateRequest,
+ ReportUpdateRequest,
+ DeleteRequest,
+]
+
+
+@dataclass
+class AppCooldown:
+ auth_id: int
+ ready_at: float # Trio time
+
+
+class ApiClientPool:
+ def __init__(
+ self,
+ auth_ids: set[int],
+ db_messenger: "strikebot.db.Messenger",
+ request_queue_limit: int,
+ error_window: dt.timedelta,
+ error_delay: dt.timedelta,
+ logger: logging.Logger,
+ ):
+ self._auth_ids = auth_ids
+ self._db_messenger = db_messenger
+ self._request_queue_limit = request_queue_limit
+ self._error_window = error_window
+ self._error_delay = error_delay
+ self._logger = logger
+
+ now = trio.current_time()
+ self._tokens = {}
+ self._app_queue = Queue()
+ self._waiting = {id_: AppCooldown(id_, now) for id_ in auth_ids}
+ self._request_queue = MaxHeap()
+
+ # pool-wide API error backoff
+ self._last_error = None
+ self._global_resume = None
+
+ self._session = asks.Session(connections = len(auth_ids))
+ self._session.base_location = API_BASE_URL
+
+ async def _update_tokens(self):
+ tokens = await self._db_messenger.do("get_auth_tokens", (self._auth_ids,))
+ self._tokens.update(tokens)
+
+ awaken_auths = self._waiting.keys() & tokens.keys()
+ self._app_queue.extend(self._waiting.pop(auth_id) for auth_id in awaken_auths)
+ self._logger.debug("updated API tokens")
+
+ async def token_updater_impl(self):
+ last_update = None
+ while True:
+ if last_update is not None:
+ await trio.sleep_until(last_update + TOKEN_UPDATE_DELAY.total_seconds())
+
+ if last_update is None:
+ last_update = trio.current_time()
+ else:
+ last_update += TOKEN_UPDATE_DELAY.total_seconds()
+
+ await self._update_tokens()
+
+ def _check_queue_size(self) -> None:
+ if len(self._request_queue) > self._request_queue_limit:
+ raise RuntimeError("request queue size exceeded limit")
+
+ async def make_request(self, request: _Request) -> Response:
+ (resp_tx, resp_rx) = trio.open_memory_channel(0)
+ self.enqueue_request(request, resp_tx)
+ async with resp_rx:
+ return await resp_rx.receive()
+
+ def enqueue_request(self, request: _Request, resp_tx = None) -> None:
+ self._logger.debug("request {}: {}".format(obj_digest(request), type(request).__name__))
+ self._request_queue.push((request, resp_tx))
+ self._logger.debug(f"{len(self._request_queue)} requests in queue")
+ self._check_queue_size()
+
+ async def worker_impl(self, task_status):
+ task_status.started()
+ while True:
+ (request, resp_tx) = await self._request_queue.pop()
+ cooldown = await self._app_queue.pop()
+ await trio.sleep_until(cooldown.ready_at)
+ if self._global_resume:
+ await trio.sleep_until(self._global_resume)
+
+ asks_kwargs = request.to_asks_kwargs()
+ headers = asks_kwargs.setdefault("headers", {})
+ headers.update({
+ "Authorization": "Bearer {}".format(self._tokens[cooldown.auth_id]),
+ "User-Agent": USER_AGENT_FMT.format(auth_id = cooldown.auth_id),
+ })
+
+ request_time = trio.current_time()
+ try:
+ resp = await self._session.request(**asks_kwargs)
+ except gaierror as e:
+ if e.errno in [EAI_FAIL, EAI_AGAIN]:
+ # DNS failure, probably temporary
+ error = True
+ else:
+ raise
+ else:
+ resp.body # read response
+ error = False
+ wait_for_token = False
+ log_suffix = " (request {})".format(obj_digest(request))
+ if resp.status_code == 429:
+ # We disagreed about the rate limit state; just try again later.
+ self._logger.warning("rate limited by Reddit API" + log_suffix)
+ error = True
+ elif resp.status_code == 401:
+ self._logger.warning("got HTTP 401 from Reddit API" + log_suffix)
+ error = True
+ wait_for_token = True
+ elif resp.status_code in [404, 500, 503]:
+ self._logger.warning(f"got HTTP {resp.status_code} from Reddit API, retrying" + log_suffix)
+ error = True
+ elif 400 <= resp.status_code < 500:
+ # If we're doing something wrong, let's catch it right away.
+ raise RuntimeError(f"unexpected client error response: {resp.status_code}" + log_suffix)
+ else:
+ if resp.status_code != 200:
+ raise RuntimeError(f"unexpected status code {resp.status_code}")
+ self._logger.debug("success" + log_suffix)
+ if resp_tx:
+ await resp_tx.send(resp)
+
+ if error:
+ self.enqueue_request(request, resp_tx)
+ if self._last_error:
+ spread = dt.timedelta(seconds = request_time - self._last_error)
+ if spread <= self._error_window:
+ self._global_resume = request_time + self._error_delay.total_seconds()
+ self._last_error = request_time
+
+ cooldown.ready_at = request_time + REQUEST_DELAY.total_seconds()
+ if wait_for_token:
+ self._waiting[cooldown.auth_id] = cooldown
+ if not self._app_queue:
+ self._logger.error("all workers waiting for API tokens")
+ else:
+ self._app_queue.push(cooldown)
--- /dev/null
+from unittest import TestCase
+
+from strikebot.updates import parse_update
+
+
+def _build_payload(body_html: str) -> dict[str, str]:
+ return {"body_html": body_html}
+
+
+class UpdateParsingTests(TestCase):
+ def test_successful_counts(self):
+ pu = parse_update(_build_payload("<div>12,345,678 spaghetti</div>"), None, "")
+ self.assertEqual(pu.number, 12345678)
+ self.assertFalse(pu.deletable)
+
+ pu = parse_update(_build_payload("<div><p>0</p><br/>oz</div>"), None, "")
+ self.assertEqual(pu.number, 0)
+ self.assertFalse(pu.deletable)
+
+ pu = parse_update(_build_payload("<div>121 345 621</div>"), 121, "")
+ self.assertEqual(pu.number, 121)
+ self.assertFalse(pu.deletable)
+
+ pu = parse_update(_build_payload("28 336 816"), 28336816, "")
+ self.assertEqual(pu.number, 28336816)
+ self.assertTrue(pu.deletable)
+
+ def test_non_counts(self):
+ pu = parse_update(_build_payload("<div><table><tbody><tr><td>zoo</td></tr></tbody></table></div>"), None, "")
+ self.assertFalse(pu.count_attempt)
+ self.assertFalse(pu.deletable)
+
+ def test_typos(self):
+ pu = parse_update(_build_payload("<span>v9</span>"), 888, "")
+ self.assertIsNone(pu.number)
+ self.assertTrue(pu.count_attempt)
+
+ pu = parse_update(_build_payload("<div>v11.585 Empire</div>"), None, "")
+ self.assertIsNone(pu.number)
+ self.assertTrue(pu.count_attempt)
+ self.assertFalse(pu.deletable)
+
+ pu = parse_update(_build_payload("<div>11, 585, 22 </div>"), 11_585_202, "")
+ self.assertIsNone(pu.number)
+ self.assertTrue(pu.count_attempt)
+ self.assertTrue(pu.deletable)
+
+ pu = parse_update(_build_payload("<span>0490499</span>"), 4999, "")
+ self.assertIsNone(pu.number)
+ self.assertTrue(pu.count_attempt)
+
+ # this update is marked non-deletable, but I'm not sure it should be
+ pu = parse_update(_build_payload("12,123123"), 12_123_123, "")
+ self.assertIsNone(pu.number)
+ self.assertTrue(pu.count_attempt)
+
+ def test_html_handling(self):
+ pu = parse_update(_build_payload("123<hr>,456"), None, "")
+ self.assertEqual(pu.number, 123)
+
+ pu = parse_update(_build_payload("<pre>123\n456</pre>"), None, "")
+ self.assertEqual(pu.number, 123)
--- /dev/null
+from __future__ import annotations
+from dataclasses import dataclass
+from enum import Enum
+from typing import Optional
+import re
+
+from bs4 import BeautifulSoup
+
+
+Command = Enum("Command", ["RESET", "REPORT"])
+
+
+@dataclass
+class ParsedUpdate:
+ number: Optional[int]
+ command: Optional[Command]
+ count_attempt: bool # either well-formed or typo
+ deletable: bool
+
+
+def _parse_command(line: str, bot_user: str) -> Optional[Command]:
+ if line.lower() == f"/u/{bot_user} reset".lower():
+ return Command.RESET
+ elif line.lower() in ["sidebar count", "current count"]:
+ return Command.REPORT
+ else:
+ return None
+
+
+def parse_update(payload_data: dict, curr_count: Optional[int], bot_user: str) -> ParsedUpdate:
+ # curr_count is the next number up, one more than the last count
+
+ NEW_LINE = object()
+ SPACE = object()
+
+ # flatten the update content to plain text
+ tree = BeautifulSoup(payload_data["body_html"], "html.parser")
+ worklist = list(reversed(tree.contents))
+ out = [[]]
+ while worklist:
+ el = worklist.pop()
+ if isinstance(el, str):
+ out[-1].append(el)
+ elif el is SPACE:
+ out[-1].append(el)
+ elif el is NEW_LINE or el.name in ["br", "hr"]:
+ if out[-1]:
+ out.append([])
+ elif el.name in ["em", "strong", "del", "span", "sup", "code", "a", "th", "td"]:
+ worklist.extend(reversed(el.contents))
+ elif el.name in ["ul", "ol", "table", "thead", "tbody"]:
+ worklist.extend(reversed(el.contents))
+ elif el.name in ["li", "p", "div", "blockquote"] or re.match(r"h[1-6]$", el.name):
+ worklist.append(NEW_LINE)
+ worklist.extend(reversed(el.contents))
+ worklist.append(NEW_LINE)
+ elif el.name == "pre":
+ out.extend([l] for l in el.text.splitlines())
+ out.append([])
+ elif el.name == "tr":
+ worklist.append(NEW_LINE)
+ for (i, cell) in enumerate(reversed(el.contents)):
+ worklist.append(cell)
+ if i != len(el.contents) - 1:
+ worklist.append(SPACE)
+ worklist.append(NEW_LINE)
+ else:
+ raise RuntimeError(f"can't parse tag {el.name}")
+
+ tmp_lines = (
+ "".join(" " if part is SPACE else part for part in parts).strip()
+ for parts in out
+ )
+ pre_strip_lines = list(filter(None, tmp_lines))
+
+ # normalize whitespace according to HTML rendering rules
+ # https://developer.mozilla.org/en-US/docs/Web/API/Document_Object_Model/Whitespace#explanation
+ stripped_lines = [
+ re.sub(" +", " ", l.replace("\t", " ").replace("\n", " ")).strip(" ")
+ for l in pre_strip_lines
+ ]
+
+ return _parse_from_lines(stripped_lines, curr_count, bot_user)
+
+
+def _parse_from_lines(lines: list[str], curr_count: Optional[int], bot_user: str) -> ParsedUpdate:
+ command = next(
+ (cmd for l in lines if (cmd := _parse_command(l, bot_user))),
+ None
+ )
+ if lines:
+ # look for groups of digits (as many as possible) separated by a uniform separator from the valid set
+ first = lines[0]
+ match = re.match(
+ "(?P<v>v)?(?P<neg>-)?(?P<num>\\d+((?P<sep>[,. \u2009]|, )\\d+((?P=sep)\\d+)*)?)",
+ first,
+ re.ASCII, # only recognize ASCII digits
+ )
+ if match:
+ raw_digits = match["num"]
+ sep = match["sep"]
+ post = first[match.end() :]
+
+ zeros = False
+ while len(raw_digits) > 1 and raw_digits[0] == "0":
+ zeros = True
+ raw_digits = raw_digits.removeprefix("0").removeprefix(sep or "")
+
+ parts = raw_digits.split(sep) if sep else [raw_digits]
+ all_parts_valid = (
+ sep is None
+ or (
+ 1 <= len(parts[0]) <= 3
+ and all(len(p) == 3 for p in parts[1:])
+ )
+ )
+ lone = len(lines) == 1 and (not post or post.isspace())
+ typo = False
+ if lone:
+ if match["v"] and len(parts) == 1 and len(parts[0]) <= 2:
+ # failed paste of leading digits
+ typo = True
+ elif match["v"] and all_parts_valid:
+ # v followed by count
+ typo = True
+ elif curr_count is not None and abs(curr_count) >= 100 and bool(match["neg"]) == (curr_count < 0):
+ goal_parts = _separate(str(abs(curr_count)))
+ partials = [
+ goal_parts[: -1] + [goal_parts[-1][: -1]], # missing last digit
+ goal_parts[: -1] + [goal_parts[-1][: -2]], # missing last two digits
+ goal_parts[: -1] + [goal_parts[-1][: -2] + goal_parts[-1][-1]], # missing second-last digit
+ ]
+ if parts in partials:
+ # missing any of last two digits
+ typo = True
+ elif any(parts == p[: -1] + [p[-1] + goal_parts[0]] + goal_parts[1 :] for p in partials):
+ # double paste
+ typo = True
+
+ if match["v"] or zeros or typo or (parts == ["0"] and match["neg"]):
+ number = None
+ count_attempt = True
+ deletable = lone
+ else:
+ groups_okay = True
+ if curr_count is not None and sep and sep.isspace():
+ # Presume that the intended count consists of as many valid digit groups as necessary to match the
+ # number of digits in the expected count, if possible.
+ digit_count = len(str(abs(curr_count)))
+ use_parts = []
+ accum = 0
+ for (i, part) in enumerate(parts):
+ part_valid = len(part) <= 3 if i == 0 else len(part) == 3
+ if part_valid and accum < digit_count:
+ use_parts.append(part)
+ accum += len(part)
+ else:
+ break
+
+ # could still be a no-separator count with some extra digit groups on the same line
+ if not use_parts:
+ use_parts = [parts[0]]
+
+ lone = lone and len(use_parts) == len(parts)
+ else:
+ # current count is unknown or no separator was used
+ groups_okay = all_parts_valid
+ use_parts = parts
+
+ digits = "".join(use_parts)
+ number = -int(digits) if match["neg"] else int(digits)
+ special = (
+ curr_count is not None
+ and abs(number - curr_count) <= 25
+ and _is_special_number(number)
+ )
+ if groups_okay:
+ deletable = lone and not special
+ if len(use_parts) == len(parts) and post and not post[0].isspace():
+ count_attempt = curr_count is not None and abs(number - curr_count) <= 25
+ number = None
+ else:
+ count_attempt = True
+ else:
+ number = None
+ count_attempt = True
+ deletable = False
+
+ if first[0].isdigit() and not count_attempt:
+ count_attempt = True
+ deletable = False
+ else:
+ # no count attempt found
+ number = None
+ count_attempt = False
+ deletable = False
+ else:
+ # no lines in update
+ number = None
+ count_attempt = False
+ deletable = True
+
+ return ParsedUpdate(
+ number = number,
+ command = command,
+ count_attempt = count_attempt,
+ deletable = deletable,
+ )
+
+
+def _separate(digits: str) -> list[str]:
+ mod = len(digits) % 3
+ out = []
+ if mod:
+ out.append(digits[: mod])
+ out.extend(digits[i : i + 3] for i in range(mod, len(digits), 3))
+ return out
+
+
+def _is_special_number(num: int) -> bool:
+ num_str = str(num)
+ return bool(
+ num % 1000 in [0, 1, 333, 999]
+ or (num > 10_000_000 and "".join(reversed(num_str)) == num_str)
+ or re.match(r"(.+)\1+$", num_str) # repeated sequence
+ )