diff --git a/.gitignore b/.gitignore index da67a41..59bba3a 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,7 @@ wandb # Data data outputs +data/ # Apple .DS_Store diff --git a/examples/gradio_upload.py b/examples/gradio_upload.py new file mode 100644 index 0000000..1e5b464 --- /dev/null +++ b/examples/gradio_upload.py @@ -0,0 +1,11 @@ +from smolagents import ( + CodeAgent, + HfApiModel, + GradioUI +) + +agent = CodeAgent( + tools=[], model=HfApiModel(), max_steps=4, verbose=True +) + +GradioUI(agent, file_upload_folder='./data').launch() diff --git a/src/smolagents/gradio_ui.py b/src/smolagents/gradio_ui.py index a326d7c..514bd1f 100644 --- a/src/smolagents/gradio_ui.py +++ b/src/smolagents/gradio_ui.py @@ -1,6 +1,5 @@ #!/usr/bin/env python # coding=utf-8 - # Copyright 2024 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -15,6 +14,10 @@ # See the License for the specific language governing permissions and # limitations under the License. import gradio as gr +import shutil +import os +import mimetypes +import re from .agents import ActionStep, AgentStep, MultiStepAgent from .types import AgentAudio, AgentImage, AgentText, handle_agent_output_types @@ -82,8 +85,12 @@ def stream_to_gradio( class GradioUI: """A one-line interface to launch your agent in Gradio""" - def __init__(self, agent: MultiStepAgent): + def __init__(self, agent: MultiStepAgent, file_upload_folder: str | None=None): self.agent = agent + self.file_upload_folder = file_upload_folder + if self.file_upload_folder is not None: + if not os.path.exists(file_upload_folder): + os.mkdir(file_upload_folder) def interact_with_agent(self, prompt, messages): messages.append(gr.ChatMessage(role="user", content=prompt)) @@ -93,6 +100,45 @@ class GradioUI: yield messages yield messages + def upload_file(self, file, allowed_file_types=["application/pdf", "application/vnd.openxmlformats-officedocument.wordprocessingml.document", "text/plain"]): + """ + Handle file uploads, default allowed types are pdf, docx, and .txt + """ + + # Check if file is uploaded + if file is None: + return "No file uploaded" + + # Check if file is in allowed filetypes + name = os.path.basename(file.name) + try: + mime_type, _ = mimetypes.guess_type(file.name) + except Exception as e: + return f"Error: {e}" + + if mime_type not in allowed_file_types: + return "File type disallowed" + + # Sanitize file name + original_name = os.path.basename(file.name) + sanitized_name = re.sub(r'[^\w\-.]', '_', original_name) # Replace any non-alphanumeric, non-dash, or non-dot characters with underscores + + type_to_ext = {} + for ext, t in mimetypes.types_map.items(): + if t not in type_to_ext: + type_to_ext[t] = ext + + # Ensure the extension correlates to the mime type + sanitized_name = sanitized_name.split(".")[:-1] + sanitized_name.append("" + type_to_ext[mime_type]) + sanitized_name = "".join(sanitized_name) + + # Save the uploaded file to the specified folder + file_path = os.path.join(self.file_upload_folder, os.path.basename(sanitized_name)) + shutil.copy(file.name, file_path) + + return f"File uploaded successfully to {self.file_upload_folder}" + def launch(self): with gr.Blocks() as demo: stored_message = gr.State([]) @@ -104,6 +150,14 @@ class GradioUI: "https://em-content.zobj.net/source/twitter/53/robot-face_1f916.png", ), ) + # If an upload folder is provided, enable the upload feature + if self.file_upload_folder is not None: + upload_file = gr.File(label="Upload a file") + upload_status = gr.Textbox(label="Upload Status", interactive=False) + + upload_file.change( + self.upload_file, [upload_file], [upload_status] + ) text_input = gr.Textbox(lines=1, label="Chat Message") text_input.submit( lambda s: (s, ""), [text_input], [stored_message, text_input]