Improve GradioUI file upload system
This commit is contained in:
		
							parent
							
								
									1f96560c92
								
							
						
					
					
						commit
						1d846072eb
					
				|  | @ -5,7 +5,7 @@ from smolagents import ( | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| agent = CodeAgent( | agent = CodeAgent( | ||||||
|     tools=[], model=HfApiModel(), max_steps=4, verbose=True |     tools=[], model=HfApiModel(), max_steps=4, verbosity_level=0 | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| GradioUI(agent, file_upload_folder='./data').launch() | GradioUI(agent, file_upload_folder='./data').launch() | ||||||
|  |  | ||||||
|  | @ -396,7 +396,7 @@ class MultiStepAgent: | ||||||
|             } |             } | ||||||
|         ] |         ] | ||||||
|         try: |         try: | ||||||
|             return self.model(self.input_messages) |             return self.model(self.input_messages).content | ||||||
|         except Exception as e: |         except Exception as e: | ||||||
|             return f"Error in generating final LLM output:\n{e}" |             return f"Error in generating final LLM output:\n{e}" | ||||||
| 
 | 
 | ||||||
|  | @ -666,7 +666,9 @@ You have been provided with these additional arguments, that you can access usin | ||||||
| Now begin!""", | Now begin!""", | ||||||
|             } |             } | ||||||
| 
 | 
 | ||||||
|             answer_facts = self.model([message_prompt_facts, message_prompt_task]) |             answer_facts = self.model( | ||||||
|  |                 [message_prompt_facts, message_prompt_task] | ||||||
|  |             ).content | ||||||
| 
 | 
 | ||||||
|             message_system_prompt_plan = { |             message_system_prompt_plan = { | ||||||
|                 "role": MessageRole.SYSTEM, |                 "role": MessageRole.SYSTEM, | ||||||
|  | @ -688,7 +690,7 @@ Now begin!""", | ||||||
|             answer_plan = self.model( |             answer_plan = self.model( | ||||||
|                 [message_system_prompt_plan, message_user_prompt_plan], |                 [message_system_prompt_plan, message_user_prompt_plan], | ||||||
|                 stop_sequences=["<end_plan>"], |                 stop_sequences=["<end_plan>"], | ||||||
|             ) |             ).content | ||||||
| 
 | 
 | ||||||
|             final_plan_redaction = f"""Here is the plan of action that I will follow to solve the task: |             final_plan_redaction = f"""Here is the plan of action that I will follow to solve the task: | ||||||
| ``` | ``` | ||||||
|  | @ -722,7 +724,7 @@ Now begin!""", | ||||||
|             } |             } | ||||||
|             facts_update = self.model( |             facts_update = self.model( | ||||||
|                 [facts_update_system_prompt] + agent_memory + [facts_update_message] |                 [facts_update_system_prompt] + agent_memory + [facts_update_message] | ||||||
|             ) |             ).content | ||||||
| 
 | 
 | ||||||
|             # Redact updated plan |             # Redact updated plan | ||||||
|             plan_update_message = { |             plan_update_message = { | ||||||
|  | @ -807,17 +809,26 @@ class ToolCallingAgent(MultiStepAgent): | ||||||
|                 tools_to_call_from=list(self.tools.values()), |                 tools_to_call_from=list(self.tools.values()), | ||||||
|                 stop_sequences=["Observation:"], |                 stop_sequences=["Observation:"], | ||||||
|             ) |             ) | ||||||
|              | 
 | ||||||
|             # Extract tool call from model output |             # Extract tool call from model output | ||||||
|             if type(model_message.tool_calls) is list and len(model_message.tool_calls) > 0: |             if ( | ||||||
|  |                 type(model_message.tool_calls) is list | ||||||
|  |                 and len(model_message.tool_calls) > 0 | ||||||
|  |             ): | ||||||
|                 tool_calls = model_message.tool_calls[0] |                 tool_calls = model_message.tool_calls[0] | ||||||
|                 tool_arguments = tool_calls.function.arguments |                 tool_arguments = tool_calls.function.arguments | ||||||
|                 tool_name, tool_call_id = tool_calls.function.name, tool_calls.id |                 tool_name, tool_call_id = tool_calls.function.name, tool_calls.id | ||||||
|             else: |             else: | ||||||
|                 start, end = model_message.content.find('{'), model_message.content.rfind('}') + 1 |                 start, end = ( | ||||||
|  |                     model_message.content.find("{"), | ||||||
|  |                     model_message.content.rfind("}") + 1, | ||||||
|  |                 ) | ||||||
|                 tool_calls = json.loads(model_message.content[start:end]) |                 tool_calls = json.loads(model_message.content[start:end]) | ||||||
|                 tool_arguments = tool_calls["tool_arguments"] |                 tool_arguments = tool_calls["tool_arguments"] | ||||||
|                 tool_name, tool_call_id = tool_calls["tool_name"], f"call_{len(self.logs)}" |                 tool_name, tool_call_id = ( | ||||||
|  |                     tool_calls["tool_name"], | ||||||
|  |                     f"call_{len(self.logs)}", | ||||||
|  |                 ) | ||||||
| 
 | 
 | ||||||
|         except Exception as e: |         except Exception as e: | ||||||
|             raise AgentGenerationError( |             raise AgentGenerationError( | ||||||
|  |  | ||||||
|  | @ -27,14 +27,15 @@ def pull_messages_from_step(step_log: AgentStep, test_mode: bool = True): | ||||||
|     """Extract ChatMessage objects from agent steps""" |     """Extract ChatMessage objects from agent steps""" | ||||||
|     if isinstance(step_log, ActionStep): |     if isinstance(step_log, ActionStep): | ||||||
|         yield gr.ChatMessage(role="assistant", content=step_log.llm_output or "") |         yield gr.ChatMessage(role="assistant", content=step_log.llm_output or "") | ||||||
|         if step_log.tool_call is not None: |         if step_log.tool_calls is not None: | ||||||
|             used_code = step_log.tool_call.name == "code interpreter" |             first_tool_call = step_log.tool_calls[0] | ||||||
|             content = step_log.tool_call.arguments |             used_code = first_tool_call.name == "code interpreter" | ||||||
|  |             content = first_tool_call.arguments | ||||||
|             if used_code: |             if used_code: | ||||||
|                 content = f"```py\n{content}\n```" |                 content = f"```py\n{content}\n```" | ||||||
|             yield gr.ChatMessage( |             yield gr.ChatMessage( | ||||||
|                 role="assistant", |                 role="assistant", | ||||||
|                 metadata={"title": f"🛠️ Used tool {step_log.tool_call.name}"}, |                 metadata={"title": f"🛠️ Used tool {first_tool_call.name}"}, | ||||||
|                 content=str(content), |                 content=str(content), | ||||||
|             ) |             ) | ||||||
|         if step_log.observations is not None: |         if step_log.observations is not None: | ||||||
|  | @ -103,6 +104,7 @@ class GradioUI: | ||||||
|     def upload_file( |     def upload_file( | ||||||
|         self, |         self, | ||||||
|         file, |         file, | ||||||
|  |         file_uploads_log, | ||||||
|         allowed_file_types=[ |         allowed_file_types=[ | ||||||
|             "application/pdf", |             "application/pdf", | ||||||
|             "application/vnd.openxmlformats-officedocument.wordprocessingml.document", |             "application/vnd.openxmlformats-officedocument.wordprocessingml.document", | ||||||
|  | @ -110,14 +112,12 @@ class GradioUI: | ||||||
|         ], |         ], | ||||||
|     ): |     ): | ||||||
|         """ |         """ | ||||||
|         Handle file uploads, default allowed types are pdf, docx, and .txt |         Handle file uploads, default allowed types are .pdf, .docx, and .txt | ||||||
|         """ |         """ | ||||||
| 
 | 
 | ||||||
|         # Check if file is uploaded |  | ||||||
|         if file is None: |         if file is None: | ||||||
|             return "No file uploaded" |             return "No file uploaded" | ||||||
| 
 | 
 | ||||||
|         # Check if file is in allowed filetypes |  | ||||||
|         try: |         try: | ||||||
|             mime_type, _ = mimetypes.guess_type(file.name) |             mime_type, _ = mimetypes.guess_type(file.name) | ||||||
|         except Exception as e: |         except Exception as e: | ||||||
|  | @ -148,11 +148,23 @@ class GradioUI: | ||||||
|         ) |         ) | ||||||
|         shutil.copy(file.name, file_path) |         shutil.copy(file.name, file_path) | ||||||
| 
 | 
 | ||||||
|         return f"File uploaded successfully to {self.file_upload_folder}" |         return gr.Textbox( | ||||||
|  |             f"File uploaded: {file_path}", visible=True | ||||||
|  |         ), file_uploads_log + [file_path] | ||||||
|  | 
 | ||||||
|  |     def log_user_message(self, text_input, file_uploads_log): | ||||||
|  |         return ( | ||||||
|  |             text_input | ||||||
|  |             + f"\nYou have been provided with these files, which might be helpful or not: {file_uploads_log}" | ||||||
|  |             if len(file_uploads_log) > 0 | ||||||
|  |             else "", | ||||||
|  |             "", | ||||||
|  |         ) | ||||||
| 
 | 
 | ||||||
|     def launch(self): |     def launch(self): | ||||||
|         with gr.Blocks() as demo: |         with gr.Blocks() as demo: | ||||||
|             stored_message = gr.State([]) |             stored_messages = gr.State([]) | ||||||
|  |             file_uploads_log = gr.State([]) | ||||||
|             chatbot = gr.Chatbot( |             chatbot = gr.Chatbot( | ||||||
|                 label="Agent", |                 label="Agent", | ||||||
|                 type="messages", |                 type="messages", | ||||||
|  | @ -163,14 +175,21 @@ class GradioUI: | ||||||
|             ) |             ) | ||||||
|             # If an upload folder is provided, enable the upload feature |             # If an upload folder is provided, enable the upload feature | ||||||
|             if self.file_upload_folder is not None: |             if self.file_upload_folder is not None: | ||||||
|                 upload_file = gr.File(label="Upload a file") |                 upload_file = gr.File(label="Upload a file", height=1) | ||||||
|                 upload_status = gr.Textbox(label="Upload Status", interactive=False) |                 upload_status = gr.Textbox( | ||||||
| 
 |                     label="Upload Status", interactive=False, visible=False | ||||||
|                 upload_file.change(self.upload_file, [upload_file], [upload_status]) |                 ) | ||||||
|  |                 upload_file.change( | ||||||
|  |                     self.upload_file, | ||||||
|  |                     [upload_file, file_uploads_log], | ||||||
|  |                     [upload_status, file_uploads_log], | ||||||
|  |                 ) | ||||||
|             text_input = gr.Textbox(lines=1, label="Chat Message") |             text_input = gr.Textbox(lines=1, label="Chat Message") | ||||||
|             text_input.submit( |             text_input.submit( | ||||||
|                 lambda s: (s, ""), [text_input], [stored_message, text_input] |                 self.log_user_message, | ||||||
|             ).then(self.interact_with_agent, [stored_message, chatbot], [chatbot]) |                 [text_input, file_uploads_log], | ||||||
|  |                 [stored_messages, text_input], | ||||||
|  |             ).then(self.interact_with_agent, [stored_messages, chatbot], [chatbot]) | ||||||
| 
 | 
 | ||||||
|         demo.launch() |         demo.launch() | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -36,6 +36,8 @@ from transformers import ( | ||||||
|     StoppingCriteriaList, |     StoppingCriteriaList, | ||||||
|     is_torch_available, |     is_torch_available, | ||||||
| ) | ) | ||||||
|  | from transformers.utils.import_utils import _is_package_available | ||||||
|  | 
 | ||||||
| import openai | import openai | ||||||
| 
 | 
 | ||||||
| from .tools import Tool | from .tools import Tool | ||||||
|  | @ -52,13 +54,9 @@ DEFAULT_CODEAGENT_REGEX_GRAMMAR = { | ||||||
|     "value": "Thought: .+?\\nCode:\\n```(?:py|python)?\\n(?:.|\\s)+?\\n```<end_code>", |     "value": "Thought: .+?\\nCode:\\n```(?:py|python)?\\n(?:.|\\s)+?\\n```<end_code>", | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| try: | if _is_package_available("litellm"): | ||||||
|     import litellm |     import litellm | ||||||
| 
 | 
 | ||||||
|     is_litellm_available = True |  | ||||||
| except ImportError: |  | ||||||
|     is_litellm_available = False |  | ||||||
| 
 |  | ||||||
| 
 | 
 | ||||||
| class MessageRole(str, Enum): | class MessageRole(str, Enum): | ||||||
|     USER = "user" |     USER = "user" | ||||||
|  | @ -159,7 +157,7 @@ class Model: | ||||||
|         stop_sequences: Optional[List[str]] = None, |         stop_sequences: Optional[List[str]] = None, | ||||||
|         grammar: Optional[str] = None, |         grammar: Optional[str] = None, | ||||||
|         max_tokens: int = 1500, |         max_tokens: int = 1500, | ||||||
|     ) -> str: |     ) -> ChatCompletionOutputMessage: | ||||||
|         """Process the input messages and return the model's response. |         """Process the input messages and return the model's response. | ||||||
| 
 | 
 | ||||||
|         Parameters: |         Parameters: | ||||||
|  | @ -174,15 +172,7 @@ class Model: | ||||||
|         Returns: |         Returns: | ||||||
|             `str`: The text content of the model's response. |             `str`: The text content of the model's response. | ||||||
|         """ |         """ | ||||||
|         if not isinstance(messages, List): |         pass  # To be implemented in child classes! | ||||||
|             raise ValueError( |  | ||||||
|                 "Messages should be a list of dictionaries with 'role' and 'content' keys." |  | ||||||
|             ) |  | ||||||
|         if stop_sequences is None: |  | ||||||
|             stop_sequences = [] |  | ||||||
|         response = self.generate(messages, stop_sequences, grammar, max_tokens) |  | ||||||
| 
 |  | ||||||
|         return remove_stop_sequences(response, stop_sequences) |  | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class HfApiModel(Model): | class HfApiModel(Model): | ||||||
|  | @ -238,7 +228,7 @@ class HfApiModel(Model): | ||||||
|         grammar: Optional[str] = None, |         grammar: Optional[str] = None, | ||||||
|         max_tokens: int = 1500, |         max_tokens: int = 1500, | ||||||
|         tools_to_call_from: Optional[List[Tool]] = None, |         tools_to_call_from: Optional[List[Tool]] = None, | ||||||
|     ) -> str: |     ) -> ChatCompletionOutputMessage: | ||||||
|         """ |         """ | ||||||
|         Gets an LLM output message for the given list of input messages. |         Gets an LLM output message for the given list of input messages. | ||||||
|         If argument `tools_to_call_from` is passed, the model's tool calling options will be used to return a tool call. |         If argument `tools_to_call_from` is passed, the model's tool calling options will be used to return a tool call. | ||||||
|  | @ -407,7 +397,7 @@ class LiteLLMModel(Model): | ||||||
|         api_key=None, |         api_key=None, | ||||||
|         **kwargs, |         **kwargs, | ||||||
|     ): |     ): | ||||||
|         if not is_litellm_available: |         if not _is_package_available("litellm"): | ||||||
|             raise ImportError( |             raise ImportError( | ||||||
|                 "litellm not found. Install it with `pip install litellm`" |                 "litellm not found. Install it with `pip install litellm`" | ||||||
|             ) |             ) | ||||||
|  | @ -426,7 +416,7 @@ class LiteLLMModel(Model): | ||||||
|         grammar: Optional[str] = None, |         grammar: Optional[str] = None, | ||||||
|         max_tokens: int = 1500, |         max_tokens: int = 1500, | ||||||
|         tools_to_call_from: Optional[List[Tool]] = None, |         tools_to_call_from: Optional[List[Tool]] = None, | ||||||
|     ) -> str: |     ) -> ChatCompletionOutputMessage: | ||||||
|         messages = get_clean_message_list( |         messages = get_clean_message_list( | ||||||
|             messages, role_conversions=tool_role_conversions |             messages, role_conversions=tool_role_conversions | ||||||
|         ) |         ) | ||||||
|  | @ -497,7 +487,7 @@ class OpenAIServerModel(Model): | ||||||
|         grammar: Optional[str] = None, |         grammar: Optional[str] = None, | ||||||
|         max_tokens: int = 1500, |         max_tokens: int = 1500, | ||||||
|         tools_to_call_from: Optional[List[Tool]] = None, |         tools_to_call_from: Optional[List[Tool]] = None, | ||||||
|     ) -> str: |     ) -> ChatCompletionOutputMessage: | ||||||
|         messages = get_clean_message_list( |         messages = get_clean_message_list( | ||||||
|             messages, role_conversions=tool_role_conversions |             messages, role_conversions=tool_role_conversions | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|  | @ -367,9 +367,10 @@ class AgentTests(unittest.TestCase): | ||||||
|             model=fake_code_model_no_return,  # use this callable because it never ends |             model=fake_code_model_no_return,  # use this callable because it never ends | ||||||
|             max_steps=5, |             max_steps=5, | ||||||
|         ) |         ) | ||||||
|         agent.run("What is 2 multiplied by 3.6452?") |         answer = agent.run("What is 2 multiplied by 3.6452?") | ||||||
|         assert len(agent.logs) == 8 |         assert len(agent.logs) == 8 | ||||||
|         assert type(agent.logs[-1].error) is AgentMaxStepsError |         assert type(agent.logs[-1].error) is AgentMaxStepsError | ||||||
|  |         assert isinstance(answer, str) | ||||||
| 
 | 
 | ||||||
|     def test_tool_descriptions_get_baked_in_system_prompt(self): |     def test_tool_descriptions_get_baked_in_system_prompt(self): | ||||||
|         tool = PythonInterpreterTool() |         tool = PythonInterpreterTool() | ||||||
|  |  | ||||||
|  | @ -486,6 +486,7 @@ if char.isalpha(): | ||||||
|         code = "import numpy.random as rd" |         code = "import numpy.random as rd" | ||||||
|         evaluate_python_code(code, authorized_imports=["numpy.random"], state={}) |         evaluate_python_code(code, authorized_imports=["numpy.random"], state={}) | ||||||
|         evaluate_python_code(code, authorized_imports=["numpy"], state={}) |         evaluate_python_code(code, authorized_imports=["numpy"], state={}) | ||||||
|  |         evaluate_python_code(code, authorized_imports=["*"], state={}) | ||||||
|         with pytest.raises(InterpreterError): |         with pytest.raises(InterpreterError): | ||||||
|             evaluate_python_code(code, authorized_imports=["random"], state={}) |             evaluate_python_code(code, authorized_imports=["random"], state={}) | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue