Fix additional args in stream_to_gradio (#221)
This commit is contained in:
		
							parent
							
								
									2a69f1574e
								
							
						
					
					
						commit
						fdf4fe49ba
					
				|  | @ -75,12 +75,12 @@ class ToolCall: | ||||||
|     id: str |     id: str | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class AgentStep: | class AgentStepLog: | ||||||
|     pass |     pass | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @dataclass | @dataclass | ||||||
| class ActionStep(AgentStep): | class ActionStep(AgentStepLog): | ||||||
|     agent_memory: List[Dict[str, str]] | None = None |     agent_memory: List[Dict[str, str]] | None = None | ||||||
|     tool_calls: List[ToolCall] | None = None |     tool_calls: List[ToolCall] | None = None | ||||||
|     start_time: float | None = None |     start_time: float | None = None | ||||||
|  | @ -94,18 +94,18 @@ class ActionStep(AgentStep): | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @dataclass | @dataclass | ||||||
| class PlanningStep(AgentStep): | class PlanningStep(AgentStepLog): | ||||||
|     plan: str |     plan: str | ||||||
|     facts: str |     facts: str | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @dataclass | @dataclass | ||||||
| class TaskStep(AgentStep): | class TaskStep(AgentStepLog): | ||||||
|     task: str |     task: str | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @dataclass | @dataclass | ||||||
| class SystemPromptStep(AgentStep): | class SystemPromptStep(AgentStepLog): | ||||||
|     system_prompt: str |     system_prompt: str | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -19,11 +19,13 @@ import os | ||||||
| import mimetypes | import mimetypes | ||||||
| import re | import re | ||||||
| 
 | 
 | ||||||
| from .agents import ActionStep, AgentStep, MultiStepAgent | from typing import Optional | ||||||
|  | 
 | ||||||
|  | from .agents import ActionStep, AgentStepLog, MultiStepAgent | ||||||
| from .types import AgentAudio, AgentImage, AgentText, handle_agent_output_types | from .types import AgentAudio, AgentImage, AgentText, handle_agent_output_types | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def pull_messages_from_step(step_log: AgentStep, test_mode: bool = True): | def pull_messages_from_step(step_log: AgentStepLog, test_mode: bool = True): | ||||||
|     """Extract ChatMessage objects from agent steps""" |     """Extract ChatMessage objects from agent steps""" | ||||||
|     if isinstance(step_log, ActionStep): |     if isinstance(step_log, ActionStep): | ||||||
|         yield gr.ChatMessage(role="assistant", content=step_log.llm_output or "") |         yield gr.ChatMessage(role="assistant", content=step_log.llm_output or "") | ||||||
|  | @ -53,11 +55,13 @@ def stream_to_gradio( | ||||||
|     task: str, |     task: str, | ||||||
|     test_mode: bool = False, |     test_mode: bool = False, | ||||||
|     reset_agent_memory: bool = False, |     reset_agent_memory: bool = False, | ||||||
|     **kwargs, |     additional_args: Optional[dict] = None, | ||||||
| ): | ): | ||||||
|     """Runs an agent with the given task and streams the messages from the agent as gradio ChatMessages.""" |     """Runs an agent with the given task and streams the messages from the agent as gradio ChatMessages.""" | ||||||
| 
 | 
 | ||||||
|     for step_log in agent.run(task, stream=True, reset=reset_agent_memory, **kwargs): |     for step_log in agent.run( | ||||||
|  |         task, stream=True, reset=reset_agent_memory, additional_args=additional_args | ||||||
|  |     ): | ||||||
|         for message in pull_messages_from_step(step_log, test_mode=test_mode): |         for message in pull_messages_from_step(step_log, test_mode=test_mode): | ||||||
|             yield message |             yield message | ||||||
| 
 | 
 | ||||||
|  | @ -172,7 +176,7 @@ class GradioUI: | ||||||
|                 type="messages", |                 type="messages", | ||||||
|                 avatar_images=( |                 avatar_images=( | ||||||
|                     None, |                     None, | ||||||
|                     "https://em-content.zobj.net/source/twitter/53/robot-face_1f916.png", |                     "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/smolagents/mascot_smol.png", | ||||||
|                 ), |                 ), | ||||||
|             ) |             ) | ||||||
|             # If an upload folder is provided, enable the upload feature |             # If an upload folder is provided, enable the upload feature | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue