Implement roaming playlists and fix directory name regression
authorJakob Cornell <jakob+gpg@jcornell.net>
Tue, 6 Oct 2020 02:11:06 +0000 (21:11 -0500)
committerJakob Cornell <jakob+gpg@jcornell.net>
Tue, 6 Oct 2020 02:11:06 +0000 (21:11 -0500)
hls_watch/__init__.py
hls_watch/__main__.py
hls_watch/playlist.py [new file with mode: 0644]
hls_watch/test.py
hls_watch/util.py [new file with mode: 0644]

index 06ae9f6f6342bb34d8b59afe16f7b2e95ddbaca8..5d54ffc42f4ab5f7b79c4915add6cb7b0a4a83f2 100644 (file)
-from collections import namedtuple
-import datetime
 import itertools
-import re
 
-
-OUT_VERSION = 3
+from hls_watch.playlist import Discont, Endlist, M3u, Segment, TargDur, Version, unparse
+from hls_watch.util import _namedtuple
 
 
-class _TypeCompatEq:
-       def __eq__(self, other):
-               return (
-                       (
-                               isinstance(self, type(other))
-                               or isinstance(other, type(self))
-                       )
-                       and super().__eq__(other)
-               )
-
-def _namedtuple(name, *args):
-       # namedtuple with type-sensitive comparison
-       nt_base = namedtuple(name, *args)
-       return type(name, (_TypeCompatEq, nt_base), {})
+OUT_VERSION = 3
 
 
 Capture = _namedtuple('Capture', ['time', 'playlist'])
 
-
-_Playlist = _namedtuple('Playlist', ['media_start', 'targ_dur', 'is_end', 'contents'])
-
-class Playlist(_Playlist):
-       """ Playlist, with metadata tags separated from media stream data """
-
-       @classmethod
-       def from_entries(class_, entries):
-               es = [e for e in entries if isinstance(e, MediaSeq)]
-               if es:
-                       [e] = es
-                       media_start = e.number
-               else:
-                       media_start = 0
-
-               [e] = [e for e in entries if isinstance(e, TargDur)]
-               targ_dur = e.seconds
-
-               is_end = Endlist() in entries
-               contents = [e for e in entries if isinstance(e, (Segment, Discont))]
-
-               return class_(media_start, targ_dur, is_end, contents)
-
-       def media_end(self, contents = None):
-               if contents is None:
-                       contents = self.contents
-               return self.media_start + len([e for e in contents if isinstance(e, Segment)])
-
-       def intersect_right(self, other):
-               """
-               Try to find a point where `other' overlaps to the right of `self'.
-               Returns a playlist of the segments after the overlap.
-               """
-               if self.media_start <= other.media_start <= self.media_end():
-                       left, right = self.contents, other.contents
-                       for l in reversed(range(len(right) + 1)):
-                               l_split = len(left) - l
-                               match = (
-                                       right[:l] == left[l_split:]
-                                       and self.media_end(left[:l_split]) == other.media_start
-                               )
-                               if match and (self.is_end, other.is_end) != (True, False):
-                                       [targ_dur] = {self.targ_dur, other.targ_dur}
-                                       return Playlist(
-                                               media_start = other.media_end(right[:l]),
-                                               targ_dur = targ_dur,
-                                               is_end = other.is_end,
-                                               contents = right[l:],
-                                       )
-               return None
-
-       def __getitem__(self, s):
-               if isinstance(s, slice):
-                       assert s.step is None
-                       contents = self.contents[s]
-                       return self._replace(
-                               contents = contents,
-                               media_start = self.media_end(self.contents[:s.start]),
-                       )
-               else:
-                       return super().__getitem__(s)
-
-       def sync_crop(self):
-               """ Return a sub-playlist starting at the first sync point, if any """
-               if self.media_start == 0:
-                       return self
-               elif Discont() in self.contents:
-                       i = self.contents.index(Discont()) + 1
-                       return self[i:]
-               else:
-                       return None
-
-
-# playlist entry record types
-
-M3u = _namedtuple('M3u', [])
-Version = _namedtuple('Version', ['number'])
-
-Segment = _namedtuple('Segment', ['uri', 'seconds'])
-Discont = _namedtuple('Discont', [])
-Endlist = _namedtuple('Endlist', [])
-MediaSeq = _namedtuple('MediaSeq', ['number'])
-TargDur = _namedtuple('TargDur', ['seconds'])
-
-
-def parse(resp):
-       _SKIP = {
-               r'$',
-               r'#(?!EXT)', # comment
-       }
-
-       def gen_parsed(lines):
-               lines = iter(lines)
-               while True:
-                       try:
-                               line = next(lines)
-                       except StopIteration:
-                               break
-
-                       if re.match(r'#EXTM3U$', line):
-                               yield M3u()
-                               continue
-
-                       match = re.match(r'#EXT-X-VERSION:(?P<number>\d+)$', line)
-                       if match:
-                               yield Version(int(match.group('number')))
-                               continue
-
-                       match = re.match(r'#EXTINF:(?P<duration>[\d.]+),', line)
-                       if match:
-                               name = next(lines)
-                               yield Segment(name, match.group('duration'))
-                               continue
-
-                       if re.match(r'#EXT-X-DISCONTINUITY$', line):
-                               yield Discont()
-                               continue
-
-                       if re.match(r'#EXT-X-ENDLIST$', line):
-                               yield Endlist()
-                               continue
-
-                       match = re.match(r'#EXT-X-MEDIA-SEQUENCE:(?P<ord>\d+)$', line)
-                       if match:
-                               yield MediaSeq(int(match.group('ord')))
-                               continue
-
-                       match = re.match(r'#EXT-X-TARGETDURATION:(?P<duration>\d+)$', line)
-                       if match:
-                               yield TargDur(int(match.group('duration')))
-                               continue
-
-                       if not any(re.match(expr, line) for expr in _SKIP):
-                               raise NotImplementedError("Unhandled line: {}".format(line))
-
-       body = resp.read().decode('utf-8')
-       parsed = list(gen_parsed(body.splitlines()))
-       return Playlist.from_entries(parsed)
-
-def _unparse(entry):
-       if isinstance(entry, Segment):
-               return [
-                       '#EXTINF:{},'.format(entry.seconds),
-                       entry.uri,
-               ]
-       elif isinstance(entry, Discont):
-               return ['#EXT-X-DISCONTINUITY']
-       elif isinstance(entry, Endlist):
-               return ['#EXT-X-ENDLIST']
-       elif isinstance(entry, MediaSeq):
-               return ['#EXT-X-MEDIA-SEQUENCE:{}'.format(entry.number)]
-       elif isinstance(entry, Version):
-               return ['#EXT-X-VERSION:{}'.format(entry.number)]
-       elif isinstance(entry, M3u):
-               return ['#EXTM3U']
-       elif isinstance(entry, TargDur):
-               return ['#EXT-X-TARGETDURATION:{}'.format(entry.seconds)]
-       else:
-               raise NotImplementedError()
+class Stream(_namedtuple('Stream', ['endpoint', 'variant'])):
+       def full(self):
+               return '{}_{}'.format(self.endpoint, self.variant)
 
 
 class CaptureHandler:
        def __init__(self, writer):
-               self.writer = writer
-               self.sync_tail = None
+               self._writer = writer
+               self._sync_tail = None
                self.in_sess = False
 
        def _end_list(self):
@@ -197,7 +25,7 @@ class CaptureHandler:
 
        def on_404(self):
                self._end_list()
-               self.sync_tail = None
+               self._sync_tail = None
 
        def ensure_sess(self, in_sess, targ_dur = None):
                if not self.in_sess and in_sess:
@@ -207,23 +35,23 @@ class CaptureHandler:
                                TargDur(targ_dur),
                        ]
                        for e in header:
-                               self.writer.put_entry(e)
+                               self._writer.put_entry(e)
                        self.in_sess = True
                elif self.in_sess and not in_sess:
-                       self.writer.put_entry(Endlist())
-                       self.writer.close()
+                       self._writer.put_entry(Endlist())
+                       self._writer.close()
                        self.in_sess = False
 
        def update(self, capture):
-               if self.sync_tail:
-                       new = self.sync_tail.intersect_right(capture.playlist)
+               if self._sync_tail:
+                       new = self._sync_tail.intersect_right(capture.playlist)
                        if new:
                                to_process = new.contents
                        else:
                                # desynced
-                               self.sync_tail = None
+                               self._sync_tail = None
 
-               if not self.sync_tail:
+               if not self._sync_tail:
                        synced = capture.playlist.sync_crop()
                        if synced:
                                to_process = synced.contents
@@ -235,20 +63,20 @@ class CaptureHandler:
                        for e in to_process:
                                if isinstance(e, Segment):
                                        self.ensure_sess(True, capture.playlist.targ_dur)
-                                       self.writer.put_entry(e, time = capture.time)
+                                       self._writer.put_entry(e, time = capture.time)
                                elif isinstance(e, Discont):
                                        self.ensure_sess(False)
                                else:
                                        raise AssertionError()
 
-                       if self.sync_tail:
-                               new_end = capture.playlist.is_end and not self.sync_tail.is_end
+                       if self._sync_tail:
+                               new_end = capture.playlist.is_end and not self._sync_tail.is_end
                        else:
                                new_end = capture.playlist.is_end
                        if new_end:
                                self.ensure_sess(False)
 
-                       self.sync_tail = capture.playlist
+                       self._sync_tail = capture.playlist
 
        close = _end_list
 
@@ -278,19 +106,19 @@ class BufferedWriter:
 
 
 class PlaylistFileWriter:
-       def __init__(self, base_dir, endpoint):
+       def __init__(self, base_dir, stream):
                self.base_dir = base_dir
-               self.endpoint = endpoint
+               self.stream = stream
 
                self.file = None
 
        def _get_file(self, time):
                def path_for(ord_):
                        return self.base_dir.joinpath(
-                               '{time}{ord}_{endpoint}'.format(
+                               '{time}{ord}_{stream}'.format(
                                        time = time.astimezone().isoformat(),
                                        ord = '' if ord_ is None else '_' + int(ord_),
-                                       endpoint = self.endpoint,
+                                       stream = self.stream.full(),
                                ),
                                'index.m3u8',
                        )
@@ -302,7 +130,7 @@ class PlaylistFileWriter:
                                pass
 
        def _output(self, entry):
-               for line in _unparse(entry):
+               for line in unparse(entry):
                        print(line, file = self.file)
 
        def put_entry(self, entry, time = None):
index b0e1ec6bfbb840e0716db30da6e8bca6008db73a..0450a25a4e59d8776b5b1ed68695973dead62e61 100644 (file)
@@ -1,12 +1,14 @@
 import datetime
+import re
 
-from hls_watch import *
+from hls_watch import BufferedWriter, Capture, CaptureHandler, PlaylistFileWriter, Stream
+from hls_watch.playlist import parse_master, parse_media
 
 
 _DELAY = datetime.timedelta(seconds = 10)
 
 
-def get_time(resp):
+def _get_time(resp):
        time_str = resp.info().get('Date')
        time = None
        if time_str:
@@ -19,6 +21,10 @@ def get_time(resp):
        return None
 
 
+def _get_lines(http_resp):
+       return resp.read().decode('utf-8').splitlines()
+
+
 if __name__ == '__main__':
        from urllib.parse import urlparse, urlunparse
        from pathlib import Path, PurePosixPath
@@ -46,15 +52,18 @@ if __name__ == '__main__':
        with cfg_path.joinpath('config.toml').open() as f:
                config = toml.load(f)
 
-       try:
+       if '_' in args.endpoint:
                (endpoint, s_num) = args.endpoint.rsplit('_', 1)
-       except:
-               raise ValueError("The endpoint arg doesn't look correct")
+               stream_spec = Stream(endpoint, s_num)
+       else:
+               stream_spec = Stream(args.endpoint, None)
+
        base_url = urlparse(config['endpoints'][endpoint])
-       path_seg = '{}_{}.m3u8'.format(endpoint, s_num)
-       url = urlunparse(base_url._replace(
-               path = str(PurePosixPath(base_url.path).joinpath(path_seg)),
-       ))
+
+       def url_for_stream(file_name):
+               return urlunparse(base_url._replace(
+                       path = str(PurePosixPath(base_url.path).joinpath(file_name)),
+               ))
 
        # Follow redirects but don't throw on 4XX
        opener = urllib.request.OpenerDirector()
@@ -66,18 +75,41 @@ if __name__ == '__main__':
        for handler in handlers:
                opener.add_handler(handler)
 
-       writer = PlaylistFileWriter(base_path, endpoint)
+       curr_variant = stream_spec.variant
+       master_url = url_for_stream('{}.m3u8'.format(stream_spec.endpoint))
+
+       writer = PlaylistFileWriter(base_path, None)
        buff_writer = BufferedWriter(writer)
        handler = CaptureHandler(buff_writer)
        with contextlib.closing(handler):
                while True:
-                       with opener.open(url) as resp:
-                               if resp.status == 404:
-                                       handler.on_404()
-                               else:
+                       locate_stream = (
+                               curr_variant is None
+                               or stream_spec.variant is None and not handler.in_sess
+                       )
+                       if locate_stream:
+                               with opener.open(master_url) as resp:
                                        assert resp.status == 200
-                                       handler.update(Capture(
-                                               time = get_time(resp),
-                                               playlist = parse(resp),
-                                       ))
-                       time.sleep(_DELAY.total_seconds())
+                                       lines = resp.read().decode('utf-8').splitlines()
+                               time.sleep(_DELAY.total_seconds())
+
+                               streams = parse_master(lines)
+                               if streams:
+                                       winner = max(streams, key = lambda s: s.attrs['RESOLUTION'].vertical)
+                                       (ep, curr_variant) = re.match(r'(.*)_(.*?)\.m3u8$', winner.uri).groups()
+                                       assert ep == stream_spec.endpoint
+
+                       if curr_variant is not None:
+                               writer.stream = stream_spec._replace(variant = curr_variant)
+                               media_url = url_for_stream('{}.m3u8'.format(writer.stream.full()))
+                               with opener.open(media_url) as resp:
+                                       if resp.status == 404:
+                                               handler.on_404()
+                                       else:
+                                               assert resp.status == 200
+                                               lines = resp.read().decode('utf-8').splitlines()
+                                               handler.update(Capture(
+                                                       time = _get_time(resp),
+                                                       playlist = parse_media(lines),
+                                               ))
+                               time.sleep(_DELAY.total_seconds())
diff --git a/hls_watch/playlist.py b/hls_watch/playlist.py
new file mode 100644 (file)
index 0000000..36fe912
--- /dev/null
@@ -0,0 +1,227 @@
+import re
+
+from hls_watch.util import _namedtuple
+
+
+# playlist entry record types
+
+M3u = _namedtuple('M3u', [])
+Version = _namedtuple('Version', ['number'])
+IndSegments = _namedtuple('IndSegments', [])
+
+Segment = _namedtuple('Segment', ['uri', 'seconds'])
+Discont = _namedtuple('Discont', [])
+Endlist = _namedtuple('Endlist', [])
+MediaSeq = _namedtuple('MediaSeq', ['number'])
+TargDur = _namedtuple('TargDur', ['seconds'])
+
+StreamInf = _namedtuple('StreamInf', ['uri', 'attrs'])
+
+
+_Playlist = _namedtuple('Playlist', ['media_start', 'targ_dur', 'is_end', 'contents'])
+
+class Playlist(_Playlist):
+       """ Playlist, with metadata tags separated from media stream data """
+
+       @classmethod
+       def from_entries(class_, entries):
+               es = [e for e in entries if isinstance(e, MediaSeq)]
+               if es:
+                       [e] = es
+                       media_start = e.number
+               else:
+                       media_start = 0
+
+               [e] = [e for e in entries if isinstance(e, TargDur)]
+               targ_dur = e.seconds
+
+               is_end = Endlist() in entries
+               contents = [e for e in entries if isinstance(e, (Segment, Discont))]
+
+               return class_(media_start, targ_dur, is_end, contents)
+
+       def media_end(self, contents = None):
+               if contents is None:
+                       contents = self.contents
+               return self.media_start + len([e for e in contents if isinstance(e, Segment)])
+
+       def intersect_right(self, other):
+               """
+               Try to find a point where `other' overlaps to the right of `self'.
+               Returns a playlist of the segments after the overlap.
+               """
+               if self.media_start <= other.media_start <= self.media_end():
+                       left, right = self.contents, other.contents
+                       for l in reversed(range(len(right) + 1)):
+                               l_split = len(left) - l
+                               match = (
+                                       right[:l] == left[l_split:]
+                                       and self.media_end(left[:l_split]) == other.media_start
+                               )
+                               if match and (self.is_end, other.is_end) != (True, False):
+                                       [targ_dur] = {self.targ_dur, other.targ_dur}
+                                       return Playlist(
+                                               media_start = other.media_end(right[:l]),
+                                               targ_dur = targ_dur,
+                                               is_end = other.is_end,
+                                               contents = right[l:],
+                                       )
+               return None
+
+       def __getitem__(self, s):
+               if isinstance(s, slice):
+                       assert s.step is None
+                       contents = self.contents[s]
+                       return self._replace(
+                               contents = contents,
+                               media_start = self.media_end(self.contents[:s.start]),
+                       )
+               else:
+                       return super().__getitem__(s)
+
+       def sync_crop(self):
+               """ Return a sub-playlist starting at the first sync point, if any """
+               if self.media_start == 0:
+                       return self
+               elif Discont() in self.contents:
+                       i = self.contents.index(Discont()) + 1
+                       return self[i:]
+               else:
+                       return None
+
+
+HexSeq = _namedtuple('HexSeq', ['characters'])
+Resolution = _namedtuple('Resolution', ['horizontal', 'vertical'])
+
+def _parse_attrs(string):
+       def pop_key(string):
+               m = re.match(r'(?P<v>[A-Z\d-]*)(?P<rest>=.*)$', string)
+               return (m.group('v'), m.group('rest'))
+
+       def pop_value(string):
+               # decimal-integer, decimal-floating-point, and signed-decimal-floating-point
+               m = re.match(r'(?P<v>-?[\d.]+)(?P<rest>,.+|$)$', string)
+               if m:
+                       val = (float if '.' in m.group('v') else int)(m.group('v'))
+                       return (val, m.group('rest'))
+
+               # hexadecimal-sequence
+               m = re.match(r'0x(?P<v>[\dA-F]+)(?P<rest>,.+|$)$', string, re.IGNORECASE)
+               if m:
+                       return (HexSeq(m.group('v')), m.group('rest'))
+
+               # quoted-string
+               m = re.match('"(?P<v>[^"]*)"(?P<rest>,.+|$)$', string)
+               if m:
+                       return (m.group('v'), m.group('rest'))
+
+               # decimal-resolution
+               m = re.match(r'(?P<h>\d+)x(?P<v>\d+)(?P<rest>,.+|$)$', string)
+               if m:
+                       return Resolution._make(map(int, m.groups()))
+
+               m = re.match('(?P<v>[^,]*)(?P<rest>,.+|$)$', string)
+               return m.groups()
+
+       items = []
+       rest = string
+       while rest:
+               (key, rest) = pop_key(rest)
+               assert rest[0] == '='
+               (val, rest) = pop_value(rest[1:])
+               items.append((key, val))
+
+       attrs = dict(items)
+       assert len(attrs) == len(items)
+       return attrs
+
+
+def _gen_parsed(lines):
+       _SKIP = {
+               '$',
+               '#(?!EXT)',  # comment
+               '#EXT-X-INDEPENDENT-SEGMENTS$',
+       }
+
+       lines = iter(lines)
+       while True:
+               try:
+                       line = next(lines)
+               except StopIteration:
+                       break
+
+               if re.match(r'#EXTM3U$', line):
+                       yield M3u()
+                       continue
+
+               match = re.match(r'#EXT-X-VERSION:(?P<number>\d+)$', line)
+               if match:
+                       yield Version(int(match.group('number')))
+                       continue
+
+               match = re.match(r'#EXTINF:(?P<duration>[\d.]+),', line)
+               if match:
+                       name = next(lines)
+                       yield Segment(name, match.group('duration'))
+                       continue
+
+               if re.match(r'#EXT-X-DISCONTINUITY$', line):
+                       yield Discont()
+                       continue
+
+               if re.match(r'#EXT-X-ENDLIST$', line):
+                       yield Endlist()
+                       continue
+
+               match = re.match(r'#EXT-X-MEDIA-SEQUENCE:(?P<ord>\d+)$', line)
+               if match:
+                       yield MediaSeq(int(match.group('ord')))
+                       continue
+
+               match = re.match(r'#EXT-X-TARGETDURATION:(?P<duration>\d+)$', line)
+               if match:
+                       yield TargDur(int(match.group('duration')))
+                       continue
+
+               match = re.match('#EXT-X-STREAM-INF:(?P<attrs>.+)$', line)
+               if match:
+                       uri = next(lines).strip()
+                       yield StreamInf(uri, _parse_attrs(match.group('attrs')))
+                       continue
+
+               if not any(re.match(expr, line) for expr in _SKIP):
+                       raise NotImplementedError("Unhandled line: {}".format(line))
+
+
+def parse_media(lines):
+       parsed = list(_gen_parsed(lines))
+       return Playlist.from_entries(parsed)
+
+
+def parse_master(lines):
+       entries = _gen_parsed(lines)
+       streams = [e for e in entries if not isinstance(e, (M3u, Version))]
+       assert all(isinstance(e, StreamInf) for e in streams)
+       return streams
+
+
+def unparse(entry):
+       if isinstance(entry, Segment):
+               return [
+                       '#EXTINF:{},'.format(entry.seconds),
+                       entry.uri,
+               ]
+       elif isinstance(entry, Discont):
+               return ['#EXT-X-DISCONTINUITY']
+       elif isinstance(entry, Endlist):
+               return ['#EXT-X-ENDLIST']
+       elif isinstance(entry, MediaSeq):
+               return ['#EXT-X-MEDIA-SEQUENCE:{}'.format(entry.number)]
+       elif isinstance(entry, Version):
+               return ['#EXT-X-VERSION:{}'.format(entry.number)]
+       elif isinstance(entry, M3u):
+               return ['#EXTM3U']
+       elif isinstance(entry, TargDur):
+               return ['#EXT-X-TARGETDURATION:{}'.format(entry.seconds)]
+       else:
+               raise NotImplementedError()
index 63d9f65fa22ccba0856931b49579956c1eaf1286..1a1fb855f7f9521b1e44d3315f8affb707e97b9b 100644 (file)
@@ -1,6 +1,9 @@
 import unittest
 
-from hls_watch import *
+from hls_watch import BufferedWriter, Capture, CaptureHandler, OUT_VERSION
+from hls_watch.playlist import (
+       Discont, Endlist, M3u, MediaSeq, Playlist, Segment, TargDur, Version,
+)
 
 
 class MockWriter:
diff --git a/hls_watch/util.py b/hls_watch/util.py
new file mode 100644 (file)
index 0000000..40e46d5
--- /dev/null
@@ -0,0 +1,17 @@
+from collections import namedtuple
+
+
+class _TypeCompatEq:
+       def __eq__(self, other):
+               return (
+                       (
+                               isinstance(self, type(other))
+                               or isinstance(other, type(self))
+                       )
+                       and super().__eq__(other)
+               )
+
+def _namedtuple(name, *args, **kwargs):
+       # namedtuple with type-sensitive comparison
+       nt_base = namedtuple(name, *args, **kwargs)
+       return type(name, (_TypeCompatEq, nt_base), {})