+# see Python `configparser' docs for precise file syntax
+
[config]
+
+######## basic operating parameters
+
+# space-separated authorization IDs for Reddit API token lookup
+auth IDs = 13 15
+
bot user = count_better
+enforcing = true
+thread ID = abc123
-# (seconds) maximum allowable spread in WebSocket message arrival times
-WS parity time = 1.0
-# goal size of WebSocket connection pool
-WS pool size = 5
+######## API pool error handling
+
+# (seconds) for error backoff, pause the API pool for this much time
+API pool error delay = 5.5
+
+# (seconds) pause the pool when the Reddit API returns two errors within any time span of this length
+API pool error window = 5.5
+
+# terminate if the Reddit API request queue exceeds this size
+request queue limit = 100
-# (seconds) time after WebSocket handshake completion during which missed updates are excused
-WS warmup time = 0.3
+
+######## live thread update handling
# (seconds) maximum time to hold updates for reordering
reorder buffer time = 0.25
-enforcing = true
-
# (seconds) minimum time to retain updates to enable future resyncs
update retention time = 120.0
-thread ID = abc123
-# space-separated authorization IDs for Reddit API token lookup
-auth IDs = 13 15
+######## WebSocket pool
+
+# (seconds) maximum allowable spread in WebSocket message arrival times
+WS parity time = 1.0
+
+# goal size of WebSocket connection pool
+WS pool size = 5
+
+# (seconds) after receiving no messages for this much time, attempt to replace the connection if the pool isn't depleted
+WS silent limit = 30
+
+# (seconds) time after WebSocket handshake completion during which missed updates are excused
+WS warmup time = 0.3
+
# Postgres database configuration; same options as in a connect string
[db connect params]
from contextlib import nullcontext as nullcontext
from dataclasses import dataclass
from functools import total_ordering
-from typing import Optional, Set
+from itertools import islice
+from typing import Optional
from uuid import UUID
import bisect
import datetime as dt
import importlib.metadata
import itertools
import logging
-import re
from trio import CancelScope, EndOfChannel
import trio
__version__ = importlib.metadata.version(__package__)
+QUIET = True # TODO remove
@dataclass
class _TimelineUpdate:
update: _Update
accepted: bool
- bad_count: bool
def __lt__(self, other):
return self.update.ts < other.update.ts
+ def rejected(self) -> bool:
+ """Whether the update should be stricken."""
+ return self.update.count_attempt and not self.accepted
+
def _parse_rfc_4122_ts(ts: int) -> dt.datetime:
epoch = dt.datetime(1582, 10, 15, tzinfo = dt.timezone.utc)
buffer_ = []
timeline = []
last_valid: Optional[_Update] = None
- pending_strikes: Set[str] = set() # names of updates to mark stricken on arrival
+ pending_strikes: set[str] = set() # names of updates to mark stricken on arrival
def handle_update(update):
- nonlocal delete_start, last_valid
+ nonlocal last_valid
- tu = _TimelineUpdate(update, accepted = False, bad_count = False)
+ tu = _TimelineUpdate(update, accepted = False)
pos = bisect.bisect(timeline, tu)
if pos != len(timeline):
logger.warning(f"long transpo: {update.name}")
+ if pos > 0 and timeline[pos - 1].update.id == update.id:
+ # The pool sent this update message multiple times. This could be because the message reached parity and
+ # its bucket was dropped just before a new connection was opened and sent its own copy, but connections also
+ # seem to send duplicate messages.
+ return
pred = next(
(timeline[i].update for i in reversed(range(pos)) if timeline[i].accepted),
if update.command is not Command.RESET and update.number is not None and last_valid and not pred:
logger.warning("ignoring {update.name}: no valid prior count on record")
else:
+ timeline.insert(pos, tu)
+ assert len({ctu.update.id for ctu in timeline}) == len(timeline) # TODO remove
tu.accepted = (
update.command is Command.RESET
or (
and (pred is None or pred.number is None or update.can_follow(pred))
)
)
- timeline.insert(pos, tu)
+ logger.debug(f"accepted: {tu.accepted}")
+ logger.debug(" pred: {}".format(pred and (pred.id, pred.number, pred.author)))
if tu.accepted:
- # resync subsequent updates already processed
+ # resync subsequent updates
newly_valid = []
newly_invalid = []
- last_valid = update
- for scan_tu in timeline[pos + 1:]:
+ resync_last_valid = update
+ for scan_tu in islice(timeline, pos + 1, None):
if scan_tu.update.command is Command.RESET:
- last_valid = scan_tu.update
+ break
elif scan_tu.update.number is not None:
accept = last_valid.number is None or scan_tu.update.can_follow(last_valid)
- if accept and not scan_tu.accepted:
+ if accept and scan_tu.accepted:
+ # resync would have no effect past this point
+ break
+ elif accept:
newly_valid.append(scan_tu)
- elif not accept and scan_tu.accepted:
+ resync_last_valid = scan_tu.update
+ elif scan_tu.accepted:
newly_invalid.append(scan_tu)
scan_tu.accepted = accept
- if accept:
- last_valid = scan_tu.update
+
+ if last_valid:
+ last_valid = max(last_valid, resync_last_valid, key = lambda u: u.ts)
+ else:
+ last_valid = resync_last_valid
parts = []
if newly_valid:
"The following counts are valid:\n\n"
+ "\n".join(" - " + _format_update_ref(tu.update, thread_id) for tu in newly_valid)
)
- for tu in newly_valid:
- tu.bad_count = False
unstrikable = [tu for tu in newly_invalid if tu.update.stricken]
if unstrikable:
"The following counts are invalid:\n\n"
+ "\n".join(" - " + _format_update_ref(tu.update, thread_id) for tu in unstrikable)
)
- for tu in unstrikable:
- tu.bad_count = True
if update.stricken:
logger.info(f"bad strike of {update.name}")
parts.append(_format_bad_strike_alert(update, thread_id))
+ if parts:
+ parts.append(_format_curr_count(last_valid))
+ if not QUIET:
+ api_pool.enqueue_request(CorrectionUpdateRequest(thread_id, "\n\n".join(parts)))
+
for invalid_tu in newly_invalid:
if not invalid_tu.update.stricken:
if enforcing:
api_pool.enqueue_request(StrikeRequest(thread_id, invalid_tu.update.name, invalid_tu.update.ts))
- invalid_tu.bad_count = True
invalid_tu.update.stricken = True
- if parts:
- parts.append(_format_curr_count(last_valid))
- api_pool.enqueue_request(CorrectionUpdateRequest(thread_id, "\n\n".join(parts)))
- elif update.number is not None or update.count_attempt:
+ elif update.count_attempt:
if enforcing:
api_pool.enqueue_request(StrikeRequest(thread_id, update.name, update.ts))
- tu.bad_count = True
update.stricken = True
if update.command is Command.REPORT:
- api_pool.enqueue_request(ReportUpdateRequest(thread_id, body = _format_curr_count(last_valid)))
+ if not QUIET:
+ api_pool.enqueue_request(ReportUpdateRequest(thread_id, body = _format_curr_count(last_valid)))
with message_rx:
while True:
stricken = payload_data["name"] in pending_strikes
pending_strikes.discard(payload_data["name"])
if payload_data["author"] != bot_user:
- next_up = last_valid.number if last_valid else None
+ if last_valid and last_valid.number is not None:
+ next_up = last_valid.number + 1
+ else:
+ next_up = None
pu = parse_update(payload_data, next_up, bot_user)
rfc_ts = UUID(payload_data["id"]).time
update = _Update(
release_at = msg.timestamp + reorder_buffer_time.total_seconds()
bisect.insort(buffer_, _BufferedUpdate(update, release_at))
elif msg.data["type"] == "strike":
- UUID(re.match("LiveUpdate_(.+)$", msg.data["payload"])[1]) # sanity check payload
slot = next(
(
slot for slot in itertools.chain(buffer_, reversed(timeline))
None
)
if slot:
- slot.update.stricken = True
- if isinstance(slot, _TimelineUpdate) and slot.accepted:
- logger.info(f"bad strike of {slot.update.name}")
- body = _format_bad_strike_alert(slot.update, thread_id)
- api_pool.enqueue_request(CorrectionUpdateRequest(thread_id, body))
+ if not slot.update.stricken:
+ slot.update.stricken = True
+ if not QUIET and isinstance(slot, _TimelineUpdate) and slot.accepted:
+ logger.info(f"bad strike of {slot.update.name}")
+ body = _format_bad_strike_alert(slot.update, thread_id)
+ api_pool.enqueue_request(CorrectionUpdateRequest(thread_id, body))
else:
pending_strikes.add(msg.data["payload"])
# valid count
or (last_valid is None or last_valid.number is None or bu.update.can_follow(last_valid))
-
- # invalid and not likely to become valid by transpo
- or bu.update.number > last_valid.number + 3
+ or bu.update.command is Command.RESET
or trio.current_time() >= bu.release_at
+ or bu.update.command is Command.RESET
)
if process:
if bu.update.ts < threshold:
logger.warning(f"ignoring {bu.update.name}: arrived past retention window")
else:
+ logger.debug(f"processing update {bu.update.id} ({bu.update.number} by {bu.update.author})")
handle_update(bu.update)
else:
+ logger.debug(f"holding {bu.update.id} ({bu.update.number}, checked against {last_valid.number})")
new_buffer.append(bu)
buffer_ = new_buffer
+ logger.debug("last count {}".format(last_valid and (last_valid.id, last_valid.number))) # TODO remove
# delete/forget old updates
+ new_timeline = []
+ i = 0
for (i, tu) in enumerate(timeline):
if i >= len(timeline) - 10 or tu.update.ts >= threshold:
break
- elif tu.bad_count and tu.update.deletable:
+ elif tu.rejected() and tu.update.deletable:
if enforcing:
api_pool.enqueue_request(DeleteRequest(thread_id, tu.update.name, tu.update.ts))
- del timeline[: i]
+ elif tu.update is last_valid:
+ new_timeline.append(tu)
+ new_timeline.extend(islice(timeline, i, None))
+ timeline = new_timeline
main_cfg = parser["config"]
+api_pool_error_delay = dt.timedelta(seconds = main_cfg.getfloat("API pool error delay"))
+api_pool_error_window = dt.timedelta(seconds = main_cfg.getfloat("API pool error window"))
auth_ids = set(map(int, main_cfg["auth IDs"].split()))
bot_user = main_cfg["bot user"]
enforcing = main_cfg.getboolean("enforcing")
reorder_buffer_time = dt.timedelta(seconds = main_cfg.getfloat("reorder buffer time"))
+request_queue_limit = main_cfg.getint("request queue limit")
thread_id = main_cfg["thread ID"]
update_retention = dt.timedelta(seconds = main_cfg.getfloat("update retention time"))
ws_parity_time = dt.timedelta(seconds = main_cfg.getfloat("WS parity time"))
ws_pool_size = main_cfg.getint("WS pool size")
+ws_silent_limit = dt.timedelta(seconds = main_cfg.getfloat("WS silent limit"))
ws_warmup = dt.timedelta(seconds = main_cfg.getfloat("WS warmup seconds"))
db_connect_params = dict(parser["db connect params"])
nursery_a.start_soon(db_messenger.db_client_impl)
- api_pool = ApiClientPool(auth_ids, db_messenger, logger)
+ api_pool = ApiClientPool(
+ auth_ids, db_messenger, request_queue_limit, api_pool_error_window, api_pool_error_delay, logger
+ )
nursery_a.start_soon(api_pool.token_updater_impl)
for _ in auth_ids:
nursery_a.start_soon(api_pool.worker_impl)
(message_tx, message_rx) = trio.open_memory_channel(0)
(pool_event_tx, pool_event_rx) = trio.open_memory_channel(0)
merger = PoolMerger(pool_event_rx, message_tx, ws_parity_time, ws_warmup, logger)
- async with trio.open_nursery() as nursery_b:
+ async with trio.open_nursery() as nursery_b, trio.open_nursery() as ws_pool_nursery:
nursery_b.start_soon(merger.event_reader_impl)
nursery_b.start_soon(merger.timeout_handler_impl)
- pool = HealingReadPool(nursery_b, ws_pool_size, thread_id, api_pool, pool_event_tx, logger)
- async with pool, trio.open_nursery() as nursery_c:
- nursery_c.start_soon(pool.conn_refresher_impl)
+ ws_pool = HealingReadPool(ws_pool_nursery, ws_pool_size, thread_id, api_pool, pool_event_tx, logger, ws_silent_limit)
+ async with ws_pool, trio.open_nursery() as nursery_c:
+ nursery_c.start_soon(ws_pool.conn_refresher_impl)
nursery_c.start_soon(
count_tracker_impl,
message_rx, api_pool, reorder_buffer_time, thread_id, bot_user, enforcing, update_retention, logger
)
+ await ws_pool.init_workers()
trio_asyncio.run(main)
conn: Any
@_channel_sender
- async def get_auth_tokens(self, auth_ids: set[int]):
- raise NotImplementedError()
+ async def get_auth_tokens(self, auth_ids: set[int]) -> dict[int, str]:
+ results = await self.conn.fetch(
+ """
+ select
+ distinct on (id)
+ id, access_token
+ from
+ public.reddit_app_authorization
+ join unnest($1::integer[]) as request_id on id = request_id
+ where expires > current_timestamp
+ """,
+ list(auth_ids),
+ )
+ return {r["id"]: r["access_token"] for r in results}
class Messenger:
import logging
import math
+from trio.lowlevel import checkpoint_if_cancelled
from trio_websocket import ConnectionClosed, HandshakeError, open_websocket_url, WebSocketConnection
import trio
class HealingReadPool:
- def __init__(self, nursery: trio.Nursery, size, live_thread_id, api_client_pool, pool_event_tx, logger: logging.Logger):
+ def __init__(
+ self, pool_nursery: trio.Nursery, size, live_thread_id, api_client_pool, pool_event_tx, logger: logging.Logger,
+ silent_limit: dt.timedelta
+ ):
assert size >= 2
- self._nursery = nursery
+ self._nursery = pool_nursery
self._size = size
self._live_thread_id = live_thread_id
self._api_client_pool = api_client_pool
self._pool_event_tx = pool_event_tx
self._logger = logger
+ self._silent_limit = silent_limit
(self._refresh_queue_tx, self._refresh_queue_rx) = trio.open_memory_channel(size)
- self._active_count = 0
async def __aenter__(self):
- for _ in range(self._size):
- await self._spawn_reader()
- if not self._active_count:
- raise RuntimeError("Unable to create any WS connections")
+ pass
async def __aexit__(self, exc_type, exc_value, traceback):
self._refresh_queue_tx.close()
+ async def init_workers(self):
+ for _ in range(self._size):
+ await self._spawn_reader()
+ if not self._nursery.child_tasks:
+ raise RuntimeError("Unable to create any WS connections")
+
async def _reader_impl(self, conn_ctx: AsyncContextManager, cancel_scope, refresh_tx):
+ # TODO could use task's implicit cancel scope
try:
async with conn_ctx as conn:
- self._active_count += 1
with refresh_tx:
+ _tsc = False # TODO remove
with cancel_scope, suppress(ConnectionClosed):
while True:
- message = await conn.get_message()
- event = _Message(trio.current_time(), json.loads(message), cancel_scope)
- await self._pool_event_tx.send(event)
-
- if cancel_scope.cancelled_caught:
- await conn.aclose(4058, "Server unexpectedly stopped sending messages")
- self._logger.warning("replacing WS connection due to missed updates")
+ with trio.move_on_after(self._silent_limit.total_seconds()) as timeout_scope:
+ message = await conn.get_message()
+
+ if timeout_scope.cancelled_caught:
+ if len(self._nursery.child_tasks) == self._size:
+ _tsc = True
+ cancel_scope.cancel()
+ await checkpoint_if_cancelled()
+ else:
+ event = _Message(trio.current_time(), json.loads(message), cancel_scope)
+ await self._pool_event_tx.send(event)
+
+ assert _tsc == timeout_scope.cancelled_caught
+ if timeout_scope.cancelled_caught:
+ self._logger.debug("replacing WS connection due to silent timeout")
+ await conn.aclose()
+ elif cancel_scope.cancelled_caught:
+ await conn.aclose(1008, "Server unexpectedly stopped sending messages")
+ self._logger.warning("replacing WS connection due to missed update")
else:
self._logger.warning("replacing WS connection closed by server")
+ self._logger.debug(f"post-kill tasks {len(self._nursery.child_tasks)}")
refresh_tx.send_nowait(cancel_scope)
- self._active_count -= 1
except HandshakeError:
self._logger.error("handshake error while opening WS connection")
await self._pool_event_tx.send(_ConnectionUp(trio.current_time(), new_scope))
refresh_tx = self._refresh_queue_tx.clone()
self._nursery.start_soon(self._reader_impl, conn, new_scope, refresh_tx)
+ self._logger.debug(f"post-spawn tasks {len(self._nursery.child_tasks)}")
async def conn_refresher_impl(self):
"""Task to monitor and replace WS connections as they disconnect or go silent."""
async for old_scope in self._refresh_queue_rx:
await self._pool_event_tx.send(_ConnectionDown(trio.current_time(), old_scope))
await self._spawn_reader()
- if not self._active_count:
+ if not self._nursery.child_tasks:
raise RuntimeError("WS pool depleted")
self._buckets = {}
self._scope_activations = {}
self._pending = deque()
+ self._outgoing_scopes = set()
(self._timer_poke_tx, self._timer_poke_rx) = trio.open_memory_channel(math.inf)
async def event_reader_impl(self):
# by timestamp to avoid that.
self._pending.append(event)
elif isinstance(event, _Message):
- if event.data["type"] in {"update", "strike", "delete"}:
+ if event.data["type"] in ["update", "strike", "delete"]:
tag = event.dedup_tag()
if tag in self._buckets:
b = self._buckets[tag]
b.recipients.add(event.scope)
if b.recipients >= self._scope_activations.keys():
del self._buckets[tag]
- else:
+ elif event.scope not in self._outgoing_scopes:
sane = (
event.scope in self._scope_activations
or any(e.scope == event.scope for e in self._pending)
)
if sane:
self._buckets[tag] = self._Bucket(event.timestamp, {event.scope})
- self._message_tx.send_nowait(event)
+ await self._message_tx.send(event)
else:
raise RuntimeError("recieved message from unrecognized WS connection")
elif isinstance(event, _ConnectionDown):
# We don't need to worry about canceling this scope at all, so no need to require it for parity for
# any message, even older ones. The scope may be gone already, if we canceled it previously.
self._scope_activations.pop(event.scope, None)
+ self._pending = deque(ev for ev in self._pending if ev.scope is not event.scope)
+ self._outgoing_scopes.discard(event.scope)
else:
raise TypeError(f"Expected pool event, found {event!r}")
self._timer_poke_tx.send_nowait(None) # may have new work for the timer
scope for (scope, active) in self._scope_activations.items()
if active + self._conn_warmup.total_seconds() < now
}
- if bucket.recipients >= target_scopes:
- del self._buckets[tag]
- elif now > bucket.start + self._parity_timeout.total_seconds():
+ if now > bucket.start + self._parity_timeout.total_seconds():
for scope in target_scopes - bucket.recipients:
+ self._outgoing_scopes.add(scope)
scope.cancel()
- del self._scope_activations[scope]
del self._buckets[tag]
else:
await trio.sleep_until(bucket.start + self._parity_timeout.total_seconds())
-"""Unbounded blocking priority queue for Trio"""
+"""Unbounded blocking queues for Trio"""
+from collections import deque
from dataclasses import dataclass
from functools import total_ordering
-from typing import Any
+from typing import Any, Iterable
import heapq
from trio.lowlevel import ParkingLot
+class Queue:
+ def __init__(self):
+ self._deque = deque()
+ self._empty_wait = ParkingLot()
+
+ def push(self, el: Any) -> None:
+ self._deque.append(el)
+ self._empty_wait.unpark()
+
+ def extend(self, els: Iterable[Any]) -> None:
+ for el in els:
+ self.push(el)
+
+ async def pop(self) -> Any:
+ if not self._deque:
+ await self._empty_wait.park()
+ return self._deque.popleft()
+
+
@total_ordering
@dataclass
class _ReverseOrdWrapper:
def push(self, item):
heapq.heappush(self._heap, _ReverseOrdWrapper(item))
- if len(self._empty_wait):
- self._empty_wait.unpark()
+ self._empty_wait.unpark()
async def pop(self):
if not self._heap:
"""Multi-auth HTTP connection pool with Reddit API wrappers and rate limiting."""
from abc import ABCMeta, abstractmethod
-from collections import deque
from dataclasses import dataclass, field
from functools import total_ordering
+from socket import EAI_AGAIN, EAI_FAIL, gaierror
import datetime as dt
import logging
import trio
from strikebot import __version__ as VERSION
-from strikebot.queue import MaxHeap
+from strikebot.queue import MaxHeap, Queue
REQUEST_DELAY = dt.timedelta(seconds = 1)
def __lt__(self, other):
if type(self) is type(other):
- return self._subtype_key() < other._subtype_key()
+ return self._subtype_cmp_key() < other._subtype_cmp_key()
else:
prec = self._SUBTYPE_PRECEDENCE
return prec.index(type(other)) < prec.index(type(self))
def __eq__(self, other):
if type(self) is type(other):
- return self._subtype_key() == other._subtype_key()
+ return self._subtype_cmp_key() == other._subtype_cmp_key()
else:
return False
class ApiClientPool:
- def __init__(self, auth_ids: set[int], db_messenger, logger: logging.Logger):
+ def __init__(
+ self,
+ auth_ids: set[int],
+ db_messenger: "strikebot.db.Messenger",
+ request_queue_limit: int,
+ error_window: dt.timedelta,
+ error_delay: dt.timedelta,
+ logger: logging.Logger,
+ ):
+ self._auth_ids = auth_ids
self._db_messenger = db_messenger
+ self._request_queue_limit = request_queue_limit
+ self._error_window = error_window
+ self._error_delay = error_delay
self._logger = logger
- self._tokens = {id_: None for id_ in auth_ids}
- self._tokens_installed = trio.Event()
now = trio.current_time()
- self._app_queue = deque(AppCooldown(id_, now) for id_ in auth_ids)
+ self._tokens = {}
+ self._app_queue = Queue()
+ self._waiting = {id_: AppCooldown(id_, now) for id_ in auth_ids}
self._request_queue = MaxHeap()
+ # pool-wide API error backoff
+ self._last_error = None
+ self._global_resume = None
+
self._session = asks.Session(connections = len(auth_ids))
self._session.base_location = API_BASE_URL
async def _update_tokens(self):
- tokens = await self._db_messenger.do("get_auth_tokens", (self._tokens.keys(),))
- assert tokens.keys() == self._tokens.keys()
- self._tokens = tokens
- self._tokens_installed.set()
+ tokens = await self._db_messenger.do("get_auth_tokens", (self._auth_ids,))
+ self._tokens.update(tokens)
+
+ awaken_auths = self._waiting.keys() & tokens.keys()
+ self._app_queue.extend(self._waiting.pop(auth_id) for auth_id in awaken_auths)
async def token_updater_impl(self):
last_update = None
while True:
if last_update is not None:
await trio.sleep_until(last_update + TOKEN_UPDATE_DELAY.total_seconds())
- last_update = trio.current_time()
+
+ if last_update is None:
+ last_update = trio.current_time()
+ else:
+ last_update += TOKEN_UPDATE_DELAY.total_seconds()
+
await self._update_tokens()
def _check_queue_size(self) -> None:
- if len(self._request_queue) > 5:
- self._logger.warning(f"API workers may be saturated; {len(self._request_queue)} requests in queue")
+ if len(self._request_queue) > self._request_queue_limit:
+ raise RuntimeError("request queue size exceeded limit")
async def make_request(self, request: _Request) -> Response:
(resp_tx, resp_rx) = trio.open_memory_channel(0)
self._request_queue.push((request, resp_tx))
+ self._check_queue_size()
async with resp_rx:
return await resp_rx.receive()
def enqueue_request(self, request: _Request) -> None:
self._request_queue.push((request, None))
+ self._check_queue_size()
async def worker_impl(self):
- await self._tokens_installed.wait()
while True:
(request, resp_tx) = await self._request_queue.pop()
- while trio.current_time() < self._app_queue[0].ready_at:
- await trio.sleep_until(self._app_queue[0].ready_at)
+ cooldown = await self._app_queue.pop()
+ await trio.sleep_until(cooldown.ready_at)
+ if self._global_resume:
+ await trio.sleep_until(self._global_resume)
- cooldown = self._app_queue.popleft()
asks_kwargs = request.to_asks_kwargs()
- asks_kwargs.setdefault("headers", {}).update({
+ headers = asks_kwargs.setdefault("headers", {})
+ headers.update({
"Authorization": "Bearer {}".format(self._tokens[cooldown.auth_id]),
"User-Agent": USER_AGENT_FMT.format(auth_id = cooldown.auth_id),
})
- cooldown.ready_at = trio.current_time() + REQUEST_DELAY.total_seconds()
- resp = await self._session.request(**asks_kwargs)
-
- if resp.status_code == 429:
- # We disagreed about the rate limit state; just try again later.
- self._request_queue.put((request, resp_tx))
- elif 400 <= resp.status_code < 500:
- # If we're doing something wrong, let's catch it right away.
- if resp_tx:
- resp_tx.close()
- raise RuntimeError("Unexpected client error response: {}".format(resp.status_code))
- else:
- if resp.status_code != 200:
- self._logger.warning(f"got HTTP {resp.status_code} from Reddit API")
- if resp_tx:
- await resp_tx.send(resp)
- self._app_queue.append(cooldown)
+ request_time = trio.current_time()
+ try:
+ resp = await self._session.request(**asks_kwargs)
+ except gaierror as e:
+ if e.errno in [EAI_FAIL, EAI_AGAIN]:
+ # DNS failure, probably temporary
+ error = True
+ else:
+ raise
+ else:
+ resp.body # read response
+ error = False
+ wait_for_token = False
+ if resp.status_code == 429:
+ # We disagreed about the rate limit state; just try again later.
+ self._logger.warning("rate limited by Reddit API")
+ error = True
+ elif resp.status_code == 401:
+ self._logger.warning("got HTTP 401 from Reddit API")
+ error = True
+ wait_for_token = True
+ elif resp.status_code in [404, 500]:
+ self._logger.warning(f"got HTTP {resp.status_code} from Reddit API, retrying")
+ error = True
+ elif 400 <= resp.status_code < 500:
+ # If we're doing something wrong, let's catch it right away.
+ raise RuntimeError(f"unexpected client error response: {resp.status_code}")
+ else:
+ if resp.status_code != 200:
+ raise RuntimeError(f"unexpected status code {resp.status_code}")
+ if resp_tx:
+ await resp_tx.send(resp)
+
+ if error:
+ self._request_queue.push((request, resp_tx))
+ self._check_queue_size()
+ if self._last_error:
+ spread = dt.timedelta(seconds = request_time - self._last_error)
+ if spread <= self._error_window:
+ self._global_resume = request_time + self._error_delay.total_seconds()
+ self._last_error = request_time
+
+ cooldown.ready_at = request_time + REQUEST_DELAY.total_seconds()
+ if wait_for_token:
+ self._waiting[cooldown.auth_id] = cooldown
+ else:
+ self._app_queue.push(cooldown)
class UpdateParsingTests(TestCase):
def test_successful_counts(self):
- pu = parse_update(_build_payload("<div>12,345,678 spaghetti</div>"), None)
+ pu = parse_update(_build_payload("<div>12,345,678 spaghetti</div>"), None, "")
self.assertEqual(pu.number, 12345678)
self.assertFalse(pu.deletable)
- pu = parse_update(_build_payload("<div><p>0</p><br/>oz</div>"), None)
+ pu = parse_update(_build_payload("<div><p>0</p><br/>oz</div>"), None, "")
self.assertEqual(pu.number, 0)
self.assertFalse(pu.deletable)
def test_non_counts(self):
- pu = parse_update(_build_payload("<div><table><tbody><tr><td>zoo</td></tr></tbody></table></div>"), None)
+ pu = parse_update(_build_payload("<div><table><tbody><tr><td>zoo</td></tr></tbody></table></div>"), None, "")
self.assertFalse(pu.count_attempt)
self.assertFalse(pu.deletable)
def test_typos(self):
- pu = parse_update(_build_payload("<span>v9</span>"), 888)
+ pu = parse_update(_build_payload("<span>v9</span>"), 888, "")
self.assertIsNone(pu.number)
self.assertTrue(pu.count_attempt)
- pu = parse_update(_build_payload("<div>v11.585 Empire</div>"), None)
+ pu = parse_update(_build_payload("<div>v11.585 Empire</div>"), None, "")
self.assertIsNone(pu.number)
self.assertTrue(pu.count_attempt)
self.assertFalse(pu.deletable)
- pu = parse_update(_build_payload("<div>11, 585, 22 </div>"), 11_585_202)
+ pu = parse_update(_build_payload("<div>11, 585, 22 </div>"), 11_585_202, "")
self.assertIsNone(pu.number)
self.assertTrue(pu.count_attempt)
self.assertTrue(pu.deletable)
- pu = parse_update(_build_payload("<span>0490499</span>"), 4999)
+ pu = parse_update(_build_payload("<span>0490499</span>"), 4999, "")
self.assertIsNone(pu.number)
self.assertTrue(pu.count_attempt)
class ParsedUpdate:
number: Optional[int]
command: Optional[Command]
- count_attempt: bool
+ count_attempt: bool # either well-formed or typo
deletable: bool
def parse_update(payload_data: dict, curr_count: Optional[int], bot_user: str) -> ParsedUpdate:
+ # curr_count is the next number up, one more than the last count
+
NEW_LINE = object()
SPACE = object()
if el.text:
worklist.appendleft(el.text)
elif el.tag == "li":
- assert not el.tail
worklist.appendleft(NEW_LINE)
worklist.appendleft(el.text)
elif el.tag in ["p", "div", "h1", "h2", "blockquote"]:
elif el.tag in ["ul", "ol"]:
if el.tail:
worklist.appendleft(el.tail)
- assert not el.text
for sub in reversed(el):
worklist.appendleft(sub)
worklist.appendleft(NEW_LINE)
elif el.tag == "pre":
- out.extend([l] for l in el.text.splitlines())
+ if el.text:
+ out.extend([l] for l in el.text.splitlines())
+ worklist.appendleft(NEW_LINE)
if el.tail:
worklist.appendleft(el.tail)
- worklist.appendleft(NEW_LINE)
+ for sub in reversed(el):
+ worklist.appendleft(sub)
elif el.tag == "table":
if el.tail:
worklist.appendleft(el.tail)
worklist.appendleft(NEW_LINE)
- assert not el.text
for sub in reversed(el):
assert sub.tag in ["thead", "tbody"]
- assert not sub.text
for row in reversed(sub):
assert row.tag == "tr"
- assert not row.text or row.tail
for (i, cell) in enumerate(reversed(row)):
- assert not cell.tail
worklist.appendleft(cell)
if i != len(row) - 1:
worklist.appendleft(SPACE)
if lines:
first = lines[0]
match = re.match(
- "(?P<v>v)?(?P<neg>-)?(?P<num>\\d+((?P<sep>[,. \u2009]|, )\\d+((?P=sep)\\d+)*)?)( |$)",
+ "(?P<v>v)?(?P<neg>-)?(?P<num>\\d+((?P<sep>[,. \u2009]|, )\\d+((?P=sep)\\d+)*)?)",
first,
re.ASCII, # only recognize ASCII digits
)
if match:
ct_str = match["num"]
sep = match["sep"]
+ post = first[match.end() :]
zeros = False
while len(ct_str) > 1 and ct_str[0] == "0":
parts = ct_str.split(sep) if sep else [ct_str]
parts_valid = (
- len(parts[0]) in range(1, 3)
- and all(len(p) == 3 for p in parts[1:])
+ sep is None
+ or (
+ len(parts[0]) in range(1, 4)
+ and all(len(p) == 3 for p in parts[1:])
+ )
)
digits = "".join(parts)
- lone = not first[match.end() :].strip() and len(lines) == 1
+ lone = len(lines) == 1 and (not post or post.isspace())
typo = False
if lone:
if match["v"] and len(ct_str) <= 2:
elif match["v"] and parts_valid:
# v followed by count
typo = True
- elif curr_count and curr_count >= 100:
+ elif curr_count and curr_count >= 100 and bool(match["neg"]) == (curr_count < 0):
goal = (sep or "").join(_separate(str(curr_count)))
partials = [goal[: -2], goal[: -1], goal[: -2] + goal[-1]]
if ct_str in partials:
elif ct_str in [p + goal for p in partials]:
# double paste
typo = True
+
if match["v"] or zeros or typo or (digits == "0" and match["neg"]):
number = None
count_attempt = True
deletable = lone
elif parts_valid:
number = -int(digits) if match["neg"] else int(digits)
- count_attempt = True
- special = curr_count is not None and abs(number - curr_count) <= 25 and _is_special_number(number)
+ special = (
+ curr_count is not None
+ and abs(number - curr_count) <= 25
+ and _is_special_number(number)
+ )
deletable = lone and not special
+ if post and not post[0].isspace():
+ count_attempt = curr_count is not None and abs(number - curr_count) <= 25
+ number = None
+ else:
+ count_attempt = True
else:
number = None
count_attempt = False