From c4bd41d39c5eb4a0e1dfb32337d6e2bb6272fb1a Mon Sep 17 00:00:00 2001 From: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> Date: Fri, 7 Feb 2025 10:56:42 +0100 Subject: [PATCH] Test HfApiModel call with custom_role_conversions (#517) --- tests/test_models.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/tests/test_models.py b/tests/test_models.py index 4369844..b73e283 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -17,13 +17,13 @@ import os import unittest from pathlib import Path from typing import Optional -from unittest.mock import patch +from unittest.mock import MagicMock, patch import pytest from transformers.testing_utils import get_tests_dir from smolagents import ChatMessage, HfApiModel, TransformersModel, models, tool -from smolagents.models import get_clean_message_list, parse_json_if_needed +from smolagents.models import MessageRole, get_clean_message_list, parse_json_if_needed class ModelTests(unittest.TestCase): @@ -103,6 +103,19 @@ class ModelTests(unittest.TestCase): assert parsed_args == 3 +class TestHfApiModel: + def test_call_with_custom_role_conversions(self): + custom_role_conversions = {MessageRole.USER: MessageRole.SYSTEM} + model = HfApiModel(model_id="test-model", custom_role_conversions=custom_role_conversions) + model.client = MagicMock() + messages = [{"role": "user", "content": "Test message"}] + _ = model(messages) + # Verify that the role conversion was applied + assert model.client.chat_completion.call_args.kwargs["messages"][0]["role"] == "system", ( + "role conversion should be applied" + ) + + def test_get_clean_message_list_basic(): messages = [ {"role": "user", "content": [{"type": "text", "text": "Hello!"}]},