Implement v2 verify and bump version, misc. cleanup
authorJakob Cornell <jakob+gpg@jcornell.net>
Fri, 4 Feb 2022 01:18:30 +0000 (19:18 -0600)
committerJakob Cornell <jakob+gpg@jcornell.net>
Fri, 4 Feb 2022 01:18:30 +0000 (19:18 -0600)
disk_jumble/setup.cfg
disk_jumble/src/disk_jumble/verify.py

index 3fc8d704801a298d8390366f49b2b0767e782023..bd8288f5b517871b3daad0c64ea8ac930eac3d4a 100644 (file)
@@ -1,6 +1,6 @@
 [metadata]
 name = disk_jumble
-version = 0.0.1
+version = 0.0.2
 
 [options]
 package_dir =
index 0adbb037cce77119461543342c4e612bd24d2282..f644d9a6427ed3ef623ac3e44a8aa2341437d074 100644 (file)
@@ -1,14 +1,15 @@
 from __future__ import annotations
-from abc import ABCMeta, abstractmethod
 from dataclasses import dataclass
-from typing import Optional
+from typing import Iterable, Optional
 import argparse
 import contextlib
 import datetime as dt
+import hashlib
 import io
 import itertools
 import math
 
+from psycopg2.extras import NumericRange
 import psycopg2
 
 from disk_jumble import bencode
@@ -16,6 +17,8 @@ from disk_jumble.db import HasherRef, Slab, Wrapper as DbWrapper
 from disk_jumble.nettle import Sha1Hasher
 
 
+_V2_BLOCK_SIZE = 16 * 1024  # in bytes
+
 _READ_BUFFER_SIZE = 16 * 1024 ** 2  # in bytes
 
 
@@ -139,7 +142,7 @@ def do_verify(conn, disk_id: int, disk_file: io.BufferedIOBase, read_size: int,
                                        for _ in range(read_tries):
                                                try:
                                                        data = disk_file.read(min(end_pos - pos, read_size))
-                                               except OSError as e:
+                                               except OSError:
                                                        disk_file.seek(pos)
                                                else:
                                                        break
@@ -213,6 +216,147 @@ def do_verify(conn, disk_id: int, disk_file: io.BufferedIOBase, read_size: int,
                db.mark_verify_piece_failed(verify_id)
 
 
+@dataclass
+class _VerifyResult:
+       sector: int
+
+
+@dataclass
+class _VerifyPass(_VerifyResult):
+       pass
+
+
+@dataclass
+class _VerifyFail(_VerifyResult):
+       entity_id: bytes
+       piece_num: int
+
+
+def _gen_verify_results(conn, disk_id: int, disk_file: io.BufferedIOBase, read_size: int, read_tries: int) -> Iterable[_VerifyResult]:
+       with conn.cursor() as cursor:
+               cursor.execute(
+                       """
+                               with
+                                       slab_plus as (
+                                               select
+                                                       *,
+                                                       int8range(
+                                                               slab.entity_offset / %(block_size)s,
+                                                               salb.entity_offset / %(block_size)s + upper(slab.disk_blocks) - lower(slab.disk_blocks)
+                                                       ) as entity_blocks
+                                               from diskjumble.slab
+                                       )
+                               select
+                                       entity_id,
+                                       generate_series(
+                                               lower(slab_plus.entity_blocks * elh.block_range),
+                                               upper(slab_plus.entity_blocks * elh.block_range) - 1
+                                       ) as piece_num,
+                                       generate_series(
+                                               (
+                                                       lower(slab_plus.entity_blocks * elh.block_range)
+                                                       - lower(slab_plus.entity_blocks)
+                                                       + lower(slab_plus.disk_blocks)
+                                               ),
+                                               (
+                                                       upper(slab_plus.entity_blocks * elh.block_range)
+                                                       - lower(slab_plus.entity_blocks)
+                                                       + lower(slab_plus.disk_blocks)
+                                                       - 1
+                                               )
+                                       ) as sector,
+                                       entity.length as entity_length
+                                       substring(hashes, generate_series(0, octet_length(hashes) / 32 - 1, 32), 32) as hash
+                               from (
+                                       entityv2_leaf_hashes elh
+                                       join slab_plus on (
+                                               slab_plus.entity_id = elh.entity_id
+                                               and slab_plus.entity_blocks && elh.block_range
+                                       )
+                                       left outer join public.entity using (entity_id)
+                               )
+                               where slab_plus.disk_id = %(disk_id)s
+                               order by sector
+                       """,
+                       {"block_size": _V2_BLOCK_SIZE, "disk_id": disk_id}
+               )
+               for (entity_id, piece_num, sector, entity_len, hash_) in cursor:
+                       read_start = sector * _V2_BLOCK_SIZE
+                       read_end = read_start + min(_V2_BLOCK_SIZE, entity_len - piece_num * _V2_BLOCK_SIZE)
+                       disk_file.seek(read_start)
+                       hasher = hashlib.sha256()
+                       try:
+                               while disk_file.tell() < read_end:
+                                       pos = disk_file.tell()
+                                       for _ in range(read_tries):
+                                               try:
+                                                       data = disk_file.read(min(read_end - pos, read_size))
+                                               except OSError:
+                                                       disk_file.seek(pos)
+                                               else:
+                                                       break
+                                       else:
+                                               raise _BadSector()
+
+                                       assert data
+                                       hasher.update(data)
+                       except _BadSector:
+                               pass_ = False
+                       else:
+                               pass_ = hasher.digest() == hash_
+
+                       if pass_:
+                               yield _VerifyPass(sector)
+                       else:
+                               yield _VerifyFail(sector, entity_id, piece_num)
+
+
+def do_verify_v2(conn, disk_id: int, disk_file: io.BufferedIOBase, read_size: int, read_tries: int) -> None:
+       ts = dt.datetime.now(dt.timezone.utc)
+       with conn.cursor() as cursor:
+               def save_pass_range(r):
+                       cursor.execute(
+                               "insert into diskjumble.verify_pass values (default, %s, %s, %s);",
+                               (ts, disk_id, NumericRange(r.start, r.stop))
+                       )
+
+               pass_sectors = None
+               for result in _gen_verify_results(conn, disk_id, disk_file, read_size, read_tries):
+                       if isinstance(result, _VerifyPass):
+                               if pass_sectors is None:
+                                       pass_sectors = range(result.sector, result.sector + 1)
+                               elif result.sector == pass_sectors.stop:
+                                       pass_sectors = range(pass_sectors.start, result.sector + 1)
+                               else:
+                                       save_pass_range(pass_sectors)
+                                       pass_sectors = range(result.sector, result.sector + 1)
+                       else:
+                               assert isinstance(result, _VerifyFail)
+                               if pass_sectors:
+                                       save_pass_range(pass_sectors)
+                                       pass_sectors = None
+
+                               cursor.execute(
+                                       """
+                                               with
+                                                       new_piece as (
+                                                               insert into diskjumble.verify_piece
+                                                               values (default, %s, %s, %s)
+                                                               returning verify_id
+                                                       ),
+                                                       _ as (
+                                                               insert into diskjumble.verify_piece_content
+                                                               values (new_piece.verify_id, 0, %s, %s)
+                                                       )
+                                               insert into diskjumble.verify_piece_fail
+                                               select verify_id from new_piece
+                                       """,
+                                       (ts, result.entity_id, result.piece_num, disk_id, NumericRange(result.sector, result.sector + 1))
+                               )
+               if pass_sectors:
+                       save_pass_range(pass_sectors)
+
+
 if __name__ == "__main__":
        def read_tries(raw_arg):
                val = int(raw_arg)
@@ -222,6 +366,7 @@ if __name__ == "__main__":
                        raise ValueError()
 
        parser = argparse.ArgumentParser()
+       parser.add_argument("entity_version", choices = ["1", "2"])
        parser.add_argument("disk_id", type = int)
        parser.add_argument(
                "read_tries",
@@ -234,4 +379,5 @@ if __name__ == "__main__":
                conn.autocommit = True
                path = f"/dev/mapper/diskjumble-{args.disk_id}"
                with open(path, "rb", buffering = _READ_BUFFER_SIZE) as disk_file:
-                       do_verify(conn, args.disk_id, disk_file, _READ_BUFFER_SIZE, args.read_tries)
+                       verify_func = do_verify_v2 if args.entity_version == "2" else do_verify
+                       verify_func(conn, args.disk_id, disk_file, _READ_BUFFER_SIZE, args.read_tries)