Cut down on Bencode decode memory usage
authorJakob Cornell <jakob+gpg@jcornell.net>
Thu, 21 Apr 2022 16:39:38 +0000 (11:39 -0500)
committerJakob Cornell <jakob+gpg@jcornell.net>
Thu, 21 Apr 2022 16:39:38 +0000 (11:39 -0500)
disk_jumble/src/disk_jumble/bencode.py

index 825ad527a38c7699193a41a770657731507f1014..1a10162445d437fe0cf23ec68605f50ccb306980 100644 (file)
@@ -1,5 +1,5 @@
 from __future__ import annotations
-from typing import Dict, List, Union
+from typing import Dict, List, Tuple, Union
 import itertools
 
 
@@ -11,86 +11,85 @@ class CodecError(Exception):
        pass
 
 
-def _pop_bytes(vals: list[bytes]) -> bytes:
+def _pop_bytes(src: bytes, start: int) -> Tuple[bytes, int]:
        len_parts = []
-       while vals and vals[0].isdigit():
-               len_parts.append(vals.pop(0))
+       off = start
+       while off < len(src) and chr(src[off]).isdigit():
+               len_parts.append(src[off])
+               off += 1
 
        try:
-               length = int(b"".join(len_parts).decode("ascii"))
+               length = int(bytes(len_parts).decode("ascii"))
        except ValueError:
                raise CodecError()
 
-       try:
-               if vals.pop(0) != b":":
-                       raise CodecError()
-       except IndexError:
+       if off == len(src) or src[off] != ord(":"):
                raise CodecError()
 
-       if length > len(vals):
+       off += 1
+       if off + length > len(src):
                raise CodecError()
 
-       out = b"".join(vals[:length])
-       del vals[:length]
-       return out
+       return (src[off : off + length], off + length)
 
 
-def _pop_int(vals: list[bytes]) -> int:
-       assert vals.pop(0) == b"i"
+def _pop_int(src: bytes, start: int) -> Tuple[int, int]:
+       assert src[start] == ord("i")
 
        try:
-               end = vals.index(b"e")
-       except ValueError:
+               end = next(i for i in range(start, len(src)) if src[i] == ord("e"))
+       except StopIteration:
                raise CodecError()
 
        try:
-               out = int(b"".join(vals[:end]).decode("ascii"))
+               out = int(src[start + 1 : end].decode("ascii"))
        except ValueError:
                raise CodecError()
 
-       del vals[slice(end + 1)]
-       return out
+       return (out, end + 1)
 
 
-def _pop_list(vals: list[bytes]) -> list[Type]:
-       assert vals.pop(0) == b"l"
+def _pop_list(src: bytes, start: int) -> Tuple[list[Type], int]:
+       assert src[start] == ord("l")
 
+       off = start + 1
        out = []
-       while vals and vals[0] != b"e":
-               out.append(_pop_value(vals))
+       while off < len(src) and src[off] != ord("e"):
+               (el, off) = _pop_value(src, off)
+               out.append(el)
 
-       if vals:
-               del vals[0]
-               return out
-       else:
+       if off == len(src):
                raise CodecError()
+       else:
+               return (out, off + 1)
 
 
-def _pop_dict(vals: list[bytes]) -> Bdict:
-       assert vals.pop(0) == b"d"
+def _pop_dict(src: bytes, start: int) -> Tuple[Bdict, int]:
+       assert src[start] == ord("d")
 
+       off = start + 1
        out = {}
-       while vals and vals[0] != b"e":
-               key = _pop_bytes(vals)
-               out[key] = _pop_value(vals)
+       while off < len(src) and src[off] != ord("e"):
+               (key, off) = _pop_bytes(src, off)
+               (out[key], off) = _pop_value(src, off)
 
-       if vals:
-               del vals[0]
-               return out
-       else:
+       if off == len(src):
                raise CodecError()
-
-
-def _pop_value(vals: list[bytes]) -> Type:
-       if vals:
-               if vals[0].isdigit():
-                       return _pop_bytes(vals)
-               elif vals[0] == b"i":
-                       return _pop_int(vals)
-               elif vals[0] == b"l":
-                       return _pop_list(vals)
-               elif vals[0] == b"d":
-                       return _pop_dict(vals)
+       else:
+               return (out, off + 1)
+
+
+def _pop_value(src: bytes, start: int) -> Tuple[Type, int]:
+       if start < len(src):
+               first = chr(src[start])
+               if first.isdigit():
+                       return _pop_bytes(src, start)
+               elif first == "i":
+                       return _pop_int(src, start)
+               elif first == "l":
+                       return _pop_list(src, start)
+               elif first == "d":
+                       return _pop_dict(src, start)
                else:
                        raise CodecError()
        else:
@@ -98,12 +97,11 @@ def _pop_value(vals: list[bytes]) -> Type:
 
 
 def decode(data: bytes) -> Type:
-       vals = [bytes([v]) for v in data]
-       out = _pop_value(vals)
-       if vals:
-               raise CodecError()
-       else:
+       (out, off) = _pop_value(data, 0)
+       if off == len(data):
                return out
+       else:
+               raise CodecError()
 
 
 def _encode_helper(data: Type) -> list[bytes]: