Fix additional args in stream_to_gradio (#221)
This commit is contained in:
		
							parent
							
								
									2a69f1574e
								
							
						
					
					
						commit
						fdf4fe49ba
					
				|  | @ -75,12 +75,12 @@ class ToolCall: | |||
|     id: str | ||||
| 
 | ||||
| 
 | ||||
| class AgentStep: | ||||
| class AgentStepLog: | ||||
|     pass | ||||
| 
 | ||||
| 
 | ||||
| @dataclass | ||||
| class ActionStep(AgentStep): | ||||
| class ActionStep(AgentStepLog): | ||||
|     agent_memory: List[Dict[str, str]] | None = None | ||||
|     tool_calls: List[ToolCall] | None = None | ||||
|     start_time: float | None = None | ||||
|  | @ -94,18 +94,18 @@ class ActionStep(AgentStep): | |||
| 
 | ||||
| 
 | ||||
| @dataclass | ||||
| class PlanningStep(AgentStep): | ||||
| class PlanningStep(AgentStepLog): | ||||
|     plan: str | ||||
|     facts: str | ||||
| 
 | ||||
| 
 | ||||
| @dataclass | ||||
| class TaskStep(AgentStep): | ||||
| class TaskStep(AgentStepLog): | ||||
|     task: str | ||||
| 
 | ||||
| 
 | ||||
| @dataclass | ||||
| class SystemPromptStep(AgentStep): | ||||
| class SystemPromptStep(AgentStepLog): | ||||
|     system_prompt: str | ||||
| 
 | ||||
| 
 | ||||
|  |  | |||
|  | @ -19,11 +19,13 @@ import os | |||
| import mimetypes | ||||
| import re | ||||
| 
 | ||||
| from .agents import ActionStep, AgentStep, MultiStepAgent | ||||
| from typing import Optional | ||||
| 
 | ||||
| from .agents import ActionStep, AgentStepLog, MultiStepAgent | ||||
| from .types import AgentAudio, AgentImage, AgentText, handle_agent_output_types | ||||
| 
 | ||||
| 
 | ||||
| def pull_messages_from_step(step_log: AgentStep, test_mode: bool = True): | ||||
| def pull_messages_from_step(step_log: AgentStepLog, test_mode: bool = True): | ||||
|     """Extract ChatMessage objects from agent steps""" | ||||
|     if isinstance(step_log, ActionStep): | ||||
|         yield gr.ChatMessage(role="assistant", content=step_log.llm_output or "") | ||||
|  | @ -53,11 +55,13 @@ def stream_to_gradio( | |||
|     task: str, | ||||
|     test_mode: bool = False, | ||||
|     reset_agent_memory: bool = False, | ||||
|     **kwargs, | ||||
|     additional_args: Optional[dict] = None, | ||||
| ): | ||||
|     """Runs an agent with the given task and streams the messages from the agent as gradio ChatMessages.""" | ||||
| 
 | ||||
|     for step_log in agent.run(task, stream=True, reset=reset_agent_memory, **kwargs): | ||||
|     for step_log in agent.run( | ||||
|         task, stream=True, reset=reset_agent_memory, additional_args=additional_args | ||||
|     ): | ||||
|         for message in pull_messages_from_step(step_log, test_mode=test_mode): | ||||
|             yield message | ||||
| 
 | ||||
|  | @ -172,7 +176,7 @@ class GradioUI: | |||
|                 type="messages", | ||||
|                 avatar_images=( | ||||
|                     None, | ||||
|                     "https://em-content.zobj.net/source/twitter/53/robot-face_1f916.png", | ||||
|                     "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/smolagents/mascot_smol.png", | ||||
|                 ), | ||||
|             ) | ||||
|             # If an upload folder is provided, enable the upload feature | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue