Add support for OpenTelemetry instrumentation 📊 (#200)
This commit is contained in:
		
							parent
							
								
									ce1cd6d906
								
							
						
					
					
						commit
						450934ce79
					
				|  | @ -14,7 +14,7 @@ | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| from dataclasses import dataclass | ||||
| from dataclasses import dataclass, asdict | ||||
| import json | ||||
| import logging | ||||
| import os | ||||
|  | @ -54,12 +54,29 @@ if _is_package_available("litellm"): | |||
|     import litellm | ||||
| 
 | ||||
| 
 | ||||
| def get_dict_from_nested_dataclasses(obj): | ||||
|     def convert(obj): | ||||
|         if hasattr(obj, "__dataclass_fields__"): | ||||
|             return {k: convert(v) for k, v in asdict(obj).items()} | ||||
|         return obj | ||||
| 
 | ||||
|     return convert(obj) | ||||
| 
 | ||||
| 
 | ||||
| @dataclass | ||||
| class ChatMessageToolCallDefinition: | ||||
|     arguments: Any | ||||
|     name: str | ||||
|     description: Optional[str] = None | ||||
| 
 | ||||
|     @classmethod | ||||
|     def from_hf_api(cls, tool_call_definition) -> "ChatMessageToolCallDefinition": | ||||
|         return cls( | ||||
|             arguments=tool_call_definition.arguments, | ||||
|             name=tool_call_definition.name, | ||||
|             description=tool_call_definition.description, | ||||
|         ) | ||||
| 
 | ||||
| 
 | ||||
| @dataclass | ||||
| class ChatMessageToolCall: | ||||
|  | @ -67,6 +84,14 @@ class ChatMessageToolCall: | |||
|     id: str | ||||
|     type: str | ||||
| 
 | ||||
|     @classmethod | ||||
|     def from_hf_api(cls, tool_call) -> "ChatMessageToolCall": | ||||
|         return cls( | ||||
|             function=ChatMessageToolCallDefinition.from_hf_api(tool_call.function), | ||||
|             id=tool_call.id, | ||||
|             type=tool_call.type, | ||||
|         ) | ||||
| 
 | ||||
| 
 | ||||
| @dataclass | ||||
| class ChatMessage: | ||||
|  | @ -74,6 +99,19 @@ class ChatMessage: | |||
|     content: Optional[str] = None | ||||
|     tool_calls: Optional[List[ChatMessageToolCall]] = None | ||||
| 
 | ||||
|     def model_dump_json(self): | ||||
|         return json.dumps(get_dict_from_nested_dataclasses(self)) | ||||
| 
 | ||||
|     @classmethod | ||||
|     def from_hf_api(cls, message) -> "ChatMessage": | ||||
|         tool_calls = None | ||||
|         if getattr(message, "tool_calls", None) is not None: | ||||
|             tool_calls = [ | ||||
|                 ChatMessageToolCall.from_hf_api(tool_call) | ||||
|                 for tool_call in message.tool_calls | ||||
|             ] | ||||
|         return cls(role=message.role, content=message.content, tool_calls=tool_calls) | ||||
| 
 | ||||
| 
 | ||||
| class MessageRole(str, Enum): | ||||
|     USER = "user" | ||||
|  | @ -283,7 +321,7 @@ class HfApiModel(Model): | |||
|             ) | ||||
|         self.last_input_token_count = response.usage.prompt_tokens | ||||
|         self.last_output_token_count = response.usage.completion_tokens | ||||
|         return response.choices[0].message | ||||
|         return ChatMessage.from_hf_api(response.choices[0].message) | ||||
| 
 | ||||
| 
 | ||||
| class TransformersModel(Model): | ||||
|  | @ -315,14 +353,18 @@ class TransformersModel(Model): | |||
|         logger.info(f"Using device: {self.device}") | ||||
|         try: | ||||
|             self.tokenizer = AutoTokenizer.from_pretrained(model_id) | ||||
|             self.model = AutoModelForCausalLM.from_pretrained(model_id, device_map=self.device) | ||||
|             self.model = AutoModelForCausalLM.from_pretrained( | ||||
|                 model_id, device_map=self.device | ||||
|             ) | ||||
|         except Exception as e: | ||||
|             logger.warning( | ||||
|                 f"Failed to load tokenizer and model for {model_id=}: {e}. Loading default tokenizer and model instead from {default_model_id=}." | ||||
|             ) | ||||
|             self.model_id = default_model_id | ||||
|             self.tokenizer = AutoTokenizer.from_pretrained(default_model_id) | ||||
|             self.model = AutoModelForCausalLM.from_pretrained(model_id, device_map=self.device) | ||||
|             self.model = AutoModelForCausalLM.from_pretrained( | ||||
|                 model_id, device_map=self.device | ||||
|             ) | ||||
| 
 | ||||
|     def make_stopping_criteria(self, stop_sequences: List[str]) -> StoppingCriteriaList: | ||||
|         class StopOnStrings(StoppingCriteria): | ||||
|  | @ -551,4 +593,5 @@ __all__ = [ | |||
|     "HfApiModel", | ||||
|     "LiteLLMModel", | ||||
|     "OpenAIServerModel", | ||||
|     "ChatMessage", | ||||
| ] | ||||
|  |  | |||
|  | @ -13,9 +13,10 @@ | |||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| import unittest | ||||
| import json | ||||
| from typing import Optional | ||||
| 
 | ||||
| from smolagents import models, tool | ||||
| from smolagents import models, tool, ChatMessage, HfApiModel | ||||
| 
 | ||||
| 
 | ||||
| class ModelTests(unittest.TestCase): | ||||
|  | @ -38,3 +39,13 @@ class ModelTests(unittest.TestCase): | |||
|                 "properties" | ||||
|             ]["celsius"] | ||||
|         ) | ||||
| 
 | ||||
|     def test_chatmessage_has_model_dumps_json(self): | ||||
|         message = ChatMessage("user", "Hello!") | ||||
|         data = json.loads(message.model_dump_json()) | ||||
|         assert data["content"] == "Hello!" | ||||
| 
 | ||||
|     def test_get_hfapi_message_no_tool(self): | ||||
|         model = HfApiModel() | ||||
|         messages = [{"role": "user", "content": "Hello!"}] | ||||
|         model(messages, stop_sequences=["great"]) | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue