packages = strikebot
python_requires = ~= 3.9
install_requires =
- trio ~= 0.19
- triopg ~= 0.6
- trio-websocket ~= 0.9
asks ~= 2.4
+ beautifulsoup4 ~= 4.11
+ trio == 0.19
+ trio-websocket == 0.9.2
+ triopg == 0.6.0
__version__ = importlib.metadata.version(__package__)
-QUIET = True # TODO remove
@dataclass
"""Determine whether this count update can follow another count update."""
return self.number == prior.number + 1 and self.author != prior.author
+ def __str__(self):
+ content = str(self.number)
+ if self.command:
+ content += "+" + self.command.name
+ return "{}({} by {})".format(self.id[4 : 8], content, self.author)
+
@total_ordering
@dataclass
timeline = []
last_valid: Optional[_Update] = None
pending_strikes: set[str] = set() # names of updates to mark stricken on arrival
+ forgotten: bool = False # whether the retention period for an update has passed during the run
def handle_update(update):
nonlocal last_valid
pos = bisect.bisect(timeline, tu)
if pos != len(timeline):
- logger.warning(f"long transpo: {update.name}")
+ logger.warning(f"long transpo: {update.id}")
+
if pos > 0 and timeline[pos - 1].update.id == update.id:
# The pool sent this update message multiple times. This could be because the message reached parity and
# its bucket was dropped just before a new connection was opened and sent its own copy, but connections also
(timeline[i].update for i in reversed(range(pos)) if timeline[i].accepted),
None
)
- if update.command is not Command.RESET and update.number is not None and last_valid and not pred:
- logger.warning("ignoring {update.name}: no valid prior count on record")
+ contender = update.command is Command.RESET or update.number is not None
+ if contender and pred is None and last_valid and forgotten:
+ # This is a really long transpo which we can't judge valid or invalid since we've forgotten the surrounding
+ # updates. The best we can do is just ignore it.
+ logger.warning(f"ignoring {update.id}: no valid prior count on record")
else:
timeline.insert(pos, tu)
- assert len({ctu.update.id for ctu in timeline}) == len(timeline) # TODO remove
tu.accepted = (
update.command is Command.RESET
or (
and (pred is None or pred.number is None or update.can_follow(pred))
)
)
- logger.debug(f"accepted: {tu.accepted}")
- logger.debug(" pred: {}".format(pred and (pred.id, pred.number, pred.author)))
+ logger.debug("{} {} -> {}".format("accepted" if tu.accepted else "rejected", pred, update))
if tu.accepted:
# resync subsequent updates
newly_valid = []
newly_invalid = []
resync_last_valid = update
+ converged = False
for scan_tu in islice(timeline, pos + 1, None):
if scan_tu.update.command is Command.RESET:
- break
+ resync_last_valid = scan_tu.update
+ if last_valid:
+ converged = True
elif scan_tu.update.number is not None:
- accept = last_valid.number is None or scan_tu.update.can_follow(last_valid)
+ accept = resync_last_valid.number is None or scan_tu.update.can_follow(resync_last_valid)
if accept and scan_tu.accepted:
# resync would have no effect past this point
- break
+ if last_valid:
+ converged = True
elif accept:
newly_valid.append(scan_tu)
resync_last_valid = scan_tu.update
elif scan_tu.accepted:
newly_invalid.append(scan_tu)
+ if scan_tu.update is last_valid:
+ last_valid = None
scan_tu.accepted = accept
+ if converged:
+ break
- if last_valid:
- last_valid = max(last_valid, resync_last_valid, key = lambda u: u.ts)
+ if converged:
+ assert last_valid.ts >= resync_last_valid.ts
else:
last_valid = resync_last_valid
)
if update.stricken:
- logger.info(f"bad strike of {update.name}")
+ logger.info(f"bad strike of {update.id}")
parts.append(_format_bad_strike_alert(update, thread_id))
if parts:
parts.append(_format_curr_count(last_valid))
- if not QUIET:
- api_pool.enqueue_request(CorrectionUpdateRequest(thread_id, "\n\n".join(parts)))
+ api_pool.enqueue_request(CorrectionUpdateRequest(thread_id, "\n\n".join(parts)))
for invalid_tu in newly_invalid:
if not invalid_tu.update.stricken:
update.stricken = True
if update.command is Command.REPORT:
- if not QUIET:
- api_pool.enqueue_request(ReportUpdateRequest(thread_id, body = _format_curr_count(last_valid)))
+ api_pool.enqueue_request(ReportUpdateRequest(thread_id, body = _format_curr_count(last_valid)))
with message_rx:
while True:
if slot:
if not slot.update.stricken:
slot.update.stricken = True
- if not QUIET and isinstance(slot, _TimelineUpdate) and slot.accepted:
- logger.info(f"bad strike of {slot.update.name}")
+ if isinstance(slot, _TimelineUpdate) and slot.accepted:
+ logger.info(f"bad strike of {slot.update.id}")
body = _format_bad_strike_alert(slot.update, thread_id)
api_pool.enqueue_request(CorrectionUpdateRequest(thread_id, body))
else:
if process:
if bu.update.ts < threshold:
- logger.warning(f"ignoring {bu.update.name}: arrived past retention window")
+ logger.warning(f"ignoring {bu.update}: arrived past retention window")
else:
- logger.debug(f"processing update {bu.update.id} ({bu.update.number} by {bu.update.author})")
+ logger.debug(f"processing {bu.update}")
handle_update(bu.update)
else:
- logger.debug(f"holding {bu.update.id} ({bu.update.number}, checked against {last_valid.number})")
+ logger.debug(f"holding {bu.update}, checked against {last_valid})")
new_buffer.append(bu)
buffer_ = new_buffer
- logger.debug("last count {}".format(last_valid and (last_valid.id, last_valid.number))) # TODO remove
+ logger.debug("last count {}".format(last_valid))
# delete/forget old updates
new_timeline = []
api_pool.enqueue_request(DeleteRequest(thread_id, tu.update.name, tu.update.ts))
elif tu.update is last_valid:
new_timeline.append(tu)
+
+ if i:
+ forgotten = True
new_timeline.extend(islice(timeline, i, None))
timeline = new_timeline
+from enum import Enum
+from inspect import getmodule
+from logging import FileHandler, getLogger, StreamHandler
+from signal import SIGUSR1
+from sys import stdout
+from traceback import StackSummary
+from typing import Optional
import argparse
import configparser
import datetime as dt
import logging
-import sys
-import trio
+from trio import open_memory_channel, open_nursery, open_signal_receiver
+from trio.lowlevel import current_root_task, Task
import trio_asyncio
import triopg
from strikebot.reddit_api import ApiClientPool
+_DEBUG_LOG_PATH: Optional[str] = None # path to file to write debug logs to, if any
+
+
+_TaskKind = Enum(
+ "_TaskKind",
+ [
+ "API_WORKER",
+ "DB_CLIENT",
+ "TOKEN_UPDATE",
+ "TRACKER",
+ "WS_MERGE_READER",
+ "WS_MERGE_TIMER",
+ "WS_REFRESHER",
+ "WS_WORKER",
+ ]
+)
+
+
+def _classify_task(task: Task) -> _TaskKind:
+ mod_name = getmodule(task.coro.cr_code).__name__
+ if mod_name.startswith("trio_asyncio."):
+ return None
+ elif task.coro.__name__ == "_signal_handler_impl":
+ return None
+ else:
+ by_coro_name = {
+ "db_client_impl": _TaskKind.DB_CLIENT,
+ "token_updater_impl": _TaskKind.TOKEN_UPDATE,
+ "event_reader_impl": _TaskKind.WS_MERGE_READER,
+ "timeout_handler_impl": _TaskKind.WS_MERGE_TIMER,
+ "conn_refresher_impl": _TaskKind.WS_REFRESHER,
+ "count_tracker_impl": _TaskKind.TRACKER,
+ "_reader_impl": _TaskKind.WS_WORKER,
+ "worker_impl": _TaskKind.API_WORKER,
+ }
+ return by_coro_name[task.coro.__name__]
+
+
+def _get_all_tasks():
+ [ta_loop] = [
+ task
+ for nursery in current_root_task().child_nurseries
+ for task in nursery.child_tasks
+ if getmodule(task.coro.cr_code).__name__.startswith("trio_asyncio.")
+ ]
+ for nursery in ta_loop.child_nurseries:
+ yield from nursery.child_tasks
+
+
+def _print_task_info():
+ groups = {kind: [] for kind in _TaskKind}
+ for task in _get_all_tasks():
+ kind = _classify_task(task)
+ if kind:
+ coro = task.coro
+ frame_tuples = []
+ while coro is not None:
+ if hasattr(coro, "cr_frame"):
+ frame_tuples.append((coro.cr_frame, coro.cr_frame.f_lineno))
+ coro = coro.cr_await
+ else:
+ frame_tuples.append((coro.gi_frame, coro.gi_frame.f_lineno))
+ coro = coro.gi_yieldfrom
+ groups[kind].append(StackSummary.extract(iter(frame_tuples)))
+ for kind in _TaskKind:
+ print(kind.name)
+ for ss in groups[kind]:
+ print(" task")
+ for line in ss.format():
+ print(" " + line, end = "")
+
+
+async def _signal_handler_impl():
+ with open_signal_receiver(SIGUSR1) as sig_src:
+ async for _ in sig_src:
+ _print_task_info()
+
+
ap = argparse.ArgumentParser(__package__)
ap.add_argument("config_path")
args = ap.parse_args()
ws_silent_limit = dt.timedelta(seconds = main_cfg.getfloat("WS silent limit"))
ws_warmup = dt.timedelta(seconds = main_cfg.getfloat("WS warmup seconds"))
-db_connect_params = dict(parser["db connect params"])
+db_cfg = parser["db connect params"]
+getters = {
+ "port": db_cfg.getint,
+}
+db_connect_params = {k: getters.get(k, db_cfg.get)(k) for k in db_cfg}
-logger = logging.getLogger(__package__)
+logger = getLogger(__package__)
logger.setLevel(logging.DEBUG)
-handler = logging.StreamHandler(sys.stdout)
-handler.setFormatter(logging.Formatter("{asctime:23}: {levelname:8}: {message}", style = "{"))
+handler = StreamHandler(stdout)
+handler.setLevel(logging.WARNING)
+handler.setFormatter(logging.Formatter("{asctime:23}: {name:17} {levelname:8}: {message}", style = "{"))
logger.addHandler(handler)
+if _DEBUG_LOG_PATH:
+ debug_handler = FileHandler(_DEBUG_LOG_PATH, mode = "w")
+ debug_handler.setLevel(logging.DEBUG)
+ debug_handler.setFormatter(logging.Formatter("{asctime:23}: {name:17} {levelname:8}: {message}", style = "{"))
+ logger.addHandler(debug_handler)
+
async def main():
- async with trio.open_nursery() as nursery_a, triopg.connect(**db_connect_params) as db_conn:
- (req_tx, req_rx) = trio.open_memory_channel(0)
+ async with (
+ triopg.connect(**db_connect_params) as db_conn,
+ open_nursery() as nursery_a,
+ open_nursery() as nursery_b,
+ open_nursery() as ws_pool_nursery,
+ open_nursery() as nursery_c,
+ ):
+ nursery_a.start_soon(_signal_handler_impl)
+
client = Client(db_conn)
db_messenger = Messenger(client)
-
nursery_a.start_soon(db_messenger.db_client_impl)
api_pool = ApiClientPool(
- auth_ids, db_messenger, request_queue_limit, api_pool_error_window, api_pool_error_delay, logger
+ auth_ids, db_messenger, request_queue_limit, api_pool_error_window, api_pool_error_delay,
+ logger.getChild("api")
)
nursery_a.start_soon(api_pool.token_updater_impl)
for _ in auth_ids:
- nursery_a.start_soon(api_pool.worker_impl)
-
- (message_tx, message_rx) = trio.open_memory_channel(0)
- (pool_event_tx, pool_event_rx) = trio.open_memory_channel(0)
- merger = PoolMerger(pool_event_rx, message_tx, ws_parity_time, ws_warmup, logger)
- async with trio.open_nursery() as nursery_b, trio.open_nursery() as ws_pool_nursery:
- nursery_b.start_soon(merger.event_reader_impl)
- nursery_b.start_soon(merger.timeout_handler_impl)
- ws_pool = HealingReadPool(ws_pool_nursery, ws_pool_size, thread_id, api_pool, pool_event_tx, logger, ws_silent_limit)
- async with ws_pool, trio.open_nursery() as nursery_c:
- nursery_c.start_soon(ws_pool.conn_refresher_impl)
- nursery_c.start_soon(
- count_tracker_impl,
- message_rx, api_pool, reorder_buffer_time, thread_id, bot_user, enforcing, update_retention, logger
- )
- await ws_pool.init_workers()
+ await nursery_a.start(api_pool.worker_impl)
+ (message_tx, message_rx) = open_memory_channel(0)
+ (pool_event_tx, pool_event_rx) = open_memory_channel(0)
+
+ merger = PoolMerger(pool_event_rx, message_tx, ws_parity_time, ws_warmup, logger.getChild("merge"))
+ nursery_b.start_soon(merger.event_reader_impl)
+ nursery_b.start_soon(merger.timeout_handler_impl)
+
+ ws_pool = HealingReadPool(
+ ws_pool_nursery, ws_pool_size, thread_id, api_pool, pool_event_tx, logger.getChild("live_ws"),
+ ws_silent_limit
+ )
+ nursery_c.start_soon(ws_pool.conn_refresher_impl)
+
+ nursery_c.start_soon(
+ count_tracker_impl, message_rx, api_pool, reorder_buffer_time, thread_id, bot_user, enforcing,
+ update_retention, logger.getChild("track")
+ )
+ await ws_pool.init_workers()
trio_asyncio.run(main)
--- /dev/null
+from typing import Any
+
+
+def int_digest(val: int) -> str:
+ return "{:04x}".format(val & 0xffff)
+
+
+def obj_digest(obj: Any) -> str:
+ """Makes a short digest of the identity of an object."""
+ return int_digest(id(obj))
def _channel_sender(method):
async def wrapped(self, resp_channel, *args, **kwargs):
+ ret = await method(self, *args, **kwargs)
with resp_channel:
- await resp_channel.send(await method(self, *args, **kwargs))
+ await resp_channel.send(ret)
return wrapped
"""This is run by consumers of the DB wrapper."""
method = getattr(self._client, method_name)
(resp_tx, resp_rx) = trio.open_memory_channel(0)
- await self._request_tx.send(method(resp_tx, *args))
+ coro = method(resp_tx, *args)
+ await self._request_tx.send(coro)
async with resp_rx:
return await resp_rx.receive()
import logging
import math
-from trio.lowlevel import checkpoint_if_cancelled
from trio_websocket import ConnectionClosed, HandshakeError, open_websocket_url, WebSocketConnection
import trio
+from strikebot.common import int_digest, obj_digest
from strikebot.reddit_api import AboutLiveThreadRequest
self._pool_event_tx = pool_event_tx
self._logger = logger
self._silent_limit = silent_limit
- (self._refresh_queue_tx, self._refresh_queue_rx) = trio.open_memory_channel(size)
- async def __aenter__(self):
- pass
+ (self._refresh_queue_tx, self._refresh_queue_rx) = trio.open_memory_channel(size)
- async def __aexit__(self, exc_type, exc_value, traceback):
- self._refresh_queue_tx.close()
+ # number of workers who have stopped receiving updates but whose tasks aren't yet stopped
+ self._closing = 0
async def init_workers(self):
for _ in range(self._size):
# TODO could use task's implicit cancel scope
try:
async with conn_ctx as conn:
+ self._logger.debug("scope up: {}".format(obj_digest(cancel_scope)))
with refresh_tx:
- _tsc = False # TODO remove
+ silent_timeout = False
with cancel_scope, suppress(ConnectionClosed):
while True:
with trio.move_on_after(self._silent_limit.total_seconds()) as timeout_scope:
message = await conn.get_message()
if timeout_scope.cancelled_caught:
- if len(self._nursery.child_tasks) == self._size:
- _tsc = True
- cancel_scope.cancel()
- await checkpoint_if_cancelled()
+ if len(self._nursery.child_tasks) - self._closing == self._size:
+ silent_timeout = True
+ break
+ else:
+ self._logger.debug("not replacing connection {}; {} tasks, {} closing".format(
+ obj_digest(cancel_scope),
+ len(self._nursery.child_tasks),
+ self._closing,
+ ))
else:
event = _Message(trio.current_time(), json.loads(message), cancel_scope)
await self._pool_event_tx.send(event)
- assert _tsc == timeout_scope.cancelled_caught
- if timeout_scope.cancelled_caught:
- self._logger.debug("replacing WS connection due to silent timeout")
+ self._closing += 1
+ if silent_timeout:
await conn.aclose()
+ self._logger.debug("replacing WS connection {} due to silent timeout".format(obj_digest(cancel_scope)))
elif cancel_scope.cancelled_caught:
await conn.aclose(1008, "Server unexpectedly stopped sending messages")
- self._logger.warning("replacing WS connection due to missed update")
+ self._logger.warning("replacing WS connection {} due to missed update".format(obj_digest(cancel_scope)))
else:
- self._logger.warning("replacing WS connection closed by server")
- self._logger.debug(f"post-kill tasks {len(self._nursery.child_tasks)}")
+ self._logger.warning("replacing WS connection {} closed by server".format(obj_digest(cancel_scope)))
+
refresh_tx.send_nowait(cancel_scope)
+ self._logger.debug("scope down: {} ({} tasks, {} closing)".format(
+ obj_digest(cancel_scope),
+ len(self._nursery.child_tasks),
+ self._closing,
+ ))
+ self._closing -= 1
except HandshakeError:
self._logger.error("handshake error while opening WS connection")
await self._pool_event_tx.send(_ConnectionUp(trio.current_time(), new_scope))
refresh_tx = self._refresh_queue_tx.clone()
self._nursery.start_soon(self._reader_impl, conn, new_scope, refresh_tx)
- self._logger.debug(f"post-spawn tasks {len(self._nursery.child_tasks)}")
async def conn_refresher_impl(self):
"""Task to monitor and replace WS connections as they disconnect or go silent."""
raise RuntimeError("WS pool depleted")
+def _tag_digest(tag):
+ return int_digest(hash(tag))
+
+
+def _format_scope_list(scopes):
+ return ", ".join(sorted(map(obj_digest, scopes)))
+
+
class PoolMerger:
@dataclass
class _Bucket:
self._outgoing_scopes = set()
(self._timer_poke_tx, self._timer_poke_rx) = trio.open_memory_channel(math.inf)
+ def _log_current_buckets(self):
+ self._logger.debug("current buckets: {}".format(", ".join(map(_tag_digest, self._buckets.keys()))))
+
async def event_reader_impl(self):
- """Drop unused messages, deduplicate useful ones, and install info needed by the timeout handler."""
+ """Drop unused messages, deduplicate useful ones, and communicate with the timeout handler."""
with self._pool_event_rx, self._timer_poke_tx:
async for event in self._pool_event_rx:
if isinstance(event, _ConnectionUp):
# An early add of an active scope could mean it's expected on a message that fired before it opened,
- # resulting in a false positive for replacement. The timer task merges these in among the buckets
- # by timestamp to avoid that.
+ # resulting in a false positive for replacement. The timer task merges these in among the buckets by
+ # timestamp to avoid that.
self._pending.append(event)
elif isinstance(event, _Message):
if event.data["type"] in ["update", "strike", "delete"]:
tag = event.dedup_tag()
+ self._logger.debug("recv {} from {}".format(_tag_digest(tag), obj_digest(event.scope)))
if tag in self._buckets:
b = self._buckets[tag]
b.recipients.add(event.scope)
- if b.recipients >= self._scope_activations.keys():
- del self._buckets[tag]
+ # If this scope is the last one for this bucket we could clear the bucket here, but since
+ # connections sometimes get repeat messages and second copies can arrive before the first
+ # copy has arrived on all connections, leaving the bucket open to absorb repeats can reduce
+ # the likelihood of a second bucket being allocated late for the second copy of a message
+ # and causing unnecessary connection replacement.
elif event.scope not in self._outgoing_scopes:
sane = (
event.scope in self._scope_activations
or any(e.scope == event.scope for e in self._pending)
)
if sane:
+ self._logger.debug("new bucket {}: {}".format(_tag_digest(tag), tag))
self._buckets[tag] = self._Bucket(event.timestamp, {event.scope})
+ self._log_current_buckets()
await self._message_tx.send(event)
else:
raise RuntimeError("recieved message from unrecognized WS connection")
+ else:
+ self._logger.debug("recv type {!r} from {} (discarding)".format(event.data["type"], obj_digest(event.scope)))
elif isinstance(event, _ConnectionDown):
# We don't need to worry about canceling this scope at all, so no need to require it for parity for
# any message, even older ones. The scope may be gone already, if we canceled it previously.
if active + self._conn_warmup.total_seconds() < now
}
if now > bucket.start + self._parity_timeout.total_seconds():
- for scope in target_scopes - bucket.recipients:
+ filled = target_scopes & bucket.recipients
+ missing = target_scopes - bucket.recipients
+ extra = bucket.recipients - target_scopes
+
+ self._logger.debug("expiring bucket {}".format(_tag_digest(tag)))
+ if filled:
+ self._logger.debug(" filled {}: {}".format(len(filled), _format_scope_list(filled)))
+ if missing:
+ self._logger.debug(" missing {}: {}".format(len(missing), _format_scope_list(missing)))
+ if extra:
+ self._logger.debug(" extra {}: {}".format(len(extra), _format_scope_list(extra)))
+
+ for scope in missing:
self._outgoing_scopes.add(scope)
scope.cancel()
del self._buckets[tag]
+ self._log_current_buckets()
else:
await trio.sleep_until(bucket.start + self._parity_timeout.total_seconds())
else:
await self._empty_wait.park()
return self._deque.popleft()
+ def __len__(self):
+ return len(self._deque)
+
@total_ordering
@dataclass
import trio
from strikebot import __version__ as VERSION
+from strikebot.common import obj_digest
from strikebot.queue import MaxHeap, Queue
awaken_auths = self._waiting.keys() & tokens.keys()
self._app_queue.extend(self._waiting.pop(auth_id) for auth_id in awaken_auths)
+ self._logger.debug("updated API tokens")
async def token_updater_impl(self):
last_update = None
async def make_request(self, request: _Request) -> Response:
(resp_tx, resp_rx) = trio.open_memory_channel(0)
- self._request_queue.push((request, resp_tx))
- self._check_queue_size()
+ self.enqueue_request(request, resp_tx)
async with resp_rx:
return await resp_rx.receive()
- def enqueue_request(self, request: _Request) -> None:
- self._request_queue.push((request, None))
+ def enqueue_request(self, request: _Request, resp_tx = None) -> None:
+ self._logger.debug("request {}: {}".format(obj_digest(request), type(request).__name__))
+ self._request_queue.push((request, resp_tx))
self._check_queue_size()
+ self._logger.debug(f"{len(self._request_queue)} requests in queue")
- async def worker_impl(self):
+ async def worker_impl(self, task_status):
+ task_status.started()
while True:
(request, resp_tx) = await self._request_queue.pop()
cooldown = await self._app_queue.pop()
resp.body # read response
error = False
wait_for_token = False
+ log_suffix = " (request {})".format(obj_digest(request))
if resp.status_code == 429:
# We disagreed about the rate limit state; just try again later.
- self._logger.warning("rate limited by Reddit API")
+ self._logger.warning("rate limited by Reddit API" + log_suffix)
error = True
elif resp.status_code == 401:
- self._logger.warning("got HTTP 401 from Reddit API")
+ self._logger.warning("got HTTP 401 from Reddit API" + log_suffix)
error = True
wait_for_token = True
- elif resp.status_code in [404, 500]:
- self._logger.warning(f"got HTTP {resp.status_code} from Reddit API, retrying")
+ elif resp.status_code in [404, 500, 503]:
+ self._logger.warning(f"got HTTP {resp.status_code} from Reddit API, retrying" + log_suffix)
error = True
elif 400 <= resp.status_code < 500:
# If we're doing something wrong, let's catch it right away.
- raise RuntimeError(f"unexpected client error response: {resp.status_code}")
+ raise RuntimeError(f"unexpected client error response: {resp.status_code}" + log_suffix)
else:
if resp.status_code != 200:
raise RuntimeError(f"unexpected status code {resp.status_code}")
+ self._logger.debug("success" + log_suffix)
if resp_tx:
await resp_tx.send(resp)
cooldown.ready_at = request_time + REQUEST_DELAY.total_seconds()
if wait_for_token:
self._waiting[cooldown.auth_id] = cooldown
+ if not self._app_queue:
+ self._logger.error("all workers waiting for API tokens")
else:
self._app_queue.push(cooldown)
self.assertEqual(pu.number, 0)
self.assertFalse(pu.deletable)
+ pu = parse_update(_build_payload("<div>121 345 621</div>"), 121, "")
+ self.assertEqual(pu.number, 121)
+ self.assertFalse(pu.deletable)
+
+ pu = parse_update(_build_payload("28 336 816"), 28336816, "")
+ self.assertEqual(pu.number, 28336816)
+ self.assertTrue(pu.deletable)
+
def test_non_counts(self):
pu = parse_update(_build_payload("<div><table><tbody><tr><td>zoo</td></tr></tbody></table></div>"), None, "")
self.assertFalse(pu.count_attempt)
from __future__ import annotations
-from collections import deque
from dataclasses import dataclass
from enum import Enum
from typing import Optional
-from xml.etree import ElementTree
import re
+from bs4 import BeautifulSoup
+
Command = Enum("Command", ["RESET", "REPORT"])
SPACE = object()
# flatten the update content to plain text
- doc = ElementTree.fromstring(payload_data["body_html"])
- worklist = deque([doc])
+ tree = BeautifulSoup(payload_data["body_html"], "html.parser")
+ worklist = tree.contents
out = [[]]
while worklist:
- el = worklist.popleft()
- if el is NEW_LINE:
- if out[-1]:
- out.append([])
+ el = worklist.pop()
+ if isinstance(el, str):
+ out[-1].append(el)
elif el is SPACE:
out[-1].append(el)
- elif isinstance(el, str):
- out[-1].append(el.replace("\n", " "))
- elif el.tag in ["em", "strong", "del", "span", "sup", "code", "a", "th", "td"]:
- if el.tail:
- worklist.appendleft(el.tail)
- for sub in reversed(el):
- worklist.appendleft(sub)
- if el.text:
- worklist.appendleft(el.text)
- elif el.tag == "li":
- worklist.appendleft(NEW_LINE)
- worklist.appendleft(el.text)
- elif el.tag in ["p", "div", "h1", "h2", "blockquote"]:
- if el.tail:
- worklist.appendleft(el.tail)
- worklist.appendleft(NEW_LINE)
- for sub in reversed(el):
- worklist.appendleft(sub)
- if el.text:
- worklist.appendleft(el.text)
- elif el.tag in ["ul", "ol"]:
- if el.tail:
- worklist.appendleft(el.tail)
- for sub in reversed(el):
- worklist.appendleft(sub)
- worklist.appendleft(NEW_LINE)
- elif el.tag == "pre":
- if el.text:
- out.extend([l] for l in el.text.splitlines())
- worklist.appendleft(NEW_LINE)
- if el.tail:
- worklist.appendleft(el.tail)
- for sub in reversed(el):
- worklist.appendleft(sub)
- elif el.tag == "table":
- if el.tail:
- worklist.appendleft(el.tail)
- worklist.appendleft(NEW_LINE)
- for sub in reversed(el):
- assert sub.tag in ["thead", "tbody"]
- for row in reversed(sub):
- assert row.tag == "tr"
- for (i, cell) in enumerate(reversed(row)):
- worklist.appendleft(cell)
- if i != len(row) - 1:
- worklist.appendleft(SPACE)
- worklist.appendleft(NEW_LINE)
- elif el.tag == "br":
- if el.tail:
- worklist.appendleft(el.tail)
- worklist.appendleft(NEW_LINE)
+ elif el is NEW_LINE or el.name == "br":
+ if out[-1]:
+ out.append([])
+ elif el.name in ["em", "strong", "del", "span", "sup", "code", "a", "th", "td"]:
+ worklist.extend(reversed(el.contents))
+ elif el.name in ["ul", "ol", "table", "thead", "tbody"]:
+ worklist.extend(reversed(el.contents))
+ elif el.name in ["li", "p", "div", "h1", "h2", "blockquote"]:
+ worklist.append(NEW_LINE)
+ worklist.extend(reversed(el.contents))
+ worklist.append(NEW_LINE)
+ elif el.name == "pre":
+ worklist.append(NEW_LINE)
+ worklist.extend([l] for l in reversed(el.text.splitlines()))
+ worklist.append(NEW_LINE)
+ elif el.name == "tr":
+ worklist.append(NEW_LINE)
+ for (i, cell) in enumerate(reversed(el.contents)):
+ worklist.append(cell)
+ if i != len(el.contents) - 1:
+ worklist.append(SPACE)
+ worklist.append(NEW_LINE)
else:
raise RuntimeError(f"can't parse tag {el.tag}")
- lines = list(filter(
- None,
- (
- "".join(" " if part is SPACE else part for part in parts).strip()
- for parts in out
- )
- ))
+ tmp_lines = (
+ "".join(" " if part is SPACE else part for part in parts).strip()
+ for parts in out
+ )
+ pre_strip_lines = list(filter(None, tmp_lines))
+
+ # normalize whitespace according to HTML rendering rules
+ # https://developer.mozilla.org/en-US/docs/Web/API/Document_Object_Model/Whitespace#explanation
+ stripped_lines = [
+ re.sub(" +", " ", l.replace("\t", " ").replace("\n", " ")).strip(" ")
+ for l in pre_strip_lines
+ ]
+
+ return _parse_from_lines(stripped_lines, curr_count, bot_user)
+
+def _parse_from_lines(lines: list[str], curr_count: Optional[int], bot_user: str) -> ParsedUpdate:
command = next(
filter(None, (_parse_command(l, bot_user) for l in lines)),
None
)
-
if lines:
+ # look for groups of digits (as many as possible) separated by a uniform separator from the valid set
first = lines[0]
match = re.match(
"(?P<v>v)?(?P<neg>-)?(?P<num>\\d+((?P<sep>[,. \u2009]|, )\\d+((?P=sep)\\d+)*)?)",
re.ASCII, # only recognize ASCII digits
)
if match:
- ct_str = match["num"]
+ raw_digits = match["num"]
sep = match["sep"]
post = first[match.end() :]
zeros = False
- while len(ct_str) > 1 and ct_str[0] == "0":
+ while len(raw_digits) > 1 and raw_digits[0] == "0":
zeros = True
- ct_str = ct_str.removeprefix("0").removeprefix(sep or "")
-
- parts = ct_str.split(sep) if sep else [ct_str]
- parts_valid = (
- sep is None
- or (
- len(parts[0]) in range(1, 4)
- and all(len(p) == 3 for p in parts[1:])
- )
- )
- digits = "".join(parts)
+ raw_digits = raw_digits.removeprefix("0").removeprefix(sep or "")
+
+ parts = raw_digits.split(sep) if sep else [raw_digits]
lone = len(lines) == 1 and (not post or post.isspace())
typo = False
if lone:
- if match["v"] and len(ct_str) <= 2:
+ all_parts_valid = (
+ sep is None
+ or (
+ 1 <= len(parts[0]) <= 3
+ and all(len(p) == 3 for p in parts[1:])
+ )
+ )
+ if match["v"] and len(parts) == 1 and len(parts[0]) <= 2:
# failed paste of leading digits
typo = True
- elif match["v"] and parts_valid:
+ elif match["v"] and all_parts_valid:
# v followed by count
typo = True
- elif curr_count and curr_count >= 100 and bool(match["neg"]) == (curr_count < 0):
- goal = (sep or "").join(_separate(str(curr_count)))
- partials = [goal[: -2], goal[: -1], goal[: -2] + goal[-1]]
- if ct_str in partials:
+ elif curr_count is not None and abs(curr_count) >= 100 and bool(match["neg"]) == (curr_count < 0):
+ goal_parts = _separate(str(abs(curr_count)))
+ partials = [
+ goal_parts[: -1] + [goal_parts[-1][: -1]], # missing last digit
+ goal_parts[: -1] + [goal_parts[-1][: -2]], # missing last two digits
+ goal_parts[: -1] + [goal_parts[-1][: -2] + goal_parts[-1][-1]], # missing second-last digit
+ ]
+ if parts in partials:
# missing any of last two digits
typo = True
- elif ct_str in [p + goal for p in partials]:
+ elif parts in [p[: -1] + [p[-1] + goal_parts[0]] + goal_parts[1 :] for p in partials]:
# double paste
typo = True
- if match["v"] or zeros or typo or (digits == "0" and match["neg"]):
+ if match["v"] or zeros or typo or (parts == ["0"] and match["neg"]):
number = None
count_attempt = True
deletable = lone
- elif parts_valid:
+ else:
+ if curr_count is not None and sep and sep.isspace():
+ # Presume that the intended count consists of as many valid digit groups as necessary to match the
+ # number of digits in the expected count, if possible.
+ digit_count = len(str(abs(curr_count)))
+ use_parts = []
+ accum = 0
+ for (i, part) in enumerate(parts):
+ part_valid = len(part) <= 3 if i == 0 else len(part) == 3
+ if part_valid and accum < digit_count:
+ use_parts.append(part)
+ accum += len(part)
+ else:
+ break
+
+ # could still be a no-separator count with some extra digit groups on the same line
+ if not use_parts:
+ use_parts = [parts[0]]
+
+ lone = lone and len(use_parts) == len(parts)
+ else:
+ # current count is unknown or no separator was used
+ use_parts = parts
+
+ digits = "".join(use_parts)
number = -int(digits) if match["neg"] else int(digits)
special = (
curr_count is not None
and _is_special_number(number)
)
deletable = lone and not special
- if post and not post[0].isspace():
+ if len(use_parts) == len(parts) and post and not post[0].isspace():
count_attempt = curr_count is not None and abs(number - curr_count) <= 25
number = None
else:
count_attempt = True
- else:
- number = None
- count_attempt = False
- deletable = lone
else:
# no count attempt found
number = None