add cod
authorAnders Cornell <anders@acorn.pw>
Tue, 1 Mar 2022 16:31:33 +0000 (11:31 -0500)
committerAnders Cornell <anders@acorn.pw>
Tue, 1 Mar 2022 16:36:21 +0000 (11:36 -0500)
disk_jumble/src/disk_jumble/disk.py [new file with mode: 0644]

diff --git a/disk_jumble/src/disk_jumble/disk.py b/disk_jumble/src/disk_jumble/disk.py
new file mode 100644 (file)
index 0000000..8ddb41c
--- /dev/null
@@ -0,0 +1,210 @@
+import os
+import uuid
+import subprocess
+import psycopg2
+import contextlib
+import struct
+
+@contextlib.contextmanager
+def openfd(path, flags, **kwargs):
+       fd = os.open(path, flags, **kwargs)
+       try:
+               yield fd
+       finally:
+               os.close(fd)
+
+class Device:
+       magic = uuid.UUID('f4bb403e-f8fa-4923-b342-b5e35da4a346').bytes
+
+       def __init__(self, fd):
+               self.fd = fd
+               self.rdev = divmod(os.fstat(self.fd).st_rdev, 256)
+               self.rdev_str = '{}:{}'.format(*self.rdev)
+               self.sector_size = int(open('/sys/dev/block/{}/queue/logical_block_size'.format(self.rdev_str)).read())
+               size_in_512b = int(open('/sys/dev/block/{}/size'.format(self.rdev_str)).read())
+               if self.sector_size == 4096:
+                       assert size_in_512b % 8 == 0
+                       self.num_sectors = size_in_512b / 8
+                       self.header_sectors = 256
+               elif self.sector_size == 512:
+                       self.num_sectors = size_in_512b
+                       self.header_sectors = 2048
+               else:
+                       raise RuntimeError('expected sector size to be 512 or 4096')
+               assert self.num_sectors > self.header_sectors
+
+       def get_serial(self): # or roll this into __init__
+               return None # TODO
+
+       @staticmethod
+       def find(dev_uuid, dev_serial=None):
+               with openfd('/dev/disk/by-id', os.O_RDONLY | os.O_DIRECTORY) as disks:
+                       for entry in os.scandir(disks):
+                               name_components = entry.name.split('-')
+                               name_attachment = name_components[0]
+                               if name_attachment == 'scsi':
+                                       name_serial = name_components[1]
+               #               elif name_attachment == 'ata':
+               #                       name_serial = name_components[3]
+                               if name_attachment not in ('scsi',):
+                                       continue
+                               if dev_serial and name_serial != dev_serial:
+                                       continue
+                               dev_fd = os.open(entry.name, os.O_RDONLY, dir_fd=disks)
+                               try:
+                                       device = Device(dev_fd)
+                                       header_data = device.read_header()
+                                       if header_data is None or header_data[0] != dev_uuid:
+                                               os.close(dev_fd)
+                                               continue
+                                       assert header_data[3] == self.num_blocks
+                               except:
+                                       os.close(dev_fd)
+                                       raise
+                               else:
+                                       self.device = device
+                                       break
+                       else:
+                               raise RuntimeError("Didn't find disk {}".format(disk_id))
+
+       def read_header(self):
+               assert self.fd is not None
+               found_headers = set()
+               for i in range(0, self.header_sectors):
+                       buf = os.pread(self.fd, 56, i*self.sector_size)
+                       if buf[0:16] == Device.magic:
+                               found_headers.add(buf)
+               if found_headers:
+                       [found_header] = found_headers
+                       (dev_uuid, sector_size, header_sectors, data_sectors) = struct.unpack('>16x16sQQQ', found_header)
+                       dev_uuid = uuid.UUID(bytes=dev_uuid)
+                       assert sector_size == self.sector_size
+                       assert header_sectors == self.header_sectors
+                       sectors_per_block = 16384 / self.sector_size
+                       assert data_sectors % sectors_per_block == 0
+                       num_blocks = data_sectors / sectors_per_block
+                       assert self.header_sectors + data_sectors <= self.num_sectors
+                       return (dev_uuid, self.sector_size, self.header_sectors, num_blocks)
+               else:
+                       return None
+
+       def write_header(self, dev_uuid, num_blocks=None, force=False):
+               if (not force) and self.read_header():
+                       raise RuntimeError('Refusing to overwrite existing diskjumble header')
+               data_sectors_avail = self.num_sectors - self.header_sectors
+               sectors_per_block = 16384 / self.sector_size
+               if num_blocks is None:
+                       num_blocks = data_sectors_avail // sectors_per_block
+               data_sectors = num_blocks * sectors_per_block
+               assert self.header_sectors + data_sectors <= self.num_sectors
+               header_data = struct.pack('>16s16sQQQ', Device.magic, dev_uuid.bytes, self.sector_size, self.header_sectors, data_sectors).ljust(sector_size, b'\0')
+               for i in range(0, self.header_sectors):
+                       os.pwrite(self.fd, header_data, i*sector_size)
+       
+       def close(self):
+               if self.fd:
+                       os.close(self.fd)
+                       self.fd = None
+
+class Disk:
+       def __init__(self, conn, disk_id=None, dev_uuid=None, dev_serial=None, num_blocks=None, device=None):
+               self.conn = conn
+
+               self.disk_id = disk_id
+               self.dev_uuid = dev_uuid
+               self.dev_serial = dev_serial
+               self.num_blocks = num_blocks
+
+               self.device = device
+
+       @staticmethod
+       def load(conn, disk_id=None, devname=None):
+               with conn.cursor() as cursor:
+                       device = None
+                       dev_uuid = None
+                       if devname is not None:
+                               device = Device(os.open(devname, os.O_RDWR))
+                               header_data = device.read_header()
+                               if header_data:
+                                       dev_uuid = header_data[0]
+                       cursor.execute('''SELECT disk_id, dev_uuid, dev_serial, num_blocks
+                                       FROM diskjumble.disk
+                                       WHERE coalesce(disk_id=%s, true) AND coalesce(dev_uuid=%s, true)''', (disk_id, dev_uuid))
+                       [(row)] = cursor.fetchall()
+                       disk_id = row[0]
+                       dev_uuid = uuid.UUID(row[1])
+                       dev_serial = row[2]
+                       num_blocks = row[3]
+
+                       if devname is None:
+                               device = Device.find(dev_uuid, dev_serial)
+
+                       return Disk(conn, disk_id=disk_id, dev_uuid=dev_uuid,
+                                       dev_serial=dev_serial, num_blocks=num_blocks, device=device)
+
+       def save(self):
+               with self.conn.cursor() as cursor:
+                       cursor.execute("INSERT INTO diskjumble.disk (disk_id, dev_uuid, dev_serial, num_blocks) VALUES (%s, %s, %s, %s) RETURNING disk_id", (self.disk_id, str(self.dev_uuid), self.dev_serial, self.num_blocks))
+                       [(self.disk_id)] = cursor.fetchall()
+
+       @staticmethod
+       def create(conn, devname, disk_id=None, num_blocks=None):
+               device = Device(os.open(devname, os.O_RDWR))
+
+               data_sectors_avail = device.num_sectors - device.header_sectors
+               sectors_per_block = 16384 / device.sector_size
+
+               if num_blocks is None:
+                       num_blocks = data_sectors_avail // sectors_per_block
+
+               return Disk(conn, disk_id=disk_id, dev_uuid=uuid.uuid4(), dev_serial=device.get_serial(),
+                               num_blocks=num_blocks, device=device)
+
+       def close(self):
+               if self.device is not None:
+                       self.device.close()
+
+       def write_header(self, force=False):
+               self.device.write_header(self.dev_uuid, force)
+
+       def mapping_name(self):
+               assert self.disk_id is not None
+               return 'diskjumble-{}'.format(self.disk_id)
+
+       def create_mapping(self):
+               dm_size = 32 * self.num_blocks # 16384 / 512
+               dm_dev = self.device.rdev_str
+               subprocess.run(['sudo', 'dmsetup', 'create', '--concise', '{},,,rw,0 {} linear {} 2048'.format(self.mapping_name(), dm_size, dm_dev)], check=True)
+
+       def destroy_mapping(self):
+               subprocess.run(['sudo', 'dmsetup', 'remove', '--retry', self.mapping_name()], check=True)
+
+if __name__ == '__main__':
+       import argparse
+       arg_parser = argparse.ArgumentParser()
+       arg_parser.add_argument('--id', help='integer disk_id')
+       arg_parser.add_argument('--dev', help='path to underlying device')
+       group = arg_parser.add_mutually_exclusive_group()
+       group.add_argument('--up', action='store_true', help='create device mapper mapping')
+       group.add_argument('--down', action='store_true', help='destroy device mapper mapping')
+       arg_parser.add_argument('--new', action='store_true', help='initialize a new disk (implies --format)')
+       arg_parser.add_argument('--format', action='store_true', help='write diskjumble header')
+       arg_parser.add_argument('--force', action='store_true', help='write header even if there is one already')
+       args = arg_parser.parse_args()
+
+       with contextlib.closing(psycopg2.connect("")) as conn:
+               conn.autocommit = 'off'
+               if args.new:
+                       args.format = True
+                       inst = Disk.create(conn, args.dev, disk_id=args.id)
+               else:
+                       inst = Disk.load(conn, disk_id=args.id, devname=args.dev)
+               if args.down:
+                       inst.destroy_mapping()
+               if args.format:
+                       inst.write_header(args.force)
+               if args.up:
+                       inst.create_mapping()
+               if args.new:
+                       inst.save(conn)
+                       conn.commit()