Update verify v1 for database schema changes
authorJakob Cornell <jakob+gpg@jcornell.net>
Mon, 21 Feb 2022 05:57:25 +0000 (23:57 -0600)
committerJakob Cornell <jakob+gpg@jcornell.net>
Mon, 21 Feb 2022 20:29:46 +0000 (14:29 -0600)
Also renamed test module for clarity

disk_jumble/src/disk_jumble/__init__.py
disk_jumble/src/disk_jumble/db.py
disk_jumble/src/disk_jumble/tests/test_verify.py [deleted file]
disk_jumble/src/disk_jumble/tests/test_verify_v1.py [new file with mode: 0644]
disk_jumble/src/disk_jumble/tests/verify_setup.sql
disk_jumble/src/disk_jumble/verify.py

index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..4b7b490d6feb6aaaf9fd49cd8a3f31744ad5bdc9 100644 (file)
@@ -0,0 +1 @@
+SECTOR_SIZE = 16 * 1024  # in bytes
index 68488f095ebc519296ec76130a0d415889539287..5b7e2f87c0b989f6066e32e71177c72648c8d806 100644 (file)
@@ -7,11 +7,6 @@ import itertools
 from psycopg2.extras import execute_batch, Json, NumericRange
 
 
-@dataclass
-class Disk:
-       sector_size: int
-
-
 @dataclass
 class Slab:
        id: int
@@ -74,14 +69,7 @@ class Wrapper:
                        ]
                        execute_batch(cursor, stmt, param_sets)
 
-       def get_disk(self, id_: int) -> Disk:
-               with self.conn.cursor() as cursor:
-                       cursor.execute("select sector_size from diskjumble.disk where disk_id = %s;", (id_,))
-                       [(sector_size,)] = cursor.fetchall()
-
-               return Disk(sector_size)
-
-       def get_slabs_and_hashers(self, disk_id: int) -> Iterable[tuple[Slab, Optional[HasherRef]]]:
+       def get_slabs_and_hashers(self, disk_id: int, sector_size: int) -> Iterable[tuple[Slab, Optional[HasherRef]]]:
                """
                Find slabs on the specified disk, and for each also return any lowest-offset saved hasher that left off directly
                before or within the slab's entity data.
@@ -93,7 +81,7 @@ class Wrapper:
                                        -- join up incomplete piece info and precompute where the hasher left off within the entity
                                        select
                                                verify_id, seq, slab.entity_id, hasher_state,
-                                               entity_offset + (upper(c.disk_sectors) - lower(slab.disk_sectors)) * sector_size as end_off
+                                               entity_offset + (upper(c.disk_sectors) - lower(slab.disk_blocks)) * %(sector_size)s as end_off
                                        from
                                                diskjumble.verify_piece_incomplete
                                                natural left join diskjumble.verify_piece p
@@ -101,12 +89,12 @@ class Wrapper:
                                                natural left join diskjumble.disk
                                                left join diskjumble.slab on (
                                                        c.disk_id = slab.disk_id
-                                                       and upper(c.disk_sectors) <@ int8range(lower(slab.disk_sectors), upper(slab.disk_sectors), '[]')
+                                                       and upper(c.disk_sectors) <@ int8range(lower(slab.disk_blocks), upper(slab.disk_blocks), '[]')
                                                )
                                        where seq >= all (select seq from diskjumble.verify_piece_content where verify_id = p.verify_id)
                                )
                        select
-                               slab_id, disk_id, disk_sectors, slab.entity_id, entity_offset, crypt_key, verify_id, seq, end_off,
+                               slab_id, disk_id, disk_blocks, slab.entity_id, entity_offset, crypt_key, verify_id, seq, end_off,
                                hasher_state
                        from
                                diskjumble.slab
@@ -115,15 +103,15 @@ class Wrapper:
                                        incomplete_edge.entity_id = slab.entity_id
                                        and incomplete_edge.end_off <@ int8range(
                                                slab.entity_offset,
-                                               slab.entity_offset + (upper(disk_sectors) - lower(disk_sectors)) * sector_size
+                                               slab.entity_offset + (upper(disk_blocks) - lower(disk_blocks)) * %(sector_size)s
                                        )
-                                       and (incomplete_edge.end_off - slab.entity_offset) %% sector_size = 0
-                       where disk_id = %s
+                                       and (incomplete_edge.end_off - slab.entity_offset) %% %(sector_size)s = 0
+                       where disk_id = %(disk_id)s
                        order by slab.entity_id, entity_offset, slab_id
                        ;
                """
                with self.conn.cursor() as cursor:
-                       cursor.execute(stmt, (disk_id,))
+                       cursor.execute(stmt, {"disk_id": disk_id, "sector_size": sector_size})
                        for (_, rows_iter) in itertools.groupby(cursor, lambda r: r[0]):
                                rows = list(rows_iter)
                                [(slab_id, disk_id, sectors_pg, entity_id, entity_off, key)] = {r[:6] for r in rows}
diff --git a/disk_jumble/src/disk_jumble/tests/test_verify.py b/disk_jumble/src/disk_jumble/tests/test_verify.py
deleted file mode 100644 (file)
index 75f166d..0000000
+++ /dev/null
@@ -1,451 +0,0 @@
-"""
-Tests for the verification program `disk_jumble.verify'
-
-Like the verification program itself, these tests take database connection information from the environment.  The
-necessary schemas and tables are set up from scratch by the test code, so environment variables should point to a
-database that's not hosting a live instance of Disk Jumble.  Ideally, this is a completely empty local database created
-for the purposes of testing, but other options are possible.
-
-The tests need access to an SQL source file containing the definitions for the required tables and other Postgres
-objects; see `test_util/dump_db.py'.
-"""
-
-from dataclasses import dataclass
-from importlib import resources
-from random import Random
-from typing import Optional
-import hashlib
-import io
-import tempfile
-import unittest
-import uuid
-
-from psycopg2.extras import NumericRange
-import psycopg2
-import psycopg2.extras
-
-from disk_jumble import bencode
-from disk_jumble.nettle import Sha1Hasher
-from disk_jumble.verify import do_verify
-
-
-_BUF_SIZE = 16 * 1024 ** 2  # in bytes
-
-
-class Tests(unittest.TestCase):
-       _SCHEMAS = {"public", "diskjumble", "bittorrent"}
-
-       def _basic_fresh_verify_helper(self, read_size):
-               sector_size = 32
-               piece_size = 64
-
-               torrent_len = 3 * piece_size
-               disk = self._write_disk(sector_size, torrent_len // sector_size)
-               with _random_file(torrent_len, Random(0), on_disk = False) as torrent_file:
-                       torrent = _Torrent(torrent_file, piece_size)
-                       self._write_torrent(torrent)
-                       with self._conn.cursor() as cursor:
-                               cursor.execute(
-                                       "insert into diskjumble.slab values (default, %s, %s, %s, 0, null);",
-                                       (disk.id, NumericRange(0, disk.sector_count), torrent.info_hash)
-                               )
-
-                               do_verify(self._conn, disk.id, torrent_file, read_size, read_tries = 1)
-
-                               cursor.execute("select * from diskjumble.verify_pass;")
-                               self.assertEqual(cursor.rowcount, 1)
-                               (_, _, disk_id, sectors) = cursor.fetchone()
-                               self.assertEqual(disk_id, disk.id)
-                               self.assertEqual(sectors, NumericRange(0, torrent_len // sector_size))
-
-       def test_basic_fresh_verify_small_read_size(self):
-               self._basic_fresh_verify_helper(16)
-
-       def test_basic_fresh_verify_large_read_size(self):
-               self._basic_fresh_verify_helper(128)
-
-       def test_resume_fragmentation_unaligned_end(self):
-               """
-               Test a run where a cached hash state is used, a piece is split on disk, and the end of the torrent isn't
-               sector-aligned.
-               """
-               read_size = 8
-               piece_size = 64
-
-               other_disk = self._write_disk(16, 1)
-               disk = self._write_disk(32, 3)
-               with _random_file(piece_size, Random(0), on_disk = False) as torrent_file:
-                       torrent = _Torrent(torrent_file, piece_size)
-                       self._write_torrent(torrent)
-                       with self._conn.cursor() as cursor:
-                               cursor.executemany(
-                                       "insert into diskjumble.slab values (default, %s, %s, %s, %s, null);",
-                                       [
-                                               (other_disk.id, NumericRange(0, 1), torrent.info_hash, 0),
-                                               (disk.id, NumericRange(0, 1), torrent.info_hash, other_disk.sector_size),
-                                               (disk.id, NumericRange(2, 3), torrent.info_hash, other_disk.sector_size + disk.sector_size),
-                                       ]
-                               )
-
-                               # Prepare the saved hasher state by running a verify
-                               do_verify(self._conn, other_disk.id, torrent_file, read_size, read_tries = 1)
-                               torrent_file.seek(0)
-
-                               cursor.execute("select count(*) from diskjumble.verify_piece_incomplete;")
-                               [(row_count,)] = cursor.fetchall()
-                               self.assertEqual(row_count, 1)
-
-                               disk_file = io.BytesIO()
-                               torrent_file.seek(other_disk.sector_size)
-                               disk_file.write(torrent_file.read(disk.sector_size))
-                               disk_file.seek(disk_file.tell() + disk.sector_size)
-                               disk_file.write(torrent_file.read())
-                               disk_file.seek(0)
-                               do_verify(self._conn, disk.id, disk_file, read_size, read_tries = 1)
-
-                               # Check that there are no verify pieces in the database.  Because of integrity constraints, this also
-                               # guarantees there aren't any stray saved hasher states, failed verifies, or piece contents.
-                               cursor.execute("select count(*) from diskjumble.verify_piece;")
-                               [(row_count,)] = cursor.fetchall()
-                               self.assertEqual(row_count, 0)
-
-                               cursor.execute("select disk_id, disk_sectors from diskjumble.verify_pass;")
-                               self.assertEqual(
-                                       cursor.fetchall(),
-                                       [(other_disk.id, NumericRange(0, 1)), (disk.id, NumericRange(0, 1)), (disk.id, NumericRange(2, 3))]
-                               )
-
-       def test_resume_no_completion(self):
-               """
-               Test a run where a saved hasher state is used and the target disk has subsequent entity data but not the full
-               remainder of the piece.
-               """
-               read_size = 7
-               piece_size = 64
-
-               other_disk = self._write_disk(16, 1)
-               disk = self._write_disk(32, 1)
-               with _random_file(piece_size, Random(0), on_disk = False) as torrent_file:
-                       torrent = _Torrent(torrent_file, piece_size)
-                       self._write_torrent(torrent)
-                       with self._conn.cursor() as cursor:
-                               cursor.executemany(
-                                       "insert into diskjumble.slab values (default, %s, %s, %s, %s, null);",
-                                       [
-                                               (other_disk.id, NumericRange(0, 1), torrent.info_hash, 0),
-                                               (disk.id, NumericRange(0, 1), torrent.info_hash, other_disk.sector_size),
-                                       ]
-                               )
-
-                               do_verify(self._conn, other_disk.id, torrent_file, read_size, read_tries = 1)
-
-                               cursor.execute("select count(*) from diskjumble.verify_piece_incomplete;")
-                               [(row_count,)] = cursor.fetchall()
-                               self.assertEqual(row_count, 1)
-
-                               disk_file = io.BytesIO()
-                               torrent_file.seek(other_disk.sector_size)
-                               disk_file.write(torrent_file.read(disk.sector_size))
-                               disk_file.seek(0)
-                               do_verify(self._conn, disk.id, disk_file, read_size, read_tries = 1)
-
-                               cursor.execute("select count(*) from diskjumble.verify_pass;")
-                               [(row_count,)] = cursor.fetchall()
-                               self.assertEqual(row_count, 0)
-
-                               cursor.execute("select entity_id, piece from diskjumble.verify_piece;")
-                               [(entity_id, piece_num)] = cursor.fetchall()
-                               self.assertEqual(bytes(entity_id), torrent.info_hash)
-                               self.assertEqual(piece_num, 0)
-
-                               cursor.execute("select disk_id, disk_sectors from diskjumble.verify_piece_content;")
-                               self.assertCountEqual(
-                                       cursor.fetchall(),
-                                       [(other_disk.id, NumericRange(0, 1)), (disk.id, NumericRange(0, 1))]
-                               )
-
-                               cursor.execute("select count(*) from diskjumble.verify_piece_fail;")
-                               [(row_count,)] = cursor.fetchall()
-                               self.assertEqual(row_count, 0)
-
-                               hasher = Sha1Hasher(None)
-                               torrent_file.seek(0)
-                               hasher.update(torrent_file.read(other_disk.sector_size + disk.sector_size))
-                               cursor.execute("select hasher_state from diskjumble.verify_piece_incomplete;")
-                               self.assertEqual(cursor.fetchall(), [(hasher.ctx.serialize(),)])
-
-       def test_ignore_hasher_beginning_on_disk(self):
-               """
-               Test a run where a saved hasher state is available for use but isn't used due to the beginning of the piece
-               being on disk.
-               """
-               piece_size = 64
-
-               other_disk = self._write_disk(16, 1)
-               disk = self._write_disk(16, 4)
-               with _random_file(piece_size * 2, Random(0), on_disk = False) as torrent_file:
-                       torrent = _Torrent(torrent_file, piece_size)
-                       self._write_torrent(torrent)
-                       with self._conn.cursor() as cursor:
-                               cursor.executemany(
-                                       "insert into diskjumble.slab values (default, %s, %s, %s, %s, null);",
-                                       [
-                                               (other_disk.id, NumericRange(0, other_disk.sector_count), torrent.info_hash, piece_size),
-                                               (disk.id, NumericRange(0, disk.sector_count), torrent.info_hash, piece_size),
-                                       ]
-                               )
-
-                               do_verify(self._conn, other_disk.id, torrent_file, read_size = 128, read_tries = 1)
-
-                               cursor.execute("select count(*) from diskjumble.verify_piece_incomplete;")
-                               [(row_count,)] = cursor.fetchall()
-                               self.assertEqual(row_count, 1)
-
-                               torrent_file.seek(piece_size)
-                               disk_file = io.BytesIO(torrent_file.read())
-                               do_verify(self._conn, disk.id, disk_file, read_size = 128, read_tries = 1)
-
-                               cursor.execute(
-                                       "select disk_id from diskjumble.verify_piece_content natural join diskjumble.verify_piece_incomplete;"
-                               )
-                               self.assertEqual(cursor.fetchall(), [(other_disk.id,)])
-
-                               cursor.execute("select disk_id, disk_sectors from diskjumble.verify_pass;")
-                               self.assertEqual(cursor.fetchall(), [(disk.id, NumericRange(0, disk.sector_count))])
-
-       def test_ignore_hasher_unaligned(self):
-               """
-               Test a run where a saved hasher isn't used because its entity data offset isn't sector-aligned on the target
-               disk.
-
-                           0   16  32  48  64  80  96  112 128
-               pieces:     [-------------- 0 -------------]
-               other disk: [--][--][--][--][--]
-               disk:                       [------][------]
-               """
-               piece_size = 128
-
-               other_disk = self._write_disk(16, 5)
-               disk = self._write_disk(32, 2)
-               with _random_file(piece_size, Random(0), on_disk = False) as torrent_file:
-                       torrent = _Torrent(torrent_file, piece_size)
-                       self._write_torrent(torrent)
-                       with self._conn.cursor() as cursor:
-                               cursor.executemany(
-                                       "insert into diskjumble.slab values (default, %s, %s, %s, %s, null);",
-                                       [
-                                               (other_disk.id, NumericRange(0, 5), torrent.info_hash, 0),
-                                               (disk.id, NumericRange(0, 2), torrent.info_hash, 64),
-                                       ]
-                               )
-
-                               do_verify(self._conn, other_disk.id, torrent_file, read_size = 16, read_tries = 1)
-                               cursor.execute("select count(*) from diskjumble.verify_piece_incomplete;")
-                               [(row_count,)] = cursor.fetchall()
-                               self.assertEqual(row_count, 1)
-
-                               disk_file = io.BytesIO(torrent_file.getvalue()[64:])
-                               do_verify(self._conn, disk.id, disk_file, read_size = 16, read_tries = 1)
-
-                               cursor.execute("""
-                                       select disk_id, disk_sectors
-                                       from diskjumble.verify_piece_incomplete natural join diskjumble.verify_piece_content;
-                               """)
-                               self.assertEqual(
-                                       cursor.fetchall(),
-                                       [(other_disk.id, NumericRange(0, 5))]
-                               )
-
-                               cursor.execute("select count(*) from diskjumble.verify_pass;")
-                               [(row_count,)] = cursor.fetchall()
-                               self.assertEqual(row_count, 0)
-
-                               cursor.execute("select count(*) from diskjumble.verify_piece_fail;")
-                               [(row_count,)] = cursor.fetchall()
-                               self.assertEqual(row_count, 0)
-
-       def test_transient_read_errors(self):
-               """
-               Test a run where a read to the disk fails but fewer times than needed to mark the sector bad.
-               """
-               piece_size = 32
-
-               disk = self._write_disk(32, 1)
-               with _random_file(piece_size, Random(0), on_disk = False) as torrent_file:
-                       torrent = _Torrent(torrent_file, piece_size)
-                       self._write_torrent(torrent)
-                       with self._conn.cursor() as cursor:
-                               cursor.execute(
-                                       "insert into diskjumble.slab values (default, %s, %s, %s, %s, null);",
-                                       (disk.id, NumericRange(0, 1), torrent.info_hash, 0)
-                               )
-
-                               disk_file = _ReadErrorProxy(torrent_file, error_pos = 12, error_count = 2)
-                               do_verify(self._conn, disk.id, disk_file, read_size = 4, read_tries = 3)
-
-                               self.assertEqual(disk_file.triggered, 2)
-
-                               cursor.execute("select count(*) from diskjumble.verify_piece;")
-                               [(row_count,)] = cursor.fetchall()
-                               self.assertEqual(row_count, 0)
-
-                               cursor.execute("select disk_id, disk_sectors from diskjumble.verify_pass;")
-                               self.assertEqual(cursor.fetchall(), [(disk.id, NumericRange(0, 1))])
-
-       def test_persistent_read_errors(self):
-               """
-               Test a run where a disk read fails enough times to trigger the bad sector logic.
-               """
-               piece_size = 64
-
-               other_a = self._write_disk(16, 1)
-               other_b = self._write_disk(16, 2)
-               disk = self._write_disk(16, 1)
-               with _random_file(piece_size, Random(0), on_disk = False) as torrent_file:
-                       torrent = _Torrent(torrent_file, piece_size)
-                       self._write_torrent(torrent)
-                       with self._conn.cursor() as cursor:
-                               cursor.executemany(
-                                       "insert into diskjumble.slab values (default, %s, %s, %s, %s, null);",
-                                       [
-                                               (other_a.id, NumericRange(0, 1), torrent.info_hash, 0),
-                                               (other_b.id, NumericRange(0, 2), torrent.info_hash, 16),
-                                               (disk.id, NumericRange(0, 1), torrent.info_hash, 48),
-                                       ]
-                               )
-
-                               do_verify(self._conn, other_a.id, torrent_file, read_size = 16, read_tries = 1)
-                               other_b_file = io.BytesIO(torrent_file.getvalue()[16:48])
-                               do_verify(self._conn, other_b.id, other_b_file, read_size = 16, read_tries = 1)
-
-                               cursor.execute("select verify_id from diskjumble.verify_piece;")
-                               [(verify_id,)] = cursor.fetchall()
-
-                               data = torrent_file.getvalue()[48:]
-                               disk_file = _ReadErrorProxy(io.BytesIO(data), error_pos = 5, error_count = None)
-                               do_verify(self._conn, disk.id, disk_file, read_size = 4, read_tries = 3)
-
-                               cursor.execute("select count(*) from diskjumble.verify_pass;")
-                               [(row_count,)] = cursor.fetchall()
-                               self.assertEqual(row_count, 0)
-
-                               cursor.execute("select disk_id, disk_sectors from diskjumble.verify_piece_content;")
-                               self.assertCountEqual(
-                                       cursor.fetchall(),
-                                       [(other_a.id, NumericRange(0, 1)), (other_b.id, NumericRange(0, 2)), (disk.id, NumericRange(0, 1))]
-                               )
-
-                               cursor.execute("select verify_id from diskjumble.verify_piece_fail;")
-                               self.assertEqual(cursor.fetchall(), [(verify_id,)])
-
-                               cursor.execute("select count(*) from diskjumble.verify_piece_incomplete;")
-                               [(row_count,)] = cursor.fetchall()
-                               self.assertEqual(row_count, 0)
-
-       def _write_torrent(self, torrent: "_Torrent") -> None:
-               with self._conn.cursor() as cursor:
-                       cursor.execute("insert into bittorrent.torrent_info values (%s);", (torrent.info,))
-
-       def _write_disk(self, sector_size: int, sector_count: int) -> "_Disk":
-               with self._conn.cursor() as cursor:
-                       cursor.execute(
-                               "insert into diskjumble.disk values (default, %s, null, %s, %s, false) returning disk_id;",
-                               (uuid.uuid4(), sector_size, sector_count)
-                       )
-                       [(id_,)] = cursor.fetchall()
-               return _Disk(id_, sector_size, sector_count)
-
-       @classmethod
-       def setUpClass(cls):
-               psycopg2.extras.register_uuid()
-               prod_schema_sql = resources.files(__package__).joinpath("verify_setup.sql").read_text()
-               schema_sql = "\n".join(
-                       l for l in prod_schema_sql.splitlines()
-                       if (
-                               not l.startswith("SET")
-                               and not l.startswith("SELECT")  # ignore server settings
-                               and "OWNER TO" not in l  # and ownership changes
-                       )
-               )
-               cls._conn = psycopg2.connect("")
-               with cls._conn, cls._conn.cursor() as cursor:
-                       for name in cls._SCHEMAS:
-                               cursor.execute(f"create schema {name};")
-                       cursor.execute(schema_sql)
-
-       @classmethod
-       def tearDownClass(self):
-               try:
-                       with self._conn, self._conn.cursor() as cursor:
-                               for name in self._SCHEMAS:
-                                       cursor.execute(f"drop schema {name} cascade;")
-               finally:
-                       self._conn.close()
-
-       def tearDown(self):
-               self._conn.rollback()
-
-
-@dataclass
-class _Disk:
-       id: int
-       sector_size: int
-       sector_count: int
-
-
-class _Torrent:
-       def __init__(self, data: io.BufferedIOBase, piece_size: int) -> None:
-               data.seek(0)
-               hashes = []
-               while True:
-                       piece = data.read(piece_size)
-                       if piece:
-                               hashes.append(hashlib.sha1(piece).digest())
-                       else:
-                               break
-
-               info_dict = {
-                       b"piece length": piece_size,
-                       b"length": data.tell(),
-                       b"pieces": b"".join(hashes),
-               }
-               self.data = data
-               self.info = bencode.encode(info_dict)
-               self.info_hash = hashlib.sha1(self.info).digest()
-
-
-def _random_file(size: int, rand_src: Random, on_disk: bool) -> io.BufferedIOBase:
-       f = tempfile.NamedTemporaryFile(buffering = _BUF_SIZE) if on_disk else io.BytesIO()
-       try:
-               while f.tell() < size:
-                       write_size = min(size - f.tell(), _BUF_SIZE)
-                       f.write(bytes(rand_src.getrandbits(8) for _ in range(write_size)))
-               f.seek(0)
-               return f
-       except Exception:
-               f.close()
-               raise
-
-
-@dataclass
-class _ReadErrorProxy(io.BufferedIOBase):
-       wrapped: io.BufferedIOBase
-       error_pos: int
-       error_count: Optional[int]
-
-       def __post_init__(self):
-               self.triggered = 0
-
-       def read(self, size: Optional[int] = None) -> bytes:
-               pre_pos = self.wrapped.tell()
-               data = self.wrapped.read(size)
-               erroring = self.error_count is None or self.triggered < self.error_count
-               in_range = 0 <= pre_pos - self.error_pos < len(data)
-               if erroring and in_range:
-                       self.triggered += 1
-                       raise OSError("simulated")
-               else:
-                       return data
-
-       def seek(self, *args, **kwargs) -> int:
-               return self.wrapped.seek(*args, **kwargs)
diff --git a/disk_jumble/src/disk_jumble/tests/test_verify_v1.py b/disk_jumble/src/disk_jumble/tests/test_verify_v1.py
new file mode 100644 (file)
index 0000000..bc18e28
--- /dev/null
@@ -0,0 +1,457 @@
+"""
+Tests for the verification program `disk_jumble.verify'
+
+Like the verification program itself, these tests take database connection information from the environment.  The
+necessary schemas and tables are set up from scratch by the test code, so environment variables should point to a
+database that's not hosting a live instance of Disk Jumble.  Ideally, this is a completely empty local database created
+for the purposes of testing, but other options are possible.
+
+The tests need access to an SQL source file containing the definitions for the required tables and other Postgres
+objects; see `test_util/dump_db.py'.
+"""
+
+from dataclasses import dataclass
+from importlib import resources
+from random import Random
+from typing import Optional
+import hashlib
+import io
+import tempfile
+import unittest
+import uuid
+
+from psycopg2.extras import NumericRange
+import psycopg2
+import psycopg2.extras
+
+from disk_jumble import bencode
+from disk_jumble.nettle import Sha1Hasher
+from disk_jumble.verify import do_verify
+
+
+_BUF_SIZE = 16 * 1024 ** 2  # in bytes
+
+
+class Tests(unittest.TestCase):
+       _SCHEMAS = {"public", "diskjumble", "bittorrent"}
+
+       def _basic_fresh_verify_helper(self, read_size):
+               sector_size = 32
+               piece_size = 64
+
+               torrent_len = 3 * piece_size
+               disk = self._write_disk(torrent_len // sector_size)
+               with _random_file(torrent_len, Random(0), on_disk = False) as torrent_file:
+                       torrent = _Torrent(torrent_file, piece_size)
+                       self._write_torrent(torrent)
+                       with self._conn.cursor() as cursor:
+                               cursor.execute(
+                                       "insert into diskjumble.slab values (default, %s, %s, %s, 0, null);",
+                                       (disk.id, NumericRange(0, disk.sector_count), torrent.info_hash)
+                               )
+
+                               do_verify(self._conn, disk.id, sector_size, torrent_file, read_size, read_tries = 1)
+
+                               cursor.execute("select * from diskjumble.verify_pass;")
+                               self.assertEqual(cursor.rowcount, 1)
+                               (_, _, disk_id, sectors) = cursor.fetchone()
+                               self.assertEqual(disk_id, disk.id)
+                               self.assertEqual(sectors, NumericRange(0, torrent_len // sector_size))
+
+       def test_basic_fresh_verify_small_read_size(self):
+               self._basic_fresh_verify_helper(16)
+
+       def test_basic_fresh_verify_large_read_size(self):
+               self._basic_fresh_verify_helper(128)
+
+       def test_resume_fragmentation_unaligned_end(self):
+               """
+               Test a run where a cached hash state is used, a piece is split on disk, and the end of the torrent isn't
+               sector-aligned.
+               """
+               sector_size = 16
+               read_size = 8
+
+               other_disk = self._write_disk(1)
+               disk = self._write_disk(5)
+               with _random_file(60, Random(0), on_disk = False) as torrent_file:
+                       torrent = _Torrent(torrent_file, piece_size = 64)
+                       self._write_torrent(torrent)
+                       with self._conn.cursor() as cursor:
+                               cursor.executemany(
+                                       "insert into diskjumble.slab values (default, %s, %s, %s, %s, null);",
+                                       [
+                                               (other_disk.id, NumericRange(0, 1), torrent.info_hash, 0),
+                                               (disk.id, NumericRange(0, 2), torrent.info_hash, 16),
+                                               (disk.id, NumericRange(4, 5), torrent.info_hash, 48),
+                                       ]
+                               )
+
+                               # Prepare the saved hasher state by running a verify
+                               do_verify(self._conn, other_disk.id, sector_size, torrent_file, read_size, read_tries = 1)
+                               torrent_file.seek(0)
+
+                               cursor.execute("select count(*) from diskjumble.verify_piece_incomplete;")
+                               [(row_count,)] = cursor.fetchall()
+                               self.assertEqual(row_count, 1)
+
+                               disk_file = io.BytesIO()
+                               torrent_file.seek(sector_size)
+                               disk_file.write(torrent_file.read(sector_size * 2))
+                               disk_file.seek(disk_file.tell() + sector_size * 2)
+                               disk_file.write(torrent_file.read())
+                               disk_file.seek(0)
+                               do_verify(self._conn, disk.id, sector_size, disk_file, read_size, read_tries = 1)
+
+                               # Check that there are no verify pieces in the database.  Because of integrity constraints, this also
+                               # guarantees there aren't any stray saved hasher states, failed verifies, or piece contents.
+                               cursor.execute("select count(*) from diskjumble.verify_piece;")
+                               [(row_count,)] = cursor.fetchall()
+                               self.assertEqual(row_count, 0)
+
+                               cursor.execute("select disk_id, disk_sectors from diskjumble.verify_pass;")
+                               self.assertEqual(
+                                       cursor.fetchall(),
+                                       [(other_disk.id, NumericRange(0, 1)), (disk.id, NumericRange(0, 2)), (disk.id, NumericRange(4, 5))]
+                               )
+
+       def test_resume_no_completion(self):
+               """
+               Test a run where a saved hasher state is used and the target disk has subsequent entity data but not the full
+               remainder of the piece.
+               """
+               sector_size = 16
+               read_size = 7
+               piece_size = 64
+
+               other_disk = self._write_disk(1)
+               disk = self._write_disk(2)
+               with _random_file(piece_size, Random(0), on_disk = False) as torrent_file:
+                       torrent = _Torrent(torrent_file, piece_size)
+                       self._write_torrent(torrent)
+                       with self._conn.cursor() as cursor:
+                               cursor.executemany(
+                                       "insert into diskjumble.slab values (default, %s, %s, %s, %s, null);",
+                                       [
+                                               (other_disk.id, NumericRange(0, 1), torrent.info_hash, 0),
+                                               (disk.id, NumericRange(0, 2), torrent.info_hash, sector_size),
+                                       ]
+                               )
+
+                               do_verify(self._conn, other_disk.id, sector_size, torrent_file, read_size, read_tries = 1)
+
+                               cursor.execute("select count(*) from diskjumble.verify_piece_incomplete;")
+                               [(row_count,)] = cursor.fetchall()
+                               self.assertEqual(row_count, 1)
+
+                               disk_file = io.BytesIO()
+                               torrent_file.seek(sector_size)
+                               disk_file.write(torrent_file.read(sector_size * 2))
+                               disk_file.seek(0)
+                               do_verify(self._conn, disk.id, sector_size, disk_file, read_size, read_tries = 1)
+
+                               cursor.execute("select count(*) from diskjumble.verify_pass;")
+                               [(row_count,)] = cursor.fetchall()
+                               self.assertEqual(row_count, 0)
+
+                               cursor.execute("select entity_id, piece from diskjumble.verify_piece;")
+                               [(entity_id, piece_num)] = cursor.fetchall()
+                               self.assertEqual(bytes(entity_id), torrent.info_hash)
+                               self.assertEqual(piece_num, 0)
+
+                               cursor.execute("select disk_id, disk_sectors from diskjumble.verify_piece_content;")
+                               self.assertCountEqual(
+                                       cursor.fetchall(),
+                                       [(other_disk.id, NumericRange(0, 1)), (disk.id, NumericRange(0, 2))]
+                               )
+
+                               cursor.execute("select count(*) from diskjumble.verify_piece_fail;")
+                               [(row_count,)] = cursor.fetchall()
+                               self.assertEqual(row_count, 0)
+
+                               hasher = Sha1Hasher(None)
+                               torrent_file.seek(0)
+                               hasher.update(torrent_file.read(sector_size * 3))
+                               cursor.execute("select hasher_state from diskjumble.verify_piece_incomplete;")
+                               self.assertEqual(cursor.fetchall(), [(hasher.ctx.serialize(),)])
+
+       def test_ignore_hasher_beginning_on_disk(self):
+               """
+               Test a run where a saved hasher state is available for use but isn't used due to the beginning of the piece
+               being on disk.
+               """
+               piece_size = 64
+
+               other_disk = self._write_disk(1)
+               od_ss = 16
+               disk = self._write_disk(4)
+               d_ss = 16
+               with _random_file(piece_size * 2, Random(0), on_disk = False) as torrent_file:
+                       torrent = _Torrent(torrent_file, piece_size)
+                       self._write_torrent(torrent)
+                       with self._conn.cursor() as cursor:
+                               cursor.executemany(
+                                       "insert into diskjumble.slab values (default, %s, %s, %s, %s, null);",
+                                       [
+                                               (other_disk.id, NumericRange(0, other_disk.sector_count), torrent.info_hash, piece_size),
+                                               (disk.id, NumericRange(0, disk.sector_count), torrent.info_hash, piece_size),
+                                       ]
+                               )
+
+                               do_verify(self._conn, other_disk.id, od_ss, torrent_file, read_size = 128, read_tries = 1)
+
+                               cursor.execute("select count(*) from diskjumble.verify_piece_incomplete;")
+                               [(row_count,)] = cursor.fetchall()
+                               self.assertEqual(row_count, 1)
+
+                               torrent_file.seek(piece_size)
+                               disk_file = io.BytesIO(torrent_file.read())
+                               do_verify(self._conn, disk.id, d_ss, disk_file, read_size = 128, read_tries = 1)
+
+                               cursor.execute(
+                                       "select disk_id from diskjumble.verify_piece_content natural join diskjumble.verify_piece_incomplete;"
+                               )
+                               self.assertEqual(cursor.fetchall(), [(other_disk.id,)])
+
+                               cursor.execute("select disk_id, disk_sectors from diskjumble.verify_pass;")
+                               self.assertEqual(cursor.fetchall(), [(disk.id, NumericRange(0, disk.sector_count))])
+
+       def test_ignore_hasher_unaligned(self):
+               """
+               Test a run where a saved hasher isn't used because its entity data offset isn't sector-aligned on the target
+               disk.
+
+                           0   16  32  48  64  80  96  112 128
+               pieces:     [-------------- 0 -------------]
+               other disk: [--][--][--][--][--]
+               disk:                       [------][------]
+               """
+               piece_size = 128
+
+               other_disk = self._write_disk(5)
+               od_ss = 16
+               disk = self._write_disk(2)
+               d_ss = 32
+               with _random_file(piece_size, Random(0), on_disk = False) as torrent_file:
+                       torrent = _Torrent(torrent_file, piece_size)
+                       self._write_torrent(torrent)
+                       with self._conn.cursor() as cursor:
+                               cursor.executemany(
+                                       "insert into diskjumble.slab values (default, %s, %s, %s, %s, null);",
+                                       [
+                                               (other_disk.id, NumericRange(0, 5), torrent.info_hash, 0),
+                                               (disk.id, NumericRange(0, 2), torrent.info_hash, 64),
+                                       ]
+                               )
+
+                               do_verify(self._conn, other_disk.id, od_ss, torrent_file, read_size = 16, read_tries = 1)
+                               cursor.execute("select count(*) from diskjumble.verify_piece_incomplete;")
+                               [(row_count,)] = cursor.fetchall()
+                               self.assertEqual(row_count, 1)
+
+                               disk_file = io.BytesIO(torrent_file.getvalue()[64:])
+                               do_verify(self._conn, disk.id, d_ss, disk_file, read_size = 16, read_tries = 1)
+
+                               cursor.execute("""
+                                       select disk_id, disk_sectors
+                                       from diskjumble.verify_piece_incomplete natural join diskjumble.verify_piece_content;
+                               """)
+                               self.assertEqual(
+                                       cursor.fetchall(),
+                                       [(other_disk.id, NumericRange(0, 5))]
+                               )
+
+                               cursor.execute("select count(*) from diskjumble.verify_pass;")
+                               [(row_count,)] = cursor.fetchall()
+                               self.assertEqual(row_count, 0)
+
+                               cursor.execute("select count(*) from diskjumble.verify_piece_fail;")
+                               [(row_count,)] = cursor.fetchall()
+                               self.assertEqual(row_count, 0)
+
+       def test_transient_read_errors(self):
+               """
+               Test a run where a read to the disk fails but fewer times than needed to mark the sector bad.
+               """
+               sector_size = 32
+               piece_size = 32
+
+               disk = self._write_disk(1)
+               with _random_file(piece_size, Random(0), on_disk = False) as torrent_file:
+                       torrent = _Torrent(torrent_file, piece_size)
+                       self._write_torrent(torrent)
+                       with self._conn.cursor() as cursor:
+                               cursor.execute(
+                                       "insert into diskjumble.slab values (default, %s, %s, %s, %s, null);",
+                                       (disk.id, NumericRange(0, 1), torrent.info_hash, 0)
+                               )
+
+                               disk_file = _ReadErrorProxy(torrent_file, error_pos = 12, error_count = 2)
+                               do_verify(self._conn, disk.id, sector_size, disk_file, read_size = 4, read_tries = 3)
+
+                               self.assertEqual(disk_file.triggered, 2)
+
+                               cursor.execute("select count(*) from diskjumble.verify_piece;")
+                               [(row_count,)] = cursor.fetchall()
+                               self.assertEqual(row_count, 0)
+
+                               cursor.execute("select disk_id, disk_sectors from diskjumble.verify_pass;")
+                               self.assertEqual(cursor.fetchall(), [(disk.id, NumericRange(0, 1))])
+
+       def test_persistent_read_errors(self):
+               """
+               Test a run where a disk read fails enough times to trigger the bad sector logic.
+               """
+               sector_size = 16
+               piece_size = 64
+
+               other_a = self._write_disk(1)
+               other_b = self._write_disk(2)
+               disk = self._write_disk(1)
+               with _random_file(piece_size, Random(0), on_disk = False) as torrent_file:
+                       torrent = _Torrent(torrent_file, piece_size)
+                       self._write_torrent(torrent)
+                       with self._conn.cursor() as cursor:
+                               cursor.executemany(
+                                       "insert into diskjumble.slab values (default, %s, %s, %s, %s, null);",
+                                       [
+                                               (other_a.id, NumericRange(0, 1), torrent.info_hash, 0),
+                                               (other_b.id, NumericRange(0, 2), torrent.info_hash, 16),
+                                               (disk.id, NumericRange(0, 1), torrent.info_hash, 48),
+                                       ]
+                               )
+
+                               do_verify(self._conn, other_a.id, sector_size, torrent_file, read_size = 16, read_tries = 1)
+                               other_b_file = io.BytesIO(torrent_file.getvalue()[16:48])
+                               do_verify(self._conn, other_b.id, sector_size, other_b_file, read_size = 16, read_tries = 1)
+
+                               cursor.execute("select verify_id from diskjumble.verify_piece;")
+                               [(verify_id,)] = cursor.fetchall()
+
+                               data = torrent_file.getvalue()[48:]
+                               disk_file = _ReadErrorProxy(io.BytesIO(data), error_pos = 5, error_count = None)
+                               do_verify(self._conn, disk.id, sector_size, disk_file, read_size = 4, read_tries = 3)
+
+                               cursor.execute("select count(*) from diskjumble.verify_pass;")
+                               [(row_count,)] = cursor.fetchall()
+                               self.assertEqual(row_count, 0)
+
+                               cursor.execute("select disk_id, disk_sectors from diskjumble.verify_piece_content;")
+                               self.assertCountEqual(
+                                       cursor.fetchall(),
+                                       [(other_a.id, NumericRange(0, 1)), (other_b.id, NumericRange(0, 2)), (disk.id, NumericRange(0, 1))]
+                               )
+
+                               cursor.execute("select verify_id from diskjumble.verify_piece_fail;")
+                               self.assertEqual(cursor.fetchall(), [(verify_id,)])
+
+                               cursor.execute("select count(*) from diskjumble.verify_piece_incomplete;")
+                               [(row_count,)] = cursor.fetchall()
+                               self.assertEqual(row_count, 0)
+
+       def _write_torrent(self, torrent: "_Torrent") -> None:
+               with self._conn.cursor() as cursor:
+                       cursor.execute("insert into bittorrent.torrent_info values (%s);", (torrent.info,))
+
+       def _write_disk(self, sector_count: int) -> "_Disk":
+               with self._conn.cursor() as cursor:
+                       cursor.execute(
+                               "insert into diskjumble.disk values (default, %s, null, %s, false) returning disk_id;",
+                               (uuid.uuid4(), sector_count)
+                       )
+                       [(id_,)] = cursor.fetchall()
+               return _Disk(id_, sector_count)
+
+       @classmethod
+       def setUpClass(cls):
+               psycopg2.extras.register_uuid()
+               prod_schema_sql = resources.files(__package__).joinpath("verify_setup.sql").read_text()
+               schema_sql = "\n".join(
+                       l for l in prod_schema_sql.splitlines()
+                       if (
+                               not l.startswith("SET")
+                               and not l.startswith("SELECT")  # ignore server settings
+                               and "OWNER TO" not in l  # and ownership changes
+                       )
+               )
+               cls._conn = psycopg2.connect("")
+               with cls._conn, cls._conn.cursor() as cursor:
+                       for name in cls._SCHEMAS:
+                               cursor.execute(f"create schema {name};")
+                       cursor.execute(schema_sql)
+
+       @classmethod
+       def tearDownClass(self):
+               try:
+                       with self._conn, self._conn.cursor() as cursor:
+                               for name in self._SCHEMAS:
+                                       cursor.execute(f"drop schema {name} cascade;")
+               finally:
+                       self._conn.close()
+
+       def tearDown(self):
+               self._conn.rollback()
+
+
+@dataclass
+class _Disk:
+       id: int
+       sector_count: int
+
+
+class _Torrent:
+       def __init__(self, data: io.BufferedIOBase, piece_size: int) -> None:
+               data.seek(0)
+               hashes = []
+               while True:
+                       piece = data.read(piece_size)
+                       if piece:
+                               hashes.append(hashlib.sha1(piece).digest())
+                       else:
+                               break
+
+               info_dict = {
+                       b"piece length": piece_size,
+                       b"length": data.tell(),
+                       b"pieces": b"".join(hashes),
+               }
+               self.data = data
+               self.info = bencode.encode(info_dict)
+               self.info_hash = hashlib.sha1(self.info).digest()
+
+
+def _random_file(size: int, rand_src: Random, on_disk: bool) -> io.BufferedIOBase:
+       f = tempfile.NamedTemporaryFile(buffering = _BUF_SIZE) if on_disk else io.BytesIO()
+       try:
+               while f.tell() < size:
+                       write_size = min(size - f.tell(), _BUF_SIZE)
+                       f.write(bytes(rand_src.getrandbits(8) for _ in range(write_size)))
+               f.seek(0)
+               return f
+       except Exception:
+               f.close()
+               raise
+
+
+@dataclass
+class _ReadErrorProxy(io.BufferedIOBase):
+       wrapped: io.BufferedIOBase
+       error_pos: int
+       error_count: Optional[int]
+
+       def __post_init__(self):
+               self.triggered = 0
+
+       def read(self, size: Optional[int] = None) -> bytes:
+               pre_pos = self.wrapped.tell()
+               data = self.wrapped.read(size)
+               erroring = self.error_count is None or self.triggered < self.error_count
+               in_range = 0 <= pre_pos - self.error_pos < len(data)
+               if erroring and in_range:
+                       self.triggered += 1
+                       raise OSError("simulated")
+               else:
+                       return data
+
+       def seek(self, *args, **kwargs) -> int:
+               return self.wrapped.seek(*args, **kwargs)
index aa0c5fb347364822066db181f30b8ef477bd7f2a..9189c4e35a81e09d4e876173c90b8927fe3e1224 100644 (file)
@@ -9,8 +9,8 @@ AS '$libdir/pgcrypto', $function$pg_digest$function$
 -- PostgreSQL database dump
 --
 
--- Dumped from database version 13.2 (Debian 13.2-1.pgdg100+1)
--- Dumped by pg_dump version 13.4 (Debian 13.4-0+deb11u1)
+-- Dumped from database version 13.5 (Debian 13.5-1.pgdg100+1)
+-- Dumped by pg_dump version 13.5 (Debian 13.5-0+deb11u1)
 
 SET statement_timeout = 0;
 SET lock_timeout = 0;
@@ -46,8 +46,7 @@ CREATE TABLE diskjumble.disk (
     disk_id integer NOT NULL,
     dev_uuid uuid NOT NULL,
     dev_serial text,
-    sector_size integer NOT NULL,
-    num_sectors bigint NOT NULL,
+    num_blocks bigint NOT NULL,
     failed boolean DEFAULT false NOT NULL
 );
 
@@ -83,10 +82,11 @@ ALTER SEQUENCE diskjumble.disk_id_seq OWNED BY diskjumble.disk.disk_id;
 CREATE TABLE diskjumble.slab (
     slab_id integer NOT NULL,
     disk_id integer NOT NULL,
-    disk_sectors int8range NOT NULL,
+    disk_blocks int8range NOT NULL,
     entity_id bytea NOT NULL,
     entity_offset bigint NOT NULL,
-    crypt_key bytea
+    crypt_key bytea,
+    realized boolean DEFAULT false
 );
 
 
@@ -271,7 +271,7 @@ ALTER TABLE ONLY diskjumble.disk
 --
 
 ALTER TABLE ONLY diskjumble.slab
-    ADD CONSTRAINT slab_disk_id_disk_sectors_excl EXCLUDE USING gist (disk_id WITH =, disk_sectors WITH &&);
+    ADD CONSTRAINT slab_disk_id_disk_sectors_excl EXCLUDE USING gist (disk_id WITH =, disk_blocks WITH &&);
 
 
 --
index f644d9a6427ed3ef623ac3e44a8aa2341437d074..6c15143390fb4b0bd27d395dd0d6950e4c0f7f8c 100644 (file)
@@ -12,7 +12,7 @@ import math
 from psycopg2.extras import NumericRange
 import psycopg2
 
-from disk_jumble import bencode
+from disk_jumble import bencode, SECTOR_SIZE
 from disk_jumble.db import HasherRef, Slab, Wrapper as DbWrapper
 from disk_jumble.nettle import Sha1Hasher
 
@@ -43,9 +43,8 @@ class _BadSector(Exception):
        pass
 
 
-def do_verify(conn, disk_id: int, disk_file: io.BufferedIOBase, read_size: int, read_tries: int) -> None:
+def do_verify(conn, disk_id: int, sector_size: int, disk_file: io.BufferedIOBase, read_size: int, read_tries: int) -> None:
        db = DbWrapper(conn)
-       disk = db.get_disk(disk_id)
 
        info_dicts = {
                info_hash: bencode.decode(info)
@@ -53,11 +52,11 @@ def do_verify(conn, disk_id: int, disk_file: io.BufferedIOBase, read_size: int,
        }
 
        tasks = []
-       slabs_and_hashers = db.get_slabs_and_hashers(disk_id)
+       slabs_and_hashers = db.get_slabs_and_hashers(disk_id, sector_size)
        for (entity_id, group) in itertools.groupby(slabs_and_hashers, lambda t: t[0].entity_id):
                info = info_dicts[entity_id]
                piece_len = info[b"piece length"]
-               assert piece_len % disk.sector_size == 0
+               assert piece_len % sector_size == 0
                if b"length" in info:
                        torrent_len = info[b"length"]
                else:
@@ -67,7 +66,7 @@ def do_verify(conn, disk_id: int, disk_file: io.BufferedIOBase, read_size: int,
                use_hasher = None
                chunks = []
                for (slab, hasher_ref) in group:
-                       slab_end = min(slab.entity_offset + len(slab.sectors) * disk.sector_size, torrent_len)
+                       slab_end = min(slab.entity_offset + len(slab.sectors) * sector_size, torrent_len)
 
                        while offset is None or offset < slab_end:
                                if offset is not None and slab.entity_offset > offset:
@@ -126,15 +125,15 @@ def do_verify(conn, disk_id: int, disk_file: io.BufferedIOBase, read_size: int,
                hasher = Sha1Hasher(task.hasher_ref.state if task.hasher_ref else None)
                sector_ranges = [
                        range(
-                               chunk.slab.sectors.start + chunk.slice.start // disk.sector_size,
-                               chunk.slab.sectors.start + math.ceil(chunk.slice.stop / disk.sector_size)
+                               chunk.slab.sectors.start + chunk.slice.start // sector_size,
+                               chunk.slab.sectors.start + math.ceil(chunk.slice.stop / sector_size)
                        )
                        for chunk in task.chunks
                ]
 
                try:
                        for chunk in task.chunks:
-                               slab_off = chunk.slab.sectors.start * disk.sector_size
+                               slab_off = chunk.slab.sectors.start * sector_size
                                disk_file.seek(slab_off + chunk.slice.start)
                                end_pos = slab_off + chunk.slice.stop
                                while disk_file.tell() < end_pos:
@@ -379,5 +378,7 @@ if __name__ == "__main__":
                conn.autocommit = True
                path = f"/dev/mapper/diskjumble-{args.disk_id}"
                with open(path, "rb", buffering = _READ_BUFFER_SIZE) as disk_file:
-                       verify_func = do_verify_v2 if args.entity_version == "2" else do_verify
-                       verify_func(conn, args.disk_id, disk_file, _READ_BUFFER_SIZE, args.read_tries)
+                       if args.entity_version == "1":
+                               do_verify(conn, args.disk_id, SECTOR_SIZE, disk_file, _READ_BUFFER_SIZE, args.read_tries)
+                       else:
+                               do_verify_v2(conn, args.disk_id, disk_file, _READ_BUFFER_SIZE, args.read_tries)