feat(agent_utils.py): introduce get_agent_type function to determine agent type and improve code clarity

refactor(agent_utils.py): update _run_agent_stream to utilize agent type for output printing
fix(ciayn_agent.py): modify _execute_tool to handle BaseMessage and improve error reporting
feat(ciayn_agent.py): add extract_tool_name method to identify tool names from code
chore(agents_alias.py): create agents_alias module to avoid circular imports and define RAgents type
refactor(config.py): remove direct import of CiaynAgent and update RAgents definition
fix(output.py): update print_agent_output to accept agent type for better error handling
fix(exceptions.py): add CiaynToolExecutionError for distinguishing tool execution failures
refactor(fallback_handler.py): improve logging and error handling in fallback mechanism
This commit is contained in:
Ariel Frischer 2025-02-12 17:55:43 -08:00
parent 96b41458a1
commit e508e4d1f2
7 changed files with 108 additions and 37 deletions

View File

@ -19,7 +19,6 @@ from langchain_core.messages import (
)
from langchain_core.tools import tool
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph.graph import CompiledGraph
from langgraph.prebuilt import create_react_agent
from langgraph.prebuilt.chat_agent_executor import AgentState
from litellm import get_model_info
@ -28,7 +27,8 @@ from rich.markdown import Markdown
from rich.panel import Panel
from ra_aid.agents.ciayn_agent import CiaynAgent
from ra_aid.config import DEFAULT_MAX_TEST_CMD_RETRIES, DEFAULT_RECURSION_LIMIT, RAgents
from ra_aid.agents_alias import RAgents
from ra_aid.config import DEFAULT_MAX_TEST_CMD_RETRIES, DEFAULT_RECURSION_LIMIT
from ra_aid.console.formatting import print_error, print_stage_header
from ra_aid.console.output import print_agent_output
from ra_aid.exceptions import AgentInterrupt, ToolExecutionError
@ -836,16 +836,24 @@ def _handle_api_error(e, attempt, max_retries, base_delay):
time.sleep(0.1)
def get_agent_type(agent: RAgents) -> Literal["CiaynAgent", "React"]:
"""
Determines the type of the agent.
Returns "CiaynAgent" if agent is an instance of CiaynAgent, otherwise "React".
"""
if isinstance(agent, CiaynAgent):
return "CiaynAgent"
else:
return "React"
def _run_agent_stream(agent: RAgents, msg_list: list[BaseMessage], config: dict):
for chunk in agent.stream({"messages": msg_list}, config):
logger.debug("Agent output: %s", chunk)
check_interrupt()
print_agent_output(chunk)
if _global_memory["plan_completed"] or _global_memory["task_completed"]:
reset_agent_completion_flags()
break
check_interrupt()
print_agent_output(chunk)
agent_type = get_agent_type(agent)
print_agent_output(chunk, agent_type)
if _global_memory["plan_completed"] or _global_memory["task_completed"]:
reset_agent_completion_flags()
break
@ -889,10 +897,13 @@ def run_agent_with_retry(
logger.debug("Agent run completed successfully")
return "Agent run completed successfully"
except ToolExecutionError as e:
fallback_response = fallback_handler.handle_failure(e, agent)
if fallback_response:
msg_list.extend(fallback_response)
continue
print("except ToolExecutionError in AGENT UTILS")
if not isinstance(agent, CiaynAgent):
logger.debug("AGENT UTILS ToolExecutionError called!")
fallback_response = fallback_handler.handle_failure(e, agent)
if fallback_response:
msg_list.extend(fallback_response)
continue
except (KeyboardInterrupt, AgentInterrupt):
raise
except (

View File

@ -2,9 +2,13 @@ import re
from dataclasses import dataclass
from typing import Any, Dict, Generator, List, Optional, Union
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
from langchain_core.tools import BaseTool
from ra_aid.console.output import cpm
from ra_aid.exceptions import ToolExecutionError
from ra_aid.fallback_handler import FallbackHandler
from ra_aid.logging_config import get_logger
from ra_aid.models_params import DEFAULT_TOKEN_LIMIT
from ra_aid.tools.reflection import get_function_info
@ -70,8 +74,8 @@ class CiaynAgent:
def __init__(
self,
model,
tools: list,
model: BaseChatModel,
tools: list[BaseTool],
max_history_messages: int = 50,
max_tokens: Optional[int] = DEFAULT_TOKEN_LIMIT,
config: Optional[dict] = None,
@ -97,8 +101,10 @@ class CiaynAgent:
self.available_functions = []
for t in tools:
self.available_functions.append(get_function_info(t.func))
self.tool_failure_current_provider = None
self.tool_failure_current_model = None
self.fallback_handler = FallbackHandler(config, tools)
def _build_prompt(self, last_result: Optional[str] = None) -> str:
"""Build the prompt for the agent including available tools and context."""
@ -229,8 +235,11 @@ Output **ONLY THE CODE** and **NO MARKDOWN BACKTICKS**"""
return base_prompt
def _execute_tool(self, code: str) -> str:
def _execute_tool(self, msg: BaseMessage) -> str:
"""Execute a tool call and return its result."""
cpm(f"execute_tool msg: { msg }")
code = msg.content
globals_dict = {tool.func.__name__: tool.func for tool in self.tools}
try:
@ -240,9 +249,7 @@ Output **ONLY THE CODE** and **NO MARKDOWN BACKTICKS**"""
# if the eval fails, try to extract it via a model call
if validate_function_call_pattern(code):
functions_list = "\n\n".join(self.available_functions)
logger.debug(f"_execute_tool: code before extraction: {code}")
code = _extract_tool_call(code, functions_list)
logger.debug(f"_execute_tool: code after extraction: {code}")
logger.debug(
f"_execute_tool: evaluating code: {code} with globals: {list(globals_dict.keys())}"
@ -251,8 +258,15 @@ Output **ONLY THE CODE** and **NO MARKDOWN BACKTICKS**"""
logger.debug(f"_execute_tool: result: {result}")
return result
except Exception as e:
error_msg = f"Error executing code: {str(e)}"
raise ToolExecutionError(error_msg)
error_msg = f"Error: {str(e)} \n Could not excute code: {code}"
tool_name = self.extract_tool_name(code)
raise ToolExecutionError(error_msg, base_message=msg, tool_name=tool_name)
def extract_tool_name(self, code: str) -> str:
match = re.match(r"\s*([\w_\-]+)\s*\(", code)
if match:
return match.group(1)
return ""
def _create_agent_chunk(self, content: str) -> Dict[str, Any]:
"""Create an agent chunk in the format expected by print_agent_output."""
@ -354,18 +368,31 @@ Output **ONLY THE CODE** and **NO MARKDOWN BACKTICKS**"""
try:
logger.debug(f"Code generated by agent: {response.content}")
last_result = self._execute_tool(response.content)
last_result = self._execute_tool(response)
chat_history.append(response)
first_iteration = False
yield {}
except ToolExecutionError as e:
chat_history.append(
HumanMessage(
content=f"Your tool call caused an error: {e}\n\nPlease correct your tool call and try again."
fallback_response = self.fallback_handler.handle_failure(e, self)
print(f"fallback_response={fallback_response}")
if fallback_response:
hm = HumanMessage(
content="The fallback handler has fixed your tool call results are in the last System message."
)
)
yield self._create_error_chunk(str(e))
chat_history.extend(fallback_response)
chat_history.append(hm)
logger.debug("Appended fallback response to chat history.")
yield {}
else:
yield self._create_error_chunk(str(e))
# yield {"messages": [fallback_response[-1]]}
# chat_history.append(
# HumanMessage(
# content=f"Your tool call caused an error: {e}\n\nPlease correct your tool call and try again."
# )
# )
def _extract_tool_call(code: str, functions_list: str) -> str:

10
ra_aid/agents_alias.py Normal file
View File

@ -0,0 +1,10 @@
from langgraph.graph.graph import CompiledGraph
from typing import TYPE_CHECKING
# Unfortunately need this to avoid Circular Imports
if TYPE_CHECKING:
from ra_aid.agents.ciayn_agent import CiaynAgent
RAgents = CompiledGraph | CiaynAgent
else:
RAgents = CompiledGraph

View File

@ -15,8 +15,3 @@ VALID_PROVIDERS = [
"deepseek",
"gemini",
]
from ra_aid.agents.ciayn_agent import CiaynAgent
from langgraph.graph.graph import CompiledGraph
RAgents = CompiledGraph | CiaynAgent

View File

@ -1,4 +1,4 @@
from typing import Any, Dict, Optional
from typing import Any, Dict, Literal, Optional
from langchain_core.messages import AIMessage
from rich.markdown import Markdown
@ -10,7 +10,9 @@ from ra_aid.exceptions import ToolExecutionError
from .formatting import console
def print_agent_output(chunk: Dict[str, Any]) -> None:
def print_agent_output(
chunk: Dict[str, Any], agent_type: Literal["CiaynAgent", "React"]
) -> None:
"""Print only the agent's message content, not tool calls.
Args:
@ -44,7 +46,10 @@ def print_agent_output(chunk: Dict[str, Any]) -> None:
)
)
tool_name = getattr(msg, "name", None)
raise ToolExecutionError(err_msg, tool_name=tool_name, base_message=msg)
if agent_type == "React":
raise ToolExecutionError(
err_msg, tool_name=tool_name, base_message=msg
)
def cpm(message: str, title: Optional[str] = None, border_style: str = "blue") -> None:

View File

@ -31,3 +31,21 @@ class ToolExecutionError(Exception):
super().__init__(message)
self.base_message = base_message
self.tool_name = tool_name
class CiaynToolExecutionError(Exception):
"""Exception raised when a tool execution fails.
This exception is used to distinguish tool execution failures
from other types of errors in the agent system.
"""
def __init__(
self,
message: str,
base_message: Optional[BaseMessage] = None,
tool_name: Optional[str] = None,
):
super().__init__(message)
self.base_message = base_message
self.tool_name = tool_name

View File

@ -6,11 +6,11 @@ from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.tools import BaseTool
from langgraph.graph.message import BaseMessage
from ra_aid.agents_alias import RAgents
from ra_aid.config import (
DEFAULT_MAX_TOOL_FAILURES,
FALLBACK_TOOL_MODEL_LIMIT,
RETRY_FALLBACK_COUNT,
RAgents,
)
from ra_aid.console.output import cpm
from ra_aid.exceptions import ToolExecutionError
@ -51,7 +51,8 @@ class FallbackHandler:
self.current_tool_to_bind: None | BaseTool = None
cpm(
"Fallback models selected: " + ", ".join([self._format_model(m) for m in self.fallback_tool_models]),
"Fallback models selected: "
+ ", ".join([self._format_model(m) for m in self.fallback_tool_models]),
title="Fallback Models",
)
@ -263,14 +264,18 @@ class FallbackHandler:
tool_call_result = self.invoke_prompt_tool_call(tool_call)
cpm(str(tool_call_result), title="Fallback Tool Call Result")
logger.debug(f"Fallback call successful with model: {self._format_model(fallback_model)}")
logger.debug(
f"Fallback call successful with model: {self._format_model(fallback_model)}"
)
self.reset_fallback_handler()
return [response, tool_call_result]
except Exception as e:
if isinstance(e, KeyboardInterrupt):
raise
logger.error(f"Fallback with model {self._format_model(fallback_model)} failed: {e}")
logger.error(
f"Fallback with model {self._format_model(fallback_model)} failed: {e}"
)
return None
def construct_prompt_msg_list(self):