From: Jakob Cornell Date: Mon, 19 Dec 2022 04:34:34 +0000 (-0600) Subject: Implement Mentionbot and tests X-Git-Tag: strikebot-0.0.8~2 X-Git-Url: https://jcornell.net/gitweb/gitweb.cgi?a=commitdiff_plain;h=9dafaec88b414f805d134d44e009e88698a85647;p=counting.git Implement Mentionbot and tests --- diff --git a/mentionbot/docs/sample_config.ini b/mentionbot/docs/sample_config.ini new file mode 100644 index 0000000..4228b6a --- /dev/null +++ b/mentionbot/docs/sample_config.ini @@ -0,0 +1,28 @@ +# see Python `configparser' docs for precise file syntax + +[config] + +# authorization ID for Reddit API token lookup +auth ID = 3 + +# (int, minutes) +token update period = 15 + +thread ID = abc123 + +# (int) maximum number of mentions within an update for which to send messages +mention limit = 10 + +# (float, seconds) wait this much time before sending notifications, in case of update deletion +message buffer time = 10 + +# (float, seconds) wait this much time between API polls (and WS connection checks) +API poll period = 90 + +# (float, seconds) replace the WS connection if it falls this far behind the API in update timestamps +max WS delay = 2 + + +# Postgres database configuration; same options as in a connect string. +[db connect params] +host = example.org diff --git a/mentionbot/pyproject.toml b/mentionbot/pyproject.toml new file mode 100644 index 0000000..8fe2f47 --- /dev/null +++ b/mentionbot/pyproject.toml @@ -0,0 +1,3 @@ +[build-system] +requires = ["setuptools>=42", "wheel"] +build-backend = "setuptools.build_meta" diff --git a/mentionbot/setup.cfg b/mentionbot/setup.cfg new file mode 100644 index 0000000..b3f9866 --- /dev/null +++ b/mentionbot/setup.cfg @@ -0,0 +1,15 @@ +[metadata] +name = mentionbot +version = 0.0.0 + +[options] +package_dir = + = src +packages = mentionbot +python_requires = ~= 3.9 +install_requires = + beautifulsoup4 ~= 4.9 + #psycopg2 ~= 2.8 + psycopg2-binary ~= 2.8 + trio == 0.19 + trio-websocket == 0.9.2 diff --git a/mentionbot/setup.py b/mentionbot/setup.py new file mode 100644 index 0000000..056ba45 --- /dev/null +++ b/mentionbot/setup.py @@ -0,0 +1,4 @@ +import setuptools + + +setuptools.setup() diff --git a/mentionbot/src/mentionbot/__init__.py b/mentionbot/src/mentionbot/__init__.py new file mode 100644 index 0000000..76fd7e2 --- /dev/null +++ b/mentionbot/src/mentionbot/__init__.py @@ -0,0 +1,187 @@ +from dataclasses import dataclass +from datetime import timedelta +from http.client import HTTPResponse +from logging import Logger +from socket import EAI_AGAIN, EAI_FAIL, gaierror +from typing import AsyncContextManager, Callable, Iterable, Optional, Type +from urllib.error import HTTPError +from urllib.request import Request, urlopen +from uuid import UUID +import importlib.metadata +import re + +from bs4 import BeautifulSoup +from trio import CancelScope, MemoryReceiveChannel, MemorySendChannel, move_on_at +import trio + +from .abc import DeleteEvent, EventStream, Notifier, NotifyInfo, ThreadTitleProvider, UpdateChecker, UpdateEvent + + +_USER_AGENT = "any:net.jcornell.mentionbot:v{} (by /u/jaxklax)".format( + importlib.metadata.version(__package__) +) + + +@dataclass +class TokenRef: + """Mutable container for an access token.""" + + token: str + + def get_common_headers(self) -> dict[str, str]: + return { + "User-Agent": _USER_AGENT, + "Authorization": "Bearer {}".format(self.token) + } + + +def try_request( + request: Request, + logger: Logger, + suppress_codes: list[int] = [404, 429, 500, 503], +) -> Optional[HTTPResponse]: + + try: + return urlopen(request) + except gaierror as e: + if e.errno in [EAI_FAIL, EAI_AGAIN]: + logger.error(f"DNS failure: {e}") + return None + else: + raise + except HTTPError as e: + e.read() + e.close() + if e.code in suppress_codes: + logger.error(f"HTTP {e.code}") + return None + else: + raise + + +@dataclass +class WsListenerState: + """The WebSockets listener uses this object to communicate with the API poll and message sender tasks.""" + + # The API poll task uses this scope to force a refresh of the WS connection if it's been too quiet. + scope: Optional[CancelScope] = None + + # The WS task uses this to track the UUID timestamp of the most recent update message update, so the API poll task + # can detect when the connection goes silent. + last_update_ts: Optional[int] = None + + # The WS task sets this when the WS connection goes down and clears it when it's recovered. + down: bool = True + + # The WS task sets this to the UUID timestamp of the first update recieved on connection recovery so the message + # sender task knows when to stop checking for deletion. + resync_ts: Optional[int] = None + + +def _parse_mentions(body: BeautifulSoup) -> set[str]: + """Extract the names of users tagged in an update.""" + return { + el["href"].removeprefix("/u/") + for el in body.find_all("a", href = re.compile(r"^/u/[\w-]{3,20}$", re.ASCII)) + } + + +async def ws_listener_impl( + event_stream_factory: Callable[[], AsyncContextManager[EventStream]], + event_stream_catch: Iterable[Type[BaseException]], + state: WsListenerState, + event_tx: MemorySendChannel, # message type `Event' + logger: Logger, +) -> None: + + while True: + state.scope = CancelScope() + with state.scope: + logger.debug("opening a new connection") + try: + async with event_stream_factory() as event_stream: + first_update = True + while True: + event = await event_stream.get_event() + if isinstance(event, UpdateEvent): + update_id = UUID(event.payload_data["id"]) + state.last_update_ts = update_id.time + if first_update: + # first update after discontinuity + state.down = False + state.resync_ts = update_id.time + first_update = False + elif isinstance(event, DeleteEvent): + logger.debug("delete event for " + event.payload) + else: + raise RuntimeError(f"expected Event, got {event}") + + await event_tx.send(event) + except tuple(event_stream_catch) as e: + logger.warning(f"WebSockets error: {e}") + + state.down = True + + +async def message_sender_impl( + notifier: Notifier, + title_provider: ThreadTitleProvider, + update_checker: UpdateChecker, + ws_state: WsListenerState, + thread_id: str, + event_rx: MemoryReceiveChannel, # message type `Event' + buffer_time: timedelta, + mention_limit: int, + logger: Logger, +) -> None: + + @dataclass + class QueueEntry: + payload_data: dict + mentions: set[str] + arrival: float # Trio time + + queue: list[QueueEntry] = [] + + while True: + logger.debug("queue: {}".format([e.payload_data["id"] for e in queue])) + if queue: + scope = move_on_at(queue[0].arrival + buffer_time.total_seconds()) + else: + scope = CancelScope() # no timeout + + with scope: + event = await event_rx.receive() + + if scope.cancelled_caught: + # timed out + entry = queue.pop(0) + update_id = entry.payload_data["id"] + if ws_state.down or ws_state.resync_ts and UUID(update_id).time < ws_state.resync_ts: + notify = update_checker.exists(thread_id, update_id) + else: + notify = True + + if notify: + thread_title = title_provider.title_for(thread_id) + for recipient in entry.mentions: + notifier.notify(NotifyInfo( + recipient, thread_id, thread_title, update_id, entry.payload_data["author"], + entry.payload_data["body"] + )) + else: + # got an event + if isinstance(event, UpdateEvent): + update_id = event.payload_data["id"] + if not any(entry.payload_data["id"] == update_id for entry in queue): + mentions = _parse_mentions(BeautifulSoup(event.payload_data["body_html"], "html.parser")) + if mentions: + if len(mentions) > mention_limit: + logger.info(f"ignoring {update_id} due to too many mentions ({len(mentions)})") + else: + queue.append(QueueEntry(event.payload_data, mentions, trio.current_time())) + elif isinstance(event, DeleteEvent): + logger.debug("delete event for " + event.payload) + queue = [entry for entry in queue if entry.payload_data["name"] != event.payload] + else: + raise RuntimeError(f"invalid event type {type(event)}") diff --git a/mentionbot/src/mentionbot/__main__.py b/mentionbot/src/mentionbot/__main__.py new file mode 100644 index 0000000..960a228 --- /dev/null +++ b/mentionbot/src/mentionbot/__main__.py @@ -0,0 +1,291 @@ +""" +Mentionbot: Notify users by private message when they're /u/-mentioned in a live thread. +""" + +from argparse import ArgumentParser +from configparser import ConfigParser +from contextlib import asynccontextmanager +from dataclasses import dataclass +from datetime import timedelta +from logging import Formatter, getLogger, Logger, NOTSET, StreamHandler +from sys import stdout +from typing import ClassVar, Iterator, Optional, Type +from urllib.error import HTTPError +from urllib.parse import urlencode +from urllib.request import Request, urlopen +from uuid import UUID +import json +import logging + +from trio import MemorySendChannel, open_memory_channel, open_nursery +from trio_websocket import ConnectionClosed, HandshakeError, open_websocket_url, WebSocketConnection +import psycopg2 +import trio + +from . import TokenRef, message_sender_impl, try_request, ws_listener_impl, WsListenerState +from .abc import DeleteEvent, Event, EventStream, Notifier, NotifyInfo, ThreadTitleProvider, UpdateChecker, UpdateEvent + + +async def _token_update_impl(period: timedelta, db_conn, auth_id: int, task_status) -> None: + def fetch_token() -> str: + cursor = db_conn.cursor() + cursor.execute( + """ + select access_token from public.reddit_app_authorization + where id = %s + """, + (auth_id,) + ) + [(token,)] = cursor.fetchall() + return token + + ref = TokenRef(fetch_token()) + task_status.started(ref) + while True: + await trio.sleep(period.total_seconds()) + ref.token = fetch_token() + + +async def _api_poll_impl( + ws_state: WsListenerState, + token_ref: TokenRef, + thread_id: str, + event_tx: MemorySendChannel, # message type `Event' + poll_period: timedelta, + max_ws_delay: timedelta, + logger: Logger, +) -> None: + + base_url = f"https://oauth.reddit.com/live/{thread_id}" + + latest_payload: Optional[dict] = None + while True: + # fetch updates and emit events + params = { + "raw_json": "1", + } + if latest_payload: + updates = [] + while True: + params["before"] = latest_payload["data"]["name"] + url = base_url + "?" + urlencode(params) + resp = try_request(Request(url, headers = token_ref.get_common_headers()), logger) + if resp: + with resp: + children = json.load(resp)["data"]["children"] + if children: + updates.extend(reversed(children)) + latest_payload = children[0] + else: + break + else: + break + else: + params["limit"] = "1" + url = base_url + "?" + urlencode(params) + resp = try_request(Request(url, headers = token_ref.get_common_headers()), logger) + if resp: + with resp: + updates = json.load(resp)["data"]["children"] + if updates: + latest_payload = updates[0] + else: + updates = [] + + for update in updates: + await event_tx.send(UpdateEvent(update["data"])) + + # check if WS connection is behind + if not ws_state.down and latest_payload: + max_delta = max_ws_delay.total_seconds() * 10 ** 7 # in 100*nanosecond + if UUID(latest_payload["data"]["id"]).time - ws_state.last_update_ts > max_delta: + ws_state.scope.cancel() + + await trio.sleep(poll_period.total_seconds()) + + +def _build_notification_md(info: NotifyInfo) -> str: + # TODO escape thread title? + return "\n\n".join([ + f"from [{info.author}](/user/{info.author}) in [{info.thread_title}](/live/{info.thread_id})", + "\n".join(">" + line for line in info.update_body.splitlines()), + ( + f"[^^permalink](/live/{info.thread_id}/updates/{info.update_id})" + + f" [^^context](/live/{info.thread_id}?after=LiveUpdate_{info.update_id})" + + " [^^about ^^these ^^notifications](/r/live_mentions/wiki)" + ), + ]) + + +@dataclass +class RedditNotifier(Notifier): + token_ref: TokenRef + logger: Logger + + def notify(self, info: NotifyInfo) -> None: + params = { + "raw_json": "1", + "api_type": "json", + "subject": "username mention", + "text": _build_notification_md(info), + } + req = Request( + "https://oauth.reddit.com/api/compose", + method = "POST", + data = urlencode({**params, "to": info.recipient}).encode("ascii"), + headers = self.token_ref.get_common_headers(), + ) + resp = try_request(req, self.logger) + if resp: + with resp: + data = json.load(resp)["json"] + if data["errors"]: + [(code, _, _)] = data["errors"] + if code != "USER_DOESNT_EXIST": + raise RuntimeError("PM send error: {code}") + + +@dataclass +class RedditThreadTitleProvider(ThreadTitleProvider): + token_ref: TokenRef + + def title_for(self, thread_id: str) -> str: + req = Request( + f"https://oauth.reddit.com/live/{thread_id}/about?raw_json=1", + headers = self.token_ref.get_common_headers() + ) + with urlopen(req) as resp: + return json.load(resp)["data"]["title"] + + +@dataclass +class RedditUpdateChecker(UpdateChecker): + token_ref: TokenRef + logger: Logger + + def exists(self, thread_id: str, update_id: str) -> bool: + req = Request( + f"https://oauth.reddit.com/live/{thread_id}/updates/{update_id}?raw_json=1", + headers = self.token_ref.get_common_headers(), + ) + try: + resp = try_request(req, self.logger, suppress_codes = [429, 500, 503]) + except HTTPError as e: + e.read() + e.close() + if e.code == 404: + return False + else: + raise + else: + if resp: + resp.read() + resp.close() + return True + + +@dataclass +class WsConnectionEventStream(EventStream): + EXCEPTION_TYPES: ClassVar[set[Type[BaseException]]] = {HandshakeError, ConnectionClosed} + + ws_conn: WebSocketConnection + + async def get_event(self) -> Event: + while True: + message = json.loads(await self.ws_conn.get_message()) + if message["type"] == "update": + return UpdateEvent(message["payload"]["data"]) + elif message["type"] == "delete": + return DeleteEvent(message["payload"]) + + +async def _main( + auth_id: int, + buffer_time: timedelta, + db_conn, + logger: Logger, + max_ws_delay: timedelta, + mention_limit: int, + poll_period: timedelta, + thread_id: str, + token_period: timedelta, +) -> None: + + @asynccontextmanager + async def event_stream_factory() -> Iterator[WsConnectionEventStream]: + resp = urlopen(Request( + f"https://oauth.reddit.com/live/{thread_id}/about?raw_json=1", + headers = token_ref.get_common_headers(), + )) + with resp: + ws_url = json.load(resp)["data"]["websocket_url"] + + async with open_websocket_url(ws_url) as ws_conn: + yield WsConnectionEventStream(ws_conn) + + async with open_nursery() as nursery: + token_ref = await nursery.start(_token_update_impl, token_period, db_conn, auth_id) + + ws_state = WsListenerState() + (event_tx, event_rx) = open_memory_channel(0) + nursery.start_soon( + ws_listener_impl, + event_stream_factory, WsConnectionEventStream.EXCEPTION_TYPES, ws_state, event_tx, logger.getChild("WS") + ) + nursery.start_soon( + _api_poll_impl, + ws_state, token_ref, thread_id, event_tx, poll_period, max_ws_delay, logger.getChild("poll") + ) + + notifier = RedditNotifier(token_ref, logger) + title_provider = RedditThreadTitleProvider(token_ref) + deletion_checker = RedditUpdateChecker(token_ref, logger) + nursery.start_soon( + message_sender_impl, + notifier, title_provider, deletion_checker, ws_state, thread_id, event_rx, buffer_time, mention_limit, + logger.getChild("sender") + ) + + +if __name__ == "__main__": + arg_parser = ArgumentParser(__package__) + arg_parser.add_argument("config_path") + args = arg_parser.parse_args() + + config_parser = ConfigParser() + with open(args.config_path) as config_file: + config_parser.read_file(config_file) + + main_cfg = config_parser["config"] + + auth_id = main_cfg.getint("auth ID") + assert auth_id is not None + + token_period = timedelta(minutes = main_cfg.getint("token update period")) + + thread_id = main_cfg["thread ID"] + + mention_limit = main_cfg.getint("mention limit") + assert mention_limit is not None and mention_limit >= 0 + + buffer_time = timedelta(seconds = main_cfg.getfloat("message buffer time")) + + poll_period = timedelta(seconds = main_cfg.getfloat("API poll period")) + + max_ws_delay = timedelta(seconds = main_cfg.getfloat("max WS delay")) + + db_conn = psycopg2.connect(**config_parser["db connect params"]) + db_conn.autocommit = True + + logger = getLogger(__package__) + logger.setLevel(NOTSET - 1) # filter in handler(s) instead + + handler = StreamHandler(stdout) + handler.setLevel(logging.DEBUG) + handler.setFormatter(Formatter("{asctime:23}: {name:17}: {levelname:8}: {message}", style = "{")) + logger.addHandler(handler) + + trio.run( + _main, + auth_id, buffer_time, db_conn, logger, max_ws_delay, mention_limit, poll_period, thread_id, token_period + ) diff --git a/mentionbot/src/mentionbot/abc.py b/mentionbot/src/mentionbot/abc.py new file mode 100644 index 0000000..3c1d7fd --- /dev/null +++ b/mentionbot/src/mentionbot/abc.py @@ -0,0 +1,50 @@ +from abc import ABCMeta, abstractmethod +from dataclasses import dataclass + + +class Event(metaclass = ABCMeta): + pass + + +@dataclass +class UpdateEvent(Event): + payload_data: dict # object at "payload" -> "data" in the message JSON + + +@dataclass +class DeleteEvent(Event): + payload: str # update name + + +class EventStream(metaclass = ABCMeta): + @abstractmethod + async def get_event(self) -> Event: + raise NotImplementedError() + + +@dataclass +class NotifyInfo: + recipient: str + thread_id: str + thread_title: str + update_id: str + author: str + update_body: str + + +class Notifier(metaclass = ABCMeta): + @abstractmethod + def notify(self, info: NotifyInfo) -> None: + raise NotImplementedError() + + +class ThreadTitleProvider(metaclass = ABCMeta): + @abstractmethod + def title_for(self, thread_id: str) -> str: + raise NotImplementedError() + + +class UpdateChecker(metaclass = ABCMeta): + @abstractmethod + def exists(self, thread_id: str, update_id: str) -> bool: + raise NotImplementedError() diff --git a/mentionbot/src/mentionbot/tests.py b/mentionbot/src/mentionbot/tests.py new file mode 100644 index 0000000..0617c90 --- /dev/null +++ b/mentionbot/src/mentionbot/tests.py @@ -0,0 +1,350 @@ +from contextlib import asynccontextmanager, contextmanager +from dataclasses import dataclass, field +from datetime import timedelta +from logging import getLogger +from numbers import Real +from os import environ +from typing import Iterator, TypeVar +from unittest import TestCase +from urllib.request import Request +from uuid import UUID, uuid1 +import logging + +from trio import CancelScope, MemoryReceiveChannel, move_on_after, open_memory_channel, open_nursery, sleep_until +from trio.testing import MockClock, wait_all_tasks_blocked +import trio + +from . import message_sender_impl, TokenRef, try_request, WsListenerState, ws_listener_impl +from .abc import DeleteEvent, Event, EventStream, Notifier, NotifyInfo, ThreadTitleProvider, UpdateChecker, UpdateEvent + + +T = TypeVar("T") + + +@dataclass +class MockNotifier(Notifier): + notifies: list[NotifyInfo] = field(default_factory = list) + + def notify(self, info: NotifyInfo) -> None: + self.notifies.append(info) + + +class MockTitleProvider(ThreadTitleProvider): + def title_for(self, thread_id: str) -> str: + return thread_id + + +class UnusedUpdateChecker(UpdateChecker): + def exists(self, thread_id: str, update_id: str) -> bool: + raise AssertionError("unexpected deletion check") + + +@dataclass +class DictBackedUpdateChecker(UpdateChecker): + exists_by_id: dict[str, bool] = field(default_factory = dict) + queries: list[tuple[str, str]] = field(default_factory = list) + + def exists(self, thread_id: str, update_id: str) -> bool: + self.queries.append((thread_id, update_id)) + return self.exists_by_id[update_id] + + +@dataclass +class ChannelBackedEventStream(EventStream): + event_rx: MemoryReceiveChannel + + async def get_event(self) -> Event: + obj = await self.event_rx.receive() + if isinstance(obj, BaseException): + raise obj + else: + return obj + + +@asynccontextmanager +async def null_async_context(val: T) -> Iterator[T]: + yield val + + +def _update_event(author: str, mentions: list[str]) -> UpdateEvent: + id_ = uuid1() + body = str(id_) + " " + " ".join("/u/" + user for user in mentions) + body_html = str(id_) + " " + " ".join( + f'/u/{user}' + for user in mentions + ) + return UpdateEvent( + { + "id": str(id_), + "name": "LiveUpdate_" + str(id_), + "author": author, + "body": body, + "body_html": body_html, + } + ) + + +def setUpModule() -> None: + getLogger().setLevel(logging.CRITICAL + 1) # disable logging + + +class TaskTests(TestCase): + @contextmanager + def _assert_completes_in(self, duration_seconds: Real) -> Iterator[None]: + with move_on_after(duration_seconds) as scope: + yield + if scope.cancelled_caught: + self.fail("operation timed out") + + def test_buffering(self) -> None: + """ + No API poll task, just a custom task mocking the WS listener by playing events over a channel, and a real + notifier task injected with mocks to capture output. The WS state is hard-coded as up, so no testing of recovery + from WS connection issues is included in this test. + + Primarily it makes sure that updates are passed through the buffer and processed if and only if a delete event + doesn't arrive while they're buffered. + """ + + buffer_time = timedelta(seconds = 1) + notifier = MockNotifier() + (event_tx, event_rx) = open_memory_channel(0) + + ign_event = _update_event("sender", []) + event_a = _update_event("sender", ["rec1", "rec2"]) + event_b = _update_event("sender", ["rec1", "rec2"]) + + async def event_stream_impl(scope: CancelScope) -> None: + await sleep_until(1.0) + await event_tx.send(ign_event) + await event_tx.send(event_a) + + await sleep_until(1.5) + self.assertFalse(notifier.notifies) + await event_tx.send(DeleteEvent(event_a.payload_data["id"])) + + await sleep_until(2.0) + await event_tx.send(event_b) + + # let things settle and terminate both tasks + await sleep_until(4.0) + scope.cancel() + + async def main() -> None: + async with open_nursery() as nursery: + nursery.start_soon(event_stream_impl, nursery.cancel_scope) + nursery.start_soon( + message_sender_impl, + notifier, MockTitleProvider(), UnusedUpdateChecker(), WsListenerState(down = False), "ThreadId", + event_rx, buffer_time, getLogger() + ) + + trio.run(main, clock = MockClock(autojump_threshold = 0)) + + # only notifications sent are for event B + self.assertEqual(len(notifier.notifies), 2) + self.assertEqual({i.recipient for i in notifier.notifies}, {"rec1", "rec2"}) + self.assertTrue(all(i.author == "sender" for i in notifier.notifies)) + self.assertTrue(all(i.update_id == event_b.payload_data["id"] for i in notifier.notifies)) + + def test_ws_trouble(self) -> None: + """ + Ensure that the notifier task responds correctly to the WS connection going down and coming back up. + """ + + thread_id = "ThreadId" + buffer_time = timedelta(seconds = 3) + update_checker = DictBackedUpdateChecker() + notifier = MockNotifier() + (event_tx, event_rx) = open_memory_channel(0) + + # sent while WS is up, checked and notifies sent + event_a = _update_event("sender", ["rec"]) + + # sent and deleted while WS is down, checked and discarded from buffer + event_b = _update_event("sender", ["rec"]) + + # sent on WS reconnect, not checked and notifies sent + event_c = _update_event("sender", ["rec"]) + + # sent after WS reconnect, not checked and notifies sent + event_d = _update_event("sender", ["rec"]) + + ws_state = WsListenerState(down = False) + + async def ws_simulator_impl() -> None: + update_checker.exists_by_id[event_a.payload_data["id"]] = True + await event_tx.send(event_a) + + await sleep_until(1.0) + ws_state.down = True + update_checker.exists_by_id[event_b.payload_data["id"]] = True + await event_tx.send(event_b) + + await sleep_until(2.0) + update_checker.exists_by_id[event_b.payload_data["id"]] = False + ws_state.down = False + ws_state.resync_ts = UUID(event_c.payload_data["id"]).time + await event_tx.send(event_c) + + await sleep_until(3.0) + await event_tx.send(event_d) + + async def main() -> None: + async with open_nursery() as nursery: + nursery.start_soon(ws_simulator_impl) + nursery.start_soon( + message_sender_impl, + notifier, MockTitleProvider(), update_checker, ws_state, "ThreadId", event_rx, buffer_time, + getLogger() + ) + + await sleep_until(3.5) + self.assertEqual([i.update_id for i in notifier.notifies], [event_a.payload_data["id"]]) + notifier.notifies.clear() + self.assertEqual(update_checker.queries, [(thread_id, event_a.payload_data["id"])]) + update_checker.queries.clear() + + await sleep_until(4.5) + self.assertFalse(notifier.notifies) + self.assertEqual(update_checker.queries, [(thread_id, event_b.payload_data["id"])]) + update_checker.queries.clear() + + await sleep_until(5.5) + self.assertEqual([i.update_id for i in notifier.notifies], [event_c.payload_data["id"]]) + notifier.notifies.clear() + self.assertFalse(update_checker.queries) + + await sleep_until(6.5) + self.assertEqual([i.update_id for i in notifier.notifies], [event_d.payload_data["id"]]) + notifier.notifies.clear() + self.assertFalse(update_checker.queries) + + nursery.cancel_scope.cancel() + + trio.run(main, clock = MockClock(autojump_threshold = 0)) + + def test_ws_listener_cancel(self) -> None: + ws_state = WsListenerState() + + (event_in_tx, event_in_rx) = open_memory_channel(0) + event_stream = ChannelBackedEventStream(event_in_rx) + event_stream_factory = lambda: null_async_context(event_stream) + + (event_out_tx, event_out_rx) = open_memory_channel(0) + + async def main(): + async with open_nursery() as nursery: + nursery.start_soon( + ws_listener_impl, + event_stream_factory, set(), ws_state, event_out_tx, getLogger() + ) + + event_1 = _update_event("author", []) + await event_in_tx.send(event_1) + with self._assert_completes_in(1): + event_out = await event_out_rx.receive() + self.assertIs(event_out, event_1) + + event_2 = _update_event("author", ["recipient"]) + await event_in_tx.send(event_2) + with self._assert_completes_in(1): + event_out = await event_out_rx.receive() + self.assertIs(event_out, event_2) + + initial_scope = ws_state.scope + initial_scope.cancel() + await wait_all_tasks_blocked() + self.assertIsNotNone(ws_state.scope) + self.assertIsNot(ws_state.scope, initial_scope) + self.assertTrue(ws_state.down) + + event_3 = _update_event("author", []) + ts = UUID(event_3.payload_data["id"]).time + await event_in_tx.send(event_3) + with self._assert_completes_in(1): + event_out = await event_out_rx.receive() + self.assertIs(event_out, event_3) + self.assertEqual(ws_state.last_update_ts, ts) + self.assertFalse(ws_state.down) + self.assertEqual(ws_state.resync_ts, ts) + + nursery.cancel_scope.cancel() + + trio.run(main, clock = MockClock(autojump_threshold = 0)) + + def test_ws_listener_error(self) -> None: + """Test that the WS listener responds correctly to errors from the input event stream.""" + + class MockStreamException(Exception): + pass + + ws_state = WsListenerState() + + (event_in_tx, event_in_rx) = open_memory_channel(0) + event_stream = ChannelBackedEventStream(event_in_rx) + event_stream_factory = lambda: null_async_context(event_stream) + + (event_out_tx, event_out_rx) = open_memory_channel(0) + + async def main(): + async with open_nursery() as nursery: + nursery.start_soon( + ws_listener_impl, + event_stream_factory, {MockStreamException}, ws_state, event_out_tx, getLogger() + ) + + event_1 = _update_event("author", []) + await event_in_tx.send(event_1) + with self._assert_completes_in(1): + event_out = await event_out_rx.receive() + self.assertIs(event_out, event_1) + + initial_scope = ws_state.scope + await event_in_tx.send(MockStreamException()) + await wait_all_tasks_blocked() + self.assertIsNotNone(ws_state.scope) + self.assertIsNot(ws_state.scope, initial_scope) + self.assertTrue(ws_state.down) + + event_2 = _update_event("author", ["recipient"]) + ts = UUID(event_2.payload_data["id"]).time + await event_in_tx.send(event_2) + with self._assert_completes_in(1): + event_out = await event_out_rx.receive() + self.assertIs(event_out, event_2) + self.assertEqual(ws_state.last_update_ts, ts) + self.assertFalse(ws_state.down) + self.assertEqual(ws_state.resync_ts, ts) + + nursery.cancel_scope.cancel() + + trio.run(main, clock = MockClock(autojump_threshold = 0)) + + +class RedditTests(TestCase): + """Tests for interactions with Reddit.""" + + @classmethod + def setUpClass(cls) -> None: + cls._TOKEN_REF = TokenRef(environ["API_TOKEN"]) + + def test_try_request(self) -> None: + self.assertIsNotNone( + try_request( + Request( + "https://oauth.reddit.com/api/live/happening_now", + headers = self._TOKEN_REF.get_common_headers(), + ), + getLogger(), + ) + ) + + def test_try_request_http_error(self) -> None: + self.assertIsNone( + try_request( + Request("https://oauth.reddit.com/api/live/happening_now"), + getLogger(), + suppress_codes = [403], + ) + )