from dataclasses import dataclass
from importlib import resources
from random import Random
+from typing import Optional
import hashlib
import io
import tempfile
import psycopg2.extras
from disk_jumble import bencode
+from disk_jumble.nettle import Sha1Hasher
from disk_jumble.verify import do_verify
_BUF_SIZE = 16 * 1024 ** 2 # in bytes
-def _random_file(size: int, rand_src: Random, on_disk: bool) -> tempfile.NamedTemporaryFile:
- f = tempfile.NamedTemporaryFile(buffering = _BUF_SIZE) if on_disk else io.BytesIO()
- try:
- while f.tell() < size:
- write_size = min(size - f.tell(), _BUF_SIZE)
- f.write(bytes(rand_src.getrandbits(8) for _ in range(write_size)))
- f.seek(0)
- return f
- except Exception:
- f.close()
- raise
-
-
class Tests(unittest.TestCase):
_SCHEMAS = {"public", "diskjumble", "bittorrent"}
disk = self._write_disk(sector_size, torrent_len // sector_size)
with _random_file(torrent_len, Random(0), on_disk = False) as torrent_file:
torrent = _Torrent(torrent_file, piece_size)
- torrent_file.seek(0)
self._write_torrent(torrent)
with self._conn.cursor() as cursor:
cursor.execute(
disk = self._write_disk(32, 3)
with _random_file(piece_size, Random(0), on_disk = False) as torrent_file:
torrent = _Torrent(torrent_file, piece_size)
- torrent_file.seek(0)
self._write_torrent(torrent)
with self._conn.cursor() as cursor:
cursor.executemany(
- "insert into diskjumble.slab values (default, %s, %s, %s, %s, null)",
+ "insert into diskjumble.slab values (default, %s, %s, %s, %s, null);",
[
(other_disk.id, NumericRange(0, 1), torrent.info_hash, 0),
(disk.id, NumericRange(0, 1), torrent.info_hash, other_disk.sector_size),
[(other_disk.id, NumericRange(0, 1)), (disk.id, NumericRange(0, 1)), (disk.id, NumericRange(2, 3))]
)
- # TODO ignore useless hasher state
+ def test_resume_no_completion(self):
+ """
+ Test a run where a saved hasher state is used and the target disk has subsequent entity data but not the full
+ remainder of the piece.
+ """
+ read_size = 7
+ piece_size = 64
+
+ other_disk = self._write_disk(16, 1)
+ disk = self._write_disk(32, 1)
+ with _random_file(piece_size, Random(0), on_disk = False) as torrent_file:
+ torrent = _Torrent(torrent_file, piece_size)
+ self._write_torrent(torrent)
+ with self._conn.cursor() as cursor:
+ cursor.executemany(
+ "insert into diskjumble.slab values (default, %s, %s, %s, %s, null);",
+ [
+ (other_disk.id, NumericRange(0, 1), torrent.info_hash, 0),
+ (disk.id, NumericRange(0, 1), torrent.info_hash, other_disk.sector_size),
+ ]
+ )
+
+ do_verify(self._conn, other_disk.id, torrent_file, read_size, read_tries = 1)
+
+ cursor.execute("select count(*) from diskjumble.verify_piece_incomplete;")
+ [(row_count,)] = cursor.fetchall()
+ self.assertEqual(row_count, 1)
+
+ disk_file = io.BytesIO()
+ torrent_file.seek(other_disk.sector_size)
+ disk_file.write(torrent_file.read(disk.sector_size))
+ disk_file.seek(0)
+ do_verify(self._conn, disk.id, disk_file, read_size, read_tries = 1)
+
+ cursor.execute("select count(*) from diskjumble.verify_pass;")
+ [(row_count,)] = cursor.fetchall()
+ self.assertEqual(row_count, 0)
+
+ cursor.execute("select entity_id, piece from diskjumble.verify_piece;")
+ [(entity_id, piece_num)] = cursor.fetchall()
+ self.assertEqual(bytes(entity_id), torrent.info_hash)
+ self.assertEqual(piece_num, 0)
+
+ cursor.execute("select disk_id, disk_sectors from diskjumble.verify_piece_content;")
+ self.assertCountEqual(
+ cursor.fetchall(),
+ [(other_disk.id, NumericRange(0, 1)), (disk.id, NumericRange(0, 1))]
+ )
+
+ cursor.execute("select count(*) from diskjumble.verify_piece_fail;")
+ [(row_count,)] = cursor.fetchall()
+ self.assertEqual(row_count, 0)
+
+ hasher = Sha1Hasher(None)
+ torrent_file.seek(0)
+ hasher.update(torrent_file.read(other_disk.sector_size + disk.sector_size))
+ cursor.execute("select hasher_state from diskjumble.verify_piece_incomplete;")
+ self.assertEqual(cursor.fetchall(), [(hasher.ctx.serialize(),)])
+
+ def test_ignore_hasher_state(self):
+ """
+ Test a run where a saved hasher state is available for use but isn't used due to the beginning of the piece
+ being on disk.
+ """
+ piece_size = 64
+
+ other_disk = self._write_disk(16, 1)
+ disk = self._write_disk(16, 4)
+ with _random_file(piece_size * 2, Random(0), on_disk = False) as torrent_file:
+ torrent = _Torrent(torrent_file, piece_size)
+ self._write_torrent(torrent)
+ with self._conn.cursor() as cursor:
+ cursor.executemany(
+ "insert into diskjumble.slab values (default, %s, %s, %s, %s, null);",
+ [
+ (other_disk.id, NumericRange(0, other_disk.sector_count), torrent.info_hash, piece_size),
+ (disk.id, NumericRange(0, disk.sector_count), torrent.info_hash, piece_size),
+ ]
+ )
+
+ do_verify(self._conn, other_disk.id, torrent_file, read_size = 128, read_tries = 1)
+
+ cursor.execute("select count(*) from diskjumble.verify_piece_incomplete;")
+ [(row_count,)] = cursor.fetchall()
+ self.assertEqual(row_count, 1)
+
+ torrent_file.seek(piece_size)
+ disk_file = io.BytesIO(torrent_file.read())
+ do_verify(self._conn, disk.id, disk_file, read_size = 128, read_tries = 1)
+
+ cursor.execute(
+ "select disk_id from diskjumble.verify_piece_content natural join diskjumble.verify_piece_incomplete;"
+ )
+ self.assertEqual(cursor.fetchall(), [(other_disk.id,)])
+
+ cursor.execute("select disk_id, disk_sectors from diskjumble.verify_pass;")
+ self.assertEqual(cursor.fetchall(), [(disk.id, NumericRange(0, disk.sector_count))])
+
+ def test_transient_read_errors(self):
+ """
+ Test a run where a read to the disk fails but fewer times than needed to mark the sector bad.
+ """
+ piece_size = 32
+
+ disk = self._write_disk(32, 1)
+ with _random_file(piece_size, Random(0), on_disk = False) as torrent_file:
+ torrent = _Torrent(torrent_file, piece_size)
+ self._write_torrent(torrent)
+ with self._conn.cursor() as cursor:
+ cursor.execute(
+ "insert into diskjumble.slab values (default, %s, %s, %s, %s, null);",
+ (disk.id, NumericRange(0, 1), torrent.info_hash, 0)
+ )
+
+ disk_file = _ReadErrorProxy(torrent_file, error_pos = 12, error_count = 2)
+ do_verify(self._conn, disk.id, disk_file, read_size = 4, read_tries = 3)
+
+ self.assertEqual(disk_file.triggered, 2)
+
+ cursor.execute("select count(*) from diskjumble.verify_piece;")
+ [(row_count,)] = cursor.fetchall()
+ self.assertEqual(row_count, 0)
+
+ cursor.execute("select disk_id, disk_sectors from diskjumble.verify_pass;")
+ self.assertEqual(cursor.fetchall(), [(disk.id, NumericRange(0, 1))])
+
+ def test_persistent_read_errors(self):
+ """
+ Test a run where a disk read fails enough times to trigger the bad sector logic.
+ """
+ piece_size = 64
+
+ other_a = self._write_disk(16, 1)
+ other_b = self._write_disk(16, 2)
+ disk = self._write_disk(16, 1)
+ with _random_file(piece_size, Random(0), on_disk = False) as torrent_file:
+ torrent = _Torrent(torrent_file, piece_size)
+ self._write_torrent(torrent)
+ with self._conn.cursor() as cursor:
+ cursor.executemany(
+ "insert into diskjumble.slab values (default, %s, %s, %s, %s, null);",
+ [
+ (other_a.id, NumericRange(0, 1), torrent.info_hash, 0),
+ (other_b.id, NumericRange(0, 2), torrent.info_hash, 16),
+ (disk.id, NumericRange(0, 1), torrent.info_hash, 48),
+ ]
+ )
+
+ do_verify(self._conn, other_a.id, torrent_file, read_size = 16, read_tries = 1)
+ other_b_file = io.BytesIO(torrent_file.getvalue()[16:48])
+ do_verify(self._conn, other_b.id, other_b_file, read_size = 16, read_tries = 1)
+
+ cursor.execute("select verify_id from diskjumble.verify_piece;")
+ [(verify_id,)] = cursor.fetchall()
- # TODO read errors
+ data = torrent_file.getvalue()[48:]
+ disk_file = _ReadErrorProxy(io.BytesIO(data), error_pos = 5, error_count = None)
+ do_verify(self._conn, disk.id, disk_file, read_size = 4, read_tries = 3)
+
+ cursor.execute("select count(*) from diskjumble.verify_pass;")
+ [(row_count,)] = cursor.fetchall()
+ self.assertEqual(row_count, 0)
+
+ cursor.execute("select disk_id, disk_sectors from diskjumble.verify_piece_content;")
+ self.assertCountEqual(
+ cursor.fetchall(),
+ [(other_a.id, NumericRange(0, 1)), (other_b.id, NumericRange(0, 2)), (disk.id, NumericRange(0, 1))]
+ )
+
+ cursor.execute("select verify_id from diskjumble.verify_piece_fail;")
+ self.assertEqual(cursor.fetchall(), [(verify_id,)])
+
+ cursor.execute("select count(*) from diskjumble.verify_piece_incomplete;")
+ [(row_count,)] = cursor.fetchall()
+ self.assertEqual(row_count, 0)
def _write_torrent(self, torrent: "_Torrent") -> None:
with self._conn.cursor() as cursor:
self.data = data
self.info = bencode.encode(info_dict)
self.info_hash = hashlib.sha1(self.info).digest()
+
+
+def _random_file(size: int, rand_src: Random, on_disk: bool) -> tempfile.NamedTemporaryFile:
+ f = tempfile.NamedTemporaryFile(buffering = _BUF_SIZE) if on_disk else io.BytesIO()
+ try:
+ while f.tell() < size:
+ write_size = min(size - f.tell(), _BUF_SIZE)
+ f.write(bytes(rand_src.getrandbits(8) for _ in range(write_size)))
+ f.seek(0)
+ return f
+ except Exception:
+ f.close()
+ raise
+
+
+@dataclass
+class _ReadErrorProxy(io.BufferedIOBase):
+ wrapped: io.BufferedIOBase
+ error_pos: int
+ error_count: Optional[int]
+
+ def __post_init__(self):
+ self.triggered = 0
+
+ def read(self, size: int = -1) -> bytes:
+ pre_pos = self.wrapped.tell()
+ data = self.wrapped.read(size)
+ erroring = self.error_count is None or self.triggered < self.error_count
+ in_range = 0 <= pre_pos - self.error_pos < len(data)
+ if erroring and in_range:
+ self.triggered += 1
+ raise OSError("simulated")
+ else:
+ return data
+
+ def seek(self, *args, **kwargs) -> int:
+ return self.wrapped.seek(*args, **kwargs)