Remove tokenizer in HfApiEngine
This commit is contained in:
		
							parent
							
								
									b11abbf27e
								
							
						
					
					
						commit
						382ee534ab
					
				|  | @ -58,10 +58,15 @@ llama_role_conversions = { | |||
|     MessageRole.TOOL_RESPONSE: MessageRole.USER, | ||||
| } | ||||
| 
 | ||||
| def remove_stop_sequences(content: str, stop_sequences: List[str]) -> str: | ||||
|     for stop_seq in stop_sequences: | ||||
|         if content[-len(stop_seq) :] == stop_seq: | ||||
|             content = content[: -len(stop_seq)] | ||||
|     return content | ||||
| 
 | ||||
| def get_clean_message_list( | ||||
|     message_list: List[Dict[str, str]], role_conversions: Dict[str, str] = {} | ||||
| ): | ||||
| ) -> List[Dict[str, str]]: | ||||
|     """ | ||||
|     Subsequent messages with the same role will be concatenated to a single message. | ||||
| 
 | ||||
|  | @ -94,21 +99,9 @@ def get_clean_message_list( | |||
| 
 | ||||
| 
 | ||||
| class HfEngine: | ||||
|     def __init__(self, model_id: Optional[str] = None): | ||||
|     def __init__(self): | ||||
|         self.last_input_token_count = None | ||||
|         self.last_output_token_count = None | ||||
|         if model_id is None: | ||||
|             model_id = "HuggingFaceTB/SmolLM2-1.7B-Instruct" | ||||
|             logger.warning(f"Using default model for token counting: '{model_id}'") | ||||
|         try: | ||||
|             self.tokenizer = AutoTokenizer.from_pretrained(model_id) | ||||
|         except Exception as e: | ||||
|             logger.warning( | ||||
|                 f"Failed to load tokenizer for model {model_id}: {e}. Loading default tokenizer instead." | ||||
|             ) | ||||
|             self.tokenizer = AutoTokenizer.from_pretrained( | ||||
|                 "HuggingFaceTB/SmolLM2-1.7B-Instruct" | ||||
|             ) | ||||
| 
 | ||||
|     def get_token_counts(self): | ||||
|         return { | ||||
|  | @ -134,8 +127,6 @@ class HfEngine: | |||
|     ) -> str: | ||||
|         """Process the input messages and return the model's response. | ||||
| 
 | ||||
|         This method sends a list of messages to the Hugging Face Inference API, optionally with stop sequences and grammar customization. | ||||
| 
 | ||||
|         Parameters: | ||||
|             messages (`List[Dict[str, str]]`): | ||||
|                 A list of message dictionaries to be processed. Each dictionary should have the structure `{"role": "user/system", "content": "message content"}`. | ||||
|  | @ -143,22 +134,10 @@ class HfEngine: | |||
|                 A list of strings that will stop the generation if encountered in the model's output. | ||||
|             grammar (`str`, *optional*): | ||||
|                 The grammar or formatting structure to use in the model's response. | ||||
| 
 | ||||
|             max_tokens (`int`, *optional*): | ||||
|                 The maximum count of tokens to generate. | ||||
|         Returns: | ||||
|             `str`: The text content of the model's response. | ||||
| 
 | ||||
|         Example: | ||||
|             ```python | ||||
|             >>> engine = HfApiEngine( | ||||
|             ...     model="Qwen/Qwen2.5-Coder-32B-Instruct", | ||||
|             ...     token="your_hf_token_here", | ||||
|             ...     max_tokens=2000 | ||||
|             ... ) | ||||
|             >>> messages = [{"role": "user", "content": "Explain quantum mechanics in simple terms."}] | ||||
|             >>> response = engine(messages, stop_sequences=["END"]) | ||||
|             >>> print(response) | ||||
|             "Quantum mechanics is the branch of physics that studies..." | ||||
|             ``` | ||||
|         """ | ||||
|         if not isinstance(messages, List): | ||||
|             raise ValueError( | ||||
|  | @ -167,16 +146,8 @@ class HfEngine: | |||
|         if stop_sequences is None: | ||||
|             stop_sequences = [] | ||||
|         response = self.generate(messages, stop_sequences, grammar, max_tokens) | ||||
|         self.last_input_token_count = len( | ||||
|             self.tokenizer.apply_chat_template(messages, tokenize=True) | ||||
|         ) | ||||
|         self.last_output_token_count = len(self.tokenizer.encode(response)) | ||||
| 
 | ||||
|         # Remove stop sequences from LLM output | ||||
|         for stop_seq in stop_sequences: | ||||
|             if response[-len(stop_seq) :] == stop_seq: | ||||
|                 response = response[: -len(stop_seq)] | ||||
|         return response | ||||
|         return remove_stop_sequences(response, stop_sequences) | ||||
| 
 | ||||
| 
 | ||||
| class HfApiEngine(HfEngine): | ||||
|  | @ -199,19 +170,32 @@ class HfApiEngine(HfEngine): | |||
|     Raises: | ||||
|         ValueError: | ||||
|             If the model name is not provided. | ||||
| 
 | ||||
|     Example: | ||||
|     ```python | ||||
|     >>> engine = HfApiEngine( | ||||
|     ...     model="Qwen/Qwen2.5-Coder-32B-Instruct", | ||||
|     ...     token="your_hf_token_here", | ||||
|     ...     max_tokens=2000 | ||||
|     ... ) | ||||
|     >>> messages = [{"role": "user", "content": "Explain quantum mechanics in simple terms."}] | ||||
|     >>> response = engine(messages, stop_sequences=["END"]) | ||||
|     >>> print(response) | ||||
|     "Quantum mechanics is the branch of physics that studies..." | ||||
|     ``` | ||||
|     """ | ||||
| 
 | ||||
|     def __init__( | ||||
|         self, | ||||
|         model: str = "Qwen/Qwen2.5-Coder-32B-Instruct", | ||||
|         model_id: str = "Qwen/Qwen2.5-Coder-32B-Instruct", | ||||
|         token: Optional[str] = None, | ||||
|         timeout: Optional[int] = 120, | ||||
|     ): | ||||
|         super().__init__(model_id=model) | ||||
|         self.model = model | ||||
|         super().__init__() | ||||
|         self.model_id = model_id | ||||
|         if token is None: | ||||
|             token = os.getenv("HF_TOKEN") | ||||
|         self.client = InferenceClient(self.model, token=token, timeout=timeout) | ||||
|         self.client = InferenceClient(self.model_id, token=token, timeout=timeout) | ||||
| 
 | ||||
|     def generate( | ||||
|         self, | ||||
|  | @ -239,6 +223,8 @@ class HfApiEngine(HfEngine): | |||
|             ) | ||||
| 
 | ||||
|         response = response.choices[0].message.content | ||||
|         self.last_input_token_count = response.usage.prompt_tokens | ||||
|         self.last_output_token_count = response.usage.completion_tokens | ||||
|         return response | ||||
| 
 | ||||
| 
 | ||||
|  | @ -246,7 +232,19 @@ class TransformersEngine(HfEngine): | |||
|     """This engine uses a pre-initialized local text-generation pipeline.""" | ||||
| 
 | ||||
|     def __init__(self, pipeline: Pipeline, model_id: Optional[str] = None): | ||||
|         super().__init__(model_id) | ||||
|         super().__init__() | ||||
|         if model_id is None: | ||||
|             model_id = "HuggingFaceTB/SmolLM2-1.7B-Instruct" | ||||
|             logger.warning(f"`model_id`not provided, using this default tokenizer for token counts: '{model_id}'") | ||||
|         try: | ||||
|             self.tokenizer = AutoTokenizer.from_pretrained(model_id) | ||||
|         except Exception as e: | ||||
|             logger.warning( | ||||
|                 f"Failed to load tokenizer for model {model_id}: {e}. Loading default tokenizer instead." | ||||
|             ) | ||||
|             self.tokenizer = AutoTokenizer.from_pretrained( | ||||
|                 "HuggingFaceTB/SmolLM2-1.7B-Instruct" | ||||
|             ) | ||||
|         self.pipeline = pipeline | ||||
| 
 | ||||
|     def generate( | ||||
|  | @ -275,6 +273,10 @@ class TransformersEngine(HfEngine): | |||
|         ) | ||||
| 
 | ||||
|         response = output[0]["generated_text"][-1]["content"] | ||||
|         self.last_input_token_count = len( | ||||
|             self.tokenizer.apply_chat_template(messages, tokenize=True) | ||||
|         ) | ||||
|         self.last_output_token_count = len(self.tokenizer.encode(response)) | ||||
|         return response | ||||
| 
 | ||||
| 
 | ||||
|  | @ -320,6 +322,8 @@ class OpenAIEngine: | |||
|             temperature=0.5, | ||||
|             max_tokens=max_tokens, | ||||
|         ) | ||||
|         self.last_input_token_count = response.usage.prompt_tokens | ||||
|         self.last_output_token_count = response.usage.completion_tokens | ||||
|         return response.choices[0].message.content | ||||
| 
 | ||||
| 
 | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue