Add verify program (untested)
authorJakob Cornell <jakob+gpg@jcornell.net>
Thu, 4 Nov 2021 15:41:28 +0000 (10:41 -0500)
committerJakob Cornell <jakob+gpg@jcornell.net>
Thu, 4 Nov 2021 15:41:28 +0000 (10:41 -0500)
src/disk_jumble/db.py
src/disk_jumble/nettle.py [new file with mode: 0644]
src/disk_jumble/verify.py [new file with mode: 0644]

index cb8c7fd461e29bb4a6939dbff7832c5c8d0adcce..9e881080ab1ce8ec60e36bece8a5a2b4804480e3 100644 (file)
@@ -1,6 +1,32 @@
-from typing import Iterable, Optional
+from dataclasses import dataclass
+from typing import Iterable, Mapping, Optional
+import datetime as dt
+import itertools
 
-from psycopg2.extras import execute_batch
+from psycopg2.extras import execute_batch, Json, NumericRange
+
+
+@dataclass
+class Disk:
+       sector_size: int
+
+
+@dataclass
+class Slab:
+       id: int
+       disk_id: int
+       sectors: range
+       entity_id: bytes
+       entity_offset: int
+       crypt_key: bytes
+
+
+@dataclass
+class HasherRef:
+       id: int
+       seq: int
+       entity_offset: int
+       state: dict
 
 
 class Wrapper:
@@ -12,17 +38,18 @@ class Wrapper:
                with self.conn.cursor() as cursor:
                        cursor.execute("select passkey from gazelle.passkey where gazelle_tracker_id = %s;", (tracker_id,))
                        [(passkey,)] = cursor.fetchall()
+
                return passkey
 
        def get_torrents(self, tracker_id: int, batch_size: Optional[int] = None) -> Iterable[bytes]:
                """Iterate the info hashes for the specified tracker which haven't been marked deleted."""
 
-               stmt = (
-                       "select infohash from gazelle.torrent"
-                       + " where gazelle_tracker_id = %s and not is_deleted"
-                       + " order by infohash asc"
-                       + ";"
-               )
+               stmt = """
+                       select infohash from gazelle.torrent
+                       where gazelle_tracker_id = %s and not is_deleted
+                       order by infohash asc
+                       ;
+               """
                with self.conn.cursor() as cursor:
                        if batch_size is not None:
                                cursor.itersize = batch_size
@@ -34,14 +61,164 @@ class Wrapper:
                                yield bytes(info_hash)
 
        def insert_swarm_info(self, tracker_id: int, infos: Iterable["disk_jumble.scrape.ScrapeInfo"]) -> None:
-               stmt = (
-                       "insert into gazelle.tracker_stat (gazelle_tracker_id, infohash, ts, complete, incomplete, downloaded)"
-                       + " values (%s, %s, %s, %s, %s, %s)"
-                       + ";"
-               )
+               stmt = """
+                       insert into gazelle.tracker_stat (gazelle_tracker_id, infohash, ts, complete, incomplete, downloaded)
+                       values (%s, %s, %s, %s, %s, %s)
+                       ;
+               """
                with self.conn.cursor() as cursor:
                        param_sets = [
                                (tracker_id, i.info_hash, i.timestamp, i.complete, i.incomplete, i.downloaded)
                                for i in infos
                        ]
                        execute_batch(cursor, stmt, param_sets)
+
+       def get_disk(self, id_: int) -> Disk:
+               with self.conn.cursor() as cursor:
+                       cursor.execute("select sector_size from diskjumble.disk where disk_id = %s;", (id_,))
+                       [(sector_size,)] = cursor.fetchall()
+
+               return Disk(sector_size)
+
+       def get_slabs_and_hashers(self, disk_id: int) -> Iterable[tuple[Slab, Optional[HasherRef]]]:
+               """
+               Find slabs on the specified disk, and for each also return any lowest-offset saved hasher that left off directly
+               before or within the slab's entity data.
+               """
+
+               stmt = """
+                       with
+                               incomplete_edge as (
+                                       -- join up incomplete piece info and precompute where the hasher left off within the entity
+                                       select
+                                               verify_id, seq, entity_id, hasher_state,
+                                               entity_offset + (upper(c.disk_sectors) - lower(slab.disk_sectors)) * sector_size as end_off
+                                       from
+                                               diskjumble.verify_piece_incomplete
+                                               natural left join diskjumble.verify_piece p
+                                               natural join diskjumble.verify_piece_content c
+                                               natural left join diskjumble.disk
+                                               left join diskjumble.slab on (
+                                                       c.disk_id = slab.disk_id
+                                                       and upper(c.disk_sectors) <@ slab.disk_sectors
+                                               )
+                                       where seq >= all (select seq from diskjumble.verify_piece_content where verify_id = p.verify_id)
+                               )
+                       select
+                               slab_id, disk_id, disk_sectors, entity_id, entity_offset, crypt_key, verify_id, seq, end_off,
+                               hasher_state
+                       from
+                               diskjumble.slab
+                               natural left join diskjumble.disk
+                               left join incomplete_edge on
+                                       incomplete_edge.entity_id = slab.entity_id
+                                       and incomplete_edge.end_off % sector_size == 0
+                                       and incomplete_edge.end_off <@ int8range(
+                                               slab.entity_offset,
+                                               slab.entity_offset + (upper(disk_sectors) - lower(disk_sectors)) * sector_size
+                                       )
+                       where disk_id = %s
+                       order by entity_id, entity_offset, slab_id
+                       ;
+               """
+               with self.conn.cursor() as cursor:
+                       cursor.execute(stmt, (disk_id,))
+                       for (_, rows_iter) in itertools.groupby(cursor, lambda r: r[0]):
+                               rows = list(rows_iter)
+                               [(slab_id, disk_id, sectors_pg, entity_id, entity_off, key)] = {r[:6] for r in rows}
+                               sectors = range(sectors_pg.lower, sectors_pg.upper)
+                               slab = Slab(slab_id, disk_id, sectors, entity_id, entity_off, key)
+
+                               # `None' if no hasher match in outer join, otherwise earliest match
+                               (*_, id_, seq, end_off, state) = min(rows, key = lambda r: r[-2])
+                               hasher_ref = None if id_ is None else HasherRef(id_, seq, end_off, state)
+
+                               yield (slab, hasher_ref)
+
+       def get_torrent_info(self, disk_id: int) -> Mapping[bytes, bytes]:
+               stmt = """
+                       with hashed as (
+                               select digest(info, 'sha1') as info_hash, info
+                               from bittorrent.torrent_info
+                       )
+                       select
+                               distinct on (info_hash)
+                               info_hash, info
+                               from diskjumble.slab left outer join hashed on entity_id = info_hash
+                               where disk_id = %s
+                       ;
+               """
+               with self.conn.cursor() as cursor:
+                       cursor.execute(stmt, (disk_id,))
+                       for (info_hash, info) in cursor:
+                               yield (info_hash, info)
+
+       def insert_verify_piece(self, ts: dt.datetime, entity_id: bytes, piece_num: int) -> int:
+               """Insert new verify piece, returning the ID of the inserted row."""
+
+               with self.conn.cursor() as cursor:
+                       stmt = "insert into diskjumble.verify_piece values (default, %s, %s, %s) returning verify_id;"
+                       cursor.execute(stmt, (ts, entity_id, piece_num))
+                       [(row_id,)] = cursor.fetchall()
+               return row_id
+
+       def insert_verify_piece_content(self, verify_id: int, seq_start: int, disk_id: int, ranges: Iterable[range]) -> None:
+               with self.conn.cursor() as cursor:
+                       execute_batch(
+                               cursor,
+                               "insert into diskjumble.verify_piece_content values (%s, %s, %s, %s);",
+                               [
+                                       (verify_id, seq, disk_id, NumericRange(r.start, r.stop))
+                                       for (seq, r) in enumerate(ranges, start = seq_start)
+                               ]
+                       )
+
+       def mark_verify_piece_failed(self, verify_id: int) -> None:
+               with self.conn.cursor() as cursor:
+                       cursor.execute("insert into diskjumble.verify_piece_fail values (%s);", (verify_id,))
+
+       def upsert_hasher_state(self, verify_id: int, state: dict) -> None:
+               stmt = """
+                       insert into diskjumble.verify_piece_incomplete values (%s, %s)
+                       on conflict (verify_id) do update set hasher_state = excluded.hasher_state
+                       ;
+               """
+
+               with self.conn.cursor() as cursor:
+                       cursor.execute(stmt, (verify_id, Json(state)))
+
+       def delete_verify_piece(self, verify_id: int) -> None:
+               with self.conn.cursor() as cursor:
+                       cursor.execute("delete from diskjumble.verify_piece_incomplete where verify_id = %s;", (verify_id,))
+                       cursor.execute("delete from diskjumble.verify_piece_content where verify_id = %s;", (verify_id,))
+                       cursor.execute("delete from diskjumble.verify_piece where verify_id = %s", (verify_id,))
+
+       def move_piece_content_for_pass(self, verify_id: int) -> None:
+               stmt = """
+                       with content_out as (
+                               delete from diskjumble.verify_piece_content c
+                               using diskjumble.verify_piece p
+                               where (
+                                       c.verify_id = p.verify_id
+                                       and p.verify_id = %s
+                               )
+                               returning at, disk_id, disk_sectors
+                       )
+                       insert into diskjumble.verify_pass (at, disk_id, disk_sectors)
+                       select at, disk_id, disk_sectors from content_out
+                       ;
+               """
+
+               with self.conn.cursor() as cursor:
+                       cursor.execute(stmt, (verify_id,))
+
+       def insert_pass_data(self, ts: dt.datetime, disk_id: int, sectors: range) -> None:
+               with self.conn.cursor() as cursor:
+                       cursor.execute(
+                               "insert into diskjumble.verify_pass values (default, %s, %s, %s);",
+                               (ts, disk_id, NumericRange(sectors.start, sectors.stop))
+                       )
+
+       def clear_incomplete(self, verify_id: int) -> None:
+               with self.conn.cursor() as cursor:
+                       cursor.execute("delete from diskjumble.verify_piece_incomplete where verify_id = %s;", (verify_id,))
diff --git a/src/disk_jumble/nettle.py b/src/disk_jumble/nettle.py
new file mode 100644 (file)
index 0000000..cd73b1a
--- /dev/null
@@ -0,0 +1,65 @@
+"""Python wrappers for some of GnuTLS Nettle."""
+
+from typing import Optional
+import ctypes
+
+
+_LIB = ctypes.CDLL(ctypes.find_library("nettle"))
+
+
+class _Sha1Defs:
+       _DIGEST_SIZE = 20  # in bytes
+       _BLOCK_SIZE = 64  # in bytes
+       _DIGEST_LENGTH = 5
+
+       _StateArr = ctypes.c_uint32 * _DIGEST_LENGTH
+
+       _BlockArr = ctypes.c_uint8 * _BLOCK_SIZE
+
+
+class Sha1Hasher(_Sha1Defs):
+       class Context(ctypes.Structure):
+               _fields_ = [
+                       ("state", _Sha1Defs._StateArr),
+                       ("count", ctypes.c_unit64),
+                       ("index", ctypes.c_uint),
+                       ("block", _Sha1Defs._BlockArr),
+               ]
+
+               @classmethod
+               def deserialize(cls, data):
+                       return cls(
+                               cls.StateArr(*data["state"]),
+                               data["count"],
+                               data["index"],
+                               cls.BlockArr(*data["block"]),
+                       )
+
+               def serialize(self):
+                       return {
+                               "state": list(self.state),
+                               "count": self.count,
+                               "index": self.index,
+                               "block": list(self.block),
+                       }
+
+       @classmethod
+       def _new_context(cls):
+               ctx = cls.Context()
+               _LIB.sha1_init(ctypes.byref(ctx))
+               return ctx
+
+       def __init__(self, ctx_dict: Optional[dict]):
+               if ctx_dict:
+                       self.ctx = self.Context.deserialize(ctx_dict)
+               else:
+                       self.ctx = self._new_context()
+
+       def update(self, data):
+               _LIB.sha1_update(ctypes.byref(self.ctx), len(data), data)
+
+       def digest(self):
+               """Return the current digest and reset the hasher state."""
+               out = (ctypes.c_uint8 * self._DIGEST_SIZE)()
+               _LIB.sha1_digest(ctypes.byref(self.ctx), self._DIGEST_SIZE, out)
+               return bytes(out)
diff --git a/src/disk_jumble/verify.py b/src/disk_jumble/verify.py
new file mode 100644 (file)
index 0000000..797962a
--- /dev/null
@@ -0,0 +1,193 @@
+from dataclasses import dataclass
+from typing import Optional
+import argparse
+import contextlib
+import datetime as dt
+import itertools
+import math
+
+import psycopg2
+
+from disk_jumble import bencode
+from disk_jumble.db import HasherRef, Slab, Wrapper as DbWrapper
+from disk_jumble.nettle import Sha1Hasher
+
+
+_READ_BUFFER_SIZE = 16 * 1024 ** 2  # in bytes
+
+
+@dataclass
+class _SlabChunk:
+       """A slice of a slab; comprising all or part of a piece to be hashed."""
+       slab: Slab
+       slice: slice
+
+
+@dataclass
+class _PieceTask:
+       """The chunks needed to hash as fully as possible an entity piece."""
+       entity_id: bytes
+       piece_num: int
+       hasher_ref: Optional[HasherRef]
+       chunks: list[_SlabChunk]
+       complete: bool  # do these chunks complete the piece?
+
+
+if __name__ == "__main__":
+       parser = argparse.ArgumentParser()
+       parser.add_argument("disk_id", type = int)
+       args = parser.parse_args()
+
+       with contextlib.closing(psycopg2.connect("")) as conn:
+               db = DbWrapper(conn)
+               disk = db.get_disk(args.disk_id)
+
+               info_dicts = {
+                       info_hash: bencode.decode(info)
+                       for (info_hash, info) in db.get_torrent_info(args.disk_id)
+               }
+
+               tasks = []
+               slabs_and_hashers = db.get_slabs_and_hashers(args.disk_id)
+               for (entity_id, group) in itertools.groupby(slabs_and_hashers, lambda t: t[0].entity_id):
+                       info = info_dicts[entity_id]
+                       piece_len = info[b"piece length"]
+                       assert piece_len % disk.sector_size == 0
+                       if b"length" in info:
+                               torrent_len = info[b"length"]
+                       else:
+                               torrent_len = sum(d[b"length"] for d in info[b"files"])
+
+                       offset = None
+                       use_hasher = None
+                       chunks = []
+                       for (slab, hasher_ref) in group:
+                               slab_end = slab.entity_offset + len(slab.sectors) * disk.sector_size
+
+                               if offset is not None and slab.entity_offset > offset:
+                                       if chunks:
+                                               tasks.append(_PieceTask(entity_id, offset // piece_len, use_hasher, chunks, False))
+                                       offset = None
+                                       use_hasher = None
+                                       chunks = []
+
+                               if offset is None:
+                                       aligned = math.ceil(slab.entity_offset / piece_len) * piece_len
+                                       if hasher_ref and hasher_ref.entity_offset < aligned:
+                                               assert hasher_ref.entity_offset < torrent_len
+                                               use_hasher = hasher_ref
+                                               offset = hasher_ref.entity_offset
+                                       elif aligned < min(slab_end, torrent_len):
+                                               offset = aligned
+
+                               if offset is not None:
+                                       piece_end = min(offset + piece_len - offset % piece_len, torrent_len)
+                                       chunk_end = min(piece_end, slab_end)
+                                       chunks.append(_SlabChunk(slab, slice(offset - slab.entity_offset, chunk_end - slab.entity_offset)))
+                                       if chunk_end == piece_end:
+                                               tasks.append(_PieceTask(entity_id, offset // piece_len, use_hasher, chunks, True))
+                                               offset = None
+                                               use_hasher = None
+                                               chunks = []
+                                       else:
+                                               offset = chunk_end
+
+                       if chunks:
+                               tasks.append(_PieceTask(entity_id, offset // piece_len, use_hasher, chunks, False))
+
+               @dataclass
+               class NewVerifyPiece:
+                       entity_id: bytes
+                       piece_num: int
+                       sector_ranges: list[range]
+                       hasher_state: dict
+                       failed: bool
+
+               @dataclass
+               class VerifyUpdate:
+                       seq_start: int
+                       new_sector_ranges: list[range]
+                       hasher_state: dict
+
+               passed_verifies = set()
+               failed_verifies = set()
+               new_pass_ranges = []
+               vp_updates = {}
+               new_vps = []
+
+               run_ts = dt.datetime.now(dt.timezone.utc)
+               with open(f"/dev/mapper/diskjumble-{args.disk_id}", "rb", buffering = _READ_BUFFER_SIZE) as dev:
+                       for task in tasks:
+                               hasher = Sha1Hasher(task.hasher_ref.state if task.hasher_ref else None)
+                               sector_ranges = [
+                                       range(
+                                               chunk.slab.sectors.start + chunk.slice.start // disk.sector_size,
+                                               chunk.slab.sectors.start + chunk.slice.stop // disk.sector_size
+                                       )
+                                       for chunk in task.chunks
+                               ]
+
+                               for chunk in task.chunks:
+                                       slab_off = chunk.slab.sectors.start * disk.sector_size
+                                       dev.seek(slab_off + chunk.slice.start)
+                                       end_pos = slab_off + chunk.slice.stop
+                                       while dev.tell() < end_pos:
+                                               data = dev.read(min(end_pos - dev.tell(), _READ_BUFFER_SIZE))
+                                               assert data
+                                               hasher.update(data)
+
+                               hasher_state = hasher.ctx.serialize()
+                               if task.complete:
+                                       s = slice(task.piece_num * 20, task.piece_num * 20 + 20)
+                                       expected_hash = info_dicts[task.entity_id][b"pieces"][s]
+                                       if hasher.digest() == expected_hash:
+                                               write_piece_data = False
+                                               new_pass_ranges.extend(sector_ranges)
+                                               if task.hasher_ref:
+                                                       passed_verifies.add(task.hasher_ref.id)
+                                       else:
+                                               failed = True
+                                               write_piece_data = True
+                                               if task.hasher_ref:
+                                                       failed_verifies.add(task.hasher_ref.id)
+                               else:
+                                       failed = False
+                                       write_piece_data = True
+
+                               if write_piece_data:
+                                       if task.hasher_ref:
+                                               assert task.hasher_ref.id not in vp_updates
+                                               vp_updates[task.hasher_ref.id] = VerifyUpdate(task.hasher_ref.seq + 1, sector_ranges, hasher_state)
+                                       else:
+                                               new_vps.append(NewVerifyPiece(task.entity_id, task.piece_num, sector_ranges, hasher_state, failed))
+
+               new_pass_ranges.sort(key = lambda r: r.start)
+               merged_ranges = []
+               for r in new_pass_ranges:
+                       if merged_ranges and r.start == merged_ranges[-1].stop:
+                               merged_ranges[-1] = range(merged_ranges[-1].start, r.stop)
+                       else:
+                               merged_ranges.append(r)
+
+               for vp in new_vps:
+                       verify_id = db.insert_verify_piece(run_ts, vp.entity_id, vp.piece_num)
+                       db.insert_verify_piece_content(verify_id, 0, args.disk_id, vp.sector_ranges)
+                       if vp.failed:
+                               db.mark_verify_piece_failed(verify_id)
+                       else:
+                               db.upsert_hasher_state(verify_id, vp.hasher_state)
+
+               for (verify_id, update) in vp_updates.items():
+                       db.insert_verify_piece_content(verify_id, update.seq_start, args.disk_id, update.new_sector_ranges)
+                       db.upsert_hasher_state(verify_id, update.hasher_state)
+
+               for verify_id in passed_verifies:
+                       db.move_piece_content_for_pass(verify_id)
+                       db.delete_verify_piece(verify_id)
+
+               for r in merged_ranges:
+                       db.insert_pass_data(run_ts, args.disk_id, r)
+
+               for verify_id in failed_verifies:
+                       db.clear_incomplete(verify_id)
+                       db.mark_verify_piece_failed(verify_id)