Full MVP and some unit tests
authorJakob Cornell <jakob+gpg@jcornell.net>
Thu, 28 Apr 2022 05:44:29 +0000 (00:44 -0500)
committerJakob Cornell <jakob+gpg@jcornell.net>
Fri, 29 Apr 2022 04:31:08 +0000 (23:31 -0500)
13 files changed:
.gitignore [new file with mode: 0644]
docs/sample_config.ini [new file with mode: 0644]
pyproject.toml [new file with mode: 0644]
setup.cfg [new file with mode: 0644]
setup.py [new file with mode: 0644]
src/strikebot/__init__.py [new file with mode: 0644]
src/strikebot/__main__.py [new file with mode: 0644]
src/strikebot/db.py [new file with mode: 0644]
src/strikebot/live_ws.py [new file with mode: 0644]
src/strikebot/queue.py [new file with mode: 0644]
src/strikebot/reddit_api.py [new file with mode: 0644]
src/strikebot/tests.py [new file with mode: 0644]
src/strikebot/updates.py [new file with mode: 0644]

diff --git a/.gitignore b/.gitignore
new file mode 100644 (file)
index 0000000..fbf1bc5
--- /dev/null
@@ -0,0 +1,3 @@
+*.egg-info
+__pycache__
+.mypy_cache
diff --git a/docs/sample_config.ini b/docs/sample_config.ini
new file mode 100644 (file)
index 0000000..66fc49e
--- /dev/null
@@ -0,0 +1,28 @@
+[config]
+bot user = count_better
+
+# (seconds) maximum allowable spread in WebSocket message arrival times
+WS parity time = 1.0
+
+# goal size of WebSocket connection pool
+WS pool size = 5
+
+# (seconds) time after WebSocket handshake completion during which missed updates are excused
+WS warmup time = 0.3
+
+# (seconds) maximum time to hold updates for reordering
+reorder buffer time = 0.25
+
+enforcing = true
+
+# (seconds) minimum time to retain updates to enable future resyncs
+update retention time = 120.0
+
+thread ID = abc123
+
+# space-separated authorization IDs for Reddit API token lookup
+auth IDs = 13 15
+
+# Postgres database configuration; same options as in a connect string
+[db connect params]
+host = example.org
diff --git a/pyproject.toml b/pyproject.toml
new file mode 100644 (file)
index 0000000..8fe2f47
--- /dev/null
@@ -0,0 +1,3 @@
+[build-system]
+requires = ["setuptools>=42", "wheel"]
+build-backend = "setuptools.build_meta"
diff --git a/setup.cfg b/setup.cfg
new file mode 100644 (file)
index 0000000..1c54cc8
--- /dev/null
+++ b/setup.cfg
@@ -0,0 +1,14 @@
+[metadata]
+name = strikebot
+version = 0.0.0
+
+[options]
+package_dir =
+       = src
+packages = strikebot
+python_requires = ~= 3.9
+install_requires =
+       trio ~= 0.19
+       triopg ~= 0.6
+       trio-websocket ~= 0.9
+       asks ~= 2.4
diff --git a/setup.py b/setup.py
new file mode 100644 (file)
index 0000000..056ba45
--- /dev/null
+++ b/setup.py
@@ -0,0 +1,4 @@
+import setuptools
+
+
+setuptools.setup()
diff --git a/src/strikebot/__init__.py b/src/strikebot/__init__.py
new file mode 100644 (file)
index 0000000..63fff59
--- /dev/null
@@ -0,0 +1,271 @@
+from contextlib import nullcontext as nullcontext
+from dataclasses import dataclass
+from functools import total_ordering
+from typing import Optional, Set
+from uuid import UUID
+import bisect
+import datetime as dt
+import importlib.metadata
+import itertools
+import logging
+import re
+
+from trio import CancelScope, EndOfChannel
+import trio
+
+from strikebot.updates import Command, parse_update
+
+
+__version__ = importlib.metadata.version(__package__)
+
+
+@dataclass
+class _Update:
+       id: str
+       name: str
+       author: str
+       number: Optional[int]
+       command: Optional[Command]
+       ts: dt.datetime
+       count_attempt: bool
+       deletable: bool
+       stricken: bool
+
+       def can_follow(self, prior: "_Update") -> bool:
+               """Determine whether this count update can follow another count update."""
+               return self.number == prior.number + 1 and self.author != prior.author
+
+
+@total_ordering
+@dataclass
+class _BufferedUpdate:
+       update: _Update
+       release_at: float  # Trio time
+
+       def __lt__(self, other):
+               return self.update.ts < other.update.ts
+
+
+@total_ordering
+@dataclass
+class _TimelineUpdate:
+       update: _Update
+       accepted: bool
+       bad_count: bool
+
+       def __lt__(self, other):
+               return self.update.ts < other.update.ts
+
+
+def _parse_rfc_4122_ts(ts: int) -> dt.datetime:
+       epoch = dt.datetime(1582, 10, 15, tzinfo = dt.timezone.utc)
+       return epoch + dt.timedelta(microseconds = ts / 10)
+
+
+def _format_update_ref(update: _Update, thread_id: str) -> str:
+       url = f"https://www.reddit.com/live/{thread_id}/updates/{update.id}"
+       author_md = update.author.replace("_", "\\_")  # Reddit allows letters, digits, underscore, and hyphen
+       return f"[{update.number:,} by {author_md}]({url})"
+
+
+def _format_bad_strike_alert(update: _Update, thread_id: str) -> str:
+       return _format_update_ref(update, thread_id) + " was incorrectly stricken."
+
+
+def _format_curr_count(last_valid: Optional[_Update]) -> str:
+       if last_valid and last_valid.number is not None:
+               count_str = f"{last_valid.number + 1:,}"
+       else:
+               count_str = "unknown"
+       return f"_current count:_ {count_str}"
+
+
+async def count_tracker_impl(
+       message_rx,
+       api_pool: "strikebot.reddit_api.ApiClientPool",
+       reorder_buffer_time: dt.timedelta,  # duration to hold some updates for reordering
+       thread_id: str,
+       bot_user: str,
+       enforcing: bool,
+       update_retention: dt.timedelta,
+       logger: logging.Logger,
+) -> None:
+       from strikebot.reddit_api import CorrectionUpdateRequest, DeleteRequest, ReportUpdateRequest, StrikeRequest
+
+       buffer_ = []
+       timeline = []
+       last_valid: Optional[_Update] = None
+       pending_strikes: Set[str] = set()  # names of updates to mark stricken on arrival
+
+       def handle_update(update):
+               nonlocal delete_start, last_valid
+
+               tu = _TimelineUpdate(update, accepted = False, bad_count = False)
+
+               pos = bisect.bisect(timeline, tu)
+               if pos != len(timeline):
+                       logger.warning(f"long transpo: {update.name}")
+
+               pred = next(
+                       (timeline[i].update for i in reversed(range(pos)) if timeline[i].accepted),
+                       None
+               )
+               if update.command is not Command.RESET and update.number is not None and last_valid and not pred:
+                       logger.warning("ignoring {update.name}: no valid prior count on record")
+               else:
+                       tu.accepted = (
+                               update.command is Command.RESET
+                               or (
+                                       update.number is not None
+                                       and (pred is None or pred.number is None or update.can_follow(pred))
+                               )
+                       )
+                       timeline.insert(pos, tu)
+                       if tu.accepted:
+                               # resync subsequent updates already processed
+                               newly_valid = []
+                               newly_invalid = []
+                               last_valid = update
+                               for scan_tu in timeline[pos + 1:]:
+                                       if scan_tu.update.command is Command.RESET:
+                                               last_valid = scan_tu.update
+                                       elif scan_tu.update.number is not None:
+                                               accept = last_valid.number is None or scan_tu.update.can_follow(last_valid)
+                                               if accept and not scan_tu.accepted:
+                                                       newly_valid.append(scan_tu)
+                                               elif not accept and scan_tu.accepted:
+                                                       newly_invalid.append(scan_tu)
+                                               scan_tu.accepted = accept
+                                               if accept:
+                                                       last_valid = scan_tu.update
+
+                               parts = []
+                               if newly_valid:
+                                       parts.append(
+                                               "The following counts are valid:\n\n"
+                                               + "\n".join(" - " + _format_update_ref(tu.update, thread_id) for tu in newly_valid)
+                                       )
+                                       for tu in newly_valid:
+                                               tu.bad_count = False
+
+                               unstrikable = [tu for tu in newly_invalid if tu.update.stricken]
+                               if unstrikable:
+                                       parts.append(
+                                               "The following counts are invalid:\n\n"
+                                               + "\n".join(" - " + _format_update_ref(tu.update, thread_id) for tu in unstrikable)
+                                       )
+                                       for tu in unstrikable:
+                                               tu.bad_count = True
+
+                               if update.stricken:
+                                       logger.info(f"bad strike of {update.name}")
+                                       parts.append(_format_bad_strike_alert(update, thread_id))
+
+                               for invalid_tu in newly_invalid:
+                                       if not invalid_tu.update.stricken:
+                                               if enforcing:
+                                                       api_pool.enqueue_request(StrikeRequest(thread_id, invalid_tu.update.name, invalid_tu.update.ts))
+                                               invalid_tu.bad_count = True
+                                               invalid_tu.update.stricken = True
+                               if parts:
+                                       parts.append(_format_curr_count(last_valid))
+                                       api_pool.enqueue_request(CorrectionUpdateRequest(thread_id, "\n\n".join(parts)))
+                       elif update.number is not None or update.count_attempt:
+                               if enforcing:
+                                       api_pool.enqueue_request(StrikeRequest(thread_id, update.name, update.ts))
+                               tu.bad_count = True
+                               update.stricken = True
+
+               if update.command is Command.REPORT:
+                       api_pool.enqueue_request(ReportUpdateRequest(thread_id, body = _format_curr_count(last_valid)))
+
+       with message_rx:
+               while True:
+                       if buffer_:
+                               deadline = min(bu.release_at for bu in buffer_)
+                               cancel_scope = CancelScope(deadline = deadline)
+                       else:
+                               cancel_scope = nullcontext()
+
+                       msg = None
+                       with cancel_scope:
+                               try:
+                                       msg = await message_rx.receive()
+                               except EndOfChannel:
+                                       break
+
+                       if msg:
+                               if msg.data["type"] == "update":
+                                       payload_data = msg.data["payload"]["data"]
+                                       stricken = payload_data["name"] in pending_strikes
+                                       pending_strikes.discard(payload_data["name"])
+                                       if payload_data["author"] != bot_user:
+                                               next_up = last_valid.number if last_valid else None
+                                               pu = parse_update(payload_data, next_up, bot_user)
+                                               rfc_ts = UUID(payload_data["id"]).time
+                                               update = _Update(
+                                                       id = payload_data["id"],
+                                                       name = payload_data["name"],
+                                                       author = payload_data["author"],
+                                                       number = pu.number,
+                                                       command = pu.command,
+                                                       ts = _parse_rfc_4122_ts(rfc_ts),
+                                                       count_attempt = pu.count_attempt,
+                                                       deletable = pu.deletable,
+                                                       stricken = stricken or payload_data["stricken"],
+                                               )
+                                               release_at = msg.timestamp + reorder_buffer_time.total_seconds()
+                                               bisect.insort(buffer_, _BufferedUpdate(update, release_at))
+                               elif msg.data["type"] == "strike":
+                                       UUID(re.match("LiveUpdate_(.+)$", msg.data["payload"])[1])  # sanity check payload
+                                       slot = next(
+                                               (
+                                                       slot for slot in itertools.chain(buffer_, reversed(timeline))
+                                                       if slot.update.name == msg.data["payload"]
+                                               ),
+                                               None
+                                       )
+                                       if slot:
+                                               slot.update.stricken = True
+                                               if isinstance(slot, _TimelineUpdate) and slot.accepted:
+                                                       logger.info(f"bad strike of {slot.update.name}")
+                                                       body = _format_bad_strike_alert(slot.update, thread_id)
+                                                       api_pool.enqueue_request(CorrectionUpdateRequest(thread_id, body))
+                                       else:
+                                               pending_strikes.add(msg.data["payload"])
+
+                       threshold = dt.datetime.now(dt.timezone.utc) - update_retention
+
+                       # pull updates from the reorder buffer and process
+                       new_buffer = []
+                       for bu in buffer_:
+                               process = (
+                                       # not a count
+                                       bu.update.number is None
+
+                                       # valid count
+                                       or (last_valid is None or last_valid.number is None or bu.update.can_follow(last_valid))
+
+                                       # invalid and not likely to become valid by transpo
+                                       or bu.update.number > last_valid.number + 3
+
+                                       or trio.current_time() >= bu.release_at
+                               )
+
+                               if process:
+                                       if bu.update.ts < threshold:
+                                               logger.warning(f"ignoring {bu.update.name}: arrived past retention window")
+                                       else:
+                                               handle_update(bu.update)
+                               else:
+                                       new_buffer.append(bu)
+                       buffer_ = new_buffer
+
+                       # delete/forget old updates
+                       for (i, tu) in enumerate(timeline):
+                               if i >= len(timeline) - 10 or tu.update.ts >= threshold:
+                                       break
+                               elif tu.bad_count and tu.update.deletable:
+                                       if enforcing:
+                                               api_pool.enqueue_request(DeleteRequest(thread_id, tu.update.name, tu.update.ts))
+                       del timeline[: i]
diff --git a/src/strikebot/__main__.py b/src/strikebot/__main__.py
new file mode 100644 (file)
index 0000000..a1c6705
--- /dev/null
@@ -0,0 +1,77 @@
+import argparse
+import configparser
+import datetime as dt
+import logging
+import sys
+
+import trio
+import trio_asyncio
+import triopg
+
+from strikebot import count_tracker_impl
+from strikebot.db import Client, Messenger
+from strikebot.live_ws import HealingReadPool, PoolMerger
+from strikebot.reddit_api import ApiClientPool
+
+
+ap = argparse.ArgumentParser(__package__)
+ap.add_argument("config_path")
+args = ap.parse_args()
+
+
+parser = configparser.ConfigParser()
+with open(args.config_path) as config_file:
+       parser.read_file(config_file)
+
+main_cfg = parser["config"]
+
+auth_ids = set(map(int, main_cfg["auth IDs"].split()))
+bot_user = main_cfg["bot user"]
+enforcing = main_cfg.getboolean("enforcing")
+reorder_buffer_time = dt.timedelta(seconds = main_cfg.getfloat("reorder buffer time"))
+thread_id = main_cfg["thread ID"]
+update_retention = dt.timedelta(seconds = main_cfg.getfloat("update retention time"))
+ws_parity_time = dt.timedelta(seconds = main_cfg.getfloat("WS parity time"))
+ws_pool_size = main_cfg.getint("WS pool size")
+ws_warmup = dt.timedelta(seconds = main_cfg.getfloat("WS warmup seconds"))
+
+db_connect_params = dict(parser["db connect params"])
+
+
+logger = logging.getLogger(__package__)
+logger.setLevel(logging.DEBUG)
+
+handler = logging.StreamHandler(sys.stdout)
+handler.setFormatter(logging.Formatter("{asctime:23}: {levelname:8}: {message}", style = "{"))
+logger.addHandler(handler)
+
+
+async def main():
+       async with trio.open_nursery() as nursery_a, triopg.connect(**db_connect_params) as db_conn:
+               (req_tx, req_rx) = trio.open_memory_channel(0)
+               client = Client(db_conn)
+               db_messenger = Messenger(client)
+
+               nursery_a.start_soon(db_messenger.db_client_impl)
+
+               api_pool = ApiClientPool(auth_ids, db_messenger, logger)
+               nursery_a.start_soon(api_pool.token_updater_impl)
+               for _ in auth_ids:
+                       nursery_a.start_soon(api_pool.worker_impl)
+
+               (message_tx, message_rx) = trio.open_memory_channel(0)
+               (pool_event_tx, pool_event_rx) = trio.open_memory_channel(0)
+               merger = PoolMerger(pool_event_rx, message_tx, ws_parity_time, ws_warmup, logger)
+               async with trio.open_nursery() as nursery_b:
+                       nursery_b.start_soon(merger.event_reader_impl)
+                       nursery_b.start_soon(merger.timeout_handler_impl)
+                       pool = HealingReadPool(nursery_b, ws_pool_size, thread_id, api_pool, pool_event_tx, logger)
+                       async with pool, trio.open_nursery() as nursery_c:
+                               nursery_c.start_soon(pool.conn_refresher_impl)
+                               nursery_c.start_soon(
+                                       count_tracker_impl,
+                                       message_rx, api_pool, reorder_buffer_time, thread_id, bot_user, enforcing, update_retention, logger
+                               )
+
+
+trio_asyncio.run(main)
diff --git a/src/strikebot/db.py b/src/strikebot/db.py
new file mode 100644 (file)
index 0000000..76a19b6
--- /dev/null
@@ -0,0 +1,44 @@
+"""Single-connection Postgres client with high-level wrappers for operations."""
+
+from dataclasses import dataclass
+from typing import Any, Iterable
+
+import trio
+
+
+def _channel_sender(method):
+       async def wrapped(self, resp_channel, *args, **kwargs):
+               with resp_channel:
+                       await resp_channel.send(await method(self, *args, **kwargs))
+
+       return wrapped
+
+
+@dataclass
+class Client:
+       """High-level wrappers for DB operations."""
+       conn: Any
+
+       @_channel_sender
+       async def get_auth_tokens(self, auth_ids: set[int]):
+               raise NotImplementedError()
+
+
+class Messenger:
+       def __init__(self, client: Client):
+               self._client = client
+               (self._request_tx, self._request_rx) = trio.open_memory_channel(0)
+
+       async def do(self, method_name: str, args: Iterable) -> Any:
+               """This is run by consumers of the DB wrapper."""
+               method = getattr(self._client, method_name)
+               (resp_tx, resp_rx) = trio.open_memory_channel(0)
+               await self._request_tx.send(method(resp_tx, *args))
+               async with resp_rx:
+                       return await resp_rx.receive()
+
+       async def db_client_impl(self):
+               """This is the DB client Trio task."""
+               async with self._client.conn, self._request_rx:
+                       async for coro in self._request_rx:
+                               await coro
diff --git a/src/strikebot/live_ws.py b/src/strikebot/live_ws.py
new file mode 100644 (file)
index 0000000..40f02dc
--- /dev/null
@@ -0,0 +1,203 @@
+from collections import deque
+from contextlib import suppress
+from dataclasses import dataclass
+from typing import Any, AsyncContextManager, Optional
+import datetime as dt
+import json
+import logging
+import math
+
+from trio_websocket import ConnectionClosed, HandshakeError, open_websocket_url, WebSocketConnection
+import trio
+
+from strikebot.reddit_api import AboutLiveThreadRequest
+
+
+@dataclass
+class _PoolEvent:
+       timestamp: float  # Trio time
+
+
+@dataclass
+class _Message(_PoolEvent):
+       data: Any
+       scope: Any
+
+       def dedup_tag(self):
+               """A hashable essence of the contents of this message."""
+               if self.data["type"] == "update":
+                       return ("update", self.data["payload"]["data"]["id"])
+               elif self.data["type"] in {"strike", "delete"}:
+                       return (self.data["type"], self.data["payload"])
+               else:
+                       raise ValueError(self.data["type"])
+
+
+@dataclass
+class _ConnectionDown(_PoolEvent):
+       scope: trio.CancelScope
+
+
+@dataclass
+class _ConnectionUp(_PoolEvent):
+       scope: trio.CancelScope
+
+
+class HealingReadPool:
+       def __init__(self, nursery: trio.Nursery, size, live_thread_id, api_client_pool, pool_event_tx, logger: logging.Logger):
+               assert size >= 2
+               self._nursery = nursery
+               self._size = size
+               self._live_thread_id = live_thread_id
+               self._api_client_pool = api_client_pool
+               self._pool_event_tx = pool_event_tx
+               self._logger = logger
+               (self._refresh_queue_tx, self._refresh_queue_rx) = trio.open_memory_channel(size)
+               self._active_count = 0
+
+       async def __aenter__(self):
+               for _ in range(self._size):
+                       await self._spawn_reader()
+               if not self._active_count:
+                       raise RuntimeError("Unable to create any WS connections")
+
+       async def __aexit__(self, exc_type, exc_value, traceback):
+               self._refresh_queue_tx.close()
+
+       async def _reader_impl(self, conn_ctx: AsyncContextManager, cancel_scope, refresh_tx):
+               try:
+                       async with conn_ctx as conn:
+                               self._active_count += 1
+                               with refresh_tx:
+                                       with cancel_scope, suppress(ConnectionClosed):
+                                               while True:
+                                                       message = await conn.get_message()
+                                                       event = _Message(trio.current_time(), json.loads(message), cancel_scope)
+                                                       await self._pool_event_tx.send(event)
+
+                                       if cancel_scope.cancelled_caught:
+                                               await conn.aclose(4058, "Server unexpectedly stopped sending messages")
+                                               self._logger.warning("replacing WS connection due to missed updates")
+                                       else:
+                                               self._logger.warning("replacing WS connection closed by server")
+                                       refresh_tx.send_nowait(cancel_scope)
+                               self._active_count -= 1
+               except HandshakeError:
+                       self._logger.error("handshake error while opening WS connection")
+
+       async def _new_connection(self) -> Optional[AsyncContextManager[WebSocketConnection]]:
+               request = AboutLiveThreadRequest(self._live_thread_id)
+               resp = await self._api_client_pool.make_request(request)
+               if resp.status_code == 200:
+                       url = json.loads(resp.text)["data"]["websocket_url"]
+                       return open_websocket_url(url)
+               else:
+                       self._logger.error(f"thread info request failed (HTTP {resp.status_code}); can't open WS connection")
+                       return None
+
+       async def _spawn_reader(self) -> None:
+               conn = await self._new_connection()
+               if conn:
+                       new_scope = trio.CancelScope()
+                       await self._pool_event_tx.send(_ConnectionUp(trio.current_time(), new_scope))
+                       refresh_tx = self._refresh_queue_tx.clone()
+                       self._nursery.start_soon(self._reader_impl, conn, new_scope, refresh_tx)
+
+       async def conn_refresher_impl(self):
+               """Task to monitor and replace WS connections as they disconnect or go silent."""
+               with self._refresh_queue_rx:
+                       async for old_scope in self._refresh_queue_rx:
+                               await self._pool_event_tx.send(_ConnectionDown(trio.current_time(), old_scope))
+                               await self._spawn_reader()
+                               if not self._active_count:
+                                       raise RuntimeError("WS pool depleted")
+
+
+class PoolMerger:
+       @dataclass
+       class _Bucket:
+               start: float
+               recipients: set
+
+       def __init__(self, pool_event_rx, message_tx, parity_timeout: dt.timedelta, conn_warmup: dt.timedelta, logger: logging.Logger):
+               """
+               :param parity_timeout: max delay between message receipt on different connections
+               :param conn_warmup: max interval after connection within which missed updates are allowed
+               """
+               self._pool_event_rx = pool_event_rx
+               self._message_tx = message_tx
+               self._parity_timeout = parity_timeout
+               self._conn_warmup = conn_warmup
+               self._logger = logger
+
+               self._buckets = {}
+               self._scope_activations = {}
+               self._pending = deque()
+               (self._timer_poke_tx, self._timer_poke_rx) = trio.open_memory_channel(math.inf)
+
+       async def event_reader_impl(self):
+               """Drop unused messages, deduplicate useful ones, and install info needed by the timeout handler."""
+               with self._pool_event_rx, self._timer_poke_tx:
+                       async for event in self._pool_event_rx:
+                               if isinstance(event, _ConnectionUp):
+                                       # An early add of an active scope could mean it's expected on a message that fired before it opened,
+                                       # resulting in a false positive for replacement.  The timer task merges these in among the buckets
+                                       # by timestamp to avoid that.
+                                       self._pending.append(event)
+                               elif isinstance(event, _Message):
+                                       if event.data["type"] in {"update", "strike", "delete"}:
+                                               tag = event.dedup_tag()
+                                               if tag in self._buckets:
+                                                       b = self._buckets[tag]
+                                                       b.recipients.add(event.scope)
+                                                       if b.recipients >= self._scope_activations.keys():
+                                                               del self._buckets[tag]
+                                               else:
+                                                       sane = (
+                                                               event.scope in self._scope_activations
+                                                               or any(e.scope == event.scope for e in self._pending)
+                                                       )
+                                                       if sane:
+                                                               self._buckets[tag] = self._Bucket(event.timestamp, {event.scope})
+                                                               self._message_tx.send_nowait(event)
+                                                       else:
+                                                               raise RuntimeError("recieved message from unrecognized WS connection")
+                               elif isinstance(event, _ConnectionDown):
+                                       # We don't need to worry about canceling this scope at all, so no need to require it for parity for
+                                       # any message, even older ones.  The scope may be gone already, if we canceled it previously.
+                                       self._scope_activations.pop(event.scope, None)
+                               else:
+                                       raise TypeError(f"Expected pool event, found {event!r}")
+                               self._timer_poke_tx.send_nowait(None)  # may have new work for the timer
+
+       async def timeout_handler_impl(self):
+               """When connections have had enough time to reach parity on a message, signal replacement of any slackers."""
+               with self._timer_poke_rx:
+                       while True:
+                               if self._buckets:
+                                       now = trio.current_time()
+                                       (tag, bucket) = min(self._buckets.items(), key = lambda i: i[1].start)
+
+                                       # Make sure the right scope set is ready for the next bucket up
+                                       while self._pending and self._pending[0].timestamp < bucket.start:
+                                               ev = self._pending.popleft()
+                                               self._scope_activations[ev.scope] = ev.timestamp
+
+                                       target_scopes = {
+                                               scope for (scope, active) in self._scope_activations.items()
+                                               if active + self._conn_warmup.total_seconds() < now
+                                       }
+                                       if bucket.recipients >= target_scopes:
+                                               del self._buckets[tag]
+                                       elif now > bucket.start + self._parity_timeout.total_seconds():
+                                               for scope in target_scopes - bucket.recipients:
+                                                       scope.cancel()
+                                                       del self._scope_activations[scope]
+                                               del self._buckets[tag]
+                                       else:
+                                               await trio.sleep_until(bucket.start + self._parity_timeout.total_seconds())
+                               else:
+                                       try:
+                                               await self._timer_poke_rx.receive()
+                                       except trio.EndOfChannel:
+                                               break
diff --git a/src/strikebot/queue.py b/src/strikebot/queue.py
new file mode 100644 (file)
index 0000000..fed0757
--- /dev/null
@@ -0,0 +1,36 @@
+"""Unbounded blocking priority queue for Trio"""
+
+from dataclasses import dataclass
+from functools import total_ordering
+from typing import Any
+import heapq
+
+from trio.lowlevel import ParkingLot
+
+
+@total_ordering
+@dataclass
+class _ReverseOrdWrapper:
+       inner: Any
+
+       def __lt__(self, other):
+               return other.inner < self.inner
+
+
+class MaxHeap:
+       def __init__(self):
+               self._heap = []
+               self._empty_wait = ParkingLot()
+
+       def push(self, item):
+               heapq.heappush(self._heap, _ReverseOrdWrapper(item))
+               if len(self._empty_wait):
+                       self._empty_wait.unpark()
+
+       async def pop(self):
+               if not self._heap:
+                       await self._empty_wait.park()
+               return heapq.heappop(self._heap).inner
+
+       def __len__(self) -> int:
+               return len(self._heap)
diff --git a/src/strikebot/reddit_api.py b/src/strikebot/reddit_api.py
new file mode 100644 (file)
index 0000000..600b38e
--- /dev/null
@@ -0,0 +1,228 @@
+"""Multi-auth HTTP connection pool with Reddit API wrappers and rate limiting."""
+
+from abc import ABCMeta, abstractmethod
+from collections import deque
+from dataclasses import dataclass, field
+from functools import total_ordering
+import datetime as dt
+import logging
+
+from asks.response_objects import Response
+import asks
+import trio
+
+from strikebot import __version__ as VERSION
+from strikebot.queue import MaxHeap
+
+
+REQUEST_DELAY = dt.timedelta(seconds = 1)
+
+TOKEN_UPDATE_DELAY = dt.timedelta(minutes = 15)
+
+USER_AGENT_FMT = f"any:net.jcornell.strikebot.{{auth_id}}:v{VERSION} (by /u/jaxklax)"
+
+API_BASE_URL = "https://oauth.reddit.com"
+
+
+@total_ordering
+class _Request(metaclass = ABCMeta):
+       _SUBTYPE_PRECEDENCE = None  # assigned later
+
+       def __lt__(self, other):
+               if type(self) is type(other):
+                       return self._subtype_key() < other._subtype_key()
+               else:
+                       prec = self._SUBTYPE_PRECEDENCE
+                       return prec.index(type(other)) < prec.index(type(self))
+
+       def __eq__(self, other):
+               if type(self) is type(other):
+                       return self._subtype_key() == other._subtype_key()
+               else:
+                       return False
+
+       @abstractmethod
+       def _subtype_cmp_key(self):
+               # a large key corresponds to a high priority
+               raise NotImplementedError()
+
+       @abstractmethod
+       def to_asks_kwargs(self):
+               raise NotImplementedError()
+
+
+@dataclass(eq = False)
+class StrikeRequest(_Request):
+       thread_id: str
+       update_name: str
+       update_ts: dt.datetime
+
+       def to_asks_kwargs(self):
+               return {
+                       "method": "POST",
+                       "path": f"/api/live/{self.thread_id}/strike_update",
+                       "data": {
+                               "api_type": "json",
+                               "id": self.update_name,
+                       },
+               }
+
+       def _subtype_cmp_key(self):
+               return self.update_ts
+
+
+@dataclass(eq = False)
+class DeleteRequest(_Request):
+       thread_id: str
+       update_name: str
+       update_ts: dt.datetime
+
+       def to_asks_kwargs(self):
+               return {
+                       "method": "POST",
+                       "path": f"/api/live/{self.thread_id}/delete_update",
+                       "data": {
+                               "api_type": "json",
+                               "id": self.update_name,
+                       },
+               }
+
+       def _subtype_cmp_key(self):
+               return -self.update_ts
+
+
+@dataclass(eq = False)
+class AboutLiveThreadRequest(_Request):
+       thread_id: str
+       created: float = field(default_factory = trio.current_time)
+
+       def to_asks_kwargs(self):
+               return {
+                       "method": "GET",
+                       "path": f"/live/{self.thread_id}/about",
+                       "params": {
+                               "raw_json": "1",
+                       },
+               }
+
+       def _subtype_cmp_key(self):
+               return -self.created
+
+
+@dataclass(eq = False)
+class _BaseUpdatePostRequest(_Request, metaclass = ABCMeta):
+       thread_id: str
+       body: str
+       created: float = field(default_factory = trio.current_time)
+
+       def to_asks_kwargs(self):
+               return {
+                       "method": "POST",
+                       "path": f"/live/{self.thread_id}/update",
+                       "data": {
+                               "api_type": "json",
+                               "body": self.body,
+                       },
+               }
+
+
+@dataclass(eq = False)
+class ReportUpdateRequest(_BaseUpdatePostRequest):
+       def _subtype_cmp_key(self):
+               return -self.created
+
+
+@dataclass(eq = False)
+class CorrectionUpdateRequest(_BaseUpdatePostRequest):
+       def _subtype_cmp_key(self):
+               return -self.created
+
+
+# highest priority first
+_Request._SUBTYPE_PRECEDENCE = [
+       AboutLiveThreadRequest,
+       StrikeRequest,
+       CorrectionUpdateRequest,
+       ReportUpdateRequest,
+       DeleteRequest,
+]
+
+
+@dataclass
+class AppCooldown:
+       auth_id: int
+       ready_at: float  # Trio time
+
+
+class ApiClientPool:
+       def __init__(self, auth_ids: set[int], db_messenger, logger: logging.Logger):
+               self._db_messenger = db_messenger
+               self._logger = logger
+
+               self._tokens = {id_: None for id_ in auth_ids}
+               self._tokens_installed = trio.Event()
+               now = trio.current_time()
+               self._app_queue = deque(AppCooldown(id_, now) for id_ in auth_ids)
+               self._request_queue = MaxHeap()
+
+               self._session = asks.Session(connections = len(auth_ids))
+               self._session.base_location = API_BASE_URL
+
+       async def _update_tokens(self):
+               tokens = await self._db_messenger.do("get_auth_tokens", (self._tokens.keys(),))
+               assert tokens.keys() == self._tokens.keys()
+               self._tokens = tokens
+               self._tokens_installed.set()
+
+       async def token_updater_impl(self):
+               last_update = None
+               while True:
+                       if last_update is not None:
+                               await trio.sleep_until(last_update + TOKEN_UPDATE_DELAY.total_seconds())
+                       last_update = trio.current_time()
+                       await self._update_tokens()
+
+       def _check_queue_size(self) -> None:
+               if len(self._request_queue) > 5:
+                       self._logger.warning(f"API workers may be saturated; {len(self._request_queue)} requests in queue")
+
+       async def make_request(self, request: _Request) -> Response:
+               (resp_tx, resp_rx) = trio.open_memory_channel(0)
+               self._request_queue.push((request, resp_tx))
+               async with resp_rx:
+                       return await resp_rx.receive()
+
+       def enqueue_request(self, request: _Request) -> None:
+               self._request_queue.push((request, None))
+
+       async def worker_impl(self):
+               await self._tokens_installed.wait()
+               while True:
+                       (request, resp_tx) = await self._request_queue.pop()
+                       while trio.current_time() < self._app_queue[0].ready_at:
+                               await trio.sleep_until(self._app_queue[0].ready_at)
+
+                       cooldown = self._app_queue.popleft()
+                       asks_kwargs = request.to_asks_kwargs()
+                       asks_kwargs.setdefault("headers", {}).update({
+                               "Authorization": "Bearer {}".format(self._tokens[cooldown.auth_id]),
+                               "User-Agent": USER_AGENT_FMT.format(auth_id = cooldown.auth_id),
+                       })
+                       cooldown.ready_at = trio.current_time() + REQUEST_DELAY.total_seconds()
+                       resp = await self._session.request(**asks_kwargs)
+
+                       if resp.status_code == 429:
+                               # We disagreed about the rate limit state; just try again later.
+                               self._request_queue.put((request, resp_tx))
+                       elif 400 <= resp.status_code < 500:
+                               # If we're doing something wrong, let's catch it right away.
+                               if resp_tx:
+                                       resp_tx.close()
+                               raise RuntimeError("Unexpected client error response: {}".format(resp.status_code))
+                       else:
+                               if resp.status_code != 200:
+                                       self._logger.warning(f"got HTTP {resp.status_code} from Reddit API")
+                               if resp_tx:
+                                       await resp_tx.send(resp)
+
+                       self._app_queue.append(cooldown)
diff --git a/src/strikebot/tests.py b/src/strikebot/tests.py
new file mode 100644 (file)
index 0000000..9b9ddae
--- /dev/null
@@ -0,0 +1,42 @@
+from unittest import TestCase
+
+from strikebot.updates import parse_update
+
+
+def _build_payload(body_html: str) -> dict[str, str]:
+       return {"body_html": body_html}
+
+
+class UpdateParsingTests(TestCase):
+       def test_successful_counts(self):
+               pu = parse_update(_build_payload("<div>12,345,678 spaghetti</div>"), None)
+               self.assertEqual(pu.number, 12345678)
+               self.assertFalse(pu.deletable)
+
+               pu = parse_update(_build_payload("<div><p>0</p><br/>oz</div>"), None)
+               self.assertEqual(pu.number, 0)
+               self.assertFalse(pu.deletable)
+
+       def test_non_counts(self):
+               pu = parse_update(_build_payload("<div><table><tbody><tr><td>zoo</td></tr></tbody></table></div>"), None)
+               self.assertFalse(pu.count_attempt)
+               self.assertFalse(pu.deletable)
+
+       def test_typos(self):
+               pu = parse_update(_build_payload("<span>v9</span>"), 888)
+               self.assertIsNone(pu.number)
+               self.assertTrue(pu.count_attempt)
+
+               pu = parse_update(_build_payload("<div>v11.585 Empire</div>"), None)
+               self.assertIsNone(pu.number)
+               self.assertTrue(pu.count_attempt)
+               self.assertFalse(pu.deletable)
+
+               pu = parse_update(_build_payload("<div>11, 585, 22 </div>"), 11_585_202)
+               self.assertIsNone(pu.number)
+               self.assertTrue(pu.count_attempt)
+               self.assertTrue(pu.deletable)
+
+               pu = parse_update(_build_payload("<span>0490499</span>"), 4999)
+               self.assertIsNone(pu.number)
+               self.assertTrue(pu.count_attempt)
diff --git a/src/strikebot/updates.py b/src/strikebot/updates.py
new file mode 100644 (file)
index 0000000..cd814bf
--- /dev/null
@@ -0,0 +1,202 @@
+from __future__ import annotations
+from collections import deque
+from dataclasses import dataclass
+from enum import Enum
+from typing import Optional
+from xml.etree import ElementTree
+import re
+
+
+Command = Enum("Command", ["RESET", "REPORT"])
+
+
+@dataclass
+class ParsedUpdate:
+       number: Optional[int]
+       command: Optional[Command]
+       count_attempt: bool
+       deletable: bool
+
+
+def _parse_command(line: str, bot_user: str) -> Optional[Command]:
+       if line.lower() == f"/u/{bot_user} reset".lower():
+               return Command.RESET
+       elif line.lower() in ["sidebar count", "current count"]:
+               return Command.REPORT
+       else:
+               return None
+
+
+def parse_update(payload_data: dict, curr_count: Optional[int], bot_user: str) -> ParsedUpdate:
+       NEW_LINE = object()
+       SPACE = object()
+
+       # flatten the update content to plain text
+       doc = ElementTree.fromstring(payload_data["body_html"])
+       worklist = deque([doc])
+       out = [[]]
+       while worklist:
+               el = worklist.popleft()
+               if el is NEW_LINE:
+                       if out[-1]:
+                               out.append([])
+               elif el is SPACE:
+                       out[-1].append(el)
+               elif isinstance(el, str):
+                       out[-1].append(el.replace("\n", " "))
+               elif el.tag in ["em", "strong", "del", "span", "sup", "code", "a", "th", "td"]:
+                       if el.tail:
+                               worklist.appendleft(el.tail)
+                       for sub in reversed(el):
+                               worklist.appendleft(sub)
+                       if el.text:
+                               worklist.appendleft(el.text)
+               elif el.tag == "li":
+                       assert not el.tail
+                       worklist.appendleft(NEW_LINE)
+                       worklist.appendleft(el.text)
+               elif el.tag in ["p", "div", "h1", "h2", "blockquote"]:
+                       if el.tail:
+                               worklist.appendleft(el.tail)
+                       worklist.appendleft(NEW_LINE)
+                       for sub in reversed(el):
+                               worklist.appendleft(sub)
+                       if el.text:
+                               worklist.appendleft(el.text)
+               elif el.tag in ["ul", "ol"]:
+                       if el.tail:
+                               worklist.appendleft(el.tail)
+                       assert not el.text
+                       for sub in reversed(el):
+                               worklist.appendleft(sub)
+                       worklist.appendleft(NEW_LINE)
+               elif el.tag == "pre":
+                       out.extend([l] for l in el.text.splitlines())
+                       if el.tail:
+                               worklist.appendleft(el.tail)
+                       worklist.appendleft(NEW_LINE)
+               elif el.tag == "table":
+                       if el.tail:
+                               worklist.appendleft(el.tail)
+                       worklist.appendleft(NEW_LINE)
+                       assert not el.text
+                       for sub in reversed(el):
+                               assert sub.tag in ["thead", "tbody"]
+                               assert not sub.text
+                               for row in reversed(sub):
+                                       assert row.tag == "tr"
+                                       assert not row.text or row.tail
+                                       for (i, cell) in enumerate(reversed(row)):
+                                               assert not cell.tail
+                                               worklist.appendleft(cell)
+                                               if i != len(row) - 1:
+                                                       worklist.appendleft(SPACE)
+                                       worklist.appendleft(NEW_LINE)
+               elif el.tag == "br":
+                       if el.tail:
+                               worklist.appendleft(el.tail)
+                       worklist.appendleft(NEW_LINE)
+               else:
+                       raise RuntimeError(f"can't parse tag {el.tag}")
+
+       lines = list(filter(
+               None,
+               (
+                       "".join(" " if part is SPACE else part for part in parts).strip()
+                       for parts in out
+               )
+       ))
+
+       command = next(
+               filter(None, (_parse_command(l, bot_user) for l in lines)),
+               None
+       )
+
+       if lines:
+               first = lines[0]
+               match = re.match(
+                       "(?P<v>v)?(?P<neg>-)?(?P<num>\\d+((?P<sep>[,. \u2009]|, )\\d+((?P=sep)\\d+)*)?)( |$)",
+                       first,
+                       re.ASCII,  # only recognize ASCII digits
+               )
+               if match:
+                       ct_str = match["num"]
+                       sep = match["sep"]
+
+                       zeros = False
+                       while len(ct_str) > 1 and ct_str[0] == "0":
+                               zeros = True
+                               ct_str = ct_str.removeprefix("0").removeprefix(sep or "")
+
+                       parts = ct_str.split(sep) if sep else [ct_str]
+                       parts_valid = (
+                               len(parts[0]) in range(1, 3)
+                               and all(len(p) == 3 for p in parts[1:])
+                       )
+                       digits = "".join(parts)
+                       lone = not first[match.end() :].strip() and len(lines) == 1
+                       typo = False
+                       if lone:
+                               if match["v"] and len(ct_str) <= 2:
+                                       # failed paste of leading digits
+                                       typo = True
+                               elif match["v"] and parts_valid:
+                                       # v followed by count
+                                       typo = True
+                               elif curr_count and curr_count >= 100:
+                                       goal = (sep or "").join(_separate(str(curr_count)))
+                                       partials = [goal[: -2], goal[: -1], goal[: -2] + goal[-1]]
+                                       if ct_str in partials:
+                                               # missing any of last two digits
+                                               typo = True
+                                       elif ct_str in [p + goal for p in partials]:
+                                               # double paste
+                                               typo = True
+                       if match["v"] or zeros or typo or (digits == "0" and match["neg"]):
+                               number = None
+                               count_attempt = True
+                               deletable = lone
+                       elif parts_valid:
+                               number = -int(digits) if match["neg"] else int(digits)
+                               count_attempt = True
+                               special = curr_count is not None and abs(number - curr_count) <= 25 and _is_special_number(number)
+                               deletable = lone and not special
+                       else:
+                               number = None
+                               count_attempt = False
+                               deletable = lone
+               else:
+                       # no count attempt found
+                       number = None
+                       count_attempt = False
+                       deletable = False
+       else:
+               # no lines in update
+               number = None
+               count_attempt = False
+               deletable = True
+
+       return ParsedUpdate(
+               number = number,
+               command = command,
+               count_attempt = count_attempt,
+               deletable = deletable,
+       )
+
+
+def _separate(digits: str) -> list[str]:
+       mod = len(digits) % 3
+       out = []
+       if mod:
+               out.append(digits[: mod])
+       out.extend(digits[i : i + 3] for i in range(mod, len(digits), 3))
+       return out
+
+
+def _is_special_number(num: int) -> bool:
+       num_str = str(num)
+       return bool(
+               num % 1000 in [0, 1, 333, 999]
+               or (num > 10_000_000 and "".join(reversed(num_str)) == num_str)
+               or re.match(r"(.+)\1+$", num_str)  # repeated sequence
+       )