From d06bf1b3f1b1463f4e0d7451f61de043ec11af90 Mon Sep 17 00:00:00 2001 From: Jakob Cornell Date: Sat, 27 Nov 2021 10:53:47 -0600 Subject: [PATCH] Minor test code cleanup and add new tests --- src/disk_jumble/tests/test_verify.py | 232 ++++++++++++++++++++++++--- 1 file changed, 214 insertions(+), 18 deletions(-) diff --git a/src/disk_jumble/tests/test_verify.py b/src/disk_jumble/tests/test_verify.py index f0b4b95..1edc3cc 100644 --- a/src/disk_jumble/tests/test_verify.py +++ b/src/disk_jumble/tests/test_verify.py @@ -13,6 +13,7 @@ objects; see `test_util/dump_db.py'. from dataclasses import dataclass from importlib import resources from random import Random +from typing import Optional import hashlib import io import tempfile @@ -24,25 +25,13 @@ import psycopg2 import psycopg2.extras from disk_jumble import bencode +from disk_jumble.nettle import Sha1Hasher from disk_jumble.verify import do_verify _BUF_SIZE = 16 * 1024 ** 2 # in bytes -def _random_file(size: int, rand_src: Random, on_disk: bool) -> tempfile.NamedTemporaryFile: - f = tempfile.NamedTemporaryFile(buffering = _BUF_SIZE) if on_disk else io.BytesIO() - try: - while f.tell() < size: - write_size = min(size - f.tell(), _BUF_SIZE) - f.write(bytes(rand_src.getrandbits(8) for _ in range(write_size))) - f.seek(0) - return f - except Exception: - f.close() - raise - - class Tests(unittest.TestCase): _SCHEMAS = {"public", "diskjumble", "bittorrent"} @@ -54,7 +43,6 @@ class Tests(unittest.TestCase): disk = self._write_disk(sector_size, torrent_len // sector_size) with _random_file(torrent_len, Random(0), on_disk = False) as torrent_file: torrent = _Torrent(torrent_file, piece_size) - torrent_file.seek(0) self._write_torrent(torrent) with self._conn.cursor() as cursor: cursor.execute( @@ -88,11 +76,10 @@ class Tests(unittest.TestCase): disk = self._write_disk(32, 3) with _random_file(piece_size, Random(0), on_disk = False) as torrent_file: torrent = _Torrent(torrent_file, piece_size) - torrent_file.seek(0) self._write_torrent(torrent) with self._conn.cursor() as cursor: cursor.executemany( - "insert into diskjumble.slab values (default, %s, %s, %s, %s, null)", + "insert into diskjumble.slab values (default, %s, %s, %s, %s, null);", [ (other_disk.id, NumericRange(0, 1), torrent.info_hash, 0), (disk.id, NumericRange(0, 1), torrent.info_hash, other_disk.sector_size), @@ -128,9 +115,181 @@ class Tests(unittest.TestCase): [(other_disk.id, NumericRange(0, 1)), (disk.id, NumericRange(0, 1)), (disk.id, NumericRange(2, 3))] ) - # TODO ignore useless hasher state + def test_resume_no_completion(self): + """ + Test a run where a saved hasher state is used and the target disk has subsequent entity data but not the full + remainder of the piece. + """ + read_size = 7 + piece_size = 64 + + other_disk = self._write_disk(16, 1) + disk = self._write_disk(32, 1) + with _random_file(piece_size, Random(0), on_disk = False) as torrent_file: + torrent = _Torrent(torrent_file, piece_size) + self._write_torrent(torrent) + with self._conn.cursor() as cursor: + cursor.executemany( + "insert into diskjumble.slab values (default, %s, %s, %s, %s, null);", + [ + (other_disk.id, NumericRange(0, 1), torrent.info_hash, 0), + (disk.id, NumericRange(0, 1), torrent.info_hash, other_disk.sector_size), + ] + ) + + do_verify(self._conn, other_disk.id, torrent_file, read_size, read_tries = 1) + + cursor.execute("select count(*) from diskjumble.verify_piece_incomplete;") + [(row_count,)] = cursor.fetchall() + self.assertEqual(row_count, 1) + + disk_file = io.BytesIO() + torrent_file.seek(other_disk.sector_size) + disk_file.write(torrent_file.read(disk.sector_size)) + disk_file.seek(0) + do_verify(self._conn, disk.id, disk_file, read_size, read_tries = 1) + + cursor.execute("select count(*) from diskjumble.verify_pass;") + [(row_count,)] = cursor.fetchall() + self.assertEqual(row_count, 0) + + cursor.execute("select entity_id, piece from diskjumble.verify_piece;") + [(entity_id, piece_num)] = cursor.fetchall() + self.assertEqual(bytes(entity_id), torrent.info_hash) + self.assertEqual(piece_num, 0) + + cursor.execute("select disk_id, disk_sectors from diskjumble.verify_piece_content;") + self.assertCountEqual( + cursor.fetchall(), + [(other_disk.id, NumericRange(0, 1)), (disk.id, NumericRange(0, 1))] + ) + + cursor.execute("select count(*) from diskjumble.verify_piece_fail;") + [(row_count,)] = cursor.fetchall() + self.assertEqual(row_count, 0) + + hasher = Sha1Hasher(None) + torrent_file.seek(0) + hasher.update(torrent_file.read(other_disk.sector_size + disk.sector_size)) + cursor.execute("select hasher_state from diskjumble.verify_piece_incomplete;") + self.assertEqual(cursor.fetchall(), [(hasher.ctx.serialize(),)]) + + def test_ignore_hasher_state(self): + """ + Test a run where a saved hasher state is available for use but isn't used due to the beginning of the piece + being on disk. + """ + piece_size = 64 + + other_disk = self._write_disk(16, 1) + disk = self._write_disk(16, 4) + with _random_file(piece_size * 2, Random(0), on_disk = False) as torrent_file: + torrent = _Torrent(torrent_file, piece_size) + self._write_torrent(torrent) + with self._conn.cursor() as cursor: + cursor.executemany( + "insert into diskjumble.slab values (default, %s, %s, %s, %s, null);", + [ + (other_disk.id, NumericRange(0, other_disk.sector_count), torrent.info_hash, piece_size), + (disk.id, NumericRange(0, disk.sector_count), torrent.info_hash, piece_size), + ] + ) + + do_verify(self._conn, other_disk.id, torrent_file, read_size = 128, read_tries = 1) + + cursor.execute("select count(*) from diskjumble.verify_piece_incomplete;") + [(row_count,)] = cursor.fetchall() + self.assertEqual(row_count, 1) + + torrent_file.seek(piece_size) + disk_file = io.BytesIO(torrent_file.read()) + do_verify(self._conn, disk.id, disk_file, read_size = 128, read_tries = 1) + + cursor.execute( + "select disk_id from diskjumble.verify_piece_content natural join diskjumble.verify_piece_incomplete;" + ) + self.assertEqual(cursor.fetchall(), [(other_disk.id,)]) + + cursor.execute("select disk_id, disk_sectors from diskjumble.verify_pass;") + self.assertEqual(cursor.fetchall(), [(disk.id, NumericRange(0, disk.sector_count))]) + + def test_transient_read_errors(self): + """ + Test a run where a read to the disk fails but fewer times than needed to mark the sector bad. + """ + piece_size = 32 + + disk = self._write_disk(32, 1) + with _random_file(piece_size, Random(0), on_disk = False) as torrent_file: + torrent = _Torrent(torrent_file, piece_size) + self._write_torrent(torrent) + with self._conn.cursor() as cursor: + cursor.execute( + "insert into diskjumble.slab values (default, %s, %s, %s, %s, null);", + (disk.id, NumericRange(0, 1), torrent.info_hash, 0) + ) + + disk_file = _ReadErrorProxy(torrent_file, error_pos = 12, error_count = 2) + do_verify(self._conn, disk.id, disk_file, read_size = 4, read_tries = 3) + + self.assertEqual(disk_file.triggered, 2) + + cursor.execute("select count(*) from diskjumble.verify_piece;") + [(row_count,)] = cursor.fetchall() + self.assertEqual(row_count, 0) + + cursor.execute("select disk_id, disk_sectors from diskjumble.verify_pass;") + self.assertEqual(cursor.fetchall(), [(disk.id, NumericRange(0, 1))]) + + def test_persistent_read_errors(self): + """ + Test a run where a disk read fails enough times to trigger the bad sector logic. + """ + piece_size = 64 + + other_a = self._write_disk(16, 1) + other_b = self._write_disk(16, 2) + disk = self._write_disk(16, 1) + with _random_file(piece_size, Random(0), on_disk = False) as torrent_file: + torrent = _Torrent(torrent_file, piece_size) + self._write_torrent(torrent) + with self._conn.cursor() as cursor: + cursor.executemany( + "insert into diskjumble.slab values (default, %s, %s, %s, %s, null);", + [ + (other_a.id, NumericRange(0, 1), torrent.info_hash, 0), + (other_b.id, NumericRange(0, 2), torrent.info_hash, 16), + (disk.id, NumericRange(0, 1), torrent.info_hash, 48), + ] + ) + + do_verify(self._conn, other_a.id, torrent_file, read_size = 16, read_tries = 1) + other_b_file = io.BytesIO(torrent_file.getvalue()[16:48]) + do_verify(self._conn, other_b.id, other_b_file, read_size = 16, read_tries = 1) + + cursor.execute("select verify_id from diskjumble.verify_piece;") + [(verify_id,)] = cursor.fetchall() - # TODO read errors + data = torrent_file.getvalue()[48:] + disk_file = _ReadErrorProxy(io.BytesIO(data), error_pos = 5, error_count = None) + do_verify(self._conn, disk.id, disk_file, read_size = 4, read_tries = 3) + + cursor.execute("select count(*) from diskjumble.verify_pass;") + [(row_count,)] = cursor.fetchall() + self.assertEqual(row_count, 0) + + cursor.execute("select disk_id, disk_sectors from diskjumble.verify_piece_content;") + self.assertCountEqual( + cursor.fetchall(), + [(other_a.id, NumericRange(0, 1)), (other_b.id, NumericRange(0, 2)), (disk.id, NumericRange(0, 1))] + ) + + cursor.execute("select verify_id from diskjumble.verify_piece_fail;") + self.assertEqual(cursor.fetchall(), [(verify_id,)]) + + cursor.execute("select count(*) from diskjumble.verify_piece_incomplete;") + [(row_count,)] = cursor.fetchall() + self.assertEqual(row_count, 0) def _write_torrent(self, torrent: "_Torrent") -> None: with self._conn.cursor() as cursor: @@ -202,3 +361,40 @@ class _Torrent: self.data = data self.info = bencode.encode(info_dict) self.info_hash = hashlib.sha1(self.info).digest() + + +def _random_file(size: int, rand_src: Random, on_disk: bool) -> tempfile.NamedTemporaryFile: + f = tempfile.NamedTemporaryFile(buffering = _BUF_SIZE) if on_disk else io.BytesIO() + try: + while f.tell() < size: + write_size = min(size - f.tell(), _BUF_SIZE) + f.write(bytes(rand_src.getrandbits(8) for _ in range(write_size))) + f.seek(0) + return f + except Exception: + f.close() + raise + + +@dataclass +class _ReadErrorProxy(io.BufferedIOBase): + wrapped: io.BufferedIOBase + error_pos: int + error_count: Optional[int] + + def __post_init__(self): + self.triggered = 0 + + def read(self, size: int = -1) -> bytes: + pre_pos = self.wrapped.tell() + data = self.wrapped.read(size) + erroring = self.error_count is None or self.triggered < self.error_count + in_range = 0 <= pre_pos - self.error_pos < len(data) + if erroring and in_range: + self.triggered += 1 + raise OSError("simulated") + else: + return data + + def seek(self, *args, **kwargs) -> int: + return self.wrapped.seek(*args, **kwargs) -- 2.30.2