from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Iterable, Optional
-import datetime as dt
-import itertools
-from psycopg2.extras import execute_batch, Json, NumericRange
-
-
-@dataclass
-class Slab:
- id: int
- disk_id: int
- sectors: range
- entity_id: bytes
- entity_offset: int
- crypt_key: Optional[bytes]
-
-
-@dataclass
-class HasherRef:
- id: int
- seq: int
- entity_offset: int
- state: dict
+from psycopg2.extras import execute_batch
@dataclass
for i in infos
]
execute_batch(cursor, stmt, param_sets)
-
- def get_v1_worklist(self, disk_id: int, sector_size: int) -> Iterable[tuple[Slab, bytes, Optional[HasherRef]]]:
- """
- Find slabs on the specified disk along with corresponding bencoded info dicts, and for each also return any
- lowest-offset saved hasher that left off directly before or within the slab's entity data.
- """
- stmt = """
- with
- incomplete_edge as (
- -- join up incomplete piece info and precompute where the hasher left off within the entity
- select
- verify_id, seq, slab.entity_id, hasher_state,
- entity_offset + (upper(c.disk_sectors) - lower(slab.disk_blocks)) * %(sector_size)s as end_off
- from
- diskjumble.verify_piece_incomplete
- natural left join diskjumble.verify_piece p
- natural join diskjumble.verify_piece_content c
- natural left join diskjumble.disk
- left join diskjumble.slab on (
- c.disk_id = slab.disk_id
- and upper(c.disk_sectors) <@ int8range(lower(slab.disk_blocks), upper(slab.disk_blocks), '[]')
- )
- where seq >= all (select seq from diskjumble.verify_piece_content where verify_id = p.verify_id)
- )
- select
- slab_id, disk_id, disk_blocks, slab.entity_id, torrent_info.info, entity_offset, crypt_key, verify_id,
- seq, end_off, hasher_state
- from
- diskjumble.slab
- join bittorrent.torrent_info on digest(info, 'sha1') = slab.entity_id
- natural left join diskjumble.disk
- left join incomplete_edge on
- incomplete_edge.entity_id = slab.entity_id
- and incomplete_edge.end_off <@ int8range(
- slab.entity_offset,
- slab.entity_offset + (upper(disk_blocks) - lower(disk_blocks)) * %(sector_size)s
- )
- and (incomplete_edge.end_off - slab.entity_offset) %% %(sector_size)s = 0
- where disk_id = %(disk_id)s
- order by disk_blocks
- ;
- """
- with self.conn.cursor() as cursor:
- cursor.execute(stmt, {"disk_id": disk_id, "sector_size": sector_size})
- for (_, rows_iter) in itertools.groupby(cursor, lambda r: r[0]):
- rows = list(rows_iter)
- [(slab_id, disk_id, sectors_pg, entity_id, info_mem, entity_off, key_mem)] = {r[:7] for r in rows}
- sectors = range(sectors_pg.lower, sectors_pg.upper)
- key = None if key_mem is None else bytes(key_mem)
- slab = Slab(slab_id, disk_id, sectors, bytes(entity_id), entity_off, key)
-
- # `None' if no hasher match in outer join, otherwise earliest match
- (*_, id_, seq, end_off, state) = min(rows, key = lambda r: r[-2])
- hasher_ref = None if id_ is None else HasherRef(id_, seq, end_off, state)
-
- yield (slab, bytes(info_mem), hasher_ref)
-
- def insert_verify_piece(self, ts: dt.datetime, entity_id: bytes, piece_num: int) -> int:
- """Insert new verify piece, returning the ID of the inserted row."""
-
- with self.conn.cursor() as cursor:
- stmt = "insert into diskjumble.verify_piece values (default, %s, %s, %s) returning verify_id;"
- cursor.execute(stmt, (ts, entity_id, piece_num))
- [(row_id,)] = cursor.fetchall()
- return row_id
-
- def insert_verify_piece_content(self, verify_id: int, seq_start: int, disk_id: int, ranges: Iterable[range]) -> None:
- with self.conn.cursor() as cursor:
- execute_batch(
- cursor,
- "insert into diskjumble.verify_piece_content values (%s, %s, %s, %s);",
- [
- (verify_id, seq, disk_id, NumericRange(r.start, r.stop))
- for (seq, r) in enumerate(ranges, start = seq_start)
- ]
- )
-
- def mark_verify_piece_failed(self, verify_id: int) -> None:
- with self.conn.cursor() as cursor:
- cursor.execute("insert into diskjumble.verify_piece_fail values (%s);", (verify_id,))
-
- def upsert_hasher_state(self, verify_id: int, state: dict) -> None:
- stmt = """
- insert into diskjumble.verify_piece_incomplete values (%s, %s)
- on conflict (verify_id) do update set hasher_state = excluded.hasher_state
- ;
- """
- with self.conn.cursor() as cursor:
- cursor.execute(stmt, (verify_id, Json(state)))
-
- def delete_verify_piece(self, verify_id: int) -> None:
- with self.conn.cursor() as cursor:
- cursor.execute("delete from diskjumble.verify_piece_incomplete where verify_id = %s;", (verify_id,))
- cursor.execute("delete from diskjumble.verify_piece_content where verify_id = %s;", (verify_id,))
- cursor.execute("delete from diskjumble.verify_piece where verify_id = %s", (verify_id,))
-
- def move_piece_content_for_pass(self, verify_id: int) -> None:
- stmt = """
- with content_out as (
- delete from diskjumble.verify_piece_content c
- using diskjumble.verify_piece p
- where (
- c.verify_id = p.verify_id
- and p.verify_id = %s
- )
- returning at, disk_id, disk_sectors
- )
- insert into diskjumble.verify_pass (at, disk_id, disk_sectors)
- select at, disk_id, disk_sectors from content_out
- ;
- """
- with self.conn.cursor() as cursor:
- cursor.execute(stmt, (verify_id,))
-
- def insert_pass_data(self, ts: dt.datetime, disk_id: int, sectors: range) -> None:
- with self.conn.cursor() as cursor:
- cursor.execute(
- "insert into diskjumble.verify_pass values (default, %s, %s, %s);",
- (ts, disk_id, NumericRange(sectors.start, sectors.stop))
- )
-
- def clear_incomplete(self, verify_id: int) -> None:
- with self.conn.cursor() as cursor:
- cursor.execute("delete from diskjumble.verify_piece_incomplete where verify_id = %s;", (verify_id,))
import psycopg2.extras
from disk_jumble import bencode
-from disk_jumble.nettle import Sha1Hasher
-from disk_jumble.verify import do_verify
+from disk_jumble.verify import _do_verify
_BUF_SIZE = 16 * 1024 ** 2 # in bytes
def test_basic_fresh_verify_large_read_size(self):
self._basic_fresh_verify_helper(128)
- def test_resume_fragmentation_unaligned_end(self):
- """
- Test a run where a cached hash state is used, a piece is split on disk, and the end of the torrent isn't
- sector-aligned.
- """
- sector_size = 16
- read_size = 8
-
- other_disk = self._write_disk(1)
- disk = self._write_disk(5)
- with _random_file(60, Random(0), on_disk = False) as torrent_file:
- torrent = _Torrent(torrent_file, piece_size = 64)
- self._write_torrent(torrent)
- with self._conn.cursor() as cursor:
- cursor.executemany(
- "insert into diskjumble.slab values (default, %s, %s, %s, %s, null);",
- [
- (other_disk.id, NumericRange(0, 1), torrent.info_hash, 0),
- (disk.id, NumericRange(0, 2), torrent.info_hash, 16),
- (disk.id, NumericRange(4, 5), torrent.info_hash, 48),
- ]
- )
-
- # Prepare the saved hasher state by running a verify
- do_verify(self._conn, other_disk.id, sector_size, torrent_file, read_size, read_tries = 1)
- torrent_file.seek(0)
-
- cursor.execute("select count(*) from diskjumble.verify_piece_incomplete;")
- [(row_count,)] = cursor.fetchall()
- self.assertEqual(row_count, 1)
-
- disk_file = io.BytesIO()
- torrent_file.seek(sector_size)
- disk_file.write(torrent_file.read(sector_size * 2))
- disk_file.seek(disk_file.tell() + sector_size * 2)
- disk_file.write(torrent_file.read())
- disk_file.seek(0)
- do_verify(self._conn, disk.id, sector_size, disk_file, read_size, read_tries = 1)
-
- # Check that there are no verify pieces in the database. Because of integrity constraints, this also
- # guarantees there aren't any stray saved hasher states, failed verifies, or piece contents.
- cursor.execute("select count(*) from diskjumble.verify_piece;")
- [(row_count,)] = cursor.fetchall()
- self.assertEqual(row_count, 0)
-
- cursor.execute("select disk_id, disk_sectors from diskjumble.verify_pass;")
- self.assertEqual(
- cursor.fetchall(),
- [(other_disk.id, NumericRange(0, 1)), (disk.id, NumericRange(0, 2)), (disk.id, NumericRange(4, 5))]
- )
-
- def test_resume_no_completion(self):
- """
- Test a run where a saved hasher state is used and the target disk has subsequent entity data but not the full
- remainder of the piece.
- """
- sector_size = 16
- read_size = 7
- piece_size = 64
-
- other_disk = self._write_disk(1)
- disk = self._write_disk(2)
- with _random_file(piece_size, Random(0), on_disk = False) as torrent_file:
- torrent = _Torrent(torrent_file, piece_size)
- self._write_torrent(torrent)
- with self._conn.cursor() as cursor:
- cursor.executemany(
- "insert into diskjumble.slab values (default, %s, %s, %s, %s, null);",
- [
- (other_disk.id, NumericRange(0, 1), torrent.info_hash, 0),
- (disk.id, NumericRange(0, 2), torrent.info_hash, sector_size),
- ]
- )
-
- do_verify(self._conn, other_disk.id, sector_size, torrent_file, read_size, read_tries = 1)
-
- cursor.execute("select count(*) from diskjumble.verify_piece_incomplete;")
- [(row_count,)] = cursor.fetchall()
- self.assertEqual(row_count, 1)
-
- disk_file = io.BytesIO()
- torrent_file.seek(sector_size)
- disk_file.write(torrent_file.read(sector_size * 2))
- disk_file.seek(0)
- do_verify(self._conn, disk.id, sector_size, disk_file, read_size, read_tries = 1)
-
- cursor.execute("select count(*) from diskjumble.verify_pass;")
- [(row_count,)] = cursor.fetchall()
- self.assertEqual(row_count, 0)
-
- cursor.execute("select entity_id, piece from diskjumble.verify_piece;")
- [(entity_id, piece_num)] = cursor.fetchall()
- self.assertEqual(bytes(entity_id), torrent.info_hash)
- self.assertEqual(piece_num, 0)
-
- cursor.execute("select disk_id, disk_sectors from diskjumble.verify_piece_content;")
- self.assertCountEqual(
- cursor.fetchall(),
- [(other_disk.id, NumericRange(0, 1)), (disk.id, NumericRange(0, 2))]
- )
-
- cursor.execute("select count(*) from diskjumble.verify_piece_fail;")
- [(row_count,)] = cursor.fetchall()
- self.assertEqual(row_count, 0)
-
- hasher = Sha1Hasher(None)
- torrent_file.seek(0)
- hasher.update(torrent_file.read(sector_size * 3))
- cursor.execute("select hasher_state from diskjumble.verify_piece_incomplete;")
- self.assertEqual(cursor.fetchall(), [(hasher.ctx.serialize(),)])
-
def test_ignore_hasher_beginning_on_disk(self):
"""
Test a run where a saved hasher state is available for use but isn't used due to the beginning of the piece
cursor.execute("select disk_id, disk_sectors from diskjumble.verify_pass;")
self.assertEqual(cursor.fetchall(), [(disk.id, NumericRange(0, disk.sector_count))])
- def test_ignore_hasher_unaligned(self):
- """
- Test a run where a saved hasher isn't used because its entity data offset isn't sector-aligned on the target
- disk.
-
- 0 16 32 48 64 80 96 112 128
- pieces: [-------------- 0 -------------]
- other disk: [--][--][--][--][--]
- disk: [------][------]
- """
- piece_size = 128
-
- other_disk = self._write_disk(5)
- od_ss = 16
- disk = self._write_disk(2)
- d_ss = 32
- with _random_file(piece_size, Random(0), on_disk = False) as torrent_file:
- torrent = _Torrent(torrent_file, piece_size)
- self._write_torrent(torrent)
- with self._conn.cursor() as cursor:
- cursor.executemany(
- "insert into diskjumble.slab values (default, %s, %s, %s, %s, null);",
- [
- (other_disk.id, NumericRange(0, 5), torrent.info_hash, 0),
- (disk.id, NumericRange(0, 2), torrent.info_hash, 64),
- ]
- )
-
- do_verify(self._conn, other_disk.id, od_ss, torrent_file, read_size = 16, read_tries = 1)
- cursor.execute("select count(*) from diskjumble.verify_piece_incomplete;")
- [(row_count,)] = cursor.fetchall()
- self.assertEqual(row_count, 1)
-
- disk_file = io.BytesIO(torrent_file.getvalue()[64:])
- do_verify(self._conn, disk.id, d_ss, disk_file, read_size = 16, read_tries = 1)
-
- cursor.execute("""
- select disk_id, disk_sectors
- from diskjumble.verify_piece_incomplete natural join diskjumble.verify_piece_content;
- """)
- self.assertEqual(
- cursor.fetchall(),
- [(other_disk.id, NumericRange(0, 5))]
- )
-
- cursor.execute("select count(*) from diskjumble.verify_pass;")
- [(row_count,)] = cursor.fetchall()
- self.assertEqual(row_count, 0)
-
- cursor.execute("select count(*) from diskjumble.verify_piece_fail;")
- [(row_count,)] = cursor.fetchall()
- self.assertEqual(row_count, 0)
-
def test_transient_read_errors(self):
"""
Test a run where a read to the disk fails but fewer times than needed to mark the sector bad.
from __future__ import annotations
from dataclasses import dataclass
-from typing import Iterable, Optional
+from typing import List, Optional
import argparse
import contextlib
import datetime as dt
import hashlib
+import heapq
import io
import itertools
import math
from psycopg2.extras import NumericRange
import psycopg2
-from disk_jumble import bencode, SECTOR_SIZE
-from disk_jumble.db import HasherRef, Slab, Wrapper as DbWrapper
-from disk_jumble.nettle import Sha1Hasher
+from disk_jumble import bencode, BLOCK_SIZE
-_V2_BLOCK_SIZE = 16 * 1024 # in bytes
-
_READ_BUFFER_SIZE = 16 * 1024 ** 2 # in bytes
@dataclass
-class _SlabChunk:
- """A slice of a slab; comprising all or part of a piece to be hashed."""
- slab: Slab
- slice: slice
-
-
-@dataclass
-class _PieceTask:
- """The chunks needed to hash as fully as possible an entity piece."""
- entity_id: bytes
- info_dict: bencode.Bdict
- piece_num: int
- hasher_ref: Optional[HasherRef]
- chunks: list[_SlabChunk]
- complete: bool # do these chunks complete the piece?
-
-
-class _BadSector(Exception):
- pass
-
-
-def do_verify_v1(conn, disk_id: int, sector_size: int, disk_file: io.BufferedIOBase, read_size: int, read_tries: int) -> None:
- db = DbWrapper(conn)
-
- tasks = []
- worklist = db.get_v1_worklist(disk_id, sector_size)
- for (entity_id, group_iter) in itertools.groupby(worklist, lambda t: t[0].entity_id):
- group = list(group_iter)
-
- [info_bytes] = {v for (_, v, _) in group}
- info = bencode.decode(info_bytes)
- piece_len = info[b"piece length"]
- assert piece_len % sector_size == 0
+class _TorrentInfo:
+ id: bytes
+ length: int
+ piece_length: int
+ hashes: List[bytes]
+
+ @classmethod
+ def build(cls, torrent_id: bytes, info: bencode.Bdict):
if b"length" in info:
- torrent_len = info[b"length"]
- else:
- torrent_len = sum(d[b"length"] for d in info[b"files"])
-
- offset = None
- use_hasher = None
- chunks = []
- for (slab, _, hasher_ref) in group:
- if slab.crypt_key is not None:
- raise NotImplementedError("verify of encrypted data")
-
- slab_end = min(slab.entity_offset + len(slab.sectors) * sector_size, torrent_len)
- while offset is None or offset < slab_end:
- if offset is not None and slab.entity_offset > offset:
- if chunks:
- tasks.append(_PieceTask(entity_id, info, offset // piece_len, use_hasher, chunks, False))
- offset = None
- use_hasher = None
- chunks = []
-
- if offset is None:
- aligned = math.ceil(slab.entity_offset / piece_len) * piece_len
- if hasher_ref and hasher_ref.entity_offset < aligned:
- assert hasher_ref.entity_offset < torrent_len
- use_hasher = hasher_ref
- offset = hasher_ref.entity_offset
- elif aligned < slab_end:
- offset = aligned
- else:
- break # no usable data in this slab
-
- if offset is not None:
- piece_end = min(offset + piece_len - offset % piece_len, torrent_len)
- chunk_end = min(piece_end, slab_end)
- chunks.append(_SlabChunk(slab, slice(offset - slab.entity_offset, chunk_end - slab.entity_offset)))
- if chunk_end == piece_end:
- tasks.append(_PieceTask(entity_id, info, offset // piece_len, use_hasher, chunks, True))
- use_hasher = None
- chunks = []
- offset = chunk_end
-
- if chunks:
- tasks.append(_PieceTask(entity_id, info, offset // piece_len, use_hasher, chunks, False))
-
- @dataclass
- class NewVerifyPiece:
- entity_id: bytes
- piece_num: int
- sector_ranges: list[range]
- hasher_state: Optional[dict]
- failed: bool
-
- @dataclass
- class VerifyUpdate:
- seq_start: int
- new_sector_ranges: list[range]
- hasher_state: Optional[dict]
-
- passed_verifies = set()
- failed_verifies = set()
- new_pass_ranges = []
- vp_updates = {}
- new_vps = []
-
- run_ts = dt.datetime.now(dt.timezone.utc)
- for task in tasks:
- hasher = Sha1Hasher(task.hasher_ref.state if task.hasher_ref else None)
- sector_ranges = [
- range(
- chunk.slab.sectors.start + chunk.slice.start // sector_size,
- chunk.slab.sectors.start + math.ceil(chunk.slice.stop / sector_size)
- )
- for chunk in task.chunks
- ]
-
- try:
- for chunk in task.chunks:
- slab_off = chunk.slab.sectors.start * sector_size
- disk_file.seek(slab_off + chunk.slice.start)
- end_pos = slab_off + chunk.slice.stop
- while disk_file.tell() < end_pos:
- pos = disk_file.tell()
- for _ in range(read_tries):
- try:
- data = disk_file.read(min(end_pos - pos, read_size))
- except OSError:
- disk_file.seek(pos)
- else:
- break
- else:
- raise _BadSector()
-
- assert data
- hasher.update(data)
- except _BadSector:
- if task.hasher_ref:
- failed_verifies.add(task.hasher_ref.id)
- vp_updates[task.hasher_ref.id] = VerifyUpdate(task.hasher_ref.seq + 1, sector_ranges, None)
- else:
- new_vps.append(NewVerifyPiece(task.entity_id, task.piece_num, sector_ranges, None, True))
+ length = info[b"length"]
else:
- hasher_state = hasher.ctx.serialize()
- if task.complete:
- s = slice(task.piece_num * 20, task.piece_num * 20 + 20)
- expected_hash = task.info_dict[b"pieces"][s]
- if hasher.digest() == expected_hash:
- write_piece_data = False
- new_pass_ranges.extend(sector_ranges)
- if task.hasher_ref:
- passed_verifies.add(task.hasher_ref.id)
- else:
- failed = True
- write_piece_data = True
- if task.hasher_ref:
- failed_verifies.add(task.hasher_ref.id)
- else:
- failed = False
- write_piece_data = True
-
- if write_piece_data:
- if task.hasher_ref:
- assert task.hasher_ref.id not in vp_updates
- vp_updates[task.hasher_ref.id] = VerifyUpdate(task.hasher_ref.seq + 1, sector_ranges, hasher_state)
- else:
- new_vps.append(NewVerifyPiece(task.entity_id, task.piece_num, sector_ranges, hasher_state, failed))
+ length = sum(d[b"length"] for d in info[b"files"])
- new_pass_ranges.sort(key = lambda r: r.start)
- merged_ranges = []
- for r in new_pass_ranges:
- if merged_ranges and r.start == merged_ranges[-1].stop:
- merged_ranges[-1] = range(merged_ranges[-1].start, r.stop)
- else:
- merged_ranges.append(r)
+ hash_blob = info[b"pieces"]
+ hashes = [hash_blob[s : s + 20] for s in range(0, len(hash_blob), 20)]
- for vp in new_vps:
- verify_id = db.insert_verify_piece(run_ts, vp.entity_id, vp.piece_num)
- db.insert_verify_piece_content(verify_id, 0, disk_id, vp.sector_ranges)
- if vp.failed:
- db.mark_verify_piece_failed(verify_id)
- else:
- db.upsert_hasher_state(verify_id, vp.hasher_state)
+ return cls(torrent_id, length, info[b"piece length"], hashes)
- for (verify_id, update) in vp_updates.items():
- db.insert_verify_piece_content(verify_id, update.seq_start, disk_id, update.new_sector_ranges)
- if update.hasher_state:
- db.upsert_hasher_state(verify_id, update.hasher_state)
- for verify_id in passed_verifies:
- db.move_piece_content_for_pass(verify_id)
- db.delete_verify_piece(verify_id)
+@dataclass
+class _Run:
+ entity_id: bytes
+ entity_length: int
+ piece_num: int
+ block_ranges: List[range]
+ hash: bytes
- for r in merged_ranges:
- db.insert_pass_data(run_ts, disk_id, r)
- for verify_id in failed_verifies:
- db.clear_incomplete(verify_id)
- db.mark_verify_piece_failed(verify_id)
+@dataclass
+class _V1Run(_Run):
+ piece_length: int # for the entity overall
-@dataclass
-class _VerifyResult:
- sector: int
+class _V2Run(_Run):
+ pass
-@dataclass
-class _VerifyPass(_VerifyResult):
+class _BadSector(Exception):
pass
-@dataclass
-class _VerifyFail(_VerifyResult):
- entity_id: bytes
- piece_num: int
+def _get_v1_worklist(conn, disk_id: int, block_ranges: List[NumericRange]) -> List[_V1Run]:
+ """
+ How this works: First, we fetch some info about each torrent on the disk that has data within the requested blocks.
+ Then, we make one query per torrent (because each one may have a different piece size) to determine which disk
+ blocks to verify. Here's how that works for a given torrent:
+
+ 1. Select slabs on the current disk that have data from the appropriate torrent.
+ 2. Join with the list of requested block ranges to obtain which slabs and what portions of each are to be verified.
+ 3. Convert each requested range to a set of piece numbers and deduplicate.
+ 4. Join again with the slab table to get the complete block ranges required to cover each requested piece.
+ 5. Order by piece number (to facilitate grouping by piece), then entity offset (to produce an ordered run for each
+ piece).
+
+ The runs are then reordered by the number of their first disk block.
+ """
+ cursor = conn.cursor()
+ cursor.execute(
+ """
+ select distinct on (entity_id)
+ entity_id, info
+ from (
+ diskjumble.slab
+ left join bittorrent.torrent_info on digest(info, 'sha1') = entity_id
+ )
+ where disk_id = %s and disk_blocks && any (%s)
+ """,
+ (disk_id, block_ranges)
+ )
+ infos = [_TorrentInfo.build(id_, info) for (id_, info) in cursor]
+ for i in infos:
+ assert i.piece_length % BLOCK_SIZE == 0, f"entity {i.id.hex()} has invalid piece length"
-def _gen_verify_results(conn, disk_id: int, disk_file: io.BufferedIOBase, read_size: int, read_tries: int) -> Iterable[_VerifyResult]:
- with conn.cursor() as cursor:
+ runs = []
+ for info in infos:
cursor.execute(
"""
with
- slab_plus as (
+ relevant_slab as (
+ select * from diskjumble.slab
+ where disk_id = %(disk_id)s and entity_id = %(entity_id)s
+ ),
+ targeted_slab as (
+ select relevant_slab.*, target_range * disk_blocks as target_span
+ from
+ relevant_slab
+ join unnest(%(block_ranges)s) as target_range
+ on target_range && disk_blocks
+ ),
+ targeted_piece as (
+ select distinct
+ (
+ entity_offset
+ + (
+ (generate_series(lower(target_span), upper(target_span) - 1) - lower(disk_blocks))
+ * %(block_size)s
+ )
+ ) / %(piece_length)s
+ as piece_num
+ from targeted_slab
+ ),
+ out as (
select
- *,
- int8range(
- slab.entity_offset / %(block_size)s,
- slab.entity_offset / %(block_size)s + upper(slab.disk_blocks) - lower(slab.disk_blocks)
- ) as entity_blocks
- from diskjumble.slab
+ entity_offset,
+ disk_blocks,
+ piece_num,
+ disk_blocks * int8range(
+ (piece_num * %(piece_length)s - entity_offset) / %(block_size)s + lower(disk_blocks),
+ ((piece_num + 1) * %(piece_length)s - entity_offset) / %(block_size)s + lower(disk_blocks)
+ ) as use_blocks,
+ crypt_key
+ from relevant_slab join targeted_piece
)
- select
- elh.entity_id,
- generate_series(
- lower(slab_plus.entity_blocks * elh.block_range),
- upper(slab_plus.entity_blocks * elh.block_range) - 1
- ) as piece_num,
- generate_series(
- (
- lower(slab_plus.entity_blocks * elh.block_range)
- - lower(slab_plus.entity_blocks)
- + lower(slab_plus.disk_blocks)
- ),
+ select piece_num, use_blocks, crypt_key from out
+ where not isempty(use_blocks)
+ order by
+ piece_num,
+ entity_offset + (lower(use_blocks) - lower(disk_blocks)) * %(block_size)s
+ """,
+ {
+ "disk_id": disk_id,
+ "entity_id": info.id,
+ "block_ranges": block_ranges,
+ "block_size": BLOCK_SIZE,
+ "piece_length": info.piece_length,
+ }
+ )
+ rows = cursor.fetchall()
+
+ if any(crypt_key is not None for (*_, crypt_key) in rows):
+ raise NotImplementedError("verify of encrypted data")
+
+ for (piece_num, ranges) in itertools.groupby(rows, lambda t: t[0]):
+ run = _V1Run(
+ entity_id = info.id,
+ entity_length = info.length,
+ piece_length = info.piece_length,
+ piece_num = piece_num,
+ block_ranges = [range(r.lower, r.upper) for r in ranges],
+ hash = info.hashes[piece_num],
+ )
+
+ # Make sure we ended up with all the block ranges for the piece. If not, assume the piece isn't fully
+ # written yet and skip.
+ run_length = min(info.piece_length, info.length - piece_num * info.piece_length)
+ if sum(map(len, run.block_ranges)) == math.ceil(run_length / BLOCK_SIZE):
+ runs.append(run)
+
+ runs.sort(lambda r: r.block_ranges)
+ return runs
+
+
+def _get_v2_worklist(conn, disk_id: int, block_ranges: List[NumericRange]) -> List[_V2Run]:
+ cursor = conn.cursor()
+ cursor.execute(
+ """
+ with
+ slab_plus as (
+ select
+ *,
+ int8range(
+ slab.entity_offset / %(block_size)s,
+ slab.entity_offset / %(block_size)s + upper(slab.disk_blocks) - lower(slab.disk_blocks)
+ ) as entity_blocks
+ from diskjumble.slab
+ where disk_id = %(disk_id)s
+ ),
+ joined as (
+ select
+ entity_id,
+ entity.length,
+ entity_blocks,
+ slab_plus.disk_blocks
+ entity_blocks * elh.block_range as check_erange,
+ from
+ entityv2_leaf_hashes elh
+ join slab_plus using (entity_id)
+ left outer join public.entity using (entity_id)
+ ),
+ filtered as (
+ select * from joined where not isempty(check_erange)
+ ),
+ exploded as (
+ select
+ entity_id,
+ length,
+ generate_series(lower(check_erange), upper(check_erange) - 1) as piece_num,
(
- upper(slab_plus.entity_blocks * elh.block_range)
- - lower(slab_plus.entity_blocks)
- + lower(slab_plus.disk_blocks)
- - 1
- )
- ) as sector,
- entity.length as entity_length,
- substring(hashes, generate_series(0, octet_length(hashes) / 32 - 1, 32), 32) as hash,
- crypt_key
- from (
- entityv2_leaf_hashes elh
- join slab_plus on (
- slab_plus.entity_id = elh.entity_id
- and slab_plus.entity_blocks && elh.block_range
- )
- left outer join public.entity on elh.entity_id = entity.entity_id
+ generate_series(lower(check_erange), upper(check_erange) - 1)
+ - lower(entity_blocks) + lower(disk_blocks)
+ ) as block,
+ substring(
+ hashes,
+ generate_series(lower(check_erange), upper(check_erange) - 1) * 32,
+ 32
+ ) as hash,
+ crypt_key
+ from filtered
)
- where slab_plus.disk_id = %(disk_id)s
- order by sector
- """,
- {"block_size": _V2_BLOCK_SIZE, "disk_id": disk_id}
+ select * from exploded
+ where block <@ any (%(block_ranges)s)
+ order by block
+ """,
+ {
+ "block_size": BLOCK_SIZE,
+ "disk_id": disk_id,
+ "block_ranges": block_ranges,
+ }
+ )
+ rows = cursor.fetchall()
+
+ if any(crypt_key is not None for (*_, crypt_key) in rows):
+ raise NotImplementedError("verify of encrypted data")
+
+ return [
+ _V2Run(
+ entity_id = entity_id,
+ entity_length = entity_length,
+ piece_num = piece_num,
+ block_ranges = [range(block, block + 1)],
+ hash = hash_,
)
- for (entity_id, piece_num, sector, entity_len, hash_, crypt_key) in cursor:
- if crypt_key is not None:
- raise NotImplementedError("verify of encrypted data")
-
- read_start = sector * _V2_BLOCK_SIZE
- read_end = read_start + min(_V2_BLOCK_SIZE, entity_len - piece_num * _V2_BLOCK_SIZE)
- disk_file.seek(read_start)
- hasher = hashlib.sha256()
- try:
- while disk_file.tell() < read_end:
- pos = disk_file.tell()
- for _ in range(read_tries):
- try:
- data = disk_file.read(min(read_end - pos, read_size))
- except OSError:
- disk_file.seek(pos)
- else:
- break
- else:
- raise _BadSector()
-
- assert data
- hasher.update(data)
- except _BadSector:
- pass_ = False
- else:
- pass_ = hasher.digest() == hash_
+ for (entity_id, entity_length, piece_num, block, entity_length, hash_, crypt_key) in rows
+ ]
+
+
+def _do_verify(conn, disk_id: int, block_ranges: Optional[List[range]], disk_file: io.BufferedIOBase, read_size: int, read_tries: int):
+ @dataclass
+ class Pass:
+ blocks: range
- if pass_:
- yield _VerifyPass(sector)
+ @dataclass
+ class Fail:
+ run: _Run
+
+ def merge_results(results):
+ curr_pass = None
+ for r in results:
+ if isinstance(r, Pass) and curr_pass and r.blocks.start <= curr_pass.stop:
+ curr_pass = range(curr_pass.start, max(r.blocks.stop, curr_pass.stop))
else:
- yield _VerifyFail(sector, entity_id, piece_num)
+ if curr_pass:
+ yield Pass(curr_pass)
+ curr_pass = None
+ if isinstance(r, Pass):
+ curr_pass = r.blocks
+ else:
+ yield r
+ if curr_pass:
+ yield Pass(curr_pass)
+
+ if block_ranges is None:
+ pg_block_ranges = [NumericRange()]
+ else:
+ pg_block_ranges = [NumericRange(r.start, r.stop) for r in block_ranges]
+
+ worklist = list(heapq.merge(
+ _get_v1_worklist(conn, disk_id, pg_block_ranges),
+ _get_v2_worklist(conn, disk_id, pg_block_ranges),
+ key = lambda run: run.block_ranges,
+ ))
+
+ if block_ranges is not None:
+ requested_blocks = {
+ block
+ for r in block_ranges
+ for block in r
+ }
+ covered_blocks = {
+ block
+ for run in worklist
+ for block_range in run.block_ranges
+ for block in block_range
+ }
+ missing = requested_blocks - covered_blocks
+ if missing:
+ raise RuntimeError(f"unable to locate blocks: {len(missing)} in the range {min(missing)} to {max(missing)}")
+
+ def generate_results():
+ for run in worklist:
+ if isinstance(run, _V1Run):
+ hasher = hashlib.sha1()
+ entity_off = run.piece_num * run.piece_length
+ else:
+ hasher = hashlib.sha256()
+ entity_off = run.piece_num * BLOCK_SIZE
+ try:
+ for range_ in run.block_ranges:
+ for block in range_:
+ read_start = block * BLOCK_SIZE
+ read_end = read_start + min(BLOCK_SIZE, run.entity_length - entity_off)
+ disk_file.seek(read_start)
+ while disk_file.tell() < read_end:
+ pos = disk_file.tell()
+ for _ in range(read_tries):
+ try:
+ data = disk_file.read(min(read_end - pos, read_size))
+ except OSError:
+ disk_file.seek(pos)
+ else:
+ break
+ else:
+ raise _BadSector()
+
+ assert data
+ hasher.update(data)
+ entity_off += BLOCK_SIZE
+ except _BadSector:
+ yield Fail(run)
+ else:
+ if hasher.digest() == run.hash:
+ yield from (Pass(r) for r in run.block_ranges)
+ else:
+ yield Fail(run)
-def do_verify_v2(conn, disk_id: int, disk_file: io.BufferedIOBase, read_size: int, read_tries: int) -> None:
+ cursor = conn.cursor()
ts = dt.datetime.now(dt.timezone.utc)
- with conn.cursor() as cursor:
- def save_pass_range(r):
+ for result in merge_results(generate_results()):
+ if isinstance(result, Pass):
cursor.execute(
- "insert into diskjumble.verify_pass values (default, %s, %s, %s);",
- (ts, disk_id, NumericRange(r.start, r.stop))
+ "insert into diskjumble.verify_pass values (default, %s, %s, %s)",
+ (ts, disk_id, result.blocks)
+ )
+ else:
+ assert isinstance(result, Fail)
+ run = result.run
+ cursor.execute(
+ """
+ with
+ new_piece as (
+ insert into diskjumble.verify_piece
+ values (default, %(ts)s, %(entity_id)s, %(piece_num)s)
+ returning verify_id
+ ),
+ _ as (
+ insert into diskjumble.verify_piece_content
+ select verify_id, ordinality - 1, %(disk_id)s, block_range
+ from
+ new_piece,
+ unnest(%(ranges)) with ordinality as block_range
+ )
+ insert into diskjumble.verify_piece_fail
+ select verify_id from new_piece
+ """,
+ {
+ "ts": ts,
+ "entity_id": run.entity_id,
+ "piece_num": run.piece_num,
+ "disk_id": disk_id,
+ "ranges": [NumericRange(r.start, r.stop) for r in run.block_ranges],
+ }
)
-
- pass_sectors = None
- for result in _gen_verify_results(conn, disk_id, disk_file, read_size, read_tries):
- if isinstance(result, _VerifyPass):
- if pass_sectors is None:
- pass_sectors = range(result.sector, result.sector + 1)
- elif result.sector == pass_sectors.stop:
- pass_sectors = range(pass_sectors.start, result.sector + 1)
- else:
- save_pass_range(pass_sectors)
- pass_sectors = range(result.sector, result.sector + 1)
- else:
- assert isinstance(result, _VerifyFail)
- if pass_sectors:
- save_pass_range(pass_sectors)
- pass_sectors = None
-
- cursor.execute(
- """
- with
- new_piece as (
- insert into diskjumble.verify_piece
- values (default, %s, %s, %s)
- returning verify_id
- ),
- _ as (
- insert into diskjumble.verify_piece_content
- select verify_id, 0, %s, %s from new_piece
- )
- insert into diskjumble.verify_piece_fail
- select verify_id from new_piece
- """,
- (ts, result.entity_id, result.piece_num, disk_id, NumericRange(result.sector, result.sector + 1))
- )
- if pass_sectors:
- save_pass_range(pass_sectors)
if __name__ == "__main__":
- def read_tries(raw_arg):
- val = int(raw_arg)
+ def read_tries(raw_val):
+ val = int(raw_val)
if val > 0:
return val
else:
raise ValueError()
+ def block_ranges(raw_val):
+ def parse_one(part):
+ if "-" in part:
+ (s, e) = map(int, part.split("-"))
+ if e <= s:
+ raise ValueError()
+ else:
+ return range(s, e)
+ else:
+ s = int(part)
+ return range(s, s + 1)
+
+ return list(map(parse_one, raw_val.split(",")))
+
parser = argparse.ArgumentParser()
parser.add_argument("disk_id", type = int)
parser.add_argument(
"read_tries",
type = read_tries,
- help = "number of times to attempt a particular disk read before giving up on the sector",
+ help = "number of times to attempt a particular disk read before giving up on the block",
+ )
+ parser.add_argument(
+ "block_ranges",
+ type = block_ranges,
+ nargs = "?",
+ help = "if specified, only verify what's needed to cover these disk blocks (\"0,2-4\" means 0, 2, and 3)",
)
args = parser.parse_args()
+ if args.block_ranges is None:
+ block_ranges = None
+ else:
+ block_ranges = []
+ for r in sorted(args.block_ranges, key = lambda r: r.start):
+ if block_ranges and r.start <= block_ranges[-1].stop:
+ prev = block_ranges.pop()
+ block_ranges.append(range(prev.start, max(prev.stop, r.stop)))
+ else:
+ block_ranges.append(r)
+
with contextlib.closing(psycopg2.connect("")) as conn:
- conn.autocommit = True
path = f"/dev/mapper/diskjumble-{args.disk_id}"
with open(path, "rb", buffering = _READ_BUFFER_SIZE) as disk_file:
- do_verify_v1(conn, args.disk_id, SECTOR_SIZE, disk_file, _READ_BUFFER_SIZE, args.read_tries)
- do_verify_v2(conn, args.disk_id, disk_file, _READ_BUFFER_SIZE, args.read_tries)
+ with conn:
+ _do_verify(conn, args.disk_id, block_ranges, disk_file, _READ_BUFFER_SIZE, args.read_tries)