Redid errors for cookies, improved testing coverage
This commit is contained in:
		
							parent
							
								
									49ccba7f95
								
							
						
					
					
						commit
						31b8f4a179
					
				|  | @ -8,4 +8,6 @@ from ._errors import ( | ||||||
|     NotTranslatable, |     NotTranslatable, | ||||||
|     TranslationLanguageNotAvailable, |     TranslationLanguageNotAvailable, | ||||||
|     NoTranscriptAvailable, |     NoTranscriptAvailable, | ||||||
|  |     CookiePathInvalid, | ||||||
|  |     CookiesInvalid | ||||||
| ) | ) | ||||||
|  |  | ||||||
|  | @ -1,8 +1,10 @@ | ||||||
| import requests | import requests | ||||||
| try: | try: | ||||||
|     import http.cookiejar as cookiejar |     import http.cookiejar as cookiejar | ||||||
|  |     CookieLoadError = (FileNotFoundError, cookiejar.LoadError) | ||||||
| except ImportError: | except ImportError: | ||||||
|     import cookielib as cookiejar |     import cookielib as cookiejar | ||||||
|  |     CookieLoadError = IOError | ||||||
| 
 | 
 | ||||||
| from ._transcripts import TranscriptListFetcher | from ._transcripts import TranscriptListFetcher | ||||||
| 
 | 
 | ||||||
|  | @ -63,7 +65,7 @@ class YouTubeTranscriptApi(): | ||||||
|         """ |         """ | ||||||
|         with requests.Session() as http_client: |         with requests.Session() as http_client: | ||||||
|             if cookies: |             if cookies: | ||||||
|                 http_client.cookies = cls.load_cookies(cookies) |                 http_client.cookies = cls._load_cookies(cookies, video_id) | ||||||
|             http_client.proxies = proxies if proxies else {} |             http_client.proxies = proxies if proxies else {} | ||||||
|             return TranscriptListFetcher(http_client).fetch(video_id) |             return TranscriptListFetcher(http_client).fetch(video_id) | ||||||
| 
 | 
 | ||||||
|  | @ -126,15 +128,13 @@ class YouTubeTranscriptApi(): | ||||||
|         return cls.list_transcripts(video_id, proxies, cookies).find_transcript(languages).fetch() |         return cls.list_transcripts(video_id, proxies, cookies).find_transcript(languages).fetch() | ||||||
|      |      | ||||||
|     @classmethod |     @classmethod | ||||||
|     def load_cookies(cls, cookies): |     def _load_cookies(cls, cookies, video_id): | ||||||
|         cj = {} |         cookie_jar = {} | ||||||
|         try: |         try: | ||||||
|             cj = cookiejar.MozillaCookieJar() |             cookie_jar = cookiejar.MozillaCookieJar() | ||||||
|             cj.load(cookies) |             cookie_jar.load(cookies) | ||||||
|         except IOError as e: |         except CookieLoadError: | ||||||
|             raise CookiePathInvalid |             raise CookiePathInvalid(video_id) | ||||||
|         except FileNotFoundError as e: |         if not cookie_jar: | ||||||
|             raise CookiePathInvalid |             raise CookiesInvalid(video_id) | ||||||
|         if not cj: |         return cookie_jar  | ||||||
|             raise CookiesInvalid |  | ||||||
|         return cj  |  | ||||||
|  |  | ||||||
|  | @ -56,7 +56,7 @@ class TranslationLanguageNotAvailable(CouldNotRetrieveTranscript): | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class CookiePathInvalid(CouldNotRetrieveTranscript): | class CookiePathInvalid(CouldNotRetrieveTranscript): | ||||||
|     CAUSE_MESSAGE = 'Path to cookie file was not valid' |     CAUSE_MESSAGE = 'The provided cookie file was unable to be loaded' | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class CookiesInvalid(CouldNotRetrieveTranscript): | class CookiesInvalid(CouldNotRetrieveTranscript): | ||||||
|  |  | ||||||
|  | @ -15,6 +15,8 @@ from youtube_transcript_api import ( | ||||||
|     NoTranscriptAvailable, |     NoTranscriptAvailable, | ||||||
|     NotTranslatable, |     NotTranslatable, | ||||||
|     TranslationLanguageNotAvailable, |     TranslationLanguageNotAvailable, | ||||||
|  |     CookiePathInvalid, | ||||||
|  |     CookiesInvalid | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | @ -167,6 +169,20 @@ class TestYouTubeTranscriptApi(TestCase): | ||||||
|             ] |             ] | ||||||
|         ) |         ) | ||||||
|      |      | ||||||
|  |     def test_get_transcript__with_cookies(self): | ||||||
|  |         dirname, filename = os.path.split(os.path.abspath(__file__)) | ||||||
|  |         cookies = dirname + '/example_cookies.txt' | ||||||
|  |         transcript = YouTubeTranscriptApi.get_transcript('GJLlxj_dtq8', cookies=cookies) | ||||||
|  | 
 | ||||||
|  |         self.assertEqual( | ||||||
|  |             transcript, | ||||||
|  |             [ | ||||||
|  |                 {'text': 'Hey, this is just a test', 'start': 0.0, 'duration': 1.54}, | ||||||
|  |                 {'text': 'this is not the original transcript', 'start': 1.54, 'duration': 4.16}, | ||||||
|  |                 {'text': 'just something shorter, I made up for testing', 'start': 5.7, 'duration': 3.239} | ||||||
|  |             ] | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|     @patch('youtube_transcript_api.YouTubeTranscriptApi.get_transcript') |     @patch('youtube_transcript_api.YouTubeTranscriptApi.get_transcript') | ||||||
|     def test_get_transcripts(self, mock_get_transcript): |     def test_get_transcripts(self, mock_get_transcript): | ||||||
|         video_id_1 = 'video_id_1' |         video_id_1 = 'video_id_1' | ||||||
|  | @ -209,15 +225,16 @@ class TestYouTubeTranscriptApi(TestCase): | ||||||
|     def test_load_cookies(self): |     def test_load_cookies(self): | ||||||
|         dirname, filename = os.path.split(os.path.abspath(__file__)) |         dirname, filename = os.path.split(os.path.abspath(__file__)) | ||||||
|         cookies = dirname + '/example_cookies.txt' |         cookies = dirname + '/example_cookies.txt' | ||||||
|         session_cookies = YouTubeTranscriptApi.load_cookies(cookies) |         session_cookies = YouTubeTranscriptApi._load_cookies(cookies, 'GJLlxj_dtq8') | ||||||
|         self.assertEqual({'TEST_FIELD': 'TEST_VALUE'},  requests.utils.dict_from_cookiejar(session_cookies)) |         self.assertEqual({'TEST_FIELD': 'TEST_VALUE'},  requests.utils.dict_from_cookiejar(session_cookies)) | ||||||
| 
 | 
 | ||||||
|     def test_load_cookies__bad_files(self): |     def test_load_cookies__bad_file_path(self): | ||||||
|         bad_cookies = 'nonexistent_cookies.txt' |         bad_cookies = 'nonexistent_cookies.txt' | ||||||
|         with self.assertRaises(Exception): |         with self.assertRaises(CookiePathInvalid): | ||||||
|             YouTubeTranscriptApi.load_cookies(bad_cookies) |             YouTubeTranscriptApi._load_cookies(bad_cookies, 'GJLlxj_dtq8') | ||||||
| 
 | 
 | ||||||
|  |     def test_load_cookies__no_valid_cookies(self): | ||||||
|         dirname, filename = os.path.split(os.path.abspath(__file__)) |         dirname, filename = os.path.split(os.path.abspath(__file__)) | ||||||
|         expired_cookies = dirname + '/expired_example_cookies.txt' |         expired_cookies = dirname + '/expired_example_cookies.txt' | ||||||
|         with self.assertRaises(Exception): |         with self.assertRaises(CookiesInvalid): | ||||||
|             YouTubeTranscriptApi.load_cookies(expired_cookies) |             YouTubeTranscriptApi._load_cookies(expired_cookies, 'GJLlxj_dtq8') | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue