Reorganize for multi-project builds, improve build helper
authorJakob Cornell <jakob+gpg@jcornell.net>
Mon, 9 Jan 2023 00:07:32 +0000 (18:07 -0600)
committerJakob Cornell <jakob+gpg@jcornell.net>
Mon, 9 Jan 2023 00:07:32 +0000 (18:07 -0600)
43 files changed:
build_helper.py
mentionbot/docs/sample_config.ini [deleted file]
mentionbot/mentionbot/docs/sample_config.ini [new file with mode: 0644]
mentionbot/mentionbot/pyproject.toml [new file with mode: 0644]
mentionbot/mentionbot/setup.cfg [new file with mode: 0644]
mentionbot/mentionbot/setup.py [new file with mode: 0644]
mentionbot/mentionbot/src/mentionbot/__init__.py [new file with mode: 0644]
mentionbot/mentionbot/src/mentionbot/__main__.py [new file with mode: 0644]
mentionbot/mentionbot/src/mentionbot/abc.py [new file with mode: 0644]
mentionbot/mentionbot/src/mentionbot/tests.py [new file with mode: 0644]
mentionbot/pyproject.toml [deleted file]
mentionbot/setup.cfg [deleted file]
mentionbot/setup.py [deleted file]
mentionbot/src/mentionbot/__init__.py [deleted file]
mentionbot/src/mentionbot/__main__.py [deleted file]
mentionbot/src/mentionbot/abc.py [deleted file]
mentionbot/src/mentionbot/tests.py [deleted file]
strikebot/docs/sample_config.ini [deleted file]
strikebot/pyproject.toml [deleted file]
strikebot/setup.cfg [deleted file]
strikebot/setup.py [deleted file]
strikebot/src/strikebot/__init__.py [deleted file]
strikebot/src/strikebot/__main__.py [deleted file]
strikebot/src/strikebot/common.py [deleted file]
strikebot/src/strikebot/db.py [deleted file]
strikebot/src/strikebot/live_ws.py [deleted file]
strikebot/src/strikebot/queue.py [deleted file]
strikebot/src/strikebot/reddit_api.py [deleted file]
strikebot/src/strikebot/tests.py [deleted file]
strikebot/src/strikebot/updates.py [deleted file]
strikebot/strikebot/docs/sample_config.ini [new file with mode: 0644]
strikebot/strikebot/pyproject.toml [new file with mode: 0644]
strikebot/strikebot/setup.cfg [new file with mode: 0644]
strikebot/strikebot/setup.py [new file with mode: 0644]
strikebot/strikebot/src/strikebot/__init__.py [new file with mode: 0644]
strikebot/strikebot/src/strikebot/__main__.py [new file with mode: 0644]
strikebot/strikebot/src/strikebot/common.py [new file with mode: 0644]
strikebot/strikebot/src/strikebot/db.py [new file with mode: 0644]
strikebot/strikebot/src/strikebot/live_ws.py [new file with mode: 0644]
strikebot/strikebot/src/strikebot/queue.py [new file with mode: 0644]
strikebot/strikebot/src/strikebot/reddit_api.py [new file with mode: 0644]
strikebot/strikebot/src/strikebot/tests.py [new file with mode: 0644]
strikebot/strikebot/src/strikebot/updates.py [new file with mode: 0644]

index c7478941fca2a8a2aa03b502e90f737b1b84e2fd..07d3be218e2eaa914e625da807922770bc86a579 100644 (file)
@@ -1,5 +1,7 @@
 """
-A few tools to help keep Debian packages and artifacts well organized.
+A few tools to help keep Debian packages and artifacts well organized. Run me from the counting repo root.
+
+Run as root if doing a build, as the 
 """
 
 from argparse import ArgumentParser
@@ -24,18 +26,26 @@ def _is_output(path: Path) -> bool:
        )
 
 
-project_root = Path(__file__).parent
-upstream_root = project_root.joinpath("strikebot")
-build_dir = project_root.joinpath("build")
-
 parser = ArgumentParser()
+
 tmp = parser.add_subparsers(dest = "cmd", required = True)
+project_cmd_parsers = [
+       tmp.add_parser("build"),
+       tmp.add_parser("clean"),
+]
 
-tmp.add_parser("build")
-tmp.add_parser("clean")
+for subparser in project_cmd_parsers:
+       subparser.add_argument("project")
 
 args = parser.parse_args()
+
+project_root = Path(__file__).parent.joinpath(args.project)
+upstream_root = project_root.joinpath(args.project)
+build_dir = project_root.joinpath("build")
+
 if args.cmd == "build":
+       build_dir.mkdir(exist_ok = True)
+
        # extract version from Python package config
        patt = re.compile("version *= *(.+?) *$")
        with upstream_root.joinpath("setup.cfg").open() as f:
@@ -43,25 +53,30 @@ if args.cmd == "build":
        version = m[1]
 
        # delete stale "orig" tarball
-       for p in project_root.glob("strikebot_*.orig.tar.xz"):
+       for p in project_root.glob(f"{args.project}_*.orig.tar.xz"):
                p.unlink()
 
        # regenerate the "orig" tarball from the current source
-       run(["dh_make", "--yes", "--python", "--createorig", "-p", "strikebot_" + version], cwd = upstream_root)
+       run(["dh_make", "--yes", "--python", "--createorig", "-p", f"{args.project}_{version}"], cwd = upstream_root)
 
-       # build the source package
-       run(["debuild", "-i", "-us", "-uc", "-S"], cwd = upstream_root, check = True)
+       try:
+               # build the source package
+               run(["debuild", "-i", "-us", "-uc", "-S"], cwd = upstream_root, check = True)
 
-       [orig_path] = project_root.glob("strikebot_*.orig.tar.xz")
-       orig_name = orig_path.name
-
-       # move outputs to build directory
-       for p in project_root.iterdir():
-               if _is_output(p):
-                       p.rename(build_dir.joinpath(p.relative_to(project_root)))
+               [orig_path] = project_root.glob(f"{args.project}_*.orig.tar.xz")
+               orig_name = orig_path.name
 
-       # create temporary link for orig tarball to satisfy binary package build
-       orig_path.symlink_to(build_dir.joinpath(orig_name).relative_to(project_root))
+               try:
+                       # build binary package
+                       run(["pdebuild"], cwd = upstream_root, check = True)
+               finally:
+                       # clean up
+                       orig_path.unlink()
+       finally:
+               # move source package and intermediates to build directory
+               for p in project_root.iterdir():
+                       if _is_output(p):
+                               p.rename(build_dir.joinpath(p.relative_to(project_root)))
 elif args.cmd == "clean":
        for p in build_dir.iterdir():
                if _is_output(p):
diff --git a/mentionbot/docs/sample_config.ini b/mentionbot/docs/sample_config.ini
deleted file mode 100644 (file)
index 4228b6a..0000000
+++ /dev/null
@@ -1,28 +0,0 @@
-# see Python `configparser' docs for precise file syntax
-
-[config]
-
-# authorization ID for Reddit API token lookup
-auth ID = 3
-
-# (int, minutes) 
-token update period = 15
-
-thread ID = abc123
-
-# (int) maximum number of mentions within an update for which to send messages
-mention limit = 10
-
-# (float, seconds) wait this much time before sending notifications, in case of update deletion
-message buffer time = 10
-
-# (float, seconds) wait this much time between API polls (and WS connection checks)
-API poll period = 90
-
-# (float, seconds) replace the WS connection if it falls this far behind the API in update timestamps
-max WS delay = 2
-
-
-# Postgres database configuration; same options as in a connect string.
-[db connect params]
-host = example.org
diff --git a/mentionbot/mentionbot/docs/sample_config.ini b/mentionbot/mentionbot/docs/sample_config.ini
new file mode 100644 (file)
index 0000000..4228b6a
--- /dev/null
@@ -0,0 +1,28 @@
+# see Python `configparser' docs for precise file syntax
+
+[config]
+
+# authorization ID for Reddit API token lookup
+auth ID = 3
+
+# (int, minutes) 
+token update period = 15
+
+thread ID = abc123
+
+# (int) maximum number of mentions within an update for which to send messages
+mention limit = 10
+
+# (float, seconds) wait this much time before sending notifications, in case of update deletion
+message buffer time = 10
+
+# (float, seconds) wait this much time between API polls (and WS connection checks)
+API poll period = 90
+
+# (float, seconds) replace the WS connection if it falls this far behind the API in update timestamps
+max WS delay = 2
+
+
+# Postgres database configuration; same options as in a connect string.
+[db connect params]
+host = example.org
diff --git a/mentionbot/mentionbot/pyproject.toml b/mentionbot/mentionbot/pyproject.toml
new file mode 100644 (file)
index 0000000..8fe2f47
--- /dev/null
@@ -0,0 +1,3 @@
+[build-system]
+requires = ["setuptools>=42", "wheel"]
+build-backend = "setuptools.build_meta"
diff --git a/mentionbot/mentionbot/setup.cfg b/mentionbot/mentionbot/setup.cfg
new file mode 100644 (file)
index 0000000..b3f9866
--- /dev/null
@@ -0,0 +1,15 @@
+[metadata]
+name = mentionbot
+version = 0.0.0
+
+[options]
+package_dir =
+       = src
+packages = mentionbot
+python_requires = ~= 3.9
+install_requires =
+       beautifulsoup4 ~= 4.9
+       #psycopg2 ~= 2.8
+       psycopg2-binary ~= 2.8
+       trio == 0.19
+       trio-websocket == 0.9.2
diff --git a/mentionbot/mentionbot/setup.py b/mentionbot/mentionbot/setup.py
new file mode 100644 (file)
index 0000000..056ba45
--- /dev/null
@@ -0,0 +1,4 @@
+import setuptools
+
+
+setuptools.setup()
diff --git a/mentionbot/mentionbot/src/mentionbot/__init__.py b/mentionbot/mentionbot/src/mentionbot/__init__.py
new file mode 100644 (file)
index 0000000..76fd7e2
--- /dev/null
@@ -0,0 +1,187 @@
+from dataclasses import dataclass
+from datetime import timedelta
+from http.client import HTTPResponse
+from logging import Logger
+from socket import EAI_AGAIN, EAI_FAIL, gaierror
+from typing import AsyncContextManager, Callable, Iterable, Optional, Type
+from urllib.error import HTTPError
+from urllib.request import Request, urlopen
+from uuid import UUID
+import importlib.metadata
+import re
+
+from bs4 import BeautifulSoup
+from trio import CancelScope, MemoryReceiveChannel, MemorySendChannel, move_on_at
+import trio
+
+from .abc import DeleteEvent, EventStream, Notifier, NotifyInfo, ThreadTitleProvider, UpdateChecker, UpdateEvent
+
+
+_USER_AGENT = "any:net.jcornell.mentionbot:v{} (by /u/jaxklax)".format(
+       importlib.metadata.version(__package__)
+)
+
+
+@dataclass
+class TokenRef:
+       """Mutable container for an access token."""
+
+       token: str
+
+       def get_common_headers(self) -> dict[str, str]:
+               return {
+                       "User-Agent": _USER_AGENT,
+                       "Authorization": "Bearer {}".format(self.token)
+               }
+
+
+def try_request(
+       request: Request,
+       logger: Logger,
+       suppress_codes: list[int] = [404, 429, 500, 503],
+) -> Optional[HTTPResponse]:
+
+       try:
+               return urlopen(request)
+       except gaierror as e:
+               if e.errno in [EAI_FAIL, EAI_AGAIN]:
+                       logger.error(f"DNS failure: {e}")
+                       return None
+               else:
+                       raise
+       except HTTPError as e:
+               e.read()
+               e.close()
+               if e.code in suppress_codes:
+                       logger.error(f"HTTP {e.code}")
+                       return None
+               else:
+                       raise
+
+
+@dataclass
+class WsListenerState:
+       """The WebSockets listener uses this object to communicate with the API poll and message sender tasks."""
+
+       # The API poll task uses this scope to force a refresh of the WS connection if it's been too quiet.
+       scope: Optional[CancelScope] = None
+
+       # The WS task uses this to track the UUID timestamp of the most recent update message update, so the API poll task
+       # can detect when the connection goes silent.
+       last_update_ts: Optional[int] = None
+
+       # The WS task sets this when the WS connection goes down and clears it when it's recovered.
+       down: bool = True
+
+       # The WS task sets this to the UUID timestamp of the first update recieved on connection recovery so the message
+       # sender task knows when to stop checking for deletion.
+       resync_ts: Optional[int] = None
+
+
+def _parse_mentions(body: BeautifulSoup) -> set[str]:
+       """Extract the names of users tagged in an update."""
+       return {
+               el["href"].removeprefix("/u/")
+               for el in body.find_all("a", href = re.compile(r"^/u/[\w-]{3,20}$", re.ASCII))
+       }
+
+
+async def ws_listener_impl(
+       event_stream_factory: Callable[[], AsyncContextManager[EventStream]],
+       event_stream_catch: Iterable[Type[BaseException]],
+       state: WsListenerState,
+       event_tx: MemorySendChannel,  # message type `Event'
+       logger: Logger,
+) -> None:
+
+       while True:
+               state.scope = CancelScope()
+               with state.scope:
+                       logger.debug("opening a new connection")
+                       try:
+                               async with event_stream_factory() as event_stream:
+                                       first_update = True
+                                       while True:
+                                               event = await event_stream.get_event()
+                                               if isinstance(event, UpdateEvent):
+                                                       update_id = UUID(event.payload_data["id"])
+                                                       state.last_update_ts = update_id.time
+                                                       if first_update:
+                                                               # first update after discontinuity
+                                                               state.down = False
+                                                               state.resync_ts = update_id.time
+                                                               first_update = False
+                                               elif isinstance(event, DeleteEvent):
+                                                       logger.debug("delete event for " + event.payload)
+                                               else:
+                                                       raise RuntimeError(f"expected Event, got {event}")
+
+                                               await event_tx.send(event)
+                       except tuple(event_stream_catch) as e:
+                               logger.warning(f"WebSockets error: {e}")
+
+               state.down = True
+
+
+async def message_sender_impl(
+       notifier: Notifier,
+       title_provider: ThreadTitleProvider,
+       update_checker: UpdateChecker,
+       ws_state: WsListenerState,
+       thread_id: str,
+       event_rx: MemoryReceiveChannel,  # message type `Event'
+       buffer_time: timedelta,
+       mention_limit: int,
+       logger: Logger,
+) -> None:
+
+       @dataclass
+       class QueueEntry:
+               payload_data: dict
+               mentions: set[str]
+               arrival: float  # Trio time
+
+       queue: list[QueueEntry] = []
+
+       while True:
+               logger.debug("queue: {}".format([e.payload_data["id"] for e in queue]))
+               if queue:
+                       scope = move_on_at(queue[0].arrival + buffer_time.total_seconds())
+               else:
+                       scope = CancelScope()  # no timeout
+
+               with scope:
+                       event = await event_rx.receive()
+
+               if scope.cancelled_caught:
+                       # timed out
+                       entry = queue.pop(0)
+                       update_id = entry.payload_data["id"]
+                       if ws_state.down or ws_state.resync_ts and UUID(update_id).time < ws_state.resync_ts:
+                               notify = update_checker.exists(thread_id, update_id)
+                       else:
+                               notify = True
+
+                       if notify:
+                               thread_title = title_provider.title_for(thread_id)
+                               for recipient in entry.mentions:
+                                       notifier.notify(NotifyInfo(
+                                               recipient, thread_id, thread_title, update_id, entry.payload_data["author"],
+                                               entry.payload_data["body"]
+                                       ))
+               else:
+                       # got an event
+                       if isinstance(event, UpdateEvent):
+                               update_id = event.payload_data["id"]
+                               if not any(entry.payload_data["id"] == update_id for entry in queue):
+                                       mentions = _parse_mentions(BeautifulSoup(event.payload_data["body_html"], "html.parser"))
+                                       if mentions:
+                                               if len(mentions) > mention_limit:
+                                                       logger.info(f"ignoring {update_id} due to too many mentions ({len(mentions)})")
+                                               else:
+                                                       queue.append(QueueEntry(event.payload_data, mentions, trio.current_time()))
+                       elif isinstance(event, DeleteEvent):
+                               logger.debug("delete event for " + event.payload)
+                               queue = [entry for entry in queue if entry.payload_data["name"] != event.payload]
+                       else:
+                               raise RuntimeError(f"invalid event type {type(event)}")
diff --git a/mentionbot/mentionbot/src/mentionbot/__main__.py b/mentionbot/mentionbot/src/mentionbot/__main__.py
new file mode 100644 (file)
index 0000000..960a228
--- /dev/null
@@ -0,0 +1,291 @@
+"""
+Mentionbot: Notify users by private message when they're /u/-mentioned in a live thread.
+"""
+
+from argparse import ArgumentParser
+from configparser import ConfigParser
+from contextlib import asynccontextmanager
+from dataclasses import dataclass
+from datetime import timedelta
+from logging import Formatter, getLogger, Logger, NOTSET, StreamHandler
+from sys import stdout
+from typing import ClassVar, Iterator, Optional, Type
+from urllib.error import HTTPError
+from urllib.parse import urlencode
+from urllib.request import Request, urlopen
+from uuid import UUID
+import json
+import logging
+
+from trio import MemorySendChannel, open_memory_channel, open_nursery
+from trio_websocket import ConnectionClosed, HandshakeError, open_websocket_url, WebSocketConnection
+import psycopg2
+import trio
+
+from . import TokenRef, message_sender_impl, try_request, ws_listener_impl, WsListenerState
+from .abc import DeleteEvent, Event, EventStream, Notifier, NotifyInfo, ThreadTitleProvider, UpdateChecker, UpdateEvent
+
+
+async def _token_update_impl(period: timedelta, db_conn, auth_id: int, task_status) -> None:
+       def fetch_token() -> str:
+               cursor = db_conn.cursor()
+               cursor.execute(
+                       """
+                               select access_token from public.reddit_app_authorization
+                               where id = %s
+                       """,
+                       (auth_id,)
+               )
+               [(token,)] = cursor.fetchall()
+               return token
+
+       ref = TokenRef(fetch_token())
+       task_status.started(ref)
+       while True:
+               await trio.sleep(period.total_seconds())
+               ref.token = fetch_token()
+
+
+async def _api_poll_impl(
+       ws_state: WsListenerState,
+       token_ref: TokenRef,
+       thread_id: str,
+       event_tx: MemorySendChannel,  # message type `Event'
+       poll_period: timedelta,
+       max_ws_delay: timedelta,
+       logger: Logger,
+) -> None:
+
+       base_url = f"https://oauth.reddit.com/live/{thread_id}"
+
+       latest_payload: Optional[dict] = None
+       while True:
+               # fetch updates and emit events
+               params = {
+                       "raw_json": "1",
+               }
+               if latest_payload:
+                       updates = []
+                       while True:
+                               params["before"] = latest_payload["data"]["name"]
+                               url = base_url + "?" + urlencode(params)
+                               resp = try_request(Request(url, headers = token_ref.get_common_headers()), logger)
+                               if resp:
+                                       with resp:
+                                               children = json.load(resp)["data"]["children"]
+                                       if children:
+                                               updates.extend(reversed(children))
+                                               latest_payload = children[0]
+                                       else:
+                                               break
+                               else:
+                                       break
+               else:
+                       params["limit"] = "1"
+                       url = base_url + "?" + urlencode(params)
+                       resp = try_request(Request(url, headers = token_ref.get_common_headers()), logger)
+                       if resp:
+                               with resp:
+                                       updates = json.load(resp)["data"]["children"]
+                               if updates:
+                                       latest_payload = updates[0]
+                       else:
+                               updates = []
+
+               for update in updates:
+                       await event_tx.send(UpdateEvent(update["data"]))
+
+               # check if WS connection is behind
+               if not ws_state.down and latest_payload:
+                       max_delta = max_ws_delay.total_seconds() * 10 ** 7  # in 100*nanosecond
+                       if UUID(latest_payload["data"]["id"]).time - ws_state.last_update_ts > max_delta:
+                               ws_state.scope.cancel()
+
+               await trio.sleep(poll_period.total_seconds())
+
+
+def _build_notification_md(info: NotifyInfo) -> str:
+       # TODO escape thread title?
+       return "\n\n".join([
+               f"from [{info.author}](/user/{info.author}) in [{info.thread_title}](/live/{info.thread_id})",
+               "\n".join(">" + line for line in info.update_body.splitlines()),
+               (
+                       f"[^^permalink](/live/{info.thread_id}/updates/{info.update_id})"
+                       + f" [^^context](/live/{info.thread_id}?after=LiveUpdate_{info.update_id})"
+                       + " [^^about ^^these ^^notifications](/r/live_mentions/wiki)"
+               ),
+       ])
+
+
+@dataclass
+class RedditNotifier(Notifier):
+       token_ref: TokenRef
+       logger: Logger
+
+       def notify(self, info: NotifyInfo) -> None:
+               params = {
+                       "raw_json": "1",
+                       "api_type": "json",
+                       "subject": "username mention",
+                       "text": _build_notification_md(info),
+               }
+               req = Request(
+                       "https://oauth.reddit.com/api/compose",
+                       method = "POST",
+                       data = urlencode({**params, "to": info.recipient}).encode("ascii"),
+                       headers = self.token_ref.get_common_headers(),
+               )
+               resp = try_request(req, self.logger)
+               if resp:
+                       with resp:
+                               data = json.load(resp)["json"]
+                       if data["errors"]:
+                               [(code, _, _)] = data["errors"]
+                               if code != "USER_DOESNT_EXIST":
+                                       raise RuntimeError("PM send error: {code}")
+
+
+@dataclass
+class RedditThreadTitleProvider(ThreadTitleProvider):
+       token_ref: TokenRef
+
+       def title_for(self, thread_id: str) -> str:
+               req = Request(
+                       f"https://oauth.reddit.com/live/{thread_id}/about?raw_json=1",
+                       headers = self.token_ref.get_common_headers()
+               )
+               with urlopen(req) as resp:
+                       return json.load(resp)["data"]["title"]
+
+
+@dataclass
+class RedditUpdateChecker(UpdateChecker):
+       token_ref: TokenRef
+       logger: Logger
+
+       def exists(self, thread_id: str, update_id: str) -> bool:
+               req = Request(
+                       f"https://oauth.reddit.com/live/{thread_id}/updates/{update_id}?raw_json=1",
+                       headers = self.token_ref.get_common_headers(),
+               )
+               try:
+                       resp = try_request(req, self.logger, suppress_codes = [429, 500, 503])
+               except HTTPError as e:
+                       e.read()
+                       e.close()
+                       if e.code == 404:
+                               return False
+                       else:
+                               raise
+               else:
+                       if resp:
+                               resp.read()
+                               resp.close()
+                       return True
+
+
+@dataclass
+class WsConnectionEventStream(EventStream):
+       EXCEPTION_TYPES: ClassVar[set[Type[BaseException]]] = {HandshakeError, ConnectionClosed}
+
+       ws_conn: WebSocketConnection
+
+       async def get_event(self) -> Event:
+               while True:
+                       message = json.loads(await self.ws_conn.get_message())
+                       if message["type"] == "update":
+                               return UpdateEvent(message["payload"]["data"])
+                       elif message["type"] == "delete":
+                               return DeleteEvent(message["payload"])
+
+
+async def _main(
+       auth_id: int,
+       buffer_time: timedelta,
+       db_conn,
+       logger: Logger,
+       max_ws_delay: timedelta,
+       mention_limit: int,
+       poll_period: timedelta,
+       thread_id: str,
+       token_period: timedelta,
+) -> None:
+
+       @asynccontextmanager
+       async def event_stream_factory() -> Iterator[WsConnectionEventStream]:
+               resp = urlopen(Request(
+                       f"https://oauth.reddit.com/live/{thread_id}/about?raw_json=1",
+                       headers = token_ref.get_common_headers(),
+               ))
+               with resp:
+                       ws_url = json.load(resp)["data"]["websocket_url"]
+
+               async with open_websocket_url(ws_url) as ws_conn:
+                       yield WsConnectionEventStream(ws_conn)
+
+       async with open_nursery() as nursery:
+               token_ref = await nursery.start(_token_update_impl, token_period, db_conn, auth_id)
+
+               ws_state = WsListenerState()
+               (event_tx, event_rx) = open_memory_channel(0)
+               nursery.start_soon(
+                       ws_listener_impl,
+                       event_stream_factory, WsConnectionEventStream.EXCEPTION_TYPES, ws_state, event_tx, logger.getChild("WS")
+               )
+               nursery.start_soon(
+                       _api_poll_impl,
+                       ws_state, token_ref, thread_id, event_tx, poll_period, max_ws_delay, logger.getChild("poll")
+               )
+
+               notifier = RedditNotifier(token_ref, logger)
+               title_provider = RedditThreadTitleProvider(token_ref)
+               deletion_checker = RedditUpdateChecker(token_ref, logger)
+               nursery.start_soon(
+                       message_sender_impl,
+                       notifier, title_provider, deletion_checker, ws_state, thread_id, event_rx, buffer_time, mention_limit,
+                       logger.getChild("sender")
+               )
+
+
+if __name__ == "__main__":
+       arg_parser = ArgumentParser(__package__)
+       arg_parser.add_argument("config_path")
+       args = arg_parser.parse_args()
+
+       config_parser = ConfigParser()
+       with open(args.config_path) as config_file:
+               config_parser.read_file(config_file)
+
+       main_cfg = config_parser["config"]
+
+       auth_id = main_cfg.getint("auth ID")
+       assert auth_id is not None
+
+       token_period = timedelta(minutes = main_cfg.getint("token update period"))
+
+       thread_id = main_cfg["thread ID"]
+
+       mention_limit = main_cfg.getint("mention limit")
+       assert mention_limit is not None and mention_limit >= 0
+
+       buffer_time = timedelta(seconds = main_cfg.getfloat("message buffer time"))
+
+       poll_period = timedelta(seconds = main_cfg.getfloat("API poll period"))
+
+       max_ws_delay = timedelta(seconds = main_cfg.getfloat("max WS delay"))
+
+       db_conn = psycopg2.connect(**config_parser["db connect params"])
+       db_conn.autocommit = True
+
+       logger = getLogger(__package__)
+       logger.setLevel(NOTSET - 1)  # filter in handler(s) instead
+
+       handler = StreamHandler(stdout)
+       handler.setLevel(logging.DEBUG)
+       handler.setFormatter(Formatter("{asctime:23}: {name:17}: {levelname:8}: {message}", style = "{"))
+       logger.addHandler(handler)
+
+       trio.run(
+               _main,
+               auth_id, buffer_time, db_conn, logger, max_ws_delay, mention_limit, poll_period, thread_id, token_period
+       )
diff --git a/mentionbot/mentionbot/src/mentionbot/abc.py b/mentionbot/mentionbot/src/mentionbot/abc.py
new file mode 100644 (file)
index 0000000..3c1d7fd
--- /dev/null
@@ -0,0 +1,50 @@
+from abc import ABCMeta, abstractmethod
+from dataclasses import dataclass
+
+
+class Event(metaclass = ABCMeta):
+       pass
+
+
+@dataclass
+class UpdateEvent(Event):
+       payload_data: dict  # object at "payload" -> "data" in the message JSON
+
+
+@dataclass
+class DeleteEvent(Event):
+       payload: str  # update name
+
+
+class EventStream(metaclass = ABCMeta):
+       @abstractmethod
+       async def get_event(self) -> Event:
+               raise NotImplementedError()
+
+
+@dataclass
+class NotifyInfo:
+       recipient: str
+       thread_id: str
+       thread_title: str
+       update_id: str
+       author: str
+       update_body: str
+
+
+class Notifier(metaclass = ABCMeta):
+       @abstractmethod
+       def notify(self, info: NotifyInfo) -> None:
+               raise NotImplementedError()
+
+
+class ThreadTitleProvider(metaclass = ABCMeta):
+       @abstractmethod
+       def title_for(self, thread_id: str) -> str:
+               raise NotImplementedError()
+
+
+class UpdateChecker(metaclass = ABCMeta):
+       @abstractmethod
+       def exists(self, thread_id: str, update_id: str) -> bool:
+               raise NotImplementedError()
diff --git a/mentionbot/mentionbot/src/mentionbot/tests.py b/mentionbot/mentionbot/src/mentionbot/tests.py
new file mode 100644 (file)
index 0000000..0617c90
--- /dev/null
@@ -0,0 +1,350 @@
+from contextlib import asynccontextmanager, contextmanager
+from dataclasses import dataclass, field
+from datetime import timedelta
+from logging import getLogger
+from numbers import Real
+from os import environ
+from typing import Iterator, TypeVar
+from unittest import TestCase
+from urllib.request import Request
+from uuid import UUID, uuid1
+import logging
+
+from trio import CancelScope, MemoryReceiveChannel, move_on_after, open_memory_channel, open_nursery, sleep_until
+from trio.testing import MockClock, wait_all_tasks_blocked
+import trio
+
+from . import message_sender_impl, TokenRef, try_request, WsListenerState, ws_listener_impl
+from .abc import DeleteEvent, Event, EventStream, Notifier, NotifyInfo, ThreadTitleProvider, UpdateChecker, UpdateEvent
+
+
+T = TypeVar("T")
+
+
+@dataclass
+class MockNotifier(Notifier):
+       notifies: list[NotifyInfo] = field(default_factory = list)
+
+       def notify(self, info: NotifyInfo) -> None:
+               self.notifies.append(info)
+
+
+class MockTitleProvider(ThreadTitleProvider):
+       def title_for(self, thread_id: str) -> str:
+               return thread_id
+
+
+class UnusedUpdateChecker(UpdateChecker):
+       def exists(self, thread_id: str, update_id: str) -> bool:
+               raise AssertionError("unexpected deletion check")
+
+
+@dataclass
+class DictBackedUpdateChecker(UpdateChecker):
+       exists_by_id: dict[str, bool] = field(default_factory = dict)
+       queries: list[tuple[str, str]] = field(default_factory = list)
+
+       def exists(self, thread_id: str, update_id: str) -> bool:
+               self.queries.append((thread_id, update_id))
+               return self.exists_by_id[update_id]
+
+
+@dataclass
+class ChannelBackedEventStream(EventStream):
+       event_rx: MemoryReceiveChannel
+
+       async def get_event(self) -> Event:
+               obj = await self.event_rx.receive()
+               if isinstance(obj, BaseException):
+                       raise obj
+               else:
+                       return obj
+
+
+@asynccontextmanager
+async def null_async_context(val: T) -> Iterator[T]:
+       yield val
+
+
+def _update_event(author: str, mentions: list[str]) -> UpdateEvent:
+       id_ = uuid1()
+       body = str(id_) + " " + " ".join("/u/" + user for user in mentions)
+       body_html = str(id_) + " " + " ".join(
+               f'<a href="/u/{user}">/u/{user}</a>'
+               for user in mentions
+       )
+       return UpdateEvent(
+               {
+                       "id": str(id_),
+                       "name": "LiveUpdate_" + str(id_),
+                       "author": author,
+                       "body": body,
+                       "body_html": body_html,
+               }
+       )
+
+
+def setUpModule() -> None:
+       getLogger().setLevel(logging.CRITICAL + 1)  # disable logging
+
+
+class TaskTests(TestCase):
+       @contextmanager
+       def _assert_completes_in(self, duration_seconds: Real) -> Iterator[None]:
+               with move_on_after(duration_seconds) as scope:
+                       yield
+               if scope.cancelled_caught:
+                       self.fail("operation timed out")
+
+       def test_buffering(self) -> None:
+               """
+               No API poll task, just a custom task mocking the WS listener by playing events over a channel, and a real
+               notifier task injected with mocks to capture output. The WS state is hard-coded as up, so no testing of recovery
+               from WS connection issues is included in this test.
+
+               Primarily it makes sure that updates are passed through the buffer and processed if and only if a delete event
+               doesn't arrive while they're buffered.
+               """
+
+               buffer_time = timedelta(seconds = 1)
+               notifier = MockNotifier()
+               (event_tx, event_rx) = open_memory_channel(0)
+
+               ign_event = _update_event("sender", [])
+               event_a = _update_event("sender", ["rec1", "rec2"])
+               event_b = _update_event("sender", ["rec1", "rec2"])
+
+               async def event_stream_impl(scope: CancelScope) -> None:
+                       await sleep_until(1.0)
+                       await event_tx.send(ign_event)
+                       await event_tx.send(event_a)
+
+                       await sleep_until(1.5)
+                       self.assertFalse(notifier.notifies)
+                       await event_tx.send(DeleteEvent(event_a.payload_data["id"]))
+
+                       await sleep_until(2.0)
+                       await event_tx.send(event_b)
+
+                       # let things settle and terminate both tasks
+                       await sleep_until(4.0)
+                       scope.cancel()
+
+               async def main() -> None:
+                       async with open_nursery() as nursery:
+                               nursery.start_soon(event_stream_impl, nursery.cancel_scope)
+                               nursery.start_soon(
+                                       message_sender_impl,
+                                       notifier, MockTitleProvider(), UnusedUpdateChecker(), WsListenerState(down = False), "ThreadId",
+                                       event_rx, buffer_time, getLogger()
+                               )
+
+               trio.run(main, clock = MockClock(autojump_threshold = 0))
+
+               # only notifications sent are for event B
+               self.assertEqual(len(notifier.notifies), 2)
+               self.assertEqual({i.recipient for i in notifier.notifies}, {"rec1", "rec2"})
+               self.assertTrue(all(i.author == "sender" for i in notifier.notifies))
+               self.assertTrue(all(i.update_id == event_b.payload_data["id"] for i in notifier.notifies))
+
+       def test_ws_trouble(self) -> None:
+               """
+               Ensure that the notifier task responds correctly to the WS connection going down and coming back up.
+               """
+
+               thread_id = "ThreadId"
+               buffer_time = timedelta(seconds = 3)
+               update_checker = DictBackedUpdateChecker()
+               notifier = MockNotifier()
+               (event_tx, event_rx) = open_memory_channel(0)
+
+               # sent while WS is up, checked and notifies sent
+               event_a = _update_event("sender", ["rec"])
+
+               # sent and deleted while WS is down, checked and discarded from buffer
+               event_b = _update_event("sender", ["rec"])
+
+               # sent on WS reconnect, not checked and notifies sent
+               event_c = _update_event("sender", ["rec"])
+
+               # sent after WS reconnect, not checked and notifies sent
+               event_d = _update_event("sender", ["rec"])
+
+               ws_state = WsListenerState(down = False)
+
+               async def ws_simulator_impl() -> None:
+                       update_checker.exists_by_id[event_a.payload_data["id"]] = True
+                       await event_tx.send(event_a)
+
+                       await sleep_until(1.0)
+                       ws_state.down = True
+                       update_checker.exists_by_id[event_b.payload_data["id"]] = True
+                       await event_tx.send(event_b)
+
+                       await sleep_until(2.0)
+                       update_checker.exists_by_id[event_b.payload_data["id"]] = False
+                       ws_state.down = False
+                       ws_state.resync_ts = UUID(event_c.payload_data["id"]).time
+                       await event_tx.send(event_c)
+
+                       await sleep_until(3.0)
+                       await event_tx.send(event_d)
+
+               async def main() -> None:
+                       async with open_nursery() as nursery:
+                               nursery.start_soon(ws_simulator_impl)
+                               nursery.start_soon(
+                                       message_sender_impl,
+                                       notifier, MockTitleProvider(), update_checker, ws_state, "ThreadId", event_rx, buffer_time,
+                                       getLogger()
+                               )
+
+                               await sleep_until(3.5)
+                               self.assertEqual([i.update_id for i in notifier.notifies], [event_a.payload_data["id"]])
+                               notifier.notifies.clear()
+                               self.assertEqual(update_checker.queries, [(thread_id, event_a.payload_data["id"])])
+                               update_checker.queries.clear()
+
+                               await sleep_until(4.5)
+                               self.assertFalse(notifier.notifies)
+                               self.assertEqual(update_checker.queries, [(thread_id, event_b.payload_data["id"])])
+                               update_checker.queries.clear()
+
+                               await sleep_until(5.5)
+                               self.assertEqual([i.update_id for i in notifier.notifies], [event_c.payload_data["id"]])
+                               notifier.notifies.clear()
+                               self.assertFalse(update_checker.queries)
+
+                               await sleep_until(6.5)
+                               self.assertEqual([i.update_id for i in notifier.notifies], [event_d.payload_data["id"]])
+                               notifier.notifies.clear()
+                               self.assertFalse(update_checker.queries)
+
+                               nursery.cancel_scope.cancel()
+
+               trio.run(main, clock = MockClock(autojump_threshold = 0))
+
+       def test_ws_listener_cancel(self) -> None:
+               ws_state = WsListenerState()
+
+               (event_in_tx, event_in_rx) = open_memory_channel(0)
+               event_stream = ChannelBackedEventStream(event_in_rx)
+               event_stream_factory = lambda: null_async_context(event_stream)
+
+               (event_out_tx, event_out_rx) = open_memory_channel(0)
+
+               async def main():
+                       async with open_nursery() as nursery:
+                               nursery.start_soon(
+                                       ws_listener_impl,
+                                       event_stream_factory, set(), ws_state, event_out_tx, getLogger()
+                               )
+
+                               event_1 = _update_event("author", [])
+                               await event_in_tx.send(event_1)
+                               with self._assert_completes_in(1):
+                                       event_out = await event_out_rx.receive()
+                               self.assertIs(event_out, event_1)
+
+                               event_2 = _update_event("author", ["recipient"])
+                               await event_in_tx.send(event_2)
+                               with self._assert_completes_in(1):
+                                       event_out = await event_out_rx.receive()
+                               self.assertIs(event_out, event_2)
+
+                               initial_scope = ws_state.scope
+                               initial_scope.cancel()
+                               await wait_all_tasks_blocked()
+                               self.assertIsNotNone(ws_state.scope)
+                               self.assertIsNot(ws_state.scope, initial_scope)
+                               self.assertTrue(ws_state.down)
+
+                               event_3 = _update_event("author", [])
+                               ts = UUID(event_3.payload_data["id"]).time
+                               await event_in_tx.send(event_3)
+                               with self._assert_completes_in(1):
+                                       event_out = await event_out_rx.receive()
+                               self.assertIs(event_out, event_3)
+                               self.assertEqual(ws_state.last_update_ts, ts)
+                               self.assertFalse(ws_state.down)
+                               self.assertEqual(ws_state.resync_ts, ts)
+
+                               nursery.cancel_scope.cancel()
+
+               trio.run(main, clock = MockClock(autojump_threshold = 0))
+
+       def test_ws_listener_error(self) -> None:
+               """Test that the WS listener responds correctly to errors from the input event stream."""
+
+               class MockStreamException(Exception):
+                       pass
+
+               ws_state = WsListenerState()
+
+               (event_in_tx, event_in_rx) = open_memory_channel(0)
+               event_stream = ChannelBackedEventStream(event_in_rx)
+               event_stream_factory = lambda: null_async_context(event_stream)
+
+               (event_out_tx, event_out_rx) = open_memory_channel(0)
+
+               async def main():
+                       async with open_nursery() as nursery:
+                               nursery.start_soon(
+                                       ws_listener_impl,
+                                       event_stream_factory, {MockStreamException}, ws_state, event_out_tx, getLogger()
+                               )
+
+                               event_1 = _update_event("author", [])
+                               await event_in_tx.send(event_1)
+                               with self._assert_completes_in(1):
+                                       event_out = await event_out_rx.receive()
+                               self.assertIs(event_out, event_1)
+
+                               initial_scope = ws_state.scope
+                               await event_in_tx.send(MockStreamException())
+                               await wait_all_tasks_blocked()
+                               self.assertIsNotNone(ws_state.scope)
+                               self.assertIsNot(ws_state.scope, initial_scope)
+                               self.assertTrue(ws_state.down)
+
+                               event_2 = _update_event("author", ["recipient"])
+                               ts = UUID(event_2.payload_data["id"]).time
+                               await event_in_tx.send(event_2)
+                               with self._assert_completes_in(1):
+                                       event_out = await event_out_rx.receive()
+                               self.assertIs(event_out, event_2)
+                               self.assertEqual(ws_state.last_update_ts, ts)
+                               self.assertFalse(ws_state.down)
+                               self.assertEqual(ws_state.resync_ts, ts)
+
+                               nursery.cancel_scope.cancel()
+
+               trio.run(main, clock = MockClock(autojump_threshold = 0))
+
+
+class RedditTests(TestCase):
+       """Tests for interactions with Reddit."""
+
+       @classmethod
+       def setUpClass(cls) -> None:
+               cls._TOKEN_REF = TokenRef(environ["API_TOKEN"])
+
+       def test_try_request(self) -> None:
+               self.assertIsNotNone(
+                       try_request(
+                               Request(
+                                       "https://oauth.reddit.com/api/live/happening_now",
+                                       headers = self._TOKEN_REF.get_common_headers(),
+                               ),
+                               getLogger(),
+                       )
+               )
+
+       def test_try_request_http_error(self) -> None:
+               self.assertIsNone(
+                       try_request(
+                               Request("https://oauth.reddit.com/api/live/happening_now"),
+                               getLogger(),
+                               suppress_codes = [403],
+                       )
+               )
diff --git a/mentionbot/pyproject.toml b/mentionbot/pyproject.toml
deleted file mode 100644 (file)
index 8fe2f47..0000000
+++ /dev/null
@@ -1,3 +0,0 @@
-[build-system]
-requires = ["setuptools>=42", "wheel"]
-build-backend = "setuptools.build_meta"
diff --git a/mentionbot/setup.cfg b/mentionbot/setup.cfg
deleted file mode 100644 (file)
index b3f9866..0000000
+++ /dev/null
@@ -1,15 +0,0 @@
-[metadata]
-name = mentionbot
-version = 0.0.0
-
-[options]
-package_dir =
-       = src
-packages = mentionbot
-python_requires = ~= 3.9
-install_requires =
-       beautifulsoup4 ~= 4.9
-       #psycopg2 ~= 2.8
-       psycopg2-binary ~= 2.8
-       trio == 0.19
-       trio-websocket == 0.9.2
diff --git a/mentionbot/setup.py b/mentionbot/setup.py
deleted file mode 100644 (file)
index 056ba45..0000000
+++ /dev/null
@@ -1,4 +0,0 @@
-import setuptools
-
-
-setuptools.setup()
diff --git a/mentionbot/src/mentionbot/__init__.py b/mentionbot/src/mentionbot/__init__.py
deleted file mode 100644 (file)
index 76fd7e2..0000000
+++ /dev/null
@@ -1,187 +0,0 @@
-from dataclasses import dataclass
-from datetime import timedelta
-from http.client import HTTPResponse
-from logging import Logger
-from socket import EAI_AGAIN, EAI_FAIL, gaierror
-from typing import AsyncContextManager, Callable, Iterable, Optional, Type
-from urllib.error import HTTPError
-from urllib.request import Request, urlopen
-from uuid import UUID
-import importlib.metadata
-import re
-
-from bs4 import BeautifulSoup
-from trio import CancelScope, MemoryReceiveChannel, MemorySendChannel, move_on_at
-import trio
-
-from .abc import DeleteEvent, EventStream, Notifier, NotifyInfo, ThreadTitleProvider, UpdateChecker, UpdateEvent
-
-
-_USER_AGENT = "any:net.jcornell.mentionbot:v{} (by /u/jaxklax)".format(
-       importlib.metadata.version(__package__)
-)
-
-
-@dataclass
-class TokenRef:
-       """Mutable container for an access token."""
-
-       token: str
-
-       def get_common_headers(self) -> dict[str, str]:
-               return {
-                       "User-Agent": _USER_AGENT,
-                       "Authorization": "Bearer {}".format(self.token)
-               }
-
-
-def try_request(
-       request: Request,
-       logger: Logger,
-       suppress_codes: list[int] = [404, 429, 500, 503],
-) -> Optional[HTTPResponse]:
-
-       try:
-               return urlopen(request)
-       except gaierror as e:
-               if e.errno in [EAI_FAIL, EAI_AGAIN]:
-                       logger.error(f"DNS failure: {e}")
-                       return None
-               else:
-                       raise
-       except HTTPError as e:
-               e.read()
-               e.close()
-               if e.code in suppress_codes:
-                       logger.error(f"HTTP {e.code}")
-                       return None
-               else:
-                       raise
-
-
-@dataclass
-class WsListenerState:
-       """The WebSockets listener uses this object to communicate with the API poll and message sender tasks."""
-
-       # The API poll task uses this scope to force a refresh of the WS connection if it's been too quiet.
-       scope: Optional[CancelScope] = None
-
-       # The WS task uses this to track the UUID timestamp of the most recent update message update, so the API poll task
-       # can detect when the connection goes silent.
-       last_update_ts: Optional[int] = None
-
-       # The WS task sets this when the WS connection goes down and clears it when it's recovered.
-       down: bool = True
-
-       # The WS task sets this to the UUID timestamp of the first update recieved on connection recovery so the message
-       # sender task knows when to stop checking for deletion.
-       resync_ts: Optional[int] = None
-
-
-def _parse_mentions(body: BeautifulSoup) -> set[str]:
-       """Extract the names of users tagged in an update."""
-       return {
-               el["href"].removeprefix("/u/")
-               for el in body.find_all("a", href = re.compile(r"^/u/[\w-]{3,20}$", re.ASCII))
-       }
-
-
-async def ws_listener_impl(
-       event_stream_factory: Callable[[], AsyncContextManager[EventStream]],
-       event_stream_catch: Iterable[Type[BaseException]],
-       state: WsListenerState,
-       event_tx: MemorySendChannel,  # message type `Event'
-       logger: Logger,
-) -> None:
-
-       while True:
-               state.scope = CancelScope()
-               with state.scope:
-                       logger.debug("opening a new connection")
-                       try:
-                               async with event_stream_factory() as event_stream:
-                                       first_update = True
-                                       while True:
-                                               event = await event_stream.get_event()
-                                               if isinstance(event, UpdateEvent):
-                                                       update_id = UUID(event.payload_data["id"])
-                                                       state.last_update_ts = update_id.time
-                                                       if first_update:
-                                                               # first update after discontinuity
-                                                               state.down = False
-                                                               state.resync_ts = update_id.time
-                                                               first_update = False
-                                               elif isinstance(event, DeleteEvent):
-                                                       logger.debug("delete event for " + event.payload)
-                                               else:
-                                                       raise RuntimeError(f"expected Event, got {event}")
-
-                                               await event_tx.send(event)
-                       except tuple(event_stream_catch) as e:
-                               logger.warning(f"WebSockets error: {e}")
-
-               state.down = True
-
-
-async def message_sender_impl(
-       notifier: Notifier,
-       title_provider: ThreadTitleProvider,
-       update_checker: UpdateChecker,
-       ws_state: WsListenerState,
-       thread_id: str,
-       event_rx: MemoryReceiveChannel,  # message type `Event'
-       buffer_time: timedelta,
-       mention_limit: int,
-       logger: Logger,
-) -> None:
-
-       @dataclass
-       class QueueEntry:
-               payload_data: dict
-               mentions: set[str]
-               arrival: float  # Trio time
-
-       queue: list[QueueEntry] = []
-
-       while True:
-               logger.debug("queue: {}".format([e.payload_data["id"] for e in queue]))
-               if queue:
-                       scope = move_on_at(queue[0].arrival + buffer_time.total_seconds())
-               else:
-                       scope = CancelScope()  # no timeout
-
-               with scope:
-                       event = await event_rx.receive()
-
-               if scope.cancelled_caught:
-                       # timed out
-                       entry = queue.pop(0)
-                       update_id = entry.payload_data["id"]
-                       if ws_state.down or ws_state.resync_ts and UUID(update_id).time < ws_state.resync_ts:
-                               notify = update_checker.exists(thread_id, update_id)
-                       else:
-                               notify = True
-
-                       if notify:
-                               thread_title = title_provider.title_for(thread_id)
-                               for recipient in entry.mentions:
-                                       notifier.notify(NotifyInfo(
-                                               recipient, thread_id, thread_title, update_id, entry.payload_data["author"],
-                                               entry.payload_data["body"]
-                                       ))
-               else:
-                       # got an event
-                       if isinstance(event, UpdateEvent):
-                               update_id = event.payload_data["id"]
-                               if not any(entry.payload_data["id"] == update_id for entry in queue):
-                                       mentions = _parse_mentions(BeautifulSoup(event.payload_data["body_html"], "html.parser"))
-                                       if mentions:
-                                               if len(mentions) > mention_limit:
-                                                       logger.info(f"ignoring {update_id} due to too many mentions ({len(mentions)})")
-                                               else:
-                                                       queue.append(QueueEntry(event.payload_data, mentions, trio.current_time()))
-                       elif isinstance(event, DeleteEvent):
-                               logger.debug("delete event for " + event.payload)
-                               queue = [entry for entry in queue if entry.payload_data["name"] != event.payload]
-                       else:
-                               raise RuntimeError(f"invalid event type {type(event)}")
diff --git a/mentionbot/src/mentionbot/__main__.py b/mentionbot/src/mentionbot/__main__.py
deleted file mode 100644 (file)
index 960a228..0000000
+++ /dev/null
@@ -1,291 +0,0 @@
-"""
-Mentionbot: Notify users by private message when they're /u/-mentioned in a live thread.
-"""
-
-from argparse import ArgumentParser
-from configparser import ConfigParser
-from contextlib import asynccontextmanager
-from dataclasses import dataclass
-from datetime import timedelta
-from logging import Formatter, getLogger, Logger, NOTSET, StreamHandler
-from sys import stdout
-from typing import ClassVar, Iterator, Optional, Type
-from urllib.error import HTTPError
-from urllib.parse import urlencode
-from urllib.request import Request, urlopen
-from uuid import UUID
-import json
-import logging
-
-from trio import MemorySendChannel, open_memory_channel, open_nursery
-from trio_websocket import ConnectionClosed, HandshakeError, open_websocket_url, WebSocketConnection
-import psycopg2
-import trio
-
-from . import TokenRef, message_sender_impl, try_request, ws_listener_impl, WsListenerState
-from .abc import DeleteEvent, Event, EventStream, Notifier, NotifyInfo, ThreadTitleProvider, UpdateChecker, UpdateEvent
-
-
-async def _token_update_impl(period: timedelta, db_conn, auth_id: int, task_status) -> None:
-       def fetch_token() -> str:
-               cursor = db_conn.cursor()
-               cursor.execute(
-                       """
-                               select access_token from public.reddit_app_authorization
-                               where id = %s
-                       """,
-                       (auth_id,)
-               )
-               [(token,)] = cursor.fetchall()
-               return token
-
-       ref = TokenRef(fetch_token())
-       task_status.started(ref)
-       while True:
-               await trio.sleep(period.total_seconds())
-               ref.token = fetch_token()
-
-
-async def _api_poll_impl(
-       ws_state: WsListenerState,
-       token_ref: TokenRef,
-       thread_id: str,
-       event_tx: MemorySendChannel,  # message type `Event'
-       poll_period: timedelta,
-       max_ws_delay: timedelta,
-       logger: Logger,
-) -> None:
-
-       base_url = f"https://oauth.reddit.com/live/{thread_id}"
-
-       latest_payload: Optional[dict] = None
-       while True:
-               # fetch updates and emit events
-               params = {
-                       "raw_json": "1",
-               }
-               if latest_payload:
-                       updates = []
-                       while True:
-                               params["before"] = latest_payload["data"]["name"]
-                               url = base_url + "?" + urlencode(params)
-                               resp = try_request(Request(url, headers = token_ref.get_common_headers()), logger)
-                               if resp:
-                                       with resp:
-                                               children = json.load(resp)["data"]["children"]
-                                       if children:
-                                               updates.extend(reversed(children))
-                                               latest_payload = children[0]
-                                       else:
-                                               break
-                               else:
-                                       break
-               else:
-                       params["limit"] = "1"
-                       url = base_url + "?" + urlencode(params)
-                       resp = try_request(Request(url, headers = token_ref.get_common_headers()), logger)
-                       if resp:
-                               with resp:
-                                       updates = json.load(resp)["data"]["children"]
-                               if updates:
-                                       latest_payload = updates[0]
-                       else:
-                               updates = []
-
-               for update in updates:
-                       await event_tx.send(UpdateEvent(update["data"]))
-
-               # check if WS connection is behind
-               if not ws_state.down and latest_payload:
-                       max_delta = max_ws_delay.total_seconds() * 10 ** 7  # in 100*nanosecond
-                       if UUID(latest_payload["data"]["id"]).time - ws_state.last_update_ts > max_delta:
-                               ws_state.scope.cancel()
-
-               await trio.sleep(poll_period.total_seconds())
-
-
-def _build_notification_md(info: NotifyInfo) -> str:
-       # TODO escape thread title?
-       return "\n\n".join([
-               f"from [{info.author}](/user/{info.author}) in [{info.thread_title}](/live/{info.thread_id})",
-               "\n".join(">" + line for line in info.update_body.splitlines()),
-               (
-                       f"[^^permalink](/live/{info.thread_id}/updates/{info.update_id})"
-                       + f" [^^context](/live/{info.thread_id}?after=LiveUpdate_{info.update_id})"
-                       + " [^^about ^^these ^^notifications](/r/live_mentions/wiki)"
-               ),
-       ])
-
-
-@dataclass
-class RedditNotifier(Notifier):
-       token_ref: TokenRef
-       logger: Logger
-
-       def notify(self, info: NotifyInfo) -> None:
-               params = {
-                       "raw_json": "1",
-                       "api_type": "json",
-                       "subject": "username mention",
-                       "text": _build_notification_md(info),
-               }
-               req = Request(
-                       "https://oauth.reddit.com/api/compose",
-                       method = "POST",
-                       data = urlencode({**params, "to": info.recipient}).encode("ascii"),
-                       headers = self.token_ref.get_common_headers(),
-               )
-               resp = try_request(req, self.logger)
-               if resp:
-                       with resp:
-                               data = json.load(resp)["json"]
-                       if data["errors"]:
-                               [(code, _, _)] = data["errors"]
-                               if code != "USER_DOESNT_EXIST":
-                                       raise RuntimeError("PM send error: {code}")
-
-
-@dataclass
-class RedditThreadTitleProvider(ThreadTitleProvider):
-       token_ref: TokenRef
-
-       def title_for(self, thread_id: str) -> str:
-               req = Request(
-                       f"https://oauth.reddit.com/live/{thread_id}/about?raw_json=1",
-                       headers = self.token_ref.get_common_headers()
-               )
-               with urlopen(req) as resp:
-                       return json.load(resp)["data"]["title"]
-
-
-@dataclass
-class RedditUpdateChecker(UpdateChecker):
-       token_ref: TokenRef
-       logger: Logger
-
-       def exists(self, thread_id: str, update_id: str) -> bool:
-               req = Request(
-                       f"https://oauth.reddit.com/live/{thread_id}/updates/{update_id}?raw_json=1",
-                       headers = self.token_ref.get_common_headers(),
-               )
-               try:
-                       resp = try_request(req, self.logger, suppress_codes = [429, 500, 503])
-               except HTTPError as e:
-                       e.read()
-                       e.close()
-                       if e.code == 404:
-                               return False
-                       else:
-                               raise
-               else:
-                       if resp:
-                               resp.read()
-                               resp.close()
-                       return True
-
-
-@dataclass
-class WsConnectionEventStream(EventStream):
-       EXCEPTION_TYPES: ClassVar[set[Type[BaseException]]] = {HandshakeError, ConnectionClosed}
-
-       ws_conn: WebSocketConnection
-
-       async def get_event(self) -> Event:
-               while True:
-                       message = json.loads(await self.ws_conn.get_message())
-                       if message["type"] == "update":
-                               return UpdateEvent(message["payload"]["data"])
-                       elif message["type"] == "delete":
-                               return DeleteEvent(message["payload"])
-
-
-async def _main(
-       auth_id: int,
-       buffer_time: timedelta,
-       db_conn,
-       logger: Logger,
-       max_ws_delay: timedelta,
-       mention_limit: int,
-       poll_period: timedelta,
-       thread_id: str,
-       token_period: timedelta,
-) -> None:
-
-       @asynccontextmanager
-       async def event_stream_factory() -> Iterator[WsConnectionEventStream]:
-               resp = urlopen(Request(
-                       f"https://oauth.reddit.com/live/{thread_id}/about?raw_json=1",
-                       headers = token_ref.get_common_headers(),
-               ))
-               with resp:
-                       ws_url = json.load(resp)["data"]["websocket_url"]
-
-               async with open_websocket_url(ws_url) as ws_conn:
-                       yield WsConnectionEventStream(ws_conn)
-
-       async with open_nursery() as nursery:
-               token_ref = await nursery.start(_token_update_impl, token_period, db_conn, auth_id)
-
-               ws_state = WsListenerState()
-               (event_tx, event_rx) = open_memory_channel(0)
-               nursery.start_soon(
-                       ws_listener_impl,
-                       event_stream_factory, WsConnectionEventStream.EXCEPTION_TYPES, ws_state, event_tx, logger.getChild("WS")
-               )
-               nursery.start_soon(
-                       _api_poll_impl,
-                       ws_state, token_ref, thread_id, event_tx, poll_period, max_ws_delay, logger.getChild("poll")
-               )
-
-               notifier = RedditNotifier(token_ref, logger)
-               title_provider = RedditThreadTitleProvider(token_ref)
-               deletion_checker = RedditUpdateChecker(token_ref, logger)
-               nursery.start_soon(
-                       message_sender_impl,
-                       notifier, title_provider, deletion_checker, ws_state, thread_id, event_rx, buffer_time, mention_limit,
-                       logger.getChild("sender")
-               )
-
-
-if __name__ == "__main__":
-       arg_parser = ArgumentParser(__package__)
-       arg_parser.add_argument("config_path")
-       args = arg_parser.parse_args()
-
-       config_parser = ConfigParser()
-       with open(args.config_path) as config_file:
-               config_parser.read_file(config_file)
-
-       main_cfg = config_parser["config"]
-
-       auth_id = main_cfg.getint("auth ID")
-       assert auth_id is not None
-
-       token_period = timedelta(minutes = main_cfg.getint("token update period"))
-
-       thread_id = main_cfg["thread ID"]
-
-       mention_limit = main_cfg.getint("mention limit")
-       assert mention_limit is not None and mention_limit >= 0
-
-       buffer_time = timedelta(seconds = main_cfg.getfloat("message buffer time"))
-
-       poll_period = timedelta(seconds = main_cfg.getfloat("API poll period"))
-
-       max_ws_delay = timedelta(seconds = main_cfg.getfloat("max WS delay"))
-
-       db_conn = psycopg2.connect(**config_parser["db connect params"])
-       db_conn.autocommit = True
-
-       logger = getLogger(__package__)
-       logger.setLevel(NOTSET - 1)  # filter in handler(s) instead
-
-       handler = StreamHandler(stdout)
-       handler.setLevel(logging.DEBUG)
-       handler.setFormatter(Formatter("{asctime:23}: {name:17}: {levelname:8}: {message}", style = "{"))
-       logger.addHandler(handler)
-
-       trio.run(
-               _main,
-               auth_id, buffer_time, db_conn, logger, max_ws_delay, mention_limit, poll_period, thread_id, token_period
-       )
diff --git a/mentionbot/src/mentionbot/abc.py b/mentionbot/src/mentionbot/abc.py
deleted file mode 100644 (file)
index 3c1d7fd..0000000
+++ /dev/null
@@ -1,50 +0,0 @@
-from abc import ABCMeta, abstractmethod
-from dataclasses import dataclass
-
-
-class Event(metaclass = ABCMeta):
-       pass
-
-
-@dataclass
-class UpdateEvent(Event):
-       payload_data: dict  # object at "payload" -> "data" in the message JSON
-
-
-@dataclass
-class DeleteEvent(Event):
-       payload: str  # update name
-
-
-class EventStream(metaclass = ABCMeta):
-       @abstractmethod
-       async def get_event(self) -> Event:
-               raise NotImplementedError()
-
-
-@dataclass
-class NotifyInfo:
-       recipient: str
-       thread_id: str
-       thread_title: str
-       update_id: str
-       author: str
-       update_body: str
-
-
-class Notifier(metaclass = ABCMeta):
-       @abstractmethod
-       def notify(self, info: NotifyInfo) -> None:
-               raise NotImplementedError()
-
-
-class ThreadTitleProvider(metaclass = ABCMeta):
-       @abstractmethod
-       def title_for(self, thread_id: str) -> str:
-               raise NotImplementedError()
-
-
-class UpdateChecker(metaclass = ABCMeta):
-       @abstractmethod
-       def exists(self, thread_id: str, update_id: str) -> bool:
-               raise NotImplementedError()
diff --git a/mentionbot/src/mentionbot/tests.py b/mentionbot/src/mentionbot/tests.py
deleted file mode 100644 (file)
index 0617c90..0000000
+++ /dev/null
@@ -1,350 +0,0 @@
-from contextlib import asynccontextmanager, contextmanager
-from dataclasses import dataclass, field
-from datetime import timedelta
-from logging import getLogger
-from numbers import Real
-from os import environ
-from typing import Iterator, TypeVar
-from unittest import TestCase
-from urllib.request import Request
-from uuid import UUID, uuid1
-import logging
-
-from trio import CancelScope, MemoryReceiveChannel, move_on_after, open_memory_channel, open_nursery, sleep_until
-from trio.testing import MockClock, wait_all_tasks_blocked
-import trio
-
-from . import message_sender_impl, TokenRef, try_request, WsListenerState, ws_listener_impl
-from .abc import DeleteEvent, Event, EventStream, Notifier, NotifyInfo, ThreadTitleProvider, UpdateChecker, UpdateEvent
-
-
-T = TypeVar("T")
-
-
-@dataclass
-class MockNotifier(Notifier):
-       notifies: list[NotifyInfo] = field(default_factory = list)
-
-       def notify(self, info: NotifyInfo) -> None:
-               self.notifies.append(info)
-
-
-class MockTitleProvider(ThreadTitleProvider):
-       def title_for(self, thread_id: str) -> str:
-               return thread_id
-
-
-class UnusedUpdateChecker(UpdateChecker):
-       def exists(self, thread_id: str, update_id: str) -> bool:
-               raise AssertionError("unexpected deletion check")
-
-
-@dataclass
-class DictBackedUpdateChecker(UpdateChecker):
-       exists_by_id: dict[str, bool] = field(default_factory = dict)
-       queries: list[tuple[str, str]] = field(default_factory = list)
-
-       def exists(self, thread_id: str, update_id: str) -> bool:
-               self.queries.append((thread_id, update_id))
-               return self.exists_by_id[update_id]
-
-
-@dataclass
-class ChannelBackedEventStream(EventStream):
-       event_rx: MemoryReceiveChannel
-
-       async def get_event(self) -> Event:
-               obj = await self.event_rx.receive()
-               if isinstance(obj, BaseException):
-                       raise obj
-               else:
-                       return obj
-
-
-@asynccontextmanager
-async def null_async_context(val: T) -> Iterator[T]:
-       yield val
-
-
-def _update_event(author: str, mentions: list[str]) -> UpdateEvent:
-       id_ = uuid1()
-       body = str(id_) + " " + " ".join("/u/" + user for user in mentions)
-       body_html = str(id_) + " " + " ".join(
-               f'<a href="/u/{user}">/u/{user}</a>'
-               for user in mentions
-       )
-       return UpdateEvent(
-               {
-                       "id": str(id_),
-                       "name": "LiveUpdate_" + str(id_),
-                       "author": author,
-                       "body": body,
-                       "body_html": body_html,
-               }
-       )
-
-
-def setUpModule() -> None:
-       getLogger().setLevel(logging.CRITICAL + 1)  # disable logging
-
-
-class TaskTests(TestCase):
-       @contextmanager
-       def _assert_completes_in(self, duration_seconds: Real) -> Iterator[None]:
-               with move_on_after(duration_seconds) as scope:
-                       yield
-               if scope.cancelled_caught:
-                       self.fail("operation timed out")
-
-       def test_buffering(self) -> None:
-               """
-               No API poll task, just a custom task mocking the WS listener by playing events over a channel, and a real
-               notifier task injected with mocks to capture output. The WS state is hard-coded as up, so no testing of recovery
-               from WS connection issues is included in this test.
-
-               Primarily it makes sure that updates are passed through the buffer and processed if and only if a delete event
-               doesn't arrive while they're buffered.
-               """
-
-               buffer_time = timedelta(seconds = 1)
-               notifier = MockNotifier()
-               (event_tx, event_rx) = open_memory_channel(0)
-
-               ign_event = _update_event("sender", [])
-               event_a = _update_event("sender", ["rec1", "rec2"])
-               event_b = _update_event("sender", ["rec1", "rec2"])
-
-               async def event_stream_impl(scope: CancelScope) -> None:
-                       await sleep_until(1.0)
-                       await event_tx.send(ign_event)
-                       await event_tx.send(event_a)
-
-                       await sleep_until(1.5)
-                       self.assertFalse(notifier.notifies)
-                       await event_tx.send(DeleteEvent(event_a.payload_data["id"]))
-
-                       await sleep_until(2.0)
-                       await event_tx.send(event_b)
-
-                       # let things settle and terminate both tasks
-                       await sleep_until(4.0)
-                       scope.cancel()
-
-               async def main() -> None:
-                       async with open_nursery() as nursery:
-                               nursery.start_soon(event_stream_impl, nursery.cancel_scope)
-                               nursery.start_soon(
-                                       message_sender_impl,
-                                       notifier, MockTitleProvider(), UnusedUpdateChecker(), WsListenerState(down = False), "ThreadId",
-                                       event_rx, buffer_time, getLogger()
-                               )
-
-               trio.run(main, clock = MockClock(autojump_threshold = 0))
-
-               # only notifications sent are for event B
-               self.assertEqual(len(notifier.notifies), 2)
-               self.assertEqual({i.recipient for i in notifier.notifies}, {"rec1", "rec2"})
-               self.assertTrue(all(i.author == "sender" for i in notifier.notifies))
-               self.assertTrue(all(i.update_id == event_b.payload_data["id"] for i in notifier.notifies))
-
-       def test_ws_trouble(self) -> None:
-               """
-               Ensure that the notifier task responds correctly to the WS connection going down and coming back up.
-               """
-
-               thread_id = "ThreadId"
-               buffer_time = timedelta(seconds = 3)
-               update_checker = DictBackedUpdateChecker()
-               notifier = MockNotifier()
-               (event_tx, event_rx) = open_memory_channel(0)
-
-               # sent while WS is up, checked and notifies sent
-               event_a = _update_event("sender", ["rec"])
-
-               # sent and deleted while WS is down, checked and discarded from buffer
-               event_b = _update_event("sender", ["rec"])
-
-               # sent on WS reconnect, not checked and notifies sent
-               event_c = _update_event("sender", ["rec"])
-
-               # sent after WS reconnect, not checked and notifies sent
-               event_d = _update_event("sender", ["rec"])
-
-               ws_state = WsListenerState(down = False)
-
-               async def ws_simulator_impl() -> None:
-                       update_checker.exists_by_id[event_a.payload_data["id"]] = True
-                       await event_tx.send(event_a)
-
-                       await sleep_until(1.0)
-                       ws_state.down = True
-                       update_checker.exists_by_id[event_b.payload_data["id"]] = True
-                       await event_tx.send(event_b)
-
-                       await sleep_until(2.0)
-                       update_checker.exists_by_id[event_b.payload_data["id"]] = False
-                       ws_state.down = False
-                       ws_state.resync_ts = UUID(event_c.payload_data["id"]).time
-                       await event_tx.send(event_c)
-
-                       await sleep_until(3.0)
-                       await event_tx.send(event_d)
-
-               async def main() -> None:
-                       async with open_nursery() as nursery:
-                               nursery.start_soon(ws_simulator_impl)
-                               nursery.start_soon(
-                                       message_sender_impl,
-                                       notifier, MockTitleProvider(), update_checker, ws_state, "ThreadId", event_rx, buffer_time,
-                                       getLogger()
-                               )
-
-                               await sleep_until(3.5)
-                               self.assertEqual([i.update_id for i in notifier.notifies], [event_a.payload_data["id"]])
-                               notifier.notifies.clear()
-                               self.assertEqual(update_checker.queries, [(thread_id, event_a.payload_data["id"])])
-                               update_checker.queries.clear()
-
-                               await sleep_until(4.5)
-                               self.assertFalse(notifier.notifies)
-                               self.assertEqual(update_checker.queries, [(thread_id, event_b.payload_data["id"])])
-                               update_checker.queries.clear()
-
-                               await sleep_until(5.5)
-                               self.assertEqual([i.update_id for i in notifier.notifies], [event_c.payload_data["id"]])
-                               notifier.notifies.clear()
-                               self.assertFalse(update_checker.queries)
-
-                               await sleep_until(6.5)
-                               self.assertEqual([i.update_id for i in notifier.notifies], [event_d.payload_data["id"]])
-                               notifier.notifies.clear()
-                               self.assertFalse(update_checker.queries)
-
-                               nursery.cancel_scope.cancel()
-
-               trio.run(main, clock = MockClock(autojump_threshold = 0))
-
-       def test_ws_listener_cancel(self) -> None:
-               ws_state = WsListenerState()
-
-               (event_in_tx, event_in_rx) = open_memory_channel(0)
-               event_stream = ChannelBackedEventStream(event_in_rx)
-               event_stream_factory = lambda: null_async_context(event_stream)
-
-               (event_out_tx, event_out_rx) = open_memory_channel(0)
-
-               async def main():
-                       async with open_nursery() as nursery:
-                               nursery.start_soon(
-                                       ws_listener_impl,
-                                       event_stream_factory, set(), ws_state, event_out_tx, getLogger()
-                               )
-
-                               event_1 = _update_event("author", [])
-                               await event_in_tx.send(event_1)
-                               with self._assert_completes_in(1):
-                                       event_out = await event_out_rx.receive()
-                               self.assertIs(event_out, event_1)
-
-                               event_2 = _update_event("author", ["recipient"])
-                               await event_in_tx.send(event_2)
-                               with self._assert_completes_in(1):
-                                       event_out = await event_out_rx.receive()
-                               self.assertIs(event_out, event_2)
-
-                               initial_scope = ws_state.scope
-                               initial_scope.cancel()
-                               await wait_all_tasks_blocked()
-                               self.assertIsNotNone(ws_state.scope)
-                               self.assertIsNot(ws_state.scope, initial_scope)
-                               self.assertTrue(ws_state.down)
-
-                               event_3 = _update_event("author", [])
-                               ts = UUID(event_3.payload_data["id"]).time
-                               await event_in_tx.send(event_3)
-                               with self._assert_completes_in(1):
-                                       event_out = await event_out_rx.receive()
-                               self.assertIs(event_out, event_3)
-                               self.assertEqual(ws_state.last_update_ts, ts)
-                               self.assertFalse(ws_state.down)
-                               self.assertEqual(ws_state.resync_ts, ts)
-
-                               nursery.cancel_scope.cancel()
-
-               trio.run(main, clock = MockClock(autojump_threshold = 0))
-
-       def test_ws_listener_error(self) -> None:
-               """Test that the WS listener responds correctly to errors from the input event stream."""
-
-               class MockStreamException(Exception):
-                       pass
-
-               ws_state = WsListenerState()
-
-               (event_in_tx, event_in_rx) = open_memory_channel(0)
-               event_stream = ChannelBackedEventStream(event_in_rx)
-               event_stream_factory = lambda: null_async_context(event_stream)
-
-               (event_out_tx, event_out_rx) = open_memory_channel(0)
-
-               async def main():
-                       async with open_nursery() as nursery:
-                               nursery.start_soon(
-                                       ws_listener_impl,
-                                       event_stream_factory, {MockStreamException}, ws_state, event_out_tx, getLogger()
-                               )
-
-                               event_1 = _update_event("author", [])
-                               await event_in_tx.send(event_1)
-                               with self._assert_completes_in(1):
-                                       event_out = await event_out_rx.receive()
-                               self.assertIs(event_out, event_1)
-
-                               initial_scope = ws_state.scope
-                               await event_in_tx.send(MockStreamException())
-                               await wait_all_tasks_blocked()
-                               self.assertIsNotNone(ws_state.scope)
-                               self.assertIsNot(ws_state.scope, initial_scope)
-                               self.assertTrue(ws_state.down)
-
-                               event_2 = _update_event("author", ["recipient"])
-                               ts = UUID(event_2.payload_data["id"]).time
-                               await event_in_tx.send(event_2)
-                               with self._assert_completes_in(1):
-                                       event_out = await event_out_rx.receive()
-                               self.assertIs(event_out, event_2)
-                               self.assertEqual(ws_state.last_update_ts, ts)
-                               self.assertFalse(ws_state.down)
-                               self.assertEqual(ws_state.resync_ts, ts)
-
-                               nursery.cancel_scope.cancel()
-
-               trio.run(main, clock = MockClock(autojump_threshold = 0))
-
-
-class RedditTests(TestCase):
-       """Tests for interactions with Reddit."""
-
-       @classmethod
-       def setUpClass(cls) -> None:
-               cls._TOKEN_REF = TokenRef(environ["API_TOKEN"])
-
-       def test_try_request(self) -> None:
-               self.assertIsNotNone(
-                       try_request(
-                               Request(
-                                       "https://oauth.reddit.com/api/live/happening_now",
-                                       headers = self._TOKEN_REF.get_common_headers(),
-                               ),
-                               getLogger(),
-                       )
-               )
-
-       def test_try_request_http_error(self) -> None:
-               self.assertIsNone(
-                       try_request(
-                               Request("https://oauth.reddit.com/api/live/happening_now"),
-                               getLogger(),
-                               suppress_codes = [403],
-                       )
-               )
diff --git a/strikebot/docs/sample_config.ini b/strikebot/docs/sample_config.ini
deleted file mode 100644 (file)
index 34dec07..0000000
+++ /dev/null
@@ -1,65 +0,0 @@
-# see Python `configparser' docs for precise file syntax
-
-[config]
-
-######## basic operating parameters
-
-# space-separated authorization IDs for Reddit API token lookup
-auth IDs = 13 15
-
-bot user = count_better
-
-# if false, don't strike or delete updates or report incorrectly stricken updates (optional, default true)
-enforcing = true
-
-# log messages at or above this severity to standard out; level names and numbers are supported (optional, default
-# WARNING); see https://docs.python.org/3/library/logging.html#levels
-log level = WARNING
-
-# if true, don't modify the live thread at all (implies enforcing = false) (optional, default false)
-read-only = false
-
-thread ID = abc123
-
-
-######## API pool error handling
-
-# (seconds) for error backoff, pause the API pool for this much time
-API pool error delay = 5.5
-
-# (seconds) pause the pool when the Reddit API returns two errors within any time span of this length
-API pool error window = 5.5
-
-# terminate if the Reddit API request queue exceeds this size
-request queue limit = 100
-
-
-######## live thread update handling
-
-# (seconds) maximum time to hold updates for reordering
-reorder buffer time = 0.25
-
-# (seconds) minimum time to retain updates to enable future resyncs
-update retention time = 120.0
-
-
-######## WebSocket pool
-
-# (seconds) maximum allowable spread in WebSocket message arrival times
-WS parity time = 1.0
-
-# goal size of WebSocket connection pool
-WS pool size = 5
-
-# (seconds) after receiving no messages for this much time, attempt to replace the connection if the pool isn't depleted
-WS silent limit = 30
-
-# (seconds) time after WebSocket handshake completion during which missed updates are excused
-WS warmup time = 0.3
-
-
-# Postgres database configuration; same options as in a connect string. Note that because Python's `SSLContext' is used
-# to implement TLS client auth, using `sslkey' to specify a key obtained from an OpenSSL engine may not be supported (a
-# file path is the only option).
-[db connect params]
-host = example.org
diff --git a/strikebot/pyproject.toml b/strikebot/pyproject.toml
deleted file mode 100644 (file)
index 8fe2f47..0000000
+++ /dev/null
@@ -1,3 +0,0 @@
-[build-system]
-requires = ["setuptools>=42", "wheel"]
-build-backend = "setuptools.build_meta"
diff --git a/strikebot/setup.cfg b/strikebot/setup.cfg
deleted file mode 100644 (file)
index cdd2e66..0000000
+++ /dev/null
@@ -1,15 +0,0 @@
-[metadata]
-name = strikebot
-version = 0.0.7
-
-[options]
-package_dir =
-       = src
-packages = strikebot
-python_requires = ~= 3.9
-install_requires =
-       asks ~= 2.4
-       beautifulsoup4 ~= 4.9
-       trio == 0.19
-       trio-websocket == 0.9.2
-       triopg == 0.6.0
diff --git a/strikebot/setup.py b/strikebot/setup.py
deleted file mode 100644 (file)
index 056ba45..0000000
+++ /dev/null
@@ -1,4 +0,0 @@
-import setuptools
-
-
-setuptools.setup()
diff --git a/strikebot/src/strikebot/__init__.py b/strikebot/src/strikebot/__init__.py
deleted file mode 100644 (file)
index daa9944..0000000
+++ /dev/null
@@ -1,317 +0,0 @@
-"""Count tracking logic and bot's external interface."""
-
-from contextlib import nullcontext as nullcontext
-from dataclasses import dataclass
-from functools import total_ordering
-from itertools import islice
-from typing import Optional
-from uuid import UUID
-import bisect
-import datetime as dt
-import importlib.metadata
-import itertools
-import logging
-
-from trio import CancelScope, EndOfChannel
-import trio
-
-from strikebot.updates import Command, parse_update
-
-
-__version__ = importlib.metadata.version(__package__)
-
-
-@dataclass
-class _Update:
-       id: str
-       name: str
-       author: str
-       number: Optional[int]
-       command: Optional[Command]
-       ts: dt.datetime
-       count_attempt: bool
-       deletable: bool
-       stricken: bool
-
-       def can_follow(self, prior: "_Update") -> bool:
-               """Determine whether this count update can follow another count update."""
-               return self.number == prior.number + 1 and self.author != prior.author
-
-       def __str__(self):
-               content = str(self.number)
-               if self.command:
-                       content += "+" + self.command.name
-               return "{}({} by {})".format(self.id[4 : 8], content, self.author)
-
-
-@total_ordering
-@dataclass
-class _BufferedUpdate:
-       update: _Update
-       release_at: float  # Trio time
-
-       def __lt__(self, other):
-               return self.update.ts < other.update.ts
-
-
-@total_ordering
-@dataclass
-class _TimelineUpdate:
-       update: _Update
-       accepted: bool
-
-       def __lt__(self, other):
-               return self.update.ts < other.update.ts
-
-       def rejected(self) -> bool:
-               """Whether the update should be stricken."""
-               return self.update.count_attempt and not self.accepted
-
-
-def _parse_rfc_4122_ts(ts: int) -> dt.datetime:
-       epoch = dt.datetime(1582, 10, 15, tzinfo = dt.timezone.utc)
-       return epoch + dt.timedelta(microseconds = ts / 10)
-
-
-def _format_update_ref(update: _Update, thread_id: str) -> str:
-       url = f"https://www.reddit.com/live/{thread_id}/updates/{update.id}"
-       author_md = update.author.replace("_", "\\_")  # Reddit allows letters, digits, underscore, and hyphen
-       return f"[{update.number:,} by {author_md}]({url})"
-
-
-def _format_bad_strike_alert(update: _Update, thread_id: str) -> str:
-       return "_incorrectly stricken:_ " + _format_update_ref(update, thread_id)
-
-
-def _format_curr_count(last_valid: Optional[_Update]) -> str:
-       if last_valid and last_valid.number is not None:
-               count_str = f"{last_valid.number + 1:,}"
-       else:
-               count_str = "unknown"
-       return f"_current count:_ {count_str}"
-
-
-async def count_tracker_impl(
-       message_rx,
-       api_pool: "strikebot.reddit_api.ApiClientPool",
-       reorder_buffer_time: dt.timedelta,  # duration to hold some updates for reordering
-       thread_id: str,
-       bot_user: str,
-       enforcing: bool,
-       read_only: bool,
-       update_retention: dt.timedelta,
-       logger: logging.Logger,
-) -> None:
-       from strikebot.reddit_api import CorrectionUpdateRequest, DeleteRequest, ReportUpdateRequest, StrikeRequest
-
-       enforcing = enforcing and not read_only
-
-       buffer_ = []
-       timeline = []
-       last_valid: Optional[_Update] = None
-       pending_strikes: set[str] = set()  # names of updates to mark stricken on arrival
-       forgotten: bool = False  # whether the retention period for an update has passed during the run
-
-       def handle_update(update):
-               nonlocal last_valid
-
-               tu = _TimelineUpdate(update, accepted = False)
-
-               pos = bisect.bisect(timeline, tu)
-               if pos != len(timeline):
-                       logger.warning(f"long transpo: {update.id}")
-
-               if pos > 0 and timeline[pos - 1].update.id == update.id:
-                       # The pool sent this update message multiple times.  This could be because the message reached parity and
-                       # its bucket was dropped just before a new connection was opened and sent its own copy, but connections also
-                       # seem to send duplicate messages.
-                       return
-
-               pred = next(
-                       (timeline[i].update for i in reversed(range(pos)) if timeline[i].accepted),
-                       None
-               )
-               contender = update.command is Command.RESET or update.number is not None
-               if contender and pred is None and last_valid and forgotten:
-                       # This is a really long transpo which we can't judge valid or invalid since we've forgotten the surrounding
-                       # updates. The best we can do is just ignore it.
-                       logger.warning(f"ignoring {update.id}: no valid prior count on record")
-               else:
-                       timeline.insert(pos, tu)
-                       tu.accepted = (
-                               update.command is Command.RESET
-                               or (
-                                       update.number is not None
-                                       and (pred is None or pred.number is None or update.can_follow(pred))
-                                       and not update.stricken
-                               )
-                       )
-                       logger.debug("{} {} -> {}".format("accepted" if tu.accepted else "rejected", pred, update))
-                       if tu.accepted:
-                               # resync subsequent updates
-                               newly_invalid = []
-                               resync_last_valid = update
-                               converged = False
-                               for scan_tu in islice(timeline, pos + 1, None):
-                                       if scan_tu.update.command is Command.RESET:
-                                               resync_last_valid = scan_tu.update
-                                               if last_valid:
-                                                       converged = True
-                                       elif scan_tu.update.number is not None:
-                                               accept = (
-                                                       (resync_last_valid.number is None or scan_tu.update.can_follow(resync_last_valid))
-                                                       and not scan_tu.update.stricken
-                                               )
-                                               if accept:
-                                                       assert scan_tu.accepted
-                                                       resync_last_valid = scan_tu.update
-                                                       # resync would have no effect past this point
-                                                       if last_valid:
-                                                               converged = True
-                                               elif scan_tu.accepted:
-                                                       newly_invalid.append(scan_tu)
-                                                       if scan_tu.update is last_valid:
-                                                               last_valid = None
-                                               scan_tu.accepted = accept
-                                       if converged:
-                                               break
-
-                               if converged:
-                                       assert last_valid.ts >= resync_last_valid.ts
-                               else:
-                                       last_valid = resync_last_valid
-
-                               parts = []
-                               unstrikable = [tu for tu in newly_invalid if tu.update.stricken]
-                               if unstrikable:
-                                       parts.append(
-                                               "The following counts are invalid:\n\n"
-                                               + "\n".join(" - " + _format_update_ref(tu.update, thread_id) for tu in unstrikable)
-                                       )
-
-                               if update.stricken:
-                                       logger.info(f"bad strike of {update.id}")
-                                       parts.append(_format_bad_strike_alert(update, thread_id))
-
-                               if parts:
-                                       parts.append(_format_curr_count(last_valid))
-                                       if not read_only:
-                                               api_pool.enqueue_request(CorrectionUpdateRequest(thread_id, "\n\n".join(parts)))
-
-                               for invalid_tu in newly_invalid:
-                                       if not invalid_tu.update.stricken:
-                                               if enforcing:
-                                                       api_pool.enqueue_request(StrikeRequest(thread_id, invalid_tu.update.name, invalid_tu.update.ts))
-                                               invalid_tu.update.stricken = True
-                       elif update.count_attempt and not update.stricken:
-                               if enforcing:
-                                       api_pool.enqueue_request(StrikeRequest(thread_id, update.name, update.ts))
-                               update.stricken = True
-
-               if not read_only and update.command is Command.REPORT:
-                       api_pool.enqueue_request(ReportUpdateRequest(thread_id, body = _format_curr_count(last_valid)))
-
-       async with message_rx:
-               while True:
-                       if buffer_:
-                               deadline = min(bu.release_at for bu in buffer_)
-                               cancel_scope = CancelScope(deadline = deadline)
-                       else:
-                               cancel_scope = nullcontext()
-
-                       msg = None
-                       with cancel_scope:
-                               try:
-                                       msg = await message_rx.receive()
-                               except EndOfChannel:
-                                       break
-
-                       if msg:
-                               if msg.data["type"] == "update":
-                                       payload_data = msg.data["payload"]["data"]
-                                       stricken = payload_data["name"] in pending_strikes
-                                       pending_strikes.discard(payload_data["name"])
-                                       if payload_data["author"] != bot_user:
-                                               if last_valid and last_valid.number is not None:
-                                                       next_up = last_valid.number + 1
-                                               else:
-                                                       next_up = None
-                                               pu = parse_update(payload_data, next_up, bot_user)
-                                               rfc_ts = UUID(payload_data["id"]).time
-                                               update = _Update(
-                                                       id = payload_data["id"],
-                                                       name = payload_data["name"],
-                                                       author = payload_data["author"],
-                                                       number = pu.number,
-                                                       command = pu.command,
-                                                       ts = _parse_rfc_4122_ts(rfc_ts),
-                                                       count_attempt = pu.count_attempt,
-                                                       deletable = pu.deletable,
-                                                       stricken = stricken or payload_data["stricken"],
-                                               )
-                                               release_at = msg.timestamp + reorder_buffer_time.total_seconds()
-                                               bisect.insort(buffer_, _BufferedUpdate(update, release_at))
-                               elif msg.data["type"] == "strike":
-                                       slot = next(
-                                               (
-                                                       slot for slot in itertools.chain(buffer_, reversed(timeline))
-                                                       if slot.update.name == msg.data["payload"]
-                                               ),
-                                               None
-                                       )
-                                       if slot:
-                                               if not slot.update.stricken:
-                                                       slot.update.stricken = True
-                                                       if isinstance(slot, _TimelineUpdate) and slot.accepted:
-                                                               logger.info(f"bad strike of {slot.update.id}")
-                                                               if not read_only:
-                                                                       body = _format_bad_strike_alert(slot.update, thread_id)
-                                                                       api_pool.enqueue_request(CorrectionUpdateRequest(thread_id, body))
-                                       else:
-                                               pending_strikes.add(msg.data["payload"])
-
-                       threshold = dt.datetime.now(dt.timezone.utc) - update_retention
-
-                       # pull updates from the reorder buffer and process
-                       new_buffer = []
-                       for bu in buffer_:
-                               process = (
-                                       # not a count
-                                       bu.update.number is None
-
-                                       # valid count
-                                       or (last_valid is None or last_valid.number is None or bu.update.can_follow(last_valid))
-                                       or bu.update.command is Command.RESET
-
-                                       or trio.current_time() >= bu.release_at
-                                       or bu.update.command is Command.RESET
-                               )
-
-                               if process:
-                                       if bu.update.ts < threshold:
-                                               logger.warning(f"ignoring {bu.update}: arrived past retention window")
-                                       else:
-                                               logger.debug(f"processing {bu.update}")
-                                               handle_update(bu.update)
-                               else:
-                                       logger.debug(f"holding {bu.update}, checked against {last_valid})")
-                                       new_buffer.append(bu)
-                       buffer_ = new_buffer
-                       logger.debug("last count {}".format(last_valid))
-
-                       # delete/forget old updates
-                       new_timeline = []
-                       i = 0
-                       for (i, tu) in enumerate(timeline):
-                               if i >= len(timeline) - 10 or tu.update.ts >= threshold:
-                                       break
-                               elif tu.rejected() and tu.update.deletable:
-                                       if enforcing:
-                                               api_pool.enqueue_request(DeleteRequest(thread_id, tu.update.name, tu.update.ts))
-                               elif tu.update is last_valid:
-                                       new_timeline.append(tu)
-
-                       if i:
-                               forgotten = True
-                       new_timeline.extend(islice(timeline, i, None))
-                       timeline = new_timeline
diff --git a/strikebot/src/strikebot/__main__.py b/strikebot/src/strikebot/__main__.py
deleted file mode 100644 (file)
index fc80e23..0000000
+++ /dev/null
@@ -1,224 +0,0 @@
-from enum import Enum
-from inspect import getmodule
-from logging import FileHandler, getLogger, StreamHandler
-from pathlib import Path
-from signal import SIGUSR1
-from sys import stdout
-from traceback import StackSummary
-from typing import Optional
-import argparse
-import configparser
-import datetime as dt
-import logging
-import ssl
-
-from trio import open_memory_channel, open_nursery, open_signal_receiver
-
-try:
-       from trio.lowlevel import current_root_task, Task
-except ImportError:
-       from trio.hazmat import current_root_task, Task
-
-import trio_asyncio
-import triopg
-
-from strikebot import count_tracker_impl
-from strikebot.db import Client, Messenger
-from strikebot.live_ws import HealingReadPool, PoolMerger
-from strikebot.reddit_api import ApiClientPool
-
-
-_DEBUG_LOG_PATH: Optional[Path] = None  # file to write debug logs to, if any
-
-
-_TaskKind = Enum(
-       "_TaskKind",
-       [
-               "API_WORKER",
-               "DB_CLIENT",
-               "TOKEN_UPDATE",
-               "TRACKER",
-               "WS_MERGE_READER",
-               "WS_MERGE_TIMER",
-               "WS_REFRESHER",
-               "WS_WORKER",
-       ]
-)
-
-
-def _classify_task(task: Task) -> _TaskKind:
-       mod_name = getmodule(task.coro.cr_code).__name__
-       if mod_name.startswith("trio_asyncio."):
-               return None
-       elif task.coro.__name__ == "_signal_handler_impl":
-               return None
-       else:
-               by_coro_name = {
-                       "db_client_impl": _TaskKind.DB_CLIENT,
-                       "token_updater_impl": _TaskKind.TOKEN_UPDATE,
-                       "event_reader_impl": _TaskKind.WS_MERGE_READER,
-                       "timeout_handler_impl": _TaskKind.WS_MERGE_TIMER,
-                       "conn_refresher_impl": _TaskKind.WS_REFRESHER,
-                       "count_tracker_impl": _TaskKind.TRACKER,
-                       "_reader_impl": _TaskKind.WS_WORKER,
-                       "worker_impl": _TaskKind.API_WORKER,
-               }
-               return by_coro_name[task.coro.__name__]
-
-
-def _get_all_tasks():
-       [ta_loop] = [
-               task
-               for nursery in current_root_task().child_nurseries
-               for task in nursery.child_tasks
-               if getmodule(task.coro.cr_code).__name__.startswith("trio_asyncio.")
-       ]
-       for nursery in ta_loop.child_nurseries:
-               yield from nursery.child_tasks
-
-
-def _print_task_info():
-       groups = {kind: [] for kind in _TaskKind}
-       for task in _get_all_tasks():
-               kind = _classify_task(task)
-               if kind:
-                       coro = task.coro
-                       frame_tuples = []
-                       while coro is not None:
-                               if hasattr(coro, "cr_frame"):
-                                       frame_tuples.append((coro.cr_frame, coro.cr_frame.f_lineno))
-                                       coro = coro.cr_await
-                               else:
-                                       frame_tuples.append((coro.gi_frame, coro.gi_frame.f_lineno))
-                                       coro = coro.gi_yieldfrom
-                       groups[kind].append(StackSummary.extract(iter(frame_tuples)))
-       for kind in _TaskKind:
-               print(kind.name)
-               for ss in groups[kind]:
-                       print("  task")
-                       for line in ss.format():
-                               print("    " + line, end = "")
-
-
-async def _signal_handler_impl():
-       with open_signal_receiver(SIGUSR1) as sig_src:
-               async for _ in sig_src:
-                       _print_task_info()
-
-
-ap = argparse.ArgumentParser(__package__)
-ap.add_argument("config_path")
-args = ap.parse_args()
-
-
-parser = configparser.ConfigParser()
-with open(args.config_path) as config_file:
-       parser.read_file(config_file)
-
-main_cfg = parser["config"]
-
-api_pool_error_delay = dt.timedelta(seconds = main_cfg.getfloat("API pool error delay"))
-api_pool_error_window = dt.timedelta(seconds = main_cfg.getfloat("API pool error window"))
-auth_ids = set(map(int, main_cfg["auth IDs"].split()))
-bot_user = main_cfg["bot user"]
-enforcing = main_cfg.getboolean("enforcing", fallback = True)
-
-raw_log_level = main_cfg.get("log level", "WARNING")
-try:
-       log_level = int(raw_log_level)
-except ValueError:
-       log_level = raw_log_level
-
-read_only = main_cfg.getboolean("read-only", fallback = False)
-reorder_buffer_time = dt.timedelta(seconds = main_cfg.getfloat("reorder buffer time"))
-request_queue_limit = main_cfg.getint("request queue limit")
-thread_id = main_cfg["thread ID"]
-update_retention = dt.timedelta(seconds = main_cfg.getfloat("update retention time"))
-ws_parity_time = dt.timedelta(seconds = main_cfg.getfloat("WS parity time"))
-ws_pool_size = main_cfg.getint("WS pool size")
-ws_silent_limit = dt.timedelta(seconds = main_cfg.getfloat("WS silent limit"))
-ws_warmup = dt.timedelta(seconds = main_cfg.getfloat("WS warmup time"))
-
-
-db_cfg = parser["db connect params"]
-cert_path = db_cfg.pop("sslcert", None)
-key_path = db_cfg.pop("sslkey", None)
-key_pass = db_cfg.pop("sslpassword", None)
-
-getters = {
-       "port": db_cfg.getint,
-}
-db_connect_kwargs = {k: getters.get(k, db_cfg.get)(k) for k in db_cfg}
-
-if cert_path:
-       # patch TLS client cert auth support; asyncpg adds it in a later version
-       db_ssl_ctx = ssl.create_default_context()
-       db_ssl_ctx.load_cert_chain(
-               cert_path,
-               key_path,
-               None if key_pass is None else lambda: key_pass,
-       )
-       db_connect_kwargs["ssl"] = db_ssl_ctx
-
-
-if read_only and enforcing:
-       raise RuntimeError("can't use read-only mode with enforcing on")
-
-
-logger = getLogger(__package__)
-logger.setLevel(logging.NOTSET - 1)  # NOTSET means inherit from parent; we use handlers to filter
-
-handler = StreamHandler(stdout)
-handler.setLevel(log_level)
-handler.setFormatter(logging.Formatter("{asctime:23}: {name:17} {levelname:8}: {message}", style = "{"))
-logger.addHandler(handler)
-
-if _DEBUG_LOG_PATH:  # TODO remove this ad hoc setup
-       debug_handler = FileHandler(_DEBUG_LOG_PATH, mode = "w")
-       debug_handler.setLevel(logging.DEBUG)
-       debug_handler.setFormatter(logging.Formatter("{asctime:23}: {name:17} {levelname:8}: {message}", style = "{"))
-       logger.addHandler(debug_handler)
-
-
-async def main():
-       async with (
-               triopg.connect(**db_connect_kwargs) as db_conn,
-               open_nursery() as nursery_a,
-               open_nursery() as nursery_b,
-               open_nursery() as ws_pool_nursery,
-               open_nursery() as nursery_c,
-       ):
-               nursery_a.start_soon(_signal_handler_impl)
-
-               client = Client(db_conn)
-               db_messenger = Messenger(client)
-               nursery_a.start_soon(db_messenger.db_client_impl)
-
-               api_pool = ApiClientPool(
-                       auth_ids, db_messenger, request_queue_limit, api_pool_error_window, api_pool_error_delay,
-                       logger.getChild("api")
-               )
-               nursery_a.start_soon(api_pool.token_updater_impl)
-               for _ in auth_ids:
-                       await nursery_a.start(api_pool.worker_impl)
-
-               (message_tx, message_rx) = open_memory_channel(0)
-               (pool_event_tx, pool_event_rx) = open_memory_channel(0)
-
-               merger = PoolMerger(pool_event_rx, message_tx, ws_parity_time, ws_warmup, logger.getChild("merge"))
-               nursery_b.start_soon(merger.event_reader_impl)
-               nursery_b.start_soon(merger.timeout_handler_impl)
-
-               ws_pool = HealingReadPool(
-                       ws_pool_nursery, ws_pool_size, thread_id, api_pool, pool_event_tx, logger.getChild("live_ws"),
-                       ws_silent_limit
-               )
-               nursery_c.start_soon(ws_pool.conn_refresher_impl)
-
-               nursery_c.start_soon(
-                       count_tracker_impl, message_rx, api_pool, reorder_buffer_time, thread_id, bot_user, enforcing, read_only,
-                       update_retention, logger.getChild("track")
-               )
-               await ws_pool.init_workers()
-
-trio_asyncio.run(main)
diff --git a/strikebot/src/strikebot/common.py b/strikebot/src/strikebot/common.py
deleted file mode 100644 (file)
index 519037a..0000000
+++ /dev/null
@@ -1,10 +0,0 @@
-from typing import Any
-
-
-def int_digest(val: int) -> str:
-       return "{:04x}".format(val & 0xffff)
-
-
-def obj_digest(obj: Any) -> str:
-       """Makes a short digest of the identity of an object."""
-       return int_digest(id(obj))
diff --git a/strikebot/src/strikebot/db.py b/strikebot/src/strikebot/db.py
deleted file mode 100644 (file)
index d1a2f42..0000000
+++ /dev/null
@@ -1,58 +0,0 @@
-"""Single-connection Postgres client with high-level wrappers for operations."""
-
-from dataclasses import dataclass
-from typing import Any, Iterable
-
-import trio
-
-
-def _channel_sender(method):
-       async def wrapped(self, resp_channel, *args, **kwargs):
-               ret = await method(self, *args, **kwargs)
-               async with resp_channel:
-                       await resp_channel.send(ret)
-
-       return wrapped
-
-
-@dataclass
-class Client:
-       """High-level wrappers for DB operations."""
-       conn: Any
-
-       @_channel_sender
-       async def get_auth_tokens(self, auth_ids: set[int]) -> dict[int, str]:
-               results = await self.conn.fetch(
-                       """
-                               select
-                                       distinct on (id)
-                                       id, access_token
-                               from
-                                       public.reddit_app_authorization
-                                       join unnest($1::integer[]) as request_id on id = request_id
-                               where expires > current_timestamp
-                       """,
-                       list(auth_ids),
-               )
-               return {r["id"]: r["access_token"] for r in results}
-
-
-class Messenger:
-       def __init__(self, client: Client):
-               self._client = client
-               (self._request_tx, self._request_rx) = trio.open_memory_channel(0)
-
-       async def do(self, method_name: str, args: Iterable) -> Any:
-               """This is run by consumers of the DB wrapper."""
-               method = getattr(self._client, method_name)
-               (resp_tx, resp_rx) = trio.open_memory_channel(0)
-               coro = method(resp_tx, *args)
-               await self._request_tx.send(coro)
-               async with resp_rx:
-                       return await resp_rx.receive()
-
-       async def db_client_impl(self):
-               """This is the DB client Trio task."""
-               async with self._client.conn, self._request_rx:
-                       async for coro in self._request_rx:
-                               await coro
diff --git a/strikebot/src/strikebot/live_ws.py b/strikebot/src/strikebot/live_ws.py
deleted file mode 100644 (file)
index 031d652..0000000
+++ /dev/null
@@ -1,266 +0,0 @@
-from collections import deque
-from contextlib import suppress
-from dataclasses import dataclass
-from typing import Any, AsyncContextManager, Optional
-import datetime as dt
-import json
-import logging
-import math
-
-from trio_websocket import ConnectionClosed, HandshakeError, open_websocket_url, WebSocketConnection
-import trio
-
-from strikebot.common import int_digest, obj_digest
-from strikebot.reddit_api import AboutLiveThreadRequest
-
-
-@dataclass
-class _PoolEvent:
-       timestamp: float  # Trio time
-
-
-@dataclass
-class _Message(_PoolEvent):
-       data: Any
-       scope: Any
-
-       def dedup_tag(self):
-               """A hashable essence of the contents of this message."""
-               if self.data["type"] == "update":
-                       return ("update", self.data["payload"]["data"]["id"])
-               elif self.data["type"] in {"strike", "delete"}:
-                       return (self.data["type"], self.data["payload"])
-               else:
-                       raise ValueError(self.data["type"])
-
-
-@dataclass
-class _ConnectionDown(_PoolEvent):
-       scope: trio.CancelScope
-
-
-@dataclass
-class _ConnectionUp(_PoolEvent):
-       scope: trio.CancelScope
-
-
-class HealingReadPool:
-       def __init__(
-               self, pool_nursery: trio.Nursery, size, live_thread_id, api_client_pool, pool_event_tx, logger: logging.Logger,
-               silent_limit: dt.timedelta
-       ):
-               assert size >= 2
-               self._nursery = pool_nursery
-               self._size = size
-               self._live_thread_id = live_thread_id
-               self._api_client_pool = api_client_pool
-               self._pool_event_tx = pool_event_tx
-               self._logger = logger
-               self._silent_limit = silent_limit
-
-               (self._refresh_queue_tx, self._refresh_queue_rx) = trio.open_memory_channel(size)
-
-               # number of workers who have stopped receiving updates but whose tasks aren't yet stopped
-               self._closing = 0
-
-       async def init_workers(self):
-               for _ in range(self._size):
-                       await self._spawn_reader()
-               if not self._nursery.child_tasks:
-                       raise RuntimeError("Unable to create any WS connections")
-
-       async def _reader_impl(self, conn_ctx: AsyncContextManager, cancel_scope, refresh_tx):
-               # TODO could use task's implicit cancel scope
-               try:
-                       async with conn_ctx as conn:
-                               self._logger.debug("scope up: {}".format(obj_digest(cancel_scope)))
-                               async with refresh_tx:
-                                       silent_timeout = False
-                                       with cancel_scope, suppress(ConnectionClosed):
-                                               while True:
-                                                       with trio.move_on_after(self._silent_limit.total_seconds()) as timeout_scope:
-                                                               message = await conn.get_message()
-
-                                                       if timeout_scope.cancelled_caught:
-                                                               if len(self._nursery.child_tasks) - self._closing == self._size:
-                                                                       silent_timeout = True
-                                                                       break
-                                                               else:
-                                                                       self._logger.debug("not replacing connection {}; {} tasks, {} closing".format(
-                                                                               obj_digest(cancel_scope),
-                                                                               len(self._nursery.child_tasks),
-                                                                               self._closing,
-                                                                       ))
-                                                       else:
-                                                               event = _Message(trio.current_time(), json.loads(message), cancel_scope)
-                                                               await self._pool_event_tx.send(event)
-
-                                       self._closing += 1
-                                       if silent_timeout:
-                                               await conn.aclose()
-                                               self._logger.debug("replacing WS connection {} due to silent timeout".format(obj_digest(cancel_scope)))
-                                       elif cancel_scope.cancelled_caught:
-                                               await conn.aclose(1008, "Server unexpectedly stopped sending messages")
-                                               self._logger.warning("replacing WS connection {} due to missed update".format(obj_digest(cancel_scope)))
-                                       else:
-                                               self._logger.warning("replacing WS connection {} closed by server".format(obj_digest(cancel_scope)))
-
-                                       refresh_tx.send_nowait(cancel_scope)
-                                       self._logger.debug("scope down: {} ({} tasks, {} closing)".format(
-                                               obj_digest(cancel_scope),
-                                               len(self._nursery.child_tasks),
-                                               self._closing,
-                                       ))
-                                       self._closing -= 1
-               except HandshakeError:
-                       self._logger.error("handshake error while opening WS connection")
-
-       async def _new_connection(self) -> Optional[AsyncContextManager[WebSocketConnection]]:
-               request = AboutLiveThreadRequest(self._live_thread_id)
-               resp = await self._api_client_pool.make_request(request)
-               if resp.status_code == 200:
-                       url = json.loads(resp.text)["data"]["websocket_url"]
-                       return open_websocket_url(url)
-               else:
-                       self._logger.error(f"thread info request failed (HTTP {resp.status_code}); can't open WS connection")
-                       return None
-
-       async def _spawn_reader(self) -> None:
-               conn = await self._new_connection()
-               if conn:
-                       new_scope = trio.CancelScope()
-                       await self._pool_event_tx.send(_ConnectionUp(trio.current_time(), new_scope))
-                       refresh_tx = self._refresh_queue_tx.clone()
-                       self._nursery.start_soon(self._reader_impl, conn, new_scope, refresh_tx)
-
-       async def conn_refresher_impl(self):
-               """Task to monitor and replace WS connections as they disconnect or go silent."""
-               async with self._refresh_queue_rx:
-                       async for old_scope in self._refresh_queue_rx:
-                               await self._pool_event_tx.send(_ConnectionDown(trio.current_time(), old_scope))
-                               await self._spawn_reader()
-                               if not self._nursery.child_tasks:
-                                       raise RuntimeError("WS pool depleted")
-
-
-def _tag_digest(tag):
-       return int_digest(hash(tag))
-
-
-def _format_scope_list(scopes):
-       return ", ".join(sorted(map(obj_digest, scopes)))
-
-
-class PoolMerger:
-       @dataclass
-       class _Bucket:
-               start: float
-               recipients: set
-
-       def __init__(self, pool_event_rx, message_tx, parity_timeout: dt.timedelta, conn_warmup: dt.timedelta, logger: logging.Logger):
-               """
-               :param parity_timeout: max delay between message receipt on different connections
-               :param conn_warmup: max interval after connection within which missed updates are allowed
-               """
-               self._pool_event_rx = pool_event_rx
-               self._message_tx = message_tx
-               self._parity_timeout = parity_timeout
-               self._conn_warmup = conn_warmup
-               self._logger = logger
-
-               self._buckets = {}
-               self._scope_activations = {}
-               self._pending = deque()
-               self._outgoing_scopes = set()
-               (self._timer_poke_tx, self._timer_poke_rx) = trio.open_memory_channel(math.inf)
-
-       def _log_current_buckets(self):
-               self._logger.debug("current buckets: {}".format(", ".join(map(_tag_digest, self._buckets.keys()))))
-
-       async def event_reader_impl(self):
-               """Drop unused messages, deduplicate useful ones, and communicate with the timeout handler."""
-               async with self._pool_event_rx, self._timer_poke_tx:
-                       async for event in self._pool_event_rx:
-                               if isinstance(event, _ConnectionUp):
-                                       # An early add of an active scope could mean it's expected on a message that fired before it opened,
-                                       # resulting in a false positive for replacement. The timer task merges these in among the buckets by
-                                       # timestamp to avoid that.
-                                       self._pending.append(event)
-                               elif isinstance(event, _Message):
-                                       if event.data["type"] in ["update", "strike", "delete"]:
-                                               tag = event.dedup_tag()
-                                               self._logger.debug("recv {} from {}".format(_tag_digest(tag), obj_digest(event.scope)))
-                                               if tag in self._buckets:
-                                                       b = self._buckets[tag]
-                                                       b.recipients.add(event.scope)
-                                                       # If this scope is the last one for this bucket we could clear the bucket here, but since
-                                                       # connections sometimes get repeat messages and second copies can arrive before the first
-                                                       # copy has arrived on all connections, leaving the bucket open to absorb repeats can reduce
-                                                       # the likelihood of a second bucket being allocated late for the second copy of a message
-                                                       # and causing unnecessary connection replacement.
-                                               elif event.scope not in self._outgoing_scopes:
-                                                       sane = (
-                                                               event.scope in self._scope_activations
-                                                               or any(e.scope == event.scope for e in self._pending)
-                                                       )
-                                                       if sane:
-                                                               self._logger.debug("new bucket {}: {}".format(_tag_digest(tag), tag))
-                                                               self._buckets[tag] = self._Bucket(event.timestamp, {event.scope})
-                                                               self._log_current_buckets()
-                                                               await self._message_tx.send(event)
-                                                       else:
-                                                               raise RuntimeError("recieved message from unrecognized WS connection")
-                                       else:
-                                               self._logger.debug("recv type {!r} from {} (discarding)".format(event.data["type"], obj_digest(event.scope)))
-                               elif isinstance(event, _ConnectionDown):
-                                       # We don't need to worry about canceling this scope at all, so no need to require it for parity for
-                                       # any message, even older ones.  The scope may be gone already, if we canceled it previously.
-                                       self._scope_activations.pop(event.scope, None)
-                                       self._pending = deque(ev for ev in self._pending if ev.scope is not event.scope)
-                                       self._outgoing_scopes.discard(event.scope)
-                               else:
-                                       raise TypeError(f"Expected pool event, found {event!r}")
-                               self._timer_poke_tx.send_nowait(None)  # may have new work for the timer
-
-       async def timeout_handler_impl(self):
-               """When connections have had enough time to reach parity on a message, signal replacement of any slackers."""
-               async with self._timer_poke_rx:
-                       while True:
-                               if self._buckets:
-                                       now = trio.current_time()
-                                       (tag, bucket) = min(self._buckets.items(), key = lambda i: i[1].start)
-
-                                       # Make sure the right scope set is ready for the next bucket up
-                                       while self._pending and self._pending[0].timestamp < bucket.start:
-                                               ev = self._pending.popleft()
-                                               self._scope_activations[ev.scope] = ev.timestamp
-
-                                       target_scopes = {
-                                               scope for (scope, active) in self._scope_activations.items()
-                                               if active + self._conn_warmup.total_seconds() < now
-                                       }
-                                       if now > bucket.start + self._parity_timeout.total_seconds():
-                                               filled = target_scopes & bucket.recipients
-                                               missing = target_scopes - bucket.recipients
-                                               extra = bucket.recipients - target_scopes
-
-                                               self._logger.debug("expiring bucket {}".format(_tag_digest(tag)))
-                                               if filled:
-                                                       self._logger.debug("  filled {}: {}".format(len(filled), _format_scope_list(filled)))
-                                               if missing:
-                                                       self._logger.debug("  missing {}: {}".format(len(missing), _format_scope_list(missing)))
-                                               if extra:
-                                                       self._logger.debug("  extra {}: {}".format(len(extra), _format_scope_list(extra)))
-
-                                               for scope in missing:
-                                                       self._outgoing_scopes.add(scope)
-                                                       scope.cancel()
-                                               del self._buckets[tag]
-                                               self._log_current_buckets()
-                                       else:
-                                               await trio.sleep_until(bucket.start + self._parity_timeout.total_seconds())
-                               else:
-                                       try:
-                                               await self._timer_poke_rx.receive()
-                                       except trio.EndOfChannel:
-                                               break
diff --git a/strikebot/src/strikebot/queue.py b/strikebot/src/strikebot/queue.py
deleted file mode 100644 (file)
index acf5ef0..0000000
+++ /dev/null
@@ -1,61 +0,0 @@
-"""Unbounded blocking queues for Trio"""
-
-from collections import deque
-from dataclasses import dataclass
-from functools import total_ordering
-from typing import Any, Iterable
-import heapq
-
-try:
-       from trio.lowlevel import ParkingLot
-except ImportError:
-       from trio.hazmat import ParkingLot
-
-
-class Queue:
-       def __init__(self):
-               self._deque = deque()
-               self._empty_wait = ParkingLot()
-
-       def push(self, el: Any) -> None:
-               self._deque.append(el)
-               self._empty_wait.unpark()
-
-       def extend(self, els: Iterable[Any]) -> None:
-               for el in els:
-                       self.push(el)
-
-       async def pop(self) -> Any:
-               if not self._deque:
-                       await self._empty_wait.park()
-               return self._deque.popleft()
-
-       def __len__(self):
-               return len(self._deque)
-
-
-@total_ordering
-@dataclass
-class _ReverseOrdWrapper:
-       inner: Any
-
-       def __lt__(self, other):
-               return other.inner < self.inner
-
-
-class MaxHeap:
-       def __init__(self):
-               self._heap = []
-               self._empty_wait = ParkingLot()
-
-       def push(self, item):
-               heapq.heappush(self._heap, _ReverseOrdWrapper(item))
-               self._empty_wait.unpark()
-
-       async def pop(self):
-               if not self._heap:
-                       await self._empty_wait.park()
-               return heapq.heappop(self._heap).inner
-
-       def __len__(self) -> int:
-               return len(self._heap)
diff --git a/strikebot/src/strikebot/reddit_api.py b/strikebot/src/strikebot/reddit_api.py
deleted file mode 100644 (file)
index 2059707..0000000
+++ /dev/null
@@ -1,291 +0,0 @@
-"""Multi-auth HTTP connection pool with Reddit API wrappers and rate limiting."""
-
-from abc import ABCMeta, abstractmethod
-from dataclasses import dataclass, field
-from functools import total_ordering
-from socket import EAI_AGAIN, EAI_FAIL, gaierror
-import datetime as dt
-import logging
-
-from asks.response_objects import Response
-import asks
-import trio
-
-from strikebot import __version__ as VERSION
-from strikebot.common import obj_digest
-from strikebot.queue import MaxHeap, Queue
-
-
-REQUEST_DELAY = dt.timedelta(seconds = 1)
-
-TOKEN_UPDATE_DELAY = dt.timedelta(minutes = 15)
-
-USER_AGENT_FMT = f"any:net.jcornell.strikebot.{{auth_id}}:v{VERSION} (by /u/jaxklax)"
-
-API_BASE_URL = "https://oauth.reddit.com"
-
-
-@total_ordering
-class _Request(metaclass = ABCMeta):
-       _SUBTYPE_PRECEDENCE = None  # assigned later
-
-       def __lt__(self, other):
-               if type(self) is type(other):
-                       return self._subtype_cmp_key() < other._subtype_cmp_key()
-               else:
-                       prec = self._SUBTYPE_PRECEDENCE
-                       return prec.index(type(other)) < prec.index(type(self))
-
-       def __eq__(self, other):
-               if type(self) is type(other):
-                       return self._subtype_cmp_key() == other._subtype_cmp_key()
-               else:
-                       return False
-
-       @abstractmethod
-       def _subtype_cmp_key(self):
-               # a large key corresponds to a high priority
-               raise NotImplementedError()
-
-       @abstractmethod
-       def to_asks_kwargs(self):
-               raise NotImplementedError()
-
-
-@dataclass(eq = False)
-class StrikeRequest(_Request):
-       thread_id: str
-       update_name: str
-       update_ts: dt.datetime
-
-       def to_asks_kwargs(self):
-               return {
-                       "method": "POST",
-                       "path": f"/api/live/{self.thread_id}/strike_update",
-                       "data": {
-                               "api_type": "json",
-                               "id": self.update_name,
-                       },
-               }
-
-       def _subtype_cmp_key(self):
-               return self.update_ts
-
-
-@dataclass(eq = False)
-class DeleteRequest(_Request):
-       thread_id: str
-       update_name: str
-       update_ts: dt.datetime
-
-       def to_asks_kwargs(self):
-               return {
-                       "method": "POST",
-                       "path": f"/api/live/{self.thread_id}/delete_update",
-                       "data": {
-                               "api_type": "json",
-                               "id": self.update_name,
-                       },
-               }
-
-       def _subtype_cmp_key(self):
-               return -self.update_ts.timestamp()
-
-
-@dataclass(eq = False)
-class AboutLiveThreadRequest(_Request):
-       thread_id: str
-       created: float = field(default_factory = trio.current_time)
-
-       def to_asks_kwargs(self):
-               return {
-                       "method": "GET",
-                       "path": f"/live/{self.thread_id}/about",
-                       "params": {
-                               "raw_json": "1",
-                       },
-               }
-
-       def _subtype_cmp_key(self):
-               return -self.created
-
-
-@dataclass(eq = False)
-class _BaseUpdatePostRequest(_Request, metaclass = ABCMeta):
-       thread_id: str
-       body: str
-       created: float = field(default_factory = trio.current_time)
-
-       def to_asks_kwargs(self):
-               return {
-                       "method": "POST",
-                       "path": f"/live/{self.thread_id}/update",
-                       "data": {
-                               "api_type": "json",
-                               "body": self.body,
-                       },
-               }
-
-
-@dataclass(eq = False)
-class ReportUpdateRequest(_BaseUpdatePostRequest):
-       def _subtype_cmp_key(self):
-               return -self.created
-
-
-@dataclass(eq = False)
-class CorrectionUpdateRequest(_BaseUpdatePostRequest):
-       def _subtype_cmp_key(self):
-               return -self.created
-
-
-# highest priority first
-_Request._SUBTYPE_PRECEDENCE = [
-       AboutLiveThreadRequest,
-       StrikeRequest,
-       CorrectionUpdateRequest,
-       ReportUpdateRequest,
-       DeleteRequest,
-]
-
-
-@dataclass
-class AppCooldown:
-       auth_id: int
-       ready_at: float  # Trio time
-
-
-class ApiClientPool:
-       def __init__(
-               self,
-               auth_ids: set[int],
-               db_messenger: "strikebot.db.Messenger",
-               request_queue_limit: int,
-               error_window: dt.timedelta,
-               error_delay: dt.timedelta,
-               logger: logging.Logger,
-       ):
-               self._auth_ids = auth_ids
-               self._db_messenger = db_messenger
-               self._request_queue_limit = request_queue_limit
-               self._error_window = error_window
-               self._error_delay = error_delay
-               self._logger = logger
-
-               now = trio.current_time()
-               self._tokens = {}
-               self._app_queue = Queue()
-               self._waiting = {id_: AppCooldown(id_, now) for id_ in auth_ids}
-               self._request_queue = MaxHeap()
-
-               # pool-wide API error backoff
-               self._last_error = None
-               self._global_resume = None
-
-               self._session = asks.Session(connections = len(auth_ids))
-               self._session.base_location = API_BASE_URL
-
-       async def _update_tokens(self):
-               tokens = await self._db_messenger.do("get_auth_tokens", (self._auth_ids,))
-               self._tokens.update(tokens)
-
-               awaken_auths = self._waiting.keys() & tokens.keys()
-               self._app_queue.extend(self._waiting.pop(auth_id) for auth_id in awaken_auths)
-               self._logger.debug("updated API tokens")
-
-       async def token_updater_impl(self):
-               last_update = None
-               while True:
-                       if last_update is not None:
-                               await trio.sleep_until(last_update + TOKEN_UPDATE_DELAY.total_seconds())
-
-                       if last_update is None:
-                               last_update = trio.current_time()
-                       else:
-                               last_update += TOKEN_UPDATE_DELAY.total_seconds()
-
-                       await self._update_tokens()
-
-       def _check_queue_size(self) -> None:
-               if len(self._request_queue) > self._request_queue_limit:
-                       raise RuntimeError("request queue size exceeded limit")
-
-       async def make_request(self, request: _Request) -> Response:
-               (resp_tx, resp_rx) = trio.open_memory_channel(0)
-               self.enqueue_request(request, resp_tx)
-               async with resp_rx:
-                       return await resp_rx.receive()
-
-       def enqueue_request(self, request: _Request, resp_tx = None) -> None:
-               self._logger.debug("request {}: {}".format(obj_digest(request), type(request).__name__))
-               self._request_queue.push((request, resp_tx))
-               self._check_queue_size()
-               self._logger.debug(f"{len(self._request_queue)} requests in queue")
-
-       async def worker_impl(self, task_status):
-               task_status.started()
-               while True:
-                       (request, resp_tx) = await self._request_queue.pop()
-                       cooldown = await self._app_queue.pop()
-                       await trio.sleep_until(cooldown.ready_at)
-                       if self._global_resume:
-                               await trio.sleep_until(self._global_resume)
-
-                       asks_kwargs = request.to_asks_kwargs()
-                       headers = asks_kwargs.setdefault("headers", {})
-                       headers.update({
-                               "Authorization": "Bearer {}".format(self._tokens[cooldown.auth_id]),
-                               "User-Agent": USER_AGENT_FMT.format(auth_id = cooldown.auth_id),
-                       })
-
-                       request_time = trio.current_time()
-                       try:
-                               resp = await self._session.request(**asks_kwargs)
-                       except gaierror as e:
-                               if e.errno in [EAI_FAIL, EAI_AGAIN]:
-                                       # DNS failure, probably temporary
-                                       error = True
-                               else:
-                                       raise
-                       else:
-                               resp.body  # read response
-                               error = False
-                               wait_for_token = False
-                               log_suffix = " (request {})".format(obj_digest(request))
-                               if resp.status_code == 429:
-                                       # We disagreed about the rate limit state; just try again later.
-                                       self._logger.warning("rate limited by Reddit API" + log_suffix)
-                                       error = True
-                               elif resp.status_code == 401:
-                                       self._logger.warning("got HTTP 401 from Reddit API" + log_suffix)
-                                       error = True
-                                       wait_for_token = True
-                               elif resp.status_code in [404, 500, 503]:
-                                       self._logger.warning(f"got HTTP {resp.status_code} from Reddit API, retrying" + log_suffix)
-                                       error = True
-                               elif 400 <= resp.status_code < 500:
-                                       # If we're doing something wrong, let's catch it right away.
-                                       raise RuntimeError(f"unexpected client error response: {resp.status_code}" + log_suffix)
-                               else:
-                                       if resp.status_code != 200:
-                                               raise RuntimeError(f"unexpected status code {resp.status_code}")
-                                       self._logger.debug("success" + log_suffix)
-                                       if resp_tx:
-                                               await resp_tx.send(resp)
-
-                       if error:
-                               self._request_queue.push((request, resp_tx))
-                               self._check_queue_size()
-                               if self._last_error:
-                                       spread = dt.timedelta(seconds = request_time - self._last_error)
-                                       if spread <= self._error_window:
-                                               self._global_resume = request_time + self._error_delay.total_seconds()
-                               self._last_error = request_time
-
-                       cooldown.ready_at = request_time + REQUEST_DELAY.total_seconds()
-                       if wait_for_token:
-                               self._waiting[cooldown.auth_id] = cooldown
-                               if not self._app_queue:
-                                       self._logger.error("all workers waiting for API tokens")
-                       else:
-                               self._app_queue.push(cooldown)
diff --git a/strikebot/src/strikebot/tests.py b/strikebot/src/strikebot/tests.py
deleted file mode 100644 (file)
index 966bca6..0000000
+++ /dev/null
@@ -1,62 +0,0 @@
-from unittest import TestCase
-
-from strikebot.updates import parse_update
-
-
-def _build_payload(body_html: str) -> dict[str, str]:
-       return {"body_html": body_html}
-
-
-class UpdateParsingTests(TestCase):
-       def test_successful_counts(self):
-               pu = parse_update(_build_payload("<div>12,345,678 spaghetti</div>"), None, "")
-               self.assertEqual(pu.number, 12345678)
-               self.assertFalse(pu.deletable)
-
-               pu = parse_update(_build_payload("<div><p>0</p><br/>oz</div>"), None, "")
-               self.assertEqual(pu.number, 0)
-               self.assertFalse(pu.deletable)
-
-               pu = parse_update(_build_payload("<div>121 345 621</div>"), 121, "")
-               self.assertEqual(pu.number, 121)
-               self.assertFalse(pu.deletable)
-
-               pu = parse_update(_build_payload("28 336 816"), 28336816, "")
-               self.assertEqual(pu.number, 28336816)
-               self.assertTrue(pu.deletable)
-
-       def test_non_counts(self):
-               pu = parse_update(_build_payload("<div><table><tbody><tr><td>zoo</td></tr></tbody></table></div>"), None, "")
-               self.assertFalse(pu.count_attempt)
-               self.assertFalse(pu.deletable)
-
-       def test_typos(self):
-               pu = parse_update(_build_payload("<span>v9</span>"), 888, "")
-               self.assertIsNone(pu.number)
-               self.assertTrue(pu.count_attempt)
-
-               pu = parse_update(_build_payload("<div>v11.585 Empire</div>"), None, "")
-               self.assertIsNone(pu.number)
-               self.assertTrue(pu.count_attempt)
-               self.assertFalse(pu.deletable)
-
-               pu = parse_update(_build_payload("<div>11, 585, 22 </div>"), 11_585_202, "")
-               self.assertIsNone(pu.number)
-               self.assertTrue(pu.count_attempt)
-               self.assertTrue(pu.deletable)
-
-               pu = parse_update(_build_payload("<span>0490499</span>"), 4999, "")
-               self.assertIsNone(pu.number)
-               self.assertTrue(pu.count_attempt)
-
-               # this update is marked non-deletable, but I'm not sure it should be
-               pu = parse_update(_build_payload("12,123123"), 12_123_123, "")
-               self.assertIsNone(pu.number)
-               self.assertTrue(pu.count_attempt)
-
-       def test_html_handling(self):
-               pu = parse_update(_build_payload("123<hr>,456"), None, "")
-               self.assertEqual(pu.number, 123)
-
-               pu = parse_update(_build_payload("<pre>123\n456</pre>"), None, "")
-               self.assertEqual(pu.number, 123)
diff --git a/strikebot/src/strikebot/updates.py b/strikebot/src/strikebot/updates.py
deleted file mode 100644 (file)
index 01eb67b..0000000
+++ /dev/null
@@ -1,226 +0,0 @@
-from __future__ import annotations
-from dataclasses import dataclass
-from enum import Enum
-from typing import Optional
-import re
-
-from bs4 import BeautifulSoup
-
-
-Command = Enum("Command", ["RESET", "REPORT"])
-
-
-@dataclass
-class ParsedUpdate:
-       number: Optional[int]
-       command: Optional[Command]
-       count_attempt: bool  # either well-formed or typo
-       deletable: bool
-
-
-def _parse_command(line: str, bot_user: str) -> Optional[Command]:
-       if line.lower() == f"/u/{bot_user} reset".lower():
-               return Command.RESET
-       elif line.lower() in ["sidebar count", "current count"]:
-               return Command.REPORT
-       else:
-               return None
-
-
-def parse_update(payload_data: dict, curr_count: Optional[int], bot_user: str) -> ParsedUpdate:
-       # curr_count is the next number up, one more than the last count
-
-       NEW_LINE = object()
-       SPACE = object()
-
-       # flatten the update content to plain text
-       tree = BeautifulSoup(payload_data["body_html"], "html.parser")
-       worklist = list(reversed(tree.contents))
-       out = [[]]
-       while worklist:
-               el = worklist.pop()
-               if isinstance(el, str):
-                       out[-1].append(el)
-               elif el is SPACE:
-                       out[-1].append(el)
-               elif el is NEW_LINE or el.name in ["br", "hr"]:
-                       if out[-1]:
-                               out.append([])
-               elif el.name in ["em", "strong", "del", "span", "sup", "code", "a", "th", "td"]:
-                       worklist.extend(reversed(el.contents))
-               elif el.name in ["ul", "ol", "table", "thead", "tbody"]:
-                       worklist.extend(reversed(el.contents))
-               elif el.name in ["li", "p", "div", "blockquote"] or re.match(r"h[1-6]$", el.name):
-                       worklist.append(NEW_LINE)
-                       worklist.extend(reversed(el.contents))
-                       worklist.append(NEW_LINE)
-               elif el.name == "pre":
-                       out.extend([l] for l in el.text.splitlines())
-                       out.append([])
-               elif el.name == "tr":
-                       worklist.append(NEW_LINE)
-                       for (i, cell) in enumerate(reversed(el.contents)):
-                               worklist.append(cell)
-                               if i != len(el.contents) - 1:
-                                       worklist.append(SPACE)
-                       worklist.append(NEW_LINE)
-               else:
-                       raise RuntimeError(f"can't parse tag {el.name}")
-
-       tmp_lines = (
-               "".join(" " if part is SPACE else part for part in parts).strip()
-               for parts in out
-       )
-       pre_strip_lines = list(filter(None, tmp_lines))
-
-       # normalize whitespace according to HTML rendering rules
-       # https://developer.mozilla.org/en-US/docs/Web/API/Document_Object_Model/Whitespace#explanation
-       stripped_lines = [
-               re.sub(" +", " ", l.replace("\t", " ").replace("\n", " ")).strip(" ")
-               for l in pre_strip_lines
-       ]
-
-       return _parse_from_lines(stripped_lines, curr_count, bot_user)
-
-
-def _parse_from_lines(lines: list[str], curr_count: Optional[int], bot_user: str) -> ParsedUpdate:
-       command = next(
-               (cmd for l in lines if (cmd := _parse_command(l, bot_user))),
-               None
-       )
-       if lines:
-               # look for groups of digits (as many as possible) separated by a uniform separator from the valid set
-               first = lines[0]
-               match = re.match(
-                       "(?P<v>v)?(?P<neg>-)?(?P<num>\\d+((?P<sep>[,. \u2009]|, )\\d+((?P=sep)\\d+)*)?)",
-                       first,
-                       re.ASCII,  # only recognize ASCII digits
-               )
-               if match:
-                       raw_digits = match["num"]
-                       sep = match["sep"]
-                       post = first[match.end() :]
-
-                       zeros = False
-                       while len(raw_digits) > 1 and raw_digits[0] == "0":
-                               zeros = True
-                               raw_digits = raw_digits.removeprefix("0").removeprefix(sep or "")
-
-                       parts = raw_digits.split(sep) if sep else [raw_digits]
-                       all_parts_valid = (
-                               sep is None
-                               or (
-                                       1 <= len(parts[0]) <= 3
-                                       and all(len(p) == 3 for p in parts[1:])
-                               )
-                       )
-                       lone = len(lines) == 1 and (not post or post.isspace())
-                       typo = False
-                       if lone:
-                               if match["v"] and len(parts) == 1 and len(parts[0]) <= 2:
-                                       # failed paste of leading digits
-                                       typo = True
-                               elif match["v"] and all_parts_valid:
-                                       # v followed by count
-                                       typo = True
-                               elif curr_count is not None and abs(curr_count) >= 100 and bool(match["neg"]) == (curr_count < 0):
-                                       goal_parts = _separate(str(abs(curr_count)))
-                                       partials = [
-                                               goal_parts[: -1] + [goal_parts[-1][: -1]],  # missing last digit
-                                               goal_parts[: -1] + [goal_parts[-1][: -2]],  # missing last two digits
-                                               goal_parts[: -1] + [goal_parts[-1][: -2] + goal_parts[-1][-1]],  # missing second-last digit
-                                       ]
-                                       if parts in partials:
-                                               # missing any of last two digits
-                                               typo = True
-                                       elif any(parts == p[: -1] + [p[-1] + goal_parts[0]] + goal_parts[1 :] for p in partials):
-                                               # double paste
-                                               typo = True
-
-                       if match["v"] or zeros or typo or (parts == ["0"] and match["neg"]):
-                               number = None
-                               count_attempt = True
-                               deletable = lone
-                       else:
-                               groups_okay = True
-                               if curr_count is not None and sep and sep.isspace():
-                                       # Presume that the intended count consists of as many valid digit groups as necessary to match the
-                                       # number of digits in the expected count, if possible.
-                                       digit_count = len(str(abs(curr_count)))
-                                       use_parts = []
-                                       accum = 0
-                                       for (i, part) in enumerate(parts):
-                                               part_valid = len(part) <= 3 if i == 0 else len(part) == 3
-                                               if part_valid and accum < digit_count:
-                                                       use_parts.append(part)
-                                                       accum += len(part)
-                                               else:
-                                                       break
-
-                                       # could still be a no-separator count with some extra digit groups on the same line
-                                       if not use_parts:
-                                               use_parts = [parts[0]]
-
-                                       lone = lone and len(use_parts) == len(parts)
-                               else:
-                                       # current count is unknown or no separator was used
-                                       groups_okay = all_parts_valid
-                                       use_parts = parts
-
-                               digits = "".join(use_parts)
-                               number = -int(digits) if match["neg"] else int(digits)
-                               special = (
-                                       curr_count is not None
-                                       and abs(number - curr_count) <= 25
-                                       and _is_special_number(number)
-                               )
-                               if groups_okay:
-                                       deletable = lone and not special
-                                       if len(use_parts) == len(parts) and post and not post[0].isspace():
-                                               count_attempt = curr_count is not None and abs(number - curr_count) <= 25
-                                               number = None
-                                       else:
-                                               count_attempt = True
-                               else:
-                                       number = None
-                                       count_attempt = True
-                                       deletable = False
-
-                       if first[0].isdigit() and not count_attempt:
-                               count_attempt = True
-                               deletable = False
-               else:
-                       # no count attempt found
-                       number = None
-                       count_attempt = False
-                       deletable = False
-       else:
-               # no lines in update
-               number = None
-               count_attempt = False
-               deletable = True
-
-       return ParsedUpdate(
-               number = number,
-               command = command,
-               count_attempt = count_attempt,
-               deletable = deletable,
-       )
-
-
-def _separate(digits: str) -> list[str]:
-       mod = len(digits) % 3
-       out = []
-       if mod:
-               out.append(digits[: mod])
-       out.extend(digits[i : i + 3] for i in range(mod, len(digits), 3))
-       return out
-
-
-def _is_special_number(num: int) -> bool:
-       num_str = str(num)
-       return bool(
-               num % 1000 in [0, 1, 333, 999]
-               or (num > 10_000_000 and "".join(reversed(num_str)) == num_str)
-               or re.match(r"(.+)\1+$", num_str)  # repeated sequence
-       )
diff --git a/strikebot/strikebot/docs/sample_config.ini b/strikebot/strikebot/docs/sample_config.ini
new file mode 100644 (file)
index 0000000..34dec07
--- /dev/null
@@ -0,0 +1,65 @@
+# see Python `configparser' docs for precise file syntax
+
+[config]
+
+######## basic operating parameters
+
+# space-separated authorization IDs for Reddit API token lookup
+auth IDs = 13 15
+
+bot user = count_better
+
+# if false, don't strike or delete updates or report incorrectly stricken updates (optional, default true)
+enforcing = true
+
+# log messages at or above this severity to standard out; level names and numbers are supported (optional, default
+# WARNING); see https://docs.python.org/3/library/logging.html#levels
+log level = WARNING
+
+# if true, don't modify the live thread at all (implies enforcing = false) (optional, default false)
+read-only = false
+
+thread ID = abc123
+
+
+######## API pool error handling
+
+# (seconds) for error backoff, pause the API pool for this much time
+API pool error delay = 5.5
+
+# (seconds) pause the pool when the Reddit API returns two errors within any time span of this length
+API pool error window = 5.5
+
+# terminate if the Reddit API request queue exceeds this size
+request queue limit = 100
+
+
+######## live thread update handling
+
+# (seconds) maximum time to hold updates for reordering
+reorder buffer time = 0.25
+
+# (seconds) minimum time to retain updates to enable future resyncs
+update retention time = 120.0
+
+
+######## WebSocket pool
+
+# (seconds) maximum allowable spread in WebSocket message arrival times
+WS parity time = 1.0
+
+# goal size of WebSocket connection pool
+WS pool size = 5
+
+# (seconds) after receiving no messages for this much time, attempt to replace the connection if the pool isn't depleted
+WS silent limit = 30
+
+# (seconds) time after WebSocket handshake completion during which missed updates are excused
+WS warmup time = 0.3
+
+
+# Postgres database configuration; same options as in a connect string. Note that because Python's `SSLContext' is used
+# to implement TLS client auth, using `sslkey' to specify a key obtained from an OpenSSL engine may not be supported (a
+# file path is the only option).
+[db connect params]
+host = example.org
diff --git a/strikebot/strikebot/pyproject.toml b/strikebot/strikebot/pyproject.toml
new file mode 100644 (file)
index 0000000..8fe2f47
--- /dev/null
@@ -0,0 +1,3 @@
+[build-system]
+requires = ["setuptools>=42", "wheel"]
+build-backend = "setuptools.build_meta"
diff --git a/strikebot/strikebot/setup.cfg b/strikebot/strikebot/setup.cfg
new file mode 100644 (file)
index 0000000..cdd2e66
--- /dev/null
@@ -0,0 +1,15 @@
+[metadata]
+name = strikebot
+version = 0.0.7
+
+[options]
+package_dir =
+       = src
+packages = strikebot
+python_requires = ~= 3.9
+install_requires =
+       asks ~= 2.4
+       beautifulsoup4 ~= 4.9
+       trio == 0.19
+       trio-websocket == 0.9.2
+       triopg == 0.6.0
diff --git a/strikebot/strikebot/setup.py b/strikebot/strikebot/setup.py
new file mode 100644 (file)
index 0000000..056ba45
--- /dev/null
@@ -0,0 +1,4 @@
+import setuptools
+
+
+setuptools.setup()
diff --git a/strikebot/strikebot/src/strikebot/__init__.py b/strikebot/strikebot/src/strikebot/__init__.py
new file mode 100644 (file)
index 0000000..daa9944
--- /dev/null
@@ -0,0 +1,317 @@
+"""Count tracking logic and bot's external interface."""
+
+from contextlib import nullcontext as nullcontext
+from dataclasses import dataclass
+from functools import total_ordering
+from itertools import islice
+from typing import Optional
+from uuid import UUID
+import bisect
+import datetime as dt
+import importlib.metadata
+import itertools
+import logging
+
+from trio import CancelScope, EndOfChannel
+import trio
+
+from strikebot.updates import Command, parse_update
+
+
+__version__ = importlib.metadata.version(__package__)
+
+
+@dataclass
+class _Update:
+       id: str
+       name: str
+       author: str
+       number: Optional[int]
+       command: Optional[Command]
+       ts: dt.datetime
+       count_attempt: bool
+       deletable: bool
+       stricken: bool
+
+       def can_follow(self, prior: "_Update") -> bool:
+               """Determine whether this count update can follow another count update."""
+               return self.number == prior.number + 1 and self.author != prior.author
+
+       def __str__(self):
+               content = str(self.number)
+               if self.command:
+                       content += "+" + self.command.name
+               return "{}({} by {})".format(self.id[4 : 8], content, self.author)
+
+
+@total_ordering
+@dataclass
+class _BufferedUpdate:
+       update: _Update
+       release_at: float  # Trio time
+
+       def __lt__(self, other):
+               return self.update.ts < other.update.ts
+
+
+@total_ordering
+@dataclass
+class _TimelineUpdate:
+       update: _Update
+       accepted: bool
+
+       def __lt__(self, other):
+               return self.update.ts < other.update.ts
+
+       def rejected(self) -> bool:
+               """Whether the update should be stricken."""
+               return self.update.count_attempt and not self.accepted
+
+
+def _parse_rfc_4122_ts(ts: int) -> dt.datetime:
+       epoch = dt.datetime(1582, 10, 15, tzinfo = dt.timezone.utc)
+       return epoch + dt.timedelta(microseconds = ts / 10)
+
+
+def _format_update_ref(update: _Update, thread_id: str) -> str:
+       url = f"https://www.reddit.com/live/{thread_id}/updates/{update.id}"
+       author_md = update.author.replace("_", "\\_")  # Reddit allows letters, digits, underscore, and hyphen
+       return f"[{update.number:,} by {author_md}]({url})"
+
+
+def _format_bad_strike_alert(update: _Update, thread_id: str) -> str:
+       return "_incorrectly stricken:_ " + _format_update_ref(update, thread_id)
+
+
+def _format_curr_count(last_valid: Optional[_Update]) -> str:
+       if last_valid and last_valid.number is not None:
+               count_str = f"{last_valid.number + 1:,}"
+       else:
+               count_str = "unknown"
+       return f"_current count:_ {count_str}"
+
+
+async def count_tracker_impl(
+       message_rx,
+       api_pool: "strikebot.reddit_api.ApiClientPool",
+       reorder_buffer_time: dt.timedelta,  # duration to hold some updates for reordering
+       thread_id: str,
+       bot_user: str,
+       enforcing: bool,
+       read_only: bool,
+       update_retention: dt.timedelta,
+       logger: logging.Logger,
+) -> None:
+       from strikebot.reddit_api import CorrectionUpdateRequest, DeleteRequest, ReportUpdateRequest, StrikeRequest
+
+       enforcing = enforcing and not read_only
+
+       buffer_ = []
+       timeline = []
+       last_valid: Optional[_Update] = None
+       pending_strikes: set[str] = set()  # names of updates to mark stricken on arrival
+       forgotten: bool = False  # whether the retention period for an update has passed during the run
+
+       def handle_update(update):
+               nonlocal last_valid
+
+               tu = _TimelineUpdate(update, accepted = False)
+
+               pos = bisect.bisect(timeline, tu)
+               if pos != len(timeline):
+                       logger.warning(f"long transpo: {update.id}")
+
+               if pos > 0 and timeline[pos - 1].update.id == update.id:
+                       # The pool sent this update message multiple times.  This could be because the message reached parity and
+                       # its bucket was dropped just before a new connection was opened and sent its own copy, but connections also
+                       # seem to send duplicate messages.
+                       return
+
+               pred = next(
+                       (timeline[i].update for i in reversed(range(pos)) if timeline[i].accepted),
+                       None
+               )
+               contender = update.command is Command.RESET or update.number is not None
+               if contender and pred is None and last_valid and forgotten:
+                       # This is a really long transpo which we can't judge valid or invalid since we've forgotten the surrounding
+                       # updates. The best we can do is just ignore it.
+                       logger.warning(f"ignoring {update.id}: no valid prior count on record")
+               else:
+                       timeline.insert(pos, tu)
+                       tu.accepted = (
+                               update.command is Command.RESET
+                               or (
+                                       update.number is not None
+                                       and (pred is None or pred.number is None or update.can_follow(pred))
+                                       and not update.stricken
+                               )
+                       )
+                       logger.debug("{} {} -> {}".format("accepted" if tu.accepted else "rejected", pred, update))
+                       if tu.accepted:
+                               # resync subsequent updates
+                               newly_invalid = []
+                               resync_last_valid = update
+                               converged = False
+                               for scan_tu in islice(timeline, pos + 1, None):
+                                       if scan_tu.update.command is Command.RESET:
+                                               resync_last_valid = scan_tu.update
+                                               if last_valid:
+                                                       converged = True
+                                       elif scan_tu.update.number is not None:
+                                               accept = (
+                                                       (resync_last_valid.number is None or scan_tu.update.can_follow(resync_last_valid))
+                                                       and not scan_tu.update.stricken
+                                               )
+                                               if accept:
+                                                       assert scan_tu.accepted
+                                                       resync_last_valid = scan_tu.update
+                                                       # resync would have no effect past this point
+                                                       if last_valid:
+                                                               converged = True
+                                               elif scan_tu.accepted:
+                                                       newly_invalid.append(scan_tu)
+                                                       if scan_tu.update is last_valid:
+                                                               last_valid = None
+                                               scan_tu.accepted = accept
+                                       if converged:
+                                               break
+
+                               if converged:
+                                       assert last_valid.ts >= resync_last_valid.ts
+                               else:
+                                       last_valid = resync_last_valid
+
+                               parts = []
+                               unstrikable = [tu for tu in newly_invalid if tu.update.stricken]
+                               if unstrikable:
+                                       parts.append(
+                                               "The following counts are invalid:\n\n"
+                                               + "\n".join(" - " + _format_update_ref(tu.update, thread_id) for tu in unstrikable)
+                                       )
+
+                               if update.stricken:
+                                       logger.info(f"bad strike of {update.id}")
+                                       parts.append(_format_bad_strike_alert(update, thread_id))
+
+                               if parts:
+                                       parts.append(_format_curr_count(last_valid))
+                                       if not read_only:
+                                               api_pool.enqueue_request(CorrectionUpdateRequest(thread_id, "\n\n".join(parts)))
+
+                               for invalid_tu in newly_invalid:
+                                       if not invalid_tu.update.stricken:
+                                               if enforcing:
+                                                       api_pool.enqueue_request(StrikeRequest(thread_id, invalid_tu.update.name, invalid_tu.update.ts))
+                                               invalid_tu.update.stricken = True
+                       elif update.count_attempt and not update.stricken:
+                               if enforcing:
+                                       api_pool.enqueue_request(StrikeRequest(thread_id, update.name, update.ts))
+                               update.stricken = True
+
+               if not read_only and update.command is Command.REPORT:
+                       api_pool.enqueue_request(ReportUpdateRequest(thread_id, body = _format_curr_count(last_valid)))
+
+       async with message_rx:
+               while True:
+                       if buffer_:
+                               deadline = min(bu.release_at for bu in buffer_)
+                               cancel_scope = CancelScope(deadline = deadline)
+                       else:
+                               cancel_scope = nullcontext()
+
+                       msg = None
+                       with cancel_scope:
+                               try:
+                                       msg = await message_rx.receive()
+                               except EndOfChannel:
+                                       break
+
+                       if msg:
+                               if msg.data["type"] == "update":
+                                       payload_data = msg.data["payload"]["data"]
+                                       stricken = payload_data["name"] in pending_strikes
+                                       pending_strikes.discard(payload_data["name"])
+                                       if payload_data["author"] != bot_user:
+                                               if last_valid and last_valid.number is not None:
+                                                       next_up = last_valid.number + 1
+                                               else:
+                                                       next_up = None
+                                               pu = parse_update(payload_data, next_up, bot_user)
+                                               rfc_ts = UUID(payload_data["id"]).time
+                                               update = _Update(
+                                                       id = payload_data["id"],
+                                                       name = payload_data["name"],
+                                                       author = payload_data["author"],
+                                                       number = pu.number,
+                                                       command = pu.command,
+                                                       ts = _parse_rfc_4122_ts(rfc_ts),
+                                                       count_attempt = pu.count_attempt,
+                                                       deletable = pu.deletable,
+                                                       stricken = stricken or payload_data["stricken"],
+                                               )
+                                               release_at = msg.timestamp + reorder_buffer_time.total_seconds()
+                                               bisect.insort(buffer_, _BufferedUpdate(update, release_at))
+                               elif msg.data["type"] == "strike":
+                                       slot = next(
+                                               (
+                                                       slot for slot in itertools.chain(buffer_, reversed(timeline))
+                                                       if slot.update.name == msg.data["payload"]
+                                               ),
+                                               None
+                                       )
+                                       if slot:
+                                               if not slot.update.stricken:
+                                                       slot.update.stricken = True
+                                                       if isinstance(slot, _TimelineUpdate) and slot.accepted:
+                                                               logger.info(f"bad strike of {slot.update.id}")
+                                                               if not read_only:
+                                                                       body = _format_bad_strike_alert(slot.update, thread_id)
+                                                                       api_pool.enqueue_request(CorrectionUpdateRequest(thread_id, body))
+                                       else:
+                                               pending_strikes.add(msg.data["payload"])
+
+                       threshold = dt.datetime.now(dt.timezone.utc) - update_retention
+
+                       # pull updates from the reorder buffer and process
+                       new_buffer = []
+                       for bu in buffer_:
+                               process = (
+                                       # not a count
+                                       bu.update.number is None
+
+                                       # valid count
+                                       or (last_valid is None or last_valid.number is None or bu.update.can_follow(last_valid))
+                                       or bu.update.command is Command.RESET
+
+                                       or trio.current_time() >= bu.release_at
+                                       or bu.update.command is Command.RESET
+                               )
+
+                               if process:
+                                       if bu.update.ts < threshold:
+                                               logger.warning(f"ignoring {bu.update}: arrived past retention window")
+                                       else:
+                                               logger.debug(f"processing {bu.update}")
+                                               handle_update(bu.update)
+                               else:
+                                       logger.debug(f"holding {bu.update}, checked against {last_valid})")
+                                       new_buffer.append(bu)
+                       buffer_ = new_buffer
+                       logger.debug("last count {}".format(last_valid))
+
+                       # delete/forget old updates
+                       new_timeline = []
+                       i = 0
+                       for (i, tu) in enumerate(timeline):
+                               if i >= len(timeline) - 10 or tu.update.ts >= threshold:
+                                       break
+                               elif tu.rejected() and tu.update.deletable:
+                                       if enforcing:
+                                               api_pool.enqueue_request(DeleteRequest(thread_id, tu.update.name, tu.update.ts))
+                               elif tu.update is last_valid:
+                                       new_timeline.append(tu)
+
+                       if i:
+                               forgotten = True
+                       new_timeline.extend(islice(timeline, i, None))
+                       timeline = new_timeline
diff --git a/strikebot/strikebot/src/strikebot/__main__.py b/strikebot/strikebot/src/strikebot/__main__.py
new file mode 100644 (file)
index 0000000..fc80e23
--- /dev/null
@@ -0,0 +1,224 @@
+from enum import Enum
+from inspect import getmodule
+from logging import FileHandler, getLogger, StreamHandler
+from pathlib import Path
+from signal import SIGUSR1
+from sys import stdout
+from traceback import StackSummary
+from typing import Optional
+import argparse
+import configparser
+import datetime as dt
+import logging
+import ssl
+
+from trio import open_memory_channel, open_nursery, open_signal_receiver
+
+try:
+       from trio.lowlevel import current_root_task, Task
+except ImportError:
+       from trio.hazmat import current_root_task, Task
+
+import trio_asyncio
+import triopg
+
+from strikebot import count_tracker_impl
+from strikebot.db import Client, Messenger
+from strikebot.live_ws import HealingReadPool, PoolMerger
+from strikebot.reddit_api import ApiClientPool
+
+
+_DEBUG_LOG_PATH: Optional[Path] = None  # file to write debug logs to, if any
+
+
+_TaskKind = Enum(
+       "_TaskKind",
+       [
+               "API_WORKER",
+               "DB_CLIENT",
+               "TOKEN_UPDATE",
+               "TRACKER",
+               "WS_MERGE_READER",
+               "WS_MERGE_TIMER",
+               "WS_REFRESHER",
+               "WS_WORKER",
+       ]
+)
+
+
+def _classify_task(task: Task) -> _TaskKind:
+       mod_name = getmodule(task.coro.cr_code).__name__
+       if mod_name.startswith("trio_asyncio."):
+               return None
+       elif task.coro.__name__ == "_signal_handler_impl":
+               return None
+       else:
+               by_coro_name = {
+                       "db_client_impl": _TaskKind.DB_CLIENT,
+                       "token_updater_impl": _TaskKind.TOKEN_UPDATE,
+                       "event_reader_impl": _TaskKind.WS_MERGE_READER,
+                       "timeout_handler_impl": _TaskKind.WS_MERGE_TIMER,
+                       "conn_refresher_impl": _TaskKind.WS_REFRESHER,
+                       "count_tracker_impl": _TaskKind.TRACKER,
+                       "_reader_impl": _TaskKind.WS_WORKER,
+                       "worker_impl": _TaskKind.API_WORKER,
+               }
+               return by_coro_name[task.coro.__name__]
+
+
+def _get_all_tasks():
+       [ta_loop] = [
+               task
+               for nursery in current_root_task().child_nurseries
+               for task in nursery.child_tasks
+               if getmodule(task.coro.cr_code).__name__.startswith("trio_asyncio.")
+       ]
+       for nursery in ta_loop.child_nurseries:
+               yield from nursery.child_tasks
+
+
+def _print_task_info():
+       groups = {kind: [] for kind in _TaskKind}
+       for task in _get_all_tasks():
+               kind = _classify_task(task)
+               if kind:
+                       coro = task.coro
+                       frame_tuples = []
+                       while coro is not None:
+                               if hasattr(coro, "cr_frame"):
+                                       frame_tuples.append((coro.cr_frame, coro.cr_frame.f_lineno))
+                                       coro = coro.cr_await
+                               else:
+                                       frame_tuples.append((coro.gi_frame, coro.gi_frame.f_lineno))
+                                       coro = coro.gi_yieldfrom
+                       groups[kind].append(StackSummary.extract(iter(frame_tuples)))
+       for kind in _TaskKind:
+               print(kind.name)
+               for ss in groups[kind]:
+                       print("  task")
+                       for line in ss.format():
+                               print("    " + line, end = "")
+
+
+async def _signal_handler_impl():
+       with open_signal_receiver(SIGUSR1) as sig_src:
+               async for _ in sig_src:
+                       _print_task_info()
+
+
+ap = argparse.ArgumentParser(__package__)
+ap.add_argument("config_path")
+args = ap.parse_args()
+
+
+parser = configparser.ConfigParser()
+with open(args.config_path) as config_file:
+       parser.read_file(config_file)
+
+main_cfg = parser["config"]
+
+api_pool_error_delay = dt.timedelta(seconds = main_cfg.getfloat("API pool error delay"))
+api_pool_error_window = dt.timedelta(seconds = main_cfg.getfloat("API pool error window"))
+auth_ids = set(map(int, main_cfg["auth IDs"].split()))
+bot_user = main_cfg["bot user"]
+enforcing = main_cfg.getboolean("enforcing", fallback = True)
+
+raw_log_level = main_cfg.get("log level", "WARNING")
+try:
+       log_level = int(raw_log_level)
+except ValueError:
+       log_level = raw_log_level
+
+read_only = main_cfg.getboolean("read-only", fallback = False)
+reorder_buffer_time = dt.timedelta(seconds = main_cfg.getfloat("reorder buffer time"))
+request_queue_limit = main_cfg.getint("request queue limit")
+thread_id = main_cfg["thread ID"]
+update_retention = dt.timedelta(seconds = main_cfg.getfloat("update retention time"))
+ws_parity_time = dt.timedelta(seconds = main_cfg.getfloat("WS parity time"))
+ws_pool_size = main_cfg.getint("WS pool size")
+ws_silent_limit = dt.timedelta(seconds = main_cfg.getfloat("WS silent limit"))
+ws_warmup = dt.timedelta(seconds = main_cfg.getfloat("WS warmup time"))
+
+
+db_cfg = parser["db connect params"]
+cert_path = db_cfg.pop("sslcert", None)
+key_path = db_cfg.pop("sslkey", None)
+key_pass = db_cfg.pop("sslpassword", None)
+
+getters = {
+       "port": db_cfg.getint,
+}
+db_connect_kwargs = {k: getters.get(k, db_cfg.get)(k) for k in db_cfg}
+
+if cert_path:
+       # patch TLS client cert auth support; asyncpg adds it in a later version
+       db_ssl_ctx = ssl.create_default_context()
+       db_ssl_ctx.load_cert_chain(
+               cert_path,
+               key_path,
+               None if key_pass is None else lambda: key_pass,
+       )
+       db_connect_kwargs["ssl"] = db_ssl_ctx
+
+
+if read_only and enforcing:
+       raise RuntimeError("can't use read-only mode with enforcing on")
+
+
+logger = getLogger(__package__)
+logger.setLevel(logging.NOTSET - 1)  # NOTSET means inherit from parent; we use handlers to filter
+
+handler = StreamHandler(stdout)
+handler.setLevel(log_level)
+handler.setFormatter(logging.Formatter("{asctime:23}: {name:17} {levelname:8}: {message}", style = "{"))
+logger.addHandler(handler)
+
+if _DEBUG_LOG_PATH:  # TODO remove this ad hoc setup
+       debug_handler = FileHandler(_DEBUG_LOG_PATH, mode = "w")
+       debug_handler.setLevel(logging.DEBUG)
+       debug_handler.setFormatter(logging.Formatter("{asctime:23}: {name:17} {levelname:8}: {message}", style = "{"))
+       logger.addHandler(debug_handler)
+
+
+async def main():
+       async with (
+               triopg.connect(**db_connect_kwargs) as db_conn,
+               open_nursery() as nursery_a,
+               open_nursery() as nursery_b,
+               open_nursery() as ws_pool_nursery,
+               open_nursery() as nursery_c,
+       ):
+               nursery_a.start_soon(_signal_handler_impl)
+
+               client = Client(db_conn)
+               db_messenger = Messenger(client)
+               nursery_a.start_soon(db_messenger.db_client_impl)
+
+               api_pool = ApiClientPool(
+                       auth_ids, db_messenger, request_queue_limit, api_pool_error_window, api_pool_error_delay,
+                       logger.getChild("api")
+               )
+               nursery_a.start_soon(api_pool.token_updater_impl)
+               for _ in auth_ids:
+                       await nursery_a.start(api_pool.worker_impl)
+
+               (message_tx, message_rx) = open_memory_channel(0)
+               (pool_event_tx, pool_event_rx) = open_memory_channel(0)
+
+               merger = PoolMerger(pool_event_rx, message_tx, ws_parity_time, ws_warmup, logger.getChild("merge"))
+               nursery_b.start_soon(merger.event_reader_impl)
+               nursery_b.start_soon(merger.timeout_handler_impl)
+
+               ws_pool = HealingReadPool(
+                       ws_pool_nursery, ws_pool_size, thread_id, api_pool, pool_event_tx, logger.getChild("live_ws"),
+                       ws_silent_limit
+               )
+               nursery_c.start_soon(ws_pool.conn_refresher_impl)
+
+               nursery_c.start_soon(
+                       count_tracker_impl, message_rx, api_pool, reorder_buffer_time, thread_id, bot_user, enforcing, read_only,
+                       update_retention, logger.getChild("track")
+               )
+               await ws_pool.init_workers()
+
+trio_asyncio.run(main)
diff --git a/strikebot/strikebot/src/strikebot/common.py b/strikebot/strikebot/src/strikebot/common.py
new file mode 100644 (file)
index 0000000..519037a
--- /dev/null
@@ -0,0 +1,10 @@
+from typing import Any
+
+
+def int_digest(val: int) -> str:
+       return "{:04x}".format(val & 0xffff)
+
+
+def obj_digest(obj: Any) -> str:
+       """Makes a short digest of the identity of an object."""
+       return int_digest(id(obj))
diff --git a/strikebot/strikebot/src/strikebot/db.py b/strikebot/strikebot/src/strikebot/db.py
new file mode 100644 (file)
index 0000000..d1a2f42
--- /dev/null
@@ -0,0 +1,58 @@
+"""Single-connection Postgres client with high-level wrappers for operations."""
+
+from dataclasses import dataclass
+from typing import Any, Iterable
+
+import trio
+
+
+def _channel_sender(method):
+       async def wrapped(self, resp_channel, *args, **kwargs):
+               ret = await method(self, *args, **kwargs)
+               async with resp_channel:
+                       await resp_channel.send(ret)
+
+       return wrapped
+
+
+@dataclass
+class Client:
+       """High-level wrappers for DB operations."""
+       conn: Any
+
+       @_channel_sender
+       async def get_auth_tokens(self, auth_ids: set[int]) -> dict[int, str]:
+               results = await self.conn.fetch(
+                       """
+                               select
+                                       distinct on (id)
+                                       id, access_token
+                               from
+                                       public.reddit_app_authorization
+                                       join unnest($1::integer[]) as request_id on id = request_id
+                               where expires > current_timestamp
+                       """,
+                       list(auth_ids),
+               )
+               return {r["id"]: r["access_token"] for r in results}
+
+
+class Messenger:
+       def __init__(self, client: Client):
+               self._client = client
+               (self._request_tx, self._request_rx) = trio.open_memory_channel(0)
+
+       async def do(self, method_name: str, args: Iterable) -> Any:
+               """This is run by consumers of the DB wrapper."""
+               method = getattr(self._client, method_name)
+               (resp_tx, resp_rx) = trio.open_memory_channel(0)
+               coro = method(resp_tx, *args)
+               await self._request_tx.send(coro)
+               async with resp_rx:
+                       return await resp_rx.receive()
+
+       async def db_client_impl(self):
+               """This is the DB client Trio task."""
+               async with self._client.conn, self._request_rx:
+                       async for coro in self._request_rx:
+                               await coro
diff --git a/strikebot/strikebot/src/strikebot/live_ws.py b/strikebot/strikebot/src/strikebot/live_ws.py
new file mode 100644 (file)
index 0000000..031d652
--- /dev/null
@@ -0,0 +1,266 @@
+from collections import deque
+from contextlib import suppress
+from dataclasses import dataclass
+from typing import Any, AsyncContextManager, Optional
+import datetime as dt
+import json
+import logging
+import math
+
+from trio_websocket import ConnectionClosed, HandshakeError, open_websocket_url, WebSocketConnection
+import trio
+
+from strikebot.common import int_digest, obj_digest
+from strikebot.reddit_api import AboutLiveThreadRequest
+
+
+@dataclass
+class _PoolEvent:
+       timestamp: float  # Trio time
+
+
+@dataclass
+class _Message(_PoolEvent):
+       data: Any
+       scope: Any
+
+       def dedup_tag(self):
+               """A hashable essence of the contents of this message."""
+               if self.data["type"] == "update":
+                       return ("update", self.data["payload"]["data"]["id"])
+               elif self.data["type"] in {"strike", "delete"}:
+                       return (self.data["type"], self.data["payload"])
+               else:
+                       raise ValueError(self.data["type"])
+
+
+@dataclass
+class _ConnectionDown(_PoolEvent):
+       scope: trio.CancelScope
+
+
+@dataclass
+class _ConnectionUp(_PoolEvent):
+       scope: trio.CancelScope
+
+
+class HealingReadPool:
+       def __init__(
+               self, pool_nursery: trio.Nursery, size, live_thread_id, api_client_pool, pool_event_tx, logger: logging.Logger,
+               silent_limit: dt.timedelta
+       ):
+               assert size >= 2
+               self._nursery = pool_nursery
+               self._size = size
+               self._live_thread_id = live_thread_id
+               self._api_client_pool = api_client_pool
+               self._pool_event_tx = pool_event_tx
+               self._logger = logger
+               self._silent_limit = silent_limit
+
+               (self._refresh_queue_tx, self._refresh_queue_rx) = trio.open_memory_channel(size)
+
+               # number of workers who have stopped receiving updates but whose tasks aren't yet stopped
+               self._closing = 0
+
+       async def init_workers(self):
+               for _ in range(self._size):
+                       await self._spawn_reader()
+               if not self._nursery.child_tasks:
+                       raise RuntimeError("Unable to create any WS connections")
+
+       async def _reader_impl(self, conn_ctx: AsyncContextManager, cancel_scope, refresh_tx):
+               # TODO could use task's implicit cancel scope
+               try:
+                       async with conn_ctx as conn:
+                               self._logger.debug("scope up: {}".format(obj_digest(cancel_scope)))
+                               async with refresh_tx:
+                                       silent_timeout = False
+                                       with cancel_scope, suppress(ConnectionClosed):
+                                               while True:
+                                                       with trio.move_on_after(self._silent_limit.total_seconds()) as timeout_scope:
+                                                               message = await conn.get_message()
+
+                                                       if timeout_scope.cancelled_caught:
+                                                               if len(self._nursery.child_tasks) - self._closing == self._size:
+                                                                       silent_timeout = True
+                                                                       break
+                                                               else:
+                                                                       self._logger.debug("not replacing connection {}; {} tasks, {} closing".format(
+                                                                               obj_digest(cancel_scope),
+                                                                               len(self._nursery.child_tasks),
+                                                                               self._closing,
+                                                                       ))
+                                                       else:
+                                                               event = _Message(trio.current_time(), json.loads(message), cancel_scope)
+                                                               await self._pool_event_tx.send(event)
+
+                                       self._closing += 1
+                                       if silent_timeout:
+                                               await conn.aclose()
+                                               self._logger.debug("replacing WS connection {} due to silent timeout".format(obj_digest(cancel_scope)))
+                                       elif cancel_scope.cancelled_caught:
+                                               await conn.aclose(1008, "Server unexpectedly stopped sending messages")
+                                               self._logger.warning("replacing WS connection {} due to missed update".format(obj_digest(cancel_scope)))
+                                       else:
+                                               self._logger.warning("replacing WS connection {} closed by server".format(obj_digest(cancel_scope)))
+
+                                       refresh_tx.send_nowait(cancel_scope)
+                                       self._logger.debug("scope down: {} ({} tasks, {} closing)".format(
+                                               obj_digest(cancel_scope),
+                                               len(self._nursery.child_tasks),
+                                               self._closing,
+                                       ))
+                                       self._closing -= 1
+               except HandshakeError:
+                       self._logger.error("handshake error while opening WS connection")
+
+       async def _new_connection(self) -> Optional[AsyncContextManager[WebSocketConnection]]:
+               request = AboutLiveThreadRequest(self._live_thread_id)
+               resp = await self._api_client_pool.make_request(request)
+               if resp.status_code == 200:
+                       url = json.loads(resp.text)["data"]["websocket_url"]
+                       return open_websocket_url(url)
+               else:
+                       self._logger.error(f"thread info request failed (HTTP {resp.status_code}); can't open WS connection")
+                       return None
+
+       async def _spawn_reader(self) -> None:
+               conn = await self._new_connection()
+               if conn:
+                       new_scope = trio.CancelScope()
+                       await self._pool_event_tx.send(_ConnectionUp(trio.current_time(), new_scope))
+                       refresh_tx = self._refresh_queue_tx.clone()
+                       self._nursery.start_soon(self._reader_impl, conn, new_scope, refresh_tx)
+
+       async def conn_refresher_impl(self):
+               """Task to monitor and replace WS connections as they disconnect or go silent."""
+               async with self._refresh_queue_rx:
+                       async for old_scope in self._refresh_queue_rx:
+                               await self._pool_event_tx.send(_ConnectionDown(trio.current_time(), old_scope))
+                               await self._spawn_reader()
+                               if not self._nursery.child_tasks:
+                                       raise RuntimeError("WS pool depleted")
+
+
+def _tag_digest(tag):
+       return int_digest(hash(tag))
+
+
+def _format_scope_list(scopes):
+       return ", ".join(sorted(map(obj_digest, scopes)))
+
+
+class PoolMerger:
+       @dataclass
+       class _Bucket:
+               start: float
+               recipients: set
+
+       def __init__(self, pool_event_rx, message_tx, parity_timeout: dt.timedelta, conn_warmup: dt.timedelta, logger: logging.Logger):
+               """
+               :param parity_timeout: max delay between message receipt on different connections
+               :param conn_warmup: max interval after connection within which missed updates are allowed
+               """
+               self._pool_event_rx = pool_event_rx
+               self._message_tx = message_tx
+               self._parity_timeout = parity_timeout
+               self._conn_warmup = conn_warmup
+               self._logger = logger
+
+               self._buckets = {}
+               self._scope_activations = {}
+               self._pending = deque()
+               self._outgoing_scopes = set()
+               (self._timer_poke_tx, self._timer_poke_rx) = trio.open_memory_channel(math.inf)
+
+       def _log_current_buckets(self):
+               self._logger.debug("current buckets: {}".format(", ".join(map(_tag_digest, self._buckets.keys()))))
+
+       async def event_reader_impl(self):
+               """Drop unused messages, deduplicate useful ones, and communicate with the timeout handler."""
+               async with self._pool_event_rx, self._timer_poke_tx:
+                       async for event in self._pool_event_rx:
+                               if isinstance(event, _ConnectionUp):
+                                       # An early add of an active scope could mean it's expected on a message that fired before it opened,
+                                       # resulting in a false positive for replacement. The timer task merges these in among the buckets by
+                                       # timestamp to avoid that.
+                                       self._pending.append(event)
+                               elif isinstance(event, _Message):
+                                       if event.data["type"] in ["update", "strike", "delete"]:
+                                               tag = event.dedup_tag()
+                                               self._logger.debug("recv {} from {}".format(_tag_digest(tag), obj_digest(event.scope)))
+                                               if tag in self._buckets:
+                                                       b = self._buckets[tag]
+                                                       b.recipients.add(event.scope)
+                                                       # If this scope is the last one for this bucket we could clear the bucket here, but since
+                                                       # connections sometimes get repeat messages and second copies can arrive before the first
+                                                       # copy has arrived on all connections, leaving the bucket open to absorb repeats can reduce
+                                                       # the likelihood of a second bucket being allocated late for the second copy of a message
+                                                       # and causing unnecessary connection replacement.
+                                               elif event.scope not in self._outgoing_scopes:
+                                                       sane = (
+                                                               event.scope in self._scope_activations
+                                                               or any(e.scope == event.scope for e in self._pending)
+                                                       )
+                                                       if sane:
+                                                               self._logger.debug("new bucket {}: {}".format(_tag_digest(tag), tag))
+                                                               self._buckets[tag] = self._Bucket(event.timestamp, {event.scope})
+                                                               self._log_current_buckets()
+                                                               await self._message_tx.send(event)
+                                                       else:
+                                                               raise RuntimeError("recieved message from unrecognized WS connection")
+                                       else:
+                                               self._logger.debug("recv type {!r} from {} (discarding)".format(event.data["type"], obj_digest(event.scope)))
+                               elif isinstance(event, _ConnectionDown):
+                                       # We don't need to worry about canceling this scope at all, so no need to require it for parity for
+                                       # any message, even older ones.  The scope may be gone already, if we canceled it previously.
+                                       self._scope_activations.pop(event.scope, None)
+                                       self._pending = deque(ev for ev in self._pending if ev.scope is not event.scope)
+                                       self._outgoing_scopes.discard(event.scope)
+                               else:
+                                       raise TypeError(f"Expected pool event, found {event!r}")
+                               self._timer_poke_tx.send_nowait(None)  # may have new work for the timer
+
+       async def timeout_handler_impl(self):
+               """When connections have had enough time to reach parity on a message, signal replacement of any slackers."""
+               async with self._timer_poke_rx:
+                       while True:
+                               if self._buckets:
+                                       now = trio.current_time()
+                                       (tag, bucket) = min(self._buckets.items(), key = lambda i: i[1].start)
+
+                                       # Make sure the right scope set is ready for the next bucket up
+                                       while self._pending and self._pending[0].timestamp < bucket.start:
+                                               ev = self._pending.popleft()
+                                               self._scope_activations[ev.scope] = ev.timestamp
+
+                                       target_scopes = {
+                                               scope for (scope, active) in self._scope_activations.items()
+                                               if active + self._conn_warmup.total_seconds() < now
+                                       }
+                                       if now > bucket.start + self._parity_timeout.total_seconds():
+                                               filled = target_scopes & bucket.recipients
+                                               missing = target_scopes - bucket.recipients
+                                               extra = bucket.recipients - target_scopes
+
+                                               self._logger.debug("expiring bucket {}".format(_tag_digest(tag)))
+                                               if filled:
+                                                       self._logger.debug("  filled {}: {}".format(len(filled), _format_scope_list(filled)))
+                                               if missing:
+                                                       self._logger.debug("  missing {}: {}".format(len(missing), _format_scope_list(missing)))
+                                               if extra:
+                                                       self._logger.debug("  extra {}: {}".format(len(extra), _format_scope_list(extra)))
+
+                                               for scope in missing:
+                                                       self._outgoing_scopes.add(scope)
+                                                       scope.cancel()
+                                               del self._buckets[tag]
+                                               self._log_current_buckets()
+                                       else:
+                                               await trio.sleep_until(bucket.start + self._parity_timeout.total_seconds())
+                               else:
+                                       try:
+                                               await self._timer_poke_rx.receive()
+                                       except trio.EndOfChannel:
+                                               break
diff --git a/strikebot/strikebot/src/strikebot/queue.py b/strikebot/strikebot/src/strikebot/queue.py
new file mode 100644 (file)
index 0000000..acf5ef0
--- /dev/null
@@ -0,0 +1,61 @@
+"""Unbounded blocking queues for Trio"""
+
+from collections import deque
+from dataclasses import dataclass
+from functools import total_ordering
+from typing import Any, Iterable
+import heapq
+
+try:
+       from trio.lowlevel import ParkingLot
+except ImportError:
+       from trio.hazmat import ParkingLot
+
+
+class Queue:
+       def __init__(self):
+               self._deque = deque()
+               self._empty_wait = ParkingLot()
+
+       def push(self, el: Any) -> None:
+               self._deque.append(el)
+               self._empty_wait.unpark()
+
+       def extend(self, els: Iterable[Any]) -> None:
+               for el in els:
+                       self.push(el)
+
+       async def pop(self) -> Any:
+               if not self._deque:
+                       await self._empty_wait.park()
+               return self._deque.popleft()
+
+       def __len__(self):
+               return len(self._deque)
+
+
+@total_ordering
+@dataclass
+class _ReverseOrdWrapper:
+       inner: Any
+
+       def __lt__(self, other):
+               return other.inner < self.inner
+
+
+class MaxHeap:
+       def __init__(self):
+               self._heap = []
+               self._empty_wait = ParkingLot()
+
+       def push(self, item):
+               heapq.heappush(self._heap, _ReverseOrdWrapper(item))
+               self._empty_wait.unpark()
+
+       async def pop(self):
+               if not self._heap:
+                       await self._empty_wait.park()
+               return heapq.heappop(self._heap).inner
+
+       def __len__(self) -> int:
+               return len(self._heap)
diff --git a/strikebot/strikebot/src/strikebot/reddit_api.py b/strikebot/strikebot/src/strikebot/reddit_api.py
new file mode 100644 (file)
index 0000000..0cf5daa
--- /dev/null
@@ -0,0 +1,290 @@
+"""Multi-auth HTTP connection pool with Reddit API wrappers and rate limiting."""
+
+from abc import ABCMeta, abstractmethod
+from dataclasses import dataclass, field
+from functools import total_ordering
+from socket import EAI_AGAIN, EAI_FAIL, gaierror
+import datetime as dt
+import logging
+
+from asks.response_objects import Response
+import asks
+import trio
+
+from strikebot import __version__ as VERSION
+from strikebot.common import obj_digest
+from strikebot.queue import MaxHeap, Queue
+
+
+REQUEST_DELAY = dt.timedelta(seconds = 1)
+
+TOKEN_UPDATE_DELAY = dt.timedelta(minutes = 15)
+
+USER_AGENT_FMT = f"any:net.jcornell.strikebot.{{auth_id}}:v{VERSION} (by /u/jaxklax)"
+
+API_BASE_URL = "https://oauth.reddit.com"
+
+
+@total_ordering
+class _Request(metaclass = ABCMeta):
+       _SUBTYPE_PRECEDENCE = None  # assigned later
+
+       def __lt__(self, other):
+               if type(self) is type(other):
+                       return self._subtype_cmp_key() < other._subtype_cmp_key()
+               else:
+                       prec = self._SUBTYPE_PRECEDENCE
+                       return prec.index(type(other)) < prec.index(type(self))
+
+       def __eq__(self, other):
+               if type(self) is type(other):
+                       return self._subtype_cmp_key() == other._subtype_cmp_key()
+               else:
+                       return False
+
+       @abstractmethod
+       def _subtype_cmp_key(self):
+               # a large key corresponds to a high priority
+               raise NotImplementedError()
+
+       @abstractmethod
+       def to_asks_kwargs(self):
+               raise NotImplementedError()
+
+
+@dataclass(eq = False)
+class StrikeRequest(_Request):
+       thread_id: str
+       update_name: str
+       update_ts: dt.datetime
+
+       def to_asks_kwargs(self):
+               return {
+                       "method": "POST",
+                       "path": f"/api/live/{self.thread_id}/strike_update",
+                       "data": {
+                               "api_type": "json",
+                               "id": self.update_name,
+                       },
+               }
+
+       def _subtype_cmp_key(self):
+               return self.update_ts
+
+
+@dataclass(eq = False)
+class DeleteRequest(_Request):
+       thread_id: str
+       update_name: str
+       update_ts: dt.datetime
+
+       def to_asks_kwargs(self):
+               return {
+                       "method": "POST",
+                       "path": f"/api/live/{self.thread_id}/delete_update",
+                       "data": {
+                               "api_type": "json",
+                               "id": self.update_name,
+                       },
+               }
+
+       def _subtype_cmp_key(self):
+               return -self.update_ts.timestamp()
+
+
+@dataclass(eq = False)
+class AboutLiveThreadRequest(_Request):
+       thread_id: str
+       created: float = field(default_factory = trio.current_time)
+
+       def to_asks_kwargs(self):
+               return {
+                       "method": "GET",
+                       "path": f"/live/{self.thread_id}/about",
+                       "params": {
+                               "raw_json": "1",
+                       },
+               }
+
+       def _subtype_cmp_key(self):
+               return -self.created
+
+
+@dataclass(eq = False)
+class _BaseUpdatePostRequest(_Request, metaclass = ABCMeta):
+       thread_id: str
+       body: str
+       created: float = field(default_factory = trio.current_time)
+
+       def to_asks_kwargs(self):
+               return {
+                       "method": "POST",
+                       "path": f"/live/{self.thread_id}/update",
+                       "data": {
+                               "api_type": "json",
+                               "body": self.body,
+                       },
+               }
+
+
+@dataclass(eq = False)
+class ReportUpdateRequest(_BaseUpdatePostRequest):
+       def _subtype_cmp_key(self):
+               return -self.created
+
+
+@dataclass(eq = False)
+class CorrectionUpdateRequest(_BaseUpdatePostRequest):
+       def _subtype_cmp_key(self):
+               return -self.created
+
+
+# highest priority first
+_Request._SUBTYPE_PRECEDENCE = [
+       AboutLiveThreadRequest,
+       StrikeRequest,
+       CorrectionUpdateRequest,
+       ReportUpdateRequest,
+       DeleteRequest,
+]
+
+
+@dataclass
+class AppCooldown:
+       auth_id: int
+       ready_at: float  # Trio time
+
+
+class ApiClientPool:
+       def __init__(
+               self,
+               auth_ids: set[int],
+               db_messenger: "strikebot.db.Messenger",
+               request_queue_limit: int,
+               error_window: dt.timedelta,
+               error_delay: dt.timedelta,
+               logger: logging.Logger,
+       ):
+               self._auth_ids = auth_ids
+               self._db_messenger = db_messenger
+               self._request_queue_limit = request_queue_limit
+               self._error_window = error_window
+               self._error_delay = error_delay
+               self._logger = logger
+
+               now = trio.current_time()
+               self._tokens = {}
+               self._app_queue = Queue()
+               self._waiting = {id_: AppCooldown(id_, now) for id_ in auth_ids}
+               self._request_queue = MaxHeap()
+
+               # pool-wide API error backoff
+               self._last_error = None
+               self._global_resume = None
+
+               self._session = asks.Session(connections = len(auth_ids))
+               self._session.base_location = API_BASE_URL
+
+       async def _update_tokens(self):
+               tokens = await self._db_messenger.do("get_auth_tokens", (self._auth_ids,))
+               self._tokens.update(tokens)
+
+               awaken_auths = self._waiting.keys() & tokens.keys()
+               self._app_queue.extend(self._waiting.pop(auth_id) for auth_id in awaken_auths)
+               self._logger.debug("updated API tokens")
+
+       async def token_updater_impl(self):
+               last_update = None
+               while True:
+                       if last_update is not None:
+                               await trio.sleep_until(last_update + TOKEN_UPDATE_DELAY.total_seconds())
+
+                       if last_update is None:
+                               last_update = trio.current_time()
+                       else:
+                               last_update += TOKEN_UPDATE_DELAY.total_seconds()
+
+                       await self._update_tokens()
+
+       def _check_queue_size(self) -> None:
+               if len(self._request_queue) > self._request_queue_limit:
+                       raise RuntimeError("request queue size exceeded limit")
+
+       async def make_request(self, request: _Request) -> Response:
+               (resp_tx, resp_rx) = trio.open_memory_channel(0)
+               self.enqueue_request(request, resp_tx)
+               async with resp_rx:
+                       return await resp_rx.receive()
+
+       def enqueue_request(self, request: _Request, resp_tx = None) -> None:
+               self._logger.debug("request {}: {}".format(obj_digest(request), type(request).__name__))
+               self._request_queue.push((request, resp_tx))
+               self._logger.debug(f"{len(self._request_queue)} requests in queue")
+               self._check_queue_size()
+
+       async def worker_impl(self, task_status):
+               task_status.started()
+               while True:
+                       (request, resp_tx) = await self._request_queue.pop()
+                       cooldown = await self._app_queue.pop()
+                       await trio.sleep_until(cooldown.ready_at)
+                       if self._global_resume:
+                               await trio.sleep_until(self._global_resume)
+
+                       asks_kwargs = request.to_asks_kwargs()
+                       headers = asks_kwargs.setdefault("headers", {})
+                       headers.update({
+                               "Authorization": "Bearer {}".format(self._tokens[cooldown.auth_id]),
+                               "User-Agent": USER_AGENT_FMT.format(auth_id = cooldown.auth_id),
+                       })
+
+                       request_time = trio.current_time()
+                       try:
+                               resp = await self._session.request(**asks_kwargs)
+                       except gaierror as e:
+                               if e.errno in [EAI_FAIL, EAI_AGAIN]:
+                                       # DNS failure, probably temporary
+                                       error = True
+                               else:
+                                       raise
+                       else:
+                               resp.body  # read response
+                               error = False
+                               wait_for_token = False
+                               log_suffix = " (request {})".format(obj_digest(request))
+                               if resp.status_code == 429:
+                                       # We disagreed about the rate limit state; just try again later.
+                                       self._logger.warning("rate limited by Reddit API" + log_suffix)
+                                       error = True
+                               elif resp.status_code == 401:
+                                       self._logger.warning("got HTTP 401 from Reddit API" + log_suffix)
+                                       error = True
+                                       wait_for_token = True
+                               elif resp.status_code in [404, 500, 503]:
+                                       self._logger.warning(f"got HTTP {resp.status_code} from Reddit API, retrying" + log_suffix)
+                                       error = True
+                               elif 400 <= resp.status_code < 500:
+                                       # If we're doing something wrong, let's catch it right away.
+                                       raise RuntimeError(f"unexpected client error response: {resp.status_code}" + log_suffix)
+                               else:
+                                       if resp.status_code != 200:
+                                               raise RuntimeError(f"unexpected status code {resp.status_code}")
+                                       self._logger.debug("success" + log_suffix)
+                                       if resp_tx:
+                                               await resp_tx.send(resp)
+
+                       if error:
+                               self.enqueue_request(request, resp_tx)
+                               if self._last_error:
+                                       spread = dt.timedelta(seconds = request_time - self._last_error)
+                                       if spread <= self._error_window:
+                                               self._global_resume = request_time + self._error_delay.total_seconds()
+                               self._last_error = request_time
+
+                       cooldown.ready_at = request_time + REQUEST_DELAY.total_seconds()
+                       if wait_for_token:
+                               self._waiting[cooldown.auth_id] = cooldown
+                               if not self._app_queue:
+                                       self._logger.error("all workers waiting for API tokens")
+                       else:
+                               self._app_queue.push(cooldown)
diff --git a/strikebot/strikebot/src/strikebot/tests.py b/strikebot/strikebot/src/strikebot/tests.py
new file mode 100644 (file)
index 0000000..966bca6
--- /dev/null
@@ -0,0 +1,62 @@
+from unittest import TestCase
+
+from strikebot.updates import parse_update
+
+
+def _build_payload(body_html: str) -> dict[str, str]:
+       return {"body_html": body_html}
+
+
+class UpdateParsingTests(TestCase):
+       def test_successful_counts(self):
+               pu = parse_update(_build_payload("<div>12,345,678 spaghetti</div>"), None, "")
+               self.assertEqual(pu.number, 12345678)
+               self.assertFalse(pu.deletable)
+
+               pu = parse_update(_build_payload("<div><p>0</p><br/>oz</div>"), None, "")
+               self.assertEqual(pu.number, 0)
+               self.assertFalse(pu.deletable)
+
+               pu = parse_update(_build_payload("<div>121 345 621</div>"), 121, "")
+               self.assertEqual(pu.number, 121)
+               self.assertFalse(pu.deletable)
+
+               pu = parse_update(_build_payload("28 336 816"), 28336816, "")
+               self.assertEqual(pu.number, 28336816)
+               self.assertTrue(pu.deletable)
+
+       def test_non_counts(self):
+               pu = parse_update(_build_payload("<div><table><tbody><tr><td>zoo</td></tr></tbody></table></div>"), None, "")
+               self.assertFalse(pu.count_attempt)
+               self.assertFalse(pu.deletable)
+
+       def test_typos(self):
+               pu = parse_update(_build_payload("<span>v9</span>"), 888, "")
+               self.assertIsNone(pu.number)
+               self.assertTrue(pu.count_attempt)
+
+               pu = parse_update(_build_payload("<div>v11.585 Empire</div>"), None, "")
+               self.assertIsNone(pu.number)
+               self.assertTrue(pu.count_attempt)
+               self.assertFalse(pu.deletable)
+
+               pu = parse_update(_build_payload("<div>11, 585, 22 </div>"), 11_585_202, "")
+               self.assertIsNone(pu.number)
+               self.assertTrue(pu.count_attempt)
+               self.assertTrue(pu.deletable)
+
+               pu = parse_update(_build_payload("<span>0490499</span>"), 4999, "")
+               self.assertIsNone(pu.number)
+               self.assertTrue(pu.count_attempt)
+
+               # this update is marked non-deletable, but I'm not sure it should be
+               pu = parse_update(_build_payload("12,123123"), 12_123_123, "")
+               self.assertIsNone(pu.number)
+               self.assertTrue(pu.count_attempt)
+
+       def test_html_handling(self):
+               pu = parse_update(_build_payload("123<hr>,456"), None, "")
+               self.assertEqual(pu.number, 123)
+
+               pu = parse_update(_build_payload("<pre>123\n456</pre>"), None, "")
+               self.assertEqual(pu.number, 123)
diff --git a/strikebot/strikebot/src/strikebot/updates.py b/strikebot/strikebot/src/strikebot/updates.py
new file mode 100644 (file)
index 0000000..01eb67b
--- /dev/null
@@ -0,0 +1,226 @@
+from __future__ import annotations
+from dataclasses import dataclass
+from enum import Enum
+from typing import Optional
+import re
+
+from bs4 import BeautifulSoup
+
+
+Command = Enum("Command", ["RESET", "REPORT"])
+
+
+@dataclass
+class ParsedUpdate:
+       number: Optional[int]
+       command: Optional[Command]
+       count_attempt: bool  # either well-formed or typo
+       deletable: bool
+
+
+def _parse_command(line: str, bot_user: str) -> Optional[Command]:
+       if line.lower() == f"/u/{bot_user} reset".lower():
+               return Command.RESET
+       elif line.lower() in ["sidebar count", "current count"]:
+               return Command.REPORT
+       else:
+               return None
+
+
+def parse_update(payload_data: dict, curr_count: Optional[int], bot_user: str) -> ParsedUpdate:
+       # curr_count is the next number up, one more than the last count
+
+       NEW_LINE = object()
+       SPACE = object()
+
+       # flatten the update content to plain text
+       tree = BeautifulSoup(payload_data["body_html"], "html.parser")
+       worklist = list(reversed(tree.contents))
+       out = [[]]
+       while worklist:
+               el = worklist.pop()
+               if isinstance(el, str):
+                       out[-1].append(el)
+               elif el is SPACE:
+                       out[-1].append(el)
+               elif el is NEW_LINE or el.name in ["br", "hr"]:
+                       if out[-1]:
+                               out.append([])
+               elif el.name in ["em", "strong", "del", "span", "sup", "code", "a", "th", "td"]:
+                       worklist.extend(reversed(el.contents))
+               elif el.name in ["ul", "ol", "table", "thead", "tbody"]:
+                       worklist.extend(reversed(el.contents))
+               elif el.name in ["li", "p", "div", "blockquote"] or re.match(r"h[1-6]$", el.name):
+                       worklist.append(NEW_LINE)
+                       worklist.extend(reversed(el.contents))
+                       worklist.append(NEW_LINE)
+               elif el.name == "pre":
+                       out.extend([l] for l in el.text.splitlines())
+                       out.append([])
+               elif el.name == "tr":
+                       worklist.append(NEW_LINE)
+                       for (i, cell) in enumerate(reversed(el.contents)):
+                               worklist.append(cell)
+                               if i != len(el.contents) - 1:
+                                       worklist.append(SPACE)
+                       worklist.append(NEW_LINE)
+               else:
+                       raise RuntimeError(f"can't parse tag {el.name}")
+
+       tmp_lines = (
+               "".join(" " if part is SPACE else part for part in parts).strip()
+               for parts in out
+       )
+       pre_strip_lines = list(filter(None, tmp_lines))
+
+       # normalize whitespace according to HTML rendering rules
+       # https://developer.mozilla.org/en-US/docs/Web/API/Document_Object_Model/Whitespace#explanation
+       stripped_lines = [
+               re.sub(" +", " ", l.replace("\t", " ").replace("\n", " ")).strip(" ")
+               for l in pre_strip_lines
+       ]
+
+       return _parse_from_lines(stripped_lines, curr_count, bot_user)
+
+
+def _parse_from_lines(lines: list[str], curr_count: Optional[int], bot_user: str) -> ParsedUpdate:
+       command = next(
+               (cmd for l in lines if (cmd := _parse_command(l, bot_user))),
+               None
+       )
+       if lines:
+               # look for groups of digits (as many as possible) separated by a uniform separator from the valid set
+               first = lines[0]
+               match = re.match(
+                       "(?P<v>v)?(?P<neg>-)?(?P<num>\\d+((?P<sep>[,. \u2009]|, )\\d+((?P=sep)\\d+)*)?)",
+                       first,
+                       re.ASCII,  # only recognize ASCII digits
+               )
+               if match:
+                       raw_digits = match["num"]
+                       sep = match["sep"]
+                       post = first[match.end() :]
+
+                       zeros = False
+                       while len(raw_digits) > 1 and raw_digits[0] == "0":
+                               zeros = True
+                               raw_digits = raw_digits.removeprefix("0").removeprefix(sep or "")
+
+                       parts = raw_digits.split(sep) if sep else [raw_digits]
+                       all_parts_valid = (
+                               sep is None
+                               or (
+                                       1 <= len(parts[0]) <= 3
+                                       and all(len(p) == 3 for p in parts[1:])
+                               )
+                       )
+                       lone = len(lines) == 1 and (not post or post.isspace())
+                       typo = False
+                       if lone:
+                               if match["v"] and len(parts) == 1 and len(parts[0]) <= 2:
+                                       # failed paste of leading digits
+                                       typo = True
+                               elif match["v"] and all_parts_valid:
+                                       # v followed by count
+                                       typo = True
+                               elif curr_count is not None and abs(curr_count) >= 100 and bool(match["neg"]) == (curr_count < 0):
+                                       goal_parts = _separate(str(abs(curr_count)))
+                                       partials = [
+                                               goal_parts[: -1] + [goal_parts[-1][: -1]],  # missing last digit
+                                               goal_parts[: -1] + [goal_parts[-1][: -2]],  # missing last two digits
+                                               goal_parts[: -1] + [goal_parts[-1][: -2] + goal_parts[-1][-1]],  # missing second-last digit
+                                       ]
+                                       if parts in partials:
+                                               # missing any of last two digits
+                                               typo = True
+                                       elif any(parts == p[: -1] + [p[-1] + goal_parts[0]] + goal_parts[1 :] for p in partials):
+                                               # double paste
+                                               typo = True
+
+                       if match["v"] or zeros or typo or (parts == ["0"] and match["neg"]):
+                               number = None
+                               count_attempt = True
+                               deletable = lone
+                       else:
+                               groups_okay = True
+                               if curr_count is not None and sep and sep.isspace():
+                                       # Presume that the intended count consists of as many valid digit groups as necessary to match the
+                                       # number of digits in the expected count, if possible.
+                                       digit_count = len(str(abs(curr_count)))
+                                       use_parts = []
+                                       accum = 0
+                                       for (i, part) in enumerate(parts):
+                                               part_valid = len(part) <= 3 if i == 0 else len(part) == 3
+                                               if part_valid and accum < digit_count:
+                                                       use_parts.append(part)
+                                                       accum += len(part)
+                                               else:
+                                                       break
+
+                                       # could still be a no-separator count with some extra digit groups on the same line
+                                       if not use_parts:
+                                               use_parts = [parts[0]]
+
+                                       lone = lone and len(use_parts) == len(parts)
+                               else:
+                                       # current count is unknown or no separator was used
+                                       groups_okay = all_parts_valid
+                                       use_parts = parts
+
+                               digits = "".join(use_parts)
+                               number = -int(digits) if match["neg"] else int(digits)
+                               special = (
+                                       curr_count is not None
+                                       and abs(number - curr_count) <= 25
+                                       and _is_special_number(number)
+                               )
+                               if groups_okay:
+                                       deletable = lone and not special
+                                       if len(use_parts) == len(parts) and post and not post[0].isspace():
+                                               count_attempt = curr_count is not None and abs(number - curr_count) <= 25
+                                               number = None
+                                       else:
+                                               count_attempt = True
+                               else:
+                                       number = None
+                                       count_attempt = True
+                                       deletable = False
+
+                       if first[0].isdigit() and not count_attempt:
+                               count_attempt = True
+                               deletable = False
+               else:
+                       # no count attempt found
+                       number = None
+                       count_attempt = False
+                       deletable = False
+       else:
+               # no lines in update
+               number = None
+               count_attempt = False
+               deletable = True
+
+       return ParsedUpdate(
+               number = number,
+               command = command,
+               count_attempt = count_attempt,
+               deletable = deletable,
+       )
+
+
+def _separate(digits: str) -> list[str]:
+       mod = len(digits) % 3
+       out = []
+       if mod:
+               out.append(digits[: mod])
+       out.extend(digits[i : i + 3] for i in range(mod, len(digits), 3))
+       return out
+
+
+def _is_special_number(num: int) -> bool:
+       num_str = str(num)
+       return bool(
+               num % 1000 in [0, 1, 333, 999]
+               or (num > 10_000_000 and "".join(reversed(num_str)) == num_str)
+               or re.match(r"(.+)\1+$", num_str)  # repeated sequence
+       )