from __future__ import annotations
from dataclasses import dataclass
-from typing import List, Optional
+from typing import Iterator, List, Optional
+from warnings import warn
import argparse
import contextlib
import hashlib
def _get_target_ranges(conn, disk_id: int, limit: Optional[int]) -> List[range]:
ranges = []
block_count = 0
- with conn.cursor() as cursor:
+ with conn:
+ cursor = conn.cursor("target_ranges")
cursor.execute(
"""
select unnest(written_map - coalesce(verified_map, int8multirange())) as range
The runs are then reordered by the number of their first disk block.
"""
- cursor = conn.cursor()
+ cursor = conn.cursor("v1_worklist")
cursor.execute(
"""
select distinct on (entity_id)
return runs
-def _get_v2_worklist(conn, disk_id: int, target_ranges: List[NumericRange]) -> List[_V2Run]:
- cursor = conn.cursor()
+def _get_v2_worklist(conn, disk_id: int, target_ranges: List[NumericRange]) -> Iterator[_V2Run]:
+ cursor = conn.cursor("v2_worklist")
cursor.execute(
"""
with
"target_ranges": target_ranges,
}
)
- rows = cursor.fetchall()
-
- if any(crypt_key is not None for (*_, crypt_key) in rows):
- raise NotImplementedError("verify of encrypted data")
-
- return [
- _V2Run(
- entity_id = bytes(entity_id),
- entity_length = entity_length,
- piece_num = piece_num,
- block_ranges = [range(block, block + 1)],
- hash = bytes(hash_),
- )
- for (entity_id, entity_length, piece_num, block, hash_, _) in rows
- ]
+ for (entity_id, entity_length, piece_num, block, hash_, crypt_key) in cursor:
+ if crypt_key is None:
+ yield _V2Run(
+ entity_id = bytes(entity_id),
+ entity_length = entity_length,
+ piece_num = piece_num,
+ block_ranges = [range(block, block + 1)],
+ hash = bytes(hash_),
+ )
+ else:
+ raise NotImplementedError("verify of encrypted data")
-def _do_verify(conn, disk_id: int, target_ranges: List[range], disk_file: io.BufferedIOBase, read_size: int, read_tries: int):
- pg_target_ranges = [NumericRange(r.start, r.stop) for r in target_ranges]
- worklist = list(heapq.merge(
- _get_v1_worklist(conn, disk_id, pg_target_ranges),
- _get_v2_worklist(conn, disk_id, pg_target_ranges),
- key = _run_sort_key,
- ))
+def _do_verify(conn, disk_id: int, target_ranges: List[range], disk_file: io.BufferedIOBase, read_size: int, read_tries: int):
requested_blocks = {
block
for r in target_ranges
for block in r
}
- covered_blocks = {
- block
- for run in worklist
- for block_range in run.block_ranges
- for block in block_range
- }
- missing = requested_blocks - covered_blocks
- if missing:
- raise RuntimeError(f"unable to locate blocks: {len(missing)} in the range {min(missing)} to {max(missing)}")
-
- passes = []
- fails = []
- for run in worklist:
- if isinstance(run, _V1Run):
- hasher = hashlib.sha1()
- entity_off = run.piece_num * run.piece_length
- else:
- hasher = hashlib.sha256()
- entity_off = run.piece_num * BLOCK_SIZE
- assert len(run.hash) == hasher.digest_size, "incorrect validation hash length"
-
- try:
- for range_ in run.block_ranges:
- for block in range_:
- read_start = block * BLOCK_SIZE
- read_end = read_start + min(BLOCK_SIZE, run.entity_length - entity_off)
- disk_file.seek(read_start)
- while disk_file.tell() < read_end:
- pos = disk_file.tell()
- for _ in range(read_tries):
- try:
- data = disk_file.read(min(read_end - pos, read_size))
- except OSError:
- disk_file.seek(pos)
- else:
- break
- else:
- raise _BadSector()
-
- assert data
- hasher.update(data)
- entity_off += BLOCK_SIZE
- except _BadSector:
- fails.extend(run.block_ranges)
- else:
- if hasher.digest() == run.hash:
- passes.extend(run.block_ranges)
+ pg_target_ranges = [NumericRange(r.start, r.stop) for r in target_ranges]
+
+ # transaction is required for named cursors in worklist generation
+ with conn:
+ worklist = heapq.merge(
+ _get_v1_worklist(conn, disk_id, pg_target_ranges),
+ _get_v2_worklist(conn, disk_id, pg_target_ranges),
+ key = _run_sort_key,
+ )
+
+ covered_blocks = set()
+ passes = []
+ fails = []
+ for run in worklist:
+ if isinstance(run, _V1Run):
+ hasher = hashlib.sha1()
+ entity_off = run.piece_num * run.piece_length
else:
+ hasher = hashlib.sha256()
+ entity_off = run.piece_num * BLOCK_SIZE
+ assert len(run.hash) == hasher.digest_size, "incorrect validation hash length"
+
+ try:
+ for range_ in run.block_ranges:
+ for block in range_:
+ read_start = block * BLOCK_SIZE
+ read_end = read_start + min(BLOCK_SIZE, run.entity_length - entity_off)
+ disk_file.seek(read_start)
+ while disk_file.tell() < read_end:
+ pos = disk_file.tell()
+ for _ in range(read_tries):
+ try:
+ data = disk_file.read(min(read_end - pos, read_size))
+ except OSError:
+ disk_file.seek(pos)
+ else:
+ break
+ else:
+ raise _BadSector()
+
+ assert data
+ hasher.update(data)
+ entity_off += BLOCK_SIZE
+ except _BadSector:
fails.extend(run.block_ranges)
+ else:
+ if hasher.digest() == run.hash:
+ passes.extend(run.block_ranges)
+ else:
+ fails.extend(run.block_ranges)
+
+ covered_blocks.update(
+ block
+ for block_range in run.block_ranges
+ for block in block_range
+ )
+
+ missing = requested_blocks - covered_blocks
+ if missing:
+ warn(f"unable to locate blocks: {len(missing)} in the range {min(missing)} to {max(missing)}")
def clean_up(ranges):
out = []