From: Jakob Cornell Date: Sat, 16 Apr 2022 21:43:29 +0000 (-0500) Subject: Verify tool v1/v2 unification refactor X-Git-Url: https://jcornell.net/gitweb/gitweb.cgi?a=commitdiff_plain;h=a4f0ed09ae34f1dbba14ef7a4578c1458cda3a9e;p=eros.git Verify tool v1/v2 unification refactor - Process pieces in disk order by first block number - Add command line argument to target specific disk blocks - Ignore pieces that aren't yet fully written to disk - Fix bug in v2 worklist generation validation hash computation - Wrap verification in a database transaction - Partial work to restore tests to an operational state --- diff --git a/disk_jumble/src/disk_jumble/__init__.py b/disk_jumble/src/disk_jumble/__init__.py index 4b7b490..8cf6f46 100644 --- a/disk_jumble/src/disk_jumble/__init__.py +++ b/disk_jumble/src/disk_jumble/__init__.py @@ -1 +1 @@ -SECTOR_SIZE = 16 * 1024 # in bytes +BLOCK_SIZE = 16 * 1024 # in bytes diff --git a/disk_jumble/src/disk_jumble/db.py b/disk_jumble/src/disk_jumble/db.py index 4a82eba..950f81d 100644 --- a/disk_jumble/src/disk_jumble/db.py +++ b/disk_jumble/src/disk_jumble/db.py @@ -1,28 +1,8 @@ from __future__ import annotations from dataclasses import dataclass from typing import Any, Iterable, Optional -import datetime as dt -import itertools -from psycopg2.extras import execute_batch, Json, NumericRange - - -@dataclass -class Slab: - id: int - disk_id: int - sectors: range - entity_id: bytes - entity_offset: int - crypt_key: Optional[bytes] - - -@dataclass -class HasherRef: - id: int - seq: int - entity_offset: int - state: dict +from psycopg2.extras import execute_batch @dataclass @@ -68,127 +48,3 @@ class Wrapper: for i in infos ] execute_batch(cursor, stmt, param_sets) - - def get_v1_worklist(self, disk_id: int, sector_size: int) -> Iterable[tuple[Slab, bytes, Optional[HasherRef]]]: - """ - Find slabs on the specified disk along with corresponding bencoded info dicts, and for each also return any - lowest-offset saved hasher that left off directly before or within the slab's entity data. - """ - stmt = """ - with - incomplete_edge as ( - -- join up incomplete piece info and precompute where the hasher left off within the entity - select - verify_id, seq, slab.entity_id, hasher_state, - entity_offset + (upper(c.disk_sectors) - lower(slab.disk_blocks)) * %(sector_size)s as end_off - from - diskjumble.verify_piece_incomplete - natural left join diskjumble.verify_piece p - natural join diskjumble.verify_piece_content c - natural left join diskjumble.disk - left join diskjumble.slab on ( - c.disk_id = slab.disk_id - and upper(c.disk_sectors) <@ int8range(lower(slab.disk_blocks), upper(slab.disk_blocks), '[]') - ) - where seq >= all (select seq from diskjumble.verify_piece_content where verify_id = p.verify_id) - ) - select - slab_id, disk_id, disk_blocks, slab.entity_id, torrent_info.info, entity_offset, crypt_key, verify_id, - seq, end_off, hasher_state - from - diskjumble.slab - join bittorrent.torrent_info on digest(info, 'sha1') = slab.entity_id - natural left join diskjumble.disk - left join incomplete_edge on - incomplete_edge.entity_id = slab.entity_id - and incomplete_edge.end_off <@ int8range( - slab.entity_offset, - slab.entity_offset + (upper(disk_blocks) - lower(disk_blocks)) * %(sector_size)s - ) - and (incomplete_edge.end_off - slab.entity_offset) %% %(sector_size)s = 0 - where disk_id = %(disk_id)s - order by disk_blocks - ; - """ - with self.conn.cursor() as cursor: - cursor.execute(stmt, {"disk_id": disk_id, "sector_size": sector_size}) - for (_, rows_iter) in itertools.groupby(cursor, lambda r: r[0]): - rows = list(rows_iter) - [(slab_id, disk_id, sectors_pg, entity_id, info_mem, entity_off, key_mem)] = {r[:7] for r in rows} - sectors = range(sectors_pg.lower, sectors_pg.upper) - key = None if key_mem is None else bytes(key_mem) - slab = Slab(slab_id, disk_id, sectors, bytes(entity_id), entity_off, key) - - # `None' if no hasher match in outer join, otherwise earliest match - (*_, id_, seq, end_off, state) = min(rows, key = lambda r: r[-2]) - hasher_ref = None if id_ is None else HasherRef(id_, seq, end_off, state) - - yield (slab, bytes(info_mem), hasher_ref) - - def insert_verify_piece(self, ts: dt.datetime, entity_id: bytes, piece_num: int) -> int: - """Insert new verify piece, returning the ID of the inserted row.""" - - with self.conn.cursor() as cursor: - stmt = "insert into diskjumble.verify_piece values (default, %s, %s, %s) returning verify_id;" - cursor.execute(stmt, (ts, entity_id, piece_num)) - [(row_id,)] = cursor.fetchall() - return row_id - - def insert_verify_piece_content(self, verify_id: int, seq_start: int, disk_id: int, ranges: Iterable[range]) -> None: - with self.conn.cursor() as cursor: - execute_batch( - cursor, - "insert into diskjumble.verify_piece_content values (%s, %s, %s, %s);", - [ - (verify_id, seq, disk_id, NumericRange(r.start, r.stop)) - for (seq, r) in enumerate(ranges, start = seq_start) - ] - ) - - def mark_verify_piece_failed(self, verify_id: int) -> None: - with self.conn.cursor() as cursor: - cursor.execute("insert into diskjumble.verify_piece_fail values (%s);", (verify_id,)) - - def upsert_hasher_state(self, verify_id: int, state: dict) -> None: - stmt = """ - insert into diskjumble.verify_piece_incomplete values (%s, %s) - on conflict (verify_id) do update set hasher_state = excluded.hasher_state - ; - """ - with self.conn.cursor() as cursor: - cursor.execute(stmt, (verify_id, Json(state))) - - def delete_verify_piece(self, verify_id: int) -> None: - with self.conn.cursor() as cursor: - cursor.execute("delete from diskjumble.verify_piece_incomplete where verify_id = %s;", (verify_id,)) - cursor.execute("delete from diskjumble.verify_piece_content where verify_id = %s;", (verify_id,)) - cursor.execute("delete from diskjumble.verify_piece where verify_id = %s", (verify_id,)) - - def move_piece_content_for_pass(self, verify_id: int) -> None: - stmt = """ - with content_out as ( - delete from diskjumble.verify_piece_content c - using diskjumble.verify_piece p - where ( - c.verify_id = p.verify_id - and p.verify_id = %s - ) - returning at, disk_id, disk_sectors - ) - insert into diskjumble.verify_pass (at, disk_id, disk_sectors) - select at, disk_id, disk_sectors from content_out - ; - """ - with self.conn.cursor() as cursor: - cursor.execute(stmt, (verify_id,)) - - def insert_pass_data(self, ts: dt.datetime, disk_id: int, sectors: range) -> None: - with self.conn.cursor() as cursor: - cursor.execute( - "insert into diskjumble.verify_pass values (default, %s, %s, %s);", - (ts, disk_id, NumericRange(sectors.start, sectors.stop)) - ) - - def clear_incomplete(self, verify_id: int) -> None: - with self.conn.cursor() as cursor: - cursor.execute("delete from diskjumble.verify_piece_incomplete where verify_id = %s;", (verify_id,)) diff --git a/disk_jumble/src/disk_jumble/nettle.py b/disk_jumble/src/disk_jumble/nettle.py deleted file mode 100644 index dfabbca..0000000 --- a/disk_jumble/src/disk_jumble/nettle.py +++ /dev/null @@ -1,66 +0,0 @@ -"""Python wrappers for some of GnuTLS Nettle.""" - -from ctypes.util import find_library -from typing import Optional -import ctypes - - -_LIB = ctypes.CDLL(find_library("nettle")) - - -class _Sha1Defs: - _DIGEST_SIZE = 20 # in bytes - _BLOCK_SIZE = 64 # in bytes - _DIGEST_LENGTH = 5 - - _StateArr = ctypes.c_uint32 * _DIGEST_LENGTH - - _BlockArr = ctypes.c_uint8 * _BLOCK_SIZE - - -class Sha1Hasher(_Sha1Defs): - class Context(ctypes.Structure): - _fields_ = [ - ("state", _Sha1Defs._StateArr), - ("count", ctypes.c_uint64), - ("index", ctypes.c_uint), - ("block", _Sha1Defs._BlockArr), - ] - - @classmethod - def deserialize(cls, data): - return cls( - _Sha1Defs._StateArr(*data["state"]), - data["count"], - data["index"], - _Sha1Defs._BlockArr(*data["block"]), - ) - - def serialize(self): - return { - "state": list(self.state), - "count": self.count, - "index": self.index, - "block": list(self.block), - } - - @classmethod - def _new_context(cls): - ctx = cls.Context() - _LIB.nettle_sha1_init(ctypes.byref(ctx)) - return ctx - - def __init__(self, ctx_dict: Optional[dict]): - if ctx_dict: - self.ctx = self.Context.deserialize(ctx_dict) - else: - self.ctx = self._new_context() - - def update(self, data): - _LIB.nettle_sha1_update(ctypes.byref(self.ctx), len(data), data) - - def digest(self): - """Return the current digest and reset the hasher state.""" - out = (ctypes.c_uint8 * self._DIGEST_SIZE)() - _LIB.nettle_sha1_digest(ctypes.byref(self.ctx), self._DIGEST_SIZE, out) - return bytes(out) diff --git a/disk_jumble/src/disk_jumble/tests/test_verify_v1.py b/disk_jumble/src/disk_jumble/tests/test_verify_v1.py index bc18e28..10a468a 100644 --- a/disk_jumble/src/disk_jumble/tests/test_verify_v1.py +++ b/disk_jumble/src/disk_jumble/tests/test_verify_v1.py @@ -25,8 +25,7 @@ import psycopg2 import psycopg2.extras from disk_jumble import bencode -from disk_jumble.nettle import Sha1Hasher -from disk_jumble.verify import do_verify +from disk_jumble.verify import _do_verify _BUF_SIZE = 16 * 1024 ** 2 # in bytes @@ -64,117 +63,6 @@ class Tests(unittest.TestCase): def test_basic_fresh_verify_large_read_size(self): self._basic_fresh_verify_helper(128) - def test_resume_fragmentation_unaligned_end(self): - """ - Test a run where a cached hash state is used, a piece is split on disk, and the end of the torrent isn't - sector-aligned. - """ - sector_size = 16 - read_size = 8 - - other_disk = self._write_disk(1) - disk = self._write_disk(5) - with _random_file(60, Random(0), on_disk = False) as torrent_file: - torrent = _Torrent(torrent_file, piece_size = 64) - self._write_torrent(torrent) - with self._conn.cursor() as cursor: - cursor.executemany( - "insert into diskjumble.slab values (default, %s, %s, %s, %s, null);", - [ - (other_disk.id, NumericRange(0, 1), torrent.info_hash, 0), - (disk.id, NumericRange(0, 2), torrent.info_hash, 16), - (disk.id, NumericRange(4, 5), torrent.info_hash, 48), - ] - ) - - # Prepare the saved hasher state by running a verify - do_verify(self._conn, other_disk.id, sector_size, torrent_file, read_size, read_tries = 1) - torrent_file.seek(0) - - cursor.execute("select count(*) from diskjumble.verify_piece_incomplete;") - [(row_count,)] = cursor.fetchall() - self.assertEqual(row_count, 1) - - disk_file = io.BytesIO() - torrent_file.seek(sector_size) - disk_file.write(torrent_file.read(sector_size * 2)) - disk_file.seek(disk_file.tell() + sector_size * 2) - disk_file.write(torrent_file.read()) - disk_file.seek(0) - do_verify(self._conn, disk.id, sector_size, disk_file, read_size, read_tries = 1) - - # Check that there are no verify pieces in the database. Because of integrity constraints, this also - # guarantees there aren't any stray saved hasher states, failed verifies, or piece contents. - cursor.execute("select count(*) from diskjumble.verify_piece;") - [(row_count,)] = cursor.fetchall() - self.assertEqual(row_count, 0) - - cursor.execute("select disk_id, disk_sectors from diskjumble.verify_pass;") - self.assertEqual( - cursor.fetchall(), - [(other_disk.id, NumericRange(0, 1)), (disk.id, NumericRange(0, 2)), (disk.id, NumericRange(4, 5))] - ) - - def test_resume_no_completion(self): - """ - Test a run where a saved hasher state is used and the target disk has subsequent entity data but not the full - remainder of the piece. - """ - sector_size = 16 - read_size = 7 - piece_size = 64 - - other_disk = self._write_disk(1) - disk = self._write_disk(2) - with _random_file(piece_size, Random(0), on_disk = False) as torrent_file: - torrent = _Torrent(torrent_file, piece_size) - self._write_torrent(torrent) - with self._conn.cursor() as cursor: - cursor.executemany( - "insert into diskjumble.slab values (default, %s, %s, %s, %s, null);", - [ - (other_disk.id, NumericRange(0, 1), torrent.info_hash, 0), - (disk.id, NumericRange(0, 2), torrent.info_hash, sector_size), - ] - ) - - do_verify(self._conn, other_disk.id, sector_size, torrent_file, read_size, read_tries = 1) - - cursor.execute("select count(*) from diskjumble.verify_piece_incomplete;") - [(row_count,)] = cursor.fetchall() - self.assertEqual(row_count, 1) - - disk_file = io.BytesIO() - torrent_file.seek(sector_size) - disk_file.write(torrent_file.read(sector_size * 2)) - disk_file.seek(0) - do_verify(self._conn, disk.id, sector_size, disk_file, read_size, read_tries = 1) - - cursor.execute("select count(*) from diskjumble.verify_pass;") - [(row_count,)] = cursor.fetchall() - self.assertEqual(row_count, 0) - - cursor.execute("select entity_id, piece from diskjumble.verify_piece;") - [(entity_id, piece_num)] = cursor.fetchall() - self.assertEqual(bytes(entity_id), torrent.info_hash) - self.assertEqual(piece_num, 0) - - cursor.execute("select disk_id, disk_sectors from diskjumble.verify_piece_content;") - self.assertCountEqual( - cursor.fetchall(), - [(other_disk.id, NumericRange(0, 1)), (disk.id, NumericRange(0, 2))] - ) - - cursor.execute("select count(*) from diskjumble.verify_piece_fail;") - [(row_count,)] = cursor.fetchall() - self.assertEqual(row_count, 0) - - hasher = Sha1Hasher(None) - torrent_file.seek(0) - hasher.update(torrent_file.read(sector_size * 3)) - cursor.execute("select hasher_state from diskjumble.verify_piece_incomplete;") - self.assertEqual(cursor.fetchall(), [(hasher.ctx.serialize(),)]) - def test_ignore_hasher_beginning_on_disk(self): """ Test a run where a saved hasher state is available for use but isn't used due to the beginning of the piece @@ -216,59 +104,6 @@ class Tests(unittest.TestCase): cursor.execute("select disk_id, disk_sectors from diskjumble.verify_pass;") self.assertEqual(cursor.fetchall(), [(disk.id, NumericRange(0, disk.sector_count))]) - def test_ignore_hasher_unaligned(self): - """ - Test a run where a saved hasher isn't used because its entity data offset isn't sector-aligned on the target - disk. - - 0 16 32 48 64 80 96 112 128 - pieces: [-------------- 0 -------------] - other disk: [--][--][--][--][--] - disk: [------][------] - """ - piece_size = 128 - - other_disk = self._write_disk(5) - od_ss = 16 - disk = self._write_disk(2) - d_ss = 32 - with _random_file(piece_size, Random(0), on_disk = False) as torrent_file: - torrent = _Torrent(torrent_file, piece_size) - self._write_torrent(torrent) - with self._conn.cursor() as cursor: - cursor.executemany( - "insert into diskjumble.slab values (default, %s, %s, %s, %s, null);", - [ - (other_disk.id, NumericRange(0, 5), torrent.info_hash, 0), - (disk.id, NumericRange(0, 2), torrent.info_hash, 64), - ] - ) - - do_verify(self._conn, other_disk.id, od_ss, torrent_file, read_size = 16, read_tries = 1) - cursor.execute("select count(*) from diskjumble.verify_piece_incomplete;") - [(row_count,)] = cursor.fetchall() - self.assertEqual(row_count, 1) - - disk_file = io.BytesIO(torrent_file.getvalue()[64:]) - do_verify(self._conn, disk.id, d_ss, disk_file, read_size = 16, read_tries = 1) - - cursor.execute(""" - select disk_id, disk_sectors - from diskjumble.verify_piece_incomplete natural join diskjumble.verify_piece_content; - """) - self.assertEqual( - cursor.fetchall(), - [(other_disk.id, NumericRange(0, 5))] - ) - - cursor.execute("select count(*) from diskjumble.verify_pass;") - [(row_count,)] = cursor.fetchall() - self.assertEqual(row_count, 0) - - cursor.execute("select count(*) from diskjumble.verify_piece_fail;") - [(row_count,)] = cursor.fetchall() - self.assertEqual(row_count, 0) - def test_transient_read_errors(self): """ Test a run where a read to the disk fails but fewer times than needed to mark the sector bad. diff --git a/disk_jumble/src/disk_jumble/verify.py b/disk_jumble/src/disk_jumble/verify.py index fb2b6d6..a11a789 100644 --- a/disk_jumble/src/disk_jumble/verify.py +++ b/disk_jumble/src/disk_jumble/verify.py @@ -1,10 +1,11 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Iterable, Optional +from typing import List, Optional import argparse import contextlib import datetime as dt import hashlib +import heapq import io import itertools import math @@ -12,375 +13,419 @@ import math from psycopg2.extras import NumericRange import psycopg2 -from disk_jumble import bencode, SECTOR_SIZE -from disk_jumble.db import HasherRef, Slab, Wrapper as DbWrapper -from disk_jumble.nettle import Sha1Hasher +from disk_jumble import bencode, BLOCK_SIZE -_V2_BLOCK_SIZE = 16 * 1024 # in bytes - _READ_BUFFER_SIZE = 16 * 1024 ** 2 # in bytes @dataclass -class _SlabChunk: - """A slice of a slab; comprising all or part of a piece to be hashed.""" - slab: Slab - slice: slice - - -@dataclass -class _PieceTask: - """The chunks needed to hash as fully as possible an entity piece.""" - entity_id: bytes - info_dict: bencode.Bdict - piece_num: int - hasher_ref: Optional[HasherRef] - chunks: list[_SlabChunk] - complete: bool # do these chunks complete the piece? - - -class _BadSector(Exception): - pass - - -def do_verify_v1(conn, disk_id: int, sector_size: int, disk_file: io.BufferedIOBase, read_size: int, read_tries: int) -> None: - db = DbWrapper(conn) - - tasks = [] - worklist = db.get_v1_worklist(disk_id, sector_size) - for (entity_id, group_iter) in itertools.groupby(worklist, lambda t: t[0].entity_id): - group = list(group_iter) - - [info_bytes] = {v for (_, v, _) in group} - info = bencode.decode(info_bytes) - piece_len = info[b"piece length"] - assert piece_len % sector_size == 0 +class _TorrentInfo: + id: bytes + length: int + piece_length: int + hashes: List[bytes] + + @classmethod + def build(cls, torrent_id: bytes, info: bencode.Bdict): if b"length" in info: - torrent_len = info[b"length"] - else: - torrent_len = sum(d[b"length"] for d in info[b"files"]) - - offset = None - use_hasher = None - chunks = [] - for (slab, _, hasher_ref) in group: - if slab.crypt_key is not None: - raise NotImplementedError("verify of encrypted data") - - slab_end = min(slab.entity_offset + len(slab.sectors) * sector_size, torrent_len) - while offset is None or offset < slab_end: - if offset is not None and slab.entity_offset > offset: - if chunks: - tasks.append(_PieceTask(entity_id, info, offset // piece_len, use_hasher, chunks, False)) - offset = None - use_hasher = None - chunks = [] - - if offset is None: - aligned = math.ceil(slab.entity_offset / piece_len) * piece_len - if hasher_ref and hasher_ref.entity_offset < aligned: - assert hasher_ref.entity_offset < torrent_len - use_hasher = hasher_ref - offset = hasher_ref.entity_offset - elif aligned < slab_end: - offset = aligned - else: - break # no usable data in this slab - - if offset is not None: - piece_end = min(offset + piece_len - offset % piece_len, torrent_len) - chunk_end = min(piece_end, slab_end) - chunks.append(_SlabChunk(slab, slice(offset - slab.entity_offset, chunk_end - slab.entity_offset))) - if chunk_end == piece_end: - tasks.append(_PieceTask(entity_id, info, offset // piece_len, use_hasher, chunks, True)) - use_hasher = None - chunks = [] - offset = chunk_end - - if chunks: - tasks.append(_PieceTask(entity_id, info, offset // piece_len, use_hasher, chunks, False)) - - @dataclass - class NewVerifyPiece: - entity_id: bytes - piece_num: int - sector_ranges: list[range] - hasher_state: Optional[dict] - failed: bool - - @dataclass - class VerifyUpdate: - seq_start: int - new_sector_ranges: list[range] - hasher_state: Optional[dict] - - passed_verifies = set() - failed_verifies = set() - new_pass_ranges = [] - vp_updates = {} - new_vps = [] - - run_ts = dt.datetime.now(dt.timezone.utc) - for task in tasks: - hasher = Sha1Hasher(task.hasher_ref.state if task.hasher_ref else None) - sector_ranges = [ - range( - chunk.slab.sectors.start + chunk.slice.start // sector_size, - chunk.slab.sectors.start + math.ceil(chunk.slice.stop / sector_size) - ) - for chunk in task.chunks - ] - - try: - for chunk in task.chunks: - slab_off = chunk.slab.sectors.start * sector_size - disk_file.seek(slab_off + chunk.slice.start) - end_pos = slab_off + chunk.slice.stop - while disk_file.tell() < end_pos: - pos = disk_file.tell() - for _ in range(read_tries): - try: - data = disk_file.read(min(end_pos - pos, read_size)) - except OSError: - disk_file.seek(pos) - else: - break - else: - raise _BadSector() - - assert data - hasher.update(data) - except _BadSector: - if task.hasher_ref: - failed_verifies.add(task.hasher_ref.id) - vp_updates[task.hasher_ref.id] = VerifyUpdate(task.hasher_ref.seq + 1, sector_ranges, None) - else: - new_vps.append(NewVerifyPiece(task.entity_id, task.piece_num, sector_ranges, None, True)) + length = info[b"length"] else: - hasher_state = hasher.ctx.serialize() - if task.complete: - s = slice(task.piece_num * 20, task.piece_num * 20 + 20) - expected_hash = task.info_dict[b"pieces"][s] - if hasher.digest() == expected_hash: - write_piece_data = False - new_pass_ranges.extend(sector_ranges) - if task.hasher_ref: - passed_verifies.add(task.hasher_ref.id) - else: - failed = True - write_piece_data = True - if task.hasher_ref: - failed_verifies.add(task.hasher_ref.id) - else: - failed = False - write_piece_data = True - - if write_piece_data: - if task.hasher_ref: - assert task.hasher_ref.id not in vp_updates - vp_updates[task.hasher_ref.id] = VerifyUpdate(task.hasher_ref.seq + 1, sector_ranges, hasher_state) - else: - new_vps.append(NewVerifyPiece(task.entity_id, task.piece_num, sector_ranges, hasher_state, failed)) + length = sum(d[b"length"] for d in info[b"files"]) - new_pass_ranges.sort(key = lambda r: r.start) - merged_ranges = [] - for r in new_pass_ranges: - if merged_ranges and r.start == merged_ranges[-1].stop: - merged_ranges[-1] = range(merged_ranges[-1].start, r.stop) - else: - merged_ranges.append(r) + hash_blob = info[b"pieces"] + hashes = [hash_blob[s : s + 20] for s in range(0, len(hash_blob), 20)] - for vp in new_vps: - verify_id = db.insert_verify_piece(run_ts, vp.entity_id, vp.piece_num) - db.insert_verify_piece_content(verify_id, 0, disk_id, vp.sector_ranges) - if vp.failed: - db.mark_verify_piece_failed(verify_id) - else: - db.upsert_hasher_state(verify_id, vp.hasher_state) + return cls(torrent_id, length, info[b"piece length"], hashes) - for (verify_id, update) in vp_updates.items(): - db.insert_verify_piece_content(verify_id, update.seq_start, disk_id, update.new_sector_ranges) - if update.hasher_state: - db.upsert_hasher_state(verify_id, update.hasher_state) - for verify_id in passed_verifies: - db.move_piece_content_for_pass(verify_id) - db.delete_verify_piece(verify_id) +@dataclass +class _Run: + entity_id: bytes + entity_length: int + piece_num: int + block_ranges: List[range] + hash: bytes - for r in merged_ranges: - db.insert_pass_data(run_ts, disk_id, r) - for verify_id in failed_verifies: - db.clear_incomplete(verify_id) - db.mark_verify_piece_failed(verify_id) +@dataclass +class _V1Run(_Run): + piece_length: int # for the entity overall -@dataclass -class _VerifyResult: - sector: int +class _V2Run(_Run): + pass -@dataclass -class _VerifyPass(_VerifyResult): +class _BadSector(Exception): pass -@dataclass -class _VerifyFail(_VerifyResult): - entity_id: bytes - piece_num: int +def _get_v1_worklist(conn, disk_id: int, block_ranges: List[NumericRange]) -> List[_V1Run]: + """ + How this works: First, we fetch some info about each torrent on the disk that has data within the requested blocks. + Then, we make one query per torrent (because each one may have a different piece size) to determine which disk + blocks to verify. Here's how that works for a given torrent: + + 1. Select slabs on the current disk that have data from the appropriate torrent. + 2. Join with the list of requested block ranges to obtain which slabs and what portions of each are to be verified. + 3. Convert each requested range to a set of piece numbers and deduplicate. + 4. Join again with the slab table to get the complete block ranges required to cover each requested piece. + 5. Order by piece number (to facilitate grouping by piece), then entity offset (to produce an ordered run for each + piece). + + The runs are then reordered by the number of their first disk block. + """ + cursor = conn.cursor() + cursor.execute( + """ + select distinct on (entity_id) + entity_id, info + from ( + diskjumble.slab + left join bittorrent.torrent_info on digest(info, 'sha1') = entity_id + ) + where disk_id = %s and disk_blocks && any (%s) + """, + (disk_id, block_ranges) + ) + infos = [_TorrentInfo.build(id_, info) for (id_, info) in cursor] + for i in infos: + assert i.piece_length % BLOCK_SIZE == 0, f"entity {i.id.hex()} has invalid piece length" -def _gen_verify_results(conn, disk_id: int, disk_file: io.BufferedIOBase, read_size: int, read_tries: int) -> Iterable[_VerifyResult]: - with conn.cursor() as cursor: + runs = [] + for info in infos: cursor.execute( """ with - slab_plus as ( + relevant_slab as ( + select * from diskjumble.slab + where disk_id = %(disk_id)s and entity_id = %(entity_id)s + ), + targeted_slab as ( + select relevant_slab.*, target_range * disk_blocks as target_span + from + relevant_slab + join unnest(%(block_ranges)s) as target_range + on target_range && disk_blocks + ), + targeted_piece as ( + select distinct + ( + entity_offset + + ( + (generate_series(lower(target_span), upper(target_span) - 1) - lower(disk_blocks)) + * %(block_size)s + ) + ) / %(piece_length)s + as piece_num + from targeted_slab + ), + out as ( select - *, - int8range( - slab.entity_offset / %(block_size)s, - slab.entity_offset / %(block_size)s + upper(slab.disk_blocks) - lower(slab.disk_blocks) - ) as entity_blocks - from diskjumble.slab + entity_offset, + disk_blocks, + piece_num, + disk_blocks * int8range( + (piece_num * %(piece_length)s - entity_offset) / %(block_size)s + lower(disk_blocks), + ((piece_num + 1) * %(piece_length)s - entity_offset) / %(block_size)s + lower(disk_blocks) + ) as use_blocks, + crypt_key + from relevant_slab join targeted_piece ) - select - elh.entity_id, - generate_series( - lower(slab_plus.entity_blocks * elh.block_range), - upper(slab_plus.entity_blocks * elh.block_range) - 1 - ) as piece_num, - generate_series( - ( - lower(slab_plus.entity_blocks * elh.block_range) - - lower(slab_plus.entity_blocks) - + lower(slab_plus.disk_blocks) - ), + select piece_num, use_blocks, crypt_key from out + where not isempty(use_blocks) + order by + piece_num, + entity_offset + (lower(use_blocks) - lower(disk_blocks)) * %(block_size)s + """, + { + "disk_id": disk_id, + "entity_id": info.id, + "block_ranges": block_ranges, + "block_size": BLOCK_SIZE, + "piece_length": info.piece_length, + } + ) + rows = cursor.fetchall() + + if any(crypt_key is not None for (*_, crypt_key) in rows): + raise NotImplementedError("verify of encrypted data") + + for (piece_num, ranges) in itertools.groupby(rows, lambda t: t[0]): + run = _V1Run( + entity_id = info.id, + entity_length = info.length, + piece_length = info.piece_length, + piece_num = piece_num, + block_ranges = [range(r.lower, r.upper) for r in ranges], + hash = info.hashes[piece_num], + ) + + # Make sure we ended up with all the block ranges for the piece. If not, assume the piece isn't fully + # written yet and skip. + run_length = min(info.piece_length, info.length - piece_num * info.piece_length) + if sum(map(len, run.block_ranges)) == math.ceil(run_length / BLOCK_SIZE): + runs.append(run) + + runs.sort(lambda r: r.block_ranges) + return runs + + +def _get_v2_worklist(conn, disk_id: int, block_ranges: List[NumericRange]) -> List[_V2Run]: + cursor = conn.cursor() + cursor.execute( + """ + with + slab_plus as ( + select + *, + int8range( + slab.entity_offset / %(block_size)s, + slab.entity_offset / %(block_size)s + upper(slab.disk_blocks) - lower(slab.disk_blocks) + ) as entity_blocks + from diskjumble.slab + where disk_id = %(disk_id)s + ), + joined as ( + select + entity_id, + entity.length, + entity_blocks, + slab_plus.disk_blocks + entity_blocks * elh.block_range as check_erange, + from + entityv2_leaf_hashes elh + join slab_plus using (entity_id) + left outer join public.entity using (entity_id) + ), + filtered as ( + select * from joined where not isempty(check_erange) + ), + exploded as ( + select + entity_id, + length, + generate_series(lower(check_erange), upper(check_erange) - 1) as piece_num, ( - upper(slab_plus.entity_blocks * elh.block_range) - - lower(slab_plus.entity_blocks) - + lower(slab_plus.disk_blocks) - - 1 - ) - ) as sector, - entity.length as entity_length, - substring(hashes, generate_series(0, octet_length(hashes) / 32 - 1, 32), 32) as hash, - crypt_key - from ( - entityv2_leaf_hashes elh - join slab_plus on ( - slab_plus.entity_id = elh.entity_id - and slab_plus.entity_blocks && elh.block_range - ) - left outer join public.entity on elh.entity_id = entity.entity_id + generate_series(lower(check_erange), upper(check_erange) - 1) + - lower(entity_blocks) + lower(disk_blocks) + ) as block, + substring( + hashes, + generate_series(lower(check_erange), upper(check_erange) - 1) * 32, + 32 + ) as hash, + crypt_key + from filtered ) - where slab_plus.disk_id = %(disk_id)s - order by sector - """, - {"block_size": _V2_BLOCK_SIZE, "disk_id": disk_id} + select * from exploded + where block <@ any (%(block_ranges)s) + order by block + """, + { + "block_size": BLOCK_SIZE, + "disk_id": disk_id, + "block_ranges": block_ranges, + } + ) + rows = cursor.fetchall() + + if any(crypt_key is not None for (*_, crypt_key) in rows): + raise NotImplementedError("verify of encrypted data") + + return [ + _V2Run( + entity_id = entity_id, + entity_length = entity_length, + piece_num = piece_num, + block_ranges = [range(block, block + 1)], + hash = hash_, ) - for (entity_id, piece_num, sector, entity_len, hash_, crypt_key) in cursor: - if crypt_key is not None: - raise NotImplementedError("verify of encrypted data") - - read_start = sector * _V2_BLOCK_SIZE - read_end = read_start + min(_V2_BLOCK_SIZE, entity_len - piece_num * _V2_BLOCK_SIZE) - disk_file.seek(read_start) - hasher = hashlib.sha256() - try: - while disk_file.tell() < read_end: - pos = disk_file.tell() - for _ in range(read_tries): - try: - data = disk_file.read(min(read_end - pos, read_size)) - except OSError: - disk_file.seek(pos) - else: - break - else: - raise _BadSector() - - assert data - hasher.update(data) - except _BadSector: - pass_ = False - else: - pass_ = hasher.digest() == hash_ + for (entity_id, entity_length, piece_num, block, entity_length, hash_, crypt_key) in rows + ] + + +def _do_verify(conn, disk_id: int, block_ranges: Optional[List[range]], disk_file: io.BufferedIOBase, read_size: int, read_tries: int): + @dataclass + class Pass: + blocks: range - if pass_: - yield _VerifyPass(sector) + @dataclass + class Fail: + run: _Run + + def merge_results(results): + curr_pass = None + for r in results: + if isinstance(r, Pass) and curr_pass and r.blocks.start <= curr_pass.stop: + curr_pass = range(curr_pass.start, max(r.blocks.stop, curr_pass.stop)) else: - yield _VerifyFail(sector, entity_id, piece_num) + if curr_pass: + yield Pass(curr_pass) + curr_pass = None + if isinstance(r, Pass): + curr_pass = r.blocks + else: + yield r + if curr_pass: + yield Pass(curr_pass) + + if block_ranges is None: + pg_block_ranges = [NumericRange()] + else: + pg_block_ranges = [NumericRange(r.start, r.stop) for r in block_ranges] + + worklist = list(heapq.merge( + _get_v1_worklist(conn, disk_id, pg_block_ranges), + _get_v2_worklist(conn, disk_id, pg_block_ranges), + key = lambda run: run.block_ranges, + )) + + if block_ranges is not None: + requested_blocks = { + block + for r in block_ranges + for block in r + } + covered_blocks = { + block + for run in worklist + for block_range in run.block_ranges + for block in block_range + } + missing = requested_blocks - covered_blocks + if missing: + raise RuntimeError(f"unable to locate blocks: {len(missing)} in the range {min(missing)} to {max(missing)}") + + def generate_results(): + for run in worklist: + if isinstance(run, _V1Run): + hasher = hashlib.sha1() + entity_off = run.piece_num * run.piece_length + else: + hasher = hashlib.sha256() + entity_off = run.piece_num * BLOCK_SIZE + try: + for range_ in run.block_ranges: + for block in range_: + read_start = block * BLOCK_SIZE + read_end = read_start + min(BLOCK_SIZE, run.entity_length - entity_off) + disk_file.seek(read_start) + while disk_file.tell() < read_end: + pos = disk_file.tell() + for _ in range(read_tries): + try: + data = disk_file.read(min(read_end - pos, read_size)) + except OSError: + disk_file.seek(pos) + else: + break + else: + raise _BadSector() + + assert data + hasher.update(data) + entity_off += BLOCK_SIZE + except _BadSector: + yield Fail(run) + else: + if hasher.digest() == run.hash: + yield from (Pass(r) for r in run.block_ranges) + else: + yield Fail(run) -def do_verify_v2(conn, disk_id: int, disk_file: io.BufferedIOBase, read_size: int, read_tries: int) -> None: + cursor = conn.cursor() ts = dt.datetime.now(dt.timezone.utc) - with conn.cursor() as cursor: - def save_pass_range(r): + for result in merge_results(generate_results()): + if isinstance(result, Pass): cursor.execute( - "insert into diskjumble.verify_pass values (default, %s, %s, %s);", - (ts, disk_id, NumericRange(r.start, r.stop)) + "insert into diskjumble.verify_pass values (default, %s, %s, %s)", + (ts, disk_id, result.blocks) + ) + else: + assert isinstance(result, Fail) + run = result.run + cursor.execute( + """ + with + new_piece as ( + insert into diskjumble.verify_piece + values (default, %(ts)s, %(entity_id)s, %(piece_num)s) + returning verify_id + ), + _ as ( + insert into diskjumble.verify_piece_content + select verify_id, ordinality - 1, %(disk_id)s, block_range + from + new_piece, + unnest(%(ranges)) with ordinality as block_range + ) + insert into diskjumble.verify_piece_fail + select verify_id from new_piece + """, + { + "ts": ts, + "entity_id": run.entity_id, + "piece_num": run.piece_num, + "disk_id": disk_id, + "ranges": [NumericRange(r.start, r.stop) for r in run.block_ranges], + } ) - - pass_sectors = None - for result in _gen_verify_results(conn, disk_id, disk_file, read_size, read_tries): - if isinstance(result, _VerifyPass): - if pass_sectors is None: - pass_sectors = range(result.sector, result.sector + 1) - elif result.sector == pass_sectors.stop: - pass_sectors = range(pass_sectors.start, result.sector + 1) - else: - save_pass_range(pass_sectors) - pass_sectors = range(result.sector, result.sector + 1) - else: - assert isinstance(result, _VerifyFail) - if pass_sectors: - save_pass_range(pass_sectors) - pass_sectors = None - - cursor.execute( - """ - with - new_piece as ( - insert into diskjumble.verify_piece - values (default, %s, %s, %s) - returning verify_id - ), - _ as ( - insert into diskjumble.verify_piece_content - select verify_id, 0, %s, %s from new_piece - ) - insert into diskjumble.verify_piece_fail - select verify_id from new_piece - """, - (ts, result.entity_id, result.piece_num, disk_id, NumericRange(result.sector, result.sector + 1)) - ) - if pass_sectors: - save_pass_range(pass_sectors) if __name__ == "__main__": - def read_tries(raw_arg): - val = int(raw_arg) + def read_tries(raw_val): + val = int(raw_val) if val > 0: return val else: raise ValueError() + def block_ranges(raw_val): + def parse_one(part): + if "-" in part: + (s, e) = map(int, part.split("-")) + if e <= s: + raise ValueError() + else: + return range(s, e) + else: + s = int(part) + return range(s, s + 1) + + return list(map(parse_one, raw_val.split(","))) + parser = argparse.ArgumentParser() parser.add_argument("disk_id", type = int) parser.add_argument( "read_tries", type = read_tries, - help = "number of times to attempt a particular disk read before giving up on the sector", + help = "number of times to attempt a particular disk read before giving up on the block", + ) + parser.add_argument( + "block_ranges", + type = block_ranges, + nargs = "?", + help = "if specified, only verify what's needed to cover these disk blocks (\"0,2-4\" means 0, 2, and 3)", ) args = parser.parse_args() + if args.block_ranges is None: + block_ranges = None + else: + block_ranges = [] + for r in sorted(args.block_ranges, key = lambda r: r.start): + if block_ranges and r.start <= block_ranges[-1].stop: + prev = block_ranges.pop() + block_ranges.append(range(prev.start, max(prev.stop, r.stop))) + else: + block_ranges.append(r) + with contextlib.closing(psycopg2.connect("")) as conn: - conn.autocommit = True path = f"/dev/mapper/diskjumble-{args.disk_id}" with open(path, "rb", buffering = _READ_BUFFER_SIZE) as disk_file: - do_verify_v1(conn, args.disk_id, SECTOR_SIZE, disk_file, _READ_BUFFER_SIZE, args.read_tries) - do_verify_v2(conn, args.disk_id, disk_file, _READ_BUFFER_SIZE, args.read_tries) + with conn: + _do_verify(conn, args.disk_id, block_ranges, disk_file, _READ_BUFFER_SIZE, args.read_tries)