From 342ebfba8c6266908065d5804f9f858962670991 Mon Sep 17 00:00:00 2001 From: Jakob Cornell Date: Sun, 25 Sep 2022 13:03:16 -0500 Subject: [PATCH] Add support for database TLS client cert auth --- strikebot/docs/sample_config.ini | 4 +++- strikebot/setup.cfg | 2 +- strikebot/src/strikebot/__main__.py | 20 ++++++++++++++++++-- 3 files changed, 22 insertions(+), 4 deletions(-) diff --git a/strikebot/docs/sample_config.ini b/strikebot/docs/sample_config.ini index 423f220..34dec07 100644 --- a/strikebot/docs/sample_config.ini +++ b/strikebot/docs/sample_config.ini @@ -58,6 +58,8 @@ WS silent limit = 30 WS warmup time = 0.3 -# Postgres database configuration; same options as in a connect string +# Postgres database configuration; same options as in a connect string. Note that because Python's `SSLContext' is used +# to implement TLS client auth, using `sslkey' to specify a key obtained from an OpenSSL engine may not be supported (a +# file path is the only option). [db connect params] host = example.org diff --git a/strikebot/setup.cfg b/strikebot/setup.cfg index f200179..29e5900 100644 --- a/strikebot/setup.cfg +++ b/strikebot/setup.cfg @@ -1,6 +1,6 @@ [metadata] name = strikebot -version = 0.0.0 +version = 0.0.2 [options] package_dir = diff --git a/strikebot/src/strikebot/__main__.py b/strikebot/src/strikebot/__main__.py index a655692..da645b9 100644 --- a/strikebot/src/strikebot/__main__.py +++ b/strikebot/src/strikebot/__main__.py @@ -10,6 +10,7 @@ import argparse import configparser import datetime as dt import logging +import ssl from trio import open_memory_channel, open_nursery, open_signal_receiver @@ -138,11 +139,26 @@ ws_pool_size = main_cfg.getint("WS pool size") ws_silent_limit = dt.timedelta(seconds = main_cfg.getfloat("WS silent limit")) ws_warmup = dt.timedelta(seconds = main_cfg.getfloat("WS warmup seconds")) + db_cfg = parser["db connect params"] +cert_path = db_cfg.pop("sslcert", None) +key_path = db_cfg.pop("sslkey", None) +key_pass = db_cfg.pop("sslpassword", None) + getters = { "port": db_cfg.getint, } -db_connect_params = {k: getters.get(k, db_cfg.get)(k) for k in db_cfg} +db_connect_kwargs = {k: getters.get(k, db_cfg.get)(k) for k in db_cfg} + +if cert_path: + # patch TLS client cert auth support; asyncpg adds it in a later version + db_ssl_ctx = ssl.create_default_context() + db_ssl_ctx.load_cert_chain( + cert_path, + key_path, + None if key_pass is None else lambda: key_pass, + ) + db_connect_kwargs["ssl"] = db_ssl_ctx if read_only and enforcing: @@ -166,7 +182,7 @@ if _DEBUG_LOG_PATH: # TODO remove this ad hoc setup async def main(): async with ( - triopg.connect(**db_connect_params) as db_conn, + triopg.connect(**db_connect_kwargs) as db_conn, open_nursery() as nursery_a, open_nursery() as nursery_b, open_nursery() as ws_pool_nursery, -- 2.30.2