from __future__ import annotations
-from abc import ABCMeta, abstractmethod
from dataclasses import dataclass
-from typing import Optional
+from typing import Iterable, Optional
import argparse
import contextlib
import datetime as dt
+import hashlib
import io
import itertools
import math
+from psycopg2.extras import NumericRange
import psycopg2
from disk_jumble import bencode
from disk_jumble.nettle import Sha1Hasher
+_V2_BLOCK_SIZE = 16 * 1024 # in bytes
+
_READ_BUFFER_SIZE = 16 * 1024 ** 2 # in bytes
for _ in range(read_tries):
try:
data = disk_file.read(min(end_pos - pos, read_size))
- except OSError as e:
+ except OSError:
disk_file.seek(pos)
else:
break
db.mark_verify_piece_failed(verify_id)
+@dataclass
+class _VerifyResult:
+ sector: int
+
+
+@dataclass
+class _VerifyPass(_VerifyResult):
+ pass
+
+
+@dataclass
+class _VerifyFail(_VerifyResult):
+ entity_id: bytes
+ piece_num: int
+
+
+def _gen_verify_results(conn, disk_id: int, disk_file: io.BufferedIOBase, read_size: int, read_tries: int) -> Iterable[_VerifyResult]:
+ with conn.cursor() as cursor:
+ cursor.execute(
+ """
+ with
+ slab_plus as (
+ select
+ *,
+ int8range(
+ slab.entity_offset / %(block_size)s,
+ salb.entity_offset / %(block_size)s + upper(slab.disk_blocks) - lower(slab.disk_blocks)
+ ) as entity_blocks
+ from diskjumble.slab
+ )
+ select
+ entity_id,
+ generate_series(
+ lower(slab_plus.entity_blocks * elh.block_range),
+ upper(slab_plus.entity_blocks * elh.block_range) - 1
+ ) as piece_num,
+ generate_series(
+ (
+ lower(slab_plus.entity_blocks * elh.block_range)
+ - lower(slab_plus.entity_blocks)
+ + lower(slab_plus.disk_blocks)
+ ),
+ (
+ upper(slab_plus.entity_blocks * elh.block_range)
+ - lower(slab_plus.entity_blocks)
+ + lower(slab_plus.disk_blocks)
+ - 1
+ )
+ ) as sector,
+ entity.length as entity_length
+ substring(hashes, generate_series(0, octet_length(hashes) / 32 - 1, 32), 32) as hash
+ from (
+ entityv2_leaf_hashes elh
+ join slab_plus on (
+ slab_plus.entity_id = elh.entity_id
+ and slab_plus.entity_blocks && elh.block_range
+ )
+ left outer join public.entity using (entity_id)
+ )
+ where slab_plus.disk_id = %(disk_id)s
+ order by sector
+ """,
+ {"block_size": _V2_BLOCK_SIZE, "disk_id": disk_id}
+ )
+ for (entity_id, piece_num, sector, entity_len, hash_) in cursor:
+ read_start = sector * _V2_BLOCK_SIZE
+ read_end = read_start + min(_V2_BLOCK_SIZE, entity_len - piece_num * _V2_BLOCK_SIZE)
+ disk_file.seek(read_start)
+ hasher = hashlib.sha256()
+ try:
+ while disk_file.tell() < read_end:
+ pos = disk_file.tell()
+ for _ in range(read_tries):
+ try:
+ data = disk_file.read(min(read_end - pos, read_size))
+ except OSError:
+ disk_file.seek(pos)
+ else:
+ break
+ else:
+ raise _BadSector()
+
+ assert data
+ hasher.update(data)
+ except _BadSector:
+ pass_ = False
+ else:
+ pass_ = hasher.digest() == hash_
+
+ if pass_:
+ yield _VerifyPass(sector)
+ else:
+ yield _VerifyFail(sector, entity_id, piece_num)
+
+
+def do_verify_v2(conn, disk_id: int, disk_file: io.BufferedIOBase, read_size: int, read_tries: int) -> None:
+ ts = dt.datetime.now(dt.timezone.utc)
+ with conn.cursor() as cursor:
+ def save_pass_range(r):
+ cursor.execute(
+ "insert into diskjumble.verify_pass values (default, %s, %s, %s);",
+ (ts, disk_id, NumericRange(r.start, r.stop))
+ )
+
+ pass_sectors = None
+ for result in _gen_verify_results(conn, disk_id, disk_file, read_size, read_tries):
+ if isinstance(result, _VerifyPass):
+ if pass_sectors is None:
+ pass_sectors = range(result.sector, result.sector + 1)
+ elif result.sector == pass_sectors.stop:
+ pass_sectors = range(pass_sectors.start, result.sector + 1)
+ else:
+ save_pass_range(pass_sectors)
+ pass_sectors = range(result.sector, result.sector + 1)
+ else:
+ assert isinstance(result, _VerifyFail)
+ if pass_sectors:
+ save_pass_range(pass_sectors)
+ pass_sectors = None
+
+ cursor.execute(
+ """
+ with
+ new_piece as (
+ insert into diskjumble.verify_piece
+ values (default, %s, %s, %s)
+ returning verify_id
+ ),
+ _ as (
+ insert into diskjumble.verify_piece_content
+ values (new_piece.verify_id, 0, %s, %s)
+ )
+ insert into diskjumble.verify_piece_fail
+ select verify_id from new_piece
+ """,
+ (ts, result.entity_id, result.piece_num, disk_id, NumericRange(result.sector, result.sector + 1))
+ )
+ if pass_sectors:
+ save_pass_range(pass_sectors)
+
+
if __name__ == "__main__":
def read_tries(raw_arg):
val = int(raw_arg)
raise ValueError()
parser = argparse.ArgumentParser()
+ parser.add_argument("entity_version", choices = ["1", "2"])
parser.add_argument("disk_id", type = int)
parser.add_argument(
"read_tries",
conn.autocommit = True
path = f"/dev/mapper/diskjumble-{args.disk_id}"
with open(path, "rb", buffering = _READ_BUFFER_SIZE) as disk_file:
- do_verify(conn, args.disk_id, disk_file, _READ_BUFFER_SIZE, args.read_tries)
+ verify_func = do_verify_v2 if args.entity_version == "2" else do_verify
+ verify_func(conn, args.disk_id, disk_file, _READ_BUFFER_SIZE, args.read_tries)