Remove tokenizer in HfApiEngine
This commit is contained in:
parent
b11abbf27e
commit
382ee534ab
|
@ -58,10 +58,15 @@ llama_role_conversions = {
|
|||
MessageRole.TOOL_RESPONSE: MessageRole.USER,
|
||||
}
|
||||
|
||||
def remove_stop_sequences(content: str, stop_sequences: List[str]) -> str:
|
||||
for stop_seq in stop_sequences:
|
||||
if content[-len(stop_seq) :] == stop_seq:
|
||||
content = content[: -len(stop_seq)]
|
||||
return content
|
||||
|
||||
def get_clean_message_list(
|
||||
message_list: List[Dict[str, str]], role_conversions: Dict[str, str] = {}
|
||||
):
|
||||
) -> List[Dict[str, str]]:
|
||||
"""
|
||||
Subsequent messages with the same role will be concatenated to a single message.
|
||||
|
||||
|
@ -94,21 +99,9 @@ def get_clean_message_list(
|
|||
|
||||
|
||||
class HfEngine:
|
||||
def __init__(self, model_id: Optional[str] = None):
|
||||
def __init__(self):
|
||||
self.last_input_token_count = None
|
||||
self.last_output_token_count = None
|
||||
if model_id is None:
|
||||
model_id = "HuggingFaceTB/SmolLM2-1.7B-Instruct"
|
||||
logger.warning(f"Using default model for token counting: '{model_id}'")
|
||||
try:
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to load tokenizer for model {model_id}: {e}. Loading default tokenizer instead."
|
||||
)
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||
"HuggingFaceTB/SmolLM2-1.7B-Instruct"
|
||||
)
|
||||
|
||||
def get_token_counts(self):
|
||||
return {
|
||||
|
@ -134,8 +127,6 @@ class HfEngine:
|
|||
) -> str:
|
||||
"""Process the input messages and return the model's response.
|
||||
|
||||
This method sends a list of messages to the Hugging Face Inference API, optionally with stop sequences and grammar customization.
|
||||
|
||||
Parameters:
|
||||
messages (`List[Dict[str, str]]`):
|
||||
A list of message dictionaries to be processed. Each dictionary should have the structure `{"role": "user/system", "content": "message content"}`.
|
||||
|
@ -143,22 +134,10 @@ class HfEngine:
|
|||
A list of strings that will stop the generation if encountered in the model's output.
|
||||
grammar (`str`, *optional*):
|
||||
The grammar or formatting structure to use in the model's response.
|
||||
|
||||
max_tokens (`int`, *optional*):
|
||||
The maximum count of tokens to generate.
|
||||
Returns:
|
||||
`str`: The text content of the model's response.
|
||||
|
||||
Example:
|
||||
```python
|
||||
>>> engine = HfApiEngine(
|
||||
... model="Qwen/Qwen2.5-Coder-32B-Instruct",
|
||||
... token="your_hf_token_here",
|
||||
... max_tokens=2000
|
||||
... )
|
||||
>>> messages = [{"role": "user", "content": "Explain quantum mechanics in simple terms."}]
|
||||
>>> response = engine(messages, stop_sequences=["END"])
|
||||
>>> print(response)
|
||||
"Quantum mechanics is the branch of physics that studies..."
|
||||
```
|
||||
"""
|
||||
if not isinstance(messages, List):
|
||||
raise ValueError(
|
||||
|
@ -167,16 +146,8 @@ class HfEngine:
|
|||
if stop_sequences is None:
|
||||
stop_sequences = []
|
||||
response = self.generate(messages, stop_sequences, grammar, max_tokens)
|
||||
self.last_input_token_count = len(
|
||||
self.tokenizer.apply_chat_template(messages, tokenize=True)
|
||||
)
|
||||
self.last_output_token_count = len(self.tokenizer.encode(response))
|
||||
|
||||
# Remove stop sequences from LLM output
|
||||
for stop_seq in stop_sequences:
|
||||
if response[-len(stop_seq) :] == stop_seq:
|
||||
response = response[: -len(stop_seq)]
|
||||
return response
|
||||
return remove_stop_sequences(response, stop_sequences)
|
||||
|
||||
|
||||
class HfApiEngine(HfEngine):
|
||||
|
@ -199,19 +170,32 @@ class HfApiEngine(HfEngine):
|
|||
Raises:
|
||||
ValueError:
|
||||
If the model name is not provided.
|
||||
|
||||
Example:
|
||||
```python
|
||||
>>> engine = HfApiEngine(
|
||||
... model="Qwen/Qwen2.5-Coder-32B-Instruct",
|
||||
... token="your_hf_token_here",
|
||||
... max_tokens=2000
|
||||
... )
|
||||
>>> messages = [{"role": "user", "content": "Explain quantum mechanics in simple terms."}]
|
||||
>>> response = engine(messages, stop_sequences=["END"])
|
||||
>>> print(response)
|
||||
"Quantum mechanics is the branch of physics that studies..."
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str = "Qwen/Qwen2.5-Coder-32B-Instruct",
|
||||
model_id: str = "Qwen/Qwen2.5-Coder-32B-Instruct",
|
||||
token: Optional[str] = None,
|
||||
timeout: Optional[int] = 120,
|
||||
):
|
||||
super().__init__(model_id=model)
|
||||
self.model = model
|
||||
super().__init__()
|
||||
self.model_id = model_id
|
||||
if token is None:
|
||||
token = os.getenv("HF_TOKEN")
|
||||
self.client = InferenceClient(self.model, token=token, timeout=timeout)
|
||||
self.client = InferenceClient(self.model_id, token=token, timeout=timeout)
|
||||
|
||||
def generate(
|
||||
self,
|
||||
|
@ -239,6 +223,8 @@ class HfApiEngine(HfEngine):
|
|||
)
|
||||
|
||||
response = response.choices[0].message.content
|
||||
self.last_input_token_count = response.usage.prompt_tokens
|
||||
self.last_output_token_count = response.usage.completion_tokens
|
||||
return response
|
||||
|
||||
|
||||
|
@ -246,7 +232,19 @@ class TransformersEngine(HfEngine):
|
|||
"""This engine uses a pre-initialized local text-generation pipeline."""
|
||||
|
||||
def __init__(self, pipeline: Pipeline, model_id: Optional[str] = None):
|
||||
super().__init__(model_id)
|
||||
super().__init__()
|
||||
if model_id is None:
|
||||
model_id = "HuggingFaceTB/SmolLM2-1.7B-Instruct"
|
||||
logger.warning(f"`model_id`not provided, using this default tokenizer for token counts: '{model_id}'")
|
||||
try:
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to load tokenizer for model {model_id}: {e}. Loading default tokenizer instead."
|
||||
)
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||
"HuggingFaceTB/SmolLM2-1.7B-Instruct"
|
||||
)
|
||||
self.pipeline = pipeline
|
||||
|
||||
def generate(
|
||||
|
@ -275,6 +273,10 @@ class TransformersEngine(HfEngine):
|
|||
)
|
||||
|
||||
response = output[0]["generated_text"][-1]["content"]
|
||||
self.last_input_token_count = len(
|
||||
self.tokenizer.apply_chat_template(messages, tokenize=True)
|
||||
)
|
||||
self.last_output_token_count = len(self.tokenizer.encode(response))
|
||||
return response
|
||||
|
||||
|
||||
|
@ -320,6 +322,8 @@ class OpenAIEngine:
|
|||
temperature=0.5,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
self.last_input_token_count = response.usage.prompt_tokens
|
||||
self.last_output_token_count = response.usage.completion_tokens
|
||||
return response.choices[0].message.content
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue