WIP
authorJakob <jakob@jcornell.net>
Sat, 23 Nov 2019 01:24:27 +0000 (19:24 -0600)
committerJakob <jakob@jcornell.net>
Sat, 23 Nov 2019 01:24:27 +0000 (19:24 -0600)
.gitignore [new file with mode: 0644]
api.py [new file with mode: 0644]
common.py [new file with mode: 0644]
fs.py [new file with mode: 0644]
main.py
oauth.py
util.py

diff --git a/.gitignore b/.gitignore
new file mode 100644 (file)
index 0000000..17ca792
--- /dev/null
@@ -0,0 +1,4 @@
+/config.ini
+/saved_state
+
+__pycache__/
diff --git a/api.py b/api.py
new file mode 100644 (file)
index 0000000..78ecbce
--- /dev/null
+++ b/api.py
@@ -0,0 +1,97 @@
+# Copyright 2019, Jakob Cornell
+
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+# 
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU General Public License for more details.
+# 
+# You should have received a copy of the GNU General Public License
+# along with this program.  If not, see <https://www.gnu.org/licenses/>.
+
+from collections import namedtuple
+import json
+from pathlib import PurePosixPath
+from urllib.parse import urlencode as mk_qs
+import urllib.request
+
+import oauth
+import util
+
+def walk_pages(opener, url):
+       while True:
+               with opener.open(url) as resp:
+                       data = json.loads(util.decode_body(resp))
+                       yield from data['results']
+                       if 'paging' in data and 'nextPage' in data['paging']:
+                               url = data['paging']['nextPage']
+                       else:
+                               return
+
+class ApiInterface:
+       def __init__(self, host, path, storage_mgr):
+               self.host = host
+               self.path = path
+               self.opener = urllib.request.build_opener(
+                       oauth.AuthHandler(storage_mgr, self),
+               )
+
+       def build_api_url(self, endpoint, query = ''):
+               return urllib.parse.urlunparse(urllib.parse.ParseResult(
+                       scheme = 'https',
+                       netloc = self.host,
+                       path = str(self.path.joinpath(endpoint)),
+                       query = query,
+                       params = '',
+                       fragment = '',
+               ))
+
+       def get_content_path(self, course_id, content_id):
+               from collections import deque
+
+               def get_info(content_id):
+                       url = self.build_api_url(
+                               PurePosixPath('v1', 'courses', course_id, 'contents', content_id)
+                       )
+                       with self.opener.open(url) as resp:
+                               content_doc = json.loads(util.decode_body(resp))
+                       return content_doc
+
+               curr_content = get_info(content_id)
+               path = deque()
+               while 'parentId' in curr_content:
+                       path.appendleft(curr_content)
+                       curr_content = get_info(curr_content['parentId'])
+
+               return list(path)
+
+       def get_children(self, course_id, content_id):
+               url = self.build_api_url(
+                       PurePosixPath('v1', 'courses', course_id, 'contents', content_id, 'children'),
+               )
+               return walk_pages(self.opener, url)
+
+       def get_attachments(self, course_id, content_id):
+               url = self.build_api_url(
+                       PurePosixPath('v1', 'courses', course_id, 'contents', content_id, 'attachments'),
+               )
+               return walk_pages(self.opener, url)
+
+       def download_attachment(self, course_id, content_id, attachment_id):
+               url = self.build_api_url(
+                       PurePosixPath(
+                               'v1',
+                               'courses',
+                               course_id,
+                               'contents',
+                               content_id,
+                               'attachments',
+                               attachment_id,
+                               'download',
+                       )
+               )
+               return self.opener.open(url)
diff --git a/common.py b/common.py
new file mode 100644 (file)
index 0000000..7dca10d
--- /dev/null
+++ b/common.py
@@ -0,0 +1,73 @@
+# Copyright 2019, Jakob Cornell
+
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+# 
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU General Public License for more details.
+# 
+# You should have received a copy of the GNU General Public License
+# along with this program.  If not, see <https://www.gnu.org/licenses/>.
+
+from collections import deque
+import json
+import logging
+from pathlib import PurePosixPath
+import urllib.request
+
+API_HOST = 'oberlin-test.blackboard.com'
+API_PATH = PurePosixPath('/learn/api/public')
+
+LOGGER = logging.getLogger('bb-sync-api')
+logging.basicConfig()
+LOGGER.setLevel(logging.INFO)
+
+class StorageManager:
+       def __init__(self, path):
+               self.path = path
+               self.patch = {}
+               self._invalidate()
+
+       def __enter__(self):
+               return self
+
+       def __exit__(self, *_):
+               self._invalidate()
+               self.cache.update(self.patch)
+               with self.path.open('w') as f:
+                       json.dump(self.cache, f)
+
+       def _invalidate(self):
+               try:
+                       with self.path.open() as f:
+                               self.cache = json.load(f)
+               except FileNotFoundError:
+                       self.cache = {}
+
+       def __getitem__(self, key):
+               if key in self.patch:
+                       return self.patch[key]
+               else:
+                       self._invalidate()
+                       return self.cache[key]
+
+       def get(self, key):
+               if key in self.patch:
+                       return self.patch[key]
+               else:
+                       self._invalidate()
+                       return self.cache.get(key)
+
+       def __contains__(self, key):
+               if key in self.patch:
+                       return True
+               else:
+                       self._invalidate()
+                       return key in self.cache
+
+       def __setitem__(self, key, value):
+               self.patch[key] = value
diff --git a/fs.py b/fs.py
new file mode 100644 (file)
index 0000000..ea8d04b
--- /dev/null
+++ b/fs.py
@@ -0,0 +1,173 @@
+# Copyright 2018-2019, Jakob Cornell
+
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+# 
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU General Public License for more details.
+# 
+# You should have received a copy of the GNU General Public License
+# along with this program.  If not, see <https://www.gnu.org/licenses/>.
+
+from collections import namedtuple
+import json
+import pathlib
+import re
+
+
+class WinPath(pathlib.WindowsPath):
+       r"""
+       This constructs a path for an NTFS alternate data stream on Windows.
+       In most cases, for directories, a trailing slash is permitted and optional:
+
+               - `C:/dir/dir2:stream'
+               - `C:/dir/dir2/:stream'
+
+       The exception is volume root paths. `C:/:stream' is fine, but Windows rejects
+       `C::stream'. For `.' we also need to form `./:stream', not `.:stream'.
+       It seems whether the `name' attribute is empty is a sufficient indicator.
+
+       The other tricky thing is that for some reason Windows has relative paths with
+       drive parts, like `C:some/path'. This is what `joinpath' does with `C:', so
+       we need to manually add a separator to ensure the resulting path actually
+       points to an ADS.
+       """
+       def get_ads(self, name):
+               full_name = self.name + ':' + name
+               if self.name:
+                       return self.with_name(full_name)
+               else:
+                       return self.joinpath('/', full_name)
+
+
+def clean_win_path(seg):
+       # https://docs.microsoft.com/en-us/windows/desktop/FileIO/naming-a-file
+       bad = {*range(0, 31 + 1), *map(ord, r'<>:"/\|?*')}
+       return seg.translate({ch: ' ' for ch in bad})
+
+
+def _split_name(name):
+       suff_len = sum(len(s) for s in pathlib.Path(name).suffixes)
+       stem = name[slice(None, len(name) - suff_len)]
+       suff = name[slice(len(name) - suff_len, None)]
+       return (stem, suff)
+
+
+def content_path(course_path, segments):
+       path = course_path
+       for (id_, name) in segments:
+               if path.exists():
+                       cands = [child for child in path.iterdir() if re.search(r'\({}\)$'.format(re.escape(id_)), child.name)]
+                       if cands:
+                               [path] = cands
+                               continue
+               path = path.joinpath(clean_win_path('{}({})'.format(name, id_)))
+       return path
+
+
+BB_META_STREAM_NAME = '8f3b98ea-e227-478f-bb58-5c31db476409'
+
+
+VersionInfo = namedtuple('ParseResult', ['bb_id', 'version'])
+VersionInfo.with_version = lambda self, v: VersionInfo(self.bb_id, v)
+
+
+def _extract_version(path):
+       if not (path.is_file() or path.is_dir()):
+               return None
+
+       info = None
+       stream_path = WinPath(path).get_ads(BB_META_STREAM_NAME)
+       if stream_path.exists():
+               # NTFS ADS metadata
+               with stream_path.open() as f:
+                       try:
+                               metadata = json.load(f)
+                       except json.decoder.JSONDecodeError:
+                               pass
+                       else:
+                               bb_id = metadata.get('bbId')
+                               version = metadata.get('version')
+                               version_typecheck = lambda v: isinstance(v, int) if path.is_file() else v is None
+                               if isinstance(bb_id, str) and version_typecheck(version):
+                                       info = VersionInfo(bb_id, version)
+       else:
+               # old in-filename metadata
+               (stem, _) = _split_name(path.name)
+               match = re.search(r'\((?P<id>[\d_]+)(?:,(?P<version>\d+))?\)$', stem)
+               if match:
+                       version = int(match.group('version')) if match.group('version') else None
+                       if (version is None) == path.is_dir():
+                               info = VersionInfo(match.group('id'), int(match.group('version')))
+
+       return info
+
+
+def get_child_versions(path):
+       from collections import defaultdict
+
+       results = defaultdict(set)
+       for path in path.iterdir():
+               info = _extract_version(path)
+               if info:
+                       results[info].add(path)
+       return results
+
+
+# https://docs.microsoft.com/en-us/windows/desktop/FileIO/naming-a-file
+def sanitize(seg):
+       bad = {*map(chr, range(0, 31 + 1)), *r'<>:"/\|?*'}
+       return seg.translate({ch: '!' for ch in bad})
+
+
+def join_content_path(path, content_doc):
+       ver_map = get_child_versions(path)
+       versions = [v for v in ver_map if v.bb_id == content_doc['id']]
+       if versions:
+               [info] = versions
+               [child_path] = ver_map[info]
+               return child_path
+       else:
+               return set_up_new_dir(
+                       path,
+                       content_doc['title'],
+                       VersionInfo(bb_id = content_doc['id'], version = None),
+               )
+
+
+def set_up_new_file(*args):
+       return _set_up_new_path(
+               *args,
+               creator = lambda path: path.open('x').close()
+       )
+
+def set_up_new_dir(*args):
+       return _set_up_new_path(*args, creator = pathlib.Path.mkdir)
+
+def _set_up_new_path(parent, base_name, version_info, creator):
+       import itertools
+
+       sanitized = sanitize(base_name)
+       (stem, exts) = _split_name(sanitized)
+       def path_for(ordinal):
+               if ordinal is None:
+                       seg = sanitized
+               else:
+                       seg = '{stem}({ordinal}){exts}'.format(stem=stem, ordinal=ordinal, exts=exts)
+                       assert seg == sanitize(seg)
+               return parent.joinpath(seg)
+
+       path = next(
+               path
+               for path in map(path_for, itertools.chain([None], itertools.count()))
+               if not path.exists()
+       )
+
+       creator(path)
+       with WinPath(path).get_ads(BB_META_STREAM_NAME).open('x') as f:
+               json.dump({'bbId': version_info.bb_id, 'version': version_info.version}, f)
+       return path
diff --git a/main.py b/main.py
index f46f06f30e012b48ca20f6f86a380d5e7f34539c..06ce2acd0128436922195bdcf155d4be541632a8 100644 (file)
--- a/main.py
+++ b/main.py
@@ -1,4 +1,4 @@
-# Copyright 2019 Jakob Cornell
+# Copyright 2019, Jakob Cornell
 
 # This program is free software: you can redistribute it and/or modify
 # it under the terms of the GNU General Public License as published by
 # You should have received a copy of the GNU General Public License
 # along with this program.  If not, see <https://www.gnu.org/licenses/>.
 
-import urllib.request
-from pathlib import Path
-import tempfile
-from collections import deque
-import shutil
+import configparser
 import filecmp
+import functools
 import json
+from operator import attrgetter
+from pathlib import Path
+import shutil
+import sys
+import tempfile
+import urllib.parse
+import urllib.request
 
-import toml
+from common import *
+import oauth
+import api
+import fs
 
-HOST = 'blackboard.oberlin.edu'
-API_PATH = '/learn/api/public/v1'
+r"""
+Example configuration file (Windows):
 
-def get_uri(path):
-       return 'https://' + HOST + API_PATH + path
+       [config]
+       base_path: C:\Users\Jakob\blackboard\math_course
 
-def walk_pages(uri):
-       while True:
-               with urllib.request.urlopen(uri) as resp:
-                       data = json.load(resp)
-                       yield from data['results']
-                       if 'paging' in data and 'nextPage' in data['paging']:
-                               uri = data['paging']['nextPage']
-                       else:
-                               return
-
-def get_path(course, leaf_id):
-       path = deque()
-       id_ = leaf_id
-       while True:
-               with urllib.request.urlopen(get_uri('/courses/{}/contents/{}'.format(course, leaf_id))) as resp:
-                       data = json.load(resp)
-               path.appendleft('{} ({})'.format(data['title'], id_))
-               if 'parentId' in data:
-                       id_ = data['parentId']
-               else:
-                       return list(path)
-
-# https://docs.microsoft.com/en-us/windows/desktop/FileIO/naming-a-file
-def sanitize(seg):
-       bad = {*map(chr, range(0, 31 + 1)), *r'<>:"/\|?*'}
-       return seg.translate({ch: None for ch in bad})
-
-# parse directory contents into mirror state for individual item
-def read_metadata(path):
-       info = {}
-       
-
-try:
-       with open('config.toml') as f:
-               config = toml.load(f)
-except OSError:
-       print("Cannot open configuration file `config.toml`:", file = sys.stderr)
-       raise
-if 'base_path' not in config:
-       print("`base_path` not in config file")
-       raise KeyError()
+"""
 
-import sys
+cfg_parser = configparser.ConfigParser()
+with open('config.ini') as f:
+       cfg_parser.read_file(f)
+cfg_section = cfg_parser['config']
 
-args = sys.argv[1:]
-if len(args) != 1:
-       print("Please pass the Blackboard URL to sync as an argument", file = sys.stderr)
-       raise AssertionError()
+with StorageManager(Path('saved_state')) as storage_mgr:
+       api_iface = api.ApiInterface(API_HOST, API_PATH, storage_mgr)
 
-url = args[0]
-try:
+       url = input("Blackboard URL: ")
        params = urllib.parse.parse_qs(urllib.parse.urlparse(url).query)
-except ValueError:
-       print("That URL doesn't look right:", file = sys.stderr)
-       raise
-
-if 'course_id' not in params or 'content_id' not in params:
-       print("That URL doesn't look right.", file = sys.stderr)
-       raise ValueError()
-course = params['course_id']
-page = params['content_id']
-base_path = Path(config['base_path'], *map(sanitize, get_path(page)))
-
-for item in walk_pages(get_uri('/courses/{}/contents/{}/children'.format(course, page))):
-       for attachment in walk_pages(get_uri('/courses/{}/contents/{}/attachments'.format(course, item['id']))):
-               dir_ = base_path.joinpath(sanitize('{} ({})'.format(item['title'], item['id'])))
-               orig_name = sanitize('{} ({})'.format(attachment['fileName'], attachment['id']))
-               dir_.mkdir(parents = True, exist_ok = True)
-               with tempfile.NamedTemporaryFile(prefix = orig_name, dir = str(dir_), delete = False) as temp:
-                       with urllib.request.urlopen(
-                               '/courses/{}/contents/{}/attachments/{}/download'.format(course, item['id'], attachment['id'])
-                       ) as resp:
-                               shutil.copyfileobj(resp, temp)
-                       temp_name = temp.name
-               orig = dir_.joinpath(orig_name)
-               temp = dir_.joinpath(temp_name)
-               if dir_.joinpath(name).exists():
-                       if not filecmp.cmp(str(dir_.joinpath(temp_name)), str(dir_.joinpath(name)), shallow = False):
-                               print("")
-               else:
-                       shutil.move(str(temp), str(orig))
+       course_ids = params.get('course_id')
+       content_ids = params.get('content_id')
+       if not (course_ids and content_ids):
+               url_ok = False
+       else:
+               course_ids = set(course_ids)
+               content_ids = set(content_ids)
+               url_ok = len(course_ids) == len(content_ids) == 1
+       if not url_ok:
+               raise ValueError("Unrecognized URL")
+       [course_id] = course_ids
+       [content_id] = content_ids
+
+       local_course_root = fs.WinPath(cfg_section['base_path'])
+       meta_path = local_course_root.get_ads(fs.BB_META_STREAM_NAME)
+       if meta_path.is_file():
+               with meta_path.open() as f:
+                       try:
+                               metadata = json.load(f)
+                       except json.decoder.JSONDecodeError:
+                               found_course_id = None
+                       else:
+                               found_course_id = metadata.get('course_id')
+       else:
+               found_course_id = None
+       if found_course_id:
+               if found_course_id != course_id:
+                       LOGGER.warning("Using a course root previously used for another course!")
+                       LOGGER.warning("File versioning may misbehave.")
+       else:
+               with meta_path.open('x') as f:
+                       json.dump({'course_id': course_id}, f)
+
+       content_path = api_iface.get_content_path(course_id, content_id)
+       LOGGER.info("Blackboard content path: {}".format('/'.join(seg['id'] for seg in content_path)))
+       local_content_root = functools.reduce(fs.join_content_path, content_path, local_course_root)
+       LOGGER.info("Local content path: {}".format(local_content_root))
+
+       for child_doc in api_iface.get_children(course_id, content_id):
+               LOGGER.info("Processing content item {id}: \"{title}\"".format(**child_doc))
+               local_path = fs.join_content_path(local_content_root, child_doc)
+               versions = fs.get_child_versions(local_path)
+
+               for attachment_doc in api_iface.get_attachments(course_id, child_doc['id']):
+                       att_id = attachment_doc['id']
+
+                       LOGGER.info("  Checking attachment {id}: \"{fileName}\"".format(**attachment_doc))
+                       with api_iface.download_attachment(course_id, child_doc['id'], att_id) as resp:
+                               with tempfile.NamedTemporaryFile(delete = False) as tmp:
+                                       tmp_path = Path(tmp.name)
+                                       shutil.copyfileobj(resp, tmp)
+
+                       my_versions = [info for info in versions.keys() if info.bb_id == att_id]
+                       if my_versions:
+                               latest = max(my_versions, key = attrgetter('version'))
+                               [latest_path] = versions[latest]
+                               match = filecmp.cmp(str(tmp_path), str(latest_path), shallow = False)
+                       else:
+                               match = None
+
+                       if match is True:
+                               tmp_path.unlink()
+                       else:
+                               if match is False:
+                                       new_version = latest.with_version(latest.version + 1)
+                                       LOGGER.info("    Found new revision ({})".format(new_version.version))
+                               else:
+                                       new_version = fs.VersionInfo(att_id, 0)
+                                       LOGGER.info("    Storing initial revision")
+                               dest = fs.set_up_new_file(local_path, attachment_doc['fileName'], new_version)
+                               LOGGER.info("    Destination: {}".format(dest))
+                               tmp_path.replace(dest)
index 6aeb982abbdd927544841fce5866b7c3db475cd5..6996b2a9474731de41e1ea7fff2e0ece28cb47dc 100644 (file)
--- a/oauth.py
+++ b/oauth.py
@@ -1,4 +1,4 @@
-# Copyright 2019 Jakob Cornell
+# Copyright 2019, Jakob Cornell
 
 # This program is free software: you can redistribute it and/or modify
 # it under the terms of the GNU General Public License as published by
 # You should have received a copy of the GNU General Public License
 # along with this program.  If not, see <https://www.gnu.org/licenses/>.
 
+import base64
+from datetime import timedelta
 import http.server
-from queue import Queue
-from threading import Thread
+import os
+from pathlib import PurePosixPath
 import urllib.parse
-import webbrowser
+import urllib.request
+
+import common
+import util
 
 ADDRESS = 'localhost'
 PORT = 1081
+ENDPOINT = '/redirection_endpoint'
+
+TIMEOUT = timedelta(minutes = 30)
+
+LOGGER = common.LOGGER.getChild('oauth')
+
+class AuthCodeRequestHandler(http.server.BaseHTTPRequestHandler):
+       LANDING_PATH = '/landing'
 
-def _get_handler(queue):
-       # Why does your API require me to use such hacks?
+       def __init__(self, channel, csrf_token):
+               self.channel = channel
+               self.csrf_token = csrf_token
 
-       class AuthCodeRequestHandler(http.server.BaseHTTPRequestHandler):
-               def do_GET(self):
-                       self.send_response(200)
+               self.status = None
+               self.message = None
+               self.caller_val = None
+
+       def __call__(self, *args):
+               # Yep, `socketserver' invokes the handler by calling it, and the superclass's
+               # code for dispatching requests to the appropriate methods is in the constructor.
+               # If you're confused, it's because it doesn't make any sense.
+               super().__init__(*args)
+
+       def do_GET(self):
+               url = urllib.parse.urlparse(self.path)
+               params = urllib.parse.parse_qs(url.query)
+               path = PurePosixPath(url.path)
+               if path == PurePosixPath(ENDPOINT):
+                       REQD_PARAMS = {'code', 'state'}
+                       if params.keys() == REQD_PARAMS and all(len(params[k]) == 1 for k in REQD_PARAMS):
+                               [code] = params['code']
+                               [token] = params['state']
+                               if token == self.csrf_token:
+                                       self.status = 200
+                                       self.message = "Success! You may close this page."
+                                       self._redirect()
+                                       self.caller_val = code
+                               else:
+                                       self.status = 403
+                                       self.message = "CSRF token check failed"
+                                       self._redirect()
+                       else:
+                               self.status = 400
+                               self.message = "Auth server redirect missing required parameters"
+                               self._redirect()
+                               self.caller_val = AuthError(self.message)
+               elif path == PurePosixPath(self.LANDING_PATH):
+                       self.send_response(self.status)
                        self.end_headers()
-                       self.wfile.write(self.path.encode('ascii'))
-                       queue.put(None)
-
-               def log_message(self, *_):
-                       pass
-
-       return AuthCodeRequestHandler
-
-def get_auth_code(bb_host, api_path, client_id):
-       redirect_uri = 'http://' + ADDRESS + ':' + str(PORT) + '/'
-       params = {
-               'redirect_uri': redirect_uri,
-               'response_type': 'code',
-               'client_id': client_id,
-       }
-       qs = urllib.parse.urlencode(params)
-       user_url = urllib.parse.urlunparse(urllib.parse.ParseResult(
-               scheme = 'https',
-               netloc = bb_host,
-               path = str(api_path.joinpath('oauth2/authorizationcode')),
-               query = qs,
-               params = '',
-               fragment = '',
-       ))
-       webbrowser.open(user_url)
-
-       queue = Queue()
-       server = http.server.HTTPServer((ADDRESS, PORT), _get_handler(queue))
-       Thread(target = lambda: server.serve_forever()).start()
-       queue.get()
-       server.shutdown()
-
-import pathlib
-get_auth_code('oberlin-test.blackboard.com', pathlib.PurePosixPath('/learn/api/public/v1'))
+                       self.wfile.write(self.message.encode('ascii'))
+                       if self.caller_val is not None:
+                               self.channel.put(self.caller_val)
+               else:
+                       self.send_response(404)
+                       self.end_headers()
+                       self.wfile.write('Not Found'.encode('ascii'))
+
+       def _redirect(self):
+               self.send_response(302)
+               self.send_header('Location', self.LANDING_PATH)
+               self.end_headers()
+
+       def log_message(self, *_):
+               pass
+
+def _make_pkce_pair():
+       from collections import namedtuple
+       import hashlib
+
+       verifier = base64.urlsafe_b64encode(os.urandom(32)).strip(b'=')
+       hasher = hashlib.sha256()
+       hasher.update(verifier)
+       challenge = base64.urlsafe_b64encode(hasher.digest()).strip(b'=')
+
+       PkcePair = namedtuple('PkcePair', ['verifier', 'challenge'])
+       return PkcePair(verifier.decode('ascii'), challenge.decode('ascii'))
+
+class AuthError(Exception):
+       pass
+
+class AuthHandler(urllib.request.BaseHandler):
+       # Precede the HTTP error handler
+       handler_order = 750
+
+       def __init__(self, storage_mgr, api_iface):
+               self.storage_mgr = storage_mgr
+               self.api_iface = api_iface
+
+       @staticmethod
+       def _set_auth_header(request, token_doc):
+               request.add_unredirected_header(
+                       'Authorization', 'Bearer {}'.format(token_doc['access_token'])
+               )
+
+       def _handle(self, request):
+               token_doc = self.storage_mgr.get('authInfo')
+               if not token_doc:
+                       LOGGER.info("No stored access token. Requesting a new token.")
+                       token_doc = get_access_token(self.storage_mgr, self.api_iface)
+                       self.storage_mgr['authInfo'] = token_doc
+               self._set_auth_header(request, token_doc)
+               return request
+
+       http_request = _handle
+       https_request = _handle
+
+       def http_error_401(self, request, fp, code, msg, headers):
+               LOGGER.info("Access token expired or revoked. Requesting a new token.")
+               token_doc = get_access_token(self.storage_mgr, self.api_iface)
+               self.storage_mgr['authInfo'] = token_doc
+               self._set_auth_header(request, token_doc)
+               return self.parent.open(request, timeout = request.timeout)
+
+def get_access_token(storage_mgr, api_iface):
+       import json
+       import queue
+       from threading import Thread
+       import webbrowser
+
+       client_id = storage_mgr['clientId']
+       client_secret = storage_mgr['clientSecret']
+       payload = base64.b64encode(
+               (client_id + ':' + client_secret).encode('utf-8')
+       ).decode('ascii')
+
+       handlers = [
+               urllib.request.HTTPSHandler(),
+               util.HeaderAddHandler([('Authorization', 'Basic {}'.format(payload))])
+       ]
+       opener = urllib.request.OpenerDirector()
+       for handler in handlers:
+               opener.add_handler(handler)
+
+       token_doc = None
+       if 'authInfo' in storage_mgr:
+               refresh_token = storage_mgr['authInfo']['refresh_token']
+               params = {
+                       'grant_type': 'refresh_token',
+                       'refresh_token': refresh_token,
+               }
+               request = urllib.request.Request(
+                       url = api_iface.build_api_url('v1/oauth2/token'),
+                       data = urllib.parse.urlencode(params).encode('utf-8'),
+                       method = 'POST',
+               )
+               with opener.open(request) as resp:
+                       body = json.loads(util.decode_body(resp))
+               if resp.status == 200:
+                       token_doc = body
+               else:
+                       LOGGER.info("Stored refresh token rejected. Obtaining new authorization code.")
+                       assert resp.status == 400
+
+       if token_doc is None:
+               redirect_uri = 'http://' + ADDRESS + ':' + str(PORT) + ENDPOINT
+               pkce_pair = _make_pkce_pair()
+               csrf_token = base64.urlsafe_b64encode(os.urandom(24)).decode('ascii')
+               params = {
+                       'redirect_uri': redirect_uri,
+                       'response_type': 'code',
+                       'client_id': client_id,
+                       'scope': 'read offline',
+                       'state': csrf_token,
+                       'code_challenge': pkce_pair.challenge,
+                       'code_challenge_method': 'S256',
+               }
+               user_url = api_iface.build_api_url(
+                       endpoint = 'v1/oauth2/authorizationcode',
+                       query = urllib.parse.urlencode(params),
+               )
+
+               channel = queue.Queue()
+               server = http.server.HTTPServer(
+                       (ADDRESS, PORT),
+                       AuthCodeRequestHandler(channel, csrf_token)
+               )
+               Thread(target = server.serve_forever, daemon = True).start()
+
+               LOGGER.info("Attempting to launch a web browser to authorize the application…")
+               if not webbrowser.open(user_url):
+                       LOGGER.info("Failed to launch a browser. Please visit this URL to authorize the application:")
+                       LOGGER.info("    {}".format(user_url))
+
+               try:
+                       resp = channel.get(timeout = TIMEOUT.total_seconds())
+               except queue.Empty:
+                       resp = None
+               server.shutdown()
+
+               if resp is None:
+                       raise AuthError("Authorization timed out. Please try again.")
+               elif isinstance(resp, Exception):
+                       raise resp
+               else:
+                       auth_code = resp
+                       params = {
+                               'grant_type': 'authorization_code',
+                               'code': auth_code,
+                               'code_verifier': pkce_pair.verifier,
+                               'redirect_uri': redirect_uri,
+                       }
+                       request = urllib.request.Request(
+                               url = api_iface.build_api_url('v1/oauth2/token'),
+                               data = urllib.parse.urlencode(params).encode('utf-8'),
+                               method = 'POST',
+                       )
+                       with opener.open(request) as resp:
+                               assert resp.status == 200
+                               token_doc = json.loads(util.decode_body(resp))
+
+       return token_doc
diff --git a/util.py b/util.py
index 9ce2138007806e1dd9e512a95c06b7eff4a369f3..e2cd9ddb8effe478c567c0467e89f737c7742353 100644 (file)
--- a/util.py
+++ b/util.py
@@ -1,4 +1,4 @@
-# Copyright 2018, Anders Cornell and Jakob Cornell
+# Copyright 2018-2019, Jakob Cornell
 
 # This program is free software: you can redistribute it and/or modify
 # it under the terms of the GNU General Public License as published by
 # You should have received a copy of the GNU General Public License
 # along with this program.  If not, see <https://www.gnu.org/licenses/>.
 
-from collections import namedtuple
-import itertools
-import logging
-import pathlib
-import re
-import urllib.parse
+import urllib.request
 
-LOGGER_NAME = 'bb_sync_api'
+class HeaderAddHandler(urllib.request.BaseHandler):
+       def __init__(self, headers):
+               self.headers = headers
 
-def resolve(target, curr_url):
-       # I hope I do this like a browser...
-       parsed = urllib.parse.urlparse(target)
-       if parsed.scheme:
-               return target
-       elif target.startswith('/'):
-               curr = urllib.parse.urlparse(curr_url)
-               return curr.scheme + '://' + curr.netloc + target
-       else:
-               raise NotImplementedError("relative URI")
+       def _handle(self, request):
+               for name, value in self.headers:
+                       request.add_unredirected_header(name, value)
+               return request
 
-def extract_ids(url):
-       qs = urllib.parse.parse_qs(urllib.parse.urlparse(url).query)
-       [course_id] = qs['course_id']
-       [content_id] = qs['content_id']
-       return {
-               'course': course_id.strip('_'),
-               'content': content_id.strip('_'),
-       }
+       http_request = _handle
+       https_request = _handle
 
-def clean_win_path(seg):
-       # https://docs.microsoft.com/en-us/windows/desktop/FileIO/naming-a-file
-       bad = {*range(0, 31 + 1), *map(ord, r'<>:"/\|?*')}
-       return seg.translate({ch: ' ' for ch in bad})
-
-def _split_name(name):
-       suff_len = sum(len(s) for s in pathlib.Path(name).suffixes)
-       stem = name[slice(None, len(name) - suff_len)]
-       suff = name[slice(len(name) - suff_len, None)]
-       return (stem, suff)
-
-def content_path(course_path, segments):
-       path = course_path
-       for (id_, name) in segments:
-               if path.exists():
-                       cands = [child for child in path.iterdir() if re.search(r'\({}\)$'.format(re.escape(id_)), child.name)]
-                       if cands:
-                               [path] = cands
-                               continue
-               path = path.joinpath(clean_win_path('{}({})'.format(name, id_)))
-       return path
-
-_BB_ID_STREAM_NAME = '8f3b98ea-e227-478f-bb58-5c31db476409'
-
-ParseResult = namedtuple('ParseResult', ['id_', 'version'])
-
-def _parse_path(path):
-       (stem, _) = _split_name(path.name)
-       match = re.search(r'\((?P<id>[\d_]+,)?(?P<version>\d+)\)$', stem)
-       if match:
-               stream_path = path.with_name(path.name + ':' + _BB_ID_STREAM_NAME)
-               if stream_path.exists():
-                       with stream_path.open() as f:
-                               id_ = f.read()
+def decode_body(resp):
+       def get_charset():
+               if 'Content-Type' in resp.info():
+                       import cgi
+                       _, options = cgi.parse_header(resp.info()['Content-Type'])
+                       return options.get('charset')
                else:
-                       id_ = match.group('id')
-                       assert id_ is not None
-               return ParseResult(id_ = id_, version = match.group('version'))
-       else:
-               return None
-
-def unparse_path(parse_result):
-
-def get_latest_versions(content_path):
-       results = {}
-       for path in content_path.iterdir():
-               result = _parse_path(path)
-               if result and (result.id_ not in results or results[result._id] < result.version):
-                       results[result.id_] = result.version
-       return results
+                       return None
+       return resp.read().decode(get_charset() or 'utf-8')