Gradio chatbot: step duration, number, token count, support nested thoughts (#384)
* Add enhanced Gradio UI with nested agents calls, execution-logs, errors - Add virtual separation between steps - Highlight final answer as required in internal discussion - Show step numbers and token counts - Include step duration tracking - Improve message display structure --------- Co-authored-by: Yuvraj Sharma <48665385+yvrjsharma@users.noreply.github.com>
This commit is contained in:
		
							parent
							
								
									3b5c99e87a
								
							
						
					
					
						commit
						49c34f625c
					
				|  | @ -41,7 +41,7 @@ e2b = [ | ||||||
|   "python-dotenv>=1.0.1", |   "python-dotenv>=1.0.1", | ||||||
| ] | ] | ||||||
| gradio = [ | gradio = [ | ||||||
|   "gradio>=5.8.0", |   "gradio>=5.13.0", | ||||||
| ] | ] | ||||||
| litellm = [ | litellm = [ | ||||||
|   "litellm>=1.55.10", |   "litellm>=1.55.10", | ||||||
|  |  | ||||||
|  | @ -24,31 +24,102 @@ from .types import AgentAudio, AgentImage, AgentText, handle_agent_output_types | ||||||
| from .utils import _is_package_available | from .utils import _is_package_available | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def pull_messages_from_step(step_log: AgentStepLog): | def pull_messages_from_step( | ||||||
|     """Extract ChatMessage objects from agent steps""" |     step_log: AgentStepLog, | ||||||
|  | ): | ||||||
|  |     """Extract ChatMessage objects from agent steps with proper nesting""" | ||||||
|     import gradio as gr |     import gradio as gr | ||||||
| 
 | 
 | ||||||
|     if isinstance(step_log, ActionStep): |     if isinstance(step_log, ActionStep): | ||||||
|         yield gr.ChatMessage(role="assistant", content=step_log.llm_output or "") |         # Output the step number | ||||||
|         if step_log.tool_calls is not None: |         step_number = f"Step {step_log.step_number}" if step_log.step_number is not None else "" | ||||||
|  |         yield gr.ChatMessage(role="assistant", content=f"**{step_number}**") | ||||||
|  | 
 | ||||||
|  |         # First yield the thought/reasoning from the LLM | ||||||
|  |         if hasattr(step_log, "llm_output") and step_log.llm_output is not None: | ||||||
|  |             # Clean up the LLM output | ||||||
|  |             llm_output = step_log.llm_output.strip() | ||||||
|  |             # Remove any trailing <end_code> and extra backticks, handling multiple possible formats | ||||||
|  |             llm_output = re.sub(r"```\s*<end_code>", "```", llm_output)  # handles ```<end_code> | ||||||
|  |             llm_output = re.sub(r"<end_code>\s*```", "```", llm_output)  # handles <end_code>``` | ||||||
|  |             llm_output = re.sub(r"```\s*\n\s*<end_code>", "```", llm_output)  # handles ```\n<end_code> | ||||||
|  |             llm_output = llm_output.strip() | ||||||
|  |             yield gr.ChatMessage(role="assistant", content=llm_output) | ||||||
|  | 
 | ||||||
|  |         # For tool calls, create a parent message | ||||||
|  |         if hasattr(step_log, "tool_calls") and step_log.tool_calls is not None: | ||||||
|             first_tool_call = step_log.tool_calls[0] |             first_tool_call = step_log.tool_calls[0] | ||||||
|             used_code = first_tool_call.name == "code interpreter" |             used_code = first_tool_call.name == "python_interpreter" | ||||||
|             content = first_tool_call.arguments |             parent_id = f"call_{len(step_log.tool_calls)}" | ||||||
|  | 
 | ||||||
|  |             # Tool call becomes the parent message with timing info | ||||||
|  |             # First we will handle arguments based on type | ||||||
|  |             args = first_tool_call.arguments | ||||||
|  |             if isinstance(args, dict): | ||||||
|  |                 content = str(args.get("answer", str(args))) | ||||||
|  |             else: | ||||||
|  |                 content = str(args).strip() | ||||||
|  | 
 | ||||||
|             if used_code: |             if used_code: | ||||||
|                 content = f"```py\n{content}\n```" |                 # Clean up the content by removing any end code tags | ||||||
|             yield gr.ChatMessage( |                 content = re.sub(r"```.*?\n", "", content)  # Remove existing code blocks | ||||||
|  |                 content = re.sub(r"\s*<end_code>\s*", "", content)  # Remove end_code tags | ||||||
|  |                 content = content.strip() | ||||||
|  |                 if not content.startswith("```python"): | ||||||
|  |                     content = f"```python\n{content}\n```" | ||||||
|  | 
 | ||||||
|  |             parent_message_tool = gr.ChatMessage( | ||||||
|                 role="assistant", |                 role="assistant", | ||||||
|                 metadata={"title": f"🛠️ Used tool {first_tool_call.name}"}, |                 content=content, | ||||||
|                 content=str(content), |                 metadata={ | ||||||
|  |                     "title": f"🛠️ Used tool {first_tool_call.name}", | ||||||
|  |                     "id": parent_id, | ||||||
|  |                     "status": "pending", | ||||||
|  |                 }, | ||||||
|             ) |             ) | ||||||
|         if step_log.observations is not None: |             yield parent_message_tool | ||||||
|             yield gr.ChatMessage(role="assistant", content=step_log.observations) | 
 | ||||||
|         if step_log.error is not None: |             # Nesting execution logs under the tool call if they exist | ||||||
|             yield gr.ChatMessage( |             if hasattr(step_log, "observations") and ( | ||||||
|                 role="assistant", |                 step_log.observations is not None and step_log.observations.strip() | ||||||
|                 content=str(step_log.error), |             ):  # Only yield execution logs if there's actual content | ||||||
|                 metadata={"title": "💥 Error"}, |                 log_content = step_log.observations.strip() | ||||||
|  |                 if log_content: | ||||||
|  |                     log_content = re.sub(r"^Execution logs:\s*", "", log_content) | ||||||
|  |                     yield gr.ChatMessage( | ||||||
|  |                         role="assistant", | ||||||
|  |                         content=f"{log_content}", | ||||||
|  |                         metadata={"title": "📝 Execution Logs", "parent_id": parent_id, "status": "done"}, | ||||||
|  |                     ) | ||||||
|  | 
 | ||||||
|  |             # Nesting any errors under the tool call | ||||||
|  |             if hasattr(step_log, "error") and step_log.error is not None: | ||||||
|  |                 yield gr.ChatMessage( | ||||||
|  |                     role="assistant", | ||||||
|  |                     content=str(step_log.error), | ||||||
|  |                     metadata={"title": "💥 Error", "parent_id": parent_id, "status": "done"}, | ||||||
|  |                 ) | ||||||
|  | 
 | ||||||
|  |             # Update parent message metadata to done status without yielding a new message | ||||||
|  |             parent_message_tool.metadata["status"] = "done" | ||||||
|  | 
 | ||||||
|  |         # Handle standalone errors but not from tool calls | ||||||
|  |         elif hasattr(step_log, "error") and step_log.error is not None: | ||||||
|  |             yield gr.ChatMessage(role="assistant", content=str(step_log.error), metadata={"title": "💥 Error"}) | ||||||
|  | 
 | ||||||
|  |         # Calculate duration and token information | ||||||
|  |         step_footnote = f"{step_number}" | ||||||
|  |         if hasattr(step_log, "input_token_count") and hasattr(step_log, "output_token_count"): | ||||||
|  |             token_str = ( | ||||||
|  |                 f" | Input-tokens:{step_log.input_token_count:,} | Output-tokens:{step_log.output_token_count:,}" | ||||||
|             ) |             ) | ||||||
|  |             step_footnote += token_str | ||||||
|  |         if hasattr(step_log, "duration"): | ||||||
|  |             step_duration = f" | Duration: {round(float(step_log.duration), 2)}" if step_log.duration else None | ||||||
|  |             step_footnote += step_duration | ||||||
|  |         step_footnote = f"""<span style="color: #bbbbc2; font-size: 12px;">{step_footnote}</span> """ | ||||||
|  |         yield gr.ChatMessage(role="assistant", content=f"{step_footnote}") | ||||||
|  |         yield gr.ChatMessage(role="assistant", content="-----") | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def stream_to_gradio( | def stream_to_gradio( | ||||||
|  | @ -60,12 +131,25 @@ def stream_to_gradio( | ||||||
|     """Runs an agent with the given task and streams the messages from the agent as gradio ChatMessages.""" |     """Runs an agent with the given task and streams the messages from the agent as gradio ChatMessages.""" | ||||||
|     if not _is_package_available("gradio"): |     if not _is_package_available("gradio"): | ||||||
|         raise ModuleNotFoundError( |         raise ModuleNotFoundError( | ||||||
|             "Please install 'gradio' extra to use the GradioUI: `pip install 'smolagents[audio]'`" |             "Please install 'gradio' extra to use the GradioUI: `pip install 'smolagents[gradio]'`" | ||||||
|         ) |         ) | ||||||
|     import gradio as gr |     import gradio as gr | ||||||
| 
 | 
 | ||||||
|  |     total_input_tokens = 0 | ||||||
|  |     total_output_tokens = 0 | ||||||
|  | 
 | ||||||
|     for step_log in agent.run(task, stream=True, reset=reset_agent_memory, additional_args=additional_args): |     for step_log in agent.run(task, stream=True, reset=reset_agent_memory, additional_args=additional_args): | ||||||
|         for message in pull_messages_from_step(step_log): |         # Track tokens if model provides them | ||||||
|  |         if hasattr(agent.model, "last_input_token_count"): | ||||||
|  |             total_input_tokens += agent.model.last_input_token_count | ||||||
|  |             total_output_tokens += agent.model.last_output_token_count | ||||||
|  |             if isinstance(step_log, ActionStep): | ||||||
|  |                 step_log.input_token_count = agent.model.last_input_token_count | ||||||
|  |                 step_log.output_token_count = agent.model.last_output_token_count | ||||||
|  | 
 | ||||||
|  |         for message in pull_messages_from_step( | ||||||
|  |             step_log, | ||||||
|  |         ): | ||||||
|             yield message |             yield message | ||||||
| 
 | 
 | ||||||
|     final_answer = step_log  # Last log is the run's final_answer |     final_answer = step_log  # Last log is the run's final_answer | ||||||
|  | @ -87,7 +171,7 @@ def stream_to_gradio( | ||||||
|             content={"path": final_answer.to_string(), "mime_type": "audio/wav"}, |             content={"path": final_answer.to_string(), "mime_type": "audio/wav"}, | ||||||
|         ) |         ) | ||||||
|     else: |     else: | ||||||
|         yield gr.ChatMessage(role="assistant", content=str(final_answer)) |         yield gr.ChatMessage(role="assistant", content=f"**Final answer:** {str(final_answer)}") | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class GradioUI: | class GradioUI: | ||||||
|  | @ -176,7 +260,7 @@ class GradioUI: | ||||||
|     def launch(self, **kwargs): |     def launch(self, **kwargs): | ||||||
|         import gradio as gr |         import gradio as gr | ||||||
| 
 | 
 | ||||||
|         with gr.Blocks() as demo: |         with gr.Blocks(fill_height=True) as demo: | ||||||
|             stored_messages = gr.State([]) |             stored_messages = gr.State([]) | ||||||
|             file_uploads_log = gr.State([]) |             file_uploads_log = gr.State([]) | ||||||
|             chatbot = gr.Chatbot( |             chatbot = gr.Chatbot( | ||||||
|  | @ -187,6 +271,7 @@ class GradioUI: | ||||||
|                     "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/smolagents/mascot_smol.png", |                     "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/smolagents/mascot_smol.png", | ||||||
|                 ), |                 ), | ||||||
|                 resizeable=True, |                 resizeable=True, | ||||||
|  |                 scale=1, | ||||||
|             ) |             ) | ||||||
|             # If an upload folder is provided, enable the upload feature |             # If an upload folder is provided, enable the upload feature | ||||||
|             if self.file_upload_folder is not None: |             if self.file_upload_folder is not None: | ||||||
|  | @ -204,7 +289,7 @@ class GradioUI: | ||||||
|                 [stored_messages, text_input], |                 [stored_messages, text_input], | ||||||
|             ).then(self.interact_with_agent, [stored_messages, chatbot], [chatbot]) |             ).then(self.interact_with_agent, [stored_messages, chatbot], [chatbot]) | ||||||
| 
 | 
 | ||||||
|         demo.launch(**kwargs) |         demo.launch(debug=True, share=True, **kwargs) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| __all__ = ["stream_to_gradio", "GradioUI"] | __all__ = ["stream_to_gradio", "GradioUI"] | ||||||
|  |  | ||||||
|  | @ -71,7 +71,7 @@ class MonitoringTester(unittest.TestCase): | ||||||
|         self.assertEqual(agent.monitor.total_input_token_count, 10) |         self.assertEqual(agent.monitor.total_input_token_count, 10) | ||||||
|         self.assertEqual(agent.monitor.total_output_token_count, 20) |         self.assertEqual(agent.monitor.total_output_token_count, 20) | ||||||
| 
 | 
 | ||||||
|     def test_json_agent_metrics(self): |     def test_toolcalling_agent_metrics(self): | ||||||
|         agent = ToolCallingAgent( |         agent = ToolCallingAgent( | ||||||
|             tools=[], |             tools=[], | ||||||
|             model=FakeLLMModel(), |             model=FakeLLMModel(), | ||||||
|  | @ -134,7 +134,7 @@ class MonitoringTester(unittest.TestCase): | ||||||
|         # Use stream_to_gradio to capture the output |         # Use stream_to_gradio to capture the output | ||||||
|         outputs = list(stream_to_gradio(agent, task="Test task")) |         outputs = list(stream_to_gradio(agent, task="Test task")) | ||||||
| 
 | 
 | ||||||
|         self.assertEqual(len(outputs), 4) |         self.assertEqual(len(outputs), 7) | ||||||
|         final_message = outputs[-1] |         final_message = outputs[-1] | ||||||
|         self.assertEqual(final_message.role, "assistant") |         self.assertEqual(final_message.role, "assistant") | ||||||
|         self.assertIn("This is the final answer.", final_message.content) |         self.assertIn("This is the final answer.", final_message.content) | ||||||
|  | @ -155,7 +155,7 @@ class MonitoringTester(unittest.TestCase): | ||||||
|             ) |             ) | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|         self.assertEqual(len(outputs), 3) |         self.assertEqual(len(outputs), 5) | ||||||
|         final_message = outputs[-1] |         final_message = outputs[-1] | ||||||
|         self.assertEqual(final_message.role, "assistant") |         self.assertEqual(final_message.role, "assistant") | ||||||
|         self.assertIsInstance(final_message.content, dict) |         self.assertIsInstance(final_message.content, dict) | ||||||
|  | @ -177,7 +177,7 @@ class MonitoringTester(unittest.TestCase): | ||||||
|         # Use stream_to_gradio to capture the output |         # Use stream_to_gradio to capture the output | ||||||
|         outputs = list(stream_to_gradio(agent, task="Test task")) |         outputs = list(stream_to_gradio(agent, task="Test task")) | ||||||
| 
 | 
 | ||||||
|         self.assertEqual(len(outputs), 5) |         self.assertEqual(len(outputs), 9) | ||||||
|         final_message = outputs[-1] |         final_message = outputs[-1] | ||||||
|         self.assertEqual(final_message.role, "assistant") |         self.assertEqual(final_message.role, "assistant") | ||||||
|         self.assertIn("Simulated agent error", final_message.content) |         self.assertIn("Simulated agent error", final_message.content) | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue