From 01bb3a9a8609401db639f2af2c2b7b0383367dcc Mon Sep 17 00:00:00 2001 From: Jakob Cornell Date: Thu, 21 Apr 2022 20:21:26 -0500 Subject: [PATCH] Stream from the database to reduce memory footprint --- disk_jumble/src/disk_jumble/verify.py | 151 +++++++++++++------------- 1 file changed, 77 insertions(+), 74 deletions(-) diff --git a/disk_jumble/src/disk_jumble/verify.py b/disk_jumble/src/disk_jumble/verify.py index b1df745..12c7d9d 100644 --- a/disk_jumble/src/disk_jumble/verify.py +++ b/disk_jumble/src/disk_jumble/verify.py @@ -1,6 +1,7 @@ from __future__ import annotations from dataclasses import dataclass -from typing import List, Optional +from typing import Iterator, List, Optional +from warnings import warn import argparse import contextlib import hashlib @@ -67,7 +68,8 @@ def _run_sort_key(run: _Run): def _get_target_ranges(conn, disk_id: int, limit: Optional[int]) -> List[range]: ranges = [] block_count = 0 - with conn.cursor() as cursor: + with conn: + cursor = conn.cursor("target_ranges") cursor.execute( """ select unnest(written_map - coalesce(verified_map, int8multirange())) as range @@ -105,7 +107,7 @@ def _get_v1_worklist(conn, disk_id: int, target_ranges: List[NumericRange]) -> L The runs are then reordered by the number of their first disk block. """ - cursor = conn.cursor() + cursor = conn.cursor("v1_worklist") cursor.execute( """ select distinct on (entity_id) @@ -206,8 +208,8 @@ def _get_v1_worklist(conn, disk_id: int, target_ranges: List[NumericRange]) -> L return runs -def _get_v2_worklist(conn, disk_id: int, target_ranges: List[NumericRange]) -> List[_V2Run]: - cursor = conn.cursor() +def _get_v2_worklist(conn, disk_id: int, target_ranges: List[NumericRange]) -> Iterator[_V2Run]: + cursor = conn.cursor("v2_worklist") cursor.execute( """ with @@ -259,85 +261,86 @@ def _get_v2_worklist(conn, disk_id: int, target_ranges: List[NumericRange]) -> L "target_ranges": target_ranges, } ) - rows = cursor.fetchall() - - if any(crypt_key is not None for (*_, crypt_key) in rows): - raise NotImplementedError("verify of encrypted data") - - return [ - _V2Run( - entity_id = bytes(entity_id), - entity_length = entity_length, - piece_num = piece_num, - block_ranges = [range(block, block + 1)], - hash = bytes(hash_), - ) - for (entity_id, entity_length, piece_num, block, hash_, _) in rows - ] + for (entity_id, entity_length, piece_num, block, hash_, crypt_key) in cursor: + if crypt_key is None: + yield _V2Run( + entity_id = bytes(entity_id), + entity_length = entity_length, + piece_num = piece_num, + block_ranges = [range(block, block + 1)], + hash = bytes(hash_), + ) + else: + raise NotImplementedError("verify of encrypted data") -def _do_verify(conn, disk_id: int, target_ranges: List[range], disk_file: io.BufferedIOBase, read_size: int, read_tries: int): - pg_target_ranges = [NumericRange(r.start, r.stop) for r in target_ranges] - worklist = list(heapq.merge( - _get_v1_worklist(conn, disk_id, pg_target_ranges), - _get_v2_worklist(conn, disk_id, pg_target_ranges), - key = _run_sort_key, - )) +def _do_verify(conn, disk_id: int, target_ranges: List[range], disk_file: io.BufferedIOBase, read_size: int, read_tries: int): requested_blocks = { block for r in target_ranges for block in r } - covered_blocks = { - block - for run in worklist - for block_range in run.block_ranges - for block in block_range - } - missing = requested_blocks - covered_blocks - if missing: - raise RuntimeError(f"unable to locate blocks: {len(missing)} in the range {min(missing)} to {max(missing)}") - - passes = [] - fails = [] - for run in worklist: - if isinstance(run, _V1Run): - hasher = hashlib.sha1() - entity_off = run.piece_num * run.piece_length - else: - hasher = hashlib.sha256() - entity_off = run.piece_num * BLOCK_SIZE - assert len(run.hash) == hasher.digest_size, "incorrect validation hash length" - - try: - for range_ in run.block_ranges: - for block in range_: - read_start = block * BLOCK_SIZE - read_end = read_start + min(BLOCK_SIZE, run.entity_length - entity_off) - disk_file.seek(read_start) - while disk_file.tell() < read_end: - pos = disk_file.tell() - for _ in range(read_tries): - try: - data = disk_file.read(min(read_end - pos, read_size)) - except OSError: - disk_file.seek(pos) - else: - break - else: - raise _BadSector() - - assert data - hasher.update(data) - entity_off += BLOCK_SIZE - except _BadSector: - fails.extend(run.block_ranges) - else: - if hasher.digest() == run.hash: - passes.extend(run.block_ranges) + pg_target_ranges = [NumericRange(r.start, r.stop) for r in target_ranges] + + # transaction is required for named cursors in worklist generation + with conn: + worklist = heapq.merge( + _get_v1_worklist(conn, disk_id, pg_target_ranges), + _get_v2_worklist(conn, disk_id, pg_target_ranges), + key = _run_sort_key, + ) + + covered_blocks = set() + passes = [] + fails = [] + for run in worklist: + if isinstance(run, _V1Run): + hasher = hashlib.sha1() + entity_off = run.piece_num * run.piece_length else: + hasher = hashlib.sha256() + entity_off = run.piece_num * BLOCK_SIZE + assert len(run.hash) == hasher.digest_size, "incorrect validation hash length" + + try: + for range_ in run.block_ranges: + for block in range_: + read_start = block * BLOCK_SIZE + read_end = read_start + min(BLOCK_SIZE, run.entity_length - entity_off) + disk_file.seek(read_start) + while disk_file.tell() < read_end: + pos = disk_file.tell() + for _ in range(read_tries): + try: + data = disk_file.read(min(read_end - pos, read_size)) + except OSError: + disk_file.seek(pos) + else: + break + else: + raise _BadSector() + + assert data + hasher.update(data) + entity_off += BLOCK_SIZE + except _BadSector: fails.extend(run.block_ranges) + else: + if hasher.digest() == run.hash: + passes.extend(run.block_ranges) + else: + fails.extend(run.block_ranges) + + covered_blocks.update( + block + for block_range in run.block_ranges + for block in block_range + ) + + missing = requested_blocks - covered_blocks + if missing: + warn(f"unable to locate blocks: {len(missing)} in the range {min(missing)} to {max(missing)}") def clean_up(ranges): out = [] -- 2.30.2