Minor test code cleanup and add new tests
authorJakob Cornell <jakob+gpg@jcornell.net>
Sat, 27 Nov 2021 16:53:47 +0000 (10:53 -0600)
committerJakob Cornell <jakob+gpg@jcornell.net>
Sat, 27 Nov 2021 16:53:47 +0000 (10:53 -0600)
src/disk_jumble/tests/test_verify.py

index f0b4b95d20cfeac4897d817d340f9171760840fb..1edc3cc14c0361090705ecadcca51304d9f2d36c 100644 (file)
@@ -13,6 +13,7 @@ objects; see `test_util/dump_db.py'.
 from dataclasses import dataclass
 from importlib import resources
 from random import Random
+from typing import Optional
 import hashlib
 import io
 import tempfile
@@ -24,25 +25,13 @@ import psycopg2
 import psycopg2.extras
 
 from disk_jumble import bencode
+from disk_jumble.nettle import Sha1Hasher
 from disk_jumble.verify import do_verify
 
 
 _BUF_SIZE = 16 * 1024 ** 2  # in bytes
 
 
-def _random_file(size: int, rand_src: Random, on_disk: bool) -> tempfile.NamedTemporaryFile:
-       f = tempfile.NamedTemporaryFile(buffering = _BUF_SIZE) if on_disk else io.BytesIO()
-       try:
-               while f.tell() < size:
-                       write_size = min(size - f.tell(), _BUF_SIZE)
-                       f.write(bytes(rand_src.getrandbits(8) for _ in range(write_size)))
-               f.seek(0)
-               return f
-       except Exception:
-               f.close()
-               raise
-
-
 class Tests(unittest.TestCase):
        _SCHEMAS = {"public", "diskjumble", "bittorrent"}
 
@@ -54,7 +43,6 @@ class Tests(unittest.TestCase):
                disk = self._write_disk(sector_size, torrent_len // sector_size)
                with _random_file(torrent_len, Random(0), on_disk = False) as torrent_file:
                        torrent = _Torrent(torrent_file, piece_size)
-                       torrent_file.seek(0)
                        self._write_torrent(torrent)
                        with self._conn.cursor() as cursor:
                                cursor.execute(
@@ -88,11 +76,10 @@ class Tests(unittest.TestCase):
                disk = self._write_disk(32, 3)
                with _random_file(piece_size, Random(0), on_disk = False) as torrent_file:
                        torrent = _Torrent(torrent_file, piece_size)
-                       torrent_file.seek(0)
                        self._write_torrent(torrent)
                        with self._conn.cursor() as cursor:
                                cursor.executemany(
-                                       "insert into diskjumble.slab values (default, %s, %s, %s, %s, null)",
+                                       "insert into diskjumble.slab values (default, %s, %s, %s, %s, null);",
                                        [
                                                (other_disk.id, NumericRange(0, 1), torrent.info_hash, 0),
                                                (disk.id, NumericRange(0, 1), torrent.info_hash, other_disk.sector_size),
@@ -128,9 +115,181 @@ class Tests(unittest.TestCase):
                                        [(other_disk.id, NumericRange(0, 1)), (disk.id, NumericRange(0, 1)), (disk.id, NumericRange(2, 3))]
                                )
 
-       # TODO ignore useless hasher state
+       def test_resume_no_completion(self):
+               """
+               Test a run where a saved hasher state is used and the target disk has subsequent entity data but not the full
+               remainder of the piece.
+               """
+               read_size = 7
+               piece_size = 64
+
+               other_disk = self._write_disk(16, 1)
+               disk = self._write_disk(32, 1)
+               with _random_file(piece_size, Random(0), on_disk = False) as torrent_file:
+                       torrent = _Torrent(torrent_file, piece_size)
+                       self._write_torrent(torrent)
+                       with self._conn.cursor() as cursor:
+                               cursor.executemany(
+                                       "insert into diskjumble.slab values (default, %s, %s, %s, %s, null);",
+                                       [
+                                               (other_disk.id, NumericRange(0, 1), torrent.info_hash, 0),
+                                               (disk.id, NumericRange(0, 1), torrent.info_hash, other_disk.sector_size),
+                                       ]
+                               )
+
+                               do_verify(self._conn, other_disk.id, torrent_file, read_size, read_tries = 1)
+
+                               cursor.execute("select count(*) from diskjumble.verify_piece_incomplete;")
+                               [(row_count,)] = cursor.fetchall()
+                               self.assertEqual(row_count, 1)
+
+                               disk_file = io.BytesIO()
+                               torrent_file.seek(other_disk.sector_size)
+                               disk_file.write(torrent_file.read(disk.sector_size))
+                               disk_file.seek(0)
+                               do_verify(self._conn, disk.id, disk_file, read_size, read_tries = 1)
+
+                               cursor.execute("select count(*) from diskjumble.verify_pass;")
+                               [(row_count,)] = cursor.fetchall()
+                               self.assertEqual(row_count, 0)
+
+                               cursor.execute("select entity_id, piece from diskjumble.verify_piece;")
+                               [(entity_id, piece_num)] = cursor.fetchall()
+                               self.assertEqual(bytes(entity_id), torrent.info_hash)
+                               self.assertEqual(piece_num, 0)
+
+                               cursor.execute("select disk_id, disk_sectors from diskjumble.verify_piece_content;")
+                               self.assertCountEqual(
+                                       cursor.fetchall(),
+                                       [(other_disk.id, NumericRange(0, 1)), (disk.id, NumericRange(0, 1))]
+                               )
+
+                               cursor.execute("select count(*) from diskjumble.verify_piece_fail;")
+                               [(row_count,)] = cursor.fetchall()
+                               self.assertEqual(row_count, 0)
+
+                               hasher = Sha1Hasher(None)
+                               torrent_file.seek(0)
+                               hasher.update(torrent_file.read(other_disk.sector_size + disk.sector_size))
+                               cursor.execute("select hasher_state from diskjumble.verify_piece_incomplete;")
+                               self.assertEqual(cursor.fetchall(), [(hasher.ctx.serialize(),)])
+
+       def test_ignore_hasher_state(self):
+               """
+               Test a run where a saved hasher state is available for use but isn't used due to the beginning of the piece
+               being on disk.
+               """
+               piece_size = 64
+
+               other_disk = self._write_disk(16, 1)
+               disk = self._write_disk(16, 4)
+               with _random_file(piece_size * 2, Random(0), on_disk = False) as torrent_file:
+                       torrent = _Torrent(torrent_file, piece_size)
+                       self._write_torrent(torrent)
+                       with self._conn.cursor() as cursor:
+                               cursor.executemany(
+                                       "insert into diskjumble.slab values (default, %s, %s, %s, %s, null);",
+                                       [
+                                               (other_disk.id, NumericRange(0, other_disk.sector_count), torrent.info_hash, piece_size),
+                                               (disk.id, NumericRange(0, disk.sector_count), torrent.info_hash, piece_size),
+                                       ]
+                               )
+
+                               do_verify(self._conn, other_disk.id, torrent_file, read_size = 128, read_tries = 1)
+
+                               cursor.execute("select count(*) from diskjumble.verify_piece_incomplete;")
+                               [(row_count,)] = cursor.fetchall()
+                               self.assertEqual(row_count, 1)
+
+                               torrent_file.seek(piece_size)
+                               disk_file = io.BytesIO(torrent_file.read())
+                               do_verify(self._conn, disk.id, disk_file, read_size = 128, read_tries = 1)
+
+                               cursor.execute(
+                                       "select disk_id from diskjumble.verify_piece_content natural join diskjumble.verify_piece_incomplete;"
+                               )
+                               self.assertEqual(cursor.fetchall(), [(other_disk.id,)])
+
+                               cursor.execute("select disk_id, disk_sectors from diskjumble.verify_pass;")
+                               self.assertEqual(cursor.fetchall(), [(disk.id, NumericRange(0, disk.sector_count))])
+
+       def test_transient_read_errors(self):
+               """
+               Test a run where a read to the disk fails but fewer times than needed to mark the sector bad.
+               """
+               piece_size = 32
+
+               disk = self._write_disk(32, 1)
+               with _random_file(piece_size, Random(0), on_disk = False) as torrent_file:
+                       torrent = _Torrent(torrent_file, piece_size)
+                       self._write_torrent(torrent)
+                       with self._conn.cursor() as cursor:
+                               cursor.execute(
+                                       "insert into diskjumble.slab values (default, %s, %s, %s, %s, null);",
+                                       (disk.id, NumericRange(0, 1), torrent.info_hash, 0)
+                               )
+
+                               disk_file = _ReadErrorProxy(torrent_file, error_pos = 12, error_count = 2)
+                               do_verify(self._conn, disk.id, disk_file, read_size = 4, read_tries = 3)
+
+                               self.assertEqual(disk_file.triggered, 2)
+
+                               cursor.execute("select count(*) from diskjumble.verify_piece;")
+                               [(row_count,)] = cursor.fetchall()
+                               self.assertEqual(row_count, 0)
+
+                               cursor.execute("select disk_id, disk_sectors from diskjumble.verify_pass;")
+                               self.assertEqual(cursor.fetchall(), [(disk.id, NumericRange(0, 1))])
+
+       def test_persistent_read_errors(self):
+               """
+               Test a run where a disk read fails enough times to trigger the bad sector logic.
+               """
+               piece_size = 64
+
+               other_a = self._write_disk(16, 1)
+               other_b = self._write_disk(16, 2)
+               disk = self._write_disk(16, 1)
+               with _random_file(piece_size, Random(0), on_disk = False) as torrent_file:
+                       torrent = _Torrent(torrent_file, piece_size)
+                       self._write_torrent(torrent)
+                       with self._conn.cursor() as cursor:
+                               cursor.executemany(
+                                       "insert into diskjumble.slab values (default, %s, %s, %s, %s, null);",
+                                       [
+                                               (other_a.id, NumericRange(0, 1), torrent.info_hash, 0),
+                                               (other_b.id, NumericRange(0, 2), torrent.info_hash, 16),
+                                               (disk.id, NumericRange(0, 1), torrent.info_hash, 48),
+                                       ]
+                               )
+
+                               do_verify(self._conn, other_a.id, torrent_file, read_size = 16, read_tries = 1)
+                               other_b_file = io.BytesIO(torrent_file.getvalue()[16:48])
+                               do_verify(self._conn, other_b.id, other_b_file, read_size = 16, read_tries = 1)
+
+                               cursor.execute("select verify_id from diskjumble.verify_piece;")
+                               [(verify_id,)] = cursor.fetchall()
 
-       # TODO read errors
+                               data = torrent_file.getvalue()[48:]
+                               disk_file = _ReadErrorProxy(io.BytesIO(data), error_pos = 5, error_count = None)
+                               do_verify(self._conn, disk.id, disk_file, read_size = 4, read_tries = 3)
+
+                               cursor.execute("select count(*) from diskjumble.verify_pass;")
+                               [(row_count,)] = cursor.fetchall()
+                               self.assertEqual(row_count, 0)
+
+                               cursor.execute("select disk_id, disk_sectors from diskjumble.verify_piece_content;")
+                               self.assertCountEqual(
+                                       cursor.fetchall(),
+                                       [(other_a.id, NumericRange(0, 1)), (other_b.id, NumericRange(0, 2)), (disk.id, NumericRange(0, 1))]
+                               )
+
+                               cursor.execute("select verify_id from diskjumble.verify_piece_fail;")
+                               self.assertEqual(cursor.fetchall(), [(verify_id,)])
+
+                               cursor.execute("select count(*) from diskjumble.verify_piece_incomplete;")
+                               [(row_count,)] = cursor.fetchall()
+                               self.assertEqual(row_count, 0)
 
        def _write_torrent(self, torrent: "_Torrent") -> None:
                with self._conn.cursor() as cursor:
@@ -202,3 +361,40 @@ class _Torrent:
                self.data = data
                self.info = bencode.encode(info_dict)
                self.info_hash = hashlib.sha1(self.info).digest()
+
+
+def _random_file(size: int, rand_src: Random, on_disk: bool) -> tempfile.NamedTemporaryFile:
+       f = tempfile.NamedTemporaryFile(buffering = _BUF_SIZE) if on_disk else io.BytesIO()
+       try:
+               while f.tell() < size:
+                       write_size = min(size - f.tell(), _BUF_SIZE)
+                       f.write(bytes(rand_src.getrandbits(8) for _ in range(write_size)))
+               f.seek(0)
+               return f
+       except Exception:
+               f.close()
+               raise
+
+
+@dataclass
+class _ReadErrorProxy(io.BufferedIOBase):
+       wrapped: io.BufferedIOBase
+       error_pos: int
+       error_count: Optional[int]
+
+       def __post_init__(self):
+               self.triggered = 0
+
+       def read(self, size: int = -1) -> bytes:
+               pre_pos = self.wrapped.tell()
+               data = self.wrapped.read(size)
+               erroring = self.error_count is None or self.triggered < self.error_count
+               in_range = 0 <= pre_pos - self.error_pos < len(data)
+               if erroring and in_range:
+                       self.triggered += 1
+                       raise OSError("simulated")
+               else:
+                       return data
+
+       def seek(self, *args, **kwargs) -> int:
+               return self.wrapped.seek(*args, **kwargs)