Implement handling of disk read errors
authorJakob Cornell <jakob+gpg@jcornell.net>
Sat, 13 Nov 2021 05:25:38 +0000 (23:25 -0600)
committerJakob Cornell <jakob+gpg@jcornell.net>
Sat, 13 Nov 2021 05:25:38 +0000 (23:25 -0600)
src/disk_jumble/tests/test_verify.py
src/disk_jumble/verify.py

index 538f5fb1229e94e6e301fd0ca864ae7ea507af2c..8443bcefa918b4f93b547b390d18099b056a78a6 100644 (file)
@@ -62,7 +62,7 @@ class Tests(unittest.TestCase):
                                        (disk.id, NumericRange(0, disk.sector_count), torrent.info_hash)
                                )
 
-                               do_verify(self._conn, disk.id, torrent_file, read_size)
+                               do_verify(self._conn, disk.id, torrent_file, read_size, 1)
 
                                cursor.execute("select * from diskjumble.verify_pass;")
                                self.assertEqual(cursor.rowcount, 1)
index f403c8d8569050c5e3c02844a027a5668304d0d5..0adbb037cce77119461543342c4e612bd24d2282 100644 (file)
@@ -36,7 +36,11 @@ class _PieceTask:
        complete: bool  # do these chunks complete the piece?
 
 
-def do_verify(conn, disk_id: int, disk_file: io.BufferedIOBase, read_size: int) -> None:
+class _BadSector(Exception):
+       pass
+
+
+def do_verify(conn, disk_id: int, disk_file: io.BufferedIOBase, read_size: int, read_tries: int) -> None:
        db = DbWrapper(conn)
        disk = db.get_disk(disk_id)
 
@@ -99,14 +103,14 @@ def do_verify(conn, disk_id: int, disk_file: io.BufferedIOBase, read_size: int)
                entity_id: bytes
                piece_num: int
                sector_ranges: list[range]
-               hasher_state: dict
+               hasher_state: Optional[dict]
                failed: bool
 
        @dataclass
        class VerifyUpdate:
                seq_start: int
                new_sector_ranges: list[range]
-               hasher_state: dict
+               hasher_state: Optional[dict]
 
        passed_verifies = set()
        failed_verifies = set()
@@ -125,39 +129,56 @@ def do_verify(conn, disk_id: int, disk_file: io.BufferedIOBase, read_size: int)
                        for chunk in task.chunks
                ]
 
-               for chunk in task.chunks:
-                       slab_off = chunk.slab.sectors.start * disk.sector_size
-                       disk_file.seek(slab_off + chunk.slice.start)
-                       end_pos = slab_off + chunk.slice.stop
-                       while disk_file.tell() < end_pos:
-                               data = disk_file.read(min(end_pos - disk_file.tell(), read_size))
-                               assert data
-                               hasher.update(data)
-
-               hasher_state = hasher.ctx.serialize()
-               if task.complete:
-                       s = slice(task.piece_num * 20, task.piece_num * 20 + 20)
-                       expected_hash = info_dicts[task.entity_id][b"pieces"][s]
-                       if hasher.digest() == expected_hash:
-                               write_piece_data = False
-                               new_pass_ranges.extend(sector_ranges)
-                               if task.hasher_ref:
-                                       passed_verifies.add(task.hasher_ref.id)
-                       else:
-                               failed = True
-                               write_piece_data = True
-                               if task.hasher_ref:
-                                       failed_verifies.add(task.hasher_ref.id)
-               else:
-                       failed = False
-                       write_piece_data = True
+               try:
+                       for chunk in task.chunks:
+                               slab_off = chunk.slab.sectors.start * disk.sector_size
+                               disk_file.seek(slab_off + chunk.slice.start)
+                               end_pos = slab_off + chunk.slice.stop
+                               while disk_file.tell() < end_pos:
+                                       pos = disk_file.tell()
+                                       for _ in range(read_tries):
+                                               try:
+                                                       data = disk_file.read(min(end_pos - pos, read_size))
+                                               except OSError as e:
+                                                       disk_file.seek(pos)
+                                               else:
+                                                       break
+                                       else:
+                                               raise _BadSector()
 
-               if write_piece_data:
+                                       assert data
+                                       hasher.update(data)
+               except _BadSector:
                        if task.hasher_ref:
-                               assert task.hasher_ref.id not in vp_updates
-                               vp_updates[task.hasher_ref.id] = VerifyUpdate(task.hasher_ref.seq + 1, sector_ranges, hasher_state)
+                               failed_verifies.add(task.hasher_ref.id)
+                               vp_updates[task.hasher_ref.id] = VerifyUpdate(task.hasher_ref.seq + 1, sector_ranges, None)
                        else:
-                               new_vps.append(NewVerifyPiece(task.entity_id, task.piece_num, sector_ranges, hasher_state, failed))
+                               new_vps.append(NewVerifyPiece(task.entity_id, task.piece_num, sector_ranges, None, True))
+               else:
+                       hasher_state = hasher.ctx.serialize()
+                       if task.complete:
+                               s = slice(task.piece_num * 20, task.piece_num * 20 + 20)
+                               expected_hash = info_dicts[task.entity_id][b"pieces"][s]
+                               if hasher.digest() == expected_hash:
+                                       write_piece_data = False
+                                       new_pass_ranges.extend(sector_ranges)
+                                       if task.hasher_ref:
+                                               passed_verifies.add(task.hasher_ref.id)
+                               else:
+                                       failed = True
+                                       write_piece_data = True
+                                       if task.hasher_ref:
+                                               failed_verifies.add(task.hasher_ref.id)
+                       else:
+                               failed = False
+                               write_piece_data = True
+
+                       if write_piece_data:
+                               if task.hasher_ref:
+                                       assert task.hasher_ref.id not in vp_updates
+                                       vp_updates[task.hasher_ref.id] = VerifyUpdate(task.hasher_ref.seq + 1, sector_ranges, hasher_state)
+                               else:
+                                       new_vps.append(NewVerifyPiece(task.entity_id, task.piece_num, sector_ranges, hasher_state, failed))
 
        new_pass_ranges.sort(key = lambda r: r.start)
        merged_ranges = []
@@ -177,7 +198,8 @@ def do_verify(conn, disk_id: int, disk_file: io.BufferedIOBase, read_size: int)
 
        for (verify_id, update) in vp_updates.items():
                db.insert_verify_piece_content(verify_id, update.seq_start, disk_id, update.new_sector_ranges)
-               db.upsert_hasher_state(verify_id, update.hasher_state)
+               if update.hasher_state:
+                       db.upsert_hasher_state(verify_id, update.hasher_state)
 
        for verify_id in passed_verifies:
                db.move_piece_content_for_pass(verify_id)
@@ -192,12 +214,24 @@ def do_verify(conn, disk_id: int, disk_file: io.BufferedIOBase, read_size: int)
 
 
 if __name__ == "__main__":
+       def read_tries(raw_arg):
+               val = int(raw_arg)
+               if val > 0:
+                       return val
+               else:
+                       raise ValueError()
+
        parser = argparse.ArgumentParser()
        parser.add_argument("disk_id", type = int)
+       parser.add_argument(
+               "read_tries",
+               type = read_tries,
+               help = "number of times to attempt a particular disk read before giving up on the sector",
+       )
        args = parser.parse_args()
 
        with contextlib.closing(psycopg2.connect("")) as conn:
                conn.autocommit = True
                path = f"/dev/mapper/diskjumble-{args.disk_id}"
-               with open(path, "rb", buffering = _READ_BUFFER_SIZE, read_size = _READ_BUFFER_SIZE) as disk_file:
-                       do_verify(conn, args.disk_id, disk_file)
+               with open(path, "rb", buffering = _READ_BUFFER_SIZE) as disk_file:
+                       do_verify(conn, args.disk_id, disk_file, _READ_BUFFER_SIZE, args.read_tries)