From a050b7db0de646aa681a5bad7aa24d9a169f7903 Mon Sep 17 00:00:00 2001 From: Jakob Cornell Date: Fri, 3 Jan 2020 13:13:48 -0600 Subject: [PATCH] Initial commit --- .gitignore | 1 + hls_watch/__init__.py | 330 ++++++++++++++++++++++++++++++++++++++++++ hls_watch/__main__.py | 49 +++++++ hls_watch/test.py | 295 +++++++++++++++++++++++++++++++++++++ setup.py | 8 + 5 files changed, 683 insertions(+) create mode 100644 .gitignore create mode 100644 hls_watch/__init__.py create mode 100644 hls_watch/__main__.py create mode 100644 hls_watch/test.py create mode 100644 setup.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..c18dd8d --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +__pycache__/ diff --git a/hls_watch/__init__.py b/hls_watch/__init__.py new file mode 100644 index 0000000..c7981f4 --- /dev/null +++ b/hls_watch/__init__.py @@ -0,0 +1,330 @@ +from collections import namedtuple +import contextlib +import datetime +import itertools +import re + + +OUT_VERSION = 3 + + +class _TypeCompatEq: + def __eq__(self, other): + return ( + ( + isinstance(self, type(other)) + or isinstance(other, type(self)) + ) + and super().__eq__(other) + ) + +def _namedtuple(name, *args): + # namedtuple with type-sensitive comparison + nt_base = namedtuple(name, *args) + return type(name, (_TypeCompatEq, nt_base), {}) + + +Capture = _namedtuple('Capture', ['time', 'playlist']) + + +_Playlist = _namedtuple('Playlist', ['media_start', 'targ_dur', 'is_end', 'contents']) + +class Playlist(_Playlist): + """ Playlist, with metadata tags separated from media stream data """ + + @classmethod + def from_entries(class_, entries): + es = [e for e in entries if isinstance(e, MediaSeq)] + if es: + [e] = es + media_start = e.number + else: + media_start = 0 + + [e] = [e for e in entries if isinstance(e, TargDur)] + targ_dur = e.seconds + + is_end = any(isinstance(e, Endlist) for e in entries) + contents = [e for e in entries if isinstance(e, (Segment, Discont))] + + return class_(media_start, targ_dur, is_end, contents) + + def media_end(self, contents = None): + if contents is None: + contents = self.contents + return self.media_start + len([e for e in contents if isinstance(e, Segment)]) + + def intersect_right(self, other): + """ + Try to find a point where `other' overlaps to the right of `self'. + Returns a playlist of the segments after the overlap. + """ + if self.media_start <= other.media_start <= self.media_end(): + left, right = self.contents, other.contents + for l in reversed(range(len(right) + 1)): + l_split = len(left) - l + match = ( + right[:l] == left[l_split:] + and self.media_end(left[:l_split]) == other.media_start + ) + if match and (self.is_end, other.is_end) != (True, False): + [targ_dur] = {self.targ_dur, other.targ_dur} + return Playlist( + media_start = other.media_end(right[:l]), + targ_dur = targ_dur, + is_end = other.is_end, + contents = right[l:], + ) + return None + + def __getitem__(self, s): + if isinstance(s, slice): + assert s.step is None + contents = self.contents[s] + return self._replace( + contents = contents, + media_start = self.media_end(self.contents[:s.start]), + ) + else: + return super().__getitem__(s) + + def sync_crop(self): + """ Return a sub-playlist starting at the first sync point, if any """ + if self.media_start == 0: + return self + elif Discont() in self.contents: + i = self.contents.index(Discont()) + 1 + return self[i:] + else: + return None + + +# playlist entry record types + +M3u = _namedtuple('M3u', []) +Version = _namedtuple('Version', ['number']) + +Segment = _namedtuple('Segment', ['uri', 'seconds']) +Discont = _namedtuple('Discont', []) +Endlist = _namedtuple('Endlist', []) +MediaSeq = _namedtuple('MediaSeq', ['number']) +TargDur = _namedtuple('TargDur', ['seconds']) + + +def parse(resp): + _SKIP = { + r'$', + r'#(?!EXT)', # comment + } + + time_str = resp.info().get('Date') + time = None + if time_str: + try: + time = datetime.datetime.strptime(time_str, '%a, %d %b %Y %H:%M:%S %Z') + except ValueError: + pass + else: + time = time.replace(tzinfo = datetime.timezone.utc) + + def gen_parsed(lines): + lines = iter(lines) + while True: + try: + line = next(lines) + except StopIteration: + break + + if re.match(r'#EXTM3U$', line): + yield M3u() + continue + + match = re.match(r'#EXT-X-VERSION:(?P\d+)$', line) + if match: + yield Version(int(match.group('number'))) + continue + + match = re.match(r'#EXTINF:(?P[\d.]+),', line) + if match: + name = next(lines) + yield Segment(name, match.group('duration')) + continue + + if re.match(r'#EXT-X-DISCONTINUITY$', line): + yield Discont() + continue + + if re.match(r'#EXT-X-ENDLIST$', line): + yield Endlist() + continue + + match = re.match(r'#EXT-X-MEDIA-SEQUENCE:(?P\d+)$', line) + if match: + yield MediaSeq(int(match.group('ord'))) + continue + + match = re.match(r'#EXT-X-TARGETDURATION:(?P\d+)$', line) + if match: + yield TargDur(int(match.group('duration'))) + continue + + if not any(re.match(expr, line) for expr in _SKIP): + raise NotImplementedError("Unhandled line: {}".format(line)) + + body = resp.read().decode('utf-8') + parsed = list(gen_parsed(body.splitlines())) + playlist = Playlist.from_entries(parsed) + return Capture(time, playlist) + +def _unparse(entry): + if isinstance(entry, Segment): + return [ + '#EXTINF:{},'.format(entry.seconds), + entry.uri, + ] + elif isinstance(entry, Discont): + return ['#EXT-X-DISCONTINUITY'] + elif isinstance(entry, Endlist): + return ['#EXT-X-ENDLIST'] + elif isinstance(entry, MediaSeq): + return ['#EXT-X-MEDIA-SEQUENCE:{}'.format(entry.number)] + elif isinstance(entry, Version): + return ['#EXT-X-VERSION:{}'.format(entry.number)] + elif isinstance(entry, M3u): + return ['#EXTM3U'] + elif isinstance(entry, TargDur): + return ['#EXT-X-TARGETDURATION:{}'.format(entry.seconds)] + else: + raise NotImplementedError() + + +class CaptureHandler: + def __init__(self, writer): + self.writer = writer + self.sync_tail = None + self.in_sess = False + + def _end_list(self): + self.ensure_sess(False) + + def on_404(self): + self._end_list() + self.sync_tail = None + + def ensure_sess(self, in_sess, targ_dur = None): + if not self.in_sess and in_sess: + header = [ + M3u(), + Version(OUT_VERSION), + TargDur(targ_dur), + ] + for e in header: + self.writer.put_entry(e) + self.in_sess = True + elif self.in_sess and not in_sess: + self.writer.put_entry(Endlist()) + self.writer.close() + self.in_sess = False + + def update(self, capture): + from functools import partial + + if self.sync_tail: + new = self.sync_tail.intersect_right(capture.playlist) + if new: + to_process = new.contents + else: + # desynced + self.sync_tail = None + + if not self.sync_tail: + synced = capture.playlist.sync_crop() + if synced: + to_process = synced.contents + else: + # no way to sync here + to_process = None + + if to_process is not None: + for e in to_process: + if isinstance(e, Segment): + self.ensure_sess(True, capture.playlist.targ_dur) + self.writer.put_entry(e, time = capture.time) + elif isinstance(e, Discont): + self.ensure_sess(False) + else: + raise AssertionError() + + if self.sync_tail: + new_end = capture.playlist.is_end and not self.sync_tail.is_end + else: + new_end = capture.playlist.is_end + if new_end: + self.ensure_sess(False) + + self.sync_tail = capture.playlist + + close = _end_list + + +class BufferedWriter: + """ Buffers entries and writes them to downstream writer if and when a segment is sent """ + + def __init__(self, dest): + self.dest = dest + self.buffer = [] + self.flushed = False + + def put_entry(self, entry, time = None): + if isinstance(entry, Segment): + while self.buffer: + self.dest.put_entry(self.buffer.pop(0)) + self.flushed = True + self.dest.put_entry(entry, time) + elif self.flushed and isinstance(entry, Endlist): + self.dest.put_entry(entry, time) + else: + self.buffer.append(entry) + + def close(self): + self.dest.close() + self.flushed = False + + +class PlaylistFileWriter: + def __init__(self, base_dir, endpoint): + self.base_dir = base_dir + self.endpoint = endpoint + + self.file = None + + def _get_file(self, time): + def path_for(ord_): + return self.base_dir.joinpath( + '{time}{ord}_{endpoint}'.format( + time = time.astimezone().isoformat(), + ord = '' if ord_ is None else '_' + int(ord_), + endpoint = self.endpoint, + ), + 'index.m3u8', + ) + for path in map(path_for, itertools.chain([None], itertools.count())): + try: + path.parent.mkdir() + return path.open('x') + except OSError: + pass + + def _output(self, entry): + for line in _unparse(entry): + print(line, file = self.file) + + def put_entry(self, entry, time = None): + if not self.file: + self.file = self._get_file(time) + self._output(entry) + + def close(self): + if self.file: + self.file.close() + self.file = None diff --git a/hls_watch/__main__.py b/hls_watch/__main__.py new file mode 100644 index 0000000..25c4e97 --- /dev/null +++ b/hls_watch/__main__.py @@ -0,0 +1,49 @@ +from hls_watch import * + + +_DELAY = datetime.timedelta(seconds = 10) + + +if __name__ == '__main__': + import argparse + import contextlib + import pathlib + import time + import urllib.parse + import urllib.request + + ap = argparse.ArgumentParser() + ap.add_argument("base_path") + ap.add_argument("url") + + args = ap.parse_args() + + base_path = pathlib.Path(args.base_path) + assert base_path.is_dir() + + url = args.url + url_path = pathlib.PurePosixPath(urllib.parse.urlparse(url).path) + endpoint = url_path.stem + + # Follow redirects but don't throw on 4XX + opener = urllib.request.OpenerDirector() + handlers = { + urllib.request.HTTPHandler(), + urllib.request.HTTPSHandler(), + urllib.request.HTTPRedirectHandler(), + } + for handler in handlers: + opener.add_handler(handler) + + writer = PlaylistFileWriter(base_path, endpoint) + buff_writer = BufferedWriter(writer) + handler = CaptureHandler(buff_writer) + with contextlib.closing(handler): + while True: + with opener.open(url) as resp: + if resp.status == 404: + handler.on_404() + else: + assert resp.status == 200 + handler.update(parse(resp)) + time.sleep(_DELAY.total_seconds()) diff --git a/hls_watch/test.py b/hls_watch/test.py new file mode 100644 index 0000000..aac9fe1 --- /dev/null +++ b/hls_watch/test.py @@ -0,0 +1,295 @@ +import unittest + +from hls_watch import * + + +class MockWriter: + def __init__(self): + self.lists = [] + self.curr_list = [] + + def put_entry(self, entry, time = None): + self.curr_list.append(entry) + + def close(self): + if self.curr_list: + self.lists.append(self.curr_list) + self.curr_list = [] + + +class _404: + pass + +class TestCases(unittest.TestCase): + def _run_case(self, inputs, output): + writer = MockWriter() + buff_writer = BufferedWriter(writer) + handler = CaptureHandler(buff_writer) + for input_ in inputs: + if isinstance(input_, _404): + handler.on_404() + else: + handler.update(Capture(None, Playlist.from_entries(input_))) + handler.close() + self.assertEqual(writer.lists, output) + + def test_static_beginning(self): + self._run_case( + [ + [ + M3u(), + Version(3), + TargDur(15), + Segment('seg0', 12), + Segment('seg1', 12), + Segment('seg2', 12), + ], + [ + M3u(), + Version(3), + TargDur(15), + Segment('seg0', 12), + Segment('seg1', 12), + Segment('seg2', 12), + ], + ], + [ + [ + M3u(), + Version(OUT_VERSION), + TargDur(15), + Segment('seg0', 12), + Segment('seg1', 12), + Segment('seg2', 12), + Endlist(), + ] + ] + ) + + def test_static_mid(self): + self._run_case( + [ + [ + M3u(), + Version(3), + TargDur(15), + MediaSeq(12), + Segment('seg12', 12), + Segment('seg13', 12), + Segment('seg14', 12), + ], + [ + M3u(), + Version(3), + TargDur(15), + MediaSeq(14), + Segment('seg14', 12), + Segment('seg15', 12), + Segment('seg16', 12), + ], + ], + [] + ) + + def test_in_on_discont(self): + self._run_case( + [ + [ + M3u(), + Version(3), + TargDur(15), + MediaSeq(100), + Segment('seg100', 12), + Segment('seg101', 4), + Discont(), + Segment('seg102', 15), + Segment('seg103', 12), + ], + [ + M3u(), + Version(3), + TargDur(15), + MediaSeq(102), + Discont(), + Segment('seg102', 15), + Segment('seg103', 12), + Segment('seg104', 12), + Segment('seg105', 12), + ], + ], + [ + [ + M3u(), + Version(OUT_VERSION), + TargDur(15), + Segment('seg102', 15), + Segment('seg103', 12), + Segment('seg104', 12), + Segment('seg105', 12), + Endlist(), + ], + ] + ) + + def test_media_seq(self): + self._run_case( + [ + [ + M3u(), + Version(3), + TargDur(15), + MediaSeq(0), + Segment('seg0', 12), + Segment('seg1', 12), + ], + ], + [ + [ + M3u(), + Version(OUT_VERSION), + TargDur(15), + Segment('seg0', 12), + Segment('seg1', 12), + Endlist(), + ] + ] + ) + + def test_threshold_joins(self): + self._run_case( + [ + [ + M3u(), + Version(3), + TargDur(15), + Segment('seg0', 12), + ], + [ + M3u(), + Version(3), + TargDur(15), + MediaSeq(1), + Segment('seg1', 12), + Segment('seg2', 12), + ], + [ + M3u(), + Version(3), + TargDur(15), + MediaSeq(3), + Segment('seg3', 12), + ], + ], + [ + [ + M3u(), + Version(OUT_VERSION), + TargDur(15), + Segment('seg0', 12), + Segment('seg1', 12), + Segment('seg2', 12), + Segment('seg3', 12), + Endlist(), + ], + ] + ) + + def test_404(self): + self._run_case( + [ + [ + M3u(), + Version(3), + TargDur(15), + Segment('seg0', 12), + Segment('seg1', 12), + Endlist(), + ], + _404(), + [ + M3u(), + Version(3), + TargDur(13), + Segment('seg-0-0', 11), + Segment('seg-0-1', 11), + ], + ], + [ + [ + M3u(), + Version(OUT_VERSION), + TargDur(15), + Segment('seg0', 12), + Segment('seg1', 12), + Endlist(), + ], + [ + M3u(), + Version(OUT_VERSION), + TargDur(13), + Segment('seg-0-0', 11), + Segment('seg-0-1', 11), + Endlist(), + ], + ] + ) + + def test_404_resync(self): + self._run_case( + [ + [ + M3u(), + Version(3), + TargDur(15), + Segment('seg-0-0', 12), + Endlist(), + ], + _404(), + [ + M3u(), + Version(3), + TargDur(100), + MediaSeq(128), + Segment('ign-128', 50), + Segment('ign-129', 50), + ], + [ + M3u(), + Version(3), + TargDur(100), + MediaSeq(129), + Segment('ign-129', 50), + Segment('ign-130', 50), + Discont(), + ], + [ + M3u(), + Version(3), + TargDur(100), + MediaSeq(131), + Segment('seg-1-131', 50), + Segment('seg-1-132', 50), + ], + ], + [ + [ + M3u(), + Version(OUT_VERSION), + TargDur(15), + Segment('seg-0-0', 12), + Endlist(), + ], + [ + M3u(), + Version(OUT_VERSION), + TargDur(100), + Segment('seg-1-131', 50), + Segment('seg-1-132', 50), + Endlist(), + ], + ] + ) + + +if __name__ == '__main__': + unittest.main() diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..ebf9e47 --- /dev/null +++ b/setup.py @@ -0,0 +1,8 @@ +from setuptools import setup, find_packages + +setup( + name = "hls_watch", + version = "0.0.0", + packages = find_packages(), + install_requires = [], +) -- 2.30.2