diff --git a/youtube_transcript_api/formatters.py b/youtube_transcript_api/formatters.py index 986044f..387e565 100644 --- a/youtube_transcript_api/formatters.py +++ b/youtube_transcript_api/formatters.py @@ -79,8 +79,19 @@ class TextFormatter(Formatter): """ return '\n\n\n'.join([self.format_transcript(transcript, **kwargs) for transcript in transcripts]) +class _TextBasedFormatter(TextFormatter): + def _format_timestamp(self, hours, mins, secs, ms): + raise NotImplementedError('A subclass of _TextBasedFormatter must implement ' \ + 'their own .format_timestamp() method.') -class WebVTTFormatter(Formatter): + def _format_transcript_header(self, lines): + raise NotImplementedError('A subclass of _TextBasedFormatter must implement ' \ + 'their own _format_transcript_header method.') + + def _format_transcript_helper(self, i, time_text, line): + raise NotImplementedError('A subclass of _TextBasedFormatter must implement ' \ + 'their own _format_transcript_helper method.') + def _seconds_to_timestamp(self, time): """Helper that converts `time` into a transcript cue timestamp. @@ -95,44 +106,55 @@ class WebVTTFormatter(Formatter): '00:00:06.930' """ time = float(time) - hours, remainder = divmod(time, 3600) - mins, secs = divmod(remainder, 60) + hours_float, remainder = divmod(time, 3600) + mins_float, secs_float = divmod(remainder, 60) + hours, mins, secs = int(hours_float), int(mins_float), int(secs_float) ms = int(round((time - int(time))*1000, 2)) - return "{:02.0f}:{:02.0f}:{:02.0f}.{:03d}".format(hours, mins, secs, ms) + return self._format_timestamp(hours, mins, secs, ms) def format_transcript(self, transcript, **kwargs): - """A basic implementation of WEBVTT formatting. + """A basic implementation of WEBVTT/SRT formatting. :param transcript: - :reference: https://www.w3.org/TR/webvtt1/#introduction-caption + :reference: + https://www.w3.org/TR/webvtt1/#introduction-caption + https://www.3playmedia.com/blog/create-srt-file/ """ lines = [] for i, line in enumerate(transcript): - if i < len(transcript) - 1: - # Looks ahead, use next start time since duration value - # would create an overlap between start times. - time_text = "{} --> {}".format( - self._seconds_to_timestamp(line['start']), - self._seconds_to_timestamp(transcript[i + 1]['start']) + end = line['start'] + line['duration'] + time_text = "{} --> {}".format( + self._seconds_to_timestamp(line['start']), + self._seconds_to_timestamp( + transcript[i + 1]['start'] + if i < len(transcript) - 1 and transcript[i + 1]['start'] < end else end ) - else: - # Reached the end, cannot look ahead, use duration now. - duration = line['start'] + line['duration'] - time_text = "{} --> {}".format( - self._seconds_to_timestamp(line['start']), - self._seconds_to_timestamp(duration) - ) - lines.append("{}\n{}".format(time_text, line['text'])) + ) + lines.append(self._format_transcript_helper(i, time_text, line)) + return self._format_transcript_header(lines) + + +class SRTFormatter(_TextBasedFormatter): + def _format_timestamp(self, hours, mins, secs, ms): + return "{:02d}:{:02d}:{:02d},{:03d}".format(hours, mins, secs, ms) + + def _format_transcript_header(self, lines): + return "\n\n".join(lines) + "\n" + + def _format_transcript_helper(self, i, time_text, line): + return "{}\n{}\n{}".format(i + 1, time_text, line['text']) + + +class WebVTTFormatter(_TextBasedFormatter): + def _format_timestamp(self, hours, mins, secs, ms): + return "{:02d}:{:02d}:{:02d}.{:03d}".format(hours, mins, secs, ms) + + def _format_transcript_header(self, lines): return "WEBVTT\n\n" + "\n\n".join(lines) + "\n" - def format_transcripts(self, transcripts, **kwargs): - """A basic implementation of WEBVTT formatting for a list of transcripts. - - :param transcripts: - :reference: https://www.w3.org/TR/webvtt1/#introduction-caption - """ - return '\n\n\n'.join([self.format_transcript(transcript, **kwargs) for transcript in transcripts]) + def _format_transcript_helper(self, i, time_text, line): + return "{}\n{}".format(time_text, line['text']) class FormatterLoader(object): @@ -141,6 +163,7 @@ class FormatterLoader(object): 'pretty': PrettyPrintFormatter, 'text': TextFormatter, 'webvtt': WebVTTFormatter, + 'srt' : SRTFormatter, } class UnknownFormatterType(Exception): diff --git a/youtube_transcript_api/test/test_formatters.py b/youtube_transcript_api/test/test_formatters.py index 748ed02..b0b3ba2 100644 --- a/youtube_transcript_api/test/test_formatters.py +++ b/youtube_transcript_api/test/test_formatters.py @@ -8,6 +8,7 @@ from youtube_transcript_api.formatters import ( Formatter, JSONFormatter, TextFormatter, + SRTFormatter, WebVTTFormatter, PrettyPrintFormatter, FormatterLoader ) @@ -28,6 +29,38 @@ class TestFormatters(TestCase): with self.assertRaises(NotImplementedError): Formatter().format_transcripts([self.transcript]) + def test_srt_formatter_starting(self): + content = SRTFormatter().format_transcript(self.transcript) + lines = content.split('\n') + + # test starting lines + self.assertEqual(lines[0], "1") + self.assertEqual(lines[1], "00:00:00,000 --> 00:00:01,500") + + def test_srt_formatter_middle(self): + content = SRTFormatter().format_transcript(self.transcript) + lines = content.split('\n') + + # test middle lines + self.assertEqual(lines[4], "2") + self.assertEqual(lines[5], "00:00:01,500 --> 00:00:02,500") + self.assertEqual(lines[6], self.transcript[1]['text']) + + def test_srt_formatter_ending(self): + content = SRTFormatter().format_transcript(self.transcript) + lines = content.split('\n') + + # test ending lines + self.assertEqual(lines[-2], self.transcript[-1]['text']) + self.assertEqual(lines[-1], "") + + def test_srt_formatter_many(self): + formatter = SRTFormatter() + content = formatter.format_transcripts(self.transcripts) + formatted_single_transcript = formatter.format_transcript(self.transcript) + + self.assertEqual(content, formatted_single_transcript + '\n\n\n' + formatted_single_transcript) + def test_webvtt_formatter_starting(self): content = WebVTTFormatter().format_transcript(self.transcript) lines = content.split('\n')