Make verify v1 compatible with mixed v1/v2 disks
authorJakob Cornell <jakob+gpg@jcornell.net>
Mon, 21 Feb 2022 20:21:36 +0000 (14:21 -0600)
committerJakob Cornell <jakob+gpg@jcornell.net>
Mon, 21 Feb 2022 20:29:58 +0000 (14:29 -0600)
disk_jumble/src/disk_jumble/db.py
disk_jumble/src/disk_jumble/verify.py

index 66e75627aba227d5e2d48fb109337b38225636b6..278d12dc6444eb823fa44757d1b0e1a3fbffeda9 100644 (file)
@@ -69,12 +69,11 @@ class Wrapper:
                        ]
                        execute_batch(cursor, stmt, param_sets)
 
-       def get_slabs_and_hashers(self, disk_id: int, sector_size: int) -> Iterable[tuple[Slab, Optional[HasherRef]]]:
+       def get_v1_worklist(self, disk_id: int, sector_size: int) -> Iterable[tuple[Slab, bytes, Optional[HasherRef]]]:
                """
-               Find slabs on the specified disk, and for each also return any lowest-offset saved hasher that left off directly
-               before or within the slab's entity data.
+               Find slabs on the specified disk along with corresponding bencoded info dicts, and for each also return any
+               lowest-offset saved hasher that left off directly before or within the slab's entity data.
                """
-
                stmt = """
                        with
                                incomplete_edge as (
@@ -94,10 +93,11 @@ class Wrapper:
                                        where seq >= all (select seq from diskjumble.verify_piece_content where verify_id = p.verify_id)
                                )
                        select
-                               slab_id, disk_id, disk_blocks, slab.entity_id, entity_offset, crypt_key, verify_id, seq, end_off,
-                               hasher_state
+                               slab_id, disk_id, disk_blocks, slab.entity_id, torrent_info.info, entity_offset, crypt_key, verify_id,
+                               seq, end_off, hasher_state
                        from
                                diskjumble.slab
+                               join bittorrent.torrent_info on digest(info, 'sha1') = slab.entity_id
                                natural left join diskjumble.disk
                                left join incomplete_edge on
                                        incomplete_edge.entity_id = slab.entity_id
@@ -114,7 +114,7 @@ class Wrapper:
                        cursor.execute(stmt, {"disk_id": disk_id, "sector_size": sector_size})
                        for (_, rows_iter) in itertools.groupby(cursor, lambda r: r[0]):
                                rows = list(rows_iter)
-                               [(slab_id, disk_id, sectors_pg, entity_id, entity_off, key_mem)] = {r[:6] for r in rows}
+                               [(slab_id, disk_id, sectors_pg, entity_id, info_mem, entity_off, key_mem)] = {r[:7] for r in rows}
                                sectors = range(sectors_pg.lower, sectors_pg.upper)
                                key = None if key_mem is None else bytes(key_mem)
                                slab = Slab(slab_id, disk_id, sectors, bytes(entity_id), entity_off, key)
@@ -123,25 +123,7 @@ class Wrapper:
                                (*_, id_, seq, end_off, state) = min(rows, key = lambda r: r[-2])
                                hasher_ref = None if id_ is None else HasherRef(id_, seq, end_off, state)
 
-                               yield (slab, hasher_ref)
-
-       def get_torrent_info(self, disk_id: int) -> Iterable[tuple[bytes, bytes]]:
-               stmt = """
-                       with hashed as (
-                               select digest(info, 'sha1') as info_hash, info
-                               from bittorrent.torrent_info
-                       )
-                       select
-                               distinct on (info_hash)
-                               info_hash, info
-                               from diskjumble.slab left outer join hashed on entity_id = info_hash
-                               where disk_id = %s
-                       ;
-               """
-               with self.conn.cursor() as cursor:
-                       cursor.execute(stmt, (disk_id,))
-                       for (info_hash, info) in cursor:
-                               yield (bytes(info_hash), bytes(info))
+                               yield (slab, bytes(info_mem), hasher_ref)
 
        def insert_verify_piece(self, ts: dt.datetime, entity_id: bytes, piece_num: int) -> int:
                """Insert new verify piece, returning the ID of the inserted row."""
index c63f8768b60c92293a3541825d33dc8f3cd351b5..63f9ffb0d227d5ad59d92574831e705c6439bccc 100644 (file)
@@ -33,6 +33,7 @@ class _SlabChunk:
 class _PieceTask:
        """The chunks needed to hash as fully as possible an entity piece."""
        entity_id: bytes
+       info_dict: bencode.Bdict
        piece_num: int
        hasher_ref: Optional[HasherRef]
        chunks: list[_SlabChunk]
@@ -46,15 +47,13 @@ class _BadSector(Exception):
 def do_verify(conn, disk_id: int, sector_size: int, disk_file: io.BufferedIOBase, read_size: int, read_tries: int) -> None:
        db = DbWrapper(conn)
 
-       info_dicts = {
-               info_hash: bencode.decode(info)
-               for (info_hash, info) in db.get_torrent_info(disk_id)
-       }
-
        tasks = []
-       slabs_and_hashers = db.get_slabs_and_hashers(disk_id, sector_size)
-       for (entity_id, group) in itertools.groupby(slabs_and_hashers, lambda t: t[0].entity_id):
-               info = info_dicts[entity_id]
+       worklist = db.get_v1_worklist(disk_id, sector_size)
+       for (entity_id, group_iter) in itertools.groupby(worklist, lambda t: t[0].entity_id):
+               group = list(group_iter)
+
+               [info_bytes] = {v for (_, v, _) in group}
+               info = bencode.decode(info_bytes)
                piece_len = info[b"piece length"]
                assert piece_len % sector_size == 0
                if b"length" in info:
@@ -65,7 +64,7 @@ def do_verify(conn, disk_id: int, sector_size: int, disk_file: io.BufferedIOBase
                offset = None
                use_hasher = None
                chunks = []
-               for (slab, hasher_ref) in group:
+               for (slab, _, hasher_ref) in group:
                        if slab.crypt_key is not None:
                                raise NotImplementedError("verify of encrypted data")
 
@@ -73,7 +72,7 @@ def do_verify(conn, disk_id: int, sector_size: int, disk_file: io.BufferedIOBase
                        while offset is None or offset < slab_end:
                                if offset is not None and slab.entity_offset > offset:
                                        if chunks:
-                                               tasks.append(_PieceTask(entity_id, offset // piece_len, use_hasher, chunks, False))
+                                               tasks.append(_PieceTask(entity_id, info, offset // piece_len, use_hasher, chunks, False))
                                        offset = None
                                        use_hasher = None
                                        chunks = []
@@ -94,13 +93,13 @@ def do_verify(conn, disk_id: int, sector_size: int, disk_file: io.BufferedIOBase
                                        chunk_end = min(piece_end, slab_end)
                                        chunks.append(_SlabChunk(slab, slice(offset - slab.entity_offset, chunk_end - slab.entity_offset)))
                                        if chunk_end == piece_end:
-                                               tasks.append(_PieceTask(entity_id, offset // piece_len, use_hasher, chunks, True))
+                                               tasks.append(_PieceTask(entity_id, info, offset // piece_len, use_hasher, chunks, True))
                                                use_hasher = None
                                                chunks = []
                                        offset = chunk_end
 
                if chunks:
-                       tasks.append(_PieceTask(entity_id, offset // piece_len, use_hasher, chunks, False))
+                       tasks.append(_PieceTask(entity_id, info, offset // piece_len, use_hasher, chunks, False))
 
        @dataclass
        class NewVerifyPiece:
@@ -162,7 +161,7 @@ def do_verify(conn, disk_id: int, sector_size: int, disk_file: io.BufferedIOBase
                        hasher_state = hasher.ctx.serialize()
                        if task.complete:
                                s = slice(task.piece_num * 20, task.piece_num * 20 + 20)
-                               expected_hash = info_dicts[task.entity_id][b"pieces"][s]
+                               expected_hash = task.info_dict[b"pieces"][s]
                                if hasher.digest() == expected_hash:
                                        write_piece_data = False
                                        new_pass_ranges.extend(sector_ranges)