Refactor verify program and add basic tests
authorJakob Cornell <jakob+gpg@jcornell.net>
Fri, 12 Nov 2021 01:19:45 +0000 (19:19 -0600)
committerJakob Cornell <jakob+gpg@jcornell.net>
Fri, 12 Nov 2021 01:19:45 +0000 (19:19 -0600)
src/disk_jumble/bencode.py
src/disk_jumble/db.py
src/disk_jumble/scrape.py
src/disk_jumble/tests/__init__.py [new file with mode: 0644]
src/disk_jumble/tests/test_verify.py [new file with mode: 0644]
src/disk_jumble/tests/verify_setup.sql [new file with mode: 0644]
src/disk_jumble/verify.py
test_util/dump_db.py [new file with mode: 0644]

index 2c9667e6641e64176c81a351667ca3816bfa958a..7883b53363548c36e8dc488dcd35523c876414f8 100644 (file)
@@ -1,15 +1,17 @@
-from typing import Dict, List, Union
+from __future__ import annotations
+from typing import Union
+import itertools
 
 
-Bdict = Dict[bytes, 'Type']
-Type = Union[bytes, int, List['Type'], Bdict]
+Bdict = dict[bytes, 'Type']
+Type = Union[bytes, int, list['Type'], Bdict]
 
 
-class DecodeError(Exception):
+class CodecError(Exception):
        pass
 
 
-def _pop_bytes(vals: List[bytes]) -> bytes:
+def _pop_bytes(vals: list[bytes]) -> bytes:
        len_parts = []
        while vals and vals[0].isdigit():
                len_parts.append(vals.pop(0))
@@ -17,40 +19,40 @@ def _pop_bytes(vals: List[bytes]) -> bytes:
        try:
                length = int(b"".join(len_parts).decode("ascii"))
        except ValueError:
-               raise DecodeError()
+               raise CodecError()
 
        try:
                if vals.pop(0) != b":":
-                       raise DecodeError()
+                       raise CodecError()
        except IndexError:
-               raise DecodeError()
+               raise CodecError()
 
        if length > len(vals):
-               raise DecodeError()
+               raise CodecError()
 
        out = b"".join(vals[:length])
        del vals[:length]
        return out
 
 
-def _pop_int(vals: List[bytes]) -> int:
+def _pop_int(vals: list[bytes]) -> int:
        assert vals.pop(0) == b"i"
 
        try:
                end = vals.index(b"e")
        except ValueError:
-               raise DecodeError()
+               raise CodecError()
 
        try:
                out = int(b"".join(vals[:end]).decode("ascii"))
        except ValueError:
-               raise DecodeError()
+               raise CodecError()
 
        del vals[slice(end + 1)]
        return out
 
 
-def _pop_list(vals: List[bytes]) -> List[Type]:
+def _pop_list(vals: list[bytes]) -> list[Type]:
        assert vals.pop(0) == b"l"
 
        out = []
@@ -61,10 +63,10 @@ def _pop_list(vals: List[bytes]) -> List[Type]:
                del vals[0]
                return out
        else:
-               raise DecodeError()
+               raise CodecError()
 
 
-def _pop_dict(vals: List[bytes]) -> Bdict:
+def _pop_dict(vals: list[bytes]) -> Bdict:
        assert vals.pop(0) == b"d"
 
        out = {}
@@ -76,10 +78,10 @@ def _pop_dict(vals: List[bytes]) -> Bdict:
                del vals[0]
                return out
        else:
-               raise DecodeError()
+               raise CodecError()
 
 
-def _pop_value(vals: List[bytes]) -> Type:
+def _pop_value(vals: list[bytes]) -> Type:
        if vals:
                if vals[0].isdigit():
                        return _pop_bytes(vals)
@@ -90,15 +92,33 @@ def _pop_value(vals: List[bytes]) -> Type:
                elif vals[0] == b"d":
                        return _pop_dict(vals)
                else:
-                       raise DecodeError()
+                       raise CodecError()
        else:
-               raise DecodeError()
+               raise CodecError()
 
 
 def decode(data: bytes) -> Type:
        vals = [bytes([v]) for v in data]
        out = _pop_value(vals)
        if vals:
-               raise DecodeError()
+               raise CodecError()
        else:
                return out
+
+
+def _encode_helper(data: Type) -> list[bytes]:
+       if isinstance(data, bytes):
+               return [str(len(data)).encode("ascii"), b":", data]
+       elif isinstance(data, int):
+               return [b"i", str(data).encode("ascii"), b"e"]
+       elif isinstance(data, list):
+               return [b"l", *itertools.chain.from_iterable(map(_encode_helper, data)), b"e"]
+       elif isinstance(data, dict):
+               contents = itertools.chain.from_iterable(data.items())
+               return [b"d", *itertools.chain.from_iterable(map(_encode_helper, contents)), b"e"]
+       else:
+               raise CodecError()
+
+
+def encode(data: Type) -> bytes:
+       return b"".join(_encode_helper(data))
index b9461deb63b21151e496fc857a13c03dd28ca216..8bb49e35ee9244f3c0810d57ef5896dae89273a0 100644 (file)
@@ -30,10 +30,9 @@ class HasherRef:
        state: dict
 
 
+@dataclass
 class Wrapper:
-       def __init__(self, conn) -> None:
-               conn.autocommit = True
-               self.conn = conn
+       conn: Any
 
        def get_passkey(self, tracker_id: int) -> str:
                with self.conn.cursor() as cursor:
@@ -185,7 +184,6 @@ class Wrapper:
                        on conflict (verify_id) do update set hasher_state = excluded.hasher_state
                        ;
                """
-
                with self.conn.cursor() as cursor:
                        cursor.execute(stmt, (verify_id, Json(state)))
 
@@ -210,7 +208,6 @@ class Wrapper:
                        select at, disk_id, disk_sectors from content_out
                        ;
                """
-
                with self.conn.cursor() as cursor:
                        cursor.execute(stmt, (verify_id,))
 
index 61cb3dddcba9c84e77bed387b49f36f5322ba43b..4143b8a6a74fa1f8852923dcfa4a8c4f1df30c93 100644 (file)
@@ -119,6 +119,7 @@ if __name__ == "__main__":
 
        params = {n: getattr(args, n) for n in PSQL_PARAMS if getattr(args, n)}
        with contextlib.closing(psycopg2.connect(**params)) as conn:
+               conn.autocommit = True
                db_wrapper = DbWrapper(conn)
                passkey = db_wrapper.get_passkey(tracker.gazelle_id)
                info_hashes = iter(db_wrapper.get_torrents(tracker.gazelle_id, args.batch_size))
diff --git a/src/disk_jumble/tests/__init__.py b/src/disk_jumble/tests/__init__.py
new file mode 100644 (file)
index 0000000..e69de29
diff --git a/src/disk_jumble/tests/test_verify.py b/src/disk_jumble/tests/test_verify.py
new file mode 100644 (file)
index 0000000..facd67e
--- /dev/null
@@ -0,0 +1,150 @@
+"""
+Tests for the verification program `disk_jumble.verify'
+
+Like the verification program itself, these tests take database connection information from the environment.  The
+necessary schemas and tables are set up from scratch by the test code, so environment variables should point to a
+database that's not hosting a live instance of Disk Jumble.  Ideally, this is a completely empty local database created
+for the purposes of testing, but other options are possible.
+
+The tests need access to an SQL source file containing the definitions for the required tables and other Postgres
+objects; see `test_util/dump_db.py'.
+"""
+
+from dataclasses import dataclass
+from importlib import resources
+from random import Random
+import hashlib
+import io
+import tempfile
+import unittest
+import uuid
+
+from psycopg2.extras import NumericRange
+import psycopg2
+import psycopg2.extras
+
+from disk_jumble import bencode
+from disk_jumble.verify import do_verify
+
+
+_BUF_SIZE = 16 * 1024 ** 2  # in bytes
+
+
+def _random_file(size: int) -> tempfile.NamedTemporaryFile:
+       r = Random(0)
+       f = tempfile.NamedTemporaryFile(buffering = _BUF_SIZE)
+       try:
+               while f.tell() < size:
+                       write_size = min(size - f.tell(), _BUF_SIZE)
+                       f.write(bytes(r.getrandbits(8) for _ in range(write_size)))
+               f.seek(0)
+               return f
+       except Exception:
+               f.close()
+               raise
+
+
+class Tests(unittest.TestCase):
+       _SCHEMAS = {"public", "diskjumble", "bittorrent"}
+       _BUF_SIZE = 16 * 1024 ** 2  # in bytes
+
+       def _basic_fresh_verify_helper(self, read_size):
+               sector_size = 32
+               piece_size = 64
+
+               torrent_len = 3 * piece_size
+               disk = self._write_disk(sector_size, torrent_len // sector_size)
+               with _random_file(torrent_len) as torrent_file:
+                       torrent = _Torrent(torrent_file, piece_size)
+                       torrent_file.seek(0)
+                       self._write_torrent(torrent)
+                       with self._conn.cursor() as cursor:
+                               cursor.execute(
+                                       "insert into diskjumble.slab values (default, %s, %s, %s, 0, null);",
+                                       (disk.id, NumericRange(0, disk.sector_count), torrent.info_hash)
+                               )
+
+                               do_verify(self._conn, disk.id, torrent_file, read_size)
+
+                               cursor.execute("select * from diskjumble.verify_pass;")
+                               self.assertEqual(cursor.rowcount, 1)
+                               (_, _, disk_id, sectors) = cursor.fetchone()
+                               self.assertEqual(disk_id, disk.id)
+                               self.assertEqual(sectors, NumericRange(0, torrent_len // sector_size))
+
+       def test_basic_fresh_verify_small_read_size(self):
+               self._basic_fresh_verify_helper(16)
+
+       def test_basic_fresh_verify_large_read_size(self):
+               self._basic_fresh_verify_helper(128)
+
+       def _write_torrent(self, torrent: "_Torrent") -> None:
+               with self._conn.cursor() as cursor:
+                       cursor.execute("insert into bittorrent.torrent_info values (%s);", (torrent.info,))
+
+       def _write_disk(self, sector_size: int, sector_count: int) -> "_Disk":
+               with self._conn.cursor() as cursor:
+                       cursor.execute(
+                               "insert into diskjumble.disk values (default, %s, null, %s, %s, false) returning disk_id;",
+                               (uuid.uuid4(), sector_size, sector_count)
+                       )
+                       [(id_,)] = cursor.fetchall()
+               return _Disk(id_, sector_size, sector_count)
+
+       @classmethod
+       def setUpClass(cls):
+               psycopg2.extras.register_uuid()
+               prod_schema_sql = resources.files(__package__).joinpath("verify_setup.sql").read_text()
+               schema_sql = "\n".join(
+                       l for l in prod_schema_sql.splitlines()
+                       if (
+                               not l.startswith("SET")
+                               and not l.startswith("SELECT")  # ignore server settings
+                               and not "OWNER TO" in l  # and ownership changes
+                       )
+               )
+               cls._conn = psycopg2.connect("")
+               with cls._conn, cls._conn.cursor() as cursor:
+                       for name in cls._SCHEMAS:
+                               cursor.execute(f"create schema {name};")
+                       cursor.execute(schema_sql)
+
+       @classmethod
+       def tearDownClass(self):
+               try:
+                       with self._conn, self._conn.cursor() as cursor:
+                               for name in self._SCHEMAS:
+                                       cursor.execute(f"drop schema {name} cascade;")
+               finally:
+                       self._conn.close()
+
+       def tearDown(self):
+               self._conn.rollback()
+
+
+@dataclass
+class _Disk:
+       id: int
+       sector_size: int
+       sector_count: int
+
+
+class _Torrent:
+       def __init__(self, data: io.BufferedIOBase, piece_size: int) -> None:
+               data.seek(0)
+               hashes = []
+               while True:
+                       piece = data.read(piece_size)
+                       if piece:
+                               hashes.append(hashlib.sha1(piece).digest())
+                       else:
+                               break
+
+               info_dict = {
+                       b"piece length": piece_size,
+                       b"length": data.tell(),
+                       b"pieces": b"".join(hashes),
+               }
+               self.data = data
+               self.info = bencode.encode(info_dict)
+               self.info_hash = hashlib.sha1(self.info).digest()
diff --git a/src/disk_jumble/tests/verify_setup.sql b/src/disk_jumble/tests/verify_setup.sql
new file mode 100644 (file)
index 0000000..aa0c5fb
--- /dev/null
@@ -0,0 +1,383 @@
+create extension btree_gist;
+CREATE OR REPLACE FUNCTION public.digest(bytea, text)
+ RETURNS bytea
+ LANGUAGE c
+ IMMUTABLE PARALLEL SAFE STRICT
+AS '$libdir/pgcrypto', $function$pg_digest$function$
+;
+--
+-- PostgreSQL database dump
+--
+
+-- Dumped from database version 13.2 (Debian 13.2-1.pgdg100+1)
+-- Dumped by pg_dump version 13.4 (Debian 13.4-0+deb11u1)
+
+SET statement_timeout = 0;
+SET lock_timeout = 0;
+SET idle_in_transaction_session_timeout = 0;
+SET client_encoding = 'UTF8';
+SET standard_conforming_strings = on;
+SELECT pg_catalog.set_config('search_path', '', false);
+SET check_function_bodies = false;
+SET xmloption = content;
+SET client_min_messages = warning;
+SET row_security = off;
+
+SET default_tablespace = '';
+
+SET default_table_access_method = heap;
+
+--
+-- Name: torrent_info; Type: TABLE; Schema: bittorrent; Owner: eros
+--
+
+CREATE TABLE bittorrent.torrent_info (
+    info bytea NOT NULL
+);
+
+
+ALTER TABLE bittorrent.torrent_info OWNER TO eros;
+
+--
+-- Name: disk; Type: TABLE; Schema: diskjumble; Owner: eros
+--
+
+CREATE TABLE diskjumble.disk (
+    disk_id integer NOT NULL,
+    dev_uuid uuid NOT NULL,
+    dev_serial text,
+    sector_size integer NOT NULL,
+    num_sectors bigint NOT NULL,
+    failed boolean DEFAULT false NOT NULL
+);
+
+
+ALTER TABLE diskjumble.disk OWNER TO eros;
+
+--
+-- Name: disk_id_seq; Type: SEQUENCE; Schema: diskjumble; Owner: eros
+--
+
+CREATE SEQUENCE diskjumble.disk_id_seq
+    AS integer
+    START WITH 1
+    INCREMENT BY 1
+    NO MINVALUE
+    NO MAXVALUE
+    CACHE 1;
+
+
+ALTER TABLE diskjumble.disk_id_seq OWNER TO eros;
+
+--
+-- Name: disk_id_seq; Type: SEQUENCE OWNED BY; Schema: diskjumble; Owner: eros
+--
+
+ALTER SEQUENCE diskjumble.disk_id_seq OWNED BY diskjumble.disk.disk_id;
+
+
+--
+-- Name: slab; Type: TABLE; Schema: diskjumble; Owner: eros
+--
+
+CREATE TABLE diskjumble.slab (
+    slab_id integer NOT NULL,
+    disk_id integer NOT NULL,
+    disk_sectors int8range NOT NULL,
+    entity_id bytea NOT NULL,
+    entity_offset bigint NOT NULL,
+    crypt_key bytea
+);
+
+
+ALTER TABLE diskjumble.slab OWNER TO eros;
+
+--
+-- Name: slab_id_seq; Type: SEQUENCE; Schema: diskjumble; Owner: eros
+--
+
+CREATE SEQUENCE diskjumble.slab_id_seq
+    START WITH 1
+    INCREMENT BY 1
+    NO MINVALUE
+    NO MAXVALUE
+    CACHE 1;
+
+
+ALTER TABLE diskjumble.slab_id_seq OWNER TO eros;
+
+--
+-- Name: slab_id_seq; Type: SEQUENCE OWNED BY; Schema: diskjumble; Owner: eros
+--
+
+ALTER SEQUENCE diskjumble.slab_id_seq OWNED BY diskjumble.slab.slab_id;
+
+
+--
+-- Name: verify_pass; Type: TABLE; Schema: diskjumble; Owner: eros
+--
+
+CREATE TABLE diskjumble.verify_pass (
+    verify_pass_id integer NOT NULL,
+    at timestamp with time zone,
+    disk_id integer NOT NULL,
+    disk_sectors int8range NOT NULL
+);
+
+
+ALTER TABLE diskjumble.verify_pass OWNER TO eros;
+
+--
+-- Name: verify_pass_verify_pass_id_seq; Type: SEQUENCE; Schema: diskjumble; Owner: eros
+--
+
+CREATE SEQUENCE diskjumble.verify_pass_verify_pass_id_seq
+    AS integer
+    START WITH 1
+    INCREMENT BY 1
+    NO MINVALUE
+    NO MAXVALUE
+    CACHE 1;
+
+
+ALTER TABLE diskjumble.verify_pass_verify_pass_id_seq OWNER TO eros;
+
+--
+-- Name: verify_pass_verify_pass_id_seq; Type: SEQUENCE OWNED BY; Schema: diskjumble; Owner: eros
+--
+
+ALTER SEQUENCE diskjumble.verify_pass_verify_pass_id_seq OWNED BY diskjumble.verify_pass.verify_pass_id;
+
+
+--
+-- Name: verify_piece; Type: TABLE; Schema: diskjumble; Owner: eros
+--
+
+CREATE TABLE diskjumble.verify_piece (
+    verify_id integer NOT NULL,
+    at timestamp with time zone,
+    entity_id bytea NOT NULL,
+    piece integer NOT NULL
+);
+
+
+ALTER TABLE diskjumble.verify_piece OWNER TO eros;
+
+--
+-- Name: verify_piece_content; Type: TABLE; Schema: diskjumble; Owner: eros
+--
+
+CREATE TABLE diskjumble.verify_piece_content (
+    verify_id integer NOT NULL,
+    seq integer NOT NULL,
+    disk_id integer NOT NULL,
+    disk_sectors int8range NOT NULL
+);
+
+
+ALTER TABLE diskjumble.verify_piece_content OWNER TO eros;
+
+--
+-- Name: verify_piece_fail; Type: TABLE; Schema: diskjumble; Owner: eros
+--
+
+CREATE TABLE diskjumble.verify_piece_fail (
+    verify_id integer NOT NULL
+);
+
+
+ALTER TABLE diskjumble.verify_piece_fail OWNER TO eros;
+
+--
+-- Name: verify_piece_incomplete; Type: TABLE; Schema: diskjumble; Owner: eros
+--
+
+CREATE TABLE diskjumble.verify_piece_incomplete (
+    verify_id integer NOT NULL,
+    hasher_state json
+);
+
+
+ALTER TABLE diskjumble.verify_piece_incomplete OWNER TO eros;
+
+--
+-- Name: verify_piece_verify_id_seq; Type: SEQUENCE; Schema: diskjumble; Owner: eros
+--
+
+CREATE SEQUENCE diskjumble.verify_piece_verify_id_seq
+    AS integer
+    START WITH 1
+    INCREMENT BY 1
+    NO MINVALUE
+    NO MAXVALUE
+    CACHE 1;
+
+
+ALTER TABLE diskjumble.verify_piece_verify_id_seq OWNER TO eros;
+
+--
+-- Name: verify_piece_verify_id_seq; Type: SEQUENCE OWNED BY; Schema: diskjumble; Owner: eros
+--
+
+ALTER SEQUENCE diskjumble.verify_piece_verify_id_seq OWNED BY diskjumble.verify_piece.verify_id;
+
+
+--
+-- Name: disk disk_id; Type: DEFAULT; Schema: diskjumble; Owner: eros
+--
+
+ALTER TABLE ONLY diskjumble.disk ALTER COLUMN disk_id SET DEFAULT nextval('diskjumble.disk_id_seq'::regclass);
+
+
+--
+-- Name: slab slab_id; Type: DEFAULT; Schema: diskjumble; Owner: eros
+--
+
+ALTER TABLE ONLY diskjumble.slab ALTER COLUMN slab_id SET DEFAULT nextval('diskjumble.slab_id_seq'::regclass);
+
+
+--
+-- Name: verify_pass verify_pass_id; Type: DEFAULT; Schema: diskjumble; Owner: eros
+--
+
+ALTER TABLE ONLY diskjumble.verify_pass ALTER COLUMN verify_pass_id SET DEFAULT nextval('diskjumble.verify_pass_verify_pass_id_seq'::regclass);
+
+
+--
+-- Name: verify_piece verify_id; Type: DEFAULT; Schema: diskjumble; Owner: eros
+--
+
+ALTER TABLE ONLY diskjumble.verify_piece ALTER COLUMN verify_id SET DEFAULT nextval('diskjumble.verify_piece_verify_id_seq'::regclass);
+
+
+--
+-- Name: disk disk_dev_uuid_key; Type: CONSTRAINT; Schema: diskjumble; Owner: eros
+--
+
+ALTER TABLE ONLY diskjumble.disk
+    ADD CONSTRAINT disk_dev_uuid_key UNIQUE (dev_uuid);
+
+
+--
+-- Name: disk disk_pkey; Type: CONSTRAINT; Schema: diskjumble; Owner: eros
+--
+
+ALTER TABLE ONLY diskjumble.disk
+    ADD CONSTRAINT disk_pkey PRIMARY KEY (disk_id);
+
+
+--
+-- Name: slab slab_disk_id_disk_sectors_excl; Type: CONSTRAINT; Schema: diskjumble; Owner: eros
+--
+
+ALTER TABLE ONLY diskjumble.slab
+    ADD CONSTRAINT slab_disk_id_disk_sectors_excl EXCLUDE USING gist (disk_id WITH =, disk_sectors WITH &&);
+
+
+--
+-- Name: slab slab_pkey; Type: CONSTRAINT; Schema: diskjumble; Owner: eros
+--
+
+ALTER TABLE ONLY diskjumble.slab
+    ADD CONSTRAINT slab_pkey PRIMARY KEY (slab_id);
+
+
+--
+-- Name: verify_pass verify_pass_pkey; Type: CONSTRAINT; Schema: diskjumble; Owner: eros
+--
+
+ALTER TABLE ONLY diskjumble.verify_pass
+    ADD CONSTRAINT verify_pass_pkey PRIMARY KEY (verify_pass_id);
+
+
+--
+-- Name: verify_piece_content verify_piece_content_pkey; Type: CONSTRAINT; Schema: diskjumble; Owner: eros
+--
+
+ALTER TABLE ONLY diskjumble.verify_piece_content
+    ADD CONSTRAINT verify_piece_content_pkey PRIMARY KEY (verify_id, seq);
+
+
+--
+-- Name: verify_piece_fail verify_piece_fail_pkey; Type: CONSTRAINT; Schema: diskjumble; Owner: eros
+--
+
+ALTER TABLE ONLY diskjumble.verify_piece_fail
+    ADD CONSTRAINT verify_piece_fail_pkey PRIMARY KEY (verify_id);
+
+
+--
+-- Name: verify_piece_incomplete verify_piece_incomplete_pkey; Type: CONSTRAINT; Schema: diskjumble; Owner: eros
+--
+
+ALTER TABLE ONLY diskjumble.verify_piece_incomplete
+    ADD CONSTRAINT verify_piece_incomplete_pkey PRIMARY KEY (verify_id);
+
+
+--
+-- Name: verify_piece verify_piece_pkey; Type: CONSTRAINT; Schema: diskjumble; Owner: eros
+--
+
+ALTER TABLE ONLY diskjumble.verify_piece
+    ADD CONSTRAINT verify_piece_pkey PRIMARY KEY (verify_id);
+
+
+--
+-- Name: torrent_info_digest_idx; Type: INDEX; Schema: bittorrent; Owner: eros
+--
+
+CREATE UNIQUE INDEX torrent_info_digest_idx ON bittorrent.torrent_info USING btree (public.digest(info, 'sha1'::text));
+
+
+--
+-- Name: slab slab_disk_id_fkey; Type: FK CONSTRAINT; Schema: diskjumble; Owner: eros
+--
+
+ALTER TABLE ONLY diskjumble.slab
+    ADD CONSTRAINT slab_disk_id_fkey FOREIGN KEY (disk_id) REFERENCES diskjumble.disk(disk_id);
+
+
+--
+-- Name: verify_pass verify_pass_disk_id_fkey; Type: FK CONSTRAINT; Schema: diskjumble; Owner: eros
+--
+
+ALTER TABLE ONLY diskjumble.verify_pass
+    ADD CONSTRAINT verify_pass_disk_id_fkey FOREIGN KEY (disk_id) REFERENCES diskjumble.disk(disk_id);
+
+
+--
+-- Name: verify_piece_content verify_piece_content_disk_id_fkey; Type: FK CONSTRAINT; Schema: diskjumble; Owner: eros
+--
+
+ALTER TABLE ONLY diskjumble.verify_piece_content
+    ADD CONSTRAINT verify_piece_content_disk_id_fkey FOREIGN KEY (disk_id) REFERENCES diskjumble.disk(disk_id);
+
+
+--
+-- Name: verify_piece_content verify_piece_content_verify_id_fkey; Type: FK CONSTRAINT; Schema: diskjumble; Owner: eros
+--
+
+ALTER TABLE ONLY diskjumble.verify_piece_content
+    ADD CONSTRAINT verify_piece_content_verify_id_fkey FOREIGN KEY (verify_id) REFERENCES diskjumble.verify_piece(verify_id);
+
+
+--
+-- Name: verify_piece_fail verify_piece_fail_verify_id_fkey; Type: FK CONSTRAINT; Schema: diskjumble; Owner: eros
+--
+
+ALTER TABLE ONLY diskjumble.verify_piece_fail
+    ADD CONSTRAINT verify_piece_fail_verify_id_fkey FOREIGN KEY (verify_id) REFERENCES diskjumble.verify_piece(verify_id);
+
+
+--
+-- Name: verify_piece_incomplete verify_piece_incomplete_verify_id_fkey; Type: FK CONSTRAINT; Schema: diskjumble; Owner: eros
+--
+
+ALTER TABLE ONLY diskjumble.verify_piece_incomplete
+    ADD CONSTRAINT verify_piece_incomplete_verify_id_fkey FOREIGN KEY (verify_id) REFERENCES diskjumble.verify_piece(verify_id);
+
+
+--
+-- PostgreSQL database dump complete
+--
+
index f160a963e2d25f7e9e73e0ff6450f8b9f445c279..f403c8d8569050c5e3c02844a027a5668304d0d5 100644 (file)
@@ -1,9 +1,11 @@
 from __future__ import annotations
+from abc import ABCMeta, abstractmethod
 from dataclasses import dataclass
 from typing import Optional
 import argparse
 import contextlib
 import datetime as dt
+import io
 import itertools
 import math
 
@@ -34,162 +36,168 @@ class _PieceTask:
        complete: bool  # do these chunks complete the piece?
 
 
-if __name__ == "__main__":
-       parser = argparse.ArgumentParser()
-       parser.add_argument("disk_id", type = int)
-       args = parser.parse_args()
-
-       with contextlib.closing(psycopg2.connect("")) as conn:
-               db = DbWrapper(conn)
-               disk = db.get_disk(args.disk_id)
-
-               info_dicts = {
-                       info_hash: bencode.decode(info)
-                       for (info_hash, info) in db.get_torrent_info(args.disk_id)
-               }
-
-               tasks = []
-               slabs_and_hashers = db.get_slabs_and_hashers(args.disk_id)
-               for (entity_id, group) in itertools.groupby(slabs_and_hashers, lambda t: t[0].entity_id):
-                       info = info_dicts[entity_id]
-                       piece_len = info[b"piece length"]
-                       assert piece_len % disk.sector_size == 0
-                       if b"length" in info:
-                               torrent_len = info[b"length"]
-                       else:
-                               torrent_len = sum(d[b"length"] for d in info[b"files"])
-
-                       offset = None
-                       use_hasher = None
-                       chunks = []
-                       for (slab, hasher_ref) in group:
-                               slab_end = min(slab.entity_offset + len(slab.sectors) * disk.sector_size, torrent_len)
-
-                               while offset is None or offset < slab_end:
-                                       if offset is not None and slab.entity_offset > offset:
-                                               if chunks:
-                                                       tasks.append(_PieceTask(entity_id, offset // piece_len, use_hasher, chunks, False))
-                                               offset = None
+def do_verify(conn, disk_id: int, disk_file: io.BufferedIOBase, read_size: int) -> None:
+       db = DbWrapper(conn)
+       disk = db.get_disk(disk_id)
+
+       info_dicts = {
+               info_hash: bencode.decode(info)
+               for (info_hash, info) in db.get_torrent_info(disk_id)
+       }
+
+       tasks = []
+       slabs_and_hashers = db.get_slabs_and_hashers(disk_id)
+       for (entity_id, group) in itertools.groupby(slabs_and_hashers, lambda t: t[0].entity_id):
+               info = info_dicts[entity_id]
+               piece_len = info[b"piece length"]
+               assert piece_len % disk.sector_size == 0
+               if b"length" in info:
+                       torrent_len = info[b"length"]
+               else:
+                       torrent_len = sum(d[b"length"] for d in info[b"files"])
+
+               offset = None
+               use_hasher = None
+               chunks = []
+               for (slab, hasher_ref) in group:
+                       slab_end = min(slab.entity_offset + len(slab.sectors) * disk.sector_size, torrent_len)
+
+                       while offset is None or offset < slab_end:
+                               if offset is not None and slab.entity_offset > offset:
+                                       if chunks:
+                                               tasks.append(_PieceTask(entity_id, offset // piece_len, use_hasher, chunks, False))
+                                       offset = None
+                                       use_hasher = None
+                                       chunks = []
+
+                               if offset is None:
+                                       aligned = math.ceil(slab.entity_offset / piece_len) * piece_len
+                                       if hasher_ref and hasher_ref.entity_offset < aligned:
+                                               assert hasher_ref.entity_offset < torrent_len
+                                               use_hasher = hasher_ref
+                                               offset = hasher_ref.entity_offset
+                                       elif aligned < slab_end:
+                                               offset = aligned
+                                       else:
+                                               break  # no usable data in this slab
+
+                               if offset is not None:
+                                       piece_end = min(offset + piece_len - offset % piece_len, torrent_len)
+                                       chunk_end = min(piece_end, slab_end)
+                                       chunks.append(_SlabChunk(slab, slice(offset - slab.entity_offset, chunk_end - slab.entity_offset)))
+                                       if chunk_end == piece_end:
+                                               tasks.append(_PieceTask(entity_id, offset // piece_len, use_hasher, chunks, True))
                                                use_hasher = None
                                                chunks = []
+                                       offset = chunk_end
+
+               if chunks:
+                       tasks.append(_PieceTask(entity_id, offset // piece_len, use_hasher, chunks, False))
+
+       @dataclass
+       class NewVerifyPiece:
+               entity_id: bytes
+               piece_num: int
+               sector_ranges: list[range]
+               hasher_state: dict
+               failed: bool
+
+       @dataclass
+       class VerifyUpdate:
+               seq_start: int
+               new_sector_ranges: list[range]
+               hasher_state: dict
+
+       passed_verifies = set()
+       failed_verifies = set()
+       new_pass_ranges = []
+       vp_updates = {}
+       new_vps = []
+
+       run_ts = dt.datetime.now(dt.timezone.utc)
+       for task in tasks:
+               hasher = Sha1Hasher(task.hasher_ref.state if task.hasher_ref else None)
+               sector_ranges = [
+                       range(
+                               chunk.slab.sectors.start + chunk.slice.start // disk.sector_size,
+                               chunk.slab.sectors.start + math.ceil(chunk.slice.stop / disk.sector_size)
+                       )
+                       for chunk in task.chunks
+               ]
+
+               for chunk in task.chunks:
+                       slab_off = chunk.slab.sectors.start * disk.sector_size
+                       disk_file.seek(slab_off + chunk.slice.start)
+                       end_pos = slab_off + chunk.slice.stop
+                       while disk_file.tell() < end_pos:
+                               data = disk_file.read(min(end_pos - disk_file.tell(), read_size))
+                               assert data
+                               hasher.update(data)
+
+               hasher_state = hasher.ctx.serialize()
+               if task.complete:
+                       s = slice(task.piece_num * 20, task.piece_num * 20 + 20)
+                       expected_hash = info_dicts[task.entity_id][b"pieces"][s]
+                       if hasher.digest() == expected_hash:
+                               write_piece_data = False
+                               new_pass_ranges.extend(sector_ranges)
+                               if task.hasher_ref:
+                                       passed_verifies.add(task.hasher_ref.id)
+                       else:
+                               failed = True
+                               write_piece_data = True
+                               if task.hasher_ref:
+                                       failed_verifies.add(task.hasher_ref.id)
+               else:
+                       failed = False
+                       write_piece_data = True
+
+               if write_piece_data:
+                       if task.hasher_ref:
+                               assert task.hasher_ref.id not in vp_updates
+                               vp_updates[task.hasher_ref.id] = VerifyUpdate(task.hasher_ref.seq + 1, sector_ranges, hasher_state)
+                       else:
+                               new_vps.append(NewVerifyPiece(task.entity_id, task.piece_num, sector_ranges, hasher_state, failed))
+
+       new_pass_ranges.sort(key = lambda r: r.start)
+       merged_ranges = []
+       for r in new_pass_ranges:
+               if merged_ranges and r.start == merged_ranges[-1].stop:
+                       merged_ranges[-1] = range(merged_ranges[-1].start, r.stop)
+               else:
+                       merged_ranges.append(r)
+
+       for vp in new_vps:
+               verify_id = db.insert_verify_piece(run_ts, vp.entity_id, vp.piece_num)
+               db.insert_verify_piece_content(verify_id, 0, disk_id, vp.sector_ranges)
+               if vp.failed:
+                       db.mark_verify_piece_failed(verify_id)
+               else:
+                       db.upsert_hasher_state(verify_id, vp.hasher_state)
 
-                                       if offset is None:
-                                               aligned = math.ceil(slab.entity_offset / piece_len) * piece_len
-                                               if hasher_ref and hasher_ref.entity_offset < aligned:
-                                                       assert hasher_ref.entity_offset < torrent_len
-                                                       use_hasher = hasher_ref
-                                                       offset = hasher_ref.entity_offset
-                                               elif aligned < slab_end:
-                                                       offset = aligned
-                                               else:
-                                                       break  # no usable data in this slab
-
-                                       if offset is not None:
-                                               piece_end = min(offset + piece_len - offset % piece_len, torrent_len)
-                                               chunk_end = min(piece_end, slab_end)
-                                               chunks.append(_SlabChunk(slab, slice(offset - slab.entity_offset, chunk_end - slab.entity_offset)))
-                                               if chunk_end == piece_end:
-                                                       tasks.append(_PieceTask(entity_id, offset // piece_len, use_hasher, chunks, True))
-                                                       use_hasher = None
-                                                       chunks = []
-                                               offset = chunk_end
-
-                       if chunks:
-                               tasks.append(_PieceTask(entity_id, offset // piece_len, use_hasher, chunks, False))
-
-               @dataclass
-               class NewVerifyPiece:
-                       entity_id: bytes
-                       piece_num: int
-                       sector_ranges: list[range]
-                       hasher_state: dict
-                       failed: bool
-
-               @dataclass
-               class VerifyUpdate:
-                       seq_start: int
-                       new_sector_ranges: list[range]
-                       hasher_state: dict
-
-               passed_verifies = set()
-               failed_verifies = set()
-               new_pass_ranges = []
-               vp_updates = {}
-               new_vps = []
-
-               run_ts = dt.datetime.now(dt.timezone.utc)
-               with open(f"/dev/mapper/diskjumble-{args.disk_id}", "rb", buffering = _READ_BUFFER_SIZE) as dev:
-                       for task in tasks:
-                               hasher = Sha1Hasher(task.hasher_ref.state if task.hasher_ref else None)
-                               sector_ranges = [
-                                       range(
-                                               chunk.slab.sectors.start + chunk.slice.start // disk.sector_size,
-                                               chunk.slab.sectors.start + math.ceil(chunk.slice.stop / disk.sector_size)
-                                       )
-                                       for chunk in task.chunks
-                               ]
-
-                               for chunk in task.chunks:
-                                       slab_off = chunk.slab.sectors.start * disk.sector_size
-                                       dev.seek(slab_off + chunk.slice.start)
-                                       end_pos = slab_off + chunk.slice.stop
-                                       while dev.tell() < end_pos:
-                                               data = dev.read(min(end_pos - dev.tell(), _READ_BUFFER_SIZE))
-                                               assert data
-                                               hasher.update(data)
-
-                               hasher_state = hasher.ctx.serialize()
-                               if task.complete:
-                                       s = slice(task.piece_num * 20, task.piece_num * 20 + 20)
-                                       expected_hash = info_dicts[task.entity_id][b"pieces"][s]
-                                       if hasher.digest() == expected_hash:
-                                               write_piece_data = False
-                                               new_pass_ranges.extend(sector_ranges)
-                                               if task.hasher_ref:
-                                                       passed_verifies.add(task.hasher_ref.id)
-                                       else:
-                                               failed = True
-                                               write_piece_data = True
-                                               if task.hasher_ref:
-                                                       failed_verifies.add(task.hasher_ref.id)
-                               else:
-                                       failed = False
-                                       write_piece_data = True
-
-                               if write_piece_data:
-                                       if task.hasher_ref:
-                                               assert task.hasher_ref.id not in vp_updates
-                                               vp_updates[task.hasher_ref.id] = VerifyUpdate(task.hasher_ref.seq + 1, sector_ranges, hasher_state)
-                                       else:
-                                               new_vps.append(NewVerifyPiece(task.entity_id, task.piece_num, sector_ranges, hasher_state, failed))
+       for (verify_id, update) in vp_updates.items():
+               db.insert_verify_piece_content(verify_id, update.seq_start, disk_id, update.new_sector_ranges)
+               db.upsert_hasher_state(verify_id, update.hasher_state)
 
-               new_pass_ranges.sort(key = lambda r: r.start)
-               merged_ranges = []
-               for r in new_pass_ranges:
-                       if merged_ranges and r.start == merged_ranges[-1].stop:
-                               merged_ranges[-1] = range(merged_ranges[-1].start, r.stop)
-                       else:
-                               merged_ranges.append(r)
+       for verify_id in passed_verifies:
+               db.move_piece_content_for_pass(verify_id)
+               db.delete_verify_piece(verify_id)
 
-               for vp in new_vps:
-                       verify_id = db.insert_verify_piece(run_ts, vp.entity_id, vp.piece_num)
-                       db.insert_verify_piece_content(verify_id, 0, args.disk_id, vp.sector_ranges)
-                       if vp.failed:
-                               db.mark_verify_piece_failed(verify_id)
-                       else:
-                               db.upsert_hasher_state(verify_id, vp.hasher_state)
+       for r in merged_ranges:
+               db.insert_pass_data(run_ts, disk_id, r)
 
-               for (verify_id, update) in vp_updates.items():
-                       db.insert_verify_piece_content(verify_id, update.seq_start, args.disk_id, update.new_sector_ranges)
-                       db.upsert_hasher_state(verify_id, update.hasher_state)
+       for verify_id in failed_verifies:
+               db.clear_incomplete(verify_id)
+               db.mark_verify_piece_failed(verify_id)
 
-               for verify_id in passed_verifies:
-                       db.move_piece_content_for_pass(verify_id)
-                       db.delete_verify_piece(verify_id)
 
-               for r in merged_ranges:
-                       db.insert_pass_data(run_ts, args.disk_id, r)
+if __name__ == "__main__":
+       parser = argparse.ArgumentParser()
+       parser.add_argument("disk_id", type = int)
+       args = parser.parse_args()
 
-               for verify_id in failed_verifies:
-                       db.clear_incomplete(verify_id)
-                       db.mark_verify_piece_failed(verify_id)
+       with contextlib.closing(psycopg2.connect("")) as conn:
+               conn.autocommit = True
+               path = f"/dev/mapper/diskjumble-{args.disk_id}"
+               with open(path, "rb", buffering = _READ_BUFFER_SIZE, read_size = _READ_BUFFER_SIZE) as disk_file:
+                       do_verify(conn, args.disk_id, disk_file)
diff --git a/test_util/dump_db.py b/test_util/dump_db.py
new file mode 100644 (file)
index 0000000..9551bf9
--- /dev/null
@@ -0,0 +1,46 @@
+"""
+Using the live database, dump creation code for extensions, tables, and functions needed for local testing
+
+For testing the verification script, write the output of this script to:
+
+       src/disk_jumble/tests/verify_setup.sql
+"""
+
+import contextlib
+import itertools
+import os
+import subprocess
+
+import psycopg2
+
+
+procedures = [
+       "public.digest(bytea, text)",
+]
+
+extensions = [
+       "btree_gist",
+]
+
+with contextlib.closing(psycopg2.connect("")) as conn:
+       conn.autocommit = True
+       with conn.cursor() as cursor:
+               for ext in extensions:
+                       print(f"create extension {ext};", flush = True)
+               for sig in procedures:
+                       cursor.execute("select pg_get_functiondef(to_regprocedure(%s));", (sig,))
+                       [(sql,)] = cursor.fetchall()
+                       print(sql + ";", flush = True)
+
+tables = [
+       "diskjumble.disk",
+       "diskjumble.slab",
+       "diskjumble.verify_pass",
+       "diskjumble.verify_piece",
+       "diskjumble.verify_piece_content",
+       "diskjumble.verify_piece_fail",
+       "diskjumble.verify_piece_incomplete",
+       "bittorrent.torrent_info",
+]
+argv = ["pg_dump", *itertools.chain.from_iterable(["-t", table] for table in tables), "--schema-only", os.environ["PGDATABASE"]]
+subprocess.run(argv, check = True)