--- /dev/null
+*.egg-info
+__pycache__
+.mypy_cache
--- /dev/null
+[config]
+bot user = count_better
+
+# (seconds) maximum allowable spread in WebSocket message arrival times
+WS parity time = 1.0
+
+# goal size of WebSocket connection pool
+WS pool size = 5
+
+# (seconds) time after WebSocket handshake completion during which missed updates are excused
+WS warmup time = 0.3
+
+# (seconds) maximum time to hold updates for reordering
+reorder buffer time = 0.25
+
+enforcing = true
+
+# (seconds) minimum time to retain updates to enable future resyncs
+update retention time = 120.0
+
+thread ID = abc123
+
+# space-separated authorization IDs for Reddit API token lookup
+auth IDs = 13 15
+
+# Postgres database configuration; same options as in a connect string
+[db connect params]
+host = example.org
--- /dev/null
+[build-system]
+requires = ["setuptools>=42", "wheel"]
+build-backend = "setuptools.build_meta"
--- /dev/null
+[metadata]
+name = strikebot
+version = 0.0.0
+
+[options]
+package_dir =
+ = src
+packages = strikebot
+python_requires = ~= 3.9
+install_requires =
+ trio ~= 0.19
+ triopg ~= 0.6
+ trio-websocket ~= 0.9
+ asks ~= 2.4
--- /dev/null
+import setuptools
+
+
+setuptools.setup()
--- /dev/null
+from contextlib import nullcontext as nullcontext
+from dataclasses import dataclass
+from functools import total_ordering
+from typing import Optional, Set
+from uuid import UUID
+import bisect
+import datetime as dt
+import importlib.metadata
+import itertools
+import logging
+import re
+
+from trio import CancelScope, EndOfChannel
+import trio
+
+from strikebot.updates import Command, parse_update
+
+
+__version__ = importlib.metadata.version(__package__)
+
+
+@dataclass
+class _Update:
+ id: str
+ name: str
+ author: str
+ number: Optional[int]
+ command: Optional[Command]
+ ts: dt.datetime
+ count_attempt: bool
+ deletable: bool
+ stricken: bool
+
+ def can_follow(self, prior: "_Update") -> bool:
+ """Determine whether this count update can follow another count update."""
+ return self.number == prior.number + 1 and self.author != prior.author
+
+
+@total_ordering
+@dataclass
+class _BufferedUpdate:
+ update: _Update
+ release_at: float # Trio time
+
+ def __lt__(self, other):
+ return self.update.ts < other.update.ts
+
+
+@total_ordering
+@dataclass
+class _TimelineUpdate:
+ update: _Update
+ accepted: bool
+ bad_count: bool
+
+ def __lt__(self, other):
+ return self.update.ts < other.update.ts
+
+
+def _parse_rfc_4122_ts(ts: int) -> dt.datetime:
+ epoch = dt.datetime(1582, 10, 15, tzinfo = dt.timezone.utc)
+ return epoch + dt.timedelta(microseconds = ts / 10)
+
+
+def _format_update_ref(update: _Update, thread_id: str) -> str:
+ url = f"https://www.reddit.com/live/{thread_id}/updates/{update.id}"
+ author_md = update.author.replace("_", "\\_") # Reddit allows letters, digits, underscore, and hyphen
+ return f"[{update.number:,} by {author_md}]({url})"
+
+
+def _format_bad_strike_alert(update: _Update, thread_id: str) -> str:
+ return _format_update_ref(update, thread_id) + " was incorrectly stricken."
+
+
+def _format_curr_count(last_valid: Optional[_Update]) -> str:
+ if last_valid and last_valid.number is not None:
+ count_str = f"{last_valid.number + 1:,}"
+ else:
+ count_str = "unknown"
+ return f"_current count:_ {count_str}"
+
+
+async def count_tracker_impl(
+ message_rx,
+ api_pool: "strikebot.reddit_api.ApiClientPool",
+ reorder_buffer_time: dt.timedelta, # duration to hold some updates for reordering
+ thread_id: str,
+ bot_user: str,
+ enforcing: bool,
+ update_retention: dt.timedelta,
+ logger: logging.Logger,
+) -> None:
+ from strikebot.reddit_api import CorrectionUpdateRequest, DeleteRequest, ReportUpdateRequest, StrikeRequest
+
+ buffer_ = []
+ timeline = []
+ last_valid: Optional[_Update] = None
+ pending_strikes: Set[str] = set() # names of updates to mark stricken on arrival
+
+ def handle_update(update):
+ nonlocal delete_start, last_valid
+
+ tu = _TimelineUpdate(update, accepted = False, bad_count = False)
+
+ pos = bisect.bisect(timeline, tu)
+ if pos != len(timeline):
+ logger.warning(f"long transpo: {update.name}")
+
+ pred = next(
+ (timeline[i].update for i in reversed(range(pos)) if timeline[i].accepted),
+ None
+ )
+ if update.command is not Command.RESET and update.number is not None and last_valid and not pred:
+ logger.warning("ignoring {update.name}: no valid prior count on record")
+ else:
+ tu.accepted = (
+ update.command is Command.RESET
+ or (
+ update.number is not None
+ and (pred is None or pred.number is None or update.can_follow(pred))
+ )
+ )
+ timeline.insert(pos, tu)
+ if tu.accepted:
+ # resync subsequent updates already processed
+ newly_valid = []
+ newly_invalid = []
+ last_valid = update
+ for scan_tu in timeline[pos + 1:]:
+ if scan_tu.update.command is Command.RESET:
+ last_valid = scan_tu.update
+ elif scan_tu.update.number is not None:
+ accept = last_valid.number is None or scan_tu.update.can_follow(last_valid)
+ if accept and not scan_tu.accepted:
+ newly_valid.append(scan_tu)
+ elif not accept and scan_tu.accepted:
+ newly_invalid.append(scan_tu)
+ scan_tu.accepted = accept
+ if accept:
+ last_valid = scan_tu.update
+
+ parts = []
+ if newly_valid:
+ parts.append(
+ "The following counts are valid:\n\n"
+ + "\n".join(" - " + _format_update_ref(tu.update, thread_id) for tu in newly_valid)
+ )
+ for tu in newly_valid:
+ tu.bad_count = False
+
+ unstrikable = [tu for tu in newly_invalid if tu.update.stricken]
+ if unstrikable:
+ parts.append(
+ "The following counts are invalid:\n\n"
+ + "\n".join(" - " + _format_update_ref(tu.update, thread_id) for tu in unstrikable)
+ )
+ for tu in unstrikable:
+ tu.bad_count = True
+
+ if update.stricken:
+ logger.info(f"bad strike of {update.name}")
+ parts.append(_format_bad_strike_alert(update, thread_id))
+
+ for invalid_tu in newly_invalid:
+ if not invalid_tu.update.stricken:
+ if enforcing:
+ api_pool.enqueue_request(StrikeRequest(thread_id, invalid_tu.update.name, invalid_tu.update.ts))
+ invalid_tu.bad_count = True
+ invalid_tu.update.stricken = True
+ if parts:
+ parts.append(_format_curr_count(last_valid))
+ api_pool.enqueue_request(CorrectionUpdateRequest(thread_id, "\n\n".join(parts)))
+ elif update.number is not None or update.count_attempt:
+ if enforcing:
+ api_pool.enqueue_request(StrikeRequest(thread_id, update.name, update.ts))
+ tu.bad_count = True
+ update.stricken = True
+
+ if update.command is Command.REPORT:
+ api_pool.enqueue_request(ReportUpdateRequest(thread_id, body = _format_curr_count(last_valid)))
+
+ with message_rx:
+ while True:
+ if buffer_:
+ deadline = min(bu.release_at for bu in buffer_)
+ cancel_scope = CancelScope(deadline = deadline)
+ else:
+ cancel_scope = nullcontext()
+
+ msg = None
+ with cancel_scope:
+ try:
+ msg = await message_rx.receive()
+ except EndOfChannel:
+ break
+
+ if msg:
+ if msg.data["type"] == "update":
+ payload_data = msg.data["payload"]["data"]
+ stricken = payload_data["name"] in pending_strikes
+ pending_strikes.discard(payload_data["name"])
+ if payload_data["author"] != bot_user:
+ next_up = last_valid.number if last_valid else None
+ pu = parse_update(payload_data, next_up, bot_user)
+ rfc_ts = UUID(payload_data["id"]).time
+ update = _Update(
+ id = payload_data["id"],
+ name = payload_data["name"],
+ author = payload_data["author"],
+ number = pu.number,
+ command = pu.command,
+ ts = _parse_rfc_4122_ts(rfc_ts),
+ count_attempt = pu.count_attempt,
+ deletable = pu.deletable,
+ stricken = stricken or payload_data["stricken"],
+ )
+ release_at = msg.timestamp + reorder_buffer_time.total_seconds()
+ bisect.insort(buffer_, _BufferedUpdate(update, release_at))
+ elif msg.data["type"] == "strike":
+ UUID(re.match("LiveUpdate_(.+)$", msg.data["payload"])[1]) # sanity check payload
+ slot = next(
+ (
+ slot for slot in itertools.chain(buffer_, reversed(timeline))
+ if slot.update.name == msg.data["payload"]
+ ),
+ None
+ )
+ if slot:
+ slot.update.stricken = True
+ if isinstance(slot, _TimelineUpdate) and slot.accepted:
+ logger.info(f"bad strike of {slot.update.name}")
+ body = _format_bad_strike_alert(slot.update, thread_id)
+ api_pool.enqueue_request(CorrectionUpdateRequest(thread_id, body))
+ else:
+ pending_strikes.add(msg.data["payload"])
+
+ threshold = dt.datetime.now(dt.timezone.utc) - update_retention
+
+ # pull updates from the reorder buffer and process
+ new_buffer = []
+ for bu in buffer_:
+ process = (
+ # not a count
+ bu.update.number is None
+
+ # valid count
+ or (last_valid is None or last_valid.number is None or bu.update.can_follow(last_valid))
+
+ # invalid and not likely to become valid by transpo
+ or bu.update.number > last_valid.number + 3
+
+ or trio.current_time() >= bu.release_at
+ )
+
+ if process:
+ if bu.update.ts < threshold:
+ logger.warning(f"ignoring {bu.update.name}: arrived past retention window")
+ else:
+ handle_update(bu.update)
+ else:
+ new_buffer.append(bu)
+ buffer_ = new_buffer
+
+ # delete/forget old updates
+ for (i, tu) in enumerate(timeline):
+ if i >= len(timeline) - 10 or tu.update.ts >= threshold:
+ break
+ elif tu.bad_count and tu.update.deletable:
+ if enforcing:
+ api_pool.enqueue_request(DeleteRequest(thread_id, tu.update.name, tu.update.ts))
+ del timeline[: i]
--- /dev/null
+import argparse
+import configparser
+import datetime as dt
+import logging
+import sys
+
+import trio
+import trio_asyncio
+import triopg
+
+from strikebot import count_tracker_impl
+from strikebot.db import Client, Messenger
+from strikebot.live_ws import HealingReadPool, PoolMerger
+from strikebot.reddit_api import ApiClientPool
+
+
+ap = argparse.ArgumentParser(__package__)
+ap.add_argument("config_path")
+args = ap.parse_args()
+
+
+parser = configparser.ConfigParser()
+with open(args.config_path) as config_file:
+ parser.read_file(config_file)
+
+main_cfg = parser["config"]
+
+auth_ids = set(map(int, main_cfg["auth IDs"].split()))
+bot_user = main_cfg["bot user"]
+enforcing = main_cfg.getboolean("enforcing")
+reorder_buffer_time = dt.timedelta(seconds = main_cfg.getfloat("reorder buffer time"))
+thread_id = main_cfg["thread ID"]
+update_retention = dt.timedelta(seconds = main_cfg.getfloat("update retention time"))
+ws_parity_time = dt.timedelta(seconds = main_cfg.getfloat("WS parity time"))
+ws_pool_size = main_cfg.getint("WS pool size")
+ws_warmup = dt.timedelta(seconds = main_cfg.getfloat("WS warmup seconds"))
+
+db_connect_params = dict(parser["db connect params"])
+
+
+logger = logging.getLogger(__package__)
+logger.setLevel(logging.DEBUG)
+
+handler = logging.StreamHandler(sys.stdout)
+handler.setFormatter(logging.Formatter("{asctime:23}: {levelname:8}: {message}", style = "{"))
+logger.addHandler(handler)
+
+
+async def main():
+ async with trio.open_nursery() as nursery_a, triopg.connect(**db_connect_params) as db_conn:
+ (req_tx, req_rx) = trio.open_memory_channel(0)
+ client = Client(db_conn)
+ db_messenger = Messenger(client)
+
+ nursery_a.start_soon(db_messenger.db_client_impl)
+
+ api_pool = ApiClientPool(auth_ids, db_messenger, logger)
+ nursery_a.start_soon(api_pool.token_updater_impl)
+ for _ in auth_ids:
+ nursery_a.start_soon(api_pool.worker_impl)
+
+ (message_tx, message_rx) = trio.open_memory_channel(0)
+ (pool_event_tx, pool_event_rx) = trio.open_memory_channel(0)
+ merger = PoolMerger(pool_event_rx, message_tx, ws_parity_time, ws_warmup, logger)
+ async with trio.open_nursery() as nursery_b:
+ nursery_b.start_soon(merger.event_reader_impl)
+ nursery_b.start_soon(merger.timeout_handler_impl)
+ pool = HealingReadPool(nursery_b, ws_pool_size, thread_id, api_pool, pool_event_tx, logger)
+ async with pool, trio.open_nursery() as nursery_c:
+ nursery_c.start_soon(pool.conn_refresher_impl)
+ nursery_c.start_soon(
+ count_tracker_impl,
+ message_rx, api_pool, reorder_buffer_time, thread_id, bot_user, enforcing, update_retention, logger
+ )
+
+
+trio_asyncio.run(main)
--- /dev/null
+"""Single-connection Postgres client with high-level wrappers for operations."""
+
+from dataclasses import dataclass
+from typing import Any, Iterable
+
+import trio
+
+
+def _channel_sender(method):
+ async def wrapped(self, resp_channel, *args, **kwargs):
+ with resp_channel:
+ await resp_channel.send(await method(self, *args, **kwargs))
+
+ return wrapped
+
+
+@dataclass
+class Client:
+ """High-level wrappers for DB operations."""
+ conn: Any
+
+ @_channel_sender
+ async def get_auth_tokens(self, auth_ids: set[int]):
+ raise NotImplementedError()
+
+
+class Messenger:
+ def __init__(self, client: Client):
+ self._client = client
+ (self._request_tx, self._request_rx) = trio.open_memory_channel(0)
+
+ async def do(self, method_name: str, args: Iterable) -> Any:
+ """This is run by consumers of the DB wrapper."""
+ method = getattr(self._client, method_name)
+ (resp_tx, resp_rx) = trio.open_memory_channel(0)
+ await self._request_tx.send(method(resp_tx, *args))
+ async with resp_rx:
+ return await resp_rx.receive()
+
+ async def db_client_impl(self):
+ """This is the DB client Trio task."""
+ async with self._client.conn, self._request_rx:
+ async for coro in self._request_rx:
+ await coro
--- /dev/null
+from collections import deque
+from contextlib import suppress
+from dataclasses import dataclass
+from typing import Any, AsyncContextManager, Optional
+import datetime as dt
+import json
+import logging
+import math
+
+from trio_websocket import ConnectionClosed, HandshakeError, open_websocket_url, WebSocketConnection
+import trio
+
+from strikebot.reddit_api import AboutLiveThreadRequest
+
+
+@dataclass
+class _PoolEvent:
+ timestamp: float # Trio time
+
+
+@dataclass
+class _Message(_PoolEvent):
+ data: Any
+ scope: Any
+
+ def dedup_tag(self):
+ """A hashable essence of the contents of this message."""
+ if self.data["type"] == "update":
+ return ("update", self.data["payload"]["data"]["id"])
+ elif self.data["type"] in {"strike", "delete"}:
+ return (self.data["type"], self.data["payload"])
+ else:
+ raise ValueError(self.data["type"])
+
+
+@dataclass
+class _ConnectionDown(_PoolEvent):
+ scope: trio.CancelScope
+
+
+@dataclass
+class _ConnectionUp(_PoolEvent):
+ scope: trio.CancelScope
+
+
+class HealingReadPool:
+ def __init__(self, nursery: trio.Nursery, size, live_thread_id, api_client_pool, pool_event_tx, logger: logging.Logger):
+ assert size >= 2
+ self._nursery = nursery
+ self._size = size
+ self._live_thread_id = live_thread_id
+ self._api_client_pool = api_client_pool
+ self._pool_event_tx = pool_event_tx
+ self._logger = logger
+ (self._refresh_queue_tx, self._refresh_queue_rx) = trio.open_memory_channel(size)
+ self._active_count = 0
+
+ async def __aenter__(self):
+ for _ in range(self._size):
+ await self._spawn_reader()
+ if not self._active_count:
+ raise RuntimeError("Unable to create any WS connections")
+
+ async def __aexit__(self, exc_type, exc_value, traceback):
+ self._refresh_queue_tx.close()
+
+ async def _reader_impl(self, conn_ctx: AsyncContextManager, cancel_scope, refresh_tx):
+ try:
+ async with conn_ctx as conn:
+ self._active_count += 1
+ with refresh_tx:
+ with cancel_scope, suppress(ConnectionClosed):
+ while True:
+ message = await conn.get_message()
+ event = _Message(trio.current_time(), json.loads(message), cancel_scope)
+ await self._pool_event_tx.send(event)
+
+ if cancel_scope.cancelled_caught:
+ await conn.aclose(4058, "Server unexpectedly stopped sending messages")
+ self._logger.warning("replacing WS connection due to missed updates")
+ else:
+ self._logger.warning("replacing WS connection closed by server")
+ refresh_tx.send_nowait(cancel_scope)
+ self._active_count -= 1
+ except HandshakeError:
+ self._logger.error("handshake error while opening WS connection")
+
+ async def _new_connection(self) -> Optional[AsyncContextManager[WebSocketConnection]]:
+ request = AboutLiveThreadRequest(self._live_thread_id)
+ resp = await self._api_client_pool.make_request(request)
+ if resp.status_code == 200:
+ url = json.loads(resp.text)["data"]["websocket_url"]
+ return open_websocket_url(url)
+ else:
+ self._logger.error(f"thread info request failed (HTTP {resp.status_code}); can't open WS connection")
+ return None
+
+ async def _spawn_reader(self) -> None:
+ conn = await self._new_connection()
+ if conn:
+ new_scope = trio.CancelScope()
+ await self._pool_event_tx.send(_ConnectionUp(trio.current_time(), new_scope))
+ refresh_tx = self._refresh_queue_tx.clone()
+ self._nursery.start_soon(self._reader_impl, conn, new_scope, refresh_tx)
+
+ async def conn_refresher_impl(self):
+ """Task to monitor and replace WS connections as they disconnect or go silent."""
+ with self._refresh_queue_rx:
+ async for old_scope in self._refresh_queue_rx:
+ await self._pool_event_tx.send(_ConnectionDown(trio.current_time(), old_scope))
+ await self._spawn_reader()
+ if not self._active_count:
+ raise RuntimeError("WS pool depleted")
+
+
+class PoolMerger:
+ @dataclass
+ class _Bucket:
+ start: float
+ recipients: set
+
+ def __init__(self, pool_event_rx, message_tx, parity_timeout: dt.timedelta, conn_warmup: dt.timedelta, logger: logging.Logger):
+ """
+ :param parity_timeout: max delay between message receipt on different connections
+ :param conn_warmup: max interval after connection within which missed updates are allowed
+ """
+ self._pool_event_rx = pool_event_rx
+ self._message_tx = message_tx
+ self._parity_timeout = parity_timeout
+ self._conn_warmup = conn_warmup
+ self._logger = logger
+
+ self._buckets = {}
+ self._scope_activations = {}
+ self._pending = deque()
+ (self._timer_poke_tx, self._timer_poke_rx) = trio.open_memory_channel(math.inf)
+
+ async def event_reader_impl(self):
+ """Drop unused messages, deduplicate useful ones, and install info needed by the timeout handler."""
+ with self._pool_event_rx, self._timer_poke_tx:
+ async for event in self._pool_event_rx:
+ if isinstance(event, _ConnectionUp):
+ # An early add of an active scope could mean it's expected on a message that fired before it opened,
+ # resulting in a false positive for replacement. The timer task merges these in among the buckets
+ # by timestamp to avoid that.
+ self._pending.append(event)
+ elif isinstance(event, _Message):
+ if event.data["type"] in {"update", "strike", "delete"}:
+ tag = event.dedup_tag()
+ if tag in self._buckets:
+ b = self._buckets[tag]
+ b.recipients.add(event.scope)
+ if b.recipients >= self._scope_activations.keys():
+ del self._buckets[tag]
+ else:
+ sane = (
+ event.scope in self._scope_activations
+ or any(e.scope == event.scope for e in self._pending)
+ )
+ if sane:
+ self._buckets[tag] = self._Bucket(event.timestamp, {event.scope})
+ self._message_tx.send_nowait(event)
+ else:
+ raise RuntimeError("recieved message from unrecognized WS connection")
+ elif isinstance(event, _ConnectionDown):
+ # We don't need to worry about canceling this scope at all, so no need to require it for parity for
+ # any message, even older ones. The scope may be gone already, if we canceled it previously.
+ self._scope_activations.pop(event.scope, None)
+ else:
+ raise TypeError(f"Expected pool event, found {event!r}")
+ self._timer_poke_tx.send_nowait(None) # may have new work for the timer
+
+ async def timeout_handler_impl(self):
+ """When connections have had enough time to reach parity on a message, signal replacement of any slackers."""
+ with self._timer_poke_rx:
+ while True:
+ if self._buckets:
+ now = trio.current_time()
+ (tag, bucket) = min(self._buckets.items(), key = lambda i: i[1].start)
+
+ # Make sure the right scope set is ready for the next bucket up
+ while self._pending and self._pending[0].timestamp < bucket.start:
+ ev = self._pending.popleft()
+ self._scope_activations[ev.scope] = ev.timestamp
+
+ target_scopes = {
+ scope for (scope, active) in self._scope_activations.items()
+ if active + self._conn_warmup.total_seconds() < now
+ }
+ if bucket.recipients >= target_scopes:
+ del self._buckets[tag]
+ elif now > bucket.start + self._parity_timeout.total_seconds():
+ for scope in target_scopes - bucket.recipients:
+ scope.cancel()
+ del self._scope_activations[scope]
+ del self._buckets[tag]
+ else:
+ await trio.sleep_until(bucket.start + self._parity_timeout.total_seconds())
+ else:
+ try:
+ await self._timer_poke_rx.receive()
+ except trio.EndOfChannel:
+ break
--- /dev/null
+"""Unbounded blocking priority queue for Trio"""
+
+from dataclasses import dataclass
+from functools import total_ordering
+from typing import Any
+import heapq
+
+from trio.lowlevel import ParkingLot
+
+
+@total_ordering
+@dataclass
+class _ReverseOrdWrapper:
+ inner: Any
+
+ def __lt__(self, other):
+ return other.inner < self.inner
+
+
+class MaxHeap:
+ def __init__(self):
+ self._heap = []
+ self._empty_wait = ParkingLot()
+
+ def push(self, item):
+ heapq.heappush(self._heap, _ReverseOrdWrapper(item))
+ if len(self._empty_wait):
+ self._empty_wait.unpark()
+
+ async def pop(self):
+ if not self._heap:
+ await self._empty_wait.park()
+ return heapq.heappop(self._heap).inner
+
+ def __len__(self) -> int:
+ return len(self._heap)
--- /dev/null
+"""Multi-auth HTTP connection pool with Reddit API wrappers and rate limiting."""
+
+from abc import ABCMeta, abstractmethod
+from collections import deque
+from dataclasses import dataclass, field
+from functools import total_ordering
+import datetime as dt
+import logging
+
+from asks.response_objects import Response
+import asks
+import trio
+
+from strikebot import __version__ as VERSION
+from strikebot.queue import MaxHeap
+
+
+REQUEST_DELAY = dt.timedelta(seconds = 1)
+
+TOKEN_UPDATE_DELAY = dt.timedelta(minutes = 15)
+
+USER_AGENT_FMT = f"any:net.jcornell.strikebot.{{auth_id}}:v{VERSION} (by /u/jaxklax)"
+
+API_BASE_URL = "https://oauth.reddit.com"
+
+
+@total_ordering
+class _Request(metaclass = ABCMeta):
+ _SUBTYPE_PRECEDENCE = None # assigned later
+
+ def __lt__(self, other):
+ if type(self) is type(other):
+ return self._subtype_key() < other._subtype_key()
+ else:
+ prec = self._SUBTYPE_PRECEDENCE
+ return prec.index(type(other)) < prec.index(type(self))
+
+ def __eq__(self, other):
+ if type(self) is type(other):
+ return self._subtype_key() == other._subtype_key()
+ else:
+ return False
+
+ @abstractmethod
+ def _subtype_cmp_key(self):
+ # a large key corresponds to a high priority
+ raise NotImplementedError()
+
+ @abstractmethod
+ def to_asks_kwargs(self):
+ raise NotImplementedError()
+
+
+@dataclass(eq = False)
+class StrikeRequest(_Request):
+ thread_id: str
+ update_name: str
+ update_ts: dt.datetime
+
+ def to_asks_kwargs(self):
+ return {
+ "method": "POST",
+ "path": f"/api/live/{self.thread_id}/strike_update",
+ "data": {
+ "api_type": "json",
+ "id": self.update_name,
+ },
+ }
+
+ def _subtype_cmp_key(self):
+ return self.update_ts
+
+
+@dataclass(eq = False)
+class DeleteRequest(_Request):
+ thread_id: str
+ update_name: str
+ update_ts: dt.datetime
+
+ def to_asks_kwargs(self):
+ return {
+ "method": "POST",
+ "path": f"/api/live/{self.thread_id}/delete_update",
+ "data": {
+ "api_type": "json",
+ "id": self.update_name,
+ },
+ }
+
+ def _subtype_cmp_key(self):
+ return -self.update_ts
+
+
+@dataclass(eq = False)
+class AboutLiveThreadRequest(_Request):
+ thread_id: str
+ created: float = field(default_factory = trio.current_time)
+
+ def to_asks_kwargs(self):
+ return {
+ "method": "GET",
+ "path": f"/live/{self.thread_id}/about",
+ "params": {
+ "raw_json": "1",
+ },
+ }
+
+ def _subtype_cmp_key(self):
+ return -self.created
+
+
+@dataclass(eq = False)
+class _BaseUpdatePostRequest(_Request, metaclass = ABCMeta):
+ thread_id: str
+ body: str
+ created: float = field(default_factory = trio.current_time)
+
+ def to_asks_kwargs(self):
+ return {
+ "method": "POST",
+ "path": f"/live/{self.thread_id}/update",
+ "data": {
+ "api_type": "json",
+ "body": self.body,
+ },
+ }
+
+
+@dataclass(eq = False)
+class ReportUpdateRequest(_BaseUpdatePostRequest):
+ def _subtype_cmp_key(self):
+ return -self.created
+
+
+@dataclass(eq = False)
+class CorrectionUpdateRequest(_BaseUpdatePostRequest):
+ def _subtype_cmp_key(self):
+ return -self.created
+
+
+# highest priority first
+_Request._SUBTYPE_PRECEDENCE = [
+ AboutLiveThreadRequest,
+ StrikeRequest,
+ CorrectionUpdateRequest,
+ ReportUpdateRequest,
+ DeleteRequest,
+]
+
+
+@dataclass
+class AppCooldown:
+ auth_id: int
+ ready_at: float # Trio time
+
+
+class ApiClientPool:
+ def __init__(self, auth_ids: set[int], db_messenger, logger: logging.Logger):
+ self._db_messenger = db_messenger
+ self._logger = logger
+
+ self._tokens = {id_: None for id_ in auth_ids}
+ self._tokens_installed = trio.Event()
+ now = trio.current_time()
+ self._app_queue = deque(AppCooldown(id_, now) for id_ in auth_ids)
+ self._request_queue = MaxHeap()
+
+ self._session = asks.Session(connections = len(auth_ids))
+ self._session.base_location = API_BASE_URL
+
+ async def _update_tokens(self):
+ tokens = await self._db_messenger.do("get_auth_tokens", (self._tokens.keys(),))
+ assert tokens.keys() == self._tokens.keys()
+ self._tokens = tokens
+ self._tokens_installed.set()
+
+ async def token_updater_impl(self):
+ last_update = None
+ while True:
+ if last_update is not None:
+ await trio.sleep_until(last_update + TOKEN_UPDATE_DELAY.total_seconds())
+ last_update = trio.current_time()
+ await self._update_tokens()
+
+ def _check_queue_size(self) -> None:
+ if len(self._request_queue) > 5:
+ self._logger.warning(f"API workers may be saturated; {len(self._request_queue)} requests in queue")
+
+ async def make_request(self, request: _Request) -> Response:
+ (resp_tx, resp_rx) = trio.open_memory_channel(0)
+ self._request_queue.push((request, resp_tx))
+ async with resp_rx:
+ return await resp_rx.receive()
+
+ def enqueue_request(self, request: _Request) -> None:
+ self._request_queue.push((request, None))
+
+ async def worker_impl(self):
+ await self._tokens_installed.wait()
+ while True:
+ (request, resp_tx) = await self._request_queue.pop()
+ while trio.current_time() < self._app_queue[0].ready_at:
+ await trio.sleep_until(self._app_queue[0].ready_at)
+
+ cooldown = self._app_queue.popleft()
+ asks_kwargs = request.to_asks_kwargs()
+ asks_kwargs.setdefault("headers", {}).update({
+ "Authorization": "Bearer {}".format(self._tokens[cooldown.auth_id]),
+ "User-Agent": USER_AGENT_FMT.format(auth_id = cooldown.auth_id),
+ })
+ cooldown.ready_at = trio.current_time() + REQUEST_DELAY.total_seconds()
+ resp = await self._session.request(**asks_kwargs)
+
+ if resp.status_code == 429:
+ # We disagreed about the rate limit state; just try again later.
+ self._request_queue.put((request, resp_tx))
+ elif 400 <= resp.status_code < 500:
+ # If we're doing something wrong, let's catch it right away.
+ if resp_tx:
+ resp_tx.close()
+ raise RuntimeError("Unexpected client error response: {}".format(resp.status_code))
+ else:
+ if resp.status_code != 200:
+ self._logger.warning(f"got HTTP {resp.status_code} from Reddit API")
+ if resp_tx:
+ await resp_tx.send(resp)
+
+ self._app_queue.append(cooldown)
--- /dev/null
+from unittest import TestCase
+
+from strikebot.updates import parse_update
+
+
+def _build_payload(body_html: str) -> dict[str, str]:
+ return {"body_html": body_html}
+
+
+class UpdateParsingTests(TestCase):
+ def test_successful_counts(self):
+ pu = parse_update(_build_payload("<div>12,345,678 spaghetti</div>"), None)
+ self.assertEqual(pu.number, 12345678)
+ self.assertFalse(pu.deletable)
+
+ pu = parse_update(_build_payload("<div><p>0</p><br/>oz</div>"), None)
+ self.assertEqual(pu.number, 0)
+ self.assertFalse(pu.deletable)
+
+ def test_non_counts(self):
+ pu = parse_update(_build_payload("<div><table><tbody><tr><td>zoo</td></tr></tbody></table></div>"), None)
+ self.assertFalse(pu.count_attempt)
+ self.assertFalse(pu.deletable)
+
+ def test_typos(self):
+ pu = parse_update(_build_payload("<span>v9</span>"), 888)
+ self.assertIsNone(pu.number)
+ self.assertTrue(pu.count_attempt)
+
+ pu = parse_update(_build_payload("<div>v11.585 Empire</div>"), None)
+ self.assertIsNone(pu.number)
+ self.assertTrue(pu.count_attempt)
+ self.assertFalse(pu.deletable)
+
+ pu = parse_update(_build_payload("<div>11, 585, 22 </div>"), 11_585_202)
+ self.assertIsNone(pu.number)
+ self.assertTrue(pu.count_attempt)
+ self.assertTrue(pu.deletable)
+
+ pu = parse_update(_build_payload("<span>0490499</span>"), 4999)
+ self.assertIsNone(pu.number)
+ self.assertTrue(pu.count_attempt)
--- /dev/null
+from __future__ import annotations
+from collections import deque
+from dataclasses import dataclass
+from enum import Enum
+from typing import Optional
+from xml.etree import ElementTree
+import re
+
+
+Command = Enum("Command", ["RESET", "REPORT"])
+
+
+@dataclass
+class ParsedUpdate:
+ number: Optional[int]
+ command: Optional[Command]
+ count_attempt: bool
+ deletable: bool
+
+
+def _parse_command(line: str, bot_user: str) -> Optional[Command]:
+ if line.lower() == f"/u/{bot_user} reset".lower():
+ return Command.RESET
+ elif line.lower() in ["sidebar count", "current count"]:
+ return Command.REPORT
+ else:
+ return None
+
+
+def parse_update(payload_data: dict, curr_count: Optional[int], bot_user: str) -> ParsedUpdate:
+ NEW_LINE = object()
+ SPACE = object()
+
+ # flatten the update content to plain text
+ doc = ElementTree.fromstring(payload_data["body_html"])
+ worklist = deque([doc])
+ out = [[]]
+ while worklist:
+ el = worklist.popleft()
+ if el is NEW_LINE:
+ if out[-1]:
+ out.append([])
+ elif el is SPACE:
+ out[-1].append(el)
+ elif isinstance(el, str):
+ out[-1].append(el.replace("\n", " "))
+ elif el.tag in ["em", "strong", "del", "span", "sup", "code", "a", "th", "td"]:
+ if el.tail:
+ worklist.appendleft(el.tail)
+ for sub in reversed(el):
+ worklist.appendleft(sub)
+ if el.text:
+ worklist.appendleft(el.text)
+ elif el.tag == "li":
+ assert not el.tail
+ worklist.appendleft(NEW_LINE)
+ worklist.appendleft(el.text)
+ elif el.tag in ["p", "div", "h1", "h2", "blockquote"]:
+ if el.tail:
+ worklist.appendleft(el.tail)
+ worklist.appendleft(NEW_LINE)
+ for sub in reversed(el):
+ worklist.appendleft(sub)
+ if el.text:
+ worklist.appendleft(el.text)
+ elif el.tag in ["ul", "ol"]:
+ if el.tail:
+ worklist.appendleft(el.tail)
+ assert not el.text
+ for sub in reversed(el):
+ worklist.appendleft(sub)
+ worklist.appendleft(NEW_LINE)
+ elif el.tag == "pre":
+ out.extend([l] for l in el.text.splitlines())
+ if el.tail:
+ worklist.appendleft(el.tail)
+ worklist.appendleft(NEW_LINE)
+ elif el.tag == "table":
+ if el.tail:
+ worklist.appendleft(el.tail)
+ worklist.appendleft(NEW_LINE)
+ assert not el.text
+ for sub in reversed(el):
+ assert sub.tag in ["thead", "tbody"]
+ assert not sub.text
+ for row in reversed(sub):
+ assert row.tag == "tr"
+ assert not row.text or row.tail
+ for (i, cell) in enumerate(reversed(row)):
+ assert not cell.tail
+ worklist.appendleft(cell)
+ if i != len(row) - 1:
+ worklist.appendleft(SPACE)
+ worklist.appendleft(NEW_LINE)
+ elif el.tag == "br":
+ if el.tail:
+ worklist.appendleft(el.tail)
+ worklist.appendleft(NEW_LINE)
+ else:
+ raise RuntimeError(f"can't parse tag {el.tag}")
+
+ lines = list(filter(
+ None,
+ (
+ "".join(" " if part is SPACE else part for part in parts).strip()
+ for parts in out
+ )
+ ))
+
+ command = next(
+ filter(None, (_parse_command(l, bot_user) for l in lines)),
+ None
+ )
+
+ if lines:
+ first = lines[0]
+ match = re.match(
+ "(?P<v>v)?(?P<neg>-)?(?P<num>\\d+((?P<sep>[,. \u2009]|, )\\d+((?P=sep)\\d+)*)?)( |$)",
+ first,
+ re.ASCII, # only recognize ASCII digits
+ )
+ if match:
+ ct_str = match["num"]
+ sep = match["sep"]
+
+ zeros = False
+ while len(ct_str) > 1 and ct_str[0] == "0":
+ zeros = True
+ ct_str = ct_str.removeprefix("0").removeprefix(sep or "")
+
+ parts = ct_str.split(sep) if sep else [ct_str]
+ parts_valid = (
+ len(parts[0]) in range(1, 3)
+ and all(len(p) == 3 for p in parts[1:])
+ )
+ digits = "".join(parts)
+ lone = not first[match.end() :].strip() and len(lines) == 1
+ typo = False
+ if lone:
+ if match["v"] and len(ct_str) <= 2:
+ # failed paste of leading digits
+ typo = True
+ elif match["v"] and parts_valid:
+ # v followed by count
+ typo = True
+ elif curr_count and curr_count >= 100:
+ goal = (sep or "").join(_separate(str(curr_count)))
+ partials = [goal[: -2], goal[: -1], goal[: -2] + goal[-1]]
+ if ct_str in partials:
+ # missing any of last two digits
+ typo = True
+ elif ct_str in [p + goal for p in partials]:
+ # double paste
+ typo = True
+ if match["v"] or zeros or typo or (digits == "0" and match["neg"]):
+ number = None
+ count_attempt = True
+ deletable = lone
+ elif parts_valid:
+ number = -int(digits) if match["neg"] else int(digits)
+ count_attempt = True
+ special = curr_count is not None and abs(number - curr_count) <= 25 and _is_special_number(number)
+ deletable = lone and not special
+ else:
+ number = None
+ count_attempt = False
+ deletable = lone
+ else:
+ # no count attempt found
+ number = None
+ count_attempt = False
+ deletable = False
+ else:
+ # no lines in update
+ number = None
+ count_attempt = False
+ deletable = True
+
+ return ParsedUpdate(
+ number = number,
+ command = command,
+ count_attempt = count_attempt,
+ deletable = deletable,
+ )
+
+
+def _separate(digits: str) -> list[str]:
+ mod = len(digits) % 3
+ out = []
+ if mod:
+ out.append(digits[: mod])
+ out.extend(digits[i : i + 3] for i in range(mod, len(digits), 3))
+ return out
+
+
+def _is_special_number(num: int) -> bool:
+ num_str = str(num)
+ return bool(
+ num % 1000 in [0, 1, 333, 999]
+ or (num > 10_000_000 and "".join(reversed(num_str)) == num_str)
+ or re.match(r"(.+)\1+$", num_str) # repeated sequence
+ )