From: Jakob Cornell Date: Thu, 4 Nov 2021 15:41:28 +0000 (-0500) Subject: Add verify program (untested) X-Git-Url: https://jcornell.net/gitweb/gitweb.cgi?a=commitdiff_plain;h=96cb9bdf80216f94456b2e0b4985b6c5e74db123;p=eros.git Add verify program (untested) --- diff --git a/src/disk_jumble/db.py b/src/disk_jumble/db.py index cb8c7fd..9e88108 100644 --- a/src/disk_jumble/db.py +++ b/src/disk_jumble/db.py @@ -1,6 +1,32 @@ -from typing import Iterable, Optional +from dataclasses import dataclass +from typing import Iterable, Mapping, Optional +import datetime as dt +import itertools -from psycopg2.extras import execute_batch +from psycopg2.extras import execute_batch, Json, NumericRange + + +@dataclass +class Disk: + sector_size: int + + +@dataclass +class Slab: + id: int + disk_id: int + sectors: range + entity_id: bytes + entity_offset: int + crypt_key: bytes + + +@dataclass +class HasherRef: + id: int + seq: int + entity_offset: int + state: dict class Wrapper: @@ -12,17 +38,18 @@ class Wrapper: with self.conn.cursor() as cursor: cursor.execute("select passkey from gazelle.passkey where gazelle_tracker_id = %s;", (tracker_id,)) [(passkey,)] = cursor.fetchall() + return passkey def get_torrents(self, tracker_id: int, batch_size: Optional[int] = None) -> Iterable[bytes]: """Iterate the info hashes for the specified tracker which haven't been marked deleted.""" - stmt = ( - "select infohash from gazelle.torrent" - + " where gazelle_tracker_id = %s and not is_deleted" - + " order by infohash asc" - + ";" - ) + stmt = """ + select infohash from gazelle.torrent + where gazelle_tracker_id = %s and not is_deleted + order by infohash asc + ; + """ with self.conn.cursor() as cursor: if batch_size is not None: cursor.itersize = batch_size @@ -34,14 +61,164 @@ class Wrapper: yield bytes(info_hash) def insert_swarm_info(self, tracker_id: int, infos: Iterable["disk_jumble.scrape.ScrapeInfo"]) -> None: - stmt = ( - "insert into gazelle.tracker_stat (gazelle_tracker_id, infohash, ts, complete, incomplete, downloaded)" - + " values (%s, %s, %s, %s, %s, %s)" - + ";" - ) + stmt = """ + insert into gazelle.tracker_stat (gazelle_tracker_id, infohash, ts, complete, incomplete, downloaded) + values (%s, %s, %s, %s, %s, %s) + ; + """ with self.conn.cursor() as cursor: param_sets = [ (tracker_id, i.info_hash, i.timestamp, i.complete, i.incomplete, i.downloaded) for i in infos ] execute_batch(cursor, stmt, param_sets) + + def get_disk(self, id_: int) -> Disk: + with self.conn.cursor() as cursor: + cursor.execute("select sector_size from diskjumble.disk where disk_id = %s;", (id_,)) + [(sector_size,)] = cursor.fetchall() + + return Disk(sector_size) + + def get_slabs_and_hashers(self, disk_id: int) -> Iterable[tuple[Slab, Optional[HasherRef]]]: + """ + Find slabs on the specified disk, and for each also return any lowest-offset saved hasher that left off directly + before or within the slab's entity data. + """ + + stmt = """ + with + incomplete_edge as ( + -- join up incomplete piece info and precompute where the hasher left off within the entity + select + verify_id, seq, entity_id, hasher_state, + entity_offset + (upper(c.disk_sectors) - lower(slab.disk_sectors)) * sector_size as end_off + from + diskjumble.verify_piece_incomplete + natural left join diskjumble.verify_piece p + natural join diskjumble.verify_piece_content c + natural left join diskjumble.disk + left join diskjumble.slab on ( + c.disk_id = slab.disk_id + and upper(c.disk_sectors) <@ slab.disk_sectors + ) + where seq >= all (select seq from diskjumble.verify_piece_content where verify_id = p.verify_id) + ) + select + slab_id, disk_id, disk_sectors, entity_id, entity_offset, crypt_key, verify_id, seq, end_off, + hasher_state + from + diskjumble.slab + natural left join diskjumble.disk + left join incomplete_edge on + incomplete_edge.entity_id = slab.entity_id + and incomplete_edge.end_off % sector_size == 0 + and incomplete_edge.end_off <@ int8range( + slab.entity_offset, + slab.entity_offset + (upper(disk_sectors) - lower(disk_sectors)) * sector_size + ) + where disk_id = %s + order by entity_id, entity_offset, slab_id + ; + """ + with self.conn.cursor() as cursor: + cursor.execute(stmt, (disk_id,)) + for (_, rows_iter) in itertools.groupby(cursor, lambda r: r[0]): + rows = list(rows_iter) + [(slab_id, disk_id, sectors_pg, entity_id, entity_off, key)] = {r[:6] for r in rows} + sectors = range(sectors_pg.lower, sectors_pg.upper) + slab = Slab(slab_id, disk_id, sectors, entity_id, entity_off, key) + + # `None' if no hasher match in outer join, otherwise earliest match + (*_, id_, seq, end_off, state) = min(rows, key = lambda r: r[-2]) + hasher_ref = None if id_ is None else HasherRef(id_, seq, end_off, state) + + yield (slab, hasher_ref) + + def get_torrent_info(self, disk_id: int) -> Mapping[bytes, bytes]: + stmt = """ + with hashed as ( + select digest(info, 'sha1') as info_hash, info + from bittorrent.torrent_info + ) + select + distinct on (info_hash) + info_hash, info + from diskjumble.slab left outer join hashed on entity_id = info_hash + where disk_id = %s + ; + """ + with self.conn.cursor() as cursor: + cursor.execute(stmt, (disk_id,)) + for (info_hash, info) in cursor: + yield (info_hash, info) + + def insert_verify_piece(self, ts: dt.datetime, entity_id: bytes, piece_num: int) -> int: + """Insert new verify piece, returning the ID of the inserted row.""" + + with self.conn.cursor() as cursor: + stmt = "insert into diskjumble.verify_piece values (default, %s, %s, %s) returning verify_id;" + cursor.execute(stmt, (ts, entity_id, piece_num)) + [(row_id,)] = cursor.fetchall() + return row_id + + def insert_verify_piece_content(self, verify_id: int, seq_start: int, disk_id: int, ranges: Iterable[range]) -> None: + with self.conn.cursor() as cursor: + execute_batch( + cursor, + "insert into diskjumble.verify_piece_content values (%s, %s, %s, %s);", + [ + (verify_id, seq, disk_id, NumericRange(r.start, r.stop)) + for (seq, r) in enumerate(ranges, start = seq_start) + ] + ) + + def mark_verify_piece_failed(self, verify_id: int) -> None: + with self.conn.cursor() as cursor: + cursor.execute("insert into diskjumble.verify_piece_fail values (%s);", (verify_id,)) + + def upsert_hasher_state(self, verify_id: int, state: dict) -> None: + stmt = """ + insert into diskjumble.verify_piece_incomplete values (%s, %s) + on conflict (verify_id) do update set hasher_state = excluded.hasher_state + ; + """ + + with self.conn.cursor() as cursor: + cursor.execute(stmt, (verify_id, Json(state))) + + def delete_verify_piece(self, verify_id: int) -> None: + with self.conn.cursor() as cursor: + cursor.execute("delete from diskjumble.verify_piece_incomplete where verify_id = %s;", (verify_id,)) + cursor.execute("delete from diskjumble.verify_piece_content where verify_id = %s;", (verify_id,)) + cursor.execute("delete from diskjumble.verify_piece where verify_id = %s", (verify_id,)) + + def move_piece_content_for_pass(self, verify_id: int) -> None: + stmt = """ + with content_out as ( + delete from diskjumble.verify_piece_content c + using diskjumble.verify_piece p + where ( + c.verify_id = p.verify_id + and p.verify_id = %s + ) + returning at, disk_id, disk_sectors + ) + insert into diskjumble.verify_pass (at, disk_id, disk_sectors) + select at, disk_id, disk_sectors from content_out + ; + """ + + with self.conn.cursor() as cursor: + cursor.execute(stmt, (verify_id,)) + + def insert_pass_data(self, ts: dt.datetime, disk_id: int, sectors: range) -> None: + with self.conn.cursor() as cursor: + cursor.execute( + "insert into diskjumble.verify_pass values (default, %s, %s, %s);", + (ts, disk_id, NumericRange(sectors.start, sectors.stop)) + ) + + def clear_incomplete(self, verify_id: int) -> None: + with self.conn.cursor() as cursor: + cursor.execute("delete from diskjumble.verify_piece_incomplete where verify_id = %s;", (verify_id,)) diff --git a/src/disk_jumble/nettle.py b/src/disk_jumble/nettle.py new file mode 100644 index 0000000..cd73b1a --- /dev/null +++ b/src/disk_jumble/nettle.py @@ -0,0 +1,65 @@ +"""Python wrappers for some of GnuTLS Nettle.""" + +from typing import Optional +import ctypes + + +_LIB = ctypes.CDLL(ctypes.find_library("nettle")) + + +class _Sha1Defs: + _DIGEST_SIZE = 20 # in bytes + _BLOCK_SIZE = 64 # in bytes + _DIGEST_LENGTH = 5 + + _StateArr = ctypes.c_uint32 * _DIGEST_LENGTH + + _BlockArr = ctypes.c_uint8 * _BLOCK_SIZE + + +class Sha1Hasher(_Sha1Defs): + class Context(ctypes.Structure): + _fields_ = [ + ("state", _Sha1Defs._StateArr), + ("count", ctypes.c_unit64), + ("index", ctypes.c_uint), + ("block", _Sha1Defs._BlockArr), + ] + + @classmethod + def deserialize(cls, data): + return cls( + cls.StateArr(*data["state"]), + data["count"], + data["index"], + cls.BlockArr(*data["block"]), + ) + + def serialize(self): + return { + "state": list(self.state), + "count": self.count, + "index": self.index, + "block": list(self.block), + } + + @classmethod + def _new_context(cls): + ctx = cls.Context() + _LIB.sha1_init(ctypes.byref(ctx)) + return ctx + + def __init__(self, ctx_dict: Optional[dict]): + if ctx_dict: + self.ctx = self.Context.deserialize(ctx_dict) + else: + self.ctx = self._new_context() + + def update(self, data): + _LIB.sha1_update(ctypes.byref(self.ctx), len(data), data) + + def digest(self): + """Return the current digest and reset the hasher state.""" + out = (ctypes.c_uint8 * self._DIGEST_SIZE)() + _LIB.sha1_digest(ctypes.byref(self.ctx), self._DIGEST_SIZE, out) + return bytes(out) diff --git a/src/disk_jumble/verify.py b/src/disk_jumble/verify.py new file mode 100644 index 0000000..797962a --- /dev/null +++ b/src/disk_jumble/verify.py @@ -0,0 +1,193 @@ +from dataclasses import dataclass +from typing import Optional +import argparse +import contextlib +import datetime as dt +import itertools +import math + +import psycopg2 + +from disk_jumble import bencode +from disk_jumble.db import HasherRef, Slab, Wrapper as DbWrapper +from disk_jumble.nettle import Sha1Hasher + + +_READ_BUFFER_SIZE = 16 * 1024 ** 2 # in bytes + + +@dataclass +class _SlabChunk: + """A slice of a slab; comprising all or part of a piece to be hashed.""" + slab: Slab + slice: slice + + +@dataclass +class _PieceTask: + """The chunks needed to hash as fully as possible an entity piece.""" + entity_id: bytes + piece_num: int + hasher_ref: Optional[HasherRef] + chunks: list[_SlabChunk] + complete: bool # do these chunks complete the piece? + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("disk_id", type = int) + args = parser.parse_args() + + with contextlib.closing(psycopg2.connect("")) as conn: + db = DbWrapper(conn) + disk = db.get_disk(args.disk_id) + + info_dicts = { + info_hash: bencode.decode(info) + for (info_hash, info) in db.get_torrent_info(args.disk_id) + } + + tasks = [] + slabs_and_hashers = db.get_slabs_and_hashers(args.disk_id) + for (entity_id, group) in itertools.groupby(slabs_and_hashers, lambda t: t[0].entity_id): + info = info_dicts[entity_id] + piece_len = info[b"piece length"] + assert piece_len % disk.sector_size == 0 + if b"length" in info: + torrent_len = info[b"length"] + else: + torrent_len = sum(d[b"length"] for d in info[b"files"]) + + offset = None + use_hasher = None + chunks = [] + for (slab, hasher_ref) in group: + slab_end = slab.entity_offset + len(slab.sectors) * disk.sector_size + + if offset is not None and slab.entity_offset > offset: + if chunks: + tasks.append(_PieceTask(entity_id, offset // piece_len, use_hasher, chunks, False)) + offset = None + use_hasher = None + chunks = [] + + if offset is None: + aligned = math.ceil(slab.entity_offset / piece_len) * piece_len + if hasher_ref and hasher_ref.entity_offset < aligned: + assert hasher_ref.entity_offset < torrent_len + use_hasher = hasher_ref + offset = hasher_ref.entity_offset + elif aligned < min(slab_end, torrent_len): + offset = aligned + + if offset is not None: + piece_end = min(offset + piece_len - offset % piece_len, torrent_len) + chunk_end = min(piece_end, slab_end) + chunks.append(_SlabChunk(slab, slice(offset - slab.entity_offset, chunk_end - slab.entity_offset))) + if chunk_end == piece_end: + tasks.append(_PieceTask(entity_id, offset // piece_len, use_hasher, chunks, True)) + offset = None + use_hasher = None + chunks = [] + else: + offset = chunk_end + + if chunks: + tasks.append(_PieceTask(entity_id, offset // piece_len, use_hasher, chunks, False)) + + @dataclass + class NewVerifyPiece: + entity_id: bytes + piece_num: int + sector_ranges: list[range] + hasher_state: dict + failed: bool + + @dataclass + class VerifyUpdate: + seq_start: int + new_sector_ranges: list[range] + hasher_state: dict + + passed_verifies = set() + failed_verifies = set() + new_pass_ranges = [] + vp_updates = {} + new_vps = [] + + run_ts = dt.datetime.now(dt.timezone.utc) + with open(f"/dev/mapper/diskjumble-{args.disk_id}", "rb", buffering = _READ_BUFFER_SIZE) as dev: + for task in tasks: + hasher = Sha1Hasher(task.hasher_ref.state if task.hasher_ref else None) + sector_ranges = [ + range( + chunk.slab.sectors.start + chunk.slice.start // disk.sector_size, + chunk.slab.sectors.start + chunk.slice.stop // disk.sector_size + ) + for chunk in task.chunks + ] + + for chunk in task.chunks: + slab_off = chunk.slab.sectors.start * disk.sector_size + dev.seek(slab_off + chunk.slice.start) + end_pos = slab_off + chunk.slice.stop + while dev.tell() < end_pos: + data = dev.read(min(end_pos - dev.tell(), _READ_BUFFER_SIZE)) + assert data + hasher.update(data) + + hasher_state = hasher.ctx.serialize() + if task.complete: + s = slice(task.piece_num * 20, task.piece_num * 20 + 20) + expected_hash = info_dicts[task.entity_id][b"pieces"][s] + if hasher.digest() == expected_hash: + write_piece_data = False + new_pass_ranges.extend(sector_ranges) + if task.hasher_ref: + passed_verifies.add(task.hasher_ref.id) + else: + failed = True + write_piece_data = True + if task.hasher_ref: + failed_verifies.add(task.hasher_ref.id) + else: + failed = False + write_piece_data = True + + if write_piece_data: + if task.hasher_ref: + assert task.hasher_ref.id not in vp_updates + vp_updates[task.hasher_ref.id] = VerifyUpdate(task.hasher_ref.seq + 1, sector_ranges, hasher_state) + else: + new_vps.append(NewVerifyPiece(task.entity_id, task.piece_num, sector_ranges, hasher_state, failed)) + + new_pass_ranges.sort(key = lambda r: r.start) + merged_ranges = [] + for r in new_pass_ranges: + if merged_ranges and r.start == merged_ranges[-1].stop: + merged_ranges[-1] = range(merged_ranges[-1].start, r.stop) + else: + merged_ranges.append(r) + + for vp in new_vps: + verify_id = db.insert_verify_piece(run_ts, vp.entity_id, vp.piece_num) + db.insert_verify_piece_content(verify_id, 0, args.disk_id, vp.sector_ranges) + if vp.failed: + db.mark_verify_piece_failed(verify_id) + else: + db.upsert_hasher_state(verify_id, vp.hasher_state) + + for (verify_id, update) in vp_updates.items(): + db.insert_verify_piece_content(verify_id, update.seq_start, args.disk_id, update.new_sector_ranges) + db.upsert_hasher_state(verify_id, update.hasher_state) + + for verify_id in passed_verifies: + db.move_piece_content_for_pass(verify_id) + db.delete_verify_piece(verify_id) + + for r in merged_ranges: + db.insert_pass_data(run_ts, args.disk_id, r) + + for verify_id in failed_verifies: + db.clear_incomplete(verify_id) + db.mark_verify_piece_failed(verify_id)