Solve additional args not being passed to task
This commit is contained in:
		
							parent
							
								
									b38d842c2d
								
							
						
					
					
						commit
						ba87dd98c8
					
				|  | @ -528,7 +528,7 @@ class ReactAgent(BaseAgent): | ||||||
|                 self.logs.append(system_prompt_step) |                 self.logs.append(system_prompt_step) | ||||||
| 
 | 
 | ||||||
|         console.print(Group(Rule("[bold]New task", characters="="), Text(self.task))) |         console.print(Group(Rule("[bold]New task", characters="="), Text(self.task))) | ||||||
|         self.logs.append(TaskStep(task=task)) |         self.logs.append(TaskStep(task=self.task)) | ||||||
| 
 | 
 | ||||||
|         if oneshot: |         if oneshot: | ||||||
|             step_start_time = time.time() |             step_start_time = time.time() | ||||||
|  | @ -541,9 +541,9 @@ class ReactAgent(BaseAgent): | ||||||
|             return result |             return result | ||||||
| 
 | 
 | ||||||
|         if stream: |         if stream: | ||||||
|             return self.stream_run(task) |             return self.stream_run(self.task) | ||||||
|         else: |         else: | ||||||
|             return self.direct_run(task) |             return self.direct_run(self.task) | ||||||
| 
 | 
 | ||||||
|     def stream_run(self, task: str): |     def stream_run(self, task: str): | ||||||
|         """ |         """ | ||||||
|  |  | ||||||
|  | @ -230,6 +230,12 @@ Action: | ||||||
|             tool_arguments="final_answer(7.2904)", |             tool_arguments="final_answer(7.2904)", | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|  |     def test_additional_args_added_to_task(self): | ||||||
|  |         agent = CodeAgent(tools=[], llm_engine=fake_code_llm) | ||||||
|  |         output = agent.run("What is 2 multiplied by 3.6452?", additional_instruction="Remember this.") | ||||||
|  |         assert "Remember this" in agent.task | ||||||
|  |         assert "Remember this" in str(agent.prompt_messages) | ||||||
|  | 
 | ||||||
|     def test_reset_conversations(self): |     def test_reset_conversations(self): | ||||||
|         agent = CodeAgent(tools=[PythonInterpreterTool()], llm_engine=fake_code_llm) |         agent = CodeAgent(tools=[PythonInterpreterTool()], llm_engine=fake_code_llm) | ||||||
|         output = agent.run("What is 2 multiplied by 3.6452?", reset=True) |         output = agent.run("What is 2 multiplied by 3.6452?", reset=True) | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue