Batch verify pass writes to decrease database I/O overhead
authorJakob Cornell <jakob+gpg@jcornell.net>
Tue, 19 Apr 2022 22:45:10 +0000 (17:45 -0500)
committerJakob Cornell <jakob+gpg@jcornell.net>
Tue, 19 Apr 2022 23:30:15 +0000 (18:30 -0500)
disk_jumble/src/disk_jumble/verify.py

index 3f762c98323ee411c49138e0fb74129255eed779..07dc5e9d0fae12b62780764b6e19650e45bd1ba7 100644 (file)
@@ -1,6 +1,6 @@
 from __future__ import annotations
 from dataclasses import dataclass
-from typing import List, Optional
+from typing import Iterator, List, Optional, Tuple
 import argparse
 import contextlib
 import datetime as dt
@@ -10,7 +10,7 @@ import io
 import itertools
 import math
 
-from psycopg2.extras import NumericRange
+from psycopg2.extras import execute_batch, NumericRange
 import psycopg2
 
 from disk_jumble import bencode, BLOCK_SIZE
@@ -263,22 +263,6 @@ def _do_verify(conn, disk_id: int, target_ranges: Optional[List[range]], disk_fi
        class Fail:
                run: _Run
 
-       def merge_results(results):
-               curr_pass = None
-               for r in results:
-                       if isinstance(r, Pass) and curr_pass and r.blocks.start <= curr_pass.stop:
-                               curr_pass = range(curr_pass.start, max(r.blocks.stop, curr_pass.stop))
-                       else:
-                               if curr_pass:
-                                       yield Pass(curr_pass)
-                                       curr_pass = None
-                               if isinstance(r, Pass):
-                                       curr_pass = r.blocks
-                               else:
-                                       yield r
-               if curr_pass:
-                       yield Pass(curr_pass)
-
        if target_ranges is None:
                pg_target_ranges = [NumericRange()]
        else:
@@ -344,21 +328,48 @@ def _do_verify(conn, disk_id: int, target_ranges: Optional[List[range]], disk_fi
                                else:
                                        yield Fail(run)
 
+       def batch_results(results) -> Iterator[Tuple[List[range], List[_Run]]]:
+               # Group verify results into batches for output to the database
+               passes = []
+               fails = []
+               batch_size = 0
+
+               for r in results:
+                       if batch_size > 500:
+                               yield (passes, fails)
+                               passes = []
+                               fails = []
+                               batch_size = 0
+                       if isinstance(r, Pass):
+                               if passes and r.blocks.start <= passes[-1].stop:
+                                       new_stop = max(r.blocks.stop, passes[-1].stop)
+                                       batch_size += (new_stop - passes[-1].stop)
+                                       passes[-1] = range(passes[-1].start, new_stop)
+                               else:
+                                       passes.append(r.blocks)
+                                       batch_size += len(r.blocks)
+                       else:
+                               fails.append(r.run)
+                               batch_size += sum(map(len, r.run.block_ranges))
+               if passes or fails:
+                       yield (passes, fails)
+
        cursor = conn.cursor()
        ts = dt.datetime.now(dt.timezone.utc)
-       for result in merge_results(generate_results()):
-               if isinstance(result, Pass):
-                       cursor.execute(
+       for (pass_ranges, fail_runs) in batch_results(generate_results()):
+               if pass_ranges:
+                       execute_batch(
+                               cursor,
                                """
                                        insert into diskjumble.verify_pass (verify_pass_id, at, disk_id, disk_sectors)
                                        values (default, %s, %s, %s)
                                """,
-                               (ts, disk_id, NumericRange(result.blocks.start, result.blocks.stop))
+                               [(ts, disk_id, NumericRange(r.start, r.stop)) for r in pass_ranges],
+                               page_size = len(pass_ranges),
                        )
-               else:
-                       assert isinstance(result, Fail)
-                       run = result.run
-                       cursor.execute(
+               if fail_runs:
+                       execute_batch(
+                               cursor,
                                """
                                        with
                                                new_piece as (
@@ -376,13 +387,17 @@ def _do_verify(conn, disk_id: int, target_ranges: Optional[List[range]], disk_fi
                                        insert into diskjumble.verify_piece_fail (verify_id)
                                        select verify_id from new_piece
                                """,
-                               {
-                                       "ts": ts,
-                                       "entity_id": run.entity_id,
-                                       "piece_num": run.piece_num,
-                                       "disk_id": disk_id,
-                                       "ranges": [NumericRange(r.start, r.stop) for r in run.block_ranges],
-                               }
+                               [
+                                       {
+                                               "ts": ts,
+                                               "entity_id": run.entity_id,
+                                               "piece_num": run.piece_num,
+                                               "disk_id": disk_id,
+                                               "ranges": [NumericRange(r.start, r.stop) for r in run.block_ranges],
+                                       }
+                                       for run in fail_runs
+                               ],
+                               page_size = len(fail_runs),
                        )