Test get_clean_message_list (#448)
* Test get_clean_message_list * Test get_clean_message_list
This commit is contained in:
		
							parent
							
								
									b228ffa328
								
							
						
					
					
						commit
						cedf63cde7
					
				|  | @ -209,16 +209,18 @@ def get_clean_message_list( | ||||||
|             message["role"] = role_conversions[role] |             message["role"] = role_conversions[role] | ||||||
|         # encode images if needed |         # encode images if needed | ||||||
|         if isinstance(message["content"], list): |         if isinstance(message["content"], list): | ||||||
|             for i, element in enumerate(message["content"]): |             for element in message["content"]: | ||||||
|                 if element["type"] == "image": |                 if element["type"] == "image": | ||||||
|                     assert not flatten_messages_as_text, f"Cannot use images with {flatten_messages_as_text=}" |                     assert not flatten_messages_as_text, f"Cannot use images with {flatten_messages_as_text=}" | ||||||
|                     if convert_images_to_image_urls: |                     if convert_images_to_image_urls: | ||||||
|                         message["content"][i] = { |                         element.update( | ||||||
|                             "type": "image_url", |                             { | ||||||
|                             "image_url": {"url": make_image_url(encode_image_base64(element["image"]))}, |                                 "type": "image_url", | ||||||
|                         } |                                 "image_url": {"url": make_image_url(encode_image_base64(element.pop("image")))}, | ||||||
|  |                             } | ||||||
|  |                         ) | ||||||
|                     else: |                     else: | ||||||
|                         message["content"][i]["image"] = encode_image_base64(element["image"]) |                         element["image"] = encode_image_base64(element["image"]) | ||||||
| 
 | 
 | ||||||
|         if len(output_message_list) > 0 and message["role"] == output_message_list[-1]["role"]: |         if len(output_message_list) > 0 and message["role"] == output_message_list[-1]["role"]: | ||||||
|             assert isinstance(message["content"], list), "Error: wrong content:" + str(message["content"]) |             assert isinstance(message["content"], list), "Error: wrong content:" + str(message["content"]) | ||||||
|  |  | ||||||
|  | @ -17,12 +17,13 @@ import os | ||||||
| import unittest | import unittest | ||||||
| from pathlib import Path | from pathlib import Path | ||||||
| from typing import Optional | from typing import Optional | ||||||
|  | from unittest.mock import patch | ||||||
| 
 | 
 | ||||||
| import pytest | import pytest | ||||||
| from transformers.testing_utils import get_tests_dir | from transformers.testing_utils import get_tests_dir | ||||||
| 
 | 
 | ||||||
| from smolagents import ChatMessage, HfApiModel, TransformersModel, models, tool | from smolagents import ChatMessage, HfApiModel, TransformersModel, models, tool | ||||||
| from smolagents.models import parse_json_if_needed | from smolagents.models import get_clean_message_list, parse_json_if_needed | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class ModelTests(unittest.TestCase): | class ModelTests(unittest.TestCase): | ||||||
|  | @ -100,3 +101,81 @@ class ModelTests(unittest.TestCase): | ||||||
|         args = 3 |         args = 3 | ||||||
|         parsed_args = parse_json_if_needed(args) |         parsed_args = parse_json_if_needed(args) | ||||||
|         assert parsed_args == 3 |         assert parsed_args == 3 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def test_get_clean_message_list_basic(): | ||||||
|  |     messages = [ | ||||||
|  |         {"role": "user", "content": [{"type": "text", "text": "Hello!"}]}, | ||||||
|  |         {"role": "assistant", "content": [{"type": "text", "text": "Hi there!"}]}, | ||||||
|  |     ] | ||||||
|  |     result = get_clean_message_list(messages) | ||||||
|  |     assert len(result) == 2 | ||||||
|  |     assert result[0]["role"] == "user" | ||||||
|  |     assert result[0]["content"][0]["text"] == "Hello!" | ||||||
|  |     assert result[1]["role"] == "assistant" | ||||||
|  |     assert result[1]["content"][0]["text"] == "Hi there!" | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def test_get_clean_message_list_role_conversions(): | ||||||
|  |     messages = [ | ||||||
|  |         {"role": "tool-call", "content": [{"type": "text", "text": "Calling tool..."}]}, | ||||||
|  |         {"role": "tool-response", "content": [{"type": "text", "text": "Tool response"}]}, | ||||||
|  |     ] | ||||||
|  |     result = get_clean_message_list(messages, role_conversions={"tool-call": "assistant", "tool-response": "user"}) | ||||||
|  |     assert len(result) == 2 | ||||||
|  |     assert result[0]["role"] == "assistant" | ||||||
|  |     assert result[0]["content"][0]["text"] == "Calling tool..." | ||||||
|  |     assert result[1]["role"] == "user" | ||||||
|  |     assert result[1]["content"][0]["text"] == "Tool response" | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | @pytest.mark.parametrize( | ||||||
|  |     "convert_images_to_image_urls, expected_clean_message", | ||||||
|  |     [ | ||||||
|  |         ( | ||||||
|  |             False, | ||||||
|  |             { | ||||||
|  |                 "role": "user", | ||||||
|  |                 "content": [ | ||||||
|  |                     {"type": "image", "image": "encoded_image"}, | ||||||
|  |                     {"type": "image", "image": "second_encoded_image"}, | ||||||
|  |                 ], | ||||||
|  |             }, | ||||||
|  |         ), | ||||||
|  |         ( | ||||||
|  |             True, | ||||||
|  |             { | ||||||
|  |                 "role": "user", | ||||||
|  |                 "content": [ | ||||||
|  |                     {"type": "image_url", "image_url": {"url": "data:image/png;base64,encoded_image"}}, | ||||||
|  |                     {"type": "image_url", "image_url": {"url": "data:image/png;base64,second_encoded_image"}}, | ||||||
|  |                 ], | ||||||
|  |             }, | ||||||
|  |         ), | ||||||
|  |     ], | ||||||
|  | ) | ||||||
|  | def test_get_clean_message_list_image_encoding(convert_images_to_image_urls, expected_clean_message): | ||||||
|  |     messages = [ | ||||||
|  |         { | ||||||
|  |             "role": "user", | ||||||
|  |             "content": [{"type": "image", "image": b"image_data"}, {"type": "image", "image": b"second_image_data"}], | ||||||
|  |         } | ||||||
|  |     ] | ||||||
|  |     with patch("smolagents.models.encode_image_base64") as mock_encode: | ||||||
|  |         mock_encode.side_effect = ["encoded_image", "second_encoded_image"] | ||||||
|  |         result = get_clean_message_list(messages, convert_images_to_image_urls=convert_images_to_image_urls) | ||||||
|  |         mock_encode.assert_any_call(b"image_data") | ||||||
|  |         mock_encode.assert_any_call(b"second_image_data") | ||||||
|  |         assert len(result) == 1 | ||||||
|  |         assert result[0] == expected_clean_message | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def test_get_clean_message_list_flatten_messages_as_text(): | ||||||
|  |     messages = [ | ||||||
|  |         {"role": "user", "content": [{"type": "text", "text": "Hello!"}]}, | ||||||
|  |         {"role": "user", "content": [{"type": "text", "text": "How are you?"}]}, | ||||||
|  |     ] | ||||||
|  |     result = get_clean_message_list(messages, flatten_messages_as_text=True) | ||||||
|  |     assert len(result) == 1 | ||||||
|  |     assert result[0]["role"] == "user" | ||||||
|  |     assert result[0]["content"] == "Hello!How are you?" | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue