From b50bcaa6bb41445042c38481c0cce60208455d6b Mon Sep 17 00:00:00 2001 From: Jakob Cornell Date: Thu, 21 Apr 2022 11:39:38 -0500 Subject: [PATCH] Cut down on Bencode decode memory usage --- disk_jumble/src/disk_jumble/bencode.py | 106 ++++++++++++------------- 1 file changed, 52 insertions(+), 54 deletions(-) diff --git a/disk_jumble/src/disk_jumble/bencode.py b/disk_jumble/src/disk_jumble/bencode.py index 825ad52..1a10162 100644 --- a/disk_jumble/src/disk_jumble/bencode.py +++ b/disk_jumble/src/disk_jumble/bencode.py @@ -1,5 +1,5 @@ from __future__ import annotations -from typing import Dict, List, Union +from typing import Dict, List, Tuple, Union import itertools @@ -11,86 +11,85 @@ class CodecError(Exception): pass -def _pop_bytes(vals: list[bytes]) -> bytes: +def _pop_bytes(src: bytes, start: int) -> Tuple[bytes, int]: len_parts = [] - while vals and vals[0].isdigit(): - len_parts.append(vals.pop(0)) + off = start + while off < len(src) and chr(src[off]).isdigit(): + len_parts.append(src[off]) + off += 1 try: - length = int(b"".join(len_parts).decode("ascii")) + length = int(bytes(len_parts).decode("ascii")) except ValueError: raise CodecError() - try: - if vals.pop(0) != b":": - raise CodecError() - except IndexError: + if off == len(src) or src[off] != ord(":"): raise CodecError() - if length > len(vals): + off += 1 + if off + length > len(src): raise CodecError() - out = b"".join(vals[:length]) - del vals[:length] - return out + return (src[off : off + length], off + length) -def _pop_int(vals: list[bytes]) -> int: - assert vals.pop(0) == b"i" +def _pop_int(src: bytes, start: int) -> Tuple[int, int]: + assert src[start] == ord("i") try: - end = vals.index(b"e") - except ValueError: + end = next(i for i in range(start, len(src)) if src[i] == ord("e")) + except StopIteration: raise CodecError() try: - out = int(b"".join(vals[:end]).decode("ascii")) + out = int(src[start + 1 : end].decode("ascii")) except ValueError: raise CodecError() - del vals[slice(end + 1)] - return out + return (out, end + 1) -def _pop_list(vals: list[bytes]) -> list[Type]: - assert vals.pop(0) == b"l" +def _pop_list(src: bytes, start: int) -> Tuple[list[Type], int]: + assert src[start] == ord("l") + off = start + 1 out = [] - while vals and vals[0] != b"e": - out.append(_pop_value(vals)) + while off < len(src) and src[off] != ord("e"): + (el, off) = _pop_value(src, off) + out.append(el) - if vals: - del vals[0] - return out - else: + if off == len(src): raise CodecError() + else: + return (out, off + 1) -def _pop_dict(vals: list[bytes]) -> Bdict: - assert vals.pop(0) == b"d" +def _pop_dict(src: bytes, start: int) -> Tuple[Bdict, int]: + assert src[start] == ord("d") + off = start + 1 out = {} - while vals and vals[0] != b"e": - key = _pop_bytes(vals) - out[key] = _pop_value(vals) + while off < len(src) and src[off] != ord("e"): + (key, off) = _pop_bytes(src, off) + (out[key], off) = _pop_value(src, off) - if vals: - del vals[0] - return out - else: + if off == len(src): raise CodecError() - - -def _pop_value(vals: list[bytes]) -> Type: - if vals: - if vals[0].isdigit(): - return _pop_bytes(vals) - elif vals[0] == b"i": - return _pop_int(vals) - elif vals[0] == b"l": - return _pop_list(vals) - elif vals[0] == b"d": - return _pop_dict(vals) + else: + return (out, off + 1) + + +def _pop_value(src: bytes, start: int) -> Tuple[Type, int]: + if start < len(src): + first = chr(src[start]) + if first.isdigit(): + return _pop_bytes(src, start) + elif first == "i": + return _pop_int(src, start) + elif first == "l": + return _pop_list(src, start) + elif first == "d": + return _pop_dict(src, start) else: raise CodecError() else: @@ -98,12 +97,11 @@ def _pop_value(vals: list[bytes]) -> Type: def decode(data: bytes) -> Type: - vals = [bytes([v]) for v in data] - out = _pop_value(vals) - if vals: - raise CodecError() - else: + (out, off) = _pop_value(data, 0) + if off == len(data): return out + else: + raise CodecError() def _encode_helper(data: Type) -> list[bytes]: -- 2.30.2