feat(agent_utils.py): add _handle_fallback_response function to streamline fallback handling logic
refactor(agent_utils.py): extract fallback handling logic from run_agent_with_retry to improve code readability fix(ciayn_agent.py): update stream method parameter name for consistency chore(agents_alias.py): reorder import statements to follow best practices style(fallback_handler.py): reorder exception imports for consistency and clarity
This commit is contained in:
parent
115cde98b6
commit
63e48db9de
|
|
@ -34,8 +34,8 @@ from ra_aid.console.formatting import print_error, print_stage_header
|
|||
from ra_aid.console.output import print_agent_output
|
||||
from ra_aid.exceptions import (
|
||||
AgentInterrupt,
|
||||
ToolExecutionError,
|
||||
FallbackToolExecutionError,
|
||||
ToolExecutionError,
|
||||
)
|
||||
from ra_aid.fallback_handler import FallbackHandler
|
||||
from ra_aid.logging_config import get_logger
|
||||
|
|
@ -852,6 +852,23 @@ def get_agent_type(agent: RAgents) -> Literal["CiaynAgent", "React"]:
|
|||
else:
|
||||
return "React"
|
||||
|
||||
def _handle_fallback_response(
|
||||
error: ToolExecutionError,
|
||||
fallback_handler,
|
||||
agent: RAgents,
|
||||
agent_type: str,
|
||||
msg_list: list
|
||||
) -> None:
|
||||
"""
|
||||
Handle fallback response by invoking fallback_handler and updating msg_list.
|
||||
"""
|
||||
if not fallback_handler:
|
||||
return
|
||||
fallback_response = fallback_handler.handle_failure(error, agent)
|
||||
if fallback_response and agent_type == "React":
|
||||
msg_list_response = [SystemMessage(str(msg)) for msg in fallback_response]
|
||||
msg_list.extend(msg_list_response)
|
||||
|
||||
|
||||
def _run_agent_stream(agent: RAgents, msg_list: list[BaseMessage], config: dict):
|
||||
for chunk in agent.stream({"messages": msg_list}, config):
|
||||
|
|
@ -904,16 +921,7 @@ def run_agent_with_retry(
|
|||
logger.debug("Agent run completed successfully")
|
||||
return "Agent run completed successfully"
|
||||
except ToolExecutionError as e:
|
||||
if not fallback_handler:
|
||||
continue
|
||||
|
||||
fallback_response = fallback_handler.handle_failure(e, agent)
|
||||
if fallback_response:
|
||||
if agent_type == "React":
|
||||
msg_list_response = [
|
||||
SystemMessage(str(msg)) for msg in fallback_response
|
||||
]
|
||||
msg_list.extend(msg_list_response)
|
||||
_handle_fallback_response(e, fallback_handler, agent, agent_type, msg_list)
|
||||
continue
|
||||
except FallbackToolExecutionError as e:
|
||||
msg_list.append(
|
||||
|
|
|
|||
|
|
@ -6,15 +6,15 @@ from langchain_core.language_models import BaseChatModel
|
|||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
from ra_aid.logging_config import get_logger
|
||||
from ra_aid.tools.expert import get_model
|
||||
from ra_aid.prompts import CIAYN_AGENT_BASE_PROMPT, EXTRACT_TOOL_CALL_PROMPT
|
||||
from ra_aid.config import DEFAULT_MAX_TOOL_FAILURES
|
||||
from ra_aid.console.output import cpm
|
||||
from ra_aid.exceptions import ToolExecutionError
|
||||
from ra_aid.fallback_handler import FallbackHandler
|
||||
from ra_aid.logging_config import get_logger
|
||||
from ra_aid.models_params import DEFAULT_TOKEN_LIMIT
|
||||
from ra_aid.prompts import CIAYN_AGENT_BASE_PROMPT, EXTRACT_TOOL_CALL_PROMPT
|
||||
from ra_aid.tools.expert import get_model
|
||||
from ra_aid.tools.reflection import get_function_info
|
||||
from ra_aid.config import DEFAULT_MAX_TOOL_FAILURES
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
|
@ -271,7 +271,7 @@ class CiaynAgent:
|
|||
return initial_messages + chat_history
|
||||
|
||||
def stream(
|
||||
self, messages_dict: Dict[str, List[Any]], config: Dict[str, Any] = None
|
||||
self, messages_dict: Dict[str, List[Any]], _config: Dict[str, Any] = None
|
||||
) -> Generator[Dict[str, Any], None, None]:
|
||||
"""Stream agent responses in a format compatible with print_agent_output."""
|
||||
initial_messages = messages_dict.get("messages", [])
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
from langgraph.graph.graph import CompiledGraph
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from langgraph.graph.graph import CompiledGraph
|
||||
|
||||
# Unfortunately need this to avoid Circular Imports
|
||||
if TYPE_CHECKING:
|
||||
from ra_aid.agents.ciayn_agent import CiaynAgent
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ from ra_aid.config import (
|
|||
RETRY_FALLBACK_COUNT,
|
||||
)
|
||||
from ra_aid.console.output import cpm
|
||||
from ra_aid.exceptions import ToolExecutionError, FallbackToolExecutionError
|
||||
from ra_aid.exceptions import FallbackToolExecutionError, ToolExecutionError
|
||||
from ra_aid.llm import initialize_llm, validate_provider_env
|
||||
from ra_aid.logging_config import get_logger
|
||||
from ra_aid.tool_configs import get_all_tools
|
||||
|
|
@ -383,3 +383,13 @@ class FallbackHandler:
|
|||
):
|
||||
tool_calls = response.get("additional_kwargs").get("tool_calls")
|
||||
return tool_calls
|
||||
|
||||
def handle_failure_response(self, error: ToolExecutionError, agent, agent_type: str):
|
||||
"""
|
||||
Handle a tool failure by calling handle_failure and, if a fallback response is returned and the agent type is "React",
|
||||
return a list of SystemMessage objects wrapping each message from the fallback response.
|
||||
"""
|
||||
fallback_response = self.handle_failure(error, agent)
|
||||
if fallback_response and agent_type == "React":
|
||||
return [SystemMessage(str(msg)) for msg in fallback_response]
|
||||
return None
|
||||
|
|
|
|||
Loading…
Reference in New Issue