--- /dev/null
+# see Python `configparser' docs for precise file syntax
+
+[config]
+
+# authorization ID for Reddit API token lookup
+auth ID = 3
+
+# (int, minutes)
+token update period = 15
+
+thread ID = abc123
+
+# (int) maximum number of mentions within an update for which to send messages
+mention limit = 10
+
+# (float, seconds) wait this much time before sending notifications, in case of update deletion
+message buffer time = 10
+
+# (float, seconds) wait this much time between API polls (and WS connection checks)
+API poll period = 90
+
+# (float, seconds) replace the WS connection if it falls this far behind the API in update timestamps
+max WS delay = 2
+
+
+# Postgres database configuration; same options as in a connect string.
+[db connect params]
+host = example.org
--- /dev/null
+[build-system]
+requires = ["setuptools>=42", "wheel"]
+build-backend = "setuptools.build_meta"
--- /dev/null
+[metadata]
+name = mentionbot
+version = 0.0.0
+
+[options]
+package_dir =
+ = src
+packages = mentionbot
+python_requires = ~= 3.9
+install_requires =
+ beautifulsoup4 ~= 4.9
+ #psycopg2 ~= 2.8
+ psycopg2-binary ~= 2.8
+ trio == 0.19
+ trio-websocket == 0.9.2
--- /dev/null
+import setuptools
+
+
+setuptools.setup()
--- /dev/null
+from dataclasses import dataclass
+from datetime import timedelta
+from http.client import HTTPResponse
+from logging import Logger
+from socket import EAI_AGAIN, EAI_FAIL, gaierror
+from typing import AsyncContextManager, Callable, Iterable, Optional, Type
+from urllib.error import HTTPError
+from urllib.request import Request, urlopen
+from uuid import UUID
+import importlib.metadata
+import re
+
+from bs4 import BeautifulSoup
+from trio import CancelScope, MemoryReceiveChannel, MemorySendChannel, move_on_at
+import trio
+
+from .abc import DeleteEvent, EventStream, Notifier, NotifyInfo, ThreadTitleProvider, UpdateChecker, UpdateEvent
+
+
+_USER_AGENT = "any:net.jcornell.mentionbot:v{} (by /u/jaxklax)".format(
+ importlib.metadata.version(__package__)
+)
+
+
+@dataclass
+class TokenRef:
+ """Mutable container for an access token."""
+
+ token: str
+
+ def get_common_headers(self) -> dict[str, str]:
+ return {
+ "User-Agent": _USER_AGENT,
+ "Authorization": "Bearer {}".format(self.token)
+ }
+
+
+def try_request(
+ request: Request,
+ logger: Logger,
+ suppress_codes: list[int] = [404, 429, 500, 503],
+) -> Optional[HTTPResponse]:
+
+ try:
+ return urlopen(request)
+ except gaierror as e:
+ if e.errno in [EAI_FAIL, EAI_AGAIN]:
+ logger.error(f"DNS failure: {e}")
+ return None
+ else:
+ raise
+ except HTTPError as e:
+ e.read()
+ e.close()
+ if e.code in suppress_codes:
+ logger.error(f"HTTP {e.code}")
+ return None
+ else:
+ raise
+
+
+@dataclass
+class WsListenerState:
+ """The WebSockets listener uses this object to communicate with the API poll and message sender tasks."""
+
+ # The API poll task uses this scope to force a refresh of the WS connection if it's been too quiet.
+ scope: Optional[CancelScope] = None
+
+ # The WS task uses this to track the UUID timestamp of the most recent update message update, so the API poll task
+ # can detect when the connection goes silent.
+ last_update_ts: Optional[int] = None
+
+ # The WS task sets this when the WS connection goes down and clears it when it's recovered.
+ down: bool = True
+
+ # The WS task sets this to the UUID timestamp of the first update recieved on connection recovery so the message
+ # sender task knows when to stop checking for deletion.
+ resync_ts: Optional[int] = None
+
+
+def _parse_mentions(body: BeautifulSoup) -> set[str]:
+ """Extract the names of users tagged in an update."""
+ return {
+ el["href"].removeprefix("/u/")
+ for el in body.find_all("a", href = re.compile(r"^/u/[\w-]{3,20}$", re.ASCII))
+ }
+
+
+async def ws_listener_impl(
+ event_stream_factory: Callable[[], AsyncContextManager[EventStream]],
+ event_stream_catch: Iterable[Type[BaseException]],
+ state: WsListenerState,
+ event_tx: MemorySendChannel, # message type `Event'
+ logger: Logger,
+) -> None:
+
+ while True:
+ state.scope = CancelScope()
+ with state.scope:
+ logger.debug("opening a new connection")
+ try:
+ async with event_stream_factory() as event_stream:
+ first_update = True
+ while True:
+ event = await event_stream.get_event()
+ if isinstance(event, UpdateEvent):
+ update_id = UUID(event.payload_data["id"])
+ state.last_update_ts = update_id.time
+ if first_update:
+ # first update after discontinuity
+ state.down = False
+ state.resync_ts = update_id.time
+ first_update = False
+ elif isinstance(event, DeleteEvent):
+ logger.debug("delete event for " + event.payload)
+ else:
+ raise RuntimeError(f"expected Event, got {event}")
+
+ await event_tx.send(event)
+ except tuple(event_stream_catch) as e:
+ logger.warning(f"WebSockets error: {e}")
+
+ state.down = True
+
+
+async def message_sender_impl(
+ notifier: Notifier,
+ title_provider: ThreadTitleProvider,
+ update_checker: UpdateChecker,
+ ws_state: WsListenerState,
+ thread_id: str,
+ event_rx: MemoryReceiveChannel, # message type `Event'
+ buffer_time: timedelta,
+ mention_limit: int,
+ logger: Logger,
+) -> None:
+
+ @dataclass
+ class QueueEntry:
+ payload_data: dict
+ mentions: set[str]
+ arrival: float # Trio time
+
+ queue: list[QueueEntry] = []
+
+ while True:
+ logger.debug("queue: {}".format([e.payload_data["id"] for e in queue]))
+ if queue:
+ scope = move_on_at(queue[0].arrival + buffer_time.total_seconds())
+ else:
+ scope = CancelScope() # no timeout
+
+ with scope:
+ event = await event_rx.receive()
+
+ if scope.cancelled_caught:
+ # timed out
+ entry = queue.pop(0)
+ update_id = entry.payload_data["id"]
+ if ws_state.down or ws_state.resync_ts and UUID(update_id).time < ws_state.resync_ts:
+ notify = update_checker.exists(thread_id, update_id)
+ else:
+ notify = True
+
+ if notify:
+ thread_title = title_provider.title_for(thread_id)
+ for recipient in entry.mentions:
+ notifier.notify(NotifyInfo(
+ recipient, thread_id, thread_title, update_id, entry.payload_data["author"],
+ entry.payload_data["body"]
+ ))
+ else:
+ # got an event
+ if isinstance(event, UpdateEvent):
+ update_id = event.payload_data["id"]
+ if not any(entry.payload_data["id"] == update_id for entry in queue):
+ mentions = _parse_mentions(BeautifulSoup(event.payload_data["body_html"], "html.parser"))
+ if mentions:
+ if len(mentions) > mention_limit:
+ logger.info(f"ignoring {update_id} due to too many mentions ({len(mentions)})")
+ else:
+ queue.append(QueueEntry(event.payload_data, mentions, trio.current_time()))
+ elif isinstance(event, DeleteEvent):
+ logger.debug("delete event for " + event.payload)
+ queue = [entry for entry in queue if entry.payload_data["name"] != event.payload]
+ else:
+ raise RuntimeError(f"invalid event type {type(event)}")
--- /dev/null
+"""
+Mentionbot: Notify users by private message when they're /u/-mentioned in a live thread.
+"""
+
+from argparse import ArgumentParser
+from configparser import ConfigParser
+from contextlib import asynccontextmanager
+from dataclasses import dataclass
+from datetime import timedelta
+from logging import Formatter, getLogger, Logger, NOTSET, StreamHandler
+from sys import stdout
+from typing import ClassVar, Iterator, Optional, Type
+from urllib.error import HTTPError
+from urllib.parse import urlencode
+from urllib.request import Request, urlopen
+from uuid import UUID
+import json
+import logging
+
+from trio import MemorySendChannel, open_memory_channel, open_nursery
+from trio_websocket import ConnectionClosed, HandshakeError, open_websocket_url, WebSocketConnection
+import psycopg2
+import trio
+
+from . import TokenRef, message_sender_impl, try_request, ws_listener_impl, WsListenerState
+from .abc import DeleteEvent, Event, EventStream, Notifier, NotifyInfo, ThreadTitleProvider, UpdateChecker, UpdateEvent
+
+
+async def _token_update_impl(period: timedelta, db_conn, auth_id: int, task_status) -> None:
+ def fetch_token() -> str:
+ cursor = db_conn.cursor()
+ cursor.execute(
+ """
+ select access_token from public.reddit_app_authorization
+ where id = %s
+ """,
+ (auth_id,)
+ )
+ [(token,)] = cursor.fetchall()
+ return token
+
+ ref = TokenRef(fetch_token())
+ task_status.started(ref)
+ while True:
+ await trio.sleep(period.total_seconds())
+ ref.token = fetch_token()
+
+
+async def _api_poll_impl(
+ ws_state: WsListenerState,
+ token_ref: TokenRef,
+ thread_id: str,
+ event_tx: MemorySendChannel, # message type `Event'
+ poll_period: timedelta,
+ max_ws_delay: timedelta,
+ logger: Logger,
+) -> None:
+
+ base_url = f"https://oauth.reddit.com/live/{thread_id}"
+
+ latest_payload: Optional[dict] = None
+ while True:
+ # fetch updates and emit events
+ params = {
+ "raw_json": "1",
+ }
+ if latest_payload:
+ updates = []
+ while True:
+ params["before"] = latest_payload["data"]["name"]
+ url = base_url + "?" + urlencode(params)
+ resp = try_request(Request(url, headers = token_ref.get_common_headers()), logger)
+ if resp:
+ with resp:
+ children = json.load(resp)["data"]["children"]
+ if children:
+ updates.extend(reversed(children))
+ latest_payload = children[0]
+ else:
+ break
+ else:
+ break
+ else:
+ params["limit"] = "1"
+ url = base_url + "?" + urlencode(params)
+ resp = try_request(Request(url, headers = token_ref.get_common_headers()), logger)
+ if resp:
+ with resp:
+ updates = json.load(resp)["data"]["children"]
+ if updates:
+ latest_payload = updates[0]
+ else:
+ updates = []
+
+ for update in updates:
+ await event_tx.send(UpdateEvent(update["data"]))
+
+ # check if WS connection is behind
+ if not ws_state.down and latest_payload:
+ max_delta = max_ws_delay.total_seconds() * 10 ** 7 # in 100*nanosecond
+ if UUID(latest_payload["data"]["id"]).time - ws_state.last_update_ts > max_delta:
+ ws_state.scope.cancel()
+
+ await trio.sleep(poll_period.total_seconds())
+
+
+def _build_notification_md(info: NotifyInfo) -> str:
+ # TODO escape thread title?
+ return "\n\n".join([
+ f"from [{info.author}](/user/{info.author}) in [{info.thread_title}](/live/{info.thread_id})",
+ "\n".join(">" + line for line in info.update_body.splitlines()),
+ (
+ f"[^^permalink](/live/{info.thread_id}/updates/{info.update_id})"
+ + f" [^^context](/live/{info.thread_id}?after=LiveUpdate_{info.update_id})"
+ + " [^^about ^^these ^^notifications](/r/live_mentions/wiki)"
+ ),
+ ])
+
+
+@dataclass
+class RedditNotifier(Notifier):
+ token_ref: TokenRef
+ logger: Logger
+
+ def notify(self, info: NotifyInfo) -> None:
+ params = {
+ "raw_json": "1",
+ "api_type": "json",
+ "subject": "username mention",
+ "text": _build_notification_md(info),
+ }
+ req = Request(
+ "https://oauth.reddit.com/api/compose",
+ method = "POST",
+ data = urlencode({**params, "to": info.recipient}).encode("ascii"),
+ headers = self.token_ref.get_common_headers(),
+ )
+ resp = try_request(req, self.logger)
+ if resp:
+ with resp:
+ data = json.load(resp)["json"]
+ if data["errors"]:
+ [(code, _, _)] = data["errors"]
+ if code != "USER_DOESNT_EXIST":
+ raise RuntimeError("PM send error: {code}")
+
+
+@dataclass
+class RedditThreadTitleProvider(ThreadTitleProvider):
+ token_ref: TokenRef
+
+ def title_for(self, thread_id: str) -> str:
+ req = Request(
+ f"https://oauth.reddit.com/live/{thread_id}/about?raw_json=1",
+ headers = self.token_ref.get_common_headers()
+ )
+ with urlopen(req) as resp:
+ return json.load(resp)["data"]["title"]
+
+
+@dataclass
+class RedditUpdateChecker(UpdateChecker):
+ token_ref: TokenRef
+ logger: Logger
+
+ def exists(self, thread_id: str, update_id: str) -> bool:
+ req = Request(
+ f"https://oauth.reddit.com/live/{thread_id}/updates/{update_id}?raw_json=1",
+ headers = self.token_ref.get_common_headers(),
+ )
+ try:
+ resp = try_request(req, self.logger, suppress_codes = [429, 500, 503])
+ except HTTPError as e:
+ e.read()
+ e.close()
+ if e.code == 404:
+ return False
+ else:
+ raise
+ else:
+ if resp:
+ resp.read()
+ resp.close()
+ return True
+
+
+@dataclass
+class WsConnectionEventStream(EventStream):
+ EXCEPTION_TYPES: ClassVar[set[Type[BaseException]]] = {HandshakeError, ConnectionClosed}
+
+ ws_conn: WebSocketConnection
+
+ async def get_event(self) -> Event:
+ while True:
+ message = json.loads(await self.ws_conn.get_message())
+ if message["type"] == "update":
+ return UpdateEvent(message["payload"]["data"])
+ elif message["type"] == "delete":
+ return DeleteEvent(message["payload"])
+
+
+async def _main(
+ auth_id: int,
+ buffer_time: timedelta,
+ db_conn,
+ logger: Logger,
+ max_ws_delay: timedelta,
+ mention_limit: int,
+ poll_period: timedelta,
+ thread_id: str,
+ token_period: timedelta,
+) -> None:
+
+ @asynccontextmanager
+ async def event_stream_factory() -> Iterator[WsConnectionEventStream]:
+ resp = urlopen(Request(
+ f"https://oauth.reddit.com/live/{thread_id}/about?raw_json=1",
+ headers = token_ref.get_common_headers(),
+ ))
+ with resp:
+ ws_url = json.load(resp)["data"]["websocket_url"]
+
+ async with open_websocket_url(ws_url) as ws_conn:
+ yield WsConnectionEventStream(ws_conn)
+
+ async with open_nursery() as nursery:
+ token_ref = await nursery.start(_token_update_impl, token_period, db_conn, auth_id)
+
+ ws_state = WsListenerState()
+ (event_tx, event_rx) = open_memory_channel(0)
+ nursery.start_soon(
+ ws_listener_impl,
+ event_stream_factory, WsConnectionEventStream.EXCEPTION_TYPES, ws_state, event_tx, logger.getChild("WS")
+ )
+ nursery.start_soon(
+ _api_poll_impl,
+ ws_state, token_ref, thread_id, event_tx, poll_period, max_ws_delay, logger.getChild("poll")
+ )
+
+ notifier = RedditNotifier(token_ref, logger)
+ title_provider = RedditThreadTitleProvider(token_ref)
+ deletion_checker = RedditUpdateChecker(token_ref, logger)
+ nursery.start_soon(
+ message_sender_impl,
+ notifier, title_provider, deletion_checker, ws_state, thread_id, event_rx, buffer_time, mention_limit,
+ logger.getChild("sender")
+ )
+
+
+if __name__ == "__main__":
+ arg_parser = ArgumentParser(__package__)
+ arg_parser.add_argument("config_path")
+ args = arg_parser.parse_args()
+
+ config_parser = ConfigParser()
+ with open(args.config_path) as config_file:
+ config_parser.read_file(config_file)
+
+ main_cfg = config_parser["config"]
+
+ auth_id = main_cfg.getint("auth ID")
+ assert auth_id is not None
+
+ token_period = timedelta(minutes = main_cfg.getint("token update period"))
+
+ thread_id = main_cfg["thread ID"]
+
+ mention_limit = main_cfg.getint("mention limit")
+ assert mention_limit is not None and mention_limit >= 0
+
+ buffer_time = timedelta(seconds = main_cfg.getfloat("message buffer time"))
+
+ poll_period = timedelta(seconds = main_cfg.getfloat("API poll period"))
+
+ max_ws_delay = timedelta(seconds = main_cfg.getfloat("max WS delay"))
+
+ db_conn = psycopg2.connect(**config_parser["db connect params"])
+ db_conn.autocommit = True
+
+ logger = getLogger(__package__)
+ logger.setLevel(NOTSET - 1) # filter in handler(s) instead
+
+ handler = StreamHandler(stdout)
+ handler.setLevel(logging.DEBUG)
+ handler.setFormatter(Formatter("{asctime:23}: {name:17}: {levelname:8}: {message}", style = "{"))
+ logger.addHandler(handler)
+
+ trio.run(
+ _main,
+ auth_id, buffer_time, db_conn, logger, max_ws_delay, mention_limit, poll_period, thread_id, token_period
+ )
--- /dev/null
+from abc import ABCMeta, abstractmethod
+from dataclasses import dataclass
+
+
+class Event(metaclass = ABCMeta):
+ pass
+
+
+@dataclass
+class UpdateEvent(Event):
+ payload_data: dict # object at "payload" -> "data" in the message JSON
+
+
+@dataclass
+class DeleteEvent(Event):
+ payload: str # update name
+
+
+class EventStream(metaclass = ABCMeta):
+ @abstractmethod
+ async def get_event(self) -> Event:
+ raise NotImplementedError()
+
+
+@dataclass
+class NotifyInfo:
+ recipient: str
+ thread_id: str
+ thread_title: str
+ update_id: str
+ author: str
+ update_body: str
+
+
+class Notifier(metaclass = ABCMeta):
+ @abstractmethod
+ def notify(self, info: NotifyInfo) -> None:
+ raise NotImplementedError()
+
+
+class ThreadTitleProvider(metaclass = ABCMeta):
+ @abstractmethod
+ def title_for(self, thread_id: str) -> str:
+ raise NotImplementedError()
+
+
+class UpdateChecker(metaclass = ABCMeta):
+ @abstractmethod
+ def exists(self, thread_id: str, update_id: str) -> bool:
+ raise NotImplementedError()
--- /dev/null
+from contextlib import asynccontextmanager, contextmanager
+from dataclasses import dataclass, field
+from datetime import timedelta
+from logging import getLogger
+from numbers import Real
+from os import environ
+from typing import Iterator, TypeVar
+from unittest import TestCase
+from urllib.request import Request
+from uuid import UUID, uuid1
+import logging
+
+from trio import CancelScope, MemoryReceiveChannel, move_on_after, open_memory_channel, open_nursery, sleep_until
+from trio.testing import MockClock, wait_all_tasks_blocked
+import trio
+
+from . import message_sender_impl, TokenRef, try_request, WsListenerState, ws_listener_impl
+from .abc import DeleteEvent, Event, EventStream, Notifier, NotifyInfo, ThreadTitleProvider, UpdateChecker, UpdateEvent
+
+
+T = TypeVar("T")
+
+
+@dataclass
+class MockNotifier(Notifier):
+ notifies: list[NotifyInfo] = field(default_factory = list)
+
+ def notify(self, info: NotifyInfo) -> None:
+ self.notifies.append(info)
+
+
+class MockTitleProvider(ThreadTitleProvider):
+ def title_for(self, thread_id: str) -> str:
+ return thread_id
+
+
+class UnusedUpdateChecker(UpdateChecker):
+ def exists(self, thread_id: str, update_id: str) -> bool:
+ raise AssertionError("unexpected deletion check")
+
+
+@dataclass
+class DictBackedUpdateChecker(UpdateChecker):
+ exists_by_id: dict[str, bool] = field(default_factory = dict)
+ queries: list[tuple[str, str]] = field(default_factory = list)
+
+ def exists(self, thread_id: str, update_id: str) -> bool:
+ self.queries.append((thread_id, update_id))
+ return self.exists_by_id[update_id]
+
+
+@dataclass
+class ChannelBackedEventStream(EventStream):
+ event_rx: MemoryReceiveChannel
+
+ async def get_event(self) -> Event:
+ obj = await self.event_rx.receive()
+ if isinstance(obj, BaseException):
+ raise obj
+ else:
+ return obj
+
+
+@asynccontextmanager
+async def null_async_context(val: T) -> Iterator[T]:
+ yield val
+
+
+def _update_event(author: str, mentions: list[str]) -> UpdateEvent:
+ id_ = uuid1()
+ body = str(id_) + " " + " ".join("/u/" + user for user in mentions)
+ body_html = str(id_) + " " + " ".join(
+ f'<a href="/u/{user}">/u/{user}</a>'
+ for user in mentions
+ )
+ return UpdateEvent(
+ {
+ "id": str(id_),
+ "name": "LiveUpdate_" + str(id_),
+ "author": author,
+ "body": body,
+ "body_html": body_html,
+ }
+ )
+
+
+def setUpModule() -> None:
+ getLogger().setLevel(logging.CRITICAL + 1) # disable logging
+
+
+class TaskTests(TestCase):
+ @contextmanager
+ def _assert_completes_in(self, duration_seconds: Real) -> Iterator[None]:
+ with move_on_after(duration_seconds) as scope:
+ yield
+ if scope.cancelled_caught:
+ self.fail("operation timed out")
+
+ def test_buffering(self) -> None:
+ """
+ No API poll task, just a custom task mocking the WS listener by playing events over a channel, and a real
+ notifier task injected with mocks to capture output. The WS state is hard-coded as up, so no testing of recovery
+ from WS connection issues is included in this test.
+
+ Primarily it makes sure that updates are passed through the buffer and processed if and only if a delete event
+ doesn't arrive while they're buffered.
+ """
+
+ buffer_time = timedelta(seconds = 1)
+ notifier = MockNotifier()
+ (event_tx, event_rx) = open_memory_channel(0)
+
+ ign_event = _update_event("sender", [])
+ event_a = _update_event("sender", ["rec1", "rec2"])
+ event_b = _update_event("sender", ["rec1", "rec2"])
+
+ async def event_stream_impl(scope: CancelScope) -> None:
+ await sleep_until(1.0)
+ await event_tx.send(ign_event)
+ await event_tx.send(event_a)
+
+ await sleep_until(1.5)
+ self.assertFalse(notifier.notifies)
+ await event_tx.send(DeleteEvent(event_a.payload_data["id"]))
+
+ await sleep_until(2.0)
+ await event_tx.send(event_b)
+
+ # let things settle and terminate both tasks
+ await sleep_until(4.0)
+ scope.cancel()
+
+ async def main() -> None:
+ async with open_nursery() as nursery:
+ nursery.start_soon(event_stream_impl, nursery.cancel_scope)
+ nursery.start_soon(
+ message_sender_impl,
+ notifier, MockTitleProvider(), UnusedUpdateChecker(), WsListenerState(down = False), "ThreadId",
+ event_rx, buffer_time, getLogger()
+ )
+
+ trio.run(main, clock = MockClock(autojump_threshold = 0))
+
+ # only notifications sent are for event B
+ self.assertEqual(len(notifier.notifies), 2)
+ self.assertEqual({i.recipient for i in notifier.notifies}, {"rec1", "rec2"})
+ self.assertTrue(all(i.author == "sender" for i in notifier.notifies))
+ self.assertTrue(all(i.update_id == event_b.payload_data["id"] for i in notifier.notifies))
+
+ def test_ws_trouble(self) -> None:
+ """
+ Ensure that the notifier task responds correctly to the WS connection going down and coming back up.
+ """
+
+ thread_id = "ThreadId"
+ buffer_time = timedelta(seconds = 3)
+ update_checker = DictBackedUpdateChecker()
+ notifier = MockNotifier()
+ (event_tx, event_rx) = open_memory_channel(0)
+
+ # sent while WS is up, checked and notifies sent
+ event_a = _update_event("sender", ["rec"])
+
+ # sent and deleted while WS is down, checked and discarded from buffer
+ event_b = _update_event("sender", ["rec"])
+
+ # sent on WS reconnect, not checked and notifies sent
+ event_c = _update_event("sender", ["rec"])
+
+ # sent after WS reconnect, not checked and notifies sent
+ event_d = _update_event("sender", ["rec"])
+
+ ws_state = WsListenerState(down = False)
+
+ async def ws_simulator_impl() -> None:
+ update_checker.exists_by_id[event_a.payload_data["id"]] = True
+ await event_tx.send(event_a)
+
+ await sleep_until(1.0)
+ ws_state.down = True
+ update_checker.exists_by_id[event_b.payload_data["id"]] = True
+ await event_tx.send(event_b)
+
+ await sleep_until(2.0)
+ update_checker.exists_by_id[event_b.payload_data["id"]] = False
+ ws_state.down = False
+ ws_state.resync_ts = UUID(event_c.payload_data["id"]).time
+ await event_tx.send(event_c)
+
+ await sleep_until(3.0)
+ await event_tx.send(event_d)
+
+ async def main() -> None:
+ async with open_nursery() as nursery:
+ nursery.start_soon(ws_simulator_impl)
+ nursery.start_soon(
+ message_sender_impl,
+ notifier, MockTitleProvider(), update_checker, ws_state, "ThreadId", event_rx, buffer_time,
+ getLogger()
+ )
+
+ await sleep_until(3.5)
+ self.assertEqual([i.update_id for i in notifier.notifies], [event_a.payload_data["id"]])
+ notifier.notifies.clear()
+ self.assertEqual(update_checker.queries, [(thread_id, event_a.payload_data["id"])])
+ update_checker.queries.clear()
+
+ await sleep_until(4.5)
+ self.assertFalse(notifier.notifies)
+ self.assertEqual(update_checker.queries, [(thread_id, event_b.payload_data["id"])])
+ update_checker.queries.clear()
+
+ await sleep_until(5.5)
+ self.assertEqual([i.update_id for i in notifier.notifies], [event_c.payload_data["id"]])
+ notifier.notifies.clear()
+ self.assertFalse(update_checker.queries)
+
+ await sleep_until(6.5)
+ self.assertEqual([i.update_id for i in notifier.notifies], [event_d.payload_data["id"]])
+ notifier.notifies.clear()
+ self.assertFalse(update_checker.queries)
+
+ nursery.cancel_scope.cancel()
+
+ trio.run(main, clock = MockClock(autojump_threshold = 0))
+
+ def test_ws_listener_cancel(self) -> None:
+ ws_state = WsListenerState()
+
+ (event_in_tx, event_in_rx) = open_memory_channel(0)
+ event_stream = ChannelBackedEventStream(event_in_rx)
+ event_stream_factory = lambda: null_async_context(event_stream)
+
+ (event_out_tx, event_out_rx) = open_memory_channel(0)
+
+ async def main():
+ async with open_nursery() as nursery:
+ nursery.start_soon(
+ ws_listener_impl,
+ event_stream_factory, set(), ws_state, event_out_tx, getLogger()
+ )
+
+ event_1 = _update_event("author", [])
+ await event_in_tx.send(event_1)
+ with self._assert_completes_in(1):
+ event_out = await event_out_rx.receive()
+ self.assertIs(event_out, event_1)
+
+ event_2 = _update_event("author", ["recipient"])
+ await event_in_tx.send(event_2)
+ with self._assert_completes_in(1):
+ event_out = await event_out_rx.receive()
+ self.assertIs(event_out, event_2)
+
+ initial_scope = ws_state.scope
+ initial_scope.cancel()
+ await wait_all_tasks_blocked()
+ self.assertIsNotNone(ws_state.scope)
+ self.assertIsNot(ws_state.scope, initial_scope)
+ self.assertTrue(ws_state.down)
+
+ event_3 = _update_event("author", [])
+ ts = UUID(event_3.payload_data["id"]).time
+ await event_in_tx.send(event_3)
+ with self._assert_completes_in(1):
+ event_out = await event_out_rx.receive()
+ self.assertIs(event_out, event_3)
+ self.assertEqual(ws_state.last_update_ts, ts)
+ self.assertFalse(ws_state.down)
+ self.assertEqual(ws_state.resync_ts, ts)
+
+ nursery.cancel_scope.cancel()
+
+ trio.run(main, clock = MockClock(autojump_threshold = 0))
+
+ def test_ws_listener_error(self) -> None:
+ """Test that the WS listener responds correctly to errors from the input event stream."""
+
+ class MockStreamException(Exception):
+ pass
+
+ ws_state = WsListenerState()
+
+ (event_in_tx, event_in_rx) = open_memory_channel(0)
+ event_stream = ChannelBackedEventStream(event_in_rx)
+ event_stream_factory = lambda: null_async_context(event_stream)
+
+ (event_out_tx, event_out_rx) = open_memory_channel(0)
+
+ async def main():
+ async with open_nursery() as nursery:
+ nursery.start_soon(
+ ws_listener_impl,
+ event_stream_factory, {MockStreamException}, ws_state, event_out_tx, getLogger()
+ )
+
+ event_1 = _update_event("author", [])
+ await event_in_tx.send(event_1)
+ with self._assert_completes_in(1):
+ event_out = await event_out_rx.receive()
+ self.assertIs(event_out, event_1)
+
+ initial_scope = ws_state.scope
+ await event_in_tx.send(MockStreamException())
+ await wait_all_tasks_blocked()
+ self.assertIsNotNone(ws_state.scope)
+ self.assertIsNot(ws_state.scope, initial_scope)
+ self.assertTrue(ws_state.down)
+
+ event_2 = _update_event("author", ["recipient"])
+ ts = UUID(event_2.payload_data["id"]).time
+ await event_in_tx.send(event_2)
+ with self._assert_completes_in(1):
+ event_out = await event_out_rx.receive()
+ self.assertIs(event_out, event_2)
+ self.assertEqual(ws_state.last_update_ts, ts)
+ self.assertFalse(ws_state.down)
+ self.assertEqual(ws_state.resync_ts, ts)
+
+ nursery.cancel_scope.cancel()
+
+ trio.run(main, clock = MockClock(autojump_threshold = 0))
+
+
+class RedditTests(TestCase):
+ """Tests for interactions with Reddit."""
+
+ @classmethod
+ def setUpClass(cls) -> None:
+ cls._TOKEN_REF = TokenRef(environ["API_TOKEN"])
+
+ def test_try_request(self) -> None:
+ self.assertIsNotNone(
+ try_request(
+ Request(
+ "https://oauth.reddit.com/api/live/happening_now",
+ headers = self._TOKEN_REF.get_common_headers(),
+ ),
+ getLogger(),
+ )
+ )
+
+ def test_try_request_http_error(self) -> None:
+ self.assertIsNone(
+ try_request(
+ Request("https://oauth.reddit.com/api/live/happening_now"),
+ getLogger(),
+ suppress_codes = [403],
+ )
+ )