Support multi-agent
This commit is contained in:
		
							parent
							
								
									43a3f46835
								
							
						
					
					
						commit
						23ab4a9df3
					
				|  | @ -92,6 +92,7 @@ class ActionStep: | ||||||
|     final_answer: Any = None |     final_answer: Any = None | ||||||
|     error: AgentError | None = None |     error: AgentError | None = None | ||||||
|     step_duration: float | None = None |     step_duration: float | None = None | ||||||
|  |     llm_output: str | None = None | ||||||
| 
 | 
 | ||||||
| @dataclass | @dataclass | ||||||
| class PlanningStep: | class PlanningStep: | ||||||
|  | @ -440,23 +441,20 @@ class ReactAgent(BaseAgent): | ||||||
|         else: |         else: | ||||||
|             self.logs.append(TaskStep(task=task)) |             self.logs.append(TaskStep(task=task)) | ||||||
| 
 | 
 | ||||||
|         with console.status( |         if oneshot: | ||||||
|             "Agent is running...", spinner="aesthetic" |             step_start_time = time.time() | ||||||
|         ): |             step_log = ActionStep(start_time=step_start_time) | ||||||
|             if oneshot: |             step_log.step_end_time = time.time() | ||||||
|                 step_start_time = time.time() |             step_log.step_duration = step_log.step_end_time - step_start_time | ||||||
|                 step_log = ActionStep(start_time=step_start_time) |  | ||||||
|                 step_log.step_end_time = time.time() |  | ||||||
|                 step_log.step_duration = step_log.step_end_time - step_start_time |  | ||||||
| 
 | 
 | ||||||
|                 # Run the agent's step |             # Run the agent's step | ||||||
|                 result = self.step(step_log) |             result = self.step(step_log) | ||||||
|                 return result |             return result | ||||||
| 
 | 
 | ||||||
|             if stream: |         if stream: | ||||||
|                 return self.stream_run(task) |             return self.stream_run(task) | ||||||
|             else: |         else: | ||||||
|                 return self.direct_run(task) |             return self.direct_run(task) | ||||||
| 
 | 
 | ||||||
|     def stream_run(self, task: str): |     def stream_run(self, task: str): | ||||||
|         """ |         """ | ||||||
|  | @ -468,6 +466,9 @@ class ReactAgent(BaseAgent): | ||||||
|             step_start_time = time.time() |             step_start_time = time.time() | ||||||
|             step_log = ActionStep(iteration=iteration, start_time=step_start_time) |             step_log = ActionStep(iteration=iteration, start_time=step_start_time) | ||||||
|             try: |             try: | ||||||
|  |                 if self.planning_interval is not None and iteration % self.planning_interval == 0: | ||||||
|  |                     self.planning_step(task, is_first_step=(iteration == 0), iteration=iteration) | ||||||
|  |                 console.rule("[bold]New step") | ||||||
|                 self.step(step_log) |                 self.step(step_log) | ||||||
|                 if step_log.final_answer is not None: |                 if step_log.final_answer is not None: | ||||||
|                     final_answer = step_log.final_answer |                     final_answer = step_log.final_answer | ||||||
|  | @ -484,7 +485,6 @@ class ReactAgent(BaseAgent): | ||||||
| 
 | 
 | ||||||
|         if final_answer is None and iteration == self.max_iterations: |         if final_answer is None and iteration == self.max_iterations: | ||||||
|             error_message = "Reached max iterations." |             error_message = "Reached max iterations." | ||||||
|             console.print(f"[bold red]{error_message}") |  | ||||||
|             final_step_log = ActionStep(error=AgentMaxIterationsError(error_message)) |             final_step_log = ActionStep(error=AgentMaxIterationsError(error_message)) | ||||||
|             self.logs.append(final_step_log) |             self.logs.append(final_step_log) | ||||||
|             final_answer = self.provide_final_answer(task) |             final_answer = self.provide_final_answer(task) | ||||||
|  | @ -509,6 +509,7 @@ class ReactAgent(BaseAgent): | ||||||
|             try: |             try: | ||||||
|                 if self.planning_interval is not None and iteration % self.planning_interval == 0: |                 if self.planning_interval is not None and iteration % self.planning_interval == 0: | ||||||
|                     self.planning_step(task, is_first_step=(iteration == 0), iteration=iteration) |                     self.planning_step(task, is_first_step=(iteration == 0), iteration=iteration) | ||||||
|  |                 console.rule("[bold]New step") | ||||||
|                 self.step(step_log) |                 self.step(step_log) | ||||||
|                 if step_log.final_answer is not None: |                 if step_log.final_answer is not None: | ||||||
|                     final_answer = step_log.final_answer |                     final_answer = step_log.final_answer | ||||||
|  | @ -527,7 +528,6 @@ class ReactAgent(BaseAgent): | ||||||
|             error_message = "Reached max iterations." |             error_message = "Reached max iterations." | ||||||
|             final_step_log = ActionStep(error=AgentMaxIterationsError(error_message)) |             final_step_log = ActionStep(error=AgentMaxIterationsError(error_message)) | ||||||
|             self.logs.append(final_step_log) |             self.logs.append(final_step_log) | ||||||
|             console.print(f"[bold red]{error_message}") |  | ||||||
|             final_answer = self.provide_final_answer(task) |             final_answer = self.provide_final_answer(task) | ||||||
|             final_step_log.final_answer = final_answer |             final_step_log.final_answer = final_answer | ||||||
|             final_step_log.step_duration = 0 |             final_step_log.step_duration = 0 | ||||||
|  | @ -677,7 +677,6 @@ class JsonAgent(ReactAgent): | ||||||
|         agent_memory = self.write_inner_memory_from_logs() |         agent_memory = self.write_inner_memory_from_logs() | ||||||
| 
 | 
 | ||||||
|         self.prompt = agent_memory |         self.prompt = agent_memory | ||||||
|         console.rule("New step") |  | ||||||
| 
 | 
 | ||||||
|         # Add new step in logs |         # Add new step in logs | ||||||
|         log_entry.agent_memory = agent_memory.copy() |         log_entry.agent_memory = agent_memory.copy() | ||||||
|  | @ -692,12 +691,14 @@ class JsonAgent(ReactAgent): | ||||||
|             llm_output = self.llm_engine( |             llm_output = self.llm_engine( | ||||||
|                 self.prompt, stop_sequences=["<end_action>", "Observation:"], **additional_args |                 self.prompt, stop_sequences=["<end_action>", "Observation:"], **additional_args | ||||||
|             ) |             ) | ||||||
|  |             log_entry.llm_output = llm_output | ||||||
|         except Exception as e: |         except Exception as e: | ||||||
|             raise AgentGenerationError(f"Error in generating llm output: {e}.") |             raise AgentGenerationError(f"Error in generating llm output: {e}.") | ||||||
|         console.rule("Output message of the LLM") |  | ||||||
|         console.print(llm_output) |  | ||||||
|         log_entry.llm_output = llm_output |  | ||||||
| 
 | 
 | ||||||
|  |         if self.verbose: | ||||||
|  |             console.rule("[italic]Output message of the LLM:") | ||||||
|  |             console.print(llm_output) | ||||||
|  |          | ||||||
|         # Parse |         # Parse | ||||||
|         rationale, action = self.extract_action(llm_output=llm_output, split_token="Action:") |         rationale, action = self.extract_action(llm_output=llm_output, split_token="Action:") | ||||||
| 
 | 
 | ||||||
|  | @ -796,7 +797,6 @@ class CodeAgent(ReactAgent): | ||||||
|         agent_memory = self.write_inner_memory_from_logs() |         agent_memory = self.write_inner_memory_from_logs() | ||||||
| 
 | 
 | ||||||
|         self.prompt = agent_memory.copy() |         self.prompt = agent_memory.copy() | ||||||
|         console.rule("New step") |  | ||||||
| 
 | 
 | ||||||
|         # Add new step in logs |         # Add new step in logs | ||||||
|         log_entry.agent_memory = agent_memory.copy() |         log_entry.agent_memory = agent_memory.copy() | ||||||
|  | @ -811,13 +811,13 @@ class CodeAgent(ReactAgent): | ||||||
|             llm_output = self.llm_engine( |             llm_output = self.llm_engine( | ||||||
|                 self.prompt, stop_sequences=["<end_action>", "Observation:"], **additional_args |                 self.prompt, stop_sequences=["<end_action>", "Observation:"], **additional_args | ||||||
|             ) |             ) | ||||||
|  |             log_entry.llm_output = llm_output | ||||||
|         except Exception as e: |         except Exception as e: | ||||||
|             raise AgentGenerationError(f"Error in generating llm output: {e}.") |             raise AgentGenerationError(f"Error in generating llm output: {e}.") | ||||||
| 
 | 
 | ||||||
|         if self.verbose: |         if self.verbose: | ||||||
|             console.rule("[italic]Output message of the LLM:") |             console.rule("[italic]Output message of the LLM:") | ||||||
|             console.print(Syntax(llm_output, lexer='markdown', background_color='default')) |             console.print(Syntax(llm_output, lexer='markdown', background_color='default')) | ||||||
|         log_entry.llm_output = llm_output |  | ||||||
| 
 | 
 | ||||||
|         # Parse |         # Parse | ||||||
|         try: |         try: | ||||||
|  |  | ||||||
|  | @ -185,3 +185,13 @@ class FinalAnswerTool(Tool): | ||||||
| 
 | 
 | ||||||
|     def forward(self, answer): |     def forward(self, answer): | ||||||
|         return answer |         return answer | ||||||
|  | 
 | ||||||
|  | class UserInputTool(Tool): | ||||||
|  |     name = "user_input" | ||||||
|  |     description = "Asks for user's input on a specific question" | ||||||
|  |     inputs = {"question": {"type": "string", "description": "The question to ask the user"}} | ||||||
|  |     output_type = "string" | ||||||
|  | 
 | ||||||
|  |     def forward(self, question): | ||||||
|  |         user_input = input(f"{question} => ") | ||||||
|  |         return user_input | ||||||
|  |  | ||||||
|  | @ -0,0 +1,73 @@ | ||||||
|  | import re | ||||||
|  | import requests | ||||||
|  | from markdownify import markdownify as md | ||||||
|  | from requests.exceptions import RequestException | ||||||
|  | 
 | ||||||
|  | from agents import ( | ||||||
|  |     tool, | ||||||
|  |     CodeAgent, | ||||||
|  |     JsonAgent, | ||||||
|  |     HfApiEngine, | ||||||
|  |     ManagedAgent, | ||||||
|  | ) | ||||||
|  | from agents.default_tools import UserInputTool | ||||||
|  | from agents.search import DuckDuckGoSearchTool | ||||||
|  | from agents.utils import console | ||||||
|  | 
 | ||||||
|  | model = "Qwen/Qwen2.5-72B-Instruct" | ||||||
|  | 
 | ||||||
|  | @tool | ||||||
|  | def visit_webpage(url: str) -> str: | ||||||
|  |     """Visits a webpage at the given URL and returns its content as a markdown string. | ||||||
|  | 
 | ||||||
|  |     Args: | ||||||
|  |         url: The URL of the webpage to visit. | ||||||
|  | 
 | ||||||
|  |     Returns: | ||||||
|  |         The content of the webpage converted to Markdown, or an error message if the request fails. | ||||||
|  |     """ | ||||||
|  |     try: | ||||||
|  |         # Send a GET request to the URL | ||||||
|  |         response = requests.get(url) | ||||||
|  |         response.raise_for_status()  # Raise an exception for bad status codes | ||||||
|  | 
 | ||||||
|  |         # Convert the HTML content to Markdown | ||||||
|  |         markdown_content = md(response.text).strip() | ||||||
|  | 
 | ||||||
|  |         # Remove multiple line breaks | ||||||
|  |         markdown_content = re.sub(r"\n{3,}", "\n\n", markdown_content) | ||||||
|  | 
 | ||||||
|  |         return markdown_content | ||||||
|  | 
 | ||||||
|  |     except RequestException as e: | ||||||
|  |         return f"Error fetching the webpage: {str(e)}" | ||||||
|  |     except Exception as e: | ||||||
|  |         return f"An unexpected error occurred: {str(e)}" | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | llm_engine = HfApiEngine(model) | ||||||
|  | 
 | ||||||
|  | web_agent = JsonAgent( | ||||||
|  |     tools=[DuckDuckGoSearchTool(), visit_webpage], | ||||||
|  |     llm_engine=llm_engine, | ||||||
|  |     max_iterations=10, | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | managed_web_agent = ManagedAgent( | ||||||
|  |     agent=web_agent, | ||||||
|  |     name="search", | ||||||
|  |     description="Runs web searches for you. Give it your query as an argument.", | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | manager_agent = CodeAgent( | ||||||
|  |     tools=[UserInputTool()], | ||||||
|  |     llm_engine=llm_engine, | ||||||
|  |     managed_agents=[managed_web_agent], | ||||||
|  |     additional_authorized_imports=["time", "datetime"], | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | with console.status( | ||||||
|  |     "Agent is running...", spinner="aesthetic" | ||||||
|  | ): | ||||||
|  |     manager_agent.run("""How many years ago was Stripe founded? | ||||||
|  |     You should ask for user input on wether the answer is correct before returning your final answer.""") | ||||||
|  | @ -1,5 +1,26 @@ | ||||||
| # This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. | # This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. | ||||||
| 
 | 
 | ||||||
|  | [[package]] | ||||||
|  | name = "beautifulsoup4" | ||||||
|  | version = "4.12.3" | ||||||
|  | description = "Screen-scraping library" | ||||||
|  | optional = false | ||||||
|  | python-versions = ">=3.6.0" | ||||||
|  | files = [ | ||||||
|  |     {file = "beautifulsoup4-4.12.3-py3-none-any.whl", hash = "sha256:b80878c9f40111313e55da8ba20bdba06d8fa3969fc68304167741bbf9e082ed"}, | ||||||
|  |     {file = "beautifulsoup4-4.12.3.tar.gz", hash = "sha256:74e3d1928edc070d21748185c46e3fb33490f22f52a3addee9aee0f4f7781051"}, | ||||||
|  | ] | ||||||
|  | 
 | ||||||
|  | [package.dependencies] | ||||||
|  | soupsieve = ">1.2" | ||||||
|  | 
 | ||||||
|  | [package.extras] | ||||||
|  | cchardet = ["cchardet"] | ||||||
|  | chardet = ["chardet"] | ||||||
|  | charset-normalizer = ["charset-normalizer"] | ||||||
|  | html5lib = ["html5lib"] | ||||||
|  | lxml = ["lxml"] | ||||||
|  | 
 | ||||||
| [[package]] | [[package]] | ||||||
| name = "certifi" | name = "certifi" | ||||||
| version = "2024.8.30" | version = "2024.8.30" | ||||||
|  | @ -338,6 +359,21 @@ profiling = ["gprof2dot"] | ||||||
| rtd = ["jupyter_sphinx", "mdit-py-plugins", "myst-parser", "pyyaml", "sphinx", "sphinx-copybutton", "sphinx-design", "sphinx_book_theme"] | rtd = ["jupyter_sphinx", "mdit-py-plugins", "myst-parser", "pyyaml", "sphinx", "sphinx-copybutton", "sphinx-design", "sphinx_book_theme"] | ||||||
| testing = ["coverage", "pytest", "pytest-cov", "pytest-regressions"] | testing = ["coverage", "pytest", "pytest-cov", "pytest-regressions"] | ||||||
| 
 | 
 | ||||||
|  | [[package]] | ||||||
|  | name = "markdownify" | ||||||
|  | version = "0.14.1" | ||||||
|  | description = "Convert HTML to markdown." | ||||||
|  | optional = false | ||||||
|  | python-versions = "*" | ||||||
|  | files = [ | ||||||
|  |     {file = "markdownify-0.14.1-py3-none-any.whl", hash = "sha256:4c46a6c0c12c6005ddcd49b45a5a890398b002ef51380cd319db62df5e09bc2a"}, | ||||||
|  |     {file = "markdownify-0.14.1.tar.gz", hash = "sha256:a62a7a216947ed0b8dafb95b99b2ef4a0edd1e18d5653c656f68f03db2bfb2f1"}, | ||||||
|  | ] | ||||||
|  | 
 | ||||||
|  | [package.dependencies] | ||||||
|  | beautifulsoup4 = ">=4.9,<5" | ||||||
|  | six = ">=1.15,<2" | ||||||
|  | 
 | ||||||
| [[package]] | [[package]] | ||||||
| name = "markupsafe" | name = "markupsafe" | ||||||
| version = "3.0.2" | version = "3.0.2" | ||||||
|  | @ -1096,6 +1132,17 @@ files = [ | ||||||
|     {file = "six-1.17.0.tar.gz", hash = "sha256:ff70335d468e7eb6ec65b95b99d3a2836546063f63acc5171de367e834932a81"}, |     {file = "six-1.17.0.tar.gz", hash = "sha256:ff70335d468e7eb6ec65b95b99d3a2836546063f63acc5171de367e834932a81"}, | ||||||
| ] | ] | ||||||
| 
 | 
 | ||||||
|  | [[package]] | ||||||
|  | name = "soupsieve" | ||||||
|  | version = "2.6" | ||||||
|  | description = "A modern CSS selector implementation for Beautiful Soup." | ||||||
|  | optional = false | ||||||
|  | python-versions = ">=3.8" | ||||||
|  | files = [ | ||||||
|  |     {file = "soupsieve-2.6-py3-none-any.whl", hash = "sha256:e72c4ff06e4fb6e4b5a9f0f55fe6e81514581fca1515028625d0f299c602ccc9"}, | ||||||
|  |     {file = "soupsieve-2.6.tar.gz", hash = "sha256:e2e68417777af359ec65daac1057404a3c8a5455bb8abc36f1a9866ab1a51abb"}, | ||||||
|  | ] | ||||||
|  | 
 | ||||||
| [[package]] | [[package]] | ||||||
| name = "tokenizers" | name = "tokenizers" | ||||||
| version = "0.21.0" | version = "0.21.0" | ||||||
|  | @ -1301,4 +1348,4 @@ zstd = ["zstandard (>=0.18.0)"] | ||||||
| [metadata] | [metadata] | ||||||
| lock-version = "2.0" | lock-version = "2.0" | ||||||
| python-versions = ">=3.10,<3.13" | python-versions = ">=3.10,<3.13" | ||||||
| content-hash = "6c3841968936d66bf70e11c6c8e0a16fec6c2f4d88d79cd8ac5a412225e7cf56" | content-hash = "3a0896faf882952a0d780efcc862017989612fcb421a6ee01e4eec0ba6c0f638" | ||||||
|  |  | ||||||
|  | @ -67,6 +67,7 @@ pandas = "^2.2.3" | ||||||
| jinja2 = "^3.1.4" | jinja2 = "^3.1.4" | ||||||
| pillow = "^11.0.0" | pillow = "^11.0.0" | ||||||
| llama-cpp-python = "^0.3.4" | llama-cpp-python = "^0.3.4" | ||||||
|  | markdownify = "^0.14.1" | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| [build-system] | [build-system] | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue