Verify tool: switch to auto block targeting and multirange output
authorJakob Cornell <jakob+gpg@jcornell.net>
Wed, 20 Apr 2022 06:02:50 +0000 (01:02 -0500)
committerJakob Cornell <jakob+gpg@jcornell.net>
Wed, 20 Apr 2022 06:02:50 +0000 (01:02 -0500)
disk_jumble/src/disk_jumble/verify.py

index 07dc5e9d0fae12b62780764b6e19650e45bd1ba7..93d5ba2ce7ebb71a50cc45f93b90a260a026f19c 100644 (file)
@@ -1,16 +1,15 @@
 from __future__ import annotations
 from dataclasses import dataclass
-from typing import Iterator, List, Optional, Tuple
+from typing import List, Optional
 import argparse
 import contextlib
-import datetime as dt
 import hashlib
 import heapq
 import io
 import itertools
 import math
 
-from psycopg2.extras import execute_batch, NumericRange
+from psycopg2.extras import NumericRange
 import psycopg2
 
 from disk_jumble import bencode, BLOCK_SIZE
@@ -48,10 +47,6 @@ class _Run:
        hash: bytes
 
 
-def _run_sort_key(run: _Run):
-       return run.block_ranges[0].start
-
-
 @dataclass
 class _V1Run(_Run):
        piece_length: int  # for the entity overall
@@ -65,6 +60,35 @@ class _BadSector(Exception):
        pass
 
 
+def _run_sort_key(run: _Run):
+       return run.block_ranges[0].start
+
+
+def _get_target_ranges(conn, limit: Optional[int]) -> List[NumericRange]:
+       ranges = []
+       block_count = 0
+       with conn.cursor() as cursor:
+               cursor.execute(
+                       """
+                               select unnest(written_map - coalesce(verified_map, int8multirange())) as range
+                               from diskjumble.disk_maps
+                               where disk_id = %s and written_map is not null
+                               order by range
+                       """,
+                       (args.disk_id,),
+               )
+               for (r,) in cursor:
+                       if limit is not None and block_count + (r.upper - r.lower) > limit:
+                               capped_size = limit - block_count
+                               if capped_size:
+                                       ranges.append(NumericRange(r.lower, r.lower + capped_size))
+                               break
+                       else:
+                               ranges.append(r)
+
+       return ranges
+
+
 def _get_v1_worklist(conn, disk_id: int, target_ranges: List[NumericRange]) -> List[_V1Run]:
        """
        How this works: First, we fetch some info about each torrent on the disk that has data within the requested blocks.
@@ -255,14 +279,6 @@ def _get_v2_worklist(conn, disk_id: int, target_ranges: List[NumericRange]) -> L
 
 
 def _do_verify(conn, disk_id: int, target_ranges: Optional[List[range]], disk_file: io.BufferedIOBase, read_size: int, read_tries: int):
-       @dataclass
-       class Pass:
-               blocks: range
-
-       @dataclass
-       class Fail:
-               run: _Run
-
        if target_ranges is None:
                pg_target_ranges = [NumericRange()]
        else:
@@ -290,167 +306,116 @@ def _do_verify(conn, disk_id: int, target_ranges: Optional[List[range]], disk_fi
                if missing:
                        raise RuntimeError(f"unable to locate blocks: {len(missing)} in the range {min(missing)} to {max(missing)}")
 
-       def generate_results():
-               for run in worklist:
-                       if isinstance(run, _V1Run):
-                               hasher = hashlib.sha1()
-                               entity_off = run.piece_num * run.piece_length
-                       else:
-                               hasher = hashlib.sha256()
-                               entity_off = run.piece_num * BLOCK_SIZE
-
-                       try:
-                               for range_ in run.block_ranges:
-                                       for block in range_:
-                                               read_start = block * BLOCK_SIZE
-                                               read_end = read_start + min(BLOCK_SIZE, run.entity_length - entity_off)
-                                               disk_file.seek(read_start)
-                                               while disk_file.tell() < read_end:
-                                                       pos = disk_file.tell()
-                                                       for _ in range(read_tries):
-                                                               try:
-                                                                       data = disk_file.read(min(read_end - pos, read_size))
-                                                               except OSError:
-                                                                       disk_file.seek(pos)
-                                                               else:
-                                                                       break
+       passes = []
+       fails = []
+       for run in worklist:
+               if isinstance(run, _V1Run):
+                       hasher = hashlib.sha1()
+                       entity_off = run.piece_num * run.piece_length
+               else:
+                       hasher = hashlib.sha256()
+                       entity_off = run.piece_num * BLOCK_SIZE
+
+               try:
+                       for range_ in run.block_ranges:
+                               for block in range_:
+                                       read_start = block * BLOCK_SIZE
+                                       read_end = read_start + min(BLOCK_SIZE, run.entity_length - entity_off)
+                                       disk_file.seek(read_start)
+                                       while disk_file.tell() < read_end:
+                                               pos = disk_file.tell()
+                                               for _ in range(read_tries):
+                                                       try:
+                                                               data = disk_file.read(min(read_end - pos, read_size))
+                                                       except OSError:
+                                                               disk_file.seek(pos)
                                                        else:
-                                                               raise _BadSector()
-
-                                                       assert data
-                                                       hasher.update(data)
-                                               entity_off += BLOCK_SIZE
-                       except _BadSector:
-                               yield Fail(run)
+                                                               break
+                                               else:
+                                                       raise _BadSector()
+
+                                               assert data
+                                               hasher.update(data)
+                                       entity_off += BLOCK_SIZE
+               except _BadSector:
+                       fails.extend(run.block_ranges)
+               else:
+                       if hasher.digest() == run.hash:
+                               passes.extend(run.block_ranges)
                        else:
-                               if hasher.digest() == run.hash:
-                                       yield from (Pass(r) for r in run.block_ranges)
-                               else:
-                                       yield Fail(run)
-
-       def batch_results(results) -> Iterator[Tuple[List[range], List[_Run]]]:
-               # Group verify results into batches for output to the database
-               passes = []
-               fails = []
-               batch_size = 0
-
-               for r in results:
-                       if batch_size > 500:
-                               yield (passes, fails)
-                               passes = []
-                               fails = []
-                               batch_size = 0
-                       if isinstance(r, Pass):
-                               if passes and r.blocks.start <= passes[-1].stop:
-                                       new_stop = max(r.blocks.stop, passes[-1].stop)
-                                       batch_size += (new_stop - passes[-1].stop)
-                                       passes[-1] = range(passes[-1].start, new_stop)
-                               else:
-                                       passes.append(r.blocks)
-                                       batch_size += len(r.blocks)
+                               fails.extend(run.block_ranges)
+
+       def clean_up(ranges):
+               out = []
+               for r in sorted(ranges, key = lambda r: r.start):
+                       if out and r.start == out[-1].stop:
+                               out[-1] = range(out[-1].start, r.stop)
                        else:
-                               fails.append(r.run)
-                               batch_size += sum(map(len, r.run.block_ranges))
-               if passes or fails:
-                       yield (passes, fails)
+                               out.append(r)
+               return out
 
-       cursor = conn.cursor()
-       ts = dt.datetime.now(dt.timezone.utc)
-       for (pass_ranges, fail_runs) in batch_results(generate_results()):
-               if pass_ranges:
-                       execute_batch(
-                               cursor,
-                               """
-                                       insert into diskjumble.verify_pass (verify_pass_id, at, disk_id, disk_sectors)
-                                       values (default, %s, %s, %s)
-                               """,
-                               [(ts, disk_id, NumericRange(r.start, r.stop)) for r in pass_ranges],
-                               page_size = len(pass_ranges),
-                       )
-               if fail_runs:
-                       execute_batch(
-                               cursor,
-                               """
-                                       with
-                                               new_piece as (
-                                                       insert into diskjumble.verify_piece (verify_id, at, entity_id, piece)
-                                                       values (default, %(ts)s, %(entity_id)s, %(piece_num)s)
-                                                       returning verify_id
-                                               ),
-                                               _ as (
-                                                       insert into diskjumble.verify_piece_content (verify_id, seq, disk_id, disk_sectors)
-                                                       select verify_id, ordinality - 1, %(disk_id)s, block_range
-                                                       from
-                                                               new_piece,
-                                                               unnest(%(ranges)s::int8range[]) with ordinality as block_range
-                                               )
-                                       insert into diskjumble.verify_piece_fail (verify_id)
-                                       select verify_id from new_piece
-                               """,
-                               [
-                                       {
-                                               "ts": ts,
-                                               "entity_id": run.entity_id,
-                                               "piece_num": run.piece_num,
-                                               "disk_id": disk_id,
-                                               "ranges": [NumericRange(r.start, r.stop) for r in run.block_ranges],
-                                       }
-                                       for run in fail_runs
-                               ],
-                               page_size = len(fail_runs),
-                       )
+       clean_passes = clean_up(passes)
+       clean_fails = clean_up(fails)
+
+       conn.cursor().execute(
+               """
+                       with
+                               new_passes as (
+                                       select range_agg(range) as new_passes
+                                       from unnest(%(pass_ranges)s::int8range[]) as range
+                               ),
+                               new_fails as (
+                                       select range_agg(range) as new_passes
+                                       from unnest(%(fail_ranges)s::int8range[]) as range
+                               )
+                       update diskjumble.disk_maps
+                       set
+                               verified_map = coalesce(verified_map, int8multirange()) + new_passes,
+                               written_map = written_map - new_fails
+                       from new_passes, new_fails
+                       where disk_id = %(disk_id)s
+               """,
+               {
+                       "pass_ranges": [NumericRange(r.start, r.stop) for r in clean_passes],
+                       "fail_ranges": [NumericRange(r.start, r.stop) for r in clean_fails],
+                       "disk_id": disk_id,
+               },
+       )
 
 
 if __name__ == "__main__":
-       def read_tries(raw_val):
+       def pos_int(raw_val):
                val = int(raw_val)
                if val > 0:
                        return val
                else:
                        raise ValueError()
 
-       def block_ranges(raw_val):
-               def parse_one(part):
-                       if "-" in part:
-                               (s, e) = map(int, part.split("-"))
-                               if e <= s:
-                                       raise ValueError()
-                               else:
-                                       return range(s, e)
-                       else:
-                               s = int(part)
-                               return range(s, s + 1)
-
-               return list(map(parse_one, raw_val.split(",")))
+       def nonneg_int(raw_val):
+               val = int(raw_val)
+               if val >= 0:
+                       return val
+               else:
+                       raise ValueError()
 
        parser = argparse.ArgumentParser()
        parser.add_argument("disk_id", type = int)
        parser.add_argument(
                "read_tries",
-               type = read_tries,
+               type = pos_int,
                help = "number of times to attempt a particular disk read before giving up on the block",
        )
        parser.add_argument(
-               "block_ranges",
-               type = block_ranges,
+               "block_limit",
+               type = nonneg_int,
                nargs = "?",
-               help = "if specified, only verify what's needed to cover these disk blocks (\"0,2-4\" means 0, 2, and 3)",
+               help = "if specified, only target this many eligible blocks",
        )
        args = parser.parse_args()
 
-       if args.block_ranges is None:
-               target_ranges = None
-       else:
-               target_ranges = []
-               for r in sorted(args.block_ranges, key = lambda r: r.start):
-                       if target_ranges and r.start <= target_ranges[-1].stop:
-                               prev = target_ranges.pop()
-                               target_ranges.append(range(prev.start, max(prev.stop, r.stop)))
-                       else:
-                               target_ranges.append(r)
-
+       path = f"/dev/mapper/diskjumble-{args.disk_id}"
        with contextlib.closing(psycopg2.connect("")) as conn:
-               path = f"/dev/mapper/diskjumble-{args.disk_id}"
+               conn.autocommit = True
+               target_ranges = _get_target_ranges(conn, args.block_limit)
                with open(path, "rb", buffering = _READ_BUFFER_SIZE) as disk_file:
-                       with conn:
-                               _do_verify(conn, args.disk_id, target_ranges, disk_file, _READ_BUFFER_SIZE, args.read_tries)
+                       _do_verify(conn, args.disk_id, target_ranges, disk_file, _READ_BUFFER_SIZE, args.read_tries)