Add tool calling agent example
This commit is contained in:
		
							parent
							
								
									30cb6111b3
								
							
						
					
					
						commit
						32d7bc5e06
					
				|  | @ -0,0 +1,22 @@ | ||||||
|  | from agents.agents import ToolCallingAgent | ||||||
|  | from agents import tool, HfApiEngine, OpenAIEngine, AnthropicEngine | ||||||
|  | 
 | ||||||
|  | # Choose which LLM engine to use! | ||||||
|  | llm_engine = OpenAIEngine("gpt-4o") | ||||||
|  | llm_engine = AnthropicEngine() | ||||||
|  | llm_engine = HfApiEngine("meta-llama/Llama-3.3-70B-Instruct") | ||||||
|  | 
 | ||||||
|  | @tool | ||||||
|  | def get_weather(location: str) -> str: | ||||||
|  |     """ | ||||||
|  |     Get weather in the next days at given location. | ||||||
|  |     Secretly this tool does not care about the location, it hates the weather everywhere. | ||||||
|  | 
 | ||||||
|  |     Args: | ||||||
|  |         location: the location | ||||||
|  |     """ | ||||||
|  |     return "The weather is UNGODLY with torrential rains and temperatures below -10°C" | ||||||
|  | 
 | ||||||
|  | agent = ToolCallingAgent(tools=[get_weather], llm_engine=llm_engine) | ||||||
|  | 
 | ||||||
|  | print(agent.run("What's the weather like in Paris?")) | ||||||
|  | @ -40,7 +40,7 @@ class DockerPythonInterpreter: | ||||||
|         Execute Python code in the container and return stdout and stderr |         Execute Python code in the container and return stdout and stderr | ||||||
|         """ |         """ | ||||||
| 
 | 
 | ||||||
|         if tools != None: |         if tools is not None: | ||||||
|             tool_instance = tools[0]() |             tool_instance = tools[0]() | ||||||
| 
 | 
 | ||||||
|             import_code = f""" |             import_code = f""" | ||||||
|  |  | ||||||
|  | @ -50,7 +50,7 @@ class MessageRole(str, Enum): | ||||||
|         return [r.value for r in cls] |         return [r.value for r in cls] | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| llama_role_conversions = { | tool_role_conversions = { | ||||||
|     MessageRole.TOOL_CALL: MessageRole.ASSISTANT, |     MessageRole.TOOL_CALL: MessageRole.ASSISTANT, | ||||||
|     MessageRole.TOOL_RESPONSE: MessageRole.USER, |     MessageRole.TOOL_RESPONSE: MessageRole.USER, | ||||||
| } | } | ||||||
|  | @ -232,7 +232,7 @@ class HfApiEngine(HfEngine): | ||||||
|     ) -> str: |     ) -> str: | ||||||
|         """Generates a text completion for the given message list""" |         """Generates a text completion for the given message list""" | ||||||
|         messages = get_clean_message_list( |         messages = get_clean_message_list( | ||||||
|             messages, role_conversions=llama_role_conversions |             messages, role_conversions=tool_role_conversions | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|         # Send messages to the Hugging Face Inference API |         # Send messages to the Hugging Face Inference API | ||||||
|  | @ -260,7 +260,7 @@ class HfApiEngine(HfEngine): | ||||||
|     ): |     ): | ||||||
|         """Generates a tool call for the given message list""" |         """Generates a tool call for the given message list""" | ||||||
|         messages = get_clean_message_list( |         messages = get_clean_message_list( | ||||||
|             messages, role_conversions=llama_role_conversions |             messages, role_conversions=tool_role_conversions | ||||||
|         ) |         ) | ||||||
|         response = self.client.chat.completions.create( |         response = self.client.chat.completions.create( | ||||||
|             messages=messages, |             messages=messages, | ||||||
|  | @ -302,7 +302,7 @@ class TransformersEngine(HfEngine): | ||||||
|         max_tokens: int = 1500, |         max_tokens: int = 1500, | ||||||
|     ) -> str: |     ) -> str: | ||||||
|         messages = get_clean_message_list( |         messages = get_clean_message_list( | ||||||
|             messages, role_conversions=llama_role_conversions |             messages, role_conversions=tool_role_conversions | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|         # Get LLM output |         # Get LLM output | ||||||
|  | @ -360,7 +360,7 @@ class OpenAIEngine: | ||||||
|         max_tokens: int = 1500, |         max_tokens: int = 1500, | ||||||
|     ) -> str: |     ) -> str: | ||||||
|         messages = get_clean_message_list( |         messages = get_clean_message_list( | ||||||
|             messages, role_conversions=openai_role_conversions |             messages, role_conversions=tool_role_conversions | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|         response = self.client.chat.completions.create( |         response = self.client.chat.completions.create( | ||||||
|  | @ -381,7 +381,7 @@ class OpenAIEngine: | ||||||
|     ): |     ): | ||||||
|         """Generates a tool call for the given message list""" |         """Generates a tool call for the given message list""" | ||||||
|         messages = get_clean_message_list( |         messages = get_clean_message_list( | ||||||
|             messages, role_conversions=llama_role_conversions |             messages, role_conversions=tool_role_conversions | ||||||
|         ) |         ) | ||||||
|         response = self.client.chat.completions.create( |         response = self.client.chat.completions.create( | ||||||
|             model=self.model_name, |             model=self.model_name, | ||||||
|  | @ -448,7 +448,7 @@ class AnthropicEngine: | ||||||
|         max_tokens: int = 1500, |         max_tokens: int = 1500, | ||||||
|     ) -> str: |     ) -> str: | ||||||
|         messages = get_clean_message_list( |         messages = get_clean_message_list( | ||||||
|             messages, role_conversions=llama_role_conversions |             messages, role_conversions=tool_role_conversions | ||||||
|         ) |         ) | ||||||
|         filtered_messages, system_prompt = self.separate_messages_system_prompt( |         filtered_messages, system_prompt = self.separate_messages_system_prompt( | ||||||
|             messages |             messages | ||||||
|  | @ -475,7 +475,7 @@ class AnthropicEngine: | ||||||
|     ): |     ): | ||||||
|         """Generates a tool call for the given message list""" |         """Generates a tool call for the given message list""" | ||||||
|         messages = get_clean_message_list( |         messages = get_clean_message_list( | ||||||
|             messages, role_conversions=llama_role_conversions |             messages, role_conversions=tool_role_conversions | ||||||
|         ) |         ) | ||||||
|         filtered_messages, system_prompt = self.separate_messages_system_prompt( |         filtered_messages, system_prompt = self.separate_messages_system_prompt( | ||||||
|             messages |             messages | ||||||
|  | @ -496,7 +496,7 @@ class AnthropicEngine: | ||||||
| 
 | 
 | ||||||
| __all__ = [ | __all__ = [ | ||||||
|     "MessageRole", |     "MessageRole", | ||||||
|     "llama_role_conversions", |     "tool_role_conversions", | ||||||
|     "get_clean_message_list", |     "get_clean_message_list", | ||||||
|     "HfEngine", |     "HfEngine", | ||||||
|     "TransformersEngine", |     "TransformersEngine", | ||||||
|  |  | ||||||
|  | @ -232,7 +232,7 @@ Action: | ||||||
| 
 | 
 | ||||||
|     def test_additional_args_added_to_task(self): |     def test_additional_args_added_to_task(self): | ||||||
|         agent = CodeAgent(tools=[], llm_engine=fake_code_llm) |         agent = CodeAgent(tools=[], llm_engine=fake_code_llm) | ||||||
|         output = agent.run( |         agent.run( | ||||||
|             "What is 2 multiplied by 3.6452?", additional_instruction="Remember this." |             "What is 2 multiplied by 3.6452?", additional_instruction="Remember this." | ||||||
|         ) |         ) | ||||||
|         assert "Remember this" in agent.task |         assert "Remember this" in agent.task | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue