Rework DJ verify v1 to handle weird piece lengths
authorJakob Cornell <jakob+gpg@jcornell.net>
Thu, 30 Jun 2022 02:51:55 +0000 (21:51 -0500)
committerJakob Cornell <jakob+gpg@jcornell.net>
Thu, 30 Jun 2022 02:51:55 +0000 (21:51 -0500)
disk_jumble/src/disk_jumble/verify.py

index 2ed32f245d2bc475231c590aef185cdd38333398..6b309bc903caed250dad092afcc5d763b679d77e 100644 (file)
@@ -8,7 +8,6 @@ import hashlib
 import heapq
 import io
 import itertools
-import math
 
 from psycopg2.extras import NumericRange
 import psycopg2
@@ -39,12 +38,18 @@ class _TorrentInfo:
                return cls(torrent_id, length, info[b"piece length"], hashes)
 
 
+@dataclass
+class _BlockSpan:
+       blocks: range
+       initial_offset: int
+
+
 @dataclass
 class _Run:
        entity_id: bytes
        entity_length: int
        piece_num: int
-       block_ranges: List[range]
+       block_spans: List[_BlockSpan]
        hash: bytes
 
 
@@ -62,7 +67,7 @@ class _BadSector(Exception):
 
 
 def _run_sort_key(run: _Run):
-       return run.block_ranges[0].start
+       return run.block_spans[0].blocks.start
 
 
 def _get_target_ranges(conn, disk_id: int, limit: Optional[int]) -> List[range]:
@@ -92,20 +97,12 @@ def _get_target_ranges(conn, disk_id: int, limit: Optional[int]) -> List[range]:
        return ranges
 
 
-def _get_v1_worklist(conn, disk_id: int, target_ranges: List[NumericRange]) -> List[_V1Run]:
+def _get_v1_worklist(conn, disk_id: int, target_ranges: List[NumericRange], block_size: int) -> List[_V1Run]:
        """
        How this works: First, we fetch some info about each torrent on the disk that has data within the requested blocks.
        Then, we make one query per torrent (because each one may have a different piece size) to determine which disk
-       blocks to verify.  Here's how that works for a given torrent:
-
-       1. Select slabs on the current disk that have data from the appropriate torrent.
-       2. Join with the list of requested block ranges to obtain which slabs and what portions of each are to be verified.
-       3. Convert each requested range to a set of piece numbers and deduplicate.
-       4. Join again with the slab table to get the complete block ranges required to cover each requested piece.
-       5. Order by piece number (to facilitate grouping by piece), then entity offset (to produce an ordered run for each
-       piece).
-
-       The runs are then reordered by the number of their first disk block.
+       blocks to verify. The runs are reordered in Python by the number of their first disk block, so the progress of the
+       run can be easily monitored.
        """
        cursor = conn.cursor("v1_worklist")
        cursor.execute(
@@ -126,7 +123,10 @@ def _get_v1_worklist(conn, disk_id: int, target_ranges: List[NumericRange]) -> L
        ]
 
        for i in infos:
-               if i.piece_length % BLOCK_SIZE != 0:
+               # precondition for disk read logic below
+               assert i.piece_length >= block_size
+
+               if i.piece_length % block_size != 0:
                        warn(f"entity {i.id.hex()} has invalid piece length")
 
        runs = []
@@ -136,10 +136,12 @@ def _get_v1_worklist(conn, disk_id: int, target_ranges: List[NumericRange]) -> L
                        """
                                with
                                        relevant_slab as (
+                                               -- select slabs on this disk
                                                select * from diskjumble.slab
                                                where disk_id = %(disk_id)s and entity_id = %(entity_id)s
                                        ),
                                        targeted_slab as (
+                                               -- select slabs that cover our requested block range
                                                select relevant_slab.*, target_range * disk_blocks as target_span
                                                from
                                                        relevant_slab
@@ -147,40 +149,54 @@ def _get_v1_worklist(conn, disk_id: int, target_ranges: List[NumericRange]) -> L
                                                                on target_range && disk_blocks
                                        ),
                                        targeted_piece as (
+                                               -- select piece numbers that cover those slabs
                                                select distinct
-                                                       (
-                                                               entity_offset
-                                                               (
-                                                                       (generate_series(lower(target_span), upper(target_span) - 1) - lower(disk_blocks))
-                                                                       * %(block_size)s
-                                                               )
-                                                       ) / %(piece_length)s
+                                                       generate_series(
+                                                               (entity_offset + (lower(target_span) - lower(disk_blocks)) * %(block_size)s) / %(piece_length)s,
+                                                               ceil(
+                                                                       (entity_offset + (upper(target_span) - lower(disk_blocks)) * %(block_size)s)::numeric(38, 19)
+                                                                       / %(piece_length)s
+                                                               )::bigint - 1
+                                                       )
                                                                as piece_num
                                                from targeted_slab
                                        ),
-                                       out as (
+                                       out_a as (
+                                               -- select slab info for each piece we're checking
                                                select
                                                        entity_offset,
                                                        disk_blocks,
                                                        piece_num,
-                                                       disk_blocks * int8range(
-                                                               (piece_num * %(piece_length)s - entity_offset) / %(block_size)s + lower(disk_blocks),
-                                                               ((piece_num + 1) * %(piece_length)s - entity_offset) / %(block_size)s + lower(disk_blocks)
-                                                       ) as use_blocks,
+                                                       int8range(
+                                                               piece_num * %(piece_length)s - entity_offset,
+                                                               (piece_num + 1) * %(piece_length)s - entity_offset
+                                                       ) as range_in_slab,  -- byte range of piece within slab
                                                        crypt_key
                                                from relevant_slab cross join targeted_piece
+                                               -- the join condition is the 'where' of the main query
+                                       ),
+                                       out_b as (
+                                               select
+                                                       entity_offset, disk_blocks, piece_num, crypt_key,
+                                                       disk_blocks * int8range(
+                                                               lower(range_in_slab) / %(block_size)s + lower(disk_blocks),
+                                                               ceil(upper(range_in_slab)::numeric(38, 19) / %(block_size)s)::bigint + lower(disk_blocks)
+                                                       ) as use_blocks,
+                                                       case when lower(range_in_slab) >= 0 then lower(range_in_slab) %% %(block_size)s else 0 end
+                                                               as initial_offset
+                                               from out_a
                                        )
-                               select piece_num, use_blocks, crypt_key from out
+                               select piece_num, use_blocks, initial_offset, crypt_key from out_b
                                where not isempty(use_blocks)
                                order by
-                                       piece_num,
-                                       entity_offset + (lower(use_blocks) - lower(disk_blocks)) * %(block_size)s
+                                       piece_num,  -- facilitate grouping by piece
+                                       entity_offset + (lower(use_blocks) - lower(disk_blocks)) * %(block_size)s  -- ordered run within each piece
                        """,
                        {
                                "disk_id": disk_id,
                                "entity_id": info.id,
                                "target_ranges": target_ranges,
-                               "block_size": BLOCK_SIZE,
+                               "block_size": block_size,
                                "piece_length": info.piece_length,
                        }
                )
@@ -190,27 +206,30 @@ def _get_v1_worklist(conn, disk_id: int, target_ranges: List[NumericRange]) -> L
                        raise NotImplementedError("verify of encrypted data")
 
                for (piece_num, piece_rows) in itertools.groupby(rows, lambda t: t[0]):
-                       block_ranges = [range(r.lower, r.upper) for (_, r, _) in piece_rows]
+                       block_spans = [
+                               _BlockSpan(range(r.lower, r.upper), off)
+                               for (_, r, off, _) in piece_rows
+                       ]
                        run = _V1Run(
                                entity_id = info.id,
                                entity_length = info.length,
                                piece_length = info.piece_length,
                                piece_num = piece_num,
-                               block_ranges = block_ranges,
+                               block_spans = block_spans,
                                hash = info.hashes[piece_num],
                        )
 
                        # Make sure we ended up with all the block ranges for the piece.  If not, assume the piece isn't fully
                        # written yet and skip.
                        run_length = min(info.piece_length, info.length - piece_num * info.piece_length)
-                       if sum(map(len, run.block_ranges)) == math.ceil(run_length / BLOCK_SIZE):
+                       if sum(len(s.blocks) * block_size - s.initial_offset for s in run.block_spans) >= run_length:
                                runs.append(run)
 
        runs.sort(key = _run_sort_key)
        return runs
 
 
-def _get_v2_worklist(conn, disk_id: int, target_ranges: List[NumericRange]) -> Iterator[_V2Run]:
+def _get_v2_worklist(conn, disk_id: int, target_ranges: List[NumericRange], block_size: int) -> Iterator[_V2Run]:
        cursor = conn.cursor("v2_worklist")
        cursor.execute(
                """
@@ -258,7 +277,7 @@ def _get_v2_worklist(conn, disk_id: int, target_ranges: List[NumericRange]) -> I
                        order by block
                """,
                {
-                       "block_size": BLOCK_SIZE,
+                       "block_size": block_size,
                        "disk_id": disk_id,
                        "target_ranges": target_ranges,
                }
@@ -270,21 +289,24 @@ def _get_v2_worklist(conn, disk_id: int, target_ranges: List[NumericRange]) -> I
                                entity_id = bytes(entity_id),
                                entity_length = entity_length,
                                piece_num = piece_num,
-                               block_ranges = [range(block, block + 1)],
+                               block_spans = [_BlockSpan(range(block, block + 1), 0)],
                                hash = bytes(hash_),
                        )
                else:
                        raise NotImplementedError("verify of encrypted data")
 
 
-def _do_verify(conn, disk_id: int, target_ranges: List[range], disk_file: io.BufferedIOBase, read_size: int, read_tries: int):
+def _do_verify(
+       conn, disk_id: int, target_ranges: List[range], disk_file: io.BufferedIOBase, read_size: int, read_tries: int,
+       block_size: int
+):
        pg_target_ranges = [NumericRange(r.start, r.stop) for r in target_ranges]
 
        # transaction is required for named cursors in worklist generation
        with conn:
                worklist = heapq.merge(
-                       _get_v1_worklist(conn, disk_id, pg_target_ranges),
-                       _get_v2_worklist(conn, disk_id, pg_target_ranges),
+                       _get_v1_worklist(conn, disk_id, pg_target_ranges, block_size),
+                       _get_v2_worklist(conn, disk_id, pg_target_ranges, block_size),
                        key = _run_sort_key,
                )
 
@@ -292,18 +314,23 @@ def _do_verify(conn, disk_id: int, target_ranges: List[range], disk_file: io.Buf
                fails = []
                for run in worklist:
                        if isinstance(run, _V1Run):
+                               piece_length = run.piece_length
                                hasher = hashlib.sha1()
-                               entity_off = run.piece_num * run.piece_length
                        else:
+                               piece_length = block_size
                                hasher = hashlib.sha256()
-                               entity_off = run.piece_num * BLOCK_SIZE
                        assert len(run.hash) == hasher.digest_size, "incorrect validation hash length"
 
+                       entity_off = run.piece_num * piece_length
+                       piece_end_ent = min(entity_off + piece_length, run.entity_length)
+                       block_ranges = [s.blocks for s in run.block_spans]
                        try:
-                               for range_ in run.block_ranges:
-                                       for block in range_:
-                                               read_start = block * BLOCK_SIZE
-                                               read_end = read_start + min(BLOCK_SIZE, run.entity_length - entity_off)
+                               for span in run.block_spans:
+                                       for (i, block) in enumerate(span.blocks):
+                                               block_offset = span.initial_offset if i == 0 else 0
+                                               read_start = block * block_size + block_offset
+                                               read_size = min(block_size - block_offset, piece_end_ent - entity_off)
+                                               read_end = read_start + read_size
                                                disk_file.seek(read_start)
                                                while disk_file.tell() < read_end:
                                                        pos = disk_file.tell()
@@ -319,14 +346,14 @@ def _do_verify(conn, disk_id: int, target_ranges: List[range], disk_file: io.Buf
 
                                                        assert data
                                                        hasher.update(data)
-                                               entity_off += BLOCK_SIZE
+                                               entity_off += read_size
                        except _BadSector:
-                               fails.extend(run.block_ranges)
+                               fails.extend(block_ranges)
                        else:
                                if hasher.digest() == run.hash:
-                                       passes.extend(run.block_ranges)
+                                       passes.extend(block_ranges)
                                else:
-                                       fails.extend(run.block_ranges)
+                                       fails.extend(block_ranges)
 
        def clean_up(ranges):
                out = []
@@ -401,4 +428,4 @@ if __name__ == "__main__":
        with contextlib.closing(psycopg2.connect("")) as conn:
                target_ranges = _get_target_ranges(conn, args.disk_id, args.block_limit)
                with open(path, "rb", buffering = _READ_BUFFER_SIZE) as disk_file:
-                       _do_verify(conn, args.disk_id, target_ranges, disk_file, _READ_BUFFER_SIZE, args.read_tries)
+                       _do_verify(conn, args.disk_id, target_ranges, disk_file, _READ_BUFFER_SIZE, args.read_tries, BLOCK_SIZE)