--- /dev/null
+from collections import namedtuple
+import contextlib
+import datetime
+import itertools
+import re
+
+
+OUT_VERSION = 3
+
+
+class _TypeCompatEq:
+ def __eq__(self, other):
+ return (
+ (
+ isinstance(self, type(other))
+ or isinstance(other, type(self))
+ )
+ and super().__eq__(other)
+ )
+
+def _namedtuple(name, *args):
+ # namedtuple with type-sensitive comparison
+ nt_base = namedtuple(name, *args)
+ return type(name, (_TypeCompatEq, nt_base), {})
+
+
+Capture = _namedtuple('Capture', ['time', 'playlist'])
+
+
+_Playlist = _namedtuple('Playlist', ['media_start', 'targ_dur', 'is_end', 'contents'])
+
+class Playlist(_Playlist):
+ """ Playlist, with metadata tags separated from media stream data """
+
+ @classmethod
+ def from_entries(class_, entries):
+ es = [e for e in entries if isinstance(e, MediaSeq)]
+ if es:
+ [e] = es
+ media_start = e.number
+ else:
+ media_start = 0
+
+ [e] = [e for e in entries if isinstance(e, TargDur)]
+ targ_dur = e.seconds
+
+ is_end = any(isinstance(e, Endlist) for e in entries)
+ contents = [e for e in entries if isinstance(e, (Segment, Discont))]
+
+ return class_(media_start, targ_dur, is_end, contents)
+
+ def media_end(self, contents = None):
+ if contents is None:
+ contents = self.contents
+ return self.media_start + len([e for e in contents if isinstance(e, Segment)])
+
+ def intersect_right(self, other):
+ """
+ Try to find a point where `other' overlaps to the right of `self'.
+ Returns a playlist of the segments after the overlap.
+ """
+ if self.media_start <= other.media_start <= self.media_end():
+ left, right = self.contents, other.contents
+ for l in reversed(range(len(right) + 1)):
+ l_split = len(left) - l
+ match = (
+ right[:l] == left[l_split:]
+ and self.media_end(left[:l_split]) == other.media_start
+ )
+ if match and (self.is_end, other.is_end) != (True, False):
+ [targ_dur] = {self.targ_dur, other.targ_dur}
+ return Playlist(
+ media_start = other.media_end(right[:l]),
+ targ_dur = targ_dur,
+ is_end = other.is_end,
+ contents = right[l:],
+ )
+ return None
+
+ def __getitem__(self, s):
+ if isinstance(s, slice):
+ assert s.step is None
+ contents = self.contents[s]
+ return self._replace(
+ contents = contents,
+ media_start = self.media_end(self.contents[:s.start]),
+ )
+ else:
+ return super().__getitem__(s)
+
+ def sync_crop(self):
+ """ Return a sub-playlist starting at the first sync point, if any """
+ if self.media_start == 0:
+ return self
+ elif Discont() in self.contents:
+ i = self.contents.index(Discont()) + 1
+ return self[i:]
+ else:
+ return None
+
+
+# playlist entry record types
+
+M3u = _namedtuple('M3u', [])
+Version = _namedtuple('Version', ['number'])
+
+Segment = _namedtuple('Segment', ['uri', 'seconds'])
+Discont = _namedtuple('Discont', [])
+Endlist = _namedtuple('Endlist', [])
+MediaSeq = _namedtuple('MediaSeq', ['number'])
+TargDur = _namedtuple('TargDur', ['seconds'])
+
+
+def parse(resp):
+ _SKIP = {
+ r'$',
+ r'#(?!EXT)', # comment
+ }
+
+ time_str = resp.info().get('Date')
+ time = None
+ if time_str:
+ try:
+ time = datetime.datetime.strptime(time_str, '%a, %d %b %Y %H:%M:%S %Z')
+ except ValueError:
+ pass
+ else:
+ time = time.replace(tzinfo = datetime.timezone.utc)
+
+ def gen_parsed(lines):
+ lines = iter(lines)
+ while True:
+ try:
+ line = next(lines)
+ except StopIteration:
+ break
+
+ if re.match(r'#EXTM3U$', line):
+ yield M3u()
+ continue
+
+ match = re.match(r'#EXT-X-VERSION:(?P<number>\d+)$', line)
+ if match:
+ yield Version(int(match.group('number')))
+ continue
+
+ match = re.match(r'#EXTINF:(?P<duration>[\d.]+),', line)
+ if match:
+ name = next(lines)
+ yield Segment(name, match.group('duration'))
+ continue
+
+ if re.match(r'#EXT-X-DISCONTINUITY$', line):
+ yield Discont()
+ continue
+
+ if re.match(r'#EXT-X-ENDLIST$', line):
+ yield Endlist()
+ continue
+
+ match = re.match(r'#EXT-X-MEDIA-SEQUENCE:(?P<ord>\d+)$', line)
+ if match:
+ yield MediaSeq(int(match.group('ord')))
+ continue
+
+ match = re.match(r'#EXT-X-TARGETDURATION:(?P<duration>\d+)$', line)
+ if match:
+ yield TargDur(int(match.group('duration')))
+ continue
+
+ if not any(re.match(expr, line) for expr in _SKIP):
+ raise NotImplementedError("Unhandled line: {}".format(line))
+
+ body = resp.read().decode('utf-8')
+ parsed = list(gen_parsed(body.splitlines()))
+ playlist = Playlist.from_entries(parsed)
+ return Capture(time, playlist)
+
+def _unparse(entry):
+ if isinstance(entry, Segment):
+ return [
+ '#EXTINF:{},'.format(entry.seconds),
+ entry.uri,
+ ]
+ elif isinstance(entry, Discont):
+ return ['#EXT-X-DISCONTINUITY']
+ elif isinstance(entry, Endlist):
+ return ['#EXT-X-ENDLIST']
+ elif isinstance(entry, MediaSeq):
+ return ['#EXT-X-MEDIA-SEQUENCE:{}'.format(entry.number)]
+ elif isinstance(entry, Version):
+ return ['#EXT-X-VERSION:{}'.format(entry.number)]
+ elif isinstance(entry, M3u):
+ return ['#EXTM3U']
+ elif isinstance(entry, TargDur):
+ return ['#EXT-X-TARGETDURATION:{}'.format(entry.seconds)]
+ else:
+ raise NotImplementedError()
+
+
+class CaptureHandler:
+ def __init__(self, writer):
+ self.writer = writer
+ self.sync_tail = None
+ self.in_sess = False
+
+ def _end_list(self):
+ self.ensure_sess(False)
+
+ def on_404(self):
+ self._end_list()
+ self.sync_tail = None
+
+ def ensure_sess(self, in_sess, targ_dur = None):
+ if not self.in_sess and in_sess:
+ header = [
+ M3u(),
+ Version(OUT_VERSION),
+ TargDur(targ_dur),
+ ]
+ for e in header:
+ self.writer.put_entry(e)
+ self.in_sess = True
+ elif self.in_sess and not in_sess:
+ self.writer.put_entry(Endlist())
+ self.writer.close()
+ self.in_sess = False
+
+ def update(self, capture):
+ from functools import partial
+
+ if self.sync_tail:
+ new = self.sync_tail.intersect_right(capture.playlist)
+ if new:
+ to_process = new.contents
+ else:
+ # desynced
+ self.sync_tail = None
+
+ if not self.sync_tail:
+ synced = capture.playlist.sync_crop()
+ if synced:
+ to_process = synced.contents
+ else:
+ # no way to sync here
+ to_process = None
+
+ if to_process is not None:
+ for e in to_process:
+ if isinstance(e, Segment):
+ self.ensure_sess(True, capture.playlist.targ_dur)
+ self.writer.put_entry(e, time = capture.time)
+ elif isinstance(e, Discont):
+ self.ensure_sess(False)
+ else:
+ raise AssertionError()
+
+ if self.sync_tail:
+ new_end = capture.playlist.is_end and not self.sync_tail.is_end
+ else:
+ new_end = capture.playlist.is_end
+ if new_end:
+ self.ensure_sess(False)
+
+ self.sync_tail = capture.playlist
+
+ close = _end_list
+
+
+class BufferedWriter:
+ """ Buffers entries and writes them to downstream writer if and when a segment is sent """
+
+ def __init__(self, dest):
+ self.dest = dest
+ self.buffer = []
+ self.flushed = False
+
+ def put_entry(self, entry, time = None):
+ if isinstance(entry, Segment):
+ while self.buffer:
+ self.dest.put_entry(self.buffer.pop(0))
+ self.flushed = True
+ self.dest.put_entry(entry, time)
+ elif self.flushed and isinstance(entry, Endlist):
+ self.dest.put_entry(entry, time)
+ else:
+ self.buffer.append(entry)
+
+ def close(self):
+ self.dest.close()
+ self.flushed = False
+
+
+class PlaylistFileWriter:
+ def __init__(self, base_dir, endpoint):
+ self.base_dir = base_dir
+ self.endpoint = endpoint
+
+ self.file = None
+
+ def _get_file(self, time):
+ def path_for(ord_):
+ return self.base_dir.joinpath(
+ '{time}{ord}_{endpoint}'.format(
+ time = time.astimezone().isoformat(),
+ ord = '' if ord_ is None else '_' + int(ord_),
+ endpoint = self.endpoint,
+ ),
+ 'index.m3u8',
+ )
+ for path in map(path_for, itertools.chain([None], itertools.count())):
+ try:
+ path.parent.mkdir()
+ return path.open('x')
+ except OSError:
+ pass
+
+ def _output(self, entry):
+ for line in _unparse(entry):
+ print(line, file = self.file)
+
+ def put_entry(self, entry, time = None):
+ if not self.file:
+ self.file = self._get_file(time)
+ self._output(entry)
+
+ def close(self):
+ if self.file:
+ self.file.close()
+ self.file = None
--- /dev/null
+import unittest
+
+from hls_watch import *
+
+
+class MockWriter:
+ def __init__(self):
+ self.lists = []
+ self.curr_list = []
+
+ def put_entry(self, entry, time = None):
+ self.curr_list.append(entry)
+
+ def close(self):
+ if self.curr_list:
+ self.lists.append(self.curr_list)
+ self.curr_list = []
+
+
+class _404:
+ pass
+
+class TestCases(unittest.TestCase):
+ def _run_case(self, inputs, output):
+ writer = MockWriter()
+ buff_writer = BufferedWriter(writer)
+ handler = CaptureHandler(buff_writer)
+ for input_ in inputs:
+ if isinstance(input_, _404):
+ handler.on_404()
+ else:
+ handler.update(Capture(None, Playlist.from_entries(input_)))
+ handler.close()
+ self.assertEqual(writer.lists, output)
+
+ def test_static_beginning(self):
+ self._run_case(
+ [
+ [
+ M3u(),
+ Version(3),
+ TargDur(15),
+ Segment('seg0', 12),
+ Segment('seg1', 12),
+ Segment('seg2', 12),
+ ],
+ [
+ M3u(),
+ Version(3),
+ TargDur(15),
+ Segment('seg0', 12),
+ Segment('seg1', 12),
+ Segment('seg2', 12),
+ ],
+ ],
+ [
+ [
+ M3u(),
+ Version(OUT_VERSION),
+ TargDur(15),
+ Segment('seg0', 12),
+ Segment('seg1', 12),
+ Segment('seg2', 12),
+ Endlist(),
+ ]
+ ]
+ )
+
+ def test_static_mid(self):
+ self._run_case(
+ [
+ [
+ M3u(),
+ Version(3),
+ TargDur(15),
+ MediaSeq(12),
+ Segment('seg12', 12),
+ Segment('seg13', 12),
+ Segment('seg14', 12),
+ ],
+ [
+ M3u(),
+ Version(3),
+ TargDur(15),
+ MediaSeq(14),
+ Segment('seg14', 12),
+ Segment('seg15', 12),
+ Segment('seg16', 12),
+ ],
+ ],
+ []
+ )
+
+ def test_in_on_discont(self):
+ self._run_case(
+ [
+ [
+ M3u(),
+ Version(3),
+ TargDur(15),
+ MediaSeq(100),
+ Segment('seg100', 12),
+ Segment('seg101', 4),
+ Discont(),
+ Segment('seg102', 15),
+ Segment('seg103', 12),
+ ],
+ [
+ M3u(),
+ Version(3),
+ TargDur(15),
+ MediaSeq(102),
+ Discont(),
+ Segment('seg102', 15),
+ Segment('seg103', 12),
+ Segment('seg104', 12),
+ Segment('seg105', 12),
+ ],
+ ],
+ [
+ [
+ M3u(),
+ Version(OUT_VERSION),
+ TargDur(15),
+ Segment('seg102', 15),
+ Segment('seg103', 12),
+ Segment('seg104', 12),
+ Segment('seg105', 12),
+ Endlist(),
+ ],
+ ]
+ )
+
+ def test_media_seq(self):
+ self._run_case(
+ [
+ [
+ M3u(),
+ Version(3),
+ TargDur(15),
+ MediaSeq(0),
+ Segment('seg0', 12),
+ Segment('seg1', 12),
+ ],
+ ],
+ [
+ [
+ M3u(),
+ Version(OUT_VERSION),
+ TargDur(15),
+ Segment('seg0', 12),
+ Segment('seg1', 12),
+ Endlist(),
+ ]
+ ]
+ )
+
+ def test_threshold_joins(self):
+ self._run_case(
+ [
+ [
+ M3u(),
+ Version(3),
+ TargDur(15),
+ Segment('seg0', 12),
+ ],
+ [
+ M3u(),
+ Version(3),
+ TargDur(15),
+ MediaSeq(1),
+ Segment('seg1', 12),
+ Segment('seg2', 12),
+ ],
+ [
+ M3u(),
+ Version(3),
+ TargDur(15),
+ MediaSeq(3),
+ Segment('seg3', 12),
+ ],
+ ],
+ [
+ [
+ M3u(),
+ Version(OUT_VERSION),
+ TargDur(15),
+ Segment('seg0', 12),
+ Segment('seg1', 12),
+ Segment('seg2', 12),
+ Segment('seg3', 12),
+ Endlist(),
+ ],
+ ]
+ )
+
+ def test_404(self):
+ self._run_case(
+ [
+ [
+ M3u(),
+ Version(3),
+ TargDur(15),
+ Segment('seg0', 12),
+ Segment('seg1', 12),
+ Endlist(),
+ ],
+ _404(),
+ [
+ M3u(),
+ Version(3),
+ TargDur(13),
+ Segment('seg-0-0', 11),
+ Segment('seg-0-1', 11),
+ ],
+ ],
+ [
+ [
+ M3u(),
+ Version(OUT_VERSION),
+ TargDur(15),
+ Segment('seg0', 12),
+ Segment('seg1', 12),
+ Endlist(),
+ ],
+ [
+ M3u(),
+ Version(OUT_VERSION),
+ TargDur(13),
+ Segment('seg-0-0', 11),
+ Segment('seg-0-1', 11),
+ Endlist(),
+ ],
+ ]
+ )
+
+ def test_404_resync(self):
+ self._run_case(
+ [
+ [
+ M3u(),
+ Version(3),
+ TargDur(15),
+ Segment('seg-0-0', 12),
+ Endlist(),
+ ],
+ _404(),
+ [
+ M3u(),
+ Version(3),
+ TargDur(100),
+ MediaSeq(128),
+ Segment('ign-128', 50),
+ Segment('ign-129', 50),
+ ],
+ [
+ M3u(),
+ Version(3),
+ TargDur(100),
+ MediaSeq(129),
+ Segment('ign-129', 50),
+ Segment('ign-130', 50),
+ Discont(),
+ ],
+ [
+ M3u(),
+ Version(3),
+ TargDur(100),
+ MediaSeq(131),
+ Segment('seg-1-131', 50),
+ Segment('seg-1-132', 50),
+ ],
+ ],
+ [
+ [
+ M3u(),
+ Version(OUT_VERSION),
+ TargDur(15),
+ Segment('seg-0-0', 12),
+ Endlist(),
+ ],
+ [
+ M3u(),
+ Version(OUT_VERSION),
+ TargDur(100),
+ Segment('seg-1-131', 50),
+ Segment('seg-1-132', 50),
+ Endlist(),
+ ],
+ ]
+ )
+
+
+if __name__ == '__main__':
+ unittest.main()