diff --git a/ra_aid/agent_utils.py b/ra_aid/agent_utils.py index c272839..18ed975 100644 --- a/ra_aid/agent_utils.py +++ b/ra_aid/agent_utils.py @@ -34,8 +34,8 @@ from ra_aid.console.formatting import print_error, print_stage_header from ra_aid.console.output import print_agent_output from ra_aid.exceptions import ( AgentInterrupt, - ToolExecutionError, FallbackToolExecutionError, + ToolExecutionError, ) from ra_aid.fallback_handler import FallbackHandler from ra_aid.logging_config import get_logger @@ -846,12 +846,29 @@ def get_agent_type(agent: RAgents) -> Literal["CiaynAgent", "React"]: Determines the type of the agent. Returns "CiaynAgent" if agent is an instance of CiaynAgent, otherwise "React". """ - + if isinstance(agent, CiaynAgent): return "CiaynAgent" else: return "React" +def _handle_fallback_response( + error: ToolExecutionError, + fallback_handler, + agent: RAgents, + agent_type: str, + msg_list: list +) -> None: + """ + Handle fallback response by invoking fallback_handler and updating msg_list. + """ + if not fallback_handler: + return + fallback_response = fallback_handler.handle_failure(error, agent) + if fallback_response and agent_type == "React": + msg_list_response = [SystemMessage(str(msg)) for msg in fallback_response] + msg_list.extend(msg_list_response) + def _run_agent_stream(agent: RAgents, msg_list: list[BaseMessage], config: dict): for chunk in agent.stream({"messages": msg_list}, config): @@ -904,16 +921,7 @@ def run_agent_with_retry( logger.debug("Agent run completed successfully") return "Agent run completed successfully" except ToolExecutionError as e: - if not fallback_handler: - continue - - fallback_response = fallback_handler.handle_failure(e, agent) - if fallback_response: - if agent_type == "React": - msg_list_response = [ - SystemMessage(str(msg)) for msg in fallback_response - ] - msg_list.extend(msg_list_response) + _handle_fallback_response(e, fallback_handler, agent, agent_type, msg_list) continue except FallbackToolExecutionError as e: msg_list.append( diff --git a/ra_aid/agents/ciayn_agent.py b/ra_aid/agents/ciayn_agent.py index 1f2cb46..aa2b3ce 100644 --- a/ra_aid/agents/ciayn_agent.py +++ b/ra_aid/agents/ciayn_agent.py @@ -6,15 +6,15 @@ from langchain_core.language_models import BaseChatModel from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage from langchain_core.tools import BaseTool -from ra_aid.logging_config import get_logger -from ra_aid.tools.expert import get_model -from ra_aid.prompts import CIAYN_AGENT_BASE_PROMPT, EXTRACT_TOOL_CALL_PROMPT +from ra_aid.config import DEFAULT_MAX_TOOL_FAILURES from ra_aid.console.output import cpm from ra_aid.exceptions import ToolExecutionError from ra_aid.fallback_handler import FallbackHandler +from ra_aid.logging_config import get_logger from ra_aid.models_params import DEFAULT_TOKEN_LIMIT +from ra_aid.prompts import CIAYN_AGENT_BASE_PROMPT, EXTRACT_TOOL_CALL_PROMPT +from ra_aid.tools.expert import get_model from ra_aid.tools.reflection import get_function_info -from ra_aid.config import DEFAULT_MAX_TOOL_FAILURES logger = get_logger(__name__) @@ -271,7 +271,7 @@ class CiaynAgent: return initial_messages + chat_history def stream( - self, messages_dict: Dict[str, List[Any]], config: Dict[str, Any] = None + self, messages_dict: Dict[str, List[Any]], _config: Dict[str, Any] = None ) -> Generator[Dict[str, Any], None, None]: """Stream agent responses in a format compatible with print_agent_output.""" initial_messages = messages_dict.get("messages", []) diff --git a/ra_aid/agents_alias.py b/ra_aid/agents_alias.py index d3e74c0..2cf6077 100644 --- a/ra_aid/agents_alias.py +++ b/ra_aid/agents_alias.py @@ -1,6 +1,7 @@ -from langgraph.graph.graph import CompiledGraph from typing import TYPE_CHECKING +from langgraph.graph.graph import CompiledGraph + # Unfortunately need this to avoid Circular Imports if TYPE_CHECKING: from ra_aid.agents.ciayn_agent import CiaynAgent diff --git a/ra_aid/fallback_handler.py b/ra_aid/fallback_handler.py index b42da97..86c4ab8 100644 --- a/ra_aid/fallback_handler.py +++ b/ra_aid/fallback_handler.py @@ -13,7 +13,7 @@ from ra_aid.config import ( RETRY_FALLBACK_COUNT, ) from ra_aid.console.output import cpm -from ra_aid.exceptions import ToolExecutionError, FallbackToolExecutionError +from ra_aid.exceptions import FallbackToolExecutionError, ToolExecutionError from ra_aid.llm import initialize_llm, validate_provider_env from ra_aid.logging_config import get_logger from ra_aid.tool_configs import get_all_tools @@ -383,3 +383,13 @@ class FallbackHandler: ): tool_calls = response.get("additional_kwargs").get("tool_calls") return tool_calls + + def handle_failure_response(self, error: ToolExecutionError, agent, agent_type: str): + """ + Handle a tool failure by calling handle_failure and, if a fallback response is returned and the agent type is "React", + return a list of SystemMessage objects wrapping each message from the fallback response. + """ + fallback_response = self.handle_failure(error, agent) + if fallback_response and agent_type == "React": + return [SystemMessage(str(msg)) for msg in fallback_response] + return None