Add support for database TLS client cert auth
authorJakob Cornell <jakob+gpg@jcornell.net>
Sun, 25 Sep 2022 18:03:16 +0000 (13:03 -0500)
committerJakob Cornell <jakob+gpg@jcornell.net>
Sun, 25 Sep 2022 18:03:16 +0000 (13:03 -0500)
strikebot/docs/sample_config.ini
strikebot/setup.cfg
strikebot/src/strikebot/__main__.py

index 423f22023b15df587f46c47fc3c43759e2e35f6f..34dec077b80e8cec3e2718e5ec98aa725bd4a025 100644 (file)
@@ -58,6 +58,8 @@ WS silent limit = 30
 WS warmup time = 0.3
 
 
-# Postgres database configuration; same options as in a connect string
+# Postgres database configuration; same options as in a connect string. Note that because Python's `SSLContext' is used
+# to implement TLS client auth, using `sslkey' to specify a key obtained from an OpenSSL engine may not be supported (a
+# file path is the only option).
 [db connect params]
 host = example.org
index f200179a3adeb4c0a412e451c169d38a20f88401..29e5900e48483cd7f36a24ac63d59384aeaaf6d9 100644 (file)
@@ -1,6 +1,6 @@
 [metadata]
 name = strikebot
-version = 0.0.0
+version = 0.0.2
 
 [options]
 package_dir =
index a655692a5dee2b8ec8f68c98ca5db427f260ff07..da645b93eedf8f7ae24abc69867cc778fde8d537 100644 (file)
@@ -10,6 +10,7 @@ import argparse
 import configparser
 import datetime as dt
 import logging
+import ssl
 
 from trio import open_memory_channel, open_nursery, open_signal_receiver
 
@@ -138,11 +139,26 @@ ws_pool_size = main_cfg.getint("WS pool size")
 ws_silent_limit = dt.timedelta(seconds = main_cfg.getfloat("WS silent limit"))
 ws_warmup = dt.timedelta(seconds = main_cfg.getfloat("WS warmup seconds"))
 
+
 db_cfg = parser["db connect params"]
+cert_path = db_cfg.pop("sslcert", None)
+key_path = db_cfg.pop("sslkey", None)
+key_pass = db_cfg.pop("sslpassword", None)
+
 getters = {
        "port": db_cfg.getint,
 }
-db_connect_params = {k: getters.get(k, db_cfg.get)(k) for k in db_cfg}
+db_connect_kwargs = {k: getters.get(k, db_cfg.get)(k) for k in db_cfg}
+
+if cert_path:
+       # patch TLS client cert auth support; asyncpg adds it in a later version
+       db_ssl_ctx = ssl.create_default_context()
+       db_ssl_ctx.load_cert_chain(
+               cert_path,
+               key_path,
+               None if key_pass is None else lambda: key_pass,
+       )
+       db_connect_kwargs["ssl"] = db_ssl_ctx
 
 
 if read_only and enforcing:
@@ -166,7 +182,7 @@ if _DEBUG_LOG_PATH:  # TODO remove this ad hoc setup
 
 async def main():
        async with (
-               triopg.connect(**db_connect_params) as db_conn,
+               triopg.connect(**db_connect_kwargs) as db_conn,
                open_nursery() as nursery_a,
                open_nursery() as nursery_b,
                open_nursery() as ws_pool_nursery,