Restructure for Debian packaging (WIP)
authorJakob Cornell <jakob+gpg@jcornell.net>
Tue, 6 Sep 2022 03:53:11 +0000 (22:53 -0500)
committerJakob Cornell <jakob+gpg@jcornell.net>
Tue, 6 Sep 2022 03:53:11 +0000 (22:53 -0500)
28 files changed:
.gitignore
build_helper.py [new file with mode: 0644]
docs/sample_config.ini [deleted file]
pyproject.toml [deleted file]
setup.cfg [deleted file]
setup.py [deleted file]
src/strikebot/__init__.py [deleted file]
src/strikebot/__main__.py [deleted file]
src/strikebot/common.py [deleted file]
src/strikebot/db.py [deleted file]
src/strikebot/live_ws.py [deleted file]
src/strikebot/queue.py [deleted file]
src/strikebot/reddit_api.py [deleted file]
src/strikebot/tests.py [deleted file]
src/strikebot/updates.py [deleted file]
strikebot/docs/sample_config.ini [new file with mode: 0644]
strikebot/pyproject.toml [new file with mode: 0644]
strikebot/setup.cfg [new file with mode: 0644]
strikebot/setup.py [new file with mode: 0644]
strikebot/src/strikebot/__init__.py [new file with mode: 0644]
strikebot/src/strikebot/__main__.py [new file with mode: 0644]
strikebot/src/strikebot/common.py [new file with mode: 0644]
strikebot/src/strikebot/db.py [new file with mode: 0644]
strikebot/src/strikebot/live_ws.py [new file with mode: 0644]
strikebot/src/strikebot/queue.py [new file with mode: 0644]
strikebot/src/strikebot/reddit_api.py [new file with mode: 0644]
strikebot/src/strikebot/tests.py [new file with mode: 0644]
strikebot/src/strikebot/updates.py [new file with mode: 0644]

index fbf1bc5520f9a04775ba5e5e69037f72ac2118ed..9de7c6a248e9cf85c47c29403b2b35e53fefd303 100644 (file)
@@ -1,3 +1,4 @@
 *.egg-info
 __pycache__
 .mypy_cache
+.pybuild/
diff --git a/build_helper.py b/build_helper.py
new file mode 100644 (file)
index 0000000..c669c83
--- /dev/null
@@ -0,0 +1,62 @@
+"""
+A few tools to help keep Debian packages and artifacts well organized.
+"""
+
+from argparse import ArgumentParser
+from pathlib import Path
+from subprocess import run
+import re
+
+
+def _is_output(path: Path) -> bool:
+       return any(
+               path.name.endswith(suffix)
+               for suffix in [
+                       ".build",
+                       ".buildinfo",
+                       ".changes",
+                       ".deb",
+                       ".debian.tar.xz",
+                       ".dsc",
+               ]
+       )
+
+
+project_root = Path(__file__).parent
+upstream_root = project_root.joinpath("strikebot")
+build_dir = project_root.joinpath("build")
+
+parser = ArgumentParser()
+tmp = parser.add_subparsers(dest = "cmd", required = True)
+
+tmp.add_parser("build")
+tmp.add_parser("clean")
+
+args = parser.parse_args()
+if args.cmd == "build":
+       # extract version from Python package config
+       patt = re.compile("version *= *(.+?) *$")
+       with upstream_root.joinpath("setup.cfg").open() as f:
+               [m] = filter(None, map(patt.match, f))
+       version = m[1]
+
+       # delete stale "orig" tarball
+       for p in project_root.glob("strikebot_*.orig.tar.xz"):
+               p.unlink()
+
+       # regenerate the "orig" tarball from the current source
+       run(["dh_make", "-y", "--indep", "--createorig", "-p", "strikebot_" + version], cwd = upstream_root)
+
+       # build the package
+       run(["debuild"], cwd = upstream_root, check = True)
+
+       # move outputs to build directory
+       for p in project_root.iterdir():
+               if _is_output(p):
+                       p.rename(build_dir.joinpath(p.relative_to(project_root)))
+elif args.cmd == "clean":
+       for p in build_dir.iterdir():
+               if _is_output(p):
+                       p.unlink()
+else:
+       raise AssertionError()
diff --git a/docs/sample_config.ini b/docs/sample_config.ini
deleted file mode 100644 (file)
index 28e510e..0000000
+++ /dev/null
@@ -1,53 +0,0 @@
-# see Python `configparser' docs for precise file syntax
-
-[config]
-
-######## basic operating parameters
-
-# space-separated authorization IDs for Reddit API token lookup
-auth IDs = 13 15
-
-bot user = count_better
-enforcing = true
-thread ID = abc123
-
-
-######## API pool error handling
-
-# (seconds) for error backoff, pause the API pool for this much time
-API pool error delay = 5.5
-
-# (seconds) pause the pool when the Reddit API returns two errors within any time span of this length
-API pool error window = 5.5
-
-# terminate if the Reddit API request queue exceeds this size
-request queue limit = 100
-
-
-######## live thread update handling
-
-# (seconds) maximum time to hold updates for reordering
-reorder buffer time = 0.25
-
-# (seconds) minimum time to retain updates to enable future resyncs
-update retention time = 120.0
-
-
-######## WebSocket pool
-
-# (seconds) maximum allowable spread in WebSocket message arrival times
-WS parity time = 1.0
-
-# goal size of WebSocket connection pool
-WS pool size = 5
-
-# (seconds) after receiving no messages for this much time, attempt to replace the connection if the pool isn't depleted
-WS silent limit = 30
-
-# (seconds) time after WebSocket handshake completion during which missed updates are excused
-WS warmup time = 0.3
-
-
-# Postgres database configuration; same options as in a connect string
-[db connect params]
-host = example.org
diff --git a/pyproject.toml b/pyproject.toml
deleted file mode 100644 (file)
index 8fe2f47..0000000
+++ /dev/null
@@ -1,3 +0,0 @@
-[build-system]
-requires = ["setuptools>=42", "wheel"]
-build-backend = "setuptools.build_meta"
diff --git a/setup.cfg b/setup.cfg
deleted file mode 100644 (file)
index eb1fd25..0000000
--- a/setup.cfg
+++ /dev/null
@@ -1,15 +0,0 @@
-[metadata]
-name = strikebot
-version = 0.0.0
-
-[options]
-package_dir =
-       = src
-packages = strikebot
-python_requires = ~= 3.9
-install_requires =
-       asks ~= 2.4
-       beautifulsoup4 ~= 4.11
-       trio == 0.19
-       trio-websocket == 0.9.2
-       triopg == 0.6.0
diff --git a/setup.py b/setup.py
deleted file mode 100644 (file)
index 056ba45..0000000
--- a/setup.py
+++ /dev/null
@@ -1,4 +0,0 @@
-import setuptools
-
-
-setuptools.setup()
diff --git a/src/strikebot/__init__.py b/src/strikebot/__init__.py
deleted file mode 100644 (file)
index 7c37d14..0000000
+++ /dev/null
@@ -1,321 +0,0 @@
-from contextlib import nullcontext as nullcontext
-from dataclasses import dataclass
-from functools import total_ordering
-from itertools import islice
-from typing import Optional
-from uuid import UUID
-import bisect
-import datetime as dt
-import importlib.metadata
-import itertools
-import logging
-
-from trio import CancelScope, EndOfChannel
-import trio
-
-from strikebot.updates import Command, parse_update
-
-
-__version__ = importlib.metadata.version(__package__)
-
-
-_READ_ONLY: bool = False  # suppress any API requests that modify the thread?
-
-
-@dataclass
-class _Update:
-       id: str
-       name: str
-       author: str
-       number: Optional[int]
-       command: Optional[Command]
-       ts: dt.datetime
-       count_attempt: bool
-       deletable: bool
-       stricken: bool
-
-       def can_follow(self, prior: "_Update") -> bool:
-               """Determine whether this count update can follow another count update."""
-               return self.number == prior.number + 1 and self.author != prior.author
-
-       def __str__(self):
-               content = str(self.number)
-               if self.command:
-                       content += "+" + self.command.name
-               return "{}({} by {})".format(self.id[4 : 8], content, self.author)
-
-
-@total_ordering
-@dataclass
-class _BufferedUpdate:
-       update: _Update
-       release_at: float  # Trio time
-
-       def __lt__(self, other):
-               return self.update.ts < other.update.ts
-
-
-@total_ordering
-@dataclass
-class _TimelineUpdate:
-       update: _Update
-       accepted: bool
-
-       def __lt__(self, other):
-               return self.update.ts < other.update.ts
-
-       def rejected(self) -> bool:
-               """Whether the update should be stricken."""
-               return self.update.count_attempt and not self.accepted
-
-
-def _parse_rfc_4122_ts(ts: int) -> dt.datetime:
-       epoch = dt.datetime(1582, 10, 15, tzinfo = dt.timezone.utc)
-       return epoch + dt.timedelta(microseconds = ts / 10)
-
-
-def _format_update_ref(update: _Update, thread_id: str) -> str:
-       url = f"https://www.reddit.com/live/{thread_id}/updates/{update.id}"
-       author_md = update.author.replace("_", "\\_")  # Reddit allows letters, digits, underscore, and hyphen
-       return f"[{update.number:,} by {author_md}]({url})"
-
-
-def _format_bad_strike_alert(update: _Update, thread_id: str) -> str:
-       return _format_update_ref(update, thread_id) + " was incorrectly stricken."
-
-
-def _format_curr_count(last_valid: Optional[_Update]) -> str:
-       if last_valid and last_valid.number is not None:
-               count_str = f"{last_valid.number + 1:,}"
-       else:
-               count_str = "unknown"
-       return f"_current count:_ {count_str}"
-
-
-async def count_tracker_impl(
-       message_rx,
-       api_pool: "strikebot.reddit_api.ApiClientPool",
-       reorder_buffer_time: dt.timedelta,  # duration to hold some updates for reordering
-       thread_id: str,
-       bot_user: str,
-       enforcing: bool,
-       update_retention: dt.timedelta,
-       logger: logging.Logger,
-) -> None:
-       from strikebot.reddit_api import CorrectionUpdateRequest, DeleteRequest, ReportUpdateRequest, StrikeRequest
-
-       enforcing = enforcing and not _READ_ONLY
-
-       buffer_ = []
-       timeline = []
-       last_valid: Optional[_Update] = None
-       pending_strikes: set[str] = set()  # names of updates to mark stricken on arrival
-       forgotten: bool = False  # whether the retention period for an update has passed during the run
-
-       def handle_update(update):
-               nonlocal last_valid
-
-               tu = _TimelineUpdate(update, accepted = False)
-
-               pos = bisect.bisect(timeline, tu)
-               if pos != len(timeline):
-                       logger.warning(f"long transpo: {update.id}")
-
-               if pos > 0 and timeline[pos - 1].update.id == update.id:
-                       # The pool sent this update message multiple times.  This could be because the message reached parity and
-                       # its bucket was dropped just before a new connection was opened and sent its own copy, but connections also
-                       # seem to send duplicate messages.
-                       return
-
-               pred = next(
-                       (timeline[i].update for i in reversed(range(pos)) if timeline[i].accepted),
-                       None
-               )
-               contender = update.command is Command.RESET or update.number is not None
-               if contender and pred is None and last_valid and forgotten:
-                       # This is a really long transpo which we can't judge valid or invalid since we've forgotten the surrounding
-                       # updates. The best we can do is just ignore it.
-                       logger.warning(f"ignoring {update.id}: no valid prior count on record")
-               else:
-                       timeline.insert(pos, tu)
-                       tu.accepted = (
-                               update.command is Command.RESET
-                               or (
-                                       update.number is not None
-                                       and (pred is None or pred.number is None or update.can_follow(pred))
-                               )
-                       )
-                       logger.debug("{} {} -> {}".format("accepted" if tu.accepted else "rejected", pred, update))
-                       if tu.accepted:
-                               # resync subsequent updates
-                               newly_valid = []
-                               newly_invalid = []
-                               resync_last_valid = update
-                               converged = False
-                               for scan_tu in islice(timeline, pos + 1, None):
-                                       if scan_tu.update.command is Command.RESET:
-                                               resync_last_valid = scan_tu.update
-                                               if last_valid:
-                                                       converged = True
-                                       elif scan_tu.update.number is not None:
-                                               accept = resync_last_valid.number is None or scan_tu.update.can_follow(resync_last_valid)
-                                               if accept and scan_tu.accepted:
-                                                       # resync would have no effect past this point
-                                                       if last_valid:
-                                                               converged = True
-                                               elif accept:
-                                                       newly_valid.append(scan_tu)
-                                                       resync_last_valid = scan_tu.update
-                                               elif scan_tu.accepted:
-                                                       newly_invalid.append(scan_tu)
-                                                       if scan_tu.update is last_valid:
-                                                               last_valid = None
-                                               scan_tu.accepted = accept
-                                       if converged:
-                                               break
-
-                               if converged:
-                                       assert last_valid.ts >= resync_last_valid.ts
-                               else:
-                                       last_valid = resync_last_valid
-
-                               parts = []
-                               if newly_valid:
-                                       parts.append(
-                                               "The following counts are valid:\n\n"
-                                               + "\n".join(" - " + _format_update_ref(tu.update, thread_id) for tu in newly_valid)
-                                       )
-
-                               unstrikable = [tu for tu in newly_invalid if tu.update.stricken]
-                               if unstrikable:
-                                       parts.append(
-                                               "The following counts are invalid:\n\n"
-                                               + "\n".join(" - " + _format_update_ref(tu.update, thread_id) for tu in unstrikable)
-                                       )
-
-                               if update.stricken:
-                                       logger.info(f"bad strike of {update.id}")
-                                       parts.append(_format_bad_strike_alert(update, thread_id))
-
-                               if parts:
-                                       parts.append(_format_curr_count(last_valid))
-                                       if not _READ_ONLY:
-                                               api_pool.enqueue_request(CorrectionUpdateRequest(thread_id, "\n\n".join(parts)))
-
-                               for invalid_tu in newly_invalid:
-                                       if not invalid_tu.update.stricken:
-                                               if enforcing:
-                                                       api_pool.enqueue_request(StrikeRequest(thread_id, invalid_tu.update.name, invalid_tu.update.ts))
-                                               invalid_tu.update.stricken = True
-                       elif update.count_attempt:
-                               if enforcing:
-                                       api_pool.enqueue_request(StrikeRequest(thread_id, update.name, update.ts))
-                               update.stricken = True
-
-               if not _READ_ONLY and update.command is Command.REPORT:
-                       api_pool.enqueue_request(ReportUpdateRequest(thread_id, body = _format_curr_count(last_valid)))
-
-       with message_rx:
-               while True:
-                       if buffer_:
-                               deadline = min(bu.release_at for bu in buffer_)
-                               cancel_scope = CancelScope(deadline = deadline)
-                       else:
-                               cancel_scope = nullcontext()
-
-                       msg = None
-                       with cancel_scope:
-                               try:
-                                       msg = await message_rx.receive()
-                               except EndOfChannel:
-                                       break
-
-                       if msg:
-                               if msg.data["type"] == "update":
-                                       payload_data = msg.data["payload"]["data"]
-                                       stricken = payload_data["name"] in pending_strikes
-                                       pending_strikes.discard(payload_data["name"])
-                                       if payload_data["author"] != bot_user:
-                                               if last_valid and last_valid.number is not None:
-                                                       next_up = last_valid.number + 1
-                                               else:
-                                                       next_up = None
-                                               pu = parse_update(payload_data, next_up, bot_user)
-                                               rfc_ts = UUID(payload_data["id"]).time
-                                               update = _Update(
-                                                       id = payload_data["id"],
-                                                       name = payload_data["name"],
-                                                       author = payload_data["author"],
-                                                       number = pu.number,
-                                                       command = pu.command,
-                                                       ts = _parse_rfc_4122_ts(rfc_ts),
-                                                       count_attempt = pu.count_attempt,
-                                                       deletable = pu.deletable,
-                                                       stricken = stricken or payload_data["stricken"],
-                                               )
-                                               release_at = msg.timestamp + reorder_buffer_time.total_seconds()
-                                               bisect.insort(buffer_, _BufferedUpdate(update, release_at))
-                               elif msg.data["type"] == "strike":
-                                       slot = next(
-                                               (
-                                                       slot for slot in itertools.chain(buffer_, reversed(timeline))
-                                                       if slot.update.name == msg.data["payload"]
-                                               ),
-                                               None
-                                       )
-                                       if slot:
-                                               if not slot.update.stricken:
-                                                       slot.update.stricken = True
-                                                       if isinstance(slot, _TimelineUpdate) and slot.accepted:
-                                                               logger.info(f"bad strike of {slot.update.id}")
-                                                               if not _READ_ONLY:
-                                                                       body = _format_bad_strike_alert(slot.update, thread_id)
-                                                                       api_pool.enqueue_request(CorrectionUpdateRequest(thread_id, body))
-                                       else:
-                                               pending_strikes.add(msg.data["payload"])
-
-                       threshold = dt.datetime.now(dt.timezone.utc) - update_retention
-
-                       # pull updates from the reorder buffer and process
-                       new_buffer = []
-                       for bu in buffer_:
-                               process = (
-                                       # not a count
-                                       bu.update.number is None
-
-                                       # valid count
-                                       or (last_valid is None or last_valid.number is None or bu.update.can_follow(last_valid))
-                                       or bu.update.command is Command.RESET
-
-                                       or trio.current_time() >= bu.release_at
-                                       or bu.update.command is Command.RESET
-                               )
-
-                               if process:
-                                       if bu.update.ts < threshold:
-                                               logger.warning(f"ignoring {bu.update}: arrived past retention window")
-                                       else:
-                                               logger.debug(f"processing {bu.update}")
-                                               handle_update(bu.update)
-                               else:
-                                       logger.debug(f"holding {bu.update}, checked against {last_valid})")
-                                       new_buffer.append(bu)
-                       buffer_ = new_buffer
-                       logger.debug("last count {}".format(last_valid))
-
-                       # delete/forget old updates
-                       new_timeline = []
-                       i = 0
-                       for (i, tu) in enumerate(timeline):
-                               if i >= len(timeline) - 10 or tu.update.ts >= threshold:
-                                       break
-                               elif tu.rejected() and tu.update.deletable:
-                                       if enforcing:
-                                               api_pool.enqueue_request(DeleteRequest(thread_id, tu.update.name, tu.update.ts))
-                               elif tu.update is last_valid:
-                                       new_timeline.append(tu)
-
-                       if i:
-                               forgotten = True
-                       new_timeline.extend(islice(timeline, i, None))
-                       timeline = new_timeline
diff --git a/src/strikebot/__main__.py b/src/strikebot/__main__.py
deleted file mode 100644 (file)
index 63da2e1..0000000
+++ /dev/null
@@ -1,190 +0,0 @@
-from enum import Enum
-from inspect import getmodule
-from logging import FileHandler, getLogger, StreamHandler
-from signal import SIGUSR1
-from sys import stdout
-from traceback import StackSummary
-from typing import Optional
-import argparse
-import configparser
-import datetime as dt
-import logging
-
-from trio import open_memory_channel, open_nursery, open_signal_receiver
-from trio.lowlevel import current_root_task, Task
-import trio_asyncio
-import triopg
-
-from strikebot import count_tracker_impl
-from strikebot.db import Client, Messenger
-from strikebot.live_ws import HealingReadPool, PoolMerger
-from strikebot.reddit_api import ApiClientPool
-
-
-_DEBUG_LOG_PATH: Optional[str] = None  # path to file to write debug logs to, if any
-
-
-_TaskKind = Enum(
-       "_TaskKind",
-       [
-               "API_WORKER",
-               "DB_CLIENT",
-               "TOKEN_UPDATE",
-               "TRACKER",
-               "WS_MERGE_READER",
-               "WS_MERGE_TIMER",
-               "WS_REFRESHER",
-               "WS_WORKER",
-       ]
-)
-
-
-def _classify_task(task: Task) -> _TaskKind:
-       mod_name = getmodule(task.coro.cr_code).__name__
-       if mod_name.startswith("trio_asyncio."):
-               return None
-       elif task.coro.__name__ == "_signal_handler_impl":
-               return None
-       else:
-               by_coro_name = {
-                       "db_client_impl": _TaskKind.DB_CLIENT,
-                       "token_updater_impl": _TaskKind.TOKEN_UPDATE,
-                       "event_reader_impl": _TaskKind.WS_MERGE_READER,
-                       "timeout_handler_impl": _TaskKind.WS_MERGE_TIMER,
-                       "conn_refresher_impl": _TaskKind.WS_REFRESHER,
-                       "count_tracker_impl": _TaskKind.TRACKER,
-                       "_reader_impl": _TaskKind.WS_WORKER,
-                       "worker_impl": _TaskKind.API_WORKER,
-               }
-               return by_coro_name[task.coro.__name__]
-
-
-def _get_all_tasks():
-       [ta_loop] = [
-               task
-               for nursery in current_root_task().child_nurseries
-               for task in nursery.child_tasks
-               if getmodule(task.coro.cr_code).__name__.startswith("trio_asyncio.")
-       ]
-       for nursery in ta_loop.child_nurseries:
-               yield from nursery.child_tasks
-
-
-def _print_task_info():
-       groups = {kind: [] for kind in _TaskKind}
-       for task in _get_all_tasks():
-               kind = _classify_task(task)
-               if kind:
-                       coro = task.coro
-                       frame_tuples = []
-                       while coro is not None:
-                               if hasattr(coro, "cr_frame"):
-                                       frame_tuples.append((coro.cr_frame, coro.cr_frame.f_lineno))
-                                       coro = coro.cr_await
-                               else:
-                                       frame_tuples.append((coro.gi_frame, coro.gi_frame.f_lineno))
-                                       coro = coro.gi_yieldfrom
-                       groups[kind].append(StackSummary.extract(iter(frame_tuples)))
-       for kind in _TaskKind:
-               print(kind.name)
-               for ss in groups[kind]:
-                       print("  task")
-                       for line in ss.format():
-                               print("    " + line, end = "")
-
-
-async def _signal_handler_impl():
-       with open_signal_receiver(SIGUSR1) as sig_src:
-               async for _ in sig_src:
-                       _print_task_info()
-
-
-ap = argparse.ArgumentParser(__package__)
-ap.add_argument("config_path")
-args = ap.parse_args()
-
-
-parser = configparser.ConfigParser()
-with open(args.config_path) as config_file:
-       parser.read_file(config_file)
-
-main_cfg = parser["config"]
-
-api_pool_error_delay = dt.timedelta(seconds = main_cfg.getfloat("API pool error delay"))
-api_pool_error_window = dt.timedelta(seconds = main_cfg.getfloat("API pool error window"))
-auth_ids = set(map(int, main_cfg["auth IDs"].split()))
-bot_user = main_cfg["bot user"]
-enforcing = main_cfg.getboolean("enforcing")
-reorder_buffer_time = dt.timedelta(seconds = main_cfg.getfloat("reorder buffer time"))
-request_queue_limit = main_cfg.getint("request queue limit")
-thread_id = main_cfg["thread ID"]
-update_retention = dt.timedelta(seconds = main_cfg.getfloat("update retention time"))
-ws_parity_time = dt.timedelta(seconds = main_cfg.getfloat("WS parity time"))
-ws_pool_size = main_cfg.getint("WS pool size")
-ws_silent_limit = dt.timedelta(seconds = main_cfg.getfloat("WS silent limit"))
-ws_warmup = dt.timedelta(seconds = main_cfg.getfloat("WS warmup seconds"))
-
-db_cfg = parser["db connect params"]
-getters = {
-       "port": db_cfg.getint,
-}
-db_connect_params = {k: getters.get(k, db_cfg.get)(k) for k in db_cfg}
-
-
-logger = getLogger(__package__)
-logger.setLevel(logging.DEBUG)
-
-handler = StreamHandler(stdout)
-handler.setLevel(logging.WARNING)
-handler.setFormatter(logging.Formatter("{asctime:23}: {name:17} {levelname:8}: {message}", style = "{"))
-logger.addHandler(handler)
-
-if _DEBUG_LOG_PATH:
-       debug_handler = FileHandler(_DEBUG_LOG_PATH, mode = "w")
-       debug_handler.setLevel(logging.DEBUG)
-       debug_handler.setFormatter(logging.Formatter("{asctime:23}: {name:17} {levelname:8}: {message}", style = "{"))
-       logger.addHandler(debug_handler)
-
-
-async def main():
-       async with (
-               triopg.connect(**db_connect_params) as db_conn,
-               open_nursery() as nursery_a,
-               open_nursery() as nursery_b,
-               open_nursery() as ws_pool_nursery,
-               open_nursery() as nursery_c,
-       ):
-               nursery_a.start_soon(_signal_handler_impl)
-
-               client = Client(db_conn)
-               db_messenger = Messenger(client)
-               nursery_a.start_soon(db_messenger.db_client_impl)
-
-               api_pool = ApiClientPool(
-                       auth_ids, db_messenger, request_queue_limit, api_pool_error_window, api_pool_error_delay,
-                       logger.getChild("api")
-               )
-               nursery_a.start_soon(api_pool.token_updater_impl)
-               for _ in auth_ids:
-                       await nursery_a.start(api_pool.worker_impl)
-
-               (message_tx, message_rx) = open_memory_channel(0)
-               (pool_event_tx, pool_event_rx) = open_memory_channel(0)
-
-               merger = PoolMerger(pool_event_rx, message_tx, ws_parity_time, ws_warmup, logger.getChild("merge"))
-               nursery_b.start_soon(merger.event_reader_impl)
-               nursery_b.start_soon(merger.timeout_handler_impl)
-
-               ws_pool = HealingReadPool(
-                       ws_pool_nursery, ws_pool_size, thread_id, api_pool, pool_event_tx, logger.getChild("live_ws"),
-                       ws_silent_limit
-               )
-               nursery_c.start_soon(ws_pool.conn_refresher_impl)
-
-               nursery_c.start_soon(
-                       count_tracker_impl, message_rx, api_pool, reorder_buffer_time, thread_id, bot_user, enforcing,
-                       update_retention, logger.getChild("track")
-               )
-               await ws_pool.init_workers()
-
-trio_asyncio.run(main)
diff --git a/src/strikebot/common.py b/src/strikebot/common.py
deleted file mode 100644 (file)
index 519037a..0000000
+++ /dev/null
@@ -1,10 +0,0 @@
-from typing import Any
-
-
-def int_digest(val: int) -> str:
-       return "{:04x}".format(val & 0xffff)
-
-
-def obj_digest(obj: Any) -> str:
-       """Makes a short digest of the identity of an object."""
-       return int_digest(id(obj))
diff --git a/src/strikebot/db.py b/src/strikebot/db.py
deleted file mode 100644 (file)
index c9a060a..0000000
+++ /dev/null
@@ -1,58 +0,0 @@
-"""Single-connection Postgres client with high-level wrappers for operations."""
-
-from dataclasses import dataclass
-from typing import Any, Iterable
-
-import trio
-
-
-def _channel_sender(method):
-       async def wrapped(self, resp_channel, *args, **kwargs):
-               ret = await method(self, *args, **kwargs)
-               with resp_channel:
-                       await resp_channel.send(ret)
-
-       return wrapped
-
-
-@dataclass
-class Client:
-       """High-level wrappers for DB operations."""
-       conn: Any
-
-       @_channel_sender
-       async def get_auth_tokens(self, auth_ids: set[int]) -> dict[int, str]:
-               results = await self.conn.fetch(
-                       """
-                               select
-                                       distinct on (id)
-                                       id, access_token
-                               from
-                                       public.reddit_app_authorization
-                                       join unnest($1::integer[]) as request_id on id = request_id
-                               where expires > current_timestamp
-                       """,
-                       list(auth_ids),
-               )
-               return {r["id"]: r["access_token"] for r in results}
-
-
-class Messenger:
-       def __init__(self, client: Client):
-               self._client = client
-               (self._request_tx, self._request_rx) = trio.open_memory_channel(0)
-
-       async def do(self, method_name: str, args: Iterable) -> Any:
-               """This is run by consumers of the DB wrapper."""
-               method = getattr(self._client, method_name)
-               (resp_tx, resp_rx) = trio.open_memory_channel(0)
-               coro = method(resp_tx, *args)
-               await self._request_tx.send(coro)
-               async with resp_rx:
-                       return await resp_rx.receive()
-
-       async def db_client_impl(self):
-               """This is the DB client Trio task."""
-               async with self._client.conn, self._request_rx:
-                       async for coro in self._request_rx:
-                               await coro
diff --git a/src/strikebot/live_ws.py b/src/strikebot/live_ws.py
deleted file mode 100644 (file)
index cb8b3d0..0000000
+++ /dev/null
@@ -1,266 +0,0 @@
-from collections import deque
-from contextlib import suppress
-from dataclasses import dataclass
-from typing import Any, AsyncContextManager, Optional
-import datetime as dt
-import json
-import logging
-import math
-
-from trio_websocket import ConnectionClosed, HandshakeError, open_websocket_url, WebSocketConnection
-import trio
-
-from strikebot.common import int_digest, obj_digest
-from strikebot.reddit_api import AboutLiveThreadRequest
-
-
-@dataclass
-class _PoolEvent:
-       timestamp: float  # Trio time
-
-
-@dataclass
-class _Message(_PoolEvent):
-       data: Any
-       scope: Any
-
-       def dedup_tag(self):
-               """A hashable essence of the contents of this message."""
-               if self.data["type"] == "update":
-                       return ("update", self.data["payload"]["data"]["id"])
-               elif self.data["type"] in {"strike", "delete"}:
-                       return (self.data["type"], self.data["payload"])
-               else:
-                       raise ValueError(self.data["type"])
-
-
-@dataclass
-class _ConnectionDown(_PoolEvent):
-       scope: trio.CancelScope
-
-
-@dataclass
-class _ConnectionUp(_PoolEvent):
-       scope: trio.CancelScope
-
-
-class HealingReadPool:
-       def __init__(
-               self, pool_nursery: trio.Nursery, size, live_thread_id, api_client_pool, pool_event_tx, logger: logging.Logger,
-               silent_limit: dt.timedelta
-       ):
-               assert size >= 2
-               self._nursery = pool_nursery
-               self._size = size
-               self._live_thread_id = live_thread_id
-               self._api_client_pool = api_client_pool
-               self._pool_event_tx = pool_event_tx
-               self._logger = logger
-               self._silent_limit = silent_limit
-
-               (self._refresh_queue_tx, self._refresh_queue_rx) = trio.open_memory_channel(size)
-
-               # number of workers who have stopped receiving updates but whose tasks aren't yet stopped
-               self._closing = 0
-
-       async def init_workers(self):
-               for _ in range(self._size):
-                       await self._spawn_reader()
-               if not self._nursery.child_tasks:
-                       raise RuntimeError("Unable to create any WS connections")
-
-       async def _reader_impl(self, conn_ctx: AsyncContextManager, cancel_scope, refresh_tx):
-               # TODO could use task's implicit cancel scope
-               try:
-                       async with conn_ctx as conn:
-                               self._logger.debug("scope up: {}".format(obj_digest(cancel_scope)))
-                               with refresh_tx:
-                                       silent_timeout = False
-                                       with cancel_scope, suppress(ConnectionClosed):
-                                               while True:
-                                                       with trio.move_on_after(self._silent_limit.total_seconds()) as timeout_scope:
-                                                               message = await conn.get_message()
-
-                                                       if timeout_scope.cancelled_caught:
-                                                               if len(self._nursery.child_tasks) - self._closing == self._size:
-                                                                       silent_timeout = True
-                                                                       break
-                                                               else:
-                                                                       self._logger.debug("not replacing connection {}; {} tasks, {} closing".format(
-                                                                               obj_digest(cancel_scope),
-                                                                               len(self._nursery.child_tasks),
-                                                                               self._closing,
-                                                                       ))
-                                                       else:
-                                                               event = _Message(trio.current_time(), json.loads(message), cancel_scope)
-                                                               await self._pool_event_tx.send(event)
-
-                                       self._closing += 1
-                                       if silent_timeout:
-                                               await conn.aclose()
-                                               self._logger.debug("replacing WS connection {} due to silent timeout".format(obj_digest(cancel_scope)))
-                                       elif cancel_scope.cancelled_caught:
-                                               await conn.aclose(1008, "Server unexpectedly stopped sending messages")
-                                               self._logger.warning("replacing WS connection {} due to missed update".format(obj_digest(cancel_scope)))
-                                       else:
-                                               self._logger.warning("replacing WS connection {} closed by server".format(obj_digest(cancel_scope)))
-
-                                       refresh_tx.send_nowait(cancel_scope)
-                                       self._logger.debug("scope down: {} ({} tasks, {} closing)".format(
-                                               obj_digest(cancel_scope),
-                                               len(self._nursery.child_tasks),
-                                               self._closing,
-                                       ))
-                                       self._closing -= 1
-               except HandshakeError:
-                       self._logger.error("handshake error while opening WS connection")
-
-       async def _new_connection(self) -> Optional[AsyncContextManager[WebSocketConnection]]:
-               request = AboutLiveThreadRequest(self._live_thread_id)
-               resp = await self._api_client_pool.make_request(request)
-               if resp.status_code == 200:
-                       url = json.loads(resp.text)["data"]["websocket_url"]
-                       return open_websocket_url(url)
-               else:
-                       self._logger.error(f"thread info request failed (HTTP {resp.status_code}); can't open WS connection")
-                       return None
-
-       async def _spawn_reader(self) -> None:
-               conn = await self._new_connection()
-               if conn:
-                       new_scope = trio.CancelScope()
-                       await self._pool_event_tx.send(_ConnectionUp(trio.current_time(), new_scope))
-                       refresh_tx = self._refresh_queue_tx.clone()
-                       self._nursery.start_soon(self._reader_impl, conn, new_scope, refresh_tx)
-
-       async def conn_refresher_impl(self):
-               """Task to monitor and replace WS connections as they disconnect or go silent."""
-               with self._refresh_queue_rx:
-                       async for old_scope in self._refresh_queue_rx:
-                               await self._pool_event_tx.send(_ConnectionDown(trio.current_time(), old_scope))
-                               await self._spawn_reader()
-                               if not self._nursery.child_tasks:
-                                       raise RuntimeError("WS pool depleted")
-
-
-def _tag_digest(tag):
-       return int_digest(hash(tag))
-
-
-def _format_scope_list(scopes):
-       return ", ".join(sorted(map(obj_digest, scopes)))
-
-
-class PoolMerger:
-       @dataclass
-       class _Bucket:
-               start: float
-               recipients: set
-
-       def __init__(self, pool_event_rx, message_tx, parity_timeout: dt.timedelta, conn_warmup: dt.timedelta, logger: logging.Logger):
-               """
-               :param parity_timeout: max delay between message receipt on different connections
-               :param conn_warmup: max interval after connection within which missed updates are allowed
-               """
-               self._pool_event_rx = pool_event_rx
-               self._message_tx = message_tx
-               self._parity_timeout = parity_timeout
-               self._conn_warmup = conn_warmup
-               self._logger = logger
-
-               self._buckets = {}
-               self._scope_activations = {}
-               self._pending = deque()
-               self._outgoing_scopes = set()
-               (self._timer_poke_tx, self._timer_poke_rx) = trio.open_memory_channel(math.inf)
-
-       def _log_current_buckets(self):
-               self._logger.debug("current buckets: {}".format(", ".join(map(_tag_digest, self._buckets.keys()))))
-
-       async def event_reader_impl(self):
-               """Drop unused messages, deduplicate useful ones, and communicate with the timeout handler."""
-               with self._pool_event_rx, self._timer_poke_tx:
-                       async for event in self._pool_event_rx:
-                               if isinstance(event, _ConnectionUp):
-                                       # An early add of an active scope could mean it's expected on a message that fired before it opened,
-                                       # resulting in a false positive for replacement. The timer task merges these in among the buckets by
-                                       # timestamp to avoid that.
-                                       self._pending.append(event)
-                               elif isinstance(event, _Message):
-                                       if event.data["type"] in ["update", "strike", "delete"]:
-                                               tag = event.dedup_tag()
-                                               self._logger.debug("recv {} from {}".format(_tag_digest(tag), obj_digest(event.scope)))
-                                               if tag in self._buckets:
-                                                       b = self._buckets[tag]
-                                                       b.recipients.add(event.scope)
-                                                       # If this scope is the last one for this bucket we could clear the bucket here, but since
-                                                       # connections sometimes get repeat messages and second copies can arrive before the first
-                                                       # copy has arrived on all connections, leaving the bucket open to absorb repeats can reduce
-                                                       # the likelihood of a second bucket being allocated late for the second copy of a message
-                                                       # and causing unnecessary connection replacement.
-                                               elif event.scope not in self._outgoing_scopes:
-                                                       sane = (
-                                                               event.scope in self._scope_activations
-                                                               or any(e.scope == event.scope for e in self._pending)
-                                                       )
-                                                       if sane:
-                                                               self._logger.debug("new bucket {}: {}".format(_tag_digest(tag), tag))
-                                                               self._buckets[tag] = self._Bucket(event.timestamp, {event.scope})
-                                                               self._log_current_buckets()
-                                                               await self._message_tx.send(event)
-                                                       else:
-                                                               raise RuntimeError("recieved message from unrecognized WS connection")
-                                       else:
-                                               self._logger.debug("recv type {!r} from {} (discarding)".format(event.data["type"], obj_digest(event.scope)))
-                               elif isinstance(event, _ConnectionDown):
-                                       # We don't need to worry about canceling this scope at all, so no need to require it for parity for
-                                       # any message, even older ones.  The scope may be gone already, if we canceled it previously.
-                                       self._scope_activations.pop(event.scope, None)
-                                       self._pending = deque(ev for ev in self._pending if ev.scope is not event.scope)
-                                       self._outgoing_scopes.discard(event.scope)
-                               else:
-                                       raise TypeError(f"Expected pool event, found {event!r}")
-                               self._timer_poke_tx.send_nowait(None)  # may have new work for the timer
-
-       async def timeout_handler_impl(self):
-               """When connections have had enough time to reach parity on a message, signal replacement of any slackers."""
-               with self._timer_poke_rx:
-                       while True:
-                               if self._buckets:
-                                       now = trio.current_time()
-                                       (tag, bucket) = min(self._buckets.items(), key = lambda i: i[1].start)
-
-                                       # Make sure the right scope set is ready for the next bucket up
-                                       while self._pending and self._pending[0].timestamp < bucket.start:
-                                               ev = self._pending.popleft()
-                                               self._scope_activations[ev.scope] = ev.timestamp
-
-                                       target_scopes = {
-                                               scope for (scope, active) in self._scope_activations.items()
-                                               if active + self._conn_warmup.total_seconds() < now
-                                       }
-                                       if now > bucket.start + self._parity_timeout.total_seconds():
-                                               filled = target_scopes & bucket.recipients
-                                               missing = target_scopes - bucket.recipients
-                                               extra = bucket.recipients - target_scopes
-
-                                               self._logger.debug("expiring bucket {}".format(_tag_digest(tag)))
-                                               if filled:
-                                                       self._logger.debug("  filled {}: {}".format(len(filled), _format_scope_list(filled)))
-                                               if missing:
-                                                       self._logger.debug("  missing {}: {}".format(len(missing), _format_scope_list(missing)))
-                                               if extra:
-                                                       self._logger.debug("  extra {}: {}".format(len(extra), _format_scope_list(extra)))
-
-                                               for scope in missing:
-                                                       self._outgoing_scopes.add(scope)
-                                                       scope.cancel()
-                                               del self._buckets[tag]
-                                               self._log_current_buckets()
-                                       else:
-                                               await trio.sleep_until(bucket.start + self._parity_timeout.total_seconds())
-                               else:
-                                       try:
-                                               await self._timer_poke_rx.receive()
-                                       except trio.EndOfChannel:
-                                               break
diff --git a/src/strikebot/queue.py b/src/strikebot/queue.py
deleted file mode 100644 (file)
index adfbbe6..0000000
+++ /dev/null
@@ -1,58 +0,0 @@
-"""Unbounded blocking queues for Trio"""
-
-from collections import deque
-from dataclasses import dataclass
-from functools import total_ordering
-from typing import Any, Iterable
-import heapq
-
-from trio.lowlevel import ParkingLot
-
-
-class Queue:
-       def __init__(self):
-               self._deque = deque()
-               self._empty_wait = ParkingLot()
-
-       def push(self, el: Any) -> None:
-               self._deque.append(el)
-               self._empty_wait.unpark()
-
-       def extend(self, els: Iterable[Any]) -> None:
-               for el in els:
-                       self.push(el)
-
-       async def pop(self) -> Any:
-               if not self._deque:
-                       await self._empty_wait.park()
-               return self._deque.popleft()
-
-       def __len__(self):
-               return len(self._deque)
-
-
-@total_ordering
-@dataclass
-class _ReverseOrdWrapper:
-       inner: Any
-
-       def __lt__(self, other):
-               return other.inner < self.inner
-
-
-class MaxHeap:
-       def __init__(self):
-               self._heap = []
-               self._empty_wait = ParkingLot()
-
-       def push(self, item):
-               heapq.heappush(self._heap, _ReverseOrdWrapper(item))
-               self._empty_wait.unpark()
-
-       async def pop(self):
-               if not self._heap:
-                       await self._empty_wait.park()
-               return heapq.heappop(self._heap).inner
-
-       def __len__(self) -> int:
-               return len(self._heap)
diff --git a/src/strikebot/reddit_api.py b/src/strikebot/reddit_api.py
deleted file mode 100644 (file)
index 3e69e6c..0000000
+++ /dev/null
@@ -1,291 +0,0 @@
-"""Multi-auth HTTP connection pool with Reddit API wrappers and rate limiting."""
-
-from abc import ABCMeta, abstractmethod
-from dataclasses import dataclass, field
-from functools import total_ordering
-from socket import EAI_AGAIN, EAI_FAIL, gaierror
-import datetime as dt
-import logging
-
-from asks.response_objects import Response
-import asks
-import trio
-
-from strikebot import __version__ as VERSION
-from strikebot.common import obj_digest
-from strikebot.queue import MaxHeap, Queue
-
-
-REQUEST_DELAY = dt.timedelta(seconds = 1)
-
-TOKEN_UPDATE_DELAY = dt.timedelta(minutes = 15)
-
-USER_AGENT_FMT = f"any:net.jcornell.strikebot.{{auth_id}}:v{VERSION} (by /u/jaxklax)"
-
-API_BASE_URL = "https://oauth.reddit.com"
-
-
-@total_ordering
-class _Request(metaclass = ABCMeta):
-       _SUBTYPE_PRECEDENCE = None  # assigned later
-
-       def __lt__(self, other):
-               if type(self) is type(other):
-                       return self._subtype_cmp_key() < other._subtype_cmp_key()
-               else:
-                       prec = self._SUBTYPE_PRECEDENCE
-                       return prec.index(type(other)) < prec.index(type(self))
-
-       def __eq__(self, other):
-               if type(self) is type(other):
-                       return self._subtype_cmp_key() == other._subtype_cmp_key()
-               else:
-                       return False
-
-       @abstractmethod
-       def _subtype_cmp_key(self):
-               # a large key corresponds to a high priority
-               raise NotImplementedError()
-
-       @abstractmethod
-       def to_asks_kwargs(self):
-               raise NotImplementedError()
-
-
-@dataclass(eq = False)
-class StrikeRequest(_Request):
-       thread_id: str
-       update_name: str
-       update_ts: dt.datetime
-
-       def to_asks_kwargs(self):
-               return {
-                       "method": "POST",
-                       "path": f"/api/live/{self.thread_id}/strike_update",
-                       "data": {
-                               "api_type": "json",
-                               "id": self.update_name,
-                       },
-               }
-
-       def _subtype_cmp_key(self):
-               return self.update_ts
-
-
-@dataclass(eq = False)
-class DeleteRequest(_Request):
-       thread_id: str
-       update_name: str
-       update_ts: dt.datetime
-
-       def to_asks_kwargs(self):
-               return {
-                       "method": "POST",
-                       "path": f"/api/live/{self.thread_id}/delete_update",
-                       "data": {
-                               "api_type": "json",
-                               "id": self.update_name,
-                       },
-               }
-
-       def _subtype_cmp_key(self):
-               return -self.update_ts
-
-
-@dataclass(eq = False)
-class AboutLiveThreadRequest(_Request):
-       thread_id: str
-       created: float = field(default_factory = trio.current_time)
-
-       def to_asks_kwargs(self):
-               return {
-                       "method": "GET",
-                       "path": f"/live/{self.thread_id}/about",
-                       "params": {
-                               "raw_json": "1",
-                       },
-               }
-
-       def _subtype_cmp_key(self):
-               return -self.created
-
-
-@dataclass(eq = False)
-class _BaseUpdatePostRequest(_Request, metaclass = ABCMeta):
-       thread_id: str
-       body: str
-       created: float = field(default_factory = trio.current_time)
-
-       def to_asks_kwargs(self):
-               return {
-                       "method": "POST",
-                       "path": f"/live/{self.thread_id}/update",
-                       "data": {
-                               "api_type": "json",
-                               "body": self.body,
-                       },
-               }
-
-
-@dataclass(eq = False)
-class ReportUpdateRequest(_BaseUpdatePostRequest):
-       def _subtype_cmp_key(self):
-               return -self.created
-
-
-@dataclass(eq = False)
-class CorrectionUpdateRequest(_BaseUpdatePostRequest):
-       def _subtype_cmp_key(self):
-               return -self.created
-
-
-# highest priority first
-_Request._SUBTYPE_PRECEDENCE = [
-       AboutLiveThreadRequest,
-       StrikeRequest,
-       CorrectionUpdateRequest,
-       ReportUpdateRequest,
-       DeleteRequest,
-]
-
-
-@dataclass
-class AppCooldown:
-       auth_id: int
-       ready_at: float  # Trio time
-
-
-class ApiClientPool:
-       def __init__(
-               self,
-               auth_ids: set[int],
-               db_messenger: "strikebot.db.Messenger",
-               request_queue_limit: int,
-               error_window: dt.timedelta,
-               error_delay: dt.timedelta,
-               logger: logging.Logger,
-       ):
-               self._auth_ids = auth_ids
-               self._db_messenger = db_messenger
-               self._request_queue_limit = request_queue_limit
-               self._error_window = error_window
-               self._error_delay = error_delay
-               self._logger = logger
-
-               now = trio.current_time()
-               self._tokens = {}
-               self._app_queue = Queue()
-               self._waiting = {id_: AppCooldown(id_, now) for id_ in auth_ids}
-               self._request_queue = MaxHeap()
-
-               # pool-wide API error backoff
-               self._last_error = None
-               self._global_resume = None
-
-               self._session = asks.Session(connections = len(auth_ids))
-               self._session.base_location = API_BASE_URL
-
-       async def _update_tokens(self):
-               tokens = await self._db_messenger.do("get_auth_tokens", (self._auth_ids,))
-               self._tokens.update(tokens)
-
-               awaken_auths = self._waiting.keys() & tokens.keys()
-               self._app_queue.extend(self._waiting.pop(auth_id) for auth_id in awaken_auths)
-               self._logger.debug("updated API tokens")
-
-       async def token_updater_impl(self):
-               last_update = None
-               while True:
-                       if last_update is not None:
-                               await trio.sleep_until(last_update + TOKEN_UPDATE_DELAY.total_seconds())
-
-                       if last_update is None:
-                               last_update = trio.current_time()
-                       else:
-                               last_update += TOKEN_UPDATE_DELAY.total_seconds()
-
-                       await self._update_tokens()
-
-       def _check_queue_size(self) -> None:
-               if len(self._request_queue) > self._request_queue_limit:
-                       raise RuntimeError("request queue size exceeded limit")
-
-       async def make_request(self, request: _Request) -> Response:
-               (resp_tx, resp_rx) = trio.open_memory_channel(0)
-               self.enqueue_request(request, resp_tx)
-               async with resp_rx:
-                       return await resp_rx.receive()
-
-       def enqueue_request(self, request: _Request, resp_tx = None) -> None:
-               self._logger.debug("request {}: {}".format(obj_digest(request), type(request).__name__))
-               self._request_queue.push((request, resp_tx))
-               self._check_queue_size()
-               self._logger.debug(f"{len(self._request_queue)} requests in queue")
-
-       async def worker_impl(self, task_status):
-               task_status.started()
-               while True:
-                       (request, resp_tx) = await self._request_queue.pop()
-                       cooldown = await self._app_queue.pop()
-                       await trio.sleep_until(cooldown.ready_at)
-                       if self._global_resume:
-                               await trio.sleep_until(self._global_resume)
-
-                       asks_kwargs = request.to_asks_kwargs()
-                       headers = asks_kwargs.setdefault("headers", {})
-                       headers.update({
-                               "Authorization": "Bearer {}".format(self._tokens[cooldown.auth_id]),
-                               "User-Agent": USER_AGENT_FMT.format(auth_id = cooldown.auth_id),
-                       })
-
-                       request_time = trio.current_time()
-                       try:
-                               resp = await self._session.request(**asks_kwargs)
-                       except gaierror as e:
-                               if e.errno in [EAI_FAIL, EAI_AGAIN]:
-                                       # DNS failure, probably temporary
-                                       error = True
-                               else:
-                                       raise
-                       else:
-                               resp.body  # read response
-                               error = False
-                               wait_for_token = False
-                               log_suffix = " (request {})".format(obj_digest(request))
-                               if resp.status_code == 429:
-                                       # We disagreed about the rate limit state; just try again later.
-                                       self._logger.warning("rate limited by Reddit API" + log_suffix)
-                                       error = True
-                               elif resp.status_code == 401:
-                                       self._logger.warning("got HTTP 401 from Reddit API" + log_suffix)
-                                       error = True
-                                       wait_for_token = True
-                               elif resp.status_code in [404, 500, 503]:
-                                       self._logger.warning(f"got HTTP {resp.status_code} from Reddit API, retrying" + log_suffix)
-                                       error = True
-                               elif 400 <= resp.status_code < 500:
-                                       # If we're doing something wrong, let's catch it right away.
-                                       raise RuntimeError(f"unexpected client error response: {resp.status_code}" + log_suffix)
-                               else:
-                                       if resp.status_code != 200:
-                                               raise RuntimeError(f"unexpected status code {resp.status_code}")
-                                       self._logger.debug("success" + log_suffix)
-                                       if resp_tx:
-                                               await resp_tx.send(resp)
-
-                       if error:
-                               self._request_queue.push((request, resp_tx))
-                               self._check_queue_size()
-                               if self._last_error:
-                                       spread = dt.timedelta(seconds = request_time - self._last_error)
-                                       if spread <= self._error_window:
-                                               self._global_resume = request_time + self._error_delay.total_seconds()
-                               self._last_error = request_time
-
-                       cooldown.ready_at = request_time + REQUEST_DELAY.total_seconds()
-                       if wait_for_token:
-                               self._waiting[cooldown.auth_id] = cooldown
-                               if not self._app_queue:
-                                       self._logger.error("all workers waiting for API tokens")
-                       else:
-                               self._app_queue.push(cooldown)
diff --git a/src/strikebot/tests.py b/src/strikebot/tests.py
deleted file mode 100644 (file)
index 7bbf226..0000000
+++ /dev/null
@@ -1,50 +0,0 @@
-from unittest import TestCase
-
-from strikebot.updates import parse_update
-
-
-def _build_payload(body_html: str) -> dict[str, str]:
-       return {"body_html": body_html}
-
-
-class UpdateParsingTests(TestCase):
-       def test_successful_counts(self):
-               pu = parse_update(_build_payload("<div>12,345,678 spaghetti</div>"), None, "")
-               self.assertEqual(pu.number, 12345678)
-               self.assertFalse(pu.deletable)
-
-               pu = parse_update(_build_payload("<div><p>0</p><br/>oz</div>"), None, "")
-               self.assertEqual(pu.number, 0)
-               self.assertFalse(pu.deletable)
-
-               pu = parse_update(_build_payload("<div>121 345 621</div>"), 121, "")
-               self.assertEqual(pu.number, 121)
-               self.assertFalse(pu.deletable)
-
-               pu = parse_update(_build_payload("28 336 816"), 28336816, "")
-               self.assertEqual(pu.number, 28336816)
-               self.assertTrue(pu.deletable)
-
-       def test_non_counts(self):
-               pu = parse_update(_build_payload("<div><table><tbody><tr><td>zoo</td></tr></tbody></table></div>"), None, "")
-               self.assertFalse(pu.count_attempt)
-               self.assertFalse(pu.deletable)
-
-       def test_typos(self):
-               pu = parse_update(_build_payload("<span>v9</span>"), 888, "")
-               self.assertIsNone(pu.number)
-               self.assertTrue(pu.count_attempt)
-
-               pu = parse_update(_build_payload("<div>v11.585 Empire</div>"), None, "")
-               self.assertIsNone(pu.number)
-               self.assertTrue(pu.count_attempt)
-               self.assertFalse(pu.deletable)
-
-               pu = parse_update(_build_payload("<div>11, 585, 22 </div>"), 11_585_202, "")
-               self.assertIsNone(pu.number)
-               self.assertTrue(pu.count_attempt)
-               self.assertTrue(pu.deletable)
-
-               pu = parse_update(_build_payload("<span>0490499</span>"), 4999, "")
-               self.assertIsNone(pu.number)
-               self.assertTrue(pu.count_attempt)
diff --git a/src/strikebot/updates.py b/src/strikebot/updates.py
deleted file mode 100644 (file)
index e3f58cd..0000000
+++ /dev/null
@@ -1,216 +0,0 @@
-from __future__ import annotations
-from dataclasses import dataclass
-from enum import Enum
-from typing import Optional
-import re
-
-from bs4 import BeautifulSoup
-
-
-Command = Enum("Command", ["RESET", "REPORT"])
-
-
-@dataclass
-class ParsedUpdate:
-       number: Optional[int]
-       command: Optional[Command]
-       count_attempt: bool  # either well-formed or typo
-       deletable: bool
-
-
-def _parse_command(line: str, bot_user: str) -> Optional[Command]:
-       if line.lower() == f"/u/{bot_user} reset".lower():
-               return Command.RESET
-       elif line.lower() in ["sidebar count", "current count"]:
-               return Command.REPORT
-       else:
-               return None
-
-
-def parse_update(payload_data: dict, curr_count: Optional[int], bot_user: str) -> ParsedUpdate:
-       # curr_count is the next number up, one more than the last count
-
-       NEW_LINE = object()
-       SPACE = object()
-
-       # flatten the update content to plain text
-       tree = BeautifulSoup(payload_data["body_html"], "html.parser")
-       worklist = tree.contents
-       out = [[]]
-       while worklist:
-               el = worklist.pop()
-               if isinstance(el, str):
-                       out[-1].append(el)
-               elif el is SPACE:
-                       out[-1].append(el)
-               elif el is NEW_LINE or el.name == "br":
-                       if out[-1]:
-                               out.append([])
-               elif el.name in ["em", "strong", "del", "span", "sup", "code", "a", "th", "td"]:
-                       worklist.extend(reversed(el.contents))
-               elif el.name in ["ul", "ol", "table", "thead", "tbody"]:
-                       worklist.extend(reversed(el.contents))
-               elif el.name in ["li", "p", "div", "h1", "h2", "blockquote"]:
-                       worklist.append(NEW_LINE)
-                       worklist.extend(reversed(el.contents))
-                       worklist.append(NEW_LINE)
-               elif el.name == "pre":
-                       worklist.append(NEW_LINE)
-                       worklist.extend([l] for l in reversed(el.text.splitlines()))
-                       worklist.append(NEW_LINE)
-               elif el.name == "tr":
-                       worklist.append(NEW_LINE)
-                       for (i, cell) in enumerate(reversed(el.contents)):
-                               worklist.append(cell)
-                               if i != len(el.contents) - 1:
-                                       worklist.append(SPACE)
-                       worklist.append(NEW_LINE)
-               else:
-                       raise RuntimeError(f"can't parse tag {el.tag}")
-
-       tmp_lines = (
-               "".join(" " if part is SPACE else part for part in parts).strip()
-               for parts in out
-       )
-       pre_strip_lines = list(filter(None, tmp_lines))
-
-       # normalize whitespace according to HTML rendering rules
-       # https://developer.mozilla.org/en-US/docs/Web/API/Document_Object_Model/Whitespace#explanation
-       stripped_lines = [
-               re.sub(" +", " ", l.replace("\t", " ").replace("\n", " ")).strip(" ")
-               for l in pre_strip_lines
-       ]
-
-       return _parse_from_lines(stripped_lines, curr_count, bot_user)
-
-
-def _parse_from_lines(lines: list[str], curr_count: Optional[int], bot_user: str) -> ParsedUpdate:
-       command = next(
-               filter(None, (_parse_command(l, bot_user) for l in lines)),
-               None
-       )
-       if lines:
-               # look for groups of digits (as many as possible) separated by a uniform separator from the valid set
-               first = lines[0]
-               match = re.match(
-                       "(?P<v>v)?(?P<neg>-)?(?P<num>\\d+((?P<sep>[,. \u2009]|, )\\d+((?P=sep)\\d+)*)?)",
-                       first,
-                       re.ASCII,  # only recognize ASCII digits
-               )
-               if match:
-                       raw_digits = match["num"]
-                       sep = match["sep"]
-                       post = first[match.end() :]
-
-                       zeros = False
-                       while len(raw_digits) > 1 and raw_digits[0] == "0":
-                               zeros = True
-                               raw_digits = raw_digits.removeprefix("0").removeprefix(sep or "")
-
-                       parts = raw_digits.split(sep) if sep else [raw_digits]
-                       lone = len(lines) == 1 and (not post or post.isspace())
-                       typo = False
-                       if lone:
-                               all_parts_valid = (
-                                       sep is None
-                                       or (
-                                               1 <= len(parts[0]) <= 3
-                                               and all(len(p) == 3 for p in parts[1:])
-                                       )
-                               )
-                               if match["v"] and len(parts) == 1 and len(parts[0]) <= 2:
-                                       # failed paste of leading digits
-                                       typo = True
-                               elif match["v"] and all_parts_valid:
-                                       # v followed by count
-                                       typo = True
-                               elif curr_count is not None and abs(curr_count) >= 100 and bool(match["neg"]) == (curr_count < 0):
-                                       goal_parts = _separate(str(abs(curr_count)))
-                                       partials = [
-                                               goal_parts[: -1] + [goal_parts[-1][: -1]],  # missing last digit
-                                               goal_parts[: -1] + [goal_parts[-1][: -2]],  # missing last two digits
-                                               goal_parts[: -1] + [goal_parts[-1][: -2] + goal_parts[-1][-1]],  # missing second-last digit
-                                       ]
-                                       if parts in partials:
-                                               # missing any of last two digits
-                                               typo = True
-                                       elif parts in [p[: -1] + [p[-1] + goal_parts[0]] + goal_parts[1 :] for p in partials]:
-                                               # double paste
-                                               typo = True
-
-                       if match["v"] or zeros or typo or (parts == ["0"] and match["neg"]):
-                               number = None
-                               count_attempt = True
-                               deletable = lone
-                       else:
-                               if curr_count is not None and sep and sep.isspace():
-                                       # Presume that the intended count consists of as many valid digit groups as necessary to match the
-                                       # number of digits in the expected count, if possible.
-                                       digit_count = len(str(abs(curr_count)))
-                                       use_parts = []
-                                       accum = 0
-                                       for (i, part) in enumerate(parts):
-                                               part_valid = len(part) <= 3 if i == 0 else len(part) == 3
-                                               if part_valid and accum < digit_count:
-                                                       use_parts.append(part)
-                                                       accum += len(part)
-                                               else:
-                                                       break
-
-                                       # could still be a no-separator count with some extra digit groups on the same line
-                                       if not use_parts:
-                                               use_parts = [parts[0]]
-
-                                       lone = lone and len(use_parts) == len(parts)
-                               else:
-                                       # current count is unknown or no separator was used
-                                       use_parts = parts
-
-                               digits = "".join(use_parts)
-                               number = -int(digits) if match["neg"] else int(digits)
-                               special = (
-                                       curr_count is not None
-                                       and abs(number - curr_count) <= 25
-                                       and _is_special_number(number)
-                               )
-                               deletable = lone and not special
-                               if len(use_parts) == len(parts) and post and not post[0].isspace():
-                                       count_attempt = curr_count is not None and abs(number - curr_count) <= 25
-                                       number = None
-                               else:
-                                       count_attempt = True
-               else:
-                       # no count attempt found
-                       number = None
-                       count_attempt = False
-                       deletable = False
-       else:
-               # no lines in update
-               number = None
-               count_attempt = False
-               deletable = True
-
-       return ParsedUpdate(
-               number = number,
-               command = command,
-               count_attempt = count_attempt,
-               deletable = deletable,
-       )
-
-
-def _separate(digits: str) -> list[str]:
-       mod = len(digits) % 3
-       out = []
-       if mod:
-               out.append(digits[: mod])
-       out.extend(digits[i : i + 3] for i in range(mod, len(digits), 3))
-       return out
-
-
-def _is_special_number(num: int) -> bool:
-       num_str = str(num)
-       return bool(
-               num % 1000 in [0, 1, 333, 999]
-               or (num > 10_000_000 and "".join(reversed(num_str)) == num_str)
-               or re.match(r"(.+)\1+$", num_str)  # repeated sequence
-       )
diff --git a/strikebot/docs/sample_config.ini b/strikebot/docs/sample_config.ini
new file mode 100644 (file)
index 0000000..28e510e
--- /dev/null
@@ -0,0 +1,53 @@
+# see Python `configparser' docs for precise file syntax
+
+[config]
+
+######## basic operating parameters
+
+# space-separated authorization IDs for Reddit API token lookup
+auth IDs = 13 15
+
+bot user = count_better
+enforcing = true
+thread ID = abc123
+
+
+######## API pool error handling
+
+# (seconds) for error backoff, pause the API pool for this much time
+API pool error delay = 5.5
+
+# (seconds) pause the pool when the Reddit API returns two errors within any time span of this length
+API pool error window = 5.5
+
+# terminate if the Reddit API request queue exceeds this size
+request queue limit = 100
+
+
+######## live thread update handling
+
+# (seconds) maximum time to hold updates for reordering
+reorder buffer time = 0.25
+
+# (seconds) minimum time to retain updates to enable future resyncs
+update retention time = 120.0
+
+
+######## WebSocket pool
+
+# (seconds) maximum allowable spread in WebSocket message arrival times
+WS parity time = 1.0
+
+# goal size of WebSocket connection pool
+WS pool size = 5
+
+# (seconds) after receiving no messages for this much time, attempt to replace the connection if the pool isn't depleted
+WS silent limit = 30
+
+# (seconds) time after WebSocket handshake completion during which missed updates are excused
+WS warmup time = 0.3
+
+
+# Postgres database configuration; same options as in a connect string
+[db connect params]
+host = example.org
diff --git a/strikebot/pyproject.toml b/strikebot/pyproject.toml
new file mode 100644 (file)
index 0000000..8fe2f47
--- /dev/null
@@ -0,0 +1,3 @@
+[build-system]
+requires = ["setuptools>=42", "wheel"]
+build-backend = "setuptools.build_meta"
diff --git a/strikebot/setup.cfg b/strikebot/setup.cfg
new file mode 100644 (file)
index 0000000..f200179
--- /dev/null
@@ -0,0 +1,15 @@
+[metadata]
+name = strikebot
+version = 0.0.0
+
+[options]
+package_dir =
+       = src
+packages = strikebot
+python_requires = ~= 3.9
+install_requires =
+       asks ~= 2.4
+       beautifulsoup4 ~= 4.9
+       trio == 0.19
+       trio-websocket == 0.9.2
+       triopg == 0.6.0
diff --git a/strikebot/setup.py b/strikebot/setup.py
new file mode 100644 (file)
index 0000000..056ba45
--- /dev/null
@@ -0,0 +1,4 @@
+import setuptools
+
+
+setuptools.setup()
diff --git a/strikebot/src/strikebot/__init__.py b/strikebot/src/strikebot/__init__.py
new file mode 100644 (file)
index 0000000..7c37d14
--- /dev/null
@@ -0,0 +1,321 @@
+from contextlib import nullcontext as nullcontext
+from dataclasses import dataclass
+from functools import total_ordering
+from itertools import islice
+from typing import Optional
+from uuid import UUID
+import bisect
+import datetime as dt
+import importlib.metadata
+import itertools
+import logging
+
+from trio import CancelScope, EndOfChannel
+import trio
+
+from strikebot.updates import Command, parse_update
+
+
+__version__ = importlib.metadata.version(__package__)
+
+
+_READ_ONLY: bool = False  # suppress any API requests that modify the thread?
+
+
+@dataclass
+class _Update:
+       id: str
+       name: str
+       author: str
+       number: Optional[int]
+       command: Optional[Command]
+       ts: dt.datetime
+       count_attempt: bool
+       deletable: bool
+       stricken: bool
+
+       def can_follow(self, prior: "_Update") -> bool:
+               """Determine whether this count update can follow another count update."""
+               return self.number == prior.number + 1 and self.author != prior.author
+
+       def __str__(self):
+               content = str(self.number)
+               if self.command:
+                       content += "+" + self.command.name
+               return "{}({} by {})".format(self.id[4 : 8], content, self.author)
+
+
+@total_ordering
+@dataclass
+class _BufferedUpdate:
+       update: _Update
+       release_at: float  # Trio time
+
+       def __lt__(self, other):
+               return self.update.ts < other.update.ts
+
+
+@total_ordering
+@dataclass
+class _TimelineUpdate:
+       update: _Update
+       accepted: bool
+
+       def __lt__(self, other):
+               return self.update.ts < other.update.ts
+
+       def rejected(self) -> bool:
+               """Whether the update should be stricken."""
+               return self.update.count_attempt and not self.accepted
+
+
+def _parse_rfc_4122_ts(ts: int) -> dt.datetime:
+       epoch = dt.datetime(1582, 10, 15, tzinfo = dt.timezone.utc)
+       return epoch + dt.timedelta(microseconds = ts / 10)
+
+
+def _format_update_ref(update: _Update, thread_id: str) -> str:
+       url = f"https://www.reddit.com/live/{thread_id}/updates/{update.id}"
+       author_md = update.author.replace("_", "\\_")  # Reddit allows letters, digits, underscore, and hyphen
+       return f"[{update.number:,} by {author_md}]({url})"
+
+
+def _format_bad_strike_alert(update: _Update, thread_id: str) -> str:
+       return _format_update_ref(update, thread_id) + " was incorrectly stricken."
+
+
+def _format_curr_count(last_valid: Optional[_Update]) -> str:
+       if last_valid and last_valid.number is not None:
+               count_str = f"{last_valid.number + 1:,}"
+       else:
+               count_str = "unknown"
+       return f"_current count:_ {count_str}"
+
+
+async def count_tracker_impl(
+       message_rx,
+       api_pool: "strikebot.reddit_api.ApiClientPool",
+       reorder_buffer_time: dt.timedelta,  # duration to hold some updates for reordering
+       thread_id: str,
+       bot_user: str,
+       enforcing: bool,
+       update_retention: dt.timedelta,
+       logger: logging.Logger,
+) -> None:
+       from strikebot.reddit_api import CorrectionUpdateRequest, DeleteRequest, ReportUpdateRequest, StrikeRequest
+
+       enforcing = enforcing and not _READ_ONLY
+
+       buffer_ = []
+       timeline = []
+       last_valid: Optional[_Update] = None
+       pending_strikes: set[str] = set()  # names of updates to mark stricken on arrival
+       forgotten: bool = False  # whether the retention period for an update has passed during the run
+
+       def handle_update(update):
+               nonlocal last_valid
+
+               tu = _TimelineUpdate(update, accepted = False)
+
+               pos = bisect.bisect(timeline, tu)
+               if pos != len(timeline):
+                       logger.warning(f"long transpo: {update.id}")
+
+               if pos > 0 and timeline[pos - 1].update.id == update.id:
+                       # The pool sent this update message multiple times.  This could be because the message reached parity and
+                       # its bucket was dropped just before a new connection was opened and sent its own copy, but connections also
+                       # seem to send duplicate messages.
+                       return
+
+               pred = next(
+                       (timeline[i].update for i in reversed(range(pos)) if timeline[i].accepted),
+                       None
+               )
+               contender = update.command is Command.RESET or update.number is not None
+               if contender and pred is None and last_valid and forgotten:
+                       # This is a really long transpo which we can't judge valid or invalid since we've forgotten the surrounding
+                       # updates. The best we can do is just ignore it.
+                       logger.warning(f"ignoring {update.id}: no valid prior count on record")
+               else:
+                       timeline.insert(pos, tu)
+                       tu.accepted = (
+                               update.command is Command.RESET
+                               or (
+                                       update.number is not None
+                                       and (pred is None or pred.number is None or update.can_follow(pred))
+                               )
+                       )
+                       logger.debug("{} {} -> {}".format("accepted" if tu.accepted else "rejected", pred, update))
+                       if tu.accepted:
+                               # resync subsequent updates
+                               newly_valid = []
+                               newly_invalid = []
+                               resync_last_valid = update
+                               converged = False
+                               for scan_tu in islice(timeline, pos + 1, None):
+                                       if scan_tu.update.command is Command.RESET:
+                                               resync_last_valid = scan_tu.update
+                                               if last_valid:
+                                                       converged = True
+                                       elif scan_tu.update.number is not None:
+                                               accept = resync_last_valid.number is None or scan_tu.update.can_follow(resync_last_valid)
+                                               if accept and scan_tu.accepted:
+                                                       # resync would have no effect past this point
+                                                       if last_valid:
+                                                               converged = True
+                                               elif accept:
+                                                       newly_valid.append(scan_tu)
+                                                       resync_last_valid = scan_tu.update
+                                               elif scan_tu.accepted:
+                                                       newly_invalid.append(scan_tu)
+                                                       if scan_tu.update is last_valid:
+                                                               last_valid = None
+                                               scan_tu.accepted = accept
+                                       if converged:
+                                               break
+
+                               if converged:
+                                       assert last_valid.ts >= resync_last_valid.ts
+                               else:
+                                       last_valid = resync_last_valid
+
+                               parts = []
+                               if newly_valid:
+                                       parts.append(
+                                               "The following counts are valid:\n\n"
+                                               + "\n".join(" - " + _format_update_ref(tu.update, thread_id) for tu in newly_valid)
+                                       )
+
+                               unstrikable = [tu for tu in newly_invalid if tu.update.stricken]
+                               if unstrikable:
+                                       parts.append(
+                                               "The following counts are invalid:\n\n"
+                                               + "\n".join(" - " + _format_update_ref(tu.update, thread_id) for tu in unstrikable)
+                                       )
+
+                               if update.stricken:
+                                       logger.info(f"bad strike of {update.id}")
+                                       parts.append(_format_bad_strike_alert(update, thread_id))
+
+                               if parts:
+                                       parts.append(_format_curr_count(last_valid))
+                                       if not _READ_ONLY:
+                                               api_pool.enqueue_request(CorrectionUpdateRequest(thread_id, "\n\n".join(parts)))
+
+                               for invalid_tu in newly_invalid:
+                                       if not invalid_tu.update.stricken:
+                                               if enforcing:
+                                                       api_pool.enqueue_request(StrikeRequest(thread_id, invalid_tu.update.name, invalid_tu.update.ts))
+                                               invalid_tu.update.stricken = True
+                       elif update.count_attempt:
+                               if enforcing:
+                                       api_pool.enqueue_request(StrikeRequest(thread_id, update.name, update.ts))
+                               update.stricken = True
+
+               if not _READ_ONLY and update.command is Command.REPORT:
+                       api_pool.enqueue_request(ReportUpdateRequest(thread_id, body = _format_curr_count(last_valid)))
+
+       with message_rx:
+               while True:
+                       if buffer_:
+                               deadline = min(bu.release_at for bu in buffer_)
+                               cancel_scope = CancelScope(deadline = deadline)
+                       else:
+                               cancel_scope = nullcontext()
+
+                       msg = None
+                       with cancel_scope:
+                               try:
+                                       msg = await message_rx.receive()
+                               except EndOfChannel:
+                                       break
+
+                       if msg:
+                               if msg.data["type"] == "update":
+                                       payload_data = msg.data["payload"]["data"]
+                                       stricken = payload_data["name"] in pending_strikes
+                                       pending_strikes.discard(payload_data["name"])
+                                       if payload_data["author"] != bot_user:
+                                               if last_valid and last_valid.number is not None:
+                                                       next_up = last_valid.number + 1
+                                               else:
+                                                       next_up = None
+                                               pu = parse_update(payload_data, next_up, bot_user)
+                                               rfc_ts = UUID(payload_data["id"]).time
+                                               update = _Update(
+                                                       id = payload_data["id"],
+                                                       name = payload_data["name"],
+                                                       author = payload_data["author"],
+                                                       number = pu.number,
+                                                       command = pu.command,
+                                                       ts = _parse_rfc_4122_ts(rfc_ts),
+                                                       count_attempt = pu.count_attempt,
+                                                       deletable = pu.deletable,
+                                                       stricken = stricken or payload_data["stricken"],
+                                               )
+                                               release_at = msg.timestamp + reorder_buffer_time.total_seconds()
+                                               bisect.insort(buffer_, _BufferedUpdate(update, release_at))
+                               elif msg.data["type"] == "strike":
+                                       slot = next(
+                                               (
+                                                       slot for slot in itertools.chain(buffer_, reversed(timeline))
+                                                       if slot.update.name == msg.data["payload"]
+                                               ),
+                                               None
+                                       )
+                                       if slot:
+                                               if not slot.update.stricken:
+                                                       slot.update.stricken = True
+                                                       if isinstance(slot, _TimelineUpdate) and slot.accepted:
+                                                               logger.info(f"bad strike of {slot.update.id}")
+                                                               if not _READ_ONLY:
+                                                                       body = _format_bad_strike_alert(slot.update, thread_id)
+                                                                       api_pool.enqueue_request(CorrectionUpdateRequest(thread_id, body))
+                                       else:
+                                               pending_strikes.add(msg.data["payload"])
+
+                       threshold = dt.datetime.now(dt.timezone.utc) - update_retention
+
+                       # pull updates from the reorder buffer and process
+                       new_buffer = []
+                       for bu in buffer_:
+                               process = (
+                                       # not a count
+                                       bu.update.number is None
+
+                                       # valid count
+                                       or (last_valid is None or last_valid.number is None or bu.update.can_follow(last_valid))
+                                       or bu.update.command is Command.RESET
+
+                                       or trio.current_time() >= bu.release_at
+                                       or bu.update.command is Command.RESET
+                               )
+
+                               if process:
+                                       if bu.update.ts < threshold:
+                                               logger.warning(f"ignoring {bu.update}: arrived past retention window")
+                                       else:
+                                               logger.debug(f"processing {bu.update}")
+                                               handle_update(bu.update)
+                               else:
+                                       logger.debug(f"holding {bu.update}, checked against {last_valid})")
+                                       new_buffer.append(bu)
+                       buffer_ = new_buffer
+                       logger.debug("last count {}".format(last_valid))
+
+                       # delete/forget old updates
+                       new_timeline = []
+                       i = 0
+                       for (i, tu) in enumerate(timeline):
+                               if i >= len(timeline) - 10 or tu.update.ts >= threshold:
+                                       break
+                               elif tu.rejected() and tu.update.deletable:
+                                       if enforcing:
+                                               api_pool.enqueue_request(DeleteRequest(thread_id, tu.update.name, tu.update.ts))
+                               elif tu.update is last_valid:
+                                       new_timeline.append(tu)
+
+                       if i:
+                               forgotten = True
+                       new_timeline.extend(islice(timeline, i, None))
+                       timeline = new_timeline
diff --git a/strikebot/src/strikebot/__main__.py b/strikebot/src/strikebot/__main__.py
new file mode 100644 (file)
index 0000000..4c2b1d9
--- /dev/null
@@ -0,0 +1,191 @@
+from enum import Enum
+from inspect import getmodule
+from logging import FileHandler, getLogger, StreamHandler
+from pathlib import Path
+from signal import SIGUSR1
+from sys import stdout
+from traceback import StackSummary
+from typing import Optional
+import argparse
+import configparser
+import datetime as dt
+import logging
+
+from trio import open_memory_channel, open_nursery, open_signal_receiver
+from trio.lowlevel import current_root_task, Task
+import trio_asyncio
+import triopg
+
+from strikebot import count_tracker_impl
+from strikebot.db import Client, Messenger
+from strikebot.live_ws import HealingReadPool, PoolMerger
+from strikebot.reddit_api import ApiClientPool
+
+
+_DEBUG_LOG_PATH: Optional[Path] = None  # file to write debug logs to, if any
+
+
+_TaskKind = Enum(
+       "_TaskKind",
+       [
+               "API_WORKER",
+               "DB_CLIENT",
+               "TOKEN_UPDATE",
+               "TRACKER",
+               "WS_MERGE_READER",
+               "WS_MERGE_TIMER",
+               "WS_REFRESHER",
+               "WS_WORKER",
+       ]
+)
+
+
+def _classify_task(task: Task) -> _TaskKind:
+       mod_name = getmodule(task.coro.cr_code).__name__
+       if mod_name.startswith("trio_asyncio."):
+               return None
+       elif task.coro.__name__ == "_signal_handler_impl":
+               return None
+       else:
+               by_coro_name = {
+                       "db_client_impl": _TaskKind.DB_CLIENT,
+                       "token_updater_impl": _TaskKind.TOKEN_UPDATE,
+                       "event_reader_impl": _TaskKind.WS_MERGE_READER,
+                       "timeout_handler_impl": _TaskKind.WS_MERGE_TIMER,
+                       "conn_refresher_impl": _TaskKind.WS_REFRESHER,
+                       "count_tracker_impl": _TaskKind.TRACKER,
+                       "_reader_impl": _TaskKind.WS_WORKER,
+                       "worker_impl": _TaskKind.API_WORKER,
+               }
+               return by_coro_name[task.coro.__name__]
+
+
+def _get_all_tasks():
+       [ta_loop] = [
+               task
+               for nursery in current_root_task().child_nurseries
+               for task in nursery.child_tasks
+               if getmodule(task.coro.cr_code).__name__.startswith("trio_asyncio.")
+       ]
+       for nursery in ta_loop.child_nurseries:
+               yield from nursery.child_tasks
+
+
+def _print_task_info():
+       groups = {kind: [] for kind in _TaskKind}
+       for task in _get_all_tasks():
+               kind = _classify_task(task)
+               if kind:
+                       coro = task.coro
+                       frame_tuples = []
+                       while coro is not None:
+                               if hasattr(coro, "cr_frame"):
+                                       frame_tuples.append((coro.cr_frame, coro.cr_frame.f_lineno))
+                                       coro = coro.cr_await
+                               else:
+                                       frame_tuples.append((coro.gi_frame, coro.gi_frame.f_lineno))
+                                       coro = coro.gi_yieldfrom
+                       groups[kind].append(StackSummary.extract(iter(frame_tuples)))
+       for kind in _TaskKind:
+               print(kind.name)
+               for ss in groups[kind]:
+                       print("  task")
+                       for line in ss.format():
+                               print("    " + line, end = "")
+
+
+async def _signal_handler_impl():
+       with open_signal_receiver(SIGUSR1) as sig_src:
+               async for _ in sig_src:
+                       _print_task_info()
+
+
+ap = argparse.ArgumentParser(__package__)
+ap.add_argument("config_path")
+args = ap.parse_args()
+
+
+parser = configparser.ConfigParser()
+with open(args.config_path) as config_file:
+       parser.read_file(config_file)
+
+main_cfg = parser["config"]
+
+api_pool_error_delay = dt.timedelta(seconds = main_cfg.getfloat("API pool error delay"))
+api_pool_error_window = dt.timedelta(seconds = main_cfg.getfloat("API pool error window"))
+auth_ids = set(map(int, main_cfg["auth IDs"].split()))
+bot_user = main_cfg["bot user"]
+enforcing = main_cfg.getboolean("enforcing")
+reorder_buffer_time = dt.timedelta(seconds = main_cfg.getfloat("reorder buffer time"))
+request_queue_limit = main_cfg.getint("request queue limit")
+thread_id = main_cfg["thread ID"]
+update_retention = dt.timedelta(seconds = main_cfg.getfloat("update retention time"))
+ws_parity_time = dt.timedelta(seconds = main_cfg.getfloat("WS parity time"))
+ws_pool_size = main_cfg.getint("WS pool size")
+ws_silent_limit = dt.timedelta(seconds = main_cfg.getfloat("WS silent limit"))
+ws_warmup = dt.timedelta(seconds = main_cfg.getfloat("WS warmup seconds"))
+
+db_cfg = parser["db connect params"]
+getters = {
+       "port": db_cfg.getint,
+}
+db_connect_params = {k: getters.get(k, db_cfg.get)(k) for k in db_cfg}
+
+
+logger = getLogger(__package__)
+logger.setLevel(logging.DEBUG)
+
+handler = StreamHandler(stdout)
+handler.setLevel(logging.WARNING)
+handler.setFormatter(logging.Formatter("{asctime:23}: {name:17} {levelname:8}: {message}", style = "{"))
+logger.addHandler(handler)
+
+if _DEBUG_LOG_PATH:
+       debug_handler = FileHandler(_DEBUG_LOG_PATH, mode = "w")
+       debug_handler.setLevel(logging.DEBUG)
+       debug_handler.setFormatter(logging.Formatter("{asctime:23}: {name:17} {levelname:8}: {message}", style = "{"))
+       logger.addHandler(debug_handler)
+
+
+async def main():
+       async with (
+               triopg.connect(**db_connect_params) as db_conn,
+               open_nursery() as nursery_a,
+               open_nursery() as nursery_b,
+               open_nursery() as ws_pool_nursery,
+               open_nursery() as nursery_c,
+       ):
+               nursery_a.start_soon(_signal_handler_impl)
+
+               client = Client(db_conn)
+               db_messenger = Messenger(client)
+               nursery_a.start_soon(db_messenger.db_client_impl)
+
+               api_pool = ApiClientPool(
+                       auth_ids, db_messenger, request_queue_limit, api_pool_error_window, api_pool_error_delay,
+                       logger.getChild("api")
+               )
+               nursery_a.start_soon(api_pool.token_updater_impl)
+               for _ in auth_ids:
+                       await nursery_a.start(api_pool.worker_impl)
+
+               (message_tx, message_rx) = open_memory_channel(0)
+               (pool_event_tx, pool_event_rx) = open_memory_channel(0)
+
+               merger = PoolMerger(pool_event_rx, message_tx, ws_parity_time, ws_warmup, logger.getChild("merge"))
+               nursery_b.start_soon(merger.event_reader_impl)
+               nursery_b.start_soon(merger.timeout_handler_impl)
+
+               ws_pool = HealingReadPool(
+                       ws_pool_nursery, ws_pool_size, thread_id, api_pool, pool_event_tx, logger.getChild("live_ws"),
+                       ws_silent_limit
+               )
+               nursery_c.start_soon(ws_pool.conn_refresher_impl)
+
+               nursery_c.start_soon(
+                       count_tracker_impl, message_rx, api_pool, reorder_buffer_time, thread_id, bot_user, enforcing,
+                       update_retention, logger.getChild("track")
+               )
+               await ws_pool.init_workers()
+
+trio_asyncio.run(main)
diff --git a/strikebot/src/strikebot/common.py b/strikebot/src/strikebot/common.py
new file mode 100644 (file)
index 0000000..519037a
--- /dev/null
@@ -0,0 +1,10 @@
+from typing import Any
+
+
+def int_digest(val: int) -> str:
+       return "{:04x}".format(val & 0xffff)
+
+
+def obj_digest(obj: Any) -> str:
+       """Makes a short digest of the identity of an object."""
+       return int_digest(id(obj))
diff --git a/strikebot/src/strikebot/db.py b/strikebot/src/strikebot/db.py
new file mode 100644 (file)
index 0000000..c9a060a
--- /dev/null
@@ -0,0 +1,58 @@
+"""Single-connection Postgres client with high-level wrappers for operations."""
+
+from dataclasses import dataclass
+from typing import Any, Iterable
+
+import trio
+
+
+def _channel_sender(method):
+       async def wrapped(self, resp_channel, *args, **kwargs):
+               ret = await method(self, *args, **kwargs)
+               with resp_channel:
+                       await resp_channel.send(ret)
+
+       return wrapped
+
+
+@dataclass
+class Client:
+       """High-level wrappers for DB operations."""
+       conn: Any
+
+       @_channel_sender
+       async def get_auth_tokens(self, auth_ids: set[int]) -> dict[int, str]:
+               results = await self.conn.fetch(
+                       """
+                               select
+                                       distinct on (id)
+                                       id, access_token
+                               from
+                                       public.reddit_app_authorization
+                                       join unnest($1::integer[]) as request_id on id = request_id
+                               where expires > current_timestamp
+                       """,
+                       list(auth_ids),
+               )
+               return {r["id"]: r["access_token"] for r in results}
+
+
+class Messenger:
+       def __init__(self, client: Client):
+               self._client = client
+               (self._request_tx, self._request_rx) = trio.open_memory_channel(0)
+
+       async def do(self, method_name: str, args: Iterable) -> Any:
+               """This is run by consumers of the DB wrapper."""
+               method = getattr(self._client, method_name)
+               (resp_tx, resp_rx) = trio.open_memory_channel(0)
+               coro = method(resp_tx, *args)
+               await self._request_tx.send(coro)
+               async with resp_rx:
+                       return await resp_rx.receive()
+
+       async def db_client_impl(self):
+               """This is the DB client Trio task."""
+               async with self._client.conn, self._request_rx:
+                       async for coro in self._request_rx:
+                               await coro
diff --git a/strikebot/src/strikebot/live_ws.py b/strikebot/src/strikebot/live_ws.py
new file mode 100644 (file)
index 0000000..cb8b3d0
--- /dev/null
@@ -0,0 +1,266 @@
+from collections import deque
+from contextlib import suppress
+from dataclasses import dataclass
+from typing import Any, AsyncContextManager, Optional
+import datetime as dt
+import json
+import logging
+import math
+
+from trio_websocket import ConnectionClosed, HandshakeError, open_websocket_url, WebSocketConnection
+import trio
+
+from strikebot.common import int_digest, obj_digest
+from strikebot.reddit_api import AboutLiveThreadRequest
+
+
+@dataclass
+class _PoolEvent:
+       timestamp: float  # Trio time
+
+
+@dataclass
+class _Message(_PoolEvent):
+       data: Any
+       scope: Any
+
+       def dedup_tag(self):
+               """A hashable essence of the contents of this message."""
+               if self.data["type"] == "update":
+                       return ("update", self.data["payload"]["data"]["id"])
+               elif self.data["type"] in {"strike", "delete"}:
+                       return (self.data["type"], self.data["payload"])
+               else:
+                       raise ValueError(self.data["type"])
+
+
+@dataclass
+class _ConnectionDown(_PoolEvent):
+       scope: trio.CancelScope
+
+
+@dataclass
+class _ConnectionUp(_PoolEvent):
+       scope: trio.CancelScope
+
+
+class HealingReadPool:
+       def __init__(
+               self, pool_nursery: trio.Nursery, size, live_thread_id, api_client_pool, pool_event_tx, logger: logging.Logger,
+               silent_limit: dt.timedelta
+       ):
+               assert size >= 2
+               self._nursery = pool_nursery
+               self._size = size
+               self._live_thread_id = live_thread_id
+               self._api_client_pool = api_client_pool
+               self._pool_event_tx = pool_event_tx
+               self._logger = logger
+               self._silent_limit = silent_limit
+
+               (self._refresh_queue_tx, self._refresh_queue_rx) = trio.open_memory_channel(size)
+
+               # number of workers who have stopped receiving updates but whose tasks aren't yet stopped
+               self._closing = 0
+
+       async def init_workers(self):
+               for _ in range(self._size):
+                       await self._spawn_reader()
+               if not self._nursery.child_tasks:
+                       raise RuntimeError("Unable to create any WS connections")
+
+       async def _reader_impl(self, conn_ctx: AsyncContextManager, cancel_scope, refresh_tx):
+               # TODO could use task's implicit cancel scope
+               try:
+                       async with conn_ctx as conn:
+                               self._logger.debug("scope up: {}".format(obj_digest(cancel_scope)))
+                               with refresh_tx:
+                                       silent_timeout = False
+                                       with cancel_scope, suppress(ConnectionClosed):
+                                               while True:
+                                                       with trio.move_on_after(self._silent_limit.total_seconds()) as timeout_scope:
+                                                               message = await conn.get_message()
+
+                                                       if timeout_scope.cancelled_caught:
+                                                               if len(self._nursery.child_tasks) - self._closing == self._size:
+                                                                       silent_timeout = True
+                                                                       break
+                                                               else:
+                                                                       self._logger.debug("not replacing connection {}; {} tasks, {} closing".format(
+                                                                               obj_digest(cancel_scope),
+                                                                               len(self._nursery.child_tasks),
+                                                                               self._closing,
+                                                                       ))
+                                                       else:
+                                                               event = _Message(trio.current_time(), json.loads(message), cancel_scope)
+                                                               await self._pool_event_tx.send(event)
+
+                                       self._closing += 1
+                                       if silent_timeout:
+                                               await conn.aclose()
+                                               self._logger.debug("replacing WS connection {} due to silent timeout".format(obj_digest(cancel_scope)))
+                                       elif cancel_scope.cancelled_caught:
+                                               await conn.aclose(1008, "Server unexpectedly stopped sending messages")
+                                               self._logger.warning("replacing WS connection {} due to missed update".format(obj_digest(cancel_scope)))
+                                       else:
+                                               self._logger.warning("replacing WS connection {} closed by server".format(obj_digest(cancel_scope)))
+
+                                       refresh_tx.send_nowait(cancel_scope)
+                                       self._logger.debug("scope down: {} ({} tasks, {} closing)".format(
+                                               obj_digest(cancel_scope),
+                                               len(self._nursery.child_tasks),
+                                               self._closing,
+                                       ))
+                                       self._closing -= 1
+               except HandshakeError:
+                       self._logger.error("handshake error while opening WS connection")
+
+       async def _new_connection(self) -> Optional[AsyncContextManager[WebSocketConnection]]:
+               request = AboutLiveThreadRequest(self._live_thread_id)
+               resp = await self._api_client_pool.make_request(request)
+               if resp.status_code == 200:
+                       url = json.loads(resp.text)["data"]["websocket_url"]
+                       return open_websocket_url(url)
+               else:
+                       self._logger.error(f"thread info request failed (HTTP {resp.status_code}); can't open WS connection")
+                       return None
+
+       async def _spawn_reader(self) -> None:
+               conn = await self._new_connection()
+               if conn:
+                       new_scope = trio.CancelScope()
+                       await self._pool_event_tx.send(_ConnectionUp(trio.current_time(), new_scope))
+                       refresh_tx = self._refresh_queue_tx.clone()
+                       self._nursery.start_soon(self._reader_impl, conn, new_scope, refresh_tx)
+
+       async def conn_refresher_impl(self):
+               """Task to monitor and replace WS connections as they disconnect or go silent."""
+               with self._refresh_queue_rx:
+                       async for old_scope in self._refresh_queue_rx:
+                               await self._pool_event_tx.send(_ConnectionDown(trio.current_time(), old_scope))
+                               await self._spawn_reader()
+                               if not self._nursery.child_tasks:
+                                       raise RuntimeError("WS pool depleted")
+
+
+def _tag_digest(tag):
+       return int_digest(hash(tag))
+
+
+def _format_scope_list(scopes):
+       return ", ".join(sorted(map(obj_digest, scopes)))
+
+
+class PoolMerger:
+       @dataclass
+       class _Bucket:
+               start: float
+               recipients: set
+
+       def __init__(self, pool_event_rx, message_tx, parity_timeout: dt.timedelta, conn_warmup: dt.timedelta, logger: logging.Logger):
+               """
+               :param parity_timeout: max delay between message receipt on different connections
+               :param conn_warmup: max interval after connection within which missed updates are allowed
+               """
+               self._pool_event_rx = pool_event_rx
+               self._message_tx = message_tx
+               self._parity_timeout = parity_timeout
+               self._conn_warmup = conn_warmup
+               self._logger = logger
+
+               self._buckets = {}
+               self._scope_activations = {}
+               self._pending = deque()
+               self._outgoing_scopes = set()
+               (self._timer_poke_tx, self._timer_poke_rx) = trio.open_memory_channel(math.inf)
+
+       def _log_current_buckets(self):
+               self._logger.debug("current buckets: {}".format(", ".join(map(_tag_digest, self._buckets.keys()))))
+
+       async def event_reader_impl(self):
+               """Drop unused messages, deduplicate useful ones, and communicate with the timeout handler."""
+               with self._pool_event_rx, self._timer_poke_tx:
+                       async for event in self._pool_event_rx:
+                               if isinstance(event, _ConnectionUp):
+                                       # An early add of an active scope could mean it's expected on a message that fired before it opened,
+                                       # resulting in a false positive for replacement. The timer task merges these in among the buckets by
+                                       # timestamp to avoid that.
+                                       self._pending.append(event)
+                               elif isinstance(event, _Message):
+                                       if event.data["type"] in ["update", "strike", "delete"]:
+                                               tag = event.dedup_tag()
+                                               self._logger.debug("recv {} from {}".format(_tag_digest(tag), obj_digest(event.scope)))
+                                               if tag in self._buckets:
+                                                       b = self._buckets[tag]
+                                                       b.recipients.add(event.scope)
+                                                       # If this scope is the last one for this bucket we could clear the bucket here, but since
+                                                       # connections sometimes get repeat messages and second copies can arrive before the first
+                                                       # copy has arrived on all connections, leaving the bucket open to absorb repeats can reduce
+                                                       # the likelihood of a second bucket being allocated late for the second copy of a message
+                                                       # and causing unnecessary connection replacement.
+                                               elif event.scope not in self._outgoing_scopes:
+                                                       sane = (
+                                                               event.scope in self._scope_activations
+                                                               or any(e.scope == event.scope for e in self._pending)
+                                                       )
+                                                       if sane:
+                                                               self._logger.debug("new bucket {}: {}".format(_tag_digest(tag), tag))
+                                                               self._buckets[tag] = self._Bucket(event.timestamp, {event.scope})
+                                                               self._log_current_buckets()
+                                                               await self._message_tx.send(event)
+                                                       else:
+                                                               raise RuntimeError("recieved message from unrecognized WS connection")
+                                       else:
+                                               self._logger.debug("recv type {!r} from {} (discarding)".format(event.data["type"], obj_digest(event.scope)))
+                               elif isinstance(event, _ConnectionDown):
+                                       # We don't need to worry about canceling this scope at all, so no need to require it for parity for
+                                       # any message, even older ones.  The scope may be gone already, if we canceled it previously.
+                                       self._scope_activations.pop(event.scope, None)
+                                       self._pending = deque(ev for ev in self._pending if ev.scope is not event.scope)
+                                       self._outgoing_scopes.discard(event.scope)
+                               else:
+                                       raise TypeError(f"Expected pool event, found {event!r}")
+                               self._timer_poke_tx.send_nowait(None)  # may have new work for the timer
+
+       async def timeout_handler_impl(self):
+               """When connections have had enough time to reach parity on a message, signal replacement of any slackers."""
+               with self._timer_poke_rx:
+                       while True:
+                               if self._buckets:
+                                       now = trio.current_time()
+                                       (tag, bucket) = min(self._buckets.items(), key = lambda i: i[1].start)
+
+                                       # Make sure the right scope set is ready for the next bucket up
+                                       while self._pending and self._pending[0].timestamp < bucket.start:
+                                               ev = self._pending.popleft()
+                                               self._scope_activations[ev.scope] = ev.timestamp
+
+                                       target_scopes = {
+                                               scope for (scope, active) in self._scope_activations.items()
+                                               if active + self._conn_warmup.total_seconds() < now
+                                       }
+                                       if now > bucket.start + self._parity_timeout.total_seconds():
+                                               filled = target_scopes & bucket.recipients
+                                               missing = target_scopes - bucket.recipients
+                                               extra = bucket.recipients - target_scopes
+
+                                               self._logger.debug("expiring bucket {}".format(_tag_digest(tag)))
+                                               if filled:
+                                                       self._logger.debug("  filled {}: {}".format(len(filled), _format_scope_list(filled)))
+                                               if missing:
+                                                       self._logger.debug("  missing {}: {}".format(len(missing), _format_scope_list(missing)))
+                                               if extra:
+                                                       self._logger.debug("  extra {}: {}".format(len(extra), _format_scope_list(extra)))
+
+                                               for scope in missing:
+                                                       self._outgoing_scopes.add(scope)
+                                                       scope.cancel()
+                                               del self._buckets[tag]
+                                               self._log_current_buckets()
+                                       else:
+                                               await trio.sleep_until(bucket.start + self._parity_timeout.total_seconds())
+                               else:
+                                       try:
+                                               await self._timer_poke_rx.receive()
+                                       except trio.EndOfChannel:
+                                               break
diff --git a/strikebot/src/strikebot/queue.py b/strikebot/src/strikebot/queue.py
new file mode 100644 (file)
index 0000000..adfbbe6
--- /dev/null
@@ -0,0 +1,58 @@
+"""Unbounded blocking queues for Trio"""
+
+from collections import deque
+from dataclasses import dataclass
+from functools import total_ordering
+from typing import Any, Iterable
+import heapq
+
+from trio.lowlevel import ParkingLot
+
+
+class Queue:
+       def __init__(self):
+               self._deque = deque()
+               self._empty_wait = ParkingLot()
+
+       def push(self, el: Any) -> None:
+               self._deque.append(el)
+               self._empty_wait.unpark()
+
+       def extend(self, els: Iterable[Any]) -> None:
+               for el in els:
+                       self.push(el)
+
+       async def pop(self) -> Any:
+               if not self._deque:
+                       await self._empty_wait.park()
+               return self._deque.popleft()
+
+       def __len__(self):
+               return len(self._deque)
+
+
+@total_ordering
+@dataclass
+class _ReverseOrdWrapper:
+       inner: Any
+
+       def __lt__(self, other):
+               return other.inner < self.inner
+
+
+class MaxHeap:
+       def __init__(self):
+               self._heap = []
+               self._empty_wait = ParkingLot()
+
+       def push(self, item):
+               heapq.heappush(self._heap, _ReverseOrdWrapper(item))
+               self._empty_wait.unpark()
+
+       async def pop(self):
+               if not self._heap:
+                       await self._empty_wait.park()
+               return heapq.heappop(self._heap).inner
+
+       def __len__(self) -> int:
+               return len(self._heap)
diff --git a/strikebot/src/strikebot/reddit_api.py b/strikebot/src/strikebot/reddit_api.py
new file mode 100644 (file)
index 0000000..3e69e6c
--- /dev/null
@@ -0,0 +1,291 @@
+"""Multi-auth HTTP connection pool with Reddit API wrappers and rate limiting."""
+
+from abc import ABCMeta, abstractmethod
+from dataclasses import dataclass, field
+from functools import total_ordering
+from socket import EAI_AGAIN, EAI_FAIL, gaierror
+import datetime as dt
+import logging
+
+from asks.response_objects import Response
+import asks
+import trio
+
+from strikebot import __version__ as VERSION
+from strikebot.common import obj_digest
+from strikebot.queue import MaxHeap, Queue
+
+
+REQUEST_DELAY = dt.timedelta(seconds = 1)
+
+TOKEN_UPDATE_DELAY = dt.timedelta(minutes = 15)
+
+USER_AGENT_FMT = f"any:net.jcornell.strikebot.{{auth_id}}:v{VERSION} (by /u/jaxklax)"
+
+API_BASE_URL = "https://oauth.reddit.com"
+
+
+@total_ordering
+class _Request(metaclass = ABCMeta):
+       _SUBTYPE_PRECEDENCE = None  # assigned later
+
+       def __lt__(self, other):
+               if type(self) is type(other):
+                       return self._subtype_cmp_key() < other._subtype_cmp_key()
+               else:
+                       prec = self._SUBTYPE_PRECEDENCE
+                       return prec.index(type(other)) < prec.index(type(self))
+
+       def __eq__(self, other):
+               if type(self) is type(other):
+                       return self._subtype_cmp_key() == other._subtype_cmp_key()
+               else:
+                       return False
+
+       @abstractmethod
+       def _subtype_cmp_key(self):
+               # a large key corresponds to a high priority
+               raise NotImplementedError()
+
+       @abstractmethod
+       def to_asks_kwargs(self):
+               raise NotImplementedError()
+
+
+@dataclass(eq = False)
+class StrikeRequest(_Request):
+       thread_id: str
+       update_name: str
+       update_ts: dt.datetime
+
+       def to_asks_kwargs(self):
+               return {
+                       "method": "POST",
+                       "path": f"/api/live/{self.thread_id}/strike_update",
+                       "data": {
+                               "api_type": "json",
+                               "id": self.update_name,
+                       },
+               }
+
+       def _subtype_cmp_key(self):
+               return self.update_ts
+
+
+@dataclass(eq = False)
+class DeleteRequest(_Request):
+       thread_id: str
+       update_name: str
+       update_ts: dt.datetime
+
+       def to_asks_kwargs(self):
+               return {
+                       "method": "POST",
+                       "path": f"/api/live/{self.thread_id}/delete_update",
+                       "data": {
+                               "api_type": "json",
+                               "id": self.update_name,
+                       },
+               }
+
+       def _subtype_cmp_key(self):
+               return -self.update_ts
+
+
+@dataclass(eq = False)
+class AboutLiveThreadRequest(_Request):
+       thread_id: str
+       created: float = field(default_factory = trio.current_time)
+
+       def to_asks_kwargs(self):
+               return {
+                       "method": "GET",
+                       "path": f"/live/{self.thread_id}/about",
+                       "params": {
+                               "raw_json": "1",
+                       },
+               }
+
+       def _subtype_cmp_key(self):
+               return -self.created
+
+
+@dataclass(eq = False)
+class _BaseUpdatePostRequest(_Request, metaclass = ABCMeta):
+       thread_id: str
+       body: str
+       created: float = field(default_factory = trio.current_time)
+
+       def to_asks_kwargs(self):
+               return {
+                       "method": "POST",
+                       "path": f"/live/{self.thread_id}/update",
+                       "data": {
+                               "api_type": "json",
+                               "body": self.body,
+                       },
+               }
+
+
+@dataclass(eq = False)
+class ReportUpdateRequest(_BaseUpdatePostRequest):
+       def _subtype_cmp_key(self):
+               return -self.created
+
+
+@dataclass(eq = False)
+class CorrectionUpdateRequest(_BaseUpdatePostRequest):
+       def _subtype_cmp_key(self):
+               return -self.created
+
+
+# highest priority first
+_Request._SUBTYPE_PRECEDENCE = [
+       AboutLiveThreadRequest,
+       StrikeRequest,
+       CorrectionUpdateRequest,
+       ReportUpdateRequest,
+       DeleteRequest,
+]
+
+
+@dataclass
+class AppCooldown:
+       auth_id: int
+       ready_at: float  # Trio time
+
+
+class ApiClientPool:
+       def __init__(
+               self,
+               auth_ids: set[int],
+               db_messenger: "strikebot.db.Messenger",
+               request_queue_limit: int,
+               error_window: dt.timedelta,
+               error_delay: dt.timedelta,
+               logger: logging.Logger,
+       ):
+               self._auth_ids = auth_ids
+               self._db_messenger = db_messenger
+               self._request_queue_limit = request_queue_limit
+               self._error_window = error_window
+               self._error_delay = error_delay
+               self._logger = logger
+
+               now = trio.current_time()
+               self._tokens = {}
+               self._app_queue = Queue()
+               self._waiting = {id_: AppCooldown(id_, now) for id_ in auth_ids}
+               self._request_queue = MaxHeap()
+
+               # pool-wide API error backoff
+               self._last_error = None
+               self._global_resume = None
+
+               self._session = asks.Session(connections = len(auth_ids))
+               self._session.base_location = API_BASE_URL
+
+       async def _update_tokens(self):
+               tokens = await self._db_messenger.do("get_auth_tokens", (self._auth_ids,))
+               self._tokens.update(tokens)
+
+               awaken_auths = self._waiting.keys() & tokens.keys()
+               self._app_queue.extend(self._waiting.pop(auth_id) for auth_id in awaken_auths)
+               self._logger.debug("updated API tokens")
+
+       async def token_updater_impl(self):
+               last_update = None
+               while True:
+                       if last_update is not None:
+                               await trio.sleep_until(last_update + TOKEN_UPDATE_DELAY.total_seconds())
+
+                       if last_update is None:
+                               last_update = trio.current_time()
+                       else:
+                               last_update += TOKEN_UPDATE_DELAY.total_seconds()
+
+                       await self._update_tokens()
+
+       def _check_queue_size(self) -> None:
+               if len(self._request_queue) > self._request_queue_limit:
+                       raise RuntimeError("request queue size exceeded limit")
+
+       async def make_request(self, request: _Request) -> Response:
+               (resp_tx, resp_rx) = trio.open_memory_channel(0)
+               self.enqueue_request(request, resp_tx)
+               async with resp_rx:
+                       return await resp_rx.receive()
+
+       def enqueue_request(self, request: _Request, resp_tx = None) -> None:
+               self._logger.debug("request {}: {}".format(obj_digest(request), type(request).__name__))
+               self._request_queue.push((request, resp_tx))
+               self._check_queue_size()
+               self._logger.debug(f"{len(self._request_queue)} requests in queue")
+
+       async def worker_impl(self, task_status):
+               task_status.started()
+               while True:
+                       (request, resp_tx) = await self._request_queue.pop()
+                       cooldown = await self._app_queue.pop()
+                       await trio.sleep_until(cooldown.ready_at)
+                       if self._global_resume:
+                               await trio.sleep_until(self._global_resume)
+
+                       asks_kwargs = request.to_asks_kwargs()
+                       headers = asks_kwargs.setdefault("headers", {})
+                       headers.update({
+                               "Authorization": "Bearer {}".format(self._tokens[cooldown.auth_id]),
+                               "User-Agent": USER_AGENT_FMT.format(auth_id = cooldown.auth_id),
+                       })
+
+                       request_time = trio.current_time()
+                       try:
+                               resp = await self._session.request(**asks_kwargs)
+                       except gaierror as e:
+                               if e.errno in [EAI_FAIL, EAI_AGAIN]:
+                                       # DNS failure, probably temporary
+                                       error = True
+                               else:
+                                       raise
+                       else:
+                               resp.body  # read response
+                               error = False
+                               wait_for_token = False
+                               log_suffix = " (request {})".format(obj_digest(request))
+                               if resp.status_code == 429:
+                                       # We disagreed about the rate limit state; just try again later.
+                                       self._logger.warning("rate limited by Reddit API" + log_suffix)
+                                       error = True
+                               elif resp.status_code == 401:
+                                       self._logger.warning("got HTTP 401 from Reddit API" + log_suffix)
+                                       error = True
+                                       wait_for_token = True
+                               elif resp.status_code in [404, 500, 503]:
+                                       self._logger.warning(f"got HTTP {resp.status_code} from Reddit API, retrying" + log_suffix)
+                                       error = True
+                               elif 400 <= resp.status_code < 500:
+                                       # If we're doing something wrong, let's catch it right away.
+                                       raise RuntimeError(f"unexpected client error response: {resp.status_code}" + log_suffix)
+                               else:
+                                       if resp.status_code != 200:
+                                               raise RuntimeError(f"unexpected status code {resp.status_code}")
+                                       self._logger.debug("success" + log_suffix)
+                                       if resp_tx:
+                                               await resp_tx.send(resp)
+
+                       if error:
+                               self._request_queue.push((request, resp_tx))
+                               self._check_queue_size()
+                               if self._last_error:
+                                       spread = dt.timedelta(seconds = request_time - self._last_error)
+                                       if spread <= self._error_window:
+                                               self._global_resume = request_time + self._error_delay.total_seconds()
+                               self._last_error = request_time
+
+                       cooldown.ready_at = request_time + REQUEST_DELAY.total_seconds()
+                       if wait_for_token:
+                               self._waiting[cooldown.auth_id] = cooldown
+                               if not self._app_queue:
+                                       self._logger.error("all workers waiting for API tokens")
+                       else:
+                               self._app_queue.push(cooldown)
diff --git a/strikebot/src/strikebot/tests.py b/strikebot/src/strikebot/tests.py
new file mode 100644 (file)
index 0000000..7bbf226
--- /dev/null
@@ -0,0 +1,50 @@
+from unittest import TestCase
+
+from strikebot.updates import parse_update
+
+
+def _build_payload(body_html: str) -> dict[str, str]:
+       return {"body_html": body_html}
+
+
+class UpdateParsingTests(TestCase):
+       def test_successful_counts(self):
+               pu = parse_update(_build_payload("<div>12,345,678 spaghetti</div>"), None, "")
+               self.assertEqual(pu.number, 12345678)
+               self.assertFalse(pu.deletable)
+
+               pu = parse_update(_build_payload("<div><p>0</p><br/>oz</div>"), None, "")
+               self.assertEqual(pu.number, 0)
+               self.assertFalse(pu.deletable)
+
+               pu = parse_update(_build_payload("<div>121 345 621</div>"), 121, "")
+               self.assertEqual(pu.number, 121)
+               self.assertFalse(pu.deletable)
+
+               pu = parse_update(_build_payload("28 336 816"), 28336816, "")
+               self.assertEqual(pu.number, 28336816)
+               self.assertTrue(pu.deletable)
+
+       def test_non_counts(self):
+               pu = parse_update(_build_payload("<div><table><tbody><tr><td>zoo</td></tr></tbody></table></div>"), None, "")
+               self.assertFalse(pu.count_attempt)
+               self.assertFalse(pu.deletable)
+
+       def test_typos(self):
+               pu = parse_update(_build_payload("<span>v9</span>"), 888, "")
+               self.assertIsNone(pu.number)
+               self.assertTrue(pu.count_attempt)
+
+               pu = parse_update(_build_payload("<div>v11.585 Empire</div>"), None, "")
+               self.assertIsNone(pu.number)
+               self.assertTrue(pu.count_attempt)
+               self.assertFalse(pu.deletable)
+
+               pu = parse_update(_build_payload("<div>11, 585, 22 </div>"), 11_585_202, "")
+               self.assertIsNone(pu.number)
+               self.assertTrue(pu.count_attempt)
+               self.assertTrue(pu.deletable)
+
+               pu = parse_update(_build_payload("<span>0490499</span>"), 4999, "")
+               self.assertIsNone(pu.number)
+               self.assertTrue(pu.count_attempt)
diff --git a/strikebot/src/strikebot/updates.py b/strikebot/src/strikebot/updates.py
new file mode 100644 (file)
index 0000000..e3f58cd
--- /dev/null
@@ -0,0 +1,216 @@
+from __future__ import annotations
+from dataclasses import dataclass
+from enum import Enum
+from typing import Optional
+import re
+
+from bs4 import BeautifulSoup
+
+
+Command = Enum("Command", ["RESET", "REPORT"])
+
+
+@dataclass
+class ParsedUpdate:
+       number: Optional[int]
+       command: Optional[Command]
+       count_attempt: bool  # either well-formed or typo
+       deletable: bool
+
+
+def _parse_command(line: str, bot_user: str) -> Optional[Command]:
+       if line.lower() == f"/u/{bot_user} reset".lower():
+               return Command.RESET
+       elif line.lower() in ["sidebar count", "current count"]:
+               return Command.REPORT
+       else:
+               return None
+
+
+def parse_update(payload_data: dict, curr_count: Optional[int], bot_user: str) -> ParsedUpdate:
+       # curr_count is the next number up, one more than the last count
+
+       NEW_LINE = object()
+       SPACE = object()
+
+       # flatten the update content to plain text
+       tree = BeautifulSoup(payload_data["body_html"], "html.parser")
+       worklist = tree.contents
+       out = [[]]
+       while worklist:
+               el = worklist.pop()
+               if isinstance(el, str):
+                       out[-1].append(el)
+               elif el is SPACE:
+                       out[-1].append(el)
+               elif el is NEW_LINE or el.name == "br":
+                       if out[-1]:
+                               out.append([])
+               elif el.name in ["em", "strong", "del", "span", "sup", "code", "a", "th", "td"]:
+                       worklist.extend(reversed(el.contents))
+               elif el.name in ["ul", "ol", "table", "thead", "tbody"]:
+                       worklist.extend(reversed(el.contents))
+               elif el.name in ["li", "p", "div", "h1", "h2", "blockquote"]:
+                       worklist.append(NEW_LINE)
+                       worklist.extend(reversed(el.contents))
+                       worklist.append(NEW_LINE)
+               elif el.name == "pre":
+                       worklist.append(NEW_LINE)
+                       worklist.extend([l] for l in reversed(el.text.splitlines()))
+                       worklist.append(NEW_LINE)
+               elif el.name == "tr":
+                       worklist.append(NEW_LINE)
+                       for (i, cell) in enumerate(reversed(el.contents)):
+                               worklist.append(cell)
+                               if i != len(el.contents) - 1:
+                                       worklist.append(SPACE)
+                       worklist.append(NEW_LINE)
+               else:
+                       raise RuntimeError(f"can't parse tag {el.tag}")
+
+       tmp_lines = (
+               "".join(" " if part is SPACE else part for part in parts).strip()
+               for parts in out
+       )
+       pre_strip_lines = list(filter(None, tmp_lines))
+
+       # normalize whitespace according to HTML rendering rules
+       # https://developer.mozilla.org/en-US/docs/Web/API/Document_Object_Model/Whitespace#explanation
+       stripped_lines = [
+               re.sub(" +", " ", l.replace("\t", " ").replace("\n", " ")).strip(" ")
+               for l in pre_strip_lines
+       ]
+
+       return _parse_from_lines(stripped_lines, curr_count, bot_user)
+
+
+def _parse_from_lines(lines: list[str], curr_count: Optional[int], bot_user: str) -> ParsedUpdate:
+       command = next(
+               filter(None, (_parse_command(l, bot_user) for l in lines)),
+               None
+       )
+       if lines:
+               # look for groups of digits (as many as possible) separated by a uniform separator from the valid set
+               first = lines[0]
+               match = re.match(
+                       "(?P<v>v)?(?P<neg>-)?(?P<num>\\d+((?P<sep>[,. \u2009]|, )\\d+((?P=sep)\\d+)*)?)",
+                       first,
+                       re.ASCII,  # only recognize ASCII digits
+               )
+               if match:
+                       raw_digits = match["num"]
+                       sep = match["sep"]
+                       post = first[match.end() :]
+
+                       zeros = False
+                       while len(raw_digits) > 1 and raw_digits[0] == "0":
+                               zeros = True
+                               raw_digits = raw_digits.removeprefix("0").removeprefix(sep or "")
+
+                       parts = raw_digits.split(sep) if sep else [raw_digits]
+                       lone = len(lines) == 1 and (not post or post.isspace())
+                       typo = False
+                       if lone:
+                               all_parts_valid = (
+                                       sep is None
+                                       or (
+                                               1 <= len(parts[0]) <= 3
+                                               and all(len(p) == 3 for p in parts[1:])
+                                       )
+                               )
+                               if match["v"] and len(parts) == 1 and len(parts[0]) <= 2:
+                                       # failed paste of leading digits
+                                       typo = True
+                               elif match["v"] and all_parts_valid:
+                                       # v followed by count
+                                       typo = True
+                               elif curr_count is not None and abs(curr_count) >= 100 and bool(match["neg"]) == (curr_count < 0):
+                                       goal_parts = _separate(str(abs(curr_count)))
+                                       partials = [
+                                               goal_parts[: -1] + [goal_parts[-1][: -1]],  # missing last digit
+                                               goal_parts[: -1] + [goal_parts[-1][: -2]],  # missing last two digits
+                                               goal_parts[: -1] + [goal_parts[-1][: -2] + goal_parts[-1][-1]],  # missing second-last digit
+                                       ]
+                                       if parts in partials:
+                                               # missing any of last two digits
+                                               typo = True
+                                       elif parts in [p[: -1] + [p[-1] + goal_parts[0]] + goal_parts[1 :] for p in partials]:
+                                               # double paste
+                                               typo = True
+
+                       if match["v"] or zeros or typo or (parts == ["0"] and match["neg"]):
+                               number = None
+                               count_attempt = True
+                               deletable = lone
+                       else:
+                               if curr_count is not None and sep and sep.isspace():
+                                       # Presume that the intended count consists of as many valid digit groups as necessary to match the
+                                       # number of digits in the expected count, if possible.
+                                       digit_count = len(str(abs(curr_count)))
+                                       use_parts = []
+                                       accum = 0
+                                       for (i, part) in enumerate(parts):
+                                               part_valid = len(part) <= 3 if i == 0 else len(part) == 3
+                                               if part_valid and accum < digit_count:
+                                                       use_parts.append(part)
+                                                       accum += len(part)
+                                               else:
+                                                       break
+
+                                       # could still be a no-separator count with some extra digit groups on the same line
+                                       if not use_parts:
+                                               use_parts = [parts[0]]
+
+                                       lone = lone and len(use_parts) == len(parts)
+                               else:
+                                       # current count is unknown or no separator was used
+                                       use_parts = parts
+
+                               digits = "".join(use_parts)
+                               number = -int(digits) if match["neg"] else int(digits)
+                               special = (
+                                       curr_count is not None
+                                       and abs(number - curr_count) <= 25
+                                       and _is_special_number(number)
+                               )
+                               deletable = lone and not special
+                               if len(use_parts) == len(parts) and post and not post[0].isspace():
+                                       count_attempt = curr_count is not None and abs(number - curr_count) <= 25
+                                       number = None
+                               else:
+                                       count_attempt = True
+               else:
+                       # no count attempt found
+                       number = None
+                       count_attempt = False
+                       deletable = False
+       else:
+               # no lines in update
+               number = None
+               count_attempt = False
+               deletable = True
+
+       return ParsedUpdate(
+               number = number,
+               command = command,
+               count_attempt = count_attempt,
+               deletable = deletable,
+       )
+
+
+def _separate(digits: str) -> list[str]:
+       mod = len(digits) % 3
+       out = []
+       if mod:
+               out.append(digits[: mod])
+       out.extend(digits[i : i + 3] for i in range(mod, len(digits), 3))
+       return out
+
+
+def _is_special_number(num: int) -> bool:
+       num_str = str(num)
+       return bool(
+               num % 1000 in [0, 1, 333, 999]
+               or (num > 10_000_000 and "".join(reversed(num_str)) == num_str)
+               or re.match(r"(.+)\1+$", num_str)  # repeated sequence
+       )