Add support for OpenTelemetry instrumentation 📊 (#200)
This commit is contained in:
		
							parent
							
								
									ce1cd6d906
								
							
						
					
					
						commit
						450934ce79
					
				|  | @ -14,7 +14,7 @@ | ||||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
| # See the License for the specific language governing permissions and | # See the License for the specific language governing permissions and | ||||||
| # limitations under the License. | # limitations under the License. | ||||||
| from dataclasses import dataclass | from dataclasses import dataclass, asdict | ||||||
| import json | import json | ||||||
| import logging | import logging | ||||||
| import os | import os | ||||||
|  | @ -54,12 +54,29 @@ if _is_package_available("litellm"): | ||||||
|     import litellm |     import litellm | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | def get_dict_from_nested_dataclasses(obj): | ||||||
|  |     def convert(obj): | ||||||
|  |         if hasattr(obj, "__dataclass_fields__"): | ||||||
|  |             return {k: convert(v) for k, v in asdict(obj).items()} | ||||||
|  |         return obj | ||||||
|  | 
 | ||||||
|  |     return convert(obj) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| @dataclass | @dataclass | ||||||
| class ChatMessageToolCallDefinition: | class ChatMessageToolCallDefinition: | ||||||
|     arguments: Any |     arguments: Any | ||||||
|     name: str |     name: str | ||||||
|     description: Optional[str] = None |     description: Optional[str] = None | ||||||
| 
 | 
 | ||||||
|  |     @classmethod | ||||||
|  |     def from_hf_api(cls, tool_call_definition) -> "ChatMessageToolCallDefinition": | ||||||
|  |         return cls( | ||||||
|  |             arguments=tool_call_definition.arguments, | ||||||
|  |             name=tool_call_definition.name, | ||||||
|  |             description=tool_call_definition.description, | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
| 
 | 
 | ||||||
| @dataclass | @dataclass | ||||||
| class ChatMessageToolCall: | class ChatMessageToolCall: | ||||||
|  | @ -67,6 +84,14 @@ class ChatMessageToolCall: | ||||||
|     id: str |     id: str | ||||||
|     type: str |     type: str | ||||||
| 
 | 
 | ||||||
|  |     @classmethod | ||||||
|  |     def from_hf_api(cls, tool_call) -> "ChatMessageToolCall": | ||||||
|  |         return cls( | ||||||
|  |             function=ChatMessageToolCallDefinition.from_hf_api(tool_call.function), | ||||||
|  |             id=tool_call.id, | ||||||
|  |             type=tool_call.type, | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
| 
 | 
 | ||||||
| @dataclass | @dataclass | ||||||
| class ChatMessage: | class ChatMessage: | ||||||
|  | @ -74,6 +99,19 @@ class ChatMessage: | ||||||
|     content: Optional[str] = None |     content: Optional[str] = None | ||||||
|     tool_calls: Optional[List[ChatMessageToolCall]] = None |     tool_calls: Optional[List[ChatMessageToolCall]] = None | ||||||
| 
 | 
 | ||||||
|  |     def model_dump_json(self): | ||||||
|  |         return json.dumps(get_dict_from_nested_dataclasses(self)) | ||||||
|  | 
 | ||||||
|  |     @classmethod | ||||||
|  |     def from_hf_api(cls, message) -> "ChatMessage": | ||||||
|  |         tool_calls = None | ||||||
|  |         if getattr(message, "tool_calls", None) is not None: | ||||||
|  |             tool_calls = [ | ||||||
|  |                 ChatMessageToolCall.from_hf_api(tool_call) | ||||||
|  |                 for tool_call in message.tool_calls | ||||||
|  |             ] | ||||||
|  |         return cls(role=message.role, content=message.content, tool_calls=tool_calls) | ||||||
|  | 
 | ||||||
| 
 | 
 | ||||||
| class MessageRole(str, Enum): | class MessageRole(str, Enum): | ||||||
|     USER = "user" |     USER = "user" | ||||||
|  | @ -283,7 +321,7 @@ class HfApiModel(Model): | ||||||
|             ) |             ) | ||||||
|         self.last_input_token_count = response.usage.prompt_tokens |         self.last_input_token_count = response.usage.prompt_tokens | ||||||
|         self.last_output_token_count = response.usage.completion_tokens |         self.last_output_token_count = response.usage.completion_tokens | ||||||
|         return response.choices[0].message |         return ChatMessage.from_hf_api(response.choices[0].message) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class TransformersModel(Model): | class TransformersModel(Model): | ||||||
|  | @ -315,14 +353,18 @@ class TransformersModel(Model): | ||||||
|         logger.info(f"Using device: {self.device}") |         logger.info(f"Using device: {self.device}") | ||||||
|         try: |         try: | ||||||
|             self.tokenizer = AutoTokenizer.from_pretrained(model_id) |             self.tokenizer = AutoTokenizer.from_pretrained(model_id) | ||||||
|             self.model = AutoModelForCausalLM.from_pretrained(model_id, device_map=self.device) |             self.model = AutoModelForCausalLM.from_pretrained( | ||||||
|  |                 model_id, device_map=self.device | ||||||
|  |             ) | ||||||
|         except Exception as e: |         except Exception as e: | ||||||
|             logger.warning( |             logger.warning( | ||||||
|                 f"Failed to load tokenizer and model for {model_id=}: {e}. Loading default tokenizer and model instead from {default_model_id=}." |                 f"Failed to load tokenizer and model for {model_id=}: {e}. Loading default tokenizer and model instead from {default_model_id=}." | ||||||
|             ) |             ) | ||||||
|             self.model_id = default_model_id |             self.model_id = default_model_id | ||||||
|             self.tokenizer = AutoTokenizer.from_pretrained(default_model_id) |             self.tokenizer = AutoTokenizer.from_pretrained(default_model_id) | ||||||
|             self.model = AutoModelForCausalLM.from_pretrained(model_id, device_map=self.device) |             self.model = AutoModelForCausalLM.from_pretrained( | ||||||
|  |                 model_id, device_map=self.device | ||||||
|  |             ) | ||||||
| 
 | 
 | ||||||
|     def make_stopping_criteria(self, stop_sequences: List[str]) -> StoppingCriteriaList: |     def make_stopping_criteria(self, stop_sequences: List[str]) -> StoppingCriteriaList: | ||||||
|         class StopOnStrings(StoppingCriteria): |         class StopOnStrings(StoppingCriteria): | ||||||
|  | @ -551,4 +593,5 @@ __all__ = [ | ||||||
|     "HfApiModel", |     "HfApiModel", | ||||||
|     "LiteLLMModel", |     "LiteLLMModel", | ||||||
|     "OpenAIServerModel", |     "OpenAIServerModel", | ||||||
|  |     "ChatMessage", | ||||||
| ] | ] | ||||||
|  |  | ||||||
|  | @ -13,9 +13,10 @@ | ||||||
| # See the License for the specific language governing permissions and | # See the License for the specific language governing permissions and | ||||||
| # limitations under the License. | # limitations under the License. | ||||||
| import unittest | import unittest | ||||||
|  | import json | ||||||
| from typing import Optional | from typing import Optional | ||||||
| 
 | 
 | ||||||
| from smolagents import models, tool | from smolagents import models, tool, ChatMessage, HfApiModel | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class ModelTests(unittest.TestCase): | class ModelTests(unittest.TestCase): | ||||||
|  | @ -38,3 +39,13 @@ class ModelTests(unittest.TestCase): | ||||||
|                 "properties" |                 "properties" | ||||||
|             ]["celsius"] |             ]["celsius"] | ||||||
|         ) |         ) | ||||||
|  | 
 | ||||||
|  |     def test_chatmessage_has_model_dumps_json(self): | ||||||
|  |         message = ChatMessage("user", "Hello!") | ||||||
|  |         data = json.loads(message.model_dump_json()) | ||||||
|  |         assert data["content"] == "Hello!" | ||||||
|  | 
 | ||||||
|  |     def test_get_hfapi_message_no_tool(self): | ||||||
|  |         model = HfApiModel() | ||||||
|  |         messages = [{"role": "user", "content": "Hello!"}] | ||||||
|  |         model(messages, stop_sequences=["great"]) | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue