WIP
authorJakob <jakob@jcornell.net>
Sat, 30 Nov 2019 22:36:51 +0000 (16:36 -0600)
committerJakob <jakob@jcornell.net>
Sat, 30 Nov 2019 22:36:51 +0000 (16:36 -0600)
api.py
auth.py [new file with mode: 0644]
common.py
main.py
oauth.py [deleted file]
util.py

diff --git a/api.py b/api.py
index 78ecbcef7dc5a4fd9cd25836702a8c02270afb32..13bc80c1b9964d9f2775d6d7a0d4adddaeafec8c 100644 (file)
--- a/api.py
+++ b/api.py
@@ -19,7 +19,7 @@ from pathlib import PurePosixPath
 from urllib.parse import urlencode as mk_qs
 import urllib.request
 
-import oauth
+import auth
 import util
 
 def walk_pages(opener, url):
@@ -33,18 +33,15 @@ def walk_pages(opener, url):
                                return
 
 class ApiInterface:
-       def __init__(self, host, path, storage_mgr):
-               self.host = host
-               self.path = path
-               self.opener = urllib.request.build_opener(
-                       oauth.AuthHandler(storage_mgr, self),
-               )
+       def __init__(self, bb_root, auth_handler):
+               self.bb_root = bb_root
+               self.opener = urllib.request.build_opener(auth_handler)
 
        def build_api_url(self, endpoint, query = ''):
                return urllib.parse.urlunparse(urllib.parse.ParseResult(
                        scheme = 'https',
-                       netloc = self.host,
-                       path = str(self.path.joinpath(endpoint)),
+                       netloc = self.bb_root.host,
+                       path = str(self.bb_root.path.joinpath('learn/api/public', endpoint)),
                        query = query,
                        params = '',
                        fragment = '',
diff --git a/auth.py b/auth.py
new file mode 100644 (file)
index 0000000..c305e43
--- /dev/null
+++ b/auth.py
@@ -0,0 +1,310 @@
+# Copyright 2019, Jakob Cornell
+
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+# 
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU General Public License for more details.
+# 
+# You should have received a copy of the GNU General Public License
+# along with this program.  If not, see <https://www.gnu.org/licenses/>.
+
+import base64
+from datetime import timedelta
+import http.server
+import http.cookiejar
+import os
+from pathlib import PurePosixPath
+import urllib.parse
+import urllib.request
+
+import common
+import util
+
+ADDRESS = 'localhost'
+PORT = 1081
+ENDPOINT = '/redirection_endpoint'
+
+TIMEOUT = timedelta(minutes = 30)
+
+LOGGER = common.LOGGER.getChild('auth')
+
+class OauthCodeRequestHandler(http.server.BaseHTTPRequestHandler):
+       LANDING_PATH = '/landing'
+
+       def __init__(self, channel, csrf_token):
+               self.channel = channel
+               self.csrf_token = csrf_token
+
+               self.status = None
+               self.message = None
+               self.caller_val = None
+
+       def __call__(self, *args):
+               # Yep, `socketserver' invokes the handler by calling it, and the superclass's
+               # code for dispatching requests to the appropriate methods is in the constructor.
+               # If you're confused, it's because it doesn't make any sense.
+               super().__init__(*args)
+
+       def do_GET(self):
+               url = urllib.parse.urlparse(self.path)
+               params = urllib.parse.parse_qs(url.query)
+               path = PurePosixPath(url.path)
+               if path == PurePosixPath(ENDPOINT):
+                       REQD_PARAMS = {'code', 'state'}
+                       if params.keys() == REQD_PARAMS and all(len(params[k]) == 1 for k in REQD_PARAMS):
+                               [code] = params['code']
+                               [token] = params['state']
+                               if token == self.csrf_token:
+                                       self.status = 200
+                                       self.message = "Success! You may close this page."
+                                       self._redirect()
+                                       self.caller_val = code
+                               else:
+                                       self.status = 403
+                                       self.message = "CSRF token check failed"
+                                       self._redirect()
+                       else:
+                               self.status = 400
+                               self.message = "Auth server redirect missing required parameters"
+                               self._redirect()
+                               self.caller_val = AuthError(self.message)
+               elif path == PurePosixPath(self.LANDING_PATH):
+                       self.send_response(self.status)
+                       self.end_headers()
+                       self.wfile.write(self.message.encode('ascii'))
+                       if self.caller_val is not None:
+                               self.channel.put(self.caller_val)
+               else:
+                       self.send_response(404)
+                       self.end_headers()
+                       self.wfile.write('Not Found'.encode('ascii'))
+
+       def _redirect(self):
+               self.send_response(302)
+               self.send_header('Location', self.LANDING_PATH)
+               self.end_headers()
+
+       def log_message(self, *_):
+               pass
+
+def _make_pkce_pair():
+       from collections import namedtuple
+       import hashlib
+
+       verifier = base64.urlsafe_b64encode(os.urandom(32)).strip(b'=')
+       hasher = hashlib.sha256()
+       hasher.update(verifier)
+       challenge = base64.urlsafe_b64encode(hasher.digest()).strip(b'=')
+
+       PkcePair = namedtuple('PkcePair', ['verifier', 'challenge'])
+       return PkcePair(verifier.decode('ascii'), challenge.decode('ascii'))
+
+class AuthError(Exception):
+       pass
+
+class OauthHandler(urllib.request.BaseHandler):
+       # Precede the HTTP error handler
+       handler_order = 750
+
+       def __init__(self, storage_mgr, api_iface):
+               self.storage_mgr = storage_mgr
+               self.api_iface = api_iface
+
+       @staticmethod
+       def _set_auth_header(request, token_doc):
+               request.add_unredirected_header(
+                       'Authorization', 'Bearer {}'.format(token_doc['access_token'])
+               )
+
+       def _handle(self, request):
+               token_doc = self.storage_mgr.get('authInfo')
+               if not token_doc:
+                       LOGGER.info("No stored access token. Requesting a new token.")
+                       token_doc = get_access_token(self.storage_mgr, self.api_iface)
+                       self.storage_mgr['authInfo'] = token_doc
+               self._set_auth_header(request, token_doc)
+               return request
+
+       http_request = _handle
+       https_request = _handle
+
+       def http_error_401(self, request, fp, code, msg, headers):
+               LOGGER.info("Access token expired or revoked. Requesting a new token.")
+               token_doc = get_access_token(self.storage_mgr, self.api_iface)
+               self.storage_mgr['authInfo'] = token_doc
+               self._set_auth_header(request, token_doc)
+               return self.parent.open(request, timeout = request.timeout)
+
+def get_access_token(storage_mgr, api_iface):
+       import json
+       import queue
+       from threading import Thread
+       import webbrowser
+
+       client_id = storage_mgr['clientId']
+       client_secret = storage_mgr['clientSecret']
+       payload = base64.b64encode(
+               (client_id + ':' + client_secret).encode('utf-8')
+       ).decode('ascii')
+
+       handlers = [
+               urllib.request.HTTPSHandler(),
+               util.HeaderAddHandler([('Authorization', 'Basic {}'.format(payload))]),
+       ]
+       opener = urllib.request.OpenerDirector()
+       for handler in handlers:
+               opener.add_handler(handler)
+
+       token_doc = None
+       if 'authInfo' in storage_mgr:
+               refresh_token = storage_mgr['authInfo']['refresh_token']
+               params = {
+                       'grant_type': 'refresh_token',
+                       'refresh_token': refresh_token,
+               }
+               request = urllib.request.Request(
+                       url = api_iface.build_api_url('v1/oauth2/token'),
+                       data = urllib.parse.urlencode(params).encode('utf-8'),
+                       method = 'POST',
+               )
+               with opener.open(request) as resp:
+                       body = json.loads(util.decode_body(resp))
+               if resp.status == 200:
+                       token_doc = body
+               else:
+                       LOGGER.info("Stored refresh token rejected. Obtaining new authorization code.")
+                       assert resp.status == 400
+
+       if token_doc is None:
+               redirect_uri = 'http://' + ADDRESS + ':' + str(PORT) + ENDPOINT
+               pkce_pair = _make_pkce_pair()
+               csrf_token = base64.urlsafe_b64encode(os.urandom(24)).decode('ascii')
+               params = {
+                       'redirect_uri': redirect_uri,
+                       'response_type': 'code',
+                       'client_id': client_id,
+                       'scope': 'read offline',
+                       'state': csrf_token,
+                       'code_challenge': pkce_pair.challenge,
+                       'code_challenge_method': 'S256',
+               }
+               user_url = api_iface.build_api_url(
+                       endpoint = 'v1/oauth2/authorizationcode',
+                       query = urllib.parse.urlencode(params),
+               )
+
+               channel = queue.Queue()
+               server = http.server.HTTPServer(
+                       (ADDRESS, PORT),
+                       OauthCodeRequestHandler(channel, csrf_token)
+               )
+               Thread(target = server.serve_forever, daemon = True).start()
+
+               LOGGER.info("Attempting to launch a web browser to authorize the application…")
+               if not webbrowser.open(user_url):
+                       LOGGER.info("Failed to launch a browser. Please visit this URL to authorize the application:")
+                       LOGGER.info("    {}".format(user_url))
+
+               try:
+                       resp = channel.get(timeout = TIMEOUT.total_seconds())
+               except queue.Empty:
+                       resp = None
+               server.shutdown()
+
+               if resp is None:
+                       raise AuthError("Authorization timed out. Please try again.")
+               elif isinstance(resp, Exception):
+                       raise resp
+               else:
+                       auth_code = resp
+                       params = {
+                               'grant_type': 'authorization_code',
+                               'code': auth_code,
+                               'code_verifier': pkce_pair.verifier,
+                               'redirect_uri': redirect_uri,
+                       }
+                       request = urllib.request.Request(
+                               url = api_iface.build_api_url('v1/oauth2/token'),
+                               data = urllib.parse.urlencode(params).encode('utf-8'),
+                               method = 'POST',
+                       )
+                       with opener.open(request) as resp:
+                               assert resp.status == 200
+                               token_doc = json.loads(util.decode_body(resp))
+
+       return token_doc
+
+
+class StorageMgrCookieJar(http.cookiejar.CookieJar):
+       def __init__(self, storage_mgr):
+               self.storage_mgr = storage_mgr
+
+       def load(self):
+               if 'cookies' not in self.storage_mgr.keys():
+                       self.storage_mgr['cookies'] = {}
+               self._cookies = self.storage_mgr['cookies']
+
+       def save(self):
+               self.storage_mgr['cookies'] = self._cookies
+
+
+class CookieAuthHandler(urllib.request.HTTPCookieProcessor):
+       # Precede the HTTP error handler
+       handler_order = 750
+
+       def __init__(self, storage_mgr, bb_root, ui, *args):
+               super().__init__(*args)
+               self.bb_root = bb_root
+               self.storage_mgr = storage_mgr
+               self.ui = ui
+
+       def log_in(self):
+               import base64
+               import bs4
+
+               encode = lambda s: base64.b64encode(s.encode('utf-8')).decode('ascii')
+               decode = lambda s: base64.b64decode(s.encode('ascii')).decode('utf-8')
+
+               # Obfuscate password to prevent accidental discovery
+               if {'username', 'password'} <= self.storage_mgr.keys():
+                       username = storage_mgr['username']
+                       password = decode(storage_mgr['password'])
+               else:
+                       username = self.ui.ask_username()
+                       password = self.ui.ask_password()
+                       storage_mgr['username'] = username
+                       storage_mgr['password'] = encode(password)
+
+               url = self.bb_root.host + str(self.bb_root.path)
+               with self.parent.open(url) as resp:
+                       soup = bs4.BeautifulSoup(resp, 'lxml')
+               [form] = soup.select('#login-form > form[name="login"]')
+               to_keep = lambda elem: (
+                       elem['type'] != 'submit'
+                       and elem['name'] not in {'user_id', 'password'}
+               )
+               inputs = filter(to_keep, form.find_all('input'))
+               params = dict([
+                       ('user_id', username),
+                       ('password', password),
+                       *((elem['name'], elem['value']) for elem in inputs)
+               ])
+               body = urllib.parse.urlencode(params).encode('ascii')
+
+               req = urllib.request.Request(
+                       util.resolve(form['action'], url),
+                       method = form['method'],
+                       data = body,
+               )
+               with self.parent.open(req) as resp:
+                       resp.read()
+
+       def http_error_401(self, request, fp, code, msg, headers):
+               LOGGER.info("Session cookies missing or expired. Logging in…")
+               self.log_in()
+               return self.parent.open(request, timeout = request.timeout)
index 7dca10d197beeb08449eeb1b78799662f1ac236a..3cc3ea7fa7a7d2a1cd099dde500eee15800b3238 100644 (file)
--- a/common.py
+++ b/common.py
 # You should have received a copy of the GNU General Public License
 # along with this program.  If not, see <https://www.gnu.org/licenses/>.
 
-from collections import deque
+from collections import deque, namedtuple
 import json
 import logging
 from pathlib import PurePosixPath
 import urllib.request
 
-API_HOST = 'oberlin-test.blackboard.com'
-API_PATH = PurePosixPath('/learn/api/public')
+BlackboardRoot = namedtuple('BlackboardRoot', ['host', 'path'])
+BB_ROOT = BlackboardRoot(
+       host = 'oberlin-test.blackboard.com',
+       path = PurePosixPath('/'),
+)
 
 LOGGER = logging.getLogger('bb-sync-api')
 logging.basicConfig()
 LOGGER.setLevel(logging.INFO)
 
-class StorageManager:
+class Adapter(logging.LoggerAdapter):
+       def process(self, msg, kwargs):
+               return ('  ' * kwargs['indent'] + msg, kwargs)
+
+'''
+LOGGER = Adapter(LOGGER, {})
+LOGGER.debug('some message', indent = 1)
+
+class IndentingFormatter(logging.Formatter):
+       def format(self, record):
+               prefix = '  ' * record.indent if hasattr(record, 'indent') else ''
+               return prefix + None
+
+LOGGER.setFormatter(logging.Formatter('', style = '{'))
+'''
+
+class StorageManager():
        def __init__(self, path):
                self.path = path
                self.patch = {}
@@ -62,6 +81,9 @@ class StorageManager:
                        self._invalidate()
                        return self.cache.get(key)
 
+       def keys(self):
+               return self.cache.keys() | self.patch.keys()
+
        def __contains__(self, key):
                if key in self.patch:
                        return True
diff --git a/main.py b/main.py
index bfe182030e204a5212c6e0718f5e4de7ac35dab2..c15e992cb861d90ebc05bb4048df2408098691e8 100644 (file)
--- a/main.py
+++ b/main.py
@@ -26,8 +26,8 @@ import urllib.parse
 import urllib.request
 
 from common import *
-import oauth
 import api
+import auth
 import fs
 
 r"""
@@ -38,13 +38,31 @@ Example configuration file (Windows):
 
 """
 
+
+class Cli:
+       @staticmethod
+       def ask_username():
+               return input("Blackboard username: ")
+
+       @staticmethod
+       def ask_password():
+               from getpass import getpass
+               return getpass("Blackboard password: ")
+
+
 cfg_parser = configparser.ConfigParser()
 with open('config.ini') as f:
        cfg_parser.read_file(f)
 cfg_section = cfg_parser['config']
 
-with StorageManager(Path('saved_state')) as storage_mgr:
-       api_iface = api.ApiInterface(API_HOST, API_PATH, storage_mgr)
+with StorageManager(Path('auth_cache')) as storage_mgr:
+       auth_handler = auth.CookieAuthHandler(
+               storage_mgr,
+               BB_ROOT,
+               Cli,
+               auth.StorageMgrCookieJar(storage_mgr),
+       )
+       api_iface = api.ApiInterface(BB_ROOT, auth_handler)
 
        url = input("Blackboard URL: ")
        params = urllib.parse.parse_qs(urllib.parse.urlparse(url).query)
@@ -95,24 +113,39 @@ with StorageManager(Path('saved_state')) as storage_mgr:
                        att_id = attachment_doc['id']
 
                        LOGGER.info("  Checking attachment {id}: \"{fileName}\"".format(**attachment_doc))
-                       with api_iface.download_attachment(course_id, child_doc['id'], att_id) as resp:
-                               with tempfile.NamedTemporaryFile(delete = False) as tmp:
-                                       tmp_path = Path(tmp.name)
-                                       shutil.copyfileobj(resp, tmp)
+
+                       class Result:
+                               NoVersions = namedtuple('NoVersions', [])
+                               MultipleLatest = namedtuple('MultipleLatest', ['paths'])
+                               SingleLatest = namedtuple('SingleLatest', ['version', 'matches'])
+
+                               @staticmethod
+                               def to_update(result):
+                                       return (
+                                               isinstance(result, Result.SingleLatest) and not result.matches
+                                               or isinstance(result, Result.NoVersions)
+                                       )
 
                        my_versions = [info for info in versions.keys() if info.bb_id == att_id]
                        if my_versions:
                                latest = max(my_versions, key = attrgetter('version'))
-                               [latest_path] = versions[latest]
-                               match = filecmp.cmp(str(tmp_path), str(latest_path), shallow = False)
+                               latest_paths = versions[latest]
+                               if len(latest_paths) == 1:
+                                       [latest_path] = latest_paths
+                                       with api_iface.download_attachment(course_id, child_doc['id'], att_id) as resp:
+                                               with tempfile.NamedTemporaryFile(delete = False) as tmp:
+                                                       tmp_path = Path(tmp.name)
+                                                       shutil.copyfileobj(resp, tmp)
+                                       matches = filecmp.cmp(str(tmp_path), str(latest_path), shallow = False)
+                                       result = Result.SingleLatest(latest, matches)
+                               else:
+                                       result = Result.MultipleLatest(latest_paths)
                        else:
-                               match = None
+                               result = Result.NoVersions()
 
-                       if match is True:
-                               tmp_path.unlink()
-                       else:
-                               if match is False:
-                                       new_version = latest.next()
+                       if Result.to_update(result):
+                               if isinstance(result, Result.SingleLatest):
+                                       new_version = result.version.next()
                                        LOGGER.info("    Found new revision ({})".format(new_version.version))
                                else:
                                        new_version = fs.VersionInfo(att_id, 0)
@@ -121,3 +154,11 @@ with StorageManager(Path('saved_state')) as storage_mgr:
                                LOGGER.info("    Destination: {}".format(dest))
                                tmp_path.replace(dest)
                                fs.write_metadata(dest, new_version)
+                       elif isinstance(result, Result.SingleLatest):
+                               # versions match
+                               tmp_path.unlink()
+                       elif isinstance(result, Result.MultipleLatest):
+                               LOGGER.error("    Identified multiple latest versions for {id}: {fileName}"
+                                       .format(**attachment_doc)
+                               )
+                               LOGGER.error("    Please delete ")
diff --git a/oauth.py b/oauth.py
deleted file mode 100644 (file)
index 6996b2a..0000000
--- a/oauth.py
+++ /dev/null
@@ -1,239 +0,0 @@
-# Copyright 2019, Jakob Cornell
-
-# This program is free software: you can redistribute it and/or modify
-# it under the terms of the GNU General Public License as published by
-# the Free Software Foundation, either version 3 of the License, or
-# (at your option) any later version.
-# 
-# This program is distributed in the hope that it will be useful,
-# but WITHOUT ANY WARRANTY; without even the implied warranty of
-# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
-# GNU General Public License for more details.
-# 
-# You should have received a copy of the GNU General Public License
-# along with this program.  If not, see <https://www.gnu.org/licenses/>.
-
-import base64
-from datetime import timedelta
-import http.server
-import os
-from pathlib import PurePosixPath
-import urllib.parse
-import urllib.request
-
-import common
-import util
-
-ADDRESS = 'localhost'
-PORT = 1081
-ENDPOINT = '/redirection_endpoint'
-
-TIMEOUT = timedelta(minutes = 30)
-
-LOGGER = common.LOGGER.getChild('oauth')
-
-class AuthCodeRequestHandler(http.server.BaseHTTPRequestHandler):
-       LANDING_PATH = '/landing'
-
-       def __init__(self, channel, csrf_token):
-               self.channel = channel
-               self.csrf_token = csrf_token
-
-               self.status = None
-               self.message = None
-               self.caller_val = None
-
-       def __call__(self, *args):
-               # Yep, `socketserver' invokes the handler by calling it, and the superclass's
-               # code for dispatching requests to the appropriate methods is in the constructor.
-               # If you're confused, it's because it doesn't make any sense.
-               super().__init__(*args)
-
-       def do_GET(self):
-               url = urllib.parse.urlparse(self.path)
-               params = urllib.parse.parse_qs(url.query)
-               path = PurePosixPath(url.path)
-               if path == PurePosixPath(ENDPOINT):
-                       REQD_PARAMS = {'code', 'state'}
-                       if params.keys() == REQD_PARAMS and all(len(params[k]) == 1 for k in REQD_PARAMS):
-                               [code] = params['code']
-                               [token] = params['state']
-                               if token == self.csrf_token:
-                                       self.status = 200
-                                       self.message = "Success! You may close this page."
-                                       self._redirect()
-                                       self.caller_val = code
-                               else:
-                                       self.status = 403
-                                       self.message = "CSRF token check failed"
-                                       self._redirect()
-                       else:
-                               self.status = 400
-                               self.message = "Auth server redirect missing required parameters"
-                               self._redirect()
-                               self.caller_val = AuthError(self.message)
-               elif path == PurePosixPath(self.LANDING_PATH):
-                       self.send_response(self.status)
-                       self.end_headers()
-                       self.wfile.write(self.message.encode('ascii'))
-                       if self.caller_val is not None:
-                               self.channel.put(self.caller_val)
-               else:
-                       self.send_response(404)
-                       self.end_headers()
-                       self.wfile.write('Not Found'.encode('ascii'))
-
-       def _redirect(self):
-               self.send_response(302)
-               self.send_header('Location', self.LANDING_PATH)
-               self.end_headers()
-
-       def log_message(self, *_):
-               pass
-
-def _make_pkce_pair():
-       from collections import namedtuple
-       import hashlib
-
-       verifier = base64.urlsafe_b64encode(os.urandom(32)).strip(b'=')
-       hasher = hashlib.sha256()
-       hasher.update(verifier)
-       challenge = base64.urlsafe_b64encode(hasher.digest()).strip(b'=')
-
-       PkcePair = namedtuple('PkcePair', ['verifier', 'challenge'])
-       return PkcePair(verifier.decode('ascii'), challenge.decode('ascii'))
-
-class AuthError(Exception):
-       pass
-
-class AuthHandler(urllib.request.BaseHandler):
-       # Precede the HTTP error handler
-       handler_order = 750
-
-       def __init__(self, storage_mgr, api_iface):
-               self.storage_mgr = storage_mgr
-               self.api_iface = api_iface
-
-       @staticmethod
-       def _set_auth_header(request, token_doc):
-               request.add_unredirected_header(
-                       'Authorization', 'Bearer {}'.format(token_doc['access_token'])
-               )
-
-       def _handle(self, request):
-               token_doc = self.storage_mgr.get('authInfo')
-               if not token_doc:
-                       LOGGER.info("No stored access token. Requesting a new token.")
-                       token_doc = get_access_token(self.storage_mgr, self.api_iface)
-                       self.storage_mgr['authInfo'] = token_doc
-               self._set_auth_header(request, token_doc)
-               return request
-
-       http_request = _handle
-       https_request = _handle
-
-       def http_error_401(self, request, fp, code, msg, headers):
-               LOGGER.info("Access token expired or revoked. Requesting a new token.")
-               token_doc = get_access_token(self.storage_mgr, self.api_iface)
-               self.storage_mgr['authInfo'] = token_doc
-               self._set_auth_header(request, token_doc)
-               return self.parent.open(request, timeout = request.timeout)
-
-def get_access_token(storage_mgr, api_iface):
-       import json
-       import queue
-       from threading import Thread
-       import webbrowser
-
-       client_id = storage_mgr['clientId']
-       client_secret = storage_mgr['clientSecret']
-       payload = base64.b64encode(
-               (client_id + ':' + client_secret).encode('utf-8')
-       ).decode('ascii')
-
-       handlers = [
-               urllib.request.HTTPSHandler(),
-               util.HeaderAddHandler([('Authorization', 'Basic {}'.format(payload))])
-       ]
-       opener = urllib.request.OpenerDirector()
-       for handler in handlers:
-               opener.add_handler(handler)
-
-       token_doc = None
-       if 'authInfo' in storage_mgr:
-               refresh_token = storage_mgr['authInfo']['refresh_token']
-               params = {
-                       'grant_type': 'refresh_token',
-                       'refresh_token': refresh_token,
-               }
-               request = urllib.request.Request(
-                       url = api_iface.build_api_url('v1/oauth2/token'),
-                       data = urllib.parse.urlencode(params).encode('utf-8'),
-                       method = 'POST',
-               )
-               with opener.open(request) as resp:
-                       body = json.loads(util.decode_body(resp))
-               if resp.status == 200:
-                       token_doc = body
-               else:
-                       LOGGER.info("Stored refresh token rejected. Obtaining new authorization code.")
-                       assert resp.status == 400
-
-       if token_doc is None:
-               redirect_uri = 'http://' + ADDRESS + ':' + str(PORT) + ENDPOINT
-               pkce_pair = _make_pkce_pair()
-               csrf_token = base64.urlsafe_b64encode(os.urandom(24)).decode('ascii')
-               params = {
-                       'redirect_uri': redirect_uri,
-                       'response_type': 'code',
-                       'client_id': client_id,
-                       'scope': 'read offline',
-                       'state': csrf_token,
-                       'code_challenge': pkce_pair.challenge,
-                       'code_challenge_method': 'S256',
-               }
-               user_url = api_iface.build_api_url(
-                       endpoint = 'v1/oauth2/authorizationcode',
-                       query = urllib.parse.urlencode(params),
-               )
-
-               channel = queue.Queue()
-               server = http.server.HTTPServer(
-                       (ADDRESS, PORT),
-                       AuthCodeRequestHandler(channel, csrf_token)
-               )
-               Thread(target = server.serve_forever, daemon = True).start()
-
-               LOGGER.info("Attempting to launch a web browser to authorize the application…")
-               if not webbrowser.open(user_url):
-                       LOGGER.info("Failed to launch a browser. Please visit this URL to authorize the application:")
-                       LOGGER.info("    {}".format(user_url))
-
-               try:
-                       resp = channel.get(timeout = TIMEOUT.total_seconds())
-               except queue.Empty:
-                       resp = None
-               server.shutdown()
-
-               if resp is None:
-                       raise AuthError("Authorization timed out. Please try again.")
-               elif isinstance(resp, Exception):
-                       raise resp
-               else:
-                       auth_code = resp
-                       params = {
-                               'grant_type': 'authorization_code',
-                               'code': auth_code,
-                               'code_verifier': pkce_pair.verifier,
-                               'redirect_uri': redirect_uri,
-                       }
-                       request = urllib.request.Request(
-                               url = api_iface.build_api_url('v1/oauth2/token'),
-                               data = urllib.parse.urlencode(params).encode('utf-8'),
-                               method = 'POST',
-                       )
-                       with opener.open(request) as resp:
-                               assert resp.status == 200
-                               token_doc = json.loads(util.decode_body(resp))
-
-       return token_doc
diff --git a/util.py b/util.py
index e2cd9ddb8effe478c567c0467e89f737c7742353..d1b86cdde18127fa34682d010a6ec63e99e1c045 100644 (file)
--- a/util.py
+++ b/util.py
@@ -15,6 +15,7 @@
 
 import urllib.request
 
+
 class HeaderAddHandler(urllib.request.BaseHandler):
        def __init__(self, headers):
                self.headers = headers
@@ -27,6 +28,19 @@ class HeaderAddHandler(urllib.request.BaseHandler):
        http_request = _handle
        https_request = _handle
 
+
+def resolve(target, curr_url):
+       # I hope I do this like a browser...
+       parsed = urllib.parse.urlparse(target)
+       if parsed.scheme:
+               return target
+       elif target.startswith('/'):
+               curr = urllib.parse.urlparse(curr_url)
+               return curr.scheme + '://' + curr.netloc + target
+       else:
+               raise NotImplementedError("relative URI")
+
+
 def decode_body(resp):
        def get_charset():
                if 'Content-Type' in resp.info():