Remove tokenizer in HfApiEngine
This commit is contained in:
		
							parent
							
								
									b11abbf27e
								
							
						
					
					
						commit
						382ee534ab
					
				|  | @ -58,10 +58,15 @@ llama_role_conversions = { | ||||||
|     MessageRole.TOOL_RESPONSE: MessageRole.USER, |     MessageRole.TOOL_RESPONSE: MessageRole.USER, | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | def remove_stop_sequences(content: str, stop_sequences: List[str]) -> str: | ||||||
|  |     for stop_seq in stop_sequences: | ||||||
|  |         if content[-len(stop_seq) :] == stop_seq: | ||||||
|  |             content = content[: -len(stop_seq)] | ||||||
|  |     return content | ||||||
| 
 | 
 | ||||||
| def get_clean_message_list( | def get_clean_message_list( | ||||||
|     message_list: List[Dict[str, str]], role_conversions: Dict[str, str] = {} |     message_list: List[Dict[str, str]], role_conversions: Dict[str, str] = {} | ||||||
| ): | ) -> List[Dict[str, str]]: | ||||||
|     """ |     """ | ||||||
|     Subsequent messages with the same role will be concatenated to a single message. |     Subsequent messages with the same role will be concatenated to a single message. | ||||||
| 
 | 
 | ||||||
|  | @ -94,21 +99,9 @@ def get_clean_message_list( | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class HfEngine: | class HfEngine: | ||||||
|     def __init__(self, model_id: Optional[str] = None): |     def __init__(self): | ||||||
|         self.last_input_token_count = None |         self.last_input_token_count = None | ||||||
|         self.last_output_token_count = None |         self.last_output_token_count = None | ||||||
|         if model_id is None: |  | ||||||
|             model_id = "HuggingFaceTB/SmolLM2-1.7B-Instruct" |  | ||||||
|             logger.warning(f"Using default model for token counting: '{model_id}'") |  | ||||||
|         try: |  | ||||||
|             self.tokenizer = AutoTokenizer.from_pretrained(model_id) |  | ||||||
|         except Exception as e: |  | ||||||
|             logger.warning( |  | ||||||
|                 f"Failed to load tokenizer for model {model_id}: {e}. Loading default tokenizer instead." |  | ||||||
|             ) |  | ||||||
|             self.tokenizer = AutoTokenizer.from_pretrained( |  | ||||||
|                 "HuggingFaceTB/SmolLM2-1.7B-Instruct" |  | ||||||
|             ) |  | ||||||
| 
 | 
 | ||||||
|     def get_token_counts(self): |     def get_token_counts(self): | ||||||
|         return { |         return { | ||||||
|  | @ -134,8 +127,6 @@ class HfEngine: | ||||||
|     ) -> str: |     ) -> str: | ||||||
|         """Process the input messages and return the model's response. |         """Process the input messages and return the model's response. | ||||||
| 
 | 
 | ||||||
|         This method sends a list of messages to the Hugging Face Inference API, optionally with stop sequences and grammar customization. |  | ||||||
| 
 |  | ||||||
|         Parameters: |         Parameters: | ||||||
|             messages (`List[Dict[str, str]]`): |             messages (`List[Dict[str, str]]`): | ||||||
|                 A list of message dictionaries to be processed. Each dictionary should have the structure `{"role": "user/system", "content": "message content"}`. |                 A list of message dictionaries to be processed. Each dictionary should have the structure `{"role": "user/system", "content": "message content"}`. | ||||||
|  | @ -143,22 +134,10 @@ class HfEngine: | ||||||
|                 A list of strings that will stop the generation if encountered in the model's output. |                 A list of strings that will stop the generation if encountered in the model's output. | ||||||
|             grammar (`str`, *optional*): |             grammar (`str`, *optional*): | ||||||
|                 The grammar or formatting structure to use in the model's response. |                 The grammar or formatting structure to use in the model's response. | ||||||
| 
 |             max_tokens (`int`, *optional*): | ||||||
|  |                 The maximum count of tokens to generate. | ||||||
|         Returns: |         Returns: | ||||||
|             `str`: The text content of the model's response. |             `str`: The text content of the model's response. | ||||||
| 
 |  | ||||||
|         Example: |  | ||||||
|             ```python |  | ||||||
|             >>> engine = HfApiEngine( |  | ||||||
|             ...     model="Qwen/Qwen2.5-Coder-32B-Instruct", |  | ||||||
|             ...     token="your_hf_token_here", |  | ||||||
|             ...     max_tokens=2000 |  | ||||||
|             ... ) |  | ||||||
|             >>> messages = [{"role": "user", "content": "Explain quantum mechanics in simple terms."}] |  | ||||||
|             >>> response = engine(messages, stop_sequences=["END"]) |  | ||||||
|             >>> print(response) |  | ||||||
|             "Quantum mechanics is the branch of physics that studies..." |  | ||||||
|             ``` |  | ||||||
|         """ |         """ | ||||||
|         if not isinstance(messages, List): |         if not isinstance(messages, List): | ||||||
|             raise ValueError( |             raise ValueError( | ||||||
|  | @ -167,16 +146,8 @@ class HfEngine: | ||||||
|         if stop_sequences is None: |         if stop_sequences is None: | ||||||
|             stop_sequences = [] |             stop_sequences = [] | ||||||
|         response = self.generate(messages, stop_sequences, grammar, max_tokens) |         response = self.generate(messages, stop_sequences, grammar, max_tokens) | ||||||
|         self.last_input_token_count = len( |  | ||||||
|             self.tokenizer.apply_chat_template(messages, tokenize=True) |  | ||||||
|         ) |  | ||||||
|         self.last_output_token_count = len(self.tokenizer.encode(response)) |  | ||||||
| 
 | 
 | ||||||
|         # Remove stop sequences from LLM output |         return remove_stop_sequences(response, stop_sequences) | ||||||
|         for stop_seq in stop_sequences: |  | ||||||
|             if response[-len(stop_seq) :] == stop_seq: |  | ||||||
|                 response = response[: -len(stop_seq)] |  | ||||||
|         return response |  | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class HfApiEngine(HfEngine): | class HfApiEngine(HfEngine): | ||||||
|  | @ -199,19 +170,32 @@ class HfApiEngine(HfEngine): | ||||||
|     Raises: |     Raises: | ||||||
|         ValueError: |         ValueError: | ||||||
|             If the model name is not provided. |             If the model name is not provided. | ||||||
|  | 
 | ||||||
|  |     Example: | ||||||
|  |     ```python | ||||||
|  |     >>> engine = HfApiEngine( | ||||||
|  |     ...     model="Qwen/Qwen2.5-Coder-32B-Instruct", | ||||||
|  |     ...     token="your_hf_token_here", | ||||||
|  |     ...     max_tokens=2000 | ||||||
|  |     ... ) | ||||||
|  |     >>> messages = [{"role": "user", "content": "Explain quantum mechanics in simple terms."}] | ||||||
|  |     >>> response = engine(messages, stop_sequences=["END"]) | ||||||
|  |     >>> print(response) | ||||||
|  |     "Quantum mechanics is the branch of physics that studies..." | ||||||
|  |     ``` | ||||||
|     """ |     """ | ||||||
| 
 | 
 | ||||||
|     def __init__( |     def __init__( | ||||||
|         self, |         self, | ||||||
|         model: str = "Qwen/Qwen2.5-Coder-32B-Instruct", |         model_id: str = "Qwen/Qwen2.5-Coder-32B-Instruct", | ||||||
|         token: Optional[str] = None, |         token: Optional[str] = None, | ||||||
|         timeout: Optional[int] = 120, |         timeout: Optional[int] = 120, | ||||||
|     ): |     ): | ||||||
|         super().__init__(model_id=model) |         super().__init__() | ||||||
|         self.model = model |         self.model_id = model_id | ||||||
|         if token is None: |         if token is None: | ||||||
|             token = os.getenv("HF_TOKEN") |             token = os.getenv("HF_TOKEN") | ||||||
|         self.client = InferenceClient(self.model, token=token, timeout=timeout) |         self.client = InferenceClient(self.model_id, token=token, timeout=timeout) | ||||||
| 
 | 
 | ||||||
|     def generate( |     def generate( | ||||||
|         self, |         self, | ||||||
|  | @ -239,6 +223,8 @@ class HfApiEngine(HfEngine): | ||||||
|             ) |             ) | ||||||
| 
 | 
 | ||||||
|         response = response.choices[0].message.content |         response = response.choices[0].message.content | ||||||
|  |         self.last_input_token_count = response.usage.prompt_tokens | ||||||
|  |         self.last_output_token_count = response.usage.completion_tokens | ||||||
|         return response |         return response | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | @ -246,7 +232,19 @@ class TransformersEngine(HfEngine): | ||||||
|     """This engine uses a pre-initialized local text-generation pipeline.""" |     """This engine uses a pre-initialized local text-generation pipeline.""" | ||||||
| 
 | 
 | ||||||
|     def __init__(self, pipeline: Pipeline, model_id: Optional[str] = None): |     def __init__(self, pipeline: Pipeline, model_id: Optional[str] = None): | ||||||
|         super().__init__(model_id) |         super().__init__() | ||||||
|  |         if model_id is None: | ||||||
|  |             model_id = "HuggingFaceTB/SmolLM2-1.7B-Instruct" | ||||||
|  |             logger.warning(f"`model_id`not provided, using this default tokenizer for token counts: '{model_id}'") | ||||||
|  |         try: | ||||||
|  |             self.tokenizer = AutoTokenizer.from_pretrained(model_id) | ||||||
|  |         except Exception as e: | ||||||
|  |             logger.warning( | ||||||
|  |                 f"Failed to load tokenizer for model {model_id}: {e}. Loading default tokenizer instead." | ||||||
|  |             ) | ||||||
|  |             self.tokenizer = AutoTokenizer.from_pretrained( | ||||||
|  |                 "HuggingFaceTB/SmolLM2-1.7B-Instruct" | ||||||
|  |             ) | ||||||
|         self.pipeline = pipeline |         self.pipeline = pipeline | ||||||
| 
 | 
 | ||||||
|     def generate( |     def generate( | ||||||
|  | @ -275,6 +273,10 @@ class TransformersEngine(HfEngine): | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|         response = output[0]["generated_text"][-1]["content"] |         response = output[0]["generated_text"][-1]["content"] | ||||||
|  |         self.last_input_token_count = len( | ||||||
|  |             self.tokenizer.apply_chat_template(messages, tokenize=True) | ||||||
|  |         ) | ||||||
|  |         self.last_output_token_count = len(self.tokenizer.encode(response)) | ||||||
|         return response |         return response | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | @ -320,6 +322,8 @@ class OpenAIEngine: | ||||||
|             temperature=0.5, |             temperature=0.5, | ||||||
|             max_tokens=max_tokens, |             max_tokens=max_tokens, | ||||||
|         ) |         ) | ||||||
|  |         self.last_input_token_count = response.usage.prompt_tokens | ||||||
|  |         self.last_output_token_count = response.usage.completion_tokens | ||||||
|         return response.choices[0].message.content |         return response.choices[0].message.content | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue