Improve inference choice examples (#311)
* Improve inference choice examples * Fix style --------- Co-authored-by: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com>
This commit is contained in:
		
							parent
							
								
									0196dc7b21
								
							
						
					
					
						commit
						de7b0ee799
					
				|  | @ -0,0 +1,51 @@ | ||||||
|  | from typing import Optional | ||||||
|  | 
 | ||||||
|  | from smolagents import HfApiModel, LiteLLMModel, TransformersModel, tool | ||||||
|  | from smolagents.agents import CodeAgent, ToolCallingAgent | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | # Choose which inference type to use! | ||||||
|  | 
 | ||||||
|  | available_inferences = ["hf_api", "transformers", "ollama", "litellm"] | ||||||
|  | chosen_inference = "transformers" | ||||||
|  | 
 | ||||||
|  | print(f"Chose model {chosen_inference}") | ||||||
|  | 
 | ||||||
|  | if chosen_inference == "hf_api": | ||||||
|  |     model = HfApiModel(model_id="meta-llama/Llama-3.3-70B-Instruct") | ||||||
|  | 
 | ||||||
|  | elif chosen_inference == "transformers": | ||||||
|  |     model = TransformersModel(model_id="HuggingFaceTB/SmolLM2-1.7B-Instruct", device_map="auto", max_new_tokens=1000) | ||||||
|  | 
 | ||||||
|  | elif chosen_inference == "ollama": | ||||||
|  |     model = LiteLLMModel( | ||||||
|  |         model_id="ollama_chat/llama3.2", | ||||||
|  |         api_base="http://localhost:11434",  # replace with remote open-ai compatible server if necessary | ||||||
|  |         api_key="your-api-key",  # replace with API key if necessary | ||||||
|  |     ) | ||||||
|  | 
 | ||||||
|  | elif chosen_inference == "litellm": | ||||||
|  |     # For anthropic: change model_id below to 'anthropic/claude-3-5-sonnet-latest' | ||||||
|  |     model = LiteLLMModel(model_id="gpt-4o") | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | @tool | ||||||
|  | def get_weather(location: str, celsius: Optional[bool] = False) -> str: | ||||||
|  |     """ | ||||||
|  |     Get weather in the next days at given location. | ||||||
|  |     Secretly this tool does not care about the location, it hates the weather everywhere. | ||||||
|  | 
 | ||||||
|  |     Args: | ||||||
|  |         location: the location | ||||||
|  |         celsius: the temperature | ||||||
|  |     """ | ||||||
|  |     return "The weather is UNGODLY with torrential rains and temperatures below -10°C" | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | agent = ToolCallingAgent(tools=[get_weather], model=model) | ||||||
|  | 
 | ||||||
|  | print("ToolCallingAgent:", agent.run("What's the weather like in Paris?")) | ||||||
|  | 
 | ||||||
|  | agent = CodeAgent(tools=[get_weather], model=model) | ||||||
|  | 
 | ||||||
|  | print("ToolCallingAgent:", agent.run("What's the weather like in Paris?")) | ||||||
|  | @ -1,30 +0,0 @@ | ||||||
| from typing import Optional |  | ||||||
| 
 |  | ||||||
| from smolagents import LiteLLMModel, tool |  | ||||||
| from smolagents.agents import ToolCallingAgent |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| # Choose which LLM engine to use! |  | ||||||
| # model = HfApiModel(model_id="meta-llama/Llama-3.3-70B-Instruct") |  | ||||||
| # model = TransformersModel(model_id="meta-llama/Llama-3.2-2B-Instruct") |  | ||||||
| 
 |  | ||||||
| # For anthropic: change model_id below to 'anthropic/claude-3-5-sonnet-20240620' |  | ||||||
| model = LiteLLMModel(model_id="gpt-4o") |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| @tool |  | ||||||
| def get_weather(location: str, celsius: Optional[bool] = False) -> str: |  | ||||||
|     """ |  | ||||||
|     Get weather in the next days at given location. |  | ||||||
|     Secretly this tool does not care about the location, it hates the weather everywhere. |  | ||||||
| 
 |  | ||||||
|     Args: |  | ||||||
|         location: the location |  | ||||||
|         celsius: the temperature |  | ||||||
|     """ |  | ||||||
|     return "The weather is UNGODLY with torrential rains and temperatures below -10°C" |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| agent = ToolCallingAgent(tools=[get_weather], model=model) |  | ||||||
| 
 |  | ||||||
| print(agent.run("What's the weather like in Paris?")) |  | ||||||
|  | @ -1,29 +0,0 @@ | ||||||
| """An example of loading a ToolCollection directly from an MCP server. |  | ||||||
| 
 |  | ||||||
| Requirements: to run this example, you need to have uv installed and in your path in |  | ||||||
| order to run the MCP server with uvx see `mcp_server_params` below. |  | ||||||
| 
 |  | ||||||
| Note this is just a demo MCP server that was implemented for the purpose of this example. |  | ||||||
| It only provide a single tool to search amongst pubmed papers abstracts. |  | ||||||
| 
 |  | ||||||
| Usage: |  | ||||||
| >>> uv run examples/tool_calling_agent_mcp.py |  | ||||||
| """ |  | ||||||
| 
 |  | ||||||
| import os |  | ||||||
| 
 |  | ||||||
| from mcp import StdioServerParameters |  | ||||||
| 
 |  | ||||||
| from smolagents import CodeAgent, HfApiModel, ToolCollection |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| mcp_server_params = StdioServerParameters( |  | ||||||
|     command="uvx", |  | ||||||
|     args=["--quiet", "pubmedmcp@0.1.3"], |  | ||||||
|     env={"UV_PYTHON": "3.12", **os.environ}, |  | ||||||
| ) |  | ||||||
| 
 |  | ||||||
| with ToolCollection.from_mcp(mcp_server_params) as tool_collection: |  | ||||||
|     # print(tool_collection.tools[0](request={"term": "efficient treatment hangover"})) |  | ||||||
|     agent = CodeAgent(tools=tool_collection.tools, model=HfApiModel(), max_steps=4) |  | ||||||
|     agent.run("Find me one risk associated with drinking alcohol regularly on low doses for humans.") |  | ||||||
|  | @ -1,29 +0,0 @@ | ||||||
| from typing import Optional |  | ||||||
| 
 |  | ||||||
| from smolagents import LiteLLMModel, tool |  | ||||||
| from smolagents.agents import ToolCallingAgent |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| model = LiteLLMModel( |  | ||||||
|     model_id="ollama_chat/llama3.2", |  | ||||||
|     api_base="http://localhost:11434",  # replace with remote open-ai compatible server if necessary |  | ||||||
|     api_key="your-api-key",  # replace with API key if necessary |  | ||||||
| ) |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| @tool |  | ||||||
| def get_weather(location: str, celsius: Optional[bool] = False) -> str: |  | ||||||
|     """ |  | ||||||
|     Get weather in the next days at given location. |  | ||||||
|     Secretly this tool does not care about the location, it hates the weather everywhere. |  | ||||||
| 
 |  | ||||||
|     Args: |  | ||||||
|         location: the location |  | ||||||
|         celsius: the temperature |  | ||||||
|     """ |  | ||||||
|     return "The weather is UNGODLY with torrential rains and temperatures below -10°C" |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| agent = ToolCallingAgent(tools=[get_weather], model=model) |  | ||||||
| 
 |  | ||||||
| print(agent.run("What's the weather like in Paris?")) |  | ||||||
|  | @ -480,7 +480,6 @@ class TransformersModel(Model): | ||||||
|             messages=messages, |             messages=messages, | ||||||
|             stop_sequences=stop_sequences, |             stop_sequences=stop_sequences, | ||||||
|             grammar=grammar, |             grammar=grammar, | ||||||
|             tools_to_call_from=tools_to_call_from, |  | ||||||
|             **kwargs, |             **kwargs, | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|  | @ -497,9 +496,6 @@ class TransformersModel(Model): | ||||||
|         if max_new_tokens: |         if max_new_tokens: | ||||||
|             completion_kwargs["max_new_tokens"] = max_new_tokens |             completion_kwargs["max_new_tokens"] = max_new_tokens | ||||||
| 
 | 
 | ||||||
|         if stop_sequences: |  | ||||||
|             completion_kwargs["stopping_criteria"] = self.make_stopping_criteria(stop_sequences) |  | ||||||
| 
 |  | ||||||
|         if tools_to_call_from is not None: |         if tools_to_call_from is not None: | ||||||
|             prompt_tensor = self.tokenizer.apply_chat_template( |             prompt_tensor = self.tokenizer.apply_chat_template( | ||||||
|                 messages, |                 messages, | ||||||
|  | @ -518,7 +514,11 @@ class TransformersModel(Model): | ||||||
|         prompt_tensor = prompt_tensor.to(self.model.device) |         prompt_tensor = prompt_tensor.to(self.model.device) | ||||||
|         count_prompt_tokens = prompt_tensor["input_ids"].shape[1] |         count_prompt_tokens = prompt_tensor["input_ids"].shape[1] | ||||||
| 
 | 
 | ||||||
|         out = self.model.generate(**prompt_tensor, **completion_kwargs) |         out = self.model.generate( | ||||||
|  |             **prompt_tensor, | ||||||
|  |             stopping_criteria=(self.make_stopping_criteria(stop_sequences) if stop_sequences else None), | ||||||
|  |             **completion_kwargs, | ||||||
|  |         ) | ||||||
|         generated_tokens = out[0, count_prompt_tokens:] |         generated_tokens = out[0, count_prompt_tokens:] | ||||||
|         output = self.tokenizer.decode(generated_tokens, skip_special_tokens=True) |         output = self.tokenizer.decode(generated_tokens, skip_special_tokens=True) | ||||||
|         self.last_input_token_count = count_prompt_tokens |         self.last_input_token_count = count_prompt_tokens | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue