Stream from the database to reduce memory footprint
authorJakob Cornell <jakob+gpg@jcornell.net>
Fri, 22 Apr 2022 01:21:26 +0000 (20:21 -0500)
committerJakob Cornell <jakob+gpg@jcornell.net>
Fri, 22 Apr 2022 01:21:26 +0000 (20:21 -0500)
disk_jumble/src/disk_jumble/verify.py

index b1df745c8e3a984734237d7a5dbf59711723474c..12c7d9dc703097669426874f9f7b9c8259e6214b 100644 (file)
@@ -1,6 +1,7 @@
 from __future__ import annotations
 from dataclasses import dataclass
-from typing import List, Optional
+from typing import Iterator, List, Optional
+from warnings import warn
 import argparse
 import contextlib
 import hashlib
@@ -67,7 +68,8 @@ def _run_sort_key(run: _Run):
 def _get_target_ranges(conn, disk_id: int, limit: Optional[int]) -> List[range]:
        ranges = []
        block_count = 0
-       with conn.cursor() as cursor:
+       with conn:
+               cursor = conn.cursor("target_ranges")
                cursor.execute(
                        """
                                select unnest(written_map - coalesce(verified_map, int8multirange())) as range
@@ -105,7 +107,7 @@ def _get_v1_worklist(conn, disk_id: int, target_ranges: List[NumericRange]) -> L
 
        The runs are then reordered by the number of their first disk block.
        """
-       cursor = conn.cursor()
+       cursor = conn.cursor("v1_worklist")
        cursor.execute(
                """
                        select distinct on (entity_id)
@@ -206,8 +208,8 @@ def _get_v1_worklist(conn, disk_id: int, target_ranges: List[NumericRange]) -> L
        return runs
 
 
-def _get_v2_worklist(conn, disk_id: int, target_ranges: List[NumericRange]) -> List[_V2Run]:
-       cursor = conn.cursor()
+def _get_v2_worklist(conn, disk_id: int, target_ranges: List[NumericRange]) -> Iterator[_V2Run]:
+       cursor = conn.cursor("v2_worklist")
        cursor.execute(
                """
                        with
@@ -259,85 +261,86 @@ def _get_v2_worklist(conn, disk_id: int, target_ranges: List[NumericRange]) -> L
                        "target_ranges": target_ranges,
                }
        )
-       rows = cursor.fetchall()
-
-       if any(crypt_key is not None for (*_, crypt_key) in rows):
-               raise NotImplementedError("verify of encrypted data")
-
-       return [
-               _V2Run(
-                       entity_id = bytes(entity_id),
-                       entity_length = entity_length,
-                       piece_num = piece_num,
-                       block_ranges = [range(block, block + 1)],
-                       hash = bytes(hash_),
-               )
-               for (entity_id, entity_length, piece_num, block, hash_, _) in rows
-       ]
 
+       for (entity_id, entity_length, piece_num, block, hash_, crypt_key) in cursor:
+               if crypt_key is None:
+                       yield _V2Run(
+                               entity_id = bytes(entity_id),
+                               entity_length = entity_length,
+                               piece_num = piece_num,
+                               block_ranges = [range(block, block + 1)],
+                               hash = bytes(hash_),
+                       )
+               else:
+                       raise NotImplementedError("verify of encrypted data")
 
-def _do_verify(conn, disk_id: int, target_ranges: List[range], disk_file: io.BufferedIOBase, read_size: int, read_tries: int):
-       pg_target_ranges = [NumericRange(r.start, r.stop) for r in target_ranges]
-       worklist = list(heapq.merge(
-               _get_v1_worklist(conn, disk_id, pg_target_ranges),
-               _get_v2_worklist(conn, disk_id, pg_target_ranges),
-               key = _run_sort_key,
-       ))
 
+def _do_verify(conn, disk_id: int, target_ranges: List[range], disk_file: io.BufferedIOBase, read_size: int, read_tries: int):
        requested_blocks = {
                block
                for r in target_ranges
                for block in r
        }
-       covered_blocks = {
-               block
-               for run in worklist
-               for block_range in run.block_ranges
-               for block in block_range
-       }
-       missing = requested_blocks - covered_blocks
-       if missing:
-               raise RuntimeError(f"unable to locate blocks: {len(missing)} in the range {min(missing)} to {max(missing)}")
-
-       passes = []
-       fails = []
-       for run in worklist:
-               if isinstance(run, _V1Run):
-                       hasher = hashlib.sha1()
-                       entity_off = run.piece_num * run.piece_length
-               else:
-                       hasher = hashlib.sha256()
-                       entity_off = run.piece_num * BLOCK_SIZE
-               assert len(run.hash) == hasher.digest_size, "incorrect validation hash length"
-
-               try:
-                       for range_ in run.block_ranges:
-                               for block in range_:
-                                       read_start = block * BLOCK_SIZE
-                                       read_end = read_start + min(BLOCK_SIZE, run.entity_length - entity_off)
-                                       disk_file.seek(read_start)
-                                       while disk_file.tell() < read_end:
-                                               pos = disk_file.tell()
-                                               for _ in range(read_tries):
-                                                       try:
-                                                               data = disk_file.read(min(read_end - pos, read_size))
-                                                       except OSError:
-                                                               disk_file.seek(pos)
-                                                       else:
-                                                               break
-                                               else:
-                                                       raise _BadSector()
-
-                                               assert data
-                                               hasher.update(data)
-                                       entity_off += BLOCK_SIZE
-               except _BadSector:
-                       fails.extend(run.block_ranges)
-               else:
-                       if hasher.digest() == run.hash:
-                               passes.extend(run.block_ranges)
+       pg_target_ranges = [NumericRange(r.start, r.stop) for r in target_ranges]
+
+       # transaction is required for named cursors in worklist generation
+       with conn:
+               worklist = heapq.merge(
+                       _get_v1_worklist(conn, disk_id, pg_target_ranges),
+                       _get_v2_worklist(conn, disk_id, pg_target_ranges),
+                       key = _run_sort_key,
+               )
+
+               covered_blocks = set()
+               passes = []
+               fails = []
+               for run in worklist:
+                       if isinstance(run, _V1Run):
+                               hasher = hashlib.sha1()
+                               entity_off = run.piece_num * run.piece_length
                        else:
+                               hasher = hashlib.sha256()
+                               entity_off = run.piece_num * BLOCK_SIZE
+                       assert len(run.hash) == hasher.digest_size, "incorrect validation hash length"
+
+                       try:
+                               for range_ in run.block_ranges:
+                                       for block in range_:
+                                               read_start = block * BLOCK_SIZE
+                                               read_end = read_start + min(BLOCK_SIZE, run.entity_length - entity_off)
+                                               disk_file.seek(read_start)
+                                               while disk_file.tell() < read_end:
+                                                       pos = disk_file.tell()
+                                                       for _ in range(read_tries):
+                                                               try:
+                                                                       data = disk_file.read(min(read_end - pos, read_size))
+                                                               except OSError:
+                                                                       disk_file.seek(pos)
+                                                               else:
+                                                                       break
+                                                       else:
+                                                               raise _BadSector()
+
+                                                       assert data
+                                                       hasher.update(data)
+                                               entity_off += BLOCK_SIZE
+                       except _BadSector:
                                fails.extend(run.block_ranges)
+                       else:
+                               if hasher.digest() == run.hash:
+                                       passes.extend(run.block_ranges)
+                               else:
+                                       fails.extend(run.block_ranges)
+
+                       covered_blocks.update(
+                               block
+                               for block_range in run.block_ranges
+                               for block in block_range
+                       )
+
+       missing = requested_blocks - covered_blocks
+       if missing:
+               warn(f"unable to locate blocks: {len(missing)} in the range {min(missing)} to {max(missing)}")
 
        def clean_up(ranges):
                out = []