From 53f00d848ed2f0acb2530a0e54b3bc794082a15d Mon Sep 17 00:00:00 2001 From: Jakob Cornell Date: Mon, 21 Feb 2022 14:21:36 -0600 Subject: [PATCH] Make verify v1 compatible with mixed v1/v2 disks --- disk_jumble/src/disk_jumble/db.py | 34 +++++++-------------------- disk_jumble/src/disk_jumble/verify.py | 25 ++++++++++---------- 2 files changed, 20 insertions(+), 39 deletions(-) diff --git a/disk_jumble/src/disk_jumble/db.py b/disk_jumble/src/disk_jumble/db.py index 66e7562..278d12d 100644 --- a/disk_jumble/src/disk_jumble/db.py +++ b/disk_jumble/src/disk_jumble/db.py @@ -69,12 +69,11 @@ class Wrapper: ] execute_batch(cursor, stmt, param_sets) - def get_slabs_and_hashers(self, disk_id: int, sector_size: int) -> Iterable[tuple[Slab, Optional[HasherRef]]]: + def get_v1_worklist(self, disk_id: int, sector_size: int) -> Iterable[tuple[Slab, bytes, Optional[HasherRef]]]: """ - Find slabs on the specified disk, and for each also return any lowest-offset saved hasher that left off directly - before or within the slab's entity data. + Find slabs on the specified disk along with corresponding bencoded info dicts, and for each also return any + lowest-offset saved hasher that left off directly before or within the slab's entity data. """ - stmt = """ with incomplete_edge as ( @@ -94,10 +93,11 @@ class Wrapper: where seq >= all (select seq from diskjumble.verify_piece_content where verify_id = p.verify_id) ) select - slab_id, disk_id, disk_blocks, slab.entity_id, entity_offset, crypt_key, verify_id, seq, end_off, - hasher_state + slab_id, disk_id, disk_blocks, slab.entity_id, torrent_info.info, entity_offset, crypt_key, verify_id, + seq, end_off, hasher_state from diskjumble.slab + join bittorrent.torrent_info on digest(info, 'sha1') = slab.entity_id natural left join diskjumble.disk left join incomplete_edge on incomplete_edge.entity_id = slab.entity_id @@ -114,7 +114,7 @@ class Wrapper: cursor.execute(stmt, {"disk_id": disk_id, "sector_size": sector_size}) for (_, rows_iter) in itertools.groupby(cursor, lambda r: r[0]): rows = list(rows_iter) - [(slab_id, disk_id, sectors_pg, entity_id, entity_off, key_mem)] = {r[:6] for r in rows} + [(slab_id, disk_id, sectors_pg, entity_id, info_mem, entity_off, key_mem)] = {r[:7] for r in rows} sectors = range(sectors_pg.lower, sectors_pg.upper) key = None if key_mem is None else bytes(key_mem) slab = Slab(slab_id, disk_id, sectors, bytes(entity_id), entity_off, key) @@ -123,25 +123,7 @@ class Wrapper: (*_, id_, seq, end_off, state) = min(rows, key = lambda r: r[-2]) hasher_ref = None if id_ is None else HasherRef(id_, seq, end_off, state) - yield (slab, hasher_ref) - - def get_torrent_info(self, disk_id: int) -> Iterable[tuple[bytes, bytes]]: - stmt = """ - with hashed as ( - select digest(info, 'sha1') as info_hash, info - from bittorrent.torrent_info - ) - select - distinct on (info_hash) - info_hash, info - from diskjumble.slab left outer join hashed on entity_id = info_hash - where disk_id = %s - ; - """ - with self.conn.cursor() as cursor: - cursor.execute(stmt, (disk_id,)) - for (info_hash, info) in cursor: - yield (bytes(info_hash), bytes(info)) + yield (slab, bytes(info_mem), hasher_ref) def insert_verify_piece(self, ts: dt.datetime, entity_id: bytes, piece_num: int) -> int: """Insert new verify piece, returning the ID of the inserted row.""" diff --git a/disk_jumble/src/disk_jumble/verify.py b/disk_jumble/src/disk_jumble/verify.py index c63f876..63f9ffb 100644 --- a/disk_jumble/src/disk_jumble/verify.py +++ b/disk_jumble/src/disk_jumble/verify.py @@ -33,6 +33,7 @@ class _SlabChunk: class _PieceTask: """The chunks needed to hash as fully as possible an entity piece.""" entity_id: bytes + info_dict: bencode.Bdict piece_num: int hasher_ref: Optional[HasherRef] chunks: list[_SlabChunk] @@ -46,15 +47,13 @@ class _BadSector(Exception): def do_verify(conn, disk_id: int, sector_size: int, disk_file: io.BufferedIOBase, read_size: int, read_tries: int) -> None: db = DbWrapper(conn) - info_dicts = { - info_hash: bencode.decode(info) - for (info_hash, info) in db.get_torrent_info(disk_id) - } - tasks = [] - slabs_and_hashers = db.get_slabs_and_hashers(disk_id, sector_size) - for (entity_id, group) in itertools.groupby(slabs_and_hashers, lambda t: t[0].entity_id): - info = info_dicts[entity_id] + worklist = db.get_v1_worklist(disk_id, sector_size) + for (entity_id, group_iter) in itertools.groupby(worklist, lambda t: t[0].entity_id): + group = list(group_iter) + + [info_bytes] = {v for (_, v, _) in group} + info = bencode.decode(info_bytes) piece_len = info[b"piece length"] assert piece_len % sector_size == 0 if b"length" in info: @@ -65,7 +64,7 @@ def do_verify(conn, disk_id: int, sector_size: int, disk_file: io.BufferedIOBase offset = None use_hasher = None chunks = [] - for (slab, hasher_ref) in group: + for (slab, _, hasher_ref) in group: if slab.crypt_key is not None: raise NotImplementedError("verify of encrypted data") @@ -73,7 +72,7 @@ def do_verify(conn, disk_id: int, sector_size: int, disk_file: io.BufferedIOBase while offset is None or offset < slab_end: if offset is not None and slab.entity_offset > offset: if chunks: - tasks.append(_PieceTask(entity_id, offset // piece_len, use_hasher, chunks, False)) + tasks.append(_PieceTask(entity_id, info, offset // piece_len, use_hasher, chunks, False)) offset = None use_hasher = None chunks = [] @@ -94,13 +93,13 @@ def do_verify(conn, disk_id: int, sector_size: int, disk_file: io.BufferedIOBase chunk_end = min(piece_end, slab_end) chunks.append(_SlabChunk(slab, slice(offset - slab.entity_offset, chunk_end - slab.entity_offset))) if chunk_end == piece_end: - tasks.append(_PieceTask(entity_id, offset // piece_len, use_hasher, chunks, True)) + tasks.append(_PieceTask(entity_id, info, offset // piece_len, use_hasher, chunks, True)) use_hasher = None chunks = [] offset = chunk_end if chunks: - tasks.append(_PieceTask(entity_id, offset // piece_len, use_hasher, chunks, False)) + tasks.append(_PieceTask(entity_id, info, offset // piece_len, use_hasher, chunks, False)) @dataclass class NewVerifyPiece: @@ -162,7 +161,7 @@ def do_verify(conn, disk_id: int, sector_size: int, disk_file: io.BufferedIOBase hasher_state = hasher.ctx.serialize() if task.complete: s = slice(task.piece_num * 20, task.piece_num * 20 + 20) - expected_hash = info_dicts[task.entity_id][b"pieces"][s] + expected_hash = task.info_dict[b"pieces"][s] if hasher.digest() == expected_hash: write_piece_data = False new_pass_ranges.extend(sector_ranges) -- 2.30.2