Gradio chatbot: step duration, number, token count, support nested thoughts (#384)
* Add enhanced Gradio UI with nested agents calls, execution-logs, errors - Add virtual separation between steps - Highlight final answer as required in internal discussion - Show step numbers and token counts - Include step duration tracking - Improve message display structure --------- Co-authored-by: Yuvraj Sharma <48665385+yvrjsharma@users.noreply.github.com>
This commit is contained in:
		
							parent
							
								
									3b5c99e87a
								
							
						
					
					
						commit
						49c34f625c
					
				|  | @ -41,7 +41,7 @@ e2b = [ | |||
|   "python-dotenv>=1.0.1", | ||||
| ] | ||||
| gradio = [ | ||||
|   "gradio>=5.8.0", | ||||
|   "gradio>=5.13.0", | ||||
| ] | ||||
| litellm = [ | ||||
|   "litellm>=1.55.10", | ||||
|  |  | |||
|  | @ -24,31 +24,102 @@ from .types import AgentAudio, AgentImage, AgentText, handle_agent_output_types | |||
| from .utils import _is_package_available | ||||
| 
 | ||||
| 
 | ||||
| def pull_messages_from_step(step_log: AgentStepLog): | ||||
|     """Extract ChatMessage objects from agent steps""" | ||||
| def pull_messages_from_step( | ||||
|     step_log: AgentStepLog, | ||||
| ): | ||||
|     """Extract ChatMessage objects from agent steps with proper nesting""" | ||||
|     import gradio as gr | ||||
| 
 | ||||
|     if isinstance(step_log, ActionStep): | ||||
|         yield gr.ChatMessage(role="assistant", content=step_log.llm_output or "") | ||||
|         if step_log.tool_calls is not None: | ||||
|         # Output the step number | ||||
|         step_number = f"Step {step_log.step_number}" if step_log.step_number is not None else "" | ||||
|         yield gr.ChatMessage(role="assistant", content=f"**{step_number}**") | ||||
| 
 | ||||
|         # First yield the thought/reasoning from the LLM | ||||
|         if hasattr(step_log, "llm_output") and step_log.llm_output is not None: | ||||
|             # Clean up the LLM output | ||||
|             llm_output = step_log.llm_output.strip() | ||||
|             # Remove any trailing <end_code> and extra backticks, handling multiple possible formats | ||||
|             llm_output = re.sub(r"```\s*<end_code>", "```", llm_output)  # handles ```<end_code> | ||||
|             llm_output = re.sub(r"<end_code>\s*```", "```", llm_output)  # handles <end_code>``` | ||||
|             llm_output = re.sub(r"```\s*\n\s*<end_code>", "```", llm_output)  # handles ```\n<end_code> | ||||
|             llm_output = llm_output.strip() | ||||
|             yield gr.ChatMessage(role="assistant", content=llm_output) | ||||
| 
 | ||||
|         # For tool calls, create a parent message | ||||
|         if hasattr(step_log, "tool_calls") and step_log.tool_calls is not None: | ||||
|             first_tool_call = step_log.tool_calls[0] | ||||
|             used_code = first_tool_call.name == "code interpreter" | ||||
|             content = first_tool_call.arguments | ||||
|             used_code = first_tool_call.name == "python_interpreter" | ||||
|             parent_id = f"call_{len(step_log.tool_calls)}" | ||||
| 
 | ||||
|             # Tool call becomes the parent message with timing info | ||||
|             # First we will handle arguments based on type | ||||
|             args = first_tool_call.arguments | ||||
|             if isinstance(args, dict): | ||||
|                 content = str(args.get("answer", str(args))) | ||||
|             else: | ||||
|                 content = str(args).strip() | ||||
| 
 | ||||
|             if used_code: | ||||
|                 content = f"```py\n{content}\n```" | ||||
|             yield gr.ChatMessage( | ||||
|                 # Clean up the content by removing any end code tags | ||||
|                 content = re.sub(r"```.*?\n", "", content)  # Remove existing code blocks | ||||
|                 content = re.sub(r"\s*<end_code>\s*", "", content)  # Remove end_code tags | ||||
|                 content = content.strip() | ||||
|                 if not content.startswith("```python"): | ||||
|                     content = f"```python\n{content}\n```" | ||||
| 
 | ||||
|             parent_message_tool = gr.ChatMessage( | ||||
|                 role="assistant", | ||||
|                 metadata={"title": f"🛠️ Used tool {first_tool_call.name}"}, | ||||
|                 content=str(content), | ||||
|                 content=content, | ||||
|                 metadata={ | ||||
|                     "title": f"🛠️ Used tool {first_tool_call.name}", | ||||
|                     "id": parent_id, | ||||
|                     "status": "pending", | ||||
|                 }, | ||||
|             ) | ||||
|         if step_log.observations is not None: | ||||
|             yield gr.ChatMessage(role="assistant", content=step_log.observations) | ||||
|         if step_log.error is not None: | ||||
|             yield gr.ChatMessage( | ||||
|                 role="assistant", | ||||
|                 content=str(step_log.error), | ||||
|                 metadata={"title": "💥 Error"}, | ||||
|             yield parent_message_tool | ||||
| 
 | ||||
|             # Nesting execution logs under the tool call if they exist | ||||
|             if hasattr(step_log, "observations") and ( | ||||
|                 step_log.observations is not None and step_log.observations.strip() | ||||
|             ):  # Only yield execution logs if there's actual content | ||||
|                 log_content = step_log.observations.strip() | ||||
|                 if log_content: | ||||
|                     log_content = re.sub(r"^Execution logs:\s*", "", log_content) | ||||
|                     yield gr.ChatMessage( | ||||
|                         role="assistant", | ||||
|                         content=f"{log_content}", | ||||
|                         metadata={"title": "📝 Execution Logs", "parent_id": parent_id, "status": "done"}, | ||||
|                     ) | ||||
| 
 | ||||
|             # Nesting any errors under the tool call | ||||
|             if hasattr(step_log, "error") and step_log.error is not None: | ||||
|                 yield gr.ChatMessage( | ||||
|                     role="assistant", | ||||
|                     content=str(step_log.error), | ||||
|                     metadata={"title": "💥 Error", "parent_id": parent_id, "status": "done"}, | ||||
|                 ) | ||||
| 
 | ||||
|             # Update parent message metadata to done status without yielding a new message | ||||
|             parent_message_tool.metadata["status"] = "done" | ||||
| 
 | ||||
|         # Handle standalone errors but not from tool calls | ||||
|         elif hasattr(step_log, "error") and step_log.error is not None: | ||||
|             yield gr.ChatMessage(role="assistant", content=str(step_log.error), metadata={"title": "💥 Error"}) | ||||
| 
 | ||||
|         # Calculate duration and token information | ||||
|         step_footnote = f"{step_number}" | ||||
|         if hasattr(step_log, "input_token_count") and hasattr(step_log, "output_token_count"): | ||||
|             token_str = ( | ||||
|                 f" | Input-tokens:{step_log.input_token_count:,} | Output-tokens:{step_log.output_token_count:,}" | ||||
|             ) | ||||
|             step_footnote += token_str | ||||
|         if hasattr(step_log, "duration"): | ||||
|             step_duration = f" | Duration: {round(float(step_log.duration), 2)}" if step_log.duration else None | ||||
|             step_footnote += step_duration | ||||
|         step_footnote = f"""<span style="color: #bbbbc2; font-size: 12px;">{step_footnote}</span> """ | ||||
|         yield gr.ChatMessage(role="assistant", content=f"{step_footnote}") | ||||
|         yield gr.ChatMessage(role="assistant", content="-----") | ||||
| 
 | ||||
| 
 | ||||
| def stream_to_gradio( | ||||
|  | @ -60,12 +131,25 @@ def stream_to_gradio( | |||
|     """Runs an agent with the given task and streams the messages from the agent as gradio ChatMessages.""" | ||||
|     if not _is_package_available("gradio"): | ||||
|         raise ModuleNotFoundError( | ||||
|             "Please install 'gradio' extra to use the GradioUI: `pip install 'smolagents[audio]'`" | ||||
|             "Please install 'gradio' extra to use the GradioUI: `pip install 'smolagents[gradio]'`" | ||||
|         ) | ||||
|     import gradio as gr | ||||
| 
 | ||||
|     total_input_tokens = 0 | ||||
|     total_output_tokens = 0 | ||||
| 
 | ||||
|     for step_log in agent.run(task, stream=True, reset=reset_agent_memory, additional_args=additional_args): | ||||
|         for message in pull_messages_from_step(step_log): | ||||
|         # Track tokens if model provides them | ||||
|         if hasattr(agent.model, "last_input_token_count"): | ||||
|             total_input_tokens += agent.model.last_input_token_count | ||||
|             total_output_tokens += agent.model.last_output_token_count | ||||
|             if isinstance(step_log, ActionStep): | ||||
|                 step_log.input_token_count = agent.model.last_input_token_count | ||||
|                 step_log.output_token_count = agent.model.last_output_token_count | ||||
| 
 | ||||
|         for message in pull_messages_from_step( | ||||
|             step_log, | ||||
|         ): | ||||
|             yield message | ||||
| 
 | ||||
|     final_answer = step_log  # Last log is the run's final_answer | ||||
|  | @ -87,7 +171,7 @@ def stream_to_gradio( | |||
|             content={"path": final_answer.to_string(), "mime_type": "audio/wav"}, | ||||
|         ) | ||||
|     else: | ||||
|         yield gr.ChatMessage(role="assistant", content=str(final_answer)) | ||||
|         yield gr.ChatMessage(role="assistant", content=f"**Final answer:** {str(final_answer)}") | ||||
| 
 | ||||
| 
 | ||||
| class GradioUI: | ||||
|  | @ -176,7 +260,7 @@ class GradioUI: | |||
|     def launch(self, **kwargs): | ||||
|         import gradio as gr | ||||
| 
 | ||||
|         with gr.Blocks() as demo: | ||||
|         with gr.Blocks(fill_height=True) as demo: | ||||
|             stored_messages = gr.State([]) | ||||
|             file_uploads_log = gr.State([]) | ||||
|             chatbot = gr.Chatbot( | ||||
|  | @ -187,6 +271,7 @@ class GradioUI: | |||
|                     "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/smolagents/mascot_smol.png", | ||||
|                 ), | ||||
|                 resizeable=True, | ||||
|                 scale=1, | ||||
|             ) | ||||
|             # If an upload folder is provided, enable the upload feature | ||||
|             if self.file_upload_folder is not None: | ||||
|  | @ -204,7 +289,7 @@ class GradioUI: | |||
|                 [stored_messages, text_input], | ||||
|             ).then(self.interact_with_agent, [stored_messages, chatbot], [chatbot]) | ||||
| 
 | ||||
|         demo.launch(**kwargs) | ||||
|         demo.launch(debug=True, share=True, **kwargs) | ||||
| 
 | ||||
| 
 | ||||
| __all__ = ["stream_to_gradio", "GradioUI"] | ||||
|  |  | |||
|  | @ -71,7 +71,7 @@ class MonitoringTester(unittest.TestCase): | |||
|         self.assertEqual(agent.monitor.total_input_token_count, 10) | ||||
|         self.assertEqual(agent.monitor.total_output_token_count, 20) | ||||
| 
 | ||||
|     def test_json_agent_metrics(self): | ||||
|     def test_toolcalling_agent_metrics(self): | ||||
|         agent = ToolCallingAgent( | ||||
|             tools=[], | ||||
|             model=FakeLLMModel(), | ||||
|  | @ -134,7 +134,7 @@ class MonitoringTester(unittest.TestCase): | |||
|         # Use stream_to_gradio to capture the output | ||||
|         outputs = list(stream_to_gradio(agent, task="Test task")) | ||||
| 
 | ||||
|         self.assertEqual(len(outputs), 4) | ||||
|         self.assertEqual(len(outputs), 7) | ||||
|         final_message = outputs[-1] | ||||
|         self.assertEqual(final_message.role, "assistant") | ||||
|         self.assertIn("This is the final answer.", final_message.content) | ||||
|  | @ -155,7 +155,7 @@ class MonitoringTester(unittest.TestCase): | |||
|             ) | ||||
|         ) | ||||
| 
 | ||||
|         self.assertEqual(len(outputs), 3) | ||||
|         self.assertEqual(len(outputs), 5) | ||||
|         final_message = outputs[-1] | ||||
|         self.assertEqual(final_message.role, "assistant") | ||||
|         self.assertIsInstance(final_message.content, dict) | ||||
|  | @ -177,7 +177,7 @@ class MonitoringTester(unittest.TestCase): | |||
|         # Use stream_to_gradio to capture the output | ||||
|         outputs = list(stream_to_gradio(agent, task="Test task")) | ||||
| 
 | ||||
|         self.assertEqual(len(outputs), 5) | ||||
|         self.assertEqual(len(outputs), 9) | ||||
|         final_message = outputs[-1] | ||||
|         self.assertEqual(final_message.role, "assistant") | ||||
|         self.assertIn("Simulated agent error", final_message.content) | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue