diff --git a/src/smolagents/memory.py b/src/smolagents/memory.py index 5b38e48..80c8c7d 100644 --- a/src/smolagents/memory.py +++ b/src/smolagents/memory.py @@ -37,9 +37,8 @@ class ToolCall: } +@dataclass class MemoryStep: - raw: Any # This is a placeholder for the raw data that the agent logs - def dict(self): return asdict(self) diff --git a/tests/test_memory.py b/tests/test_memory.py index b2f2ffc..0362272 100644 --- a/tests/test_memory.py +++ b/tests/test_memory.py @@ -1,7 +1,10 @@ +import pytest + from smolagents.memory import ( ActionStep, AgentMemory, ChatMessage, + MemoryStep, Message, MessageRole, PlanningStep, @@ -18,6 +21,21 @@ class TestAgentMemory: assert memory.steps == [] +class TestMemoryStep: + def test_initialization(self): + step = MemoryStep() + assert isinstance(step, MemoryStep) + + def test_dict(self): + step = MemoryStep() + assert step.dict() == {} + + def test_to_messages(self): + step = MemoryStep() + with pytest.raises(NotImplementedError): + step.to_messages() + + def test_action_step_to_messages(): action_step = ActionStep( model_input_messages=[Message(role=MessageRole.USER, content="Hello")],