Implement Mentionbot and tests
authorJakob Cornell <jakob+gpg@jcornell.net>
Mon, 19 Dec 2022 04:34:34 +0000 (22:34 -0600)
committerJakob Cornell <jakob+gpg@jcornell.net>
Mon, 19 Dec 2022 04:34:34 +0000 (22:34 -0600)
mentionbot/docs/sample_config.ini [new file with mode: 0644]
mentionbot/pyproject.toml [new file with mode: 0644]
mentionbot/setup.cfg [new file with mode: 0644]
mentionbot/setup.py [new file with mode: 0644]
mentionbot/src/mentionbot/__init__.py [new file with mode: 0644]
mentionbot/src/mentionbot/__main__.py [new file with mode: 0644]
mentionbot/src/mentionbot/abc.py [new file with mode: 0644]
mentionbot/src/mentionbot/tests.py [new file with mode: 0644]

diff --git a/mentionbot/docs/sample_config.ini b/mentionbot/docs/sample_config.ini
new file mode 100644 (file)
index 0000000..4228b6a
--- /dev/null
@@ -0,0 +1,28 @@
+# see Python `configparser' docs for precise file syntax
+
+[config]
+
+# authorization ID for Reddit API token lookup
+auth ID = 3
+
+# (int, minutes) 
+token update period = 15
+
+thread ID = abc123
+
+# (int) maximum number of mentions within an update for which to send messages
+mention limit = 10
+
+# (float, seconds) wait this much time before sending notifications, in case of update deletion
+message buffer time = 10
+
+# (float, seconds) wait this much time between API polls (and WS connection checks)
+API poll period = 90
+
+# (float, seconds) replace the WS connection if it falls this far behind the API in update timestamps
+max WS delay = 2
+
+
+# Postgres database configuration; same options as in a connect string.
+[db connect params]
+host = example.org
diff --git a/mentionbot/pyproject.toml b/mentionbot/pyproject.toml
new file mode 100644 (file)
index 0000000..8fe2f47
--- /dev/null
@@ -0,0 +1,3 @@
+[build-system]
+requires = ["setuptools>=42", "wheel"]
+build-backend = "setuptools.build_meta"
diff --git a/mentionbot/setup.cfg b/mentionbot/setup.cfg
new file mode 100644 (file)
index 0000000..b3f9866
--- /dev/null
@@ -0,0 +1,15 @@
+[metadata]
+name = mentionbot
+version = 0.0.0
+
+[options]
+package_dir =
+       = src
+packages = mentionbot
+python_requires = ~= 3.9
+install_requires =
+       beautifulsoup4 ~= 4.9
+       #psycopg2 ~= 2.8
+       psycopg2-binary ~= 2.8
+       trio == 0.19
+       trio-websocket == 0.9.2
diff --git a/mentionbot/setup.py b/mentionbot/setup.py
new file mode 100644 (file)
index 0000000..056ba45
--- /dev/null
@@ -0,0 +1,4 @@
+import setuptools
+
+
+setuptools.setup()
diff --git a/mentionbot/src/mentionbot/__init__.py b/mentionbot/src/mentionbot/__init__.py
new file mode 100644 (file)
index 0000000..76fd7e2
--- /dev/null
@@ -0,0 +1,187 @@
+from dataclasses import dataclass
+from datetime import timedelta
+from http.client import HTTPResponse
+from logging import Logger
+from socket import EAI_AGAIN, EAI_FAIL, gaierror
+from typing import AsyncContextManager, Callable, Iterable, Optional, Type
+from urllib.error import HTTPError
+from urllib.request import Request, urlopen
+from uuid import UUID
+import importlib.metadata
+import re
+
+from bs4 import BeautifulSoup
+from trio import CancelScope, MemoryReceiveChannel, MemorySendChannel, move_on_at
+import trio
+
+from .abc import DeleteEvent, EventStream, Notifier, NotifyInfo, ThreadTitleProvider, UpdateChecker, UpdateEvent
+
+
+_USER_AGENT = "any:net.jcornell.mentionbot:v{} (by /u/jaxklax)".format(
+       importlib.metadata.version(__package__)
+)
+
+
+@dataclass
+class TokenRef:
+       """Mutable container for an access token."""
+
+       token: str
+
+       def get_common_headers(self) -> dict[str, str]:
+               return {
+                       "User-Agent": _USER_AGENT,
+                       "Authorization": "Bearer {}".format(self.token)
+               }
+
+
+def try_request(
+       request: Request,
+       logger: Logger,
+       suppress_codes: list[int] = [404, 429, 500, 503],
+) -> Optional[HTTPResponse]:
+
+       try:
+               return urlopen(request)
+       except gaierror as e:
+               if e.errno in [EAI_FAIL, EAI_AGAIN]:
+                       logger.error(f"DNS failure: {e}")
+                       return None
+               else:
+                       raise
+       except HTTPError as e:
+               e.read()
+               e.close()
+               if e.code in suppress_codes:
+                       logger.error(f"HTTP {e.code}")
+                       return None
+               else:
+                       raise
+
+
+@dataclass
+class WsListenerState:
+       """The WebSockets listener uses this object to communicate with the API poll and message sender tasks."""
+
+       # The API poll task uses this scope to force a refresh of the WS connection if it's been too quiet.
+       scope: Optional[CancelScope] = None
+
+       # The WS task uses this to track the UUID timestamp of the most recent update message update, so the API poll task
+       # can detect when the connection goes silent.
+       last_update_ts: Optional[int] = None
+
+       # The WS task sets this when the WS connection goes down and clears it when it's recovered.
+       down: bool = True
+
+       # The WS task sets this to the UUID timestamp of the first update recieved on connection recovery so the message
+       # sender task knows when to stop checking for deletion.
+       resync_ts: Optional[int] = None
+
+
+def _parse_mentions(body: BeautifulSoup) -> set[str]:
+       """Extract the names of users tagged in an update."""
+       return {
+               el["href"].removeprefix("/u/")
+               for el in body.find_all("a", href = re.compile(r"^/u/[\w-]{3,20}$", re.ASCII))
+       }
+
+
+async def ws_listener_impl(
+       event_stream_factory: Callable[[], AsyncContextManager[EventStream]],
+       event_stream_catch: Iterable[Type[BaseException]],
+       state: WsListenerState,
+       event_tx: MemorySendChannel,  # message type `Event'
+       logger: Logger,
+) -> None:
+
+       while True:
+               state.scope = CancelScope()
+               with state.scope:
+                       logger.debug("opening a new connection")
+                       try:
+                               async with event_stream_factory() as event_stream:
+                                       first_update = True
+                                       while True:
+                                               event = await event_stream.get_event()
+                                               if isinstance(event, UpdateEvent):
+                                                       update_id = UUID(event.payload_data["id"])
+                                                       state.last_update_ts = update_id.time
+                                                       if first_update:
+                                                               # first update after discontinuity
+                                                               state.down = False
+                                                               state.resync_ts = update_id.time
+                                                               first_update = False
+                                               elif isinstance(event, DeleteEvent):
+                                                       logger.debug("delete event for " + event.payload)
+                                               else:
+                                                       raise RuntimeError(f"expected Event, got {event}")
+
+                                               await event_tx.send(event)
+                       except tuple(event_stream_catch) as e:
+                               logger.warning(f"WebSockets error: {e}")
+
+               state.down = True
+
+
+async def message_sender_impl(
+       notifier: Notifier,
+       title_provider: ThreadTitleProvider,
+       update_checker: UpdateChecker,
+       ws_state: WsListenerState,
+       thread_id: str,
+       event_rx: MemoryReceiveChannel,  # message type `Event'
+       buffer_time: timedelta,
+       mention_limit: int,
+       logger: Logger,
+) -> None:
+
+       @dataclass
+       class QueueEntry:
+               payload_data: dict
+               mentions: set[str]
+               arrival: float  # Trio time
+
+       queue: list[QueueEntry] = []
+
+       while True:
+               logger.debug("queue: {}".format([e.payload_data["id"] for e in queue]))
+               if queue:
+                       scope = move_on_at(queue[0].arrival + buffer_time.total_seconds())
+               else:
+                       scope = CancelScope()  # no timeout
+
+               with scope:
+                       event = await event_rx.receive()
+
+               if scope.cancelled_caught:
+                       # timed out
+                       entry = queue.pop(0)
+                       update_id = entry.payload_data["id"]
+                       if ws_state.down or ws_state.resync_ts and UUID(update_id).time < ws_state.resync_ts:
+                               notify = update_checker.exists(thread_id, update_id)
+                       else:
+                               notify = True
+
+                       if notify:
+                               thread_title = title_provider.title_for(thread_id)
+                               for recipient in entry.mentions:
+                                       notifier.notify(NotifyInfo(
+                                               recipient, thread_id, thread_title, update_id, entry.payload_data["author"],
+                                               entry.payload_data["body"]
+                                       ))
+               else:
+                       # got an event
+                       if isinstance(event, UpdateEvent):
+                               update_id = event.payload_data["id"]
+                               if not any(entry.payload_data["id"] == update_id for entry in queue):
+                                       mentions = _parse_mentions(BeautifulSoup(event.payload_data["body_html"], "html.parser"))
+                                       if mentions:
+                                               if len(mentions) > mention_limit:
+                                                       logger.info(f"ignoring {update_id} due to too many mentions ({len(mentions)})")
+                                               else:
+                                                       queue.append(QueueEntry(event.payload_data, mentions, trio.current_time()))
+                       elif isinstance(event, DeleteEvent):
+                               logger.debug("delete event for " + event.payload)
+                               queue = [entry for entry in queue if entry.payload_data["name"] != event.payload]
+                       else:
+                               raise RuntimeError(f"invalid event type {type(event)}")
diff --git a/mentionbot/src/mentionbot/__main__.py b/mentionbot/src/mentionbot/__main__.py
new file mode 100644 (file)
index 0000000..960a228
--- /dev/null
@@ -0,0 +1,291 @@
+"""
+Mentionbot: Notify users by private message when they're /u/-mentioned in a live thread.
+"""
+
+from argparse import ArgumentParser
+from configparser import ConfigParser
+from contextlib import asynccontextmanager
+from dataclasses import dataclass
+from datetime import timedelta
+from logging import Formatter, getLogger, Logger, NOTSET, StreamHandler
+from sys import stdout
+from typing import ClassVar, Iterator, Optional, Type
+from urllib.error import HTTPError
+from urllib.parse import urlencode
+from urllib.request import Request, urlopen
+from uuid import UUID
+import json
+import logging
+
+from trio import MemorySendChannel, open_memory_channel, open_nursery
+from trio_websocket import ConnectionClosed, HandshakeError, open_websocket_url, WebSocketConnection
+import psycopg2
+import trio
+
+from . import TokenRef, message_sender_impl, try_request, ws_listener_impl, WsListenerState
+from .abc import DeleteEvent, Event, EventStream, Notifier, NotifyInfo, ThreadTitleProvider, UpdateChecker, UpdateEvent
+
+
+async def _token_update_impl(period: timedelta, db_conn, auth_id: int, task_status) -> None:
+       def fetch_token() -> str:
+               cursor = db_conn.cursor()
+               cursor.execute(
+                       """
+                               select access_token from public.reddit_app_authorization
+                               where id = %s
+                       """,
+                       (auth_id,)
+               )
+               [(token,)] = cursor.fetchall()
+               return token
+
+       ref = TokenRef(fetch_token())
+       task_status.started(ref)
+       while True:
+               await trio.sleep(period.total_seconds())
+               ref.token = fetch_token()
+
+
+async def _api_poll_impl(
+       ws_state: WsListenerState,
+       token_ref: TokenRef,
+       thread_id: str,
+       event_tx: MemorySendChannel,  # message type `Event'
+       poll_period: timedelta,
+       max_ws_delay: timedelta,
+       logger: Logger,
+) -> None:
+
+       base_url = f"https://oauth.reddit.com/live/{thread_id}"
+
+       latest_payload: Optional[dict] = None
+       while True:
+               # fetch updates and emit events
+               params = {
+                       "raw_json": "1",
+               }
+               if latest_payload:
+                       updates = []
+                       while True:
+                               params["before"] = latest_payload["data"]["name"]
+                               url = base_url + "?" + urlencode(params)
+                               resp = try_request(Request(url, headers = token_ref.get_common_headers()), logger)
+                               if resp:
+                                       with resp:
+                                               children = json.load(resp)["data"]["children"]
+                                       if children:
+                                               updates.extend(reversed(children))
+                                               latest_payload = children[0]
+                                       else:
+                                               break
+                               else:
+                                       break
+               else:
+                       params["limit"] = "1"
+                       url = base_url + "?" + urlencode(params)
+                       resp = try_request(Request(url, headers = token_ref.get_common_headers()), logger)
+                       if resp:
+                               with resp:
+                                       updates = json.load(resp)["data"]["children"]
+                               if updates:
+                                       latest_payload = updates[0]
+                       else:
+                               updates = []
+
+               for update in updates:
+                       await event_tx.send(UpdateEvent(update["data"]))
+
+               # check if WS connection is behind
+               if not ws_state.down and latest_payload:
+                       max_delta = max_ws_delay.total_seconds() * 10 ** 7  # in 100*nanosecond
+                       if UUID(latest_payload["data"]["id"]).time - ws_state.last_update_ts > max_delta:
+                               ws_state.scope.cancel()
+
+               await trio.sleep(poll_period.total_seconds())
+
+
+def _build_notification_md(info: NotifyInfo) -> str:
+       # TODO escape thread title?
+       return "\n\n".join([
+               f"from [{info.author}](/user/{info.author}) in [{info.thread_title}](/live/{info.thread_id})",
+               "\n".join(">" + line for line in info.update_body.splitlines()),
+               (
+                       f"[^^permalink](/live/{info.thread_id}/updates/{info.update_id})"
+                       + f" [^^context](/live/{info.thread_id}?after=LiveUpdate_{info.update_id})"
+                       + " [^^about ^^these ^^notifications](/r/live_mentions/wiki)"
+               ),
+       ])
+
+
+@dataclass
+class RedditNotifier(Notifier):
+       token_ref: TokenRef
+       logger: Logger
+
+       def notify(self, info: NotifyInfo) -> None:
+               params = {
+                       "raw_json": "1",
+                       "api_type": "json",
+                       "subject": "username mention",
+                       "text": _build_notification_md(info),
+               }
+               req = Request(
+                       "https://oauth.reddit.com/api/compose",
+                       method = "POST",
+                       data = urlencode({**params, "to": info.recipient}).encode("ascii"),
+                       headers = self.token_ref.get_common_headers(),
+               )
+               resp = try_request(req, self.logger)
+               if resp:
+                       with resp:
+                               data = json.load(resp)["json"]
+                       if data["errors"]:
+                               [(code, _, _)] = data["errors"]
+                               if code != "USER_DOESNT_EXIST":
+                                       raise RuntimeError("PM send error: {code}")
+
+
+@dataclass
+class RedditThreadTitleProvider(ThreadTitleProvider):
+       token_ref: TokenRef
+
+       def title_for(self, thread_id: str) -> str:
+               req = Request(
+                       f"https://oauth.reddit.com/live/{thread_id}/about?raw_json=1",
+                       headers = self.token_ref.get_common_headers()
+               )
+               with urlopen(req) as resp:
+                       return json.load(resp)["data"]["title"]
+
+
+@dataclass
+class RedditUpdateChecker(UpdateChecker):
+       token_ref: TokenRef
+       logger: Logger
+
+       def exists(self, thread_id: str, update_id: str) -> bool:
+               req = Request(
+                       f"https://oauth.reddit.com/live/{thread_id}/updates/{update_id}?raw_json=1",
+                       headers = self.token_ref.get_common_headers(),
+               )
+               try:
+                       resp = try_request(req, self.logger, suppress_codes = [429, 500, 503])
+               except HTTPError as e:
+                       e.read()
+                       e.close()
+                       if e.code == 404:
+                               return False
+                       else:
+                               raise
+               else:
+                       if resp:
+                               resp.read()
+                               resp.close()
+                       return True
+
+
+@dataclass
+class WsConnectionEventStream(EventStream):
+       EXCEPTION_TYPES: ClassVar[set[Type[BaseException]]] = {HandshakeError, ConnectionClosed}
+
+       ws_conn: WebSocketConnection
+
+       async def get_event(self) -> Event:
+               while True:
+                       message = json.loads(await self.ws_conn.get_message())
+                       if message["type"] == "update":
+                               return UpdateEvent(message["payload"]["data"])
+                       elif message["type"] == "delete":
+                               return DeleteEvent(message["payload"])
+
+
+async def _main(
+       auth_id: int,
+       buffer_time: timedelta,
+       db_conn,
+       logger: Logger,
+       max_ws_delay: timedelta,
+       mention_limit: int,
+       poll_period: timedelta,
+       thread_id: str,
+       token_period: timedelta,
+) -> None:
+
+       @asynccontextmanager
+       async def event_stream_factory() -> Iterator[WsConnectionEventStream]:
+               resp = urlopen(Request(
+                       f"https://oauth.reddit.com/live/{thread_id}/about?raw_json=1",
+                       headers = token_ref.get_common_headers(),
+               ))
+               with resp:
+                       ws_url = json.load(resp)["data"]["websocket_url"]
+
+               async with open_websocket_url(ws_url) as ws_conn:
+                       yield WsConnectionEventStream(ws_conn)
+
+       async with open_nursery() as nursery:
+               token_ref = await nursery.start(_token_update_impl, token_period, db_conn, auth_id)
+
+               ws_state = WsListenerState()
+               (event_tx, event_rx) = open_memory_channel(0)
+               nursery.start_soon(
+                       ws_listener_impl,
+                       event_stream_factory, WsConnectionEventStream.EXCEPTION_TYPES, ws_state, event_tx, logger.getChild("WS")
+               )
+               nursery.start_soon(
+                       _api_poll_impl,
+                       ws_state, token_ref, thread_id, event_tx, poll_period, max_ws_delay, logger.getChild("poll")
+               )
+
+               notifier = RedditNotifier(token_ref, logger)
+               title_provider = RedditThreadTitleProvider(token_ref)
+               deletion_checker = RedditUpdateChecker(token_ref, logger)
+               nursery.start_soon(
+                       message_sender_impl,
+                       notifier, title_provider, deletion_checker, ws_state, thread_id, event_rx, buffer_time, mention_limit,
+                       logger.getChild("sender")
+               )
+
+
+if __name__ == "__main__":
+       arg_parser = ArgumentParser(__package__)
+       arg_parser.add_argument("config_path")
+       args = arg_parser.parse_args()
+
+       config_parser = ConfigParser()
+       with open(args.config_path) as config_file:
+               config_parser.read_file(config_file)
+
+       main_cfg = config_parser["config"]
+
+       auth_id = main_cfg.getint("auth ID")
+       assert auth_id is not None
+
+       token_period = timedelta(minutes = main_cfg.getint("token update period"))
+
+       thread_id = main_cfg["thread ID"]
+
+       mention_limit = main_cfg.getint("mention limit")
+       assert mention_limit is not None and mention_limit >= 0
+
+       buffer_time = timedelta(seconds = main_cfg.getfloat("message buffer time"))
+
+       poll_period = timedelta(seconds = main_cfg.getfloat("API poll period"))
+
+       max_ws_delay = timedelta(seconds = main_cfg.getfloat("max WS delay"))
+
+       db_conn = psycopg2.connect(**config_parser["db connect params"])
+       db_conn.autocommit = True
+
+       logger = getLogger(__package__)
+       logger.setLevel(NOTSET - 1)  # filter in handler(s) instead
+
+       handler = StreamHandler(stdout)
+       handler.setLevel(logging.DEBUG)
+       handler.setFormatter(Formatter("{asctime:23}: {name:17}: {levelname:8}: {message}", style = "{"))
+       logger.addHandler(handler)
+
+       trio.run(
+               _main,
+               auth_id, buffer_time, db_conn, logger, max_ws_delay, mention_limit, poll_period, thread_id, token_period
+       )
diff --git a/mentionbot/src/mentionbot/abc.py b/mentionbot/src/mentionbot/abc.py
new file mode 100644 (file)
index 0000000..3c1d7fd
--- /dev/null
@@ -0,0 +1,50 @@
+from abc import ABCMeta, abstractmethod
+from dataclasses import dataclass
+
+
+class Event(metaclass = ABCMeta):
+       pass
+
+
+@dataclass
+class UpdateEvent(Event):
+       payload_data: dict  # object at "payload" -> "data" in the message JSON
+
+
+@dataclass
+class DeleteEvent(Event):
+       payload: str  # update name
+
+
+class EventStream(metaclass = ABCMeta):
+       @abstractmethod
+       async def get_event(self) -> Event:
+               raise NotImplementedError()
+
+
+@dataclass
+class NotifyInfo:
+       recipient: str
+       thread_id: str
+       thread_title: str
+       update_id: str
+       author: str
+       update_body: str
+
+
+class Notifier(metaclass = ABCMeta):
+       @abstractmethod
+       def notify(self, info: NotifyInfo) -> None:
+               raise NotImplementedError()
+
+
+class ThreadTitleProvider(metaclass = ABCMeta):
+       @abstractmethod
+       def title_for(self, thread_id: str) -> str:
+               raise NotImplementedError()
+
+
+class UpdateChecker(metaclass = ABCMeta):
+       @abstractmethod
+       def exists(self, thread_id: str, update_id: str) -> bool:
+               raise NotImplementedError()
diff --git a/mentionbot/src/mentionbot/tests.py b/mentionbot/src/mentionbot/tests.py
new file mode 100644 (file)
index 0000000..0617c90
--- /dev/null
@@ -0,0 +1,350 @@
+from contextlib import asynccontextmanager, contextmanager
+from dataclasses import dataclass, field
+from datetime import timedelta
+from logging import getLogger
+from numbers import Real
+from os import environ
+from typing import Iterator, TypeVar
+from unittest import TestCase
+from urllib.request import Request
+from uuid import UUID, uuid1
+import logging
+
+from trio import CancelScope, MemoryReceiveChannel, move_on_after, open_memory_channel, open_nursery, sleep_until
+from trio.testing import MockClock, wait_all_tasks_blocked
+import trio
+
+from . import message_sender_impl, TokenRef, try_request, WsListenerState, ws_listener_impl
+from .abc import DeleteEvent, Event, EventStream, Notifier, NotifyInfo, ThreadTitleProvider, UpdateChecker, UpdateEvent
+
+
+T = TypeVar("T")
+
+
+@dataclass
+class MockNotifier(Notifier):
+       notifies: list[NotifyInfo] = field(default_factory = list)
+
+       def notify(self, info: NotifyInfo) -> None:
+               self.notifies.append(info)
+
+
+class MockTitleProvider(ThreadTitleProvider):
+       def title_for(self, thread_id: str) -> str:
+               return thread_id
+
+
+class UnusedUpdateChecker(UpdateChecker):
+       def exists(self, thread_id: str, update_id: str) -> bool:
+               raise AssertionError("unexpected deletion check")
+
+
+@dataclass
+class DictBackedUpdateChecker(UpdateChecker):
+       exists_by_id: dict[str, bool] = field(default_factory = dict)
+       queries: list[tuple[str, str]] = field(default_factory = list)
+
+       def exists(self, thread_id: str, update_id: str) -> bool:
+               self.queries.append((thread_id, update_id))
+               return self.exists_by_id[update_id]
+
+
+@dataclass
+class ChannelBackedEventStream(EventStream):
+       event_rx: MemoryReceiveChannel
+
+       async def get_event(self) -> Event:
+               obj = await self.event_rx.receive()
+               if isinstance(obj, BaseException):
+                       raise obj
+               else:
+                       return obj
+
+
+@asynccontextmanager
+async def null_async_context(val: T) -> Iterator[T]:
+       yield val
+
+
+def _update_event(author: str, mentions: list[str]) -> UpdateEvent:
+       id_ = uuid1()
+       body = str(id_) + " " + " ".join("/u/" + user for user in mentions)
+       body_html = str(id_) + " " + " ".join(
+               f'<a href="/u/{user}">/u/{user}</a>'
+               for user in mentions
+       )
+       return UpdateEvent(
+               {
+                       "id": str(id_),
+                       "name": "LiveUpdate_" + str(id_),
+                       "author": author,
+                       "body": body,
+                       "body_html": body_html,
+               }
+       )
+
+
+def setUpModule() -> None:
+       getLogger().setLevel(logging.CRITICAL + 1)  # disable logging
+
+
+class TaskTests(TestCase):
+       @contextmanager
+       def _assert_completes_in(self, duration_seconds: Real) -> Iterator[None]:
+               with move_on_after(duration_seconds) as scope:
+                       yield
+               if scope.cancelled_caught:
+                       self.fail("operation timed out")
+
+       def test_buffering(self) -> None:
+               """
+               No API poll task, just a custom task mocking the WS listener by playing events over a channel, and a real
+               notifier task injected with mocks to capture output. The WS state is hard-coded as up, so no testing of recovery
+               from WS connection issues is included in this test.
+
+               Primarily it makes sure that updates are passed through the buffer and processed if and only if a delete event
+               doesn't arrive while they're buffered.
+               """
+
+               buffer_time = timedelta(seconds = 1)
+               notifier = MockNotifier()
+               (event_tx, event_rx) = open_memory_channel(0)
+
+               ign_event = _update_event("sender", [])
+               event_a = _update_event("sender", ["rec1", "rec2"])
+               event_b = _update_event("sender", ["rec1", "rec2"])
+
+               async def event_stream_impl(scope: CancelScope) -> None:
+                       await sleep_until(1.0)
+                       await event_tx.send(ign_event)
+                       await event_tx.send(event_a)
+
+                       await sleep_until(1.5)
+                       self.assertFalse(notifier.notifies)
+                       await event_tx.send(DeleteEvent(event_a.payload_data["id"]))
+
+                       await sleep_until(2.0)
+                       await event_tx.send(event_b)
+
+                       # let things settle and terminate both tasks
+                       await sleep_until(4.0)
+                       scope.cancel()
+
+               async def main() -> None:
+                       async with open_nursery() as nursery:
+                               nursery.start_soon(event_stream_impl, nursery.cancel_scope)
+                               nursery.start_soon(
+                                       message_sender_impl,
+                                       notifier, MockTitleProvider(), UnusedUpdateChecker(), WsListenerState(down = False), "ThreadId",
+                                       event_rx, buffer_time, getLogger()
+                               )
+
+               trio.run(main, clock = MockClock(autojump_threshold = 0))
+
+               # only notifications sent are for event B
+               self.assertEqual(len(notifier.notifies), 2)
+               self.assertEqual({i.recipient for i in notifier.notifies}, {"rec1", "rec2"})
+               self.assertTrue(all(i.author == "sender" for i in notifier.notifies))
+               self.assertTrue(all(i.update_id == event_b.payload_data["id"] for i in notifier.notifies))
+
+       def test_ws_trouble(self) -> None:
+               """
+               Ensure that the notifier task responds correctly to the WS connection going down and coming back up.
+               """
+
+               thread_id = "ThreadId"
+               buffer_time = timedelta(seconds = 3)
+               update_checker = DictBackedUpdateChecker()
+               notifier = MockNotifier()
+               (event_tx, event_rx) = open_memory_channel(0)
+
+               # sent while WS is up, checked and notifies sent
+               event_a = _update_event("sender", ["rec"])
+
+               # sent and deleted while WS is down, checked and discarded from buffer
+               event_b = _update_event("sender", ["rec"])
+
+               # sent on WS reconnect, not checked and notifies sent
+               event_c = _update_event("sender", ["rec"])
+
+               # sent after WS reconnect, not checked and notifies sent
+               event_d = _update_event("sender", ["rec"])
+
+               ws_state = WsListenerState(down = False)
+
+               async def ws_simulator_impl() -> None:
+                       update_checker.exists_by_id[event_a.payload_data["id"]] = True
+                       await event_tx.send(event_a)
+
+                       await sleep_until(1.0)
+                       ws_state.down = True
+                       update_checker.exists_by_id[event_b.payload_data["id"]] = True
+                       await event_tx.send(event_b)
+
+                       await sleep_until(2.0)
+                       update_checker.exists_by_id[event_b.payload_data["id"]] = False
+                       ws_state.down = False
+                       ws_state.resync_ts = UUID(event_c.payload_data["id"]).time
+                       await event_tx.send(event_c)
+
+                       await sleep_until(3.0)
+                       await event_tx.send(event_d)
+
+               async def main() -> None:
+                       async with open_nursery() as nursery:
+                               nursery.start_soon(ws_simulator_impl)
+                               nursery.start_soon(
+                                       message_sender_impl,
+                                       notifier, MockTitleProvider(), update_checker, ws_state, "ThreadId", event_rx, buffer_time,
+                                       getLogger()
+                               )
+
+                               await sleep_until(3.5)
+                               self.assertEqual([i.update_id for i in notifier.notifies], [event_a.payload_data["id"]])
+                               notifier.notifies.clear()
+                               self.assertEqual(update_checker.queries, [(thread_id, event_a.payload_data["id"])])
+                               update_checker.queries.clear()
+
+                               await sleep_until(4.5)
+                               self.assertFalse(notifier.notifies)
+                               self.assertEqual(update_checker.queries, [(thread_id, event_b.payload_data["id"])])
+                               update_checker.queries.clear()
+
+                               await sleep_until(5.5)
+                               self.assertEqual([i.update_id for i in notifier.notifies], [event_c.payload_data["id"]])
+                               notifier.notifies.clear()
+                               self.assertFalse(update_checker.queries)
+
+                               await sleep_until(6.5)
+                               self.assertEqual([i.update_id for i in notifier.notifies], [event_d.payload_data["id"]])
+                               notifier.notifies.clear()
+                               self.assertFalse(update_checker.queries)
+
+                               nursery.cancel_scope.cancel()
+
+               trio.run(main, clock = MockClock(autojump_threshold = 0))
+
+       def test_ws_listener_cancel(self) -> None:
+               ws_state = WsListenerState()
+
+               (event_in_tx, event_in_rx) = open_memory_channel(0)
+               event_stream = ChannelBackedEventStream(event_in_rx)
+               event_stream_factory = lambda: null_async_context(event_stream)
+
+               (event_out_tx, event_out_rx) = open_memory_channel(0)
+
+               async def main():
+                       async with open_nursery() as nursery:
+                               nursery.start_soon(
+                                       ws_listener_impl,
+                                       event_stream_factory, set(), ws_state, event_out_tx, getLogger()
+                               )
+
+                               event_1 = _update_event("author", [])
+                               await event_in_tx.send(event_1)
+                               with self._assert_completes_in(1):
+                                       event_out = await event_out_rx.receive()
+                               self.assertIs(event_out, event_1)
+
+                               event_2 = _update_event("author", ["recipient"])
+                               await event_in_tx.send(event_2)
+                               with self._assert_completes_in(1):
+                                       event_out = await event_out_rx.receive()
+                               self.assertIs(event_out, event_2)
+
+                               initial_scope = ws_state.scope
+                               initial_scope.cancel()
+                               await wait_all_tasks_blocked()
+                               self.assertIsNotNone(ws_state.scope)
+                               self.assertIsNot(ws_state.scope, initial_scope)
+                               self.assertTrue(ws_state.down)
+
+                               event_3 = _update_event("author", [])
+                               ts = UUID(event_3.payload_data["id"]).time
+                               await event_in_tx.send(event_3)
+                               with self._assert_completes_in(1):
+                                       event_out = await event_out_rx.receive()
+                               self.assertIs(event_out, event_3)
+                               self.assertEqual(ws_state.last_update_ts, ts)
+                               self.assertFalse(ws_state.down)
+                               self.assertEqual(ws_state.resync_ts, ts)
+
+                               nursery.cancel_scope.cancel()
+
+               trio.run(main, clock = MockClock(autojump_threshold = 0))
+
+       def test_ws_listener_error(self) -> None:
+               """Test that the WS listener responds correctly to errors from the input event stream."""
+
+               class MockStreamException(Exception):
+                       pass
+
+               ws_state = WsListenerState()
+
+               (event_in_tx, event_in_rx) = open_memory_channel(0)
+               event_stream = ChannelBackedEventStream(event_in_rx)
+               event_stream_factory = lambda: null_async_context(event_stream)
+
+               (event_out_tx, event_out_rx) = open_memory_channel(0)
+
+               async def main():
+                       async with open_nursery() as nursery:
+                               nursery.start_soon(
+                                       ws_listener_impl,
+                                       event_stream_factory, {MockStreamException}, ws_state, event_out_tx, getLogger()
+                               )
+
+                               event_1 = _update_event("author", [])
+                               await event_in_tx.send(event_1)
+                               with self._assert_completes_in(1):
+                                       event_out = await event_out_rx.receive()
+                               self.assertIs(event_out, event_1)
+
+                               initial_scope = ws_state.scope
+                               await event_in_tx.send(MockStreamException())
+                               await wait_all_tasks_blocked()
+                               self.assertIsNotNone(ws_state.scope)
+                               self.assertIsNot(ws_state.scope, initial_scope)
+                               self.assertTrue(ws_state.down)
+
+                               event_2 = _update_event("author", ["recipient"])
+                               ts = UUID(event_2.payload_data["id"]).time
+                               await event_in_tx.send(event_2)
+                               with self._assert_completes_in(1):
+                                       event_out = await event_out_rx.receive()
+                               self.assertIs(event_out, event_2)
+                               self.assertEqual(ws_state.last_update_ts, ts)
+                               self.assertFalse(ws_state.down)
+                               self.assertEqual(ws_state.resync_ts, ts)
+
+                               nursery.cancel_scope.cancel()
+
+               trio.run(main, clock = MockClock(autojump_threshold = 0))
+
+
+class RedditTests(TestCase):
+       """Tests for interactions with Reddit."""
+
+       @classmethod
+       def setUpClass(cls) -> None:
+               cls._TOKEN_REF = TokenRef(environ["API_TOKEN"])
+
+       def test_try_request(self) -> None:
+               self.assertIsNotNone(
+                       try_request(
+                               Request(
+                                       "https://oauth.reddit.com/api/live/happening_now",
+                                       headers = self._TOKEN_REF.get_common_headers(),
+                               ),
+                               getLogger(),
+                       )
+               )
+
+       def test_try_request_http_error(self) -> None:
+               self.assertIsNone(
+                       try_request(
+                               Request("https://oauth.reddit.com/api/live/happening_now"),
+                               getLogger(),
+                               suppress_codes = [403],
+                       )
+               )