diff --git a/youtube_transcript_api/formatters.py b/youtube_transcript_api/formatters.py index a5f30cd..54f7815 100644 --- a/youtube_transcript_api/formatters.py +++ b/youtube_transcript_api/formatters.py @@ -1,13 +1,9 @@ -from abc import ABCMeta +from abc import ABC from abc import abstractclassmethod from collections import defaultdict import json import re -from xml.etree import ElementTree - -from ._html_unescaping import unescape - def parse_timecode(time): """Converts a `time` into a formatted transcript timecode. @@ -31,23 +27,37 @@ def parse_timecode(time): return f"{hours}:{mins}:{secs},{ms}" -class TranscriptFormatter(metaclass=ABCMeta): - """ - Abstract Base TranscriptFormatter class +class TranscriptFormatter(ABC): + """Abstract Base TranscriptFormatter class This class should be inherited from to create additional custom transcript formatters. - """ HTML_TAG_REGEX = re.compile(r'<[^>]*>', re.IGNORECASE) - + DELIMITER = '' + + @classmethod + def combine(cls, transcripts): + """Subclass may override this class method. + + Default behavior of this method will ''.join() the str() + of each transcript in transcripts. + + :param transcripts: a list of many transcripts + :type transcript_data: list[, ...] + :return: A string joined on the `cls.DELIMITER` to combine transcripts + :rtype: str + """ + return cls.DELIMITER.join( + str(transcript) for transcript in transcripts) + @abstractclassmethod def format(cls, transcript_data): """Any subclass must implement this format class method. :param transcript_data: a list of transcripts, 1 or more. :type transcript_data: list[list[dict], list[dict]] - :return: A list where each item is an individual transcript + :return: A list where each item is an individual transcript as a string. :rtype: list[str] """ @@ -56,9 +66,15 @@ class TranscriptFormatter(metaclass=ABCMeta): class JSONTranscriptFormatter(TranscriptFormatter): """Formatter for outputting JSON data""" + DELIMITER = ',' + + @classmethod + def combine(cls, transcripts): + return json.dumps(transcripts) + @classmethod def format(cls, transcript_data): - return [json.dumps(transcript_data)] if transcript_data else [] + return transcript_data class TextTranscriptFormatter(TranscriptFormatter): @@ -66,55 +82,56 @@ class TextTranscriptFormatter(TranscriptFormatter): Converts the fetched transcript data into separated lines of plain text separated by newline breaks (\n) with no timecodes. - """ + DELIMITER = '\n\n' + @classmethod def format(cls, transcript_data): - return ['\n'.join(line['text'] for transcript in transcript_data - for line in transcript)] + return '{}\n'.format('\n'.join( + line['text']for line in transcript_data)) class SRTTranscriptFormatter(TranscriptFormatter): """Formatter for outputting the SRT Format - - Converts the fetched transcript data into a simple .srt file format. + Converts the fetched transcript data into a simple .srt file format. """ + DELIMITER = '\n\n' + @classmethod def format(cls, transcript_data): - contents = [] - for transcript in transcript_data: - content = [] - for frame, item in enumerate(transcript, start=1): - start_time = float(item.get('start')) - duration = float(item.get('dur', '0.0')) + output = [] + for frame, item in enumerate(transcript_data, start=1): + start_time = float(item.get('start')) + duration = float(item.get('dur', '0.0')) - end_time = parse_timecode(start_time + duration) - start_time = parse_timecode(start_time) + end_time = parse_timecode(start_time + duration) + start_time = parse_timecode(start_time) - content.append("{frame}\n".format(frame=frame)) - content.append("{start_time} --> {end_time}\n".format( - start_time=start_time, end_time=end_time)) - content.append("{text}\n\n".format(text=item.get('text'))) - - contents.append(''.join(content)) - return ['\n\n'.join(contents)] + output.append("{frame}\n".format(frame=frame)) + output.append("{start_time} --> {end_time}\n".format( + start_time=start_time, end_time=end_time)) + output.append("{text}".format(text=item.get('text'))) + if frame < len(transcript_data): + output.append('\n\n') + + return '{}\n'.format(''.join(output)) class TranscriptFormatterFactory: """A Transcript Class Factory - + Allows for adding additional custom Transcript classes for the API - to use. Custom Transcript classes must inherit from the + to use. Custom Transcript classes must inherit from the TranscriptFormatter abstract base class. """ def __init__(self): self._formatters = defaultdict(JSONTranscriptFormatter) - + def add_formatter(self, name, formatter_class): """Allows for creating additional transcript formatters. - + :param name: a name given to the `formatter_class` :type name: str :param formatter_class: a subclass of TranscriptFormatter @@ -124,8 +141,8 @@ class TranscriptFormatterFactory: if not issubclass(formatter_class, TranscriptFormatter): raise TypeError( f'{formatter_class} must be a subclass of TranscriptFormatter') - self._formatters.update({name:formatter_class}) - + self._formatters.update({name: formatter_class}) + def add_formatters(self, formatters_dict): """Allow creation of multiple transcript formatters at a time. @@ -137,7 +154,7 @@ class TranscriptFormatterFactory: """ for name, formatter_class in formatters_dict.items(): self.add_formatter(name, formatter_class) - + def get_formatter(self, name): """Retrieve a formatter class by its assigned name.