feat(agent_utils.py): add _handle_fallback_response function to streamline fallback handling logic
refactor(agent_utils.py): extract fallback handling logic from run_agent_with_retry to improve code readability fix(ciayn_agent.py): update stream method parameter name for consistency chore(agents_alias.py): reorder import statements to follow best practices style(fallback_handler.py): reorder exception imports for consistency and clarity
This commit is contained in:
parent
115cde98b6
commit
63e48db9de
|
|
@ -34,8 +34,8 @@ from ra_aid.console.formatting import print_error, print_stage_header
|
||||||
from ra_aid.console.output import print_agent_output
|
from ra_aid.console.output import print_agent_output
|
||||||
from ra_aid.exceptions import (
|
from ra_aid.exceptions import (
|
||||||
AgentInterrupt,
|
AgentInterrupt,
|
||||||
ToolExecutionError,
|
|
||||||
FallbackToolExecutionError,
|
FallbackToolExecutionError,
|
||||||
|
ToolExecutionError,
|
||||||
)
|
)
|
||||||
from ra_aid.fallback_handler import FallbackHandler
|
from ra_aid.fallback_handler import FallbackHandler
|
||||||
from ra_aid.logging_config import get_logger
|
from ra_aid.logging_config import get_logger
|
||||||
|
|
@ -852,6 +852,23 @@ def get_agent_type(agent: RAgents) -> Literal["CiaynAgent", "React"]:
|
||||||
else:
|
else:
|
||||||
return "React"
|
return "React"
|
||||||
|
|
||||||
|
def _handle_fallback_response(
|
||||||
|
error: ToolExecutionError,
|
||||||
|
fallback_handler,
|
||||||
|
agent: RAgents,
|
||||||
|
agent_type: str,
|
||||||
|
msg_list: list
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Handle fallback response by invoking fallback_handler and updating msg_list.
|
||||||
|
"""
|
||||||
|
if not fallback_handler:
|
||||||
|
return
|
||||||
|
fallback_response = fallback_handler.handle_failure(error, agent)
|
||||||
|
if fallback_response and agent_type == "React":
|
||||||
|
msg_list_response = [SystemMessage(str(msg)) for msg in fallback_response]
|
||||||
|
msg_list.extend(msg_list_response)
|
||||||
|
|
||||||
|
|
||||||
def _run_agent_stream(agent: RAgents, msg_list: list[BaseMessage], config: dict):
|
def _run_agent_stream(agent: RAgents, msg_list: list[BaseMessage], config: dict):
|
||||||
for chunk in agent.stream({"messages": msg_list}, config):
|
for chunk in agent.stream({"messages": msg_list}, config):
|
||||||
|
|
@ -904,16 +921,7 @@ def run_agent_with_retry(
|
||||||
logger.debug("Agent run completed successfully")
|
logger.debug("Agent run completed successfully")
|
||||||
return "Agent run completed successfully"
|
return "Agent run completed successfully"
|
||||||
except ToolExecutionError as e:
|
except ToolExecutionError as e:
|
||||||
if not fallback_handler:
|
_handle_fallback_response(e, fallback_handler, agent, agent_type, msg_list)
|
||||||
continue
|
|
||||||
|
|
||||||
fallback_response = fallback_handler.handle_failure(e, agent)
|
|
||||||
if fallback_response:
|
|
||||||
if agent_type == "React":
|
|
||||||
msg_list_response = [
|
|
||||||
SystemMessage(str(msg)) for msg in fallback_response
|
|
||||||
]
|
|
||||||
msg_list.extend(msg_list_response)
|
|
||||||
continue
|
continue
|
||||||
except FallbackToolExecutionError as e:
|
except FallbackToolExecutionError as e:
|
||||||
msg_list.append(
|
msg_list.append(
|
||||||
|
|
|
||||||
|
|
@ -6,15 +6,15 @@ from langchain_core.language_models import BaseChatModel
|
||||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
|
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
|
||||||
from langchain_core.tools import BaseTool
|
from langchain_core.tools import BaseTool
|
||||||
|
|
||||||
from ra_aid.logging_config import get_logger
|
from ra_aid.config import DEFAULT_MAX_TOOL_FAILURES
|
||||||
from ra_aid.tools.expert import get_model
|
|
||||||
from ra_aid.prompts import CIAYN_AGENT_BASE_PROMPT, EXTRACT_TOOL_CALL_PROMPT
|
|
||||||
from ra_aid.console.output import cpm
|
from ra_aid.console.output import cpm
|
||||||
from ra_aid.exceptions import ToolExecutionError
|
from ra_aid.exceptions import ToolExecutionError
|
||||||
from ra_aid.fallback_handler import FallbackHandler
|
from ra_aid.fallback_handler import FallbackHandler
|
||||||
|
from ra_aid.logging_config import get_logger
|
||||||
from ra_aid.models_params import DEFAULT_TOKEN_LIMIT
|
from ra_aid.models_params import DEFAULT_TOKEN_LIMIT
|
||||||
|
from ra_aid.prompts import CIAYN_AGENT_BASE_PROMPT, EXTRACT_TOOL_CALL_PROMPT
|
||||||
|
from ra_aid.tools.expert import get_model
|
||||||
from ra_aid.tools.reflection import get_function_info
|
from ra_aid.tools.reflection import get_function_info
|
||||||
from ra_aid.config import DEFAULT_MAX_TOOL_FAILURES
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
@ -271,7 +271,7 @@ class CiaynAgent:
|
||||||
return initial_messages + chat_history
|
return initial_messages + chat_history
|
||||||
|
|
||||||
def stream(
|
def stream(
|
||||||
self, messages_dict: Dict[str, List[Any]], config: Dict[str, Any] = None
|
self, messages_dict: Dict[str, List[Any]], _config: Dict[str, Any] = None
|
||||||
) -> Generator[Dict[str, Any], None, None]:
|
) -> Generator[Dict[str, Any], None, None]:
|
||||||
"""Stream agent responses in a format compatible with print_agent_output."""
|
"""Stream agent responses in a format compatible with print_agent_output."""
|
||||||
initial_messages = messages_dict.get("messages", [])
|
initial_messages = messages_dict.get("messages", [])
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
from langgraph.graph.graph import CompiledGraph
|
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from langgraph.graph.graph import CompiledGraph
|
||||||
|
|
||||||
# Unfortunately need this to avoid Circular Imports
|
# Unfortunately need this to avoid Circular Imports
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ra_aid.agents.ciayn_agent import CiaynAgent
|
from ra_aid.agents.ciayn_agent import CiaynAgent
|
||||||
|
|
|
||||||
|
|
@ -13,7 +13,7 @@ from ra_aid.config import (
|
||||||
RETRY_FALLBACK_COUNT,
|
RETRY_FALLBACK_COUNT,
|
||||||
)
|
)
|
||||||
from ra_aid.console.output import cpm
|
from ra_aid.console.output import cpm
|
||||||
from ra_aid.exceptions import ToolExecutionError, FallbackToolExecutionError
|
from ra_aid.exceptions import FallbackToolExecutionError, ToolExecutionError
|
||||||
from ra_aid.llm import initialize_llm, validate_provider_env
|
from ra_aid.llm import initialize_llm, validate_provider_env
|
||||||
from ra_aid.logging_config import get_logger
|
from ra_aid.logging_config import get_logger
|
||||||
from ra_aid.tool_configs import get_all_tools
|
from ra_aid.tool_configs import get_all_tools
|
||||||
|
|
@ -383,3 +383,13 @@ class FallbackHandler:
|
||||||
):
|
):
|
||||||
tool_calls = response.get("additional_kwargs").get("tool_calls")
|
tool_calls = response.get("additional_kwargs").get("tool_calls")
|
||||||
return tool_calls
|
return tool_calls
|
||||||
|
|
||||||
|
def handle_failure_response(self, error: ToolExecutionError, agent, agent_type: str):
|
||||||
|
"""
|
||||||
|
Handle a tool failure by calling handle_failure and, if a fallback response is returned and the agent type is "React",
|
||||||
|
return a list of SystemMessage objects wrapping each message from the fallback response.
|
||||||
|
"""
|
||||||
|
fallback_response = self.handle_failure(error, agent)
|
||||||
|
if fallback_response and agent_type == "React":
|
||||||
|
return [SystemMessage(str(msg)) for msg in fallback_response]
|
||||||
|
return None
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue