diff --git a/ra_aid/agent_utils.py b/ra_aid/agent_utils.py index 18ed975..62c78f1 100644 --- a/ra_aid/agent_utils.py +++ b/ra_aid/agent_utils.py @@ -414,8 +414,10 @@ def run_research_agent( if agent is not None: logger.debug("Research agent created successfully") - fallback_handler = FallbackHandler(config, tools) - _result = run_agent_with_retry(agent, prompt, run_config, fallback_handler) + none_or_fallback_handler = init_fallback_handler(agent, config, tools) + _result = run_agent_with_retry( + agent, prompt, run_config, none_or_fallback_handler + ) if _result: # Log research completion log_work_event(f"Completed research phase for: {base_task_or_query}") @@ -531,8 +533,10 @@ def run_web_research_agent( console.print(Panel(Markdown(console_message), title="🔬 Researching...")) logger.debug("Web research agent completed successfully") - fallback_handler = FallbackHandler(config, tools) - _result = run_agent_with_retry(agent, prompt, run_config, fallback_handler) + none_or_fallback_handler = init_fallback_handler(agent, config, tools) + _result = run_agent_with_retry( + agent, prompt, run_config, none_or_fallback_handler + ) if _result: # Log web research completion log_work_event(f"Completed web research phase for: {query}") @@ -637,9 +641,9 @@ def run_planning_agent( try: print_stage_header("Planning Stage") logger.debug("Planning agent completed successfully") - fallback_handler = FallbackHandler(config, tools) + none_or_fallback_handler = init_fallback_handler(agent, config, tools) _result = run_agent_with_retry( - agent, planning_prompt, run_config, fallback_handler + agent, planning_prompt, run_config, none_or_fallback_handler ) if _result: # Log planning completion @@ -745,8 +749,10 @@ def run_task_implementation_agent( try: logger.debug("Implementation agent completed successfully") - fallback_handler = FallbackHandler(config, tools) - _result = run_agent_with_retry(agent, prompt, run_config, fallback_handler) + none_or_fallback_handler = init_fallback_handler(agent, config, tools) + _result = run_agent_with_retry( + agent, prompt, run_config, none_or_fallback_handler + ) if _result: # Log task implementation completion log_work_event(f"Completed implementation of task: {task}") @@ -846,18 +852,29 @@ def get_agent_type(agent: RAgents) -> Literal["CiaynAgent", "React"]: Determines the type of the agent. Returns "CiaynAgent" if agent is an instance of CiaynAgent, otherwise "React". """ - + if isinstance(agent, CiaynAgent): return "CiaynAgent" else: return "React" + +def init_fallback_handler(agent: RAgents, config: Dict[str, Any], tools: List[Any]): + """ + Initialize fallback handler if agent is of type "React"; otherwise return None. + """ + agent_type = get_agent_type(agent) + if agent_type == "React": + return FallbackHandler(config, tools) + return None + + def _handle_fallback_response( error: ToolExecutionError, - fallback_handler, + fallback_handler: Optional[FallbackHandler], agent: RAgents, agent_type: str, - msg_list: list + msg_list: list, ) -> None: """ Handle fallback response by invoking fallback_handler and updating msg_list. @@ -921,7 +938,9 @@ def run_agent_with_retry( logger.debug("Agent run completed successfully") return "Agent run completed successfully" except ToolExecutionError as e: - _handle_fallback_response(e, fallback_handler, agent, agent_type, msg_list) + _handle_fallback_response( + e, fallback_handler, agent, agent_type, msg_list + ) continue except FallbackToolExecutionError as e: msg_list.append( diff --git a/ra_aid/agents/ciayn_agent.py b/ra_aid/agents/ciayn_agent.py index aa2b3ce..df12722 100644 --- a/ra_aid/agents/ciayn_agent.py +++ b/ra_aid/agents/ciayn_agent.py @@ -288,6 +288,7 @@ class CiaynAgent: logger.debug(f"Code generated by agent: {response.content}") last_result = self._execute_tool(response) self.chat_history.append(response) + self.fallback_handler.reset_fallback_handler() yield {} except ToolExecutionError as e: diff --git a/ra_aid/fallback_handler.py b/ra_aid/fallback_handler.py index 86c4ab8..30f5690 100644 --- a/ra_aid/fallback_handler.py +++ b/ra_aid/fallback_handler.py @@ -384,7 +384,9 @@ class FallbackHandler: tool_calls = response.get("additional_kwargs").get("tool_calls") return tool_calls - def handle_failure_response(self, error: ToolExecutionError, agent, agent_type: str): + def handle_failure_response( + self, error: ToolExecutionError, agent, agent_type: str + ): """ Handle a tool failure by calling handle_failure and, if a fallback response is returned and the agent type is "React", return a list of SystemMessage objects wrapping each message from the fallback response.