feat(agent_utils.py): add _handle_fallback_response function to streamline fallback handling logic

refactor(agent_utils.py): extract fallback handling logic from run_agent_with_retry to improve code readability
fix(ciayn_agent.py): update stream method parameter name for consistency
chore(agents_alias.py): reorder import statements to follow best practices
style(fallback_handler.py): reorder exception imports for consistency and clarity
This commit is contained in:
Ariel Frischer 2025-02-13 16:47:31 -08:00
parent 115cde98b6
commit 63e48db9de
4 changed files with 38 additions and 19 deletions

View File

@ -34,8 +34,8 @@ from ra_aid.console.formatting import print_error, print_stage_header
from ra_aid.console.output import print_agent_output
from ra_aid.exceptions import (
AgentInterrupt,
ToolExecutionError,
FallbackToolExecutionError,
ToolExecutionError,
)
from ra_aid.fallback_handler import FallbackHandler
from ra_aid.logging_config import get_logger
@ -852,6 +852,23 @@ def get_agent_type(agent: RAgents) -> Literal["CiaynAgent", "React"]:
else:
return "React"
def _handle_fallback_response(
error: ToolExecutionError,
fallback_handler,
agent: RAgents,
agent_type: str,
msg_list: list
) -> None:
"""
Handle fallback response by invoking fallback_handler and updating msg_list.
"""
if not fallback_handler:
return
fallback_response = fallback_handler.handle_failure(error, agent)
if fallback_response and agent_type == "React":
msg_list_response = [SystemMessage(str(msg)) for msg in fallback_response]
msg_list.extend(msg_list_response)
def _run_agent_stream(agent: RAgents, msg_list: list[BaseMessage], config: dict):
for chunk in agent.stream({"messages": msg_list}, config):
@ -904,16 +921,7 @@ def run_agent_with_retry(
logger.debug("Agent run completed successfully")
return "Agent run completed successfully"
except ToolExecutionError as e:
if not fallback_handler:
continue
fallback_response = fallback_handler.handle_failure(e, agent)
if fallback_response:
if agent_type == "React":
msg_list_response = [
SystemMessage(str(msg)) for msg in fallback_response
]
msg_list.extend(msg_list_response)
_handle_fallback_response(e, fallback_handler, agent, agent_type, msg_list)
continue
except FallbackToolExecutionError as e:
msg_list.append(

View File

@ -6,15 +6,15 @@ from langchain_core.language_models import BaseChatModel
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
from langchain_core.tools import BaseTool
from ra_aid.logging_config import get_logger
from ra_aid.tools.expert import get_model
from ra_aid.prompts import CIAYN_AGENT_BASE_PROMPT, EXTRACT_TOOL_CALL_PROMPT
from ra_aid.config import DEFAULT_MAX_TOOL_FAILURES
from ra_aid.console.output import cpm
from ra_aid.exceptions import ToolExecutionError
from ra_aid.fallback_handler import FallbackHandler
from ra_aid.logging_config import get_logger
from ra_aid.models_params import DEFAULT_TOKEN_LIMIT
from ra_aid.prompts import CIAYN_AGENT_BASE_PROMPT, EXTRACT_TOOL_CALL_PROMPT
from ra_aid.tools.expert import get_model
from ra_aid.tools.reflection import get_function_info
from ra_aid.config import DEFAULT_MAX_TOOL_FAILURES
logger = get_logger(__name__)
@ -271,7 +271,7 @@ class CiaynAgent:
return initial_messages + chat_history
def stream(
self, messages_dict: Dict[str, List[Any]], config: Dict[str, Any] = None
self, messages_dict: Dict[str, List[Any]], _config: Dict[str, Any] = None
) -> Generator[Dict[str, Any], None, None]:
"""Stream agent responses in a format compatible with print_agent_output."""
initial_messages = messages_dict.get("messages", [])

View File

@ -1,6 +1,7 @@
from langgraph.graph.graph import CompiledGraph
from typing import TYPE_CHECKING
from langgraph.graph.graph import CompiledGraph
# Unfortunately need this to avoid Circular Imports
if TYPE_CHECKING:
from ra_aid.agents.ciayn_agent import CiaynAgent

View File

@ -13,7 +13,7 @@ from ra_aid.config import (
RETRY_FALLBACK_COUNT,
)
from ra_aid.console.output import cpm
from ra_aid.exceptions import ToolExecutionError, FallbackToolExecutionError
from ra_aid.exceptions import FallbackToolExecutionError, ToolExecutionError
from ra_aid.llm import initialize_llm, validate_provider_env
from ra_aid.logging_config import get_logger
from ra_aid.tool_configs import get_all_tools
@ -383,3 +383,13 @@ class FallbackHandler:
):
tool_calls = response.get("additional_kwargs").get("tool_calls")
return tool_calls
def handle_failure_response(self, error: ToolExecutionError, agent, agent_type: str):
"""
Handle a tool failure by calling handle_failure and, if a fallback response is returned and the agent type is "React",
return a list of SystemMessage objects wrapping each message from the fallback response.
"""
fallback_response = self.handle_failure(error, agent)
if fallback_response and agent_type == "React":
return [SystemMessage(str(msg)) for msg in fallback_response]
return None