Improve GradioUI file upload system
This commit is contained in:
		
							parent
							
								
									1f96560c92
								
							
						
					
					
						commit
						1d846072eb
					
				|  | @ -5,7 +5,7 @@ from smolagents import ( | |||
| ) | ||||
| 
 | ||||
| agent = CodeAgent( | ||||
|     tools=[], model=HfApiModel(), max_steps=4, verbose=True | ||||
|     tools=[], model=HfApiModel(), max_steps=4, verbosity_level=0 | ||||
| ) | ||||
| 
 | ||||
| GradioUI(agent, file_upload_folder='./data').launch() | ||||
|  |  | |||
|  | @ -396,7 +396,7 @@ class MultiStepAgent: | |||
|             } | ||||
|         ] | ||||
|         try: | ||||
|             return self.model(self.input_messages) | ||||
|             return self.model(self.input_messages).content | ||||
|         except Exception as e: | ||||
|             return f"Error in generating final LLM output:\n{e}" | ||||
| 
 | ||||
|  | @ -666,7 +666,9 @@ You have been provided with these additional arguments, that you can access usin | |||
| Now begin!""", | ||||
|             } | ||||
| 
 | ||||
|             answer_facts = self.model([message_prompt_facts, message_prompt_task]) | ||||
|             answer_facts = self.model( | ||||
|                 [message_prompt_facts, message_prompt_task] | ||||
|             ).content | ||||
| 
 | ||||
|             message_system_prompt_plan = { | ||||
|                 "role": MessageRole.SYSTEM, | ||||
|  | @ -688,7 +690,7 @@ Now begin!""", | |||
|             answer_plan = self.model( | ||||
|                 [message_system_prompt_plan, message_user_prompt_plan], | ||||
|                 stop_sequences=["<end_plan>"], | ||||
|             ) | ||||
|             ).content | ||||
| 
 | ||||
|             final_plan_redaction = f"""Here is the plan of action that I will follow to solve the task: | ||||
| ``` | ||||
|  | @ -722,7 +724,7 @@ Now begin!""", | |||
|             } | ||||
|             facts_update = self.model( | ||||
|                 [facts_update_system_prompt] + agent_memory + [facts_update_message] | ||||
|             ) | ||||
|             ).content | ||||
| 
 | ||||
|             # Redact updated plan | ||||
|             plan_update_message = { | ||||
|  | @ -807,17 +809,26 @@ class ToolCallingAgent(MultiStepAgent): | |||
|                 tools_to_call_from=list(self.tools.values()), | ||||
|                 stop_sequences=["Observation:"], | ||||
|             ) | ||||
|              | ||||
| 
 | ||||
|             # Extract tool call from model output | ||||
|             if type(model_message.tool_calls) is list and len(model_message.tool_calls) > 0: | ||||
|             if ( | ||||
|                 type(model_message.tool_calls) is list | ||||
|                 and len(model_message.tool_calls) > 0 | ||||
|             ): | ||||
|                 tool_calls = model_message.tool_calls[0] | ||||
|                 tool_arguments = tool_calls.function.arguments | ||||
|                 tool_name, tool_call_id = tool_calls.function.name, tool_calls.id | ||||
|             else: | ||||
|                 start, end = model_message.content.find('{'), model_message.content.rfind('}') + 1 | ||||
|                 start, end = ( | ||||
|                     model_message.content.find("{"), | ||||
|                     model_message.content.rfind("}") + 1, | ||||
|                 ) | ||||
|                 tool_calls = json.loads(model_message.content[start:end]) | ||||
|                 tool_arguments = tool_calls["tool_arguments"] | ||||
|                 tool_name, tool_call_id = tool_calls["tool_name"], f"call_{len(self.logs)}" | ||||
|                 tool_name, tool_call_id = ( | ||||
|                     tool_calls["tool_name"], | ||||
|                     f"call_{len(self.logs)}", | ||||
|                 ) | ||||
| 
 | ||||
|         except Exception as e: | ||||
|             raise AgentGenerationError( | ||||
|  |  | |||
|  | @ -27,14 +27,15 @@ def pull_messages_from_step(step_log: AgentStep, test_mode: bool = True): | |||
|     """Extract ChatMessage objects from agent steps""" | ||||
|     if isinstance(step_log, ActionStep): | ||||
|         yield gr.ChatMessage(role="assistant", content=step_log.llm_output or "") | ||||
|         if step_log.tool_call is not None: | ||||
|             used_code = step_log.tool_call.name == "code interpreter" | ||||
|             content = step_log.tool_call.arguments | ||||
|         if step_log.tool_calls is not None: | ||||
|             first_tool_call = step_log.tool_calls[0] | ||||
|             used_code = first_tool_call.name == "code interpreter" | ||||
|             content = first_tool_call.arguments | ||||
|             if used_code: | ||||
|                 content = f"```py\n{content}\n```" | ||||
|             yield gr.ChatMessage( | ||||
|                 role="assistant", | ||||
|                 metadata={"title": f"🛠️ Used tool {step_log.tool_call.name}"}, | ||||
|                 metadata={"title": f"🛠️ Used tool {first_tool_call.name}"}, | ||||
|                 content=str(content), | ||||
|             ) | ||||
|         if step_log.observations is not None: | ||||
|  | @ -103,6 +104,7 @@ class GradioUI: | |||
|     def upload_file( | ||||
|         self, | ||||
|         file, | ||||
|         file_uploads_log, | ||||
|         allowed_file_types=[ | ||||
|             "application/pdf", | ||||
|             "application/vnd.openxmlformats-officedocument.wordprocessingml.document", | ||||
|  | @ -110,14 +112,12 @@ class GradioUI: | |||
|         ], | ||||
|     ): | ||||
|         """ | ||||
|         Handle file uploads, default allowed types are pdf, docx, and .txt | ||||
|         Handle file uploads, default allowed types are .pdf, .docx, and .txt | ||||
|         """ | ||||
| 
 | ||||
|         # Check if file is uploaded | ||||
|         if file is None: | ||||
|             return "No file uploaded" | ||||
| 
 | ||||
|         # Check if file is in allowed filetypes | ||||
|         try: | ||||
|             mime_type, _ = mimetypes.guess_type(file.name) | ||||
|         except Exception as e: | ||||
|  | @ -148,11 +148,23 @@ class GradioUI: | |||
|         ) | ||||
|         shutil.copy(file.name, file_path) | ||||
| 
 | ||||
|         return f"File uploaded successfully to {self.file_upload_folder}" | ||||
|         return gr.Textbox( | ||||
|             f"File uploaded: {file_path}", visible=True | ||||
|         ), file_uploads_log + [file_path] | ||||
| 
 | ||||
|     def log_user_message(self, text_input, file_uploads_log): | ||||
|         return ( | ||||
|             text_input | ||||
|             + f"\nYou have been provided with these files, which might be helpful or not: {file_uploads_log}" | ||||
|             if len(file_uploads_log) > 0 | ||||
|             else "", | ||||
|             "", | ||||
|         ) | ||||
| 
 | ||||
|     def launch(self): | ||||
|         with gr.Blocks() as demo: | ||||
|             stored_message = gr.State([]) | ||||
|             stored_messages = gr.State([]) | ||||
|             file_uploads_log = gr.State([]) | ||||
|             chatbot = gr.Chatbot( | ||||
|                 label="Agent", | ||||
|                 type="messages", | ||||
|  | @ -163,14 +175,21 @@ class GradioUI: | |||
|             ) | ||||
|             # If an upload folder is provided, enable the upload feature | ||||
|             if self.file_upload_folder is not None: | ||||
|                 upload_file = gr.File(label="Upload a file") | ||||
|                 upload_status = gr.Textbox(label="Upload Status", interactive=False) | ||||
| 
 | ||||
|                 upload_file.change(self.upload_file, [upload_file], [upload_status]) | ||||
|                 upload_file = gr.File(label="Upload a file", height=1) | ||||
|                 upload_status = gr.Textbox( | ||||
|                     label="Upload Status", interactive=False, visible=False | ||||
|                 ) | ||||
|                 upload_file.change( | ||||
|                     self.upload_file, | ||||
|                     [upload_file, file_uploads_log], | ||||
|                     [upload_status, file_uploads_log], | ||||
|                 ) | ||||
|             text_input = gr.Textbox(lines=1, label="Chat Message") | ||||
|             text_input.submit( | ||||
|                 lambda s: (s, ""), [text_input], [stored_message, text_input] | ||||
|             ).then(self.interact_with_agent, [stored_message, chatbot], [chatbot]) | ||||
|                 self.log_user_message, | ||||
|                 [text_input, file_uploads_log], | ||||
|                 [stored_messages, text_input], | ||||
|             ).then(self.interact_with_agent, [stored_messages, chatbot], [chatbot]) | ||||
| 
 | ||||
|         demo.launch() | ||||
| 
 | ||||
|  |  | |||
|  | @ -36,6 +36,8 @@ from transformers import ( | |||
|     StoppingCriteriaList, | ||||
|     is_torch_available, | ||||
| ) | ||||
| from transformers.utils.import_utils import _is_package_available | ||||
| 
 | ||||
| import openai | ||||
| 
 | ||||
| from .tools import Tool | ||||
|  | @ -52,13 +54,9 @@ DEFAULT_CODEAGENT_REGEX_GRAMMAR = { | |||
|     "value": "Thought: .+?\\nCode:\\n```(?:py|python)?\\n(?:.|\\s)+?\\n```<end_code>", | ||||
| } | ||||
| 
 | ||||
| try: | ||||
| if _is_package_available("litellm"): | ||||
|     import litellm | ||||
| 
 | ||||
|     is_litellm_available = True | ||||
| except ImportError: | ||||
|     is_litellm_available = False | ||||
| 
 | ||||
| 
 | ||||
| class MessageRole(str, Enum): | ||||
|     USER = "user" | ||||
|  | @ -159,7 +157,7 @@ class Model: | |||
|         stop_sequences: Optional[List[str]] = None, | ||||
|         grammar: Optional[str] = None, | ||||
|         max_tokens: int = 1500, | ||||
|     ) -> str: | ||||
|     ) -> ChatCompletionOutputMessage: | ||||
|         """Process the input messages and return the model's response. | ||||
| 
 | ||||
|         Parameters: | ||||
|  | @ -174,15 +172,7 @@ class Model: | |||
|         Returns: | ||||
|             `str`: The text content of the model's response. | ||||
|         """ | ||||
|         if not isinstance(messages, List): | ||||
|             raise ValueError( | ||||
|                 "Messages should be a list of dictionaries with 'role' and 'content' keys." | ||||
|             ) | ||||
|         if stop_sequences is None: | ||||
|             stop_sequences = [] | ||||
|         response = self.generate(messages, stop_sequences, grammar, max_tokens) | ||||
| 
 | ||||
|         return remove_stop_sequences(response, stop_sequences) | ||||
|         pass  # To be implemented in child classes! | ||||
| 
 | ||||
| 
 | ||||
| class HfApiModel(Model): | ||||
|  | @ -238,7 +228,7 @@ class HfApiModel(Model): | |||
|         grammar: Optional[str] = None, | ||||
|         max_tokens: int = 1500, | ||||
|         tools_to_call_from: Optional[List[Tool]] = None, | ||||
|     ) -> str: | ||||
|     ) -> ChatCompletionOutputMessage: | ||||
|         """ | ||||
|         Gets an LLM output message for the given list of input messages. | ||||
|         If argument `tools_to_call_from` is passed, the model's tool calling options will be used to return a tool call. | ||||
|  | @ -407,7 +397,7 @@ class LiteLLMModel(Model): | |||
|         api_key=None, | ||||
|         **kwargs, | ||||
|     ): | ||||
|         if not is_litellm_available: | ||||
|         if not _is_package_available("litellm"): | ||||
|             raise ImportError( | ||||
|                 "litellm not found. Install it with `pip install litellm`" | ||||
|             ) | ||||
|  | @ -426,7 +416,7 @@ class LiteLLMModel(Model): | |||
|         grammar: Optional[str] = None, | ||||
|         max_tokens: int = 1500, | ||||
|         tools_to_call_from: Optional[List[Tool]] = None, | ||||
|     ) -> str: | ||||
|     ) -> ChatCompletionOutputMessage: | ||||
|         messages = get_clean_message_list( | ||||
|             messages, role_conversions=tool_role_conversions | ||||
|         ) | ||||
|  | @ -497,7 +487,7 @@ class OpenAIServerModel(Model): | |||
|         grammar: Optional[str] = None, | ||||
|         max_tokens: int = 1500, | ||||
|         tools_to_call_from: Optional[List[Tool]] = None, | ||||
|     ) -> str: | ||||
|     ) -> ChatCompletionOutputMessage: | ||||
|         messages = get_clean_message_list( | ||||
|             messages, role_conversions=tool_role_conversions | ||||
|         ) | ||||
|  |  | |||
|  | @ -367,9 +367,10 @@ class AgentTests(unittest.TestCase): | |||
|             model=fake_code_model_no_return,  # use this callable because it never ends | ||||
|             max_steps=5, | ||||
|         ) | ||||
|         agent.run("What is 2 multiplied by 3.6452?") | ||||
|         answer = agent.run("What is 2 multiplied by 3.6452?") | ||||
|         assert len(agent.logs) == 8 | ||||
|         assert type(agent.logs[-1].error) is AgentMaxStepsError | ||||
|         assert isinstance(answer, str) | ||||
| 
 | ||||
|     def test_tool_descriptions_get_baked_in_system_prompt(self): | ||||
|         tool = PythonInterpreterTool() | ||||
|  |  | |||
|  | @ -486,6 +486,7 @@ if char.isalpha(): | |||
|         code = "import numpy.random as rd" | ||||
|         evaluate_python_code(code, authorized_imports=["numpy.random"], state={}) | ||||
|         evaluate_python_code(code, authorized_imports=["numpy"], state={}) | ||||
|         evaluate_python_code(code, authorized_imports=["*"], state={}) | ||||
|         with pytest.raises(InterpreterError): | ||||
|             evaluate_python_code(code, authorized_imports=["random"], state={}) | ||||
| 
 | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue