From e508e4d1f2abda10c3e8002940fd6fc7048d5bb8 Mon Sep 17 00:00:00 2001 From: Ariel Frischer Date: Wed, 12 Feb 2025 17:55:43 -0800 Subject: [PATCH] feat(agent_utils.py): introduce get_agent_type function to determine agent type and improve code clarity refactor(agent_utils.py): update _run_agent_stream to utilize agent type for output printing fix(ciayn_agent.py): modify _execute_tool to handle BaseMessage and improve error reporting feat(ciayn_agent.py): add extract_tool_name method to identify tool names from code chore(agents_alias.py): create agents_alias module to avoid circular imports and define RAgents type refactor(config.py): remove direct import of CiaynAgent and update RAgents definition fix(output.py): update print_agent_output to accept agent type for better error handling fix(exceptions.py): add CiaynToolExecutionError for distinguishing tool execution failures refactor(fallback_handler.py): improve logging and error handling in fallback mechanism --- ra_aid/agent_utils.py | 35 ++++++++++++++++-------- ra_aid/agents/ciayn_agent.py | 53 +++++++++++++++++++++++++++--------- ra_aid/agents_alias.py | 10 +++++++ ra_aid/config.py | 5 ---- ra_aid/console/output.py | 11 ++++++-- ra_aid/exceptions.py | 18 ++++++++++++ ra_aid/fallback_handler.py | 13 ++++++--- 7 files changed, 108 insertions(+), 37 deletions(-) create mode 100644 ra_aid/agents_alias.py diff --git a/ra_aid/agent_utils.py b/ra_aid/agent_utils.py index 75b95b7..0f67623 100644 --- a/ra_aid/agent_utils.py +++ b/ra_aid/agent_utils.py @@ -19,7 +19,6 @@ from langchain_core.messages import ( ) from langchain_core.tools import tool from langgraph.checkpoint.memory import MemorySaver -from langgraph.graph.graph import CompiledGraph from langgraph.prebuilt import create_react_agent from langgraph.prebuilt.chat_agent_executor import AgentState from litellm import get_model_info @@ -28,7 +27,8 @@ from rich.markdown import Markdown from rich.panel import Panel from ra_aid.agents.ciayn_agent import CiaynAgent -from ra_aid.config import DEFAULT_MAX_TEST_CMD_RETRIES, DEFAULT_RECURSION_LIMIT, RAgents +from ra_aid.agents_alias import RAgents +from ra_aid.config import DEFAULT_MAX_TEST_CMD_RETRIES, DEFAULT_RECURSION_LIMIT from ra_aid.console.formatting import print_error, print_stage_header from ra_aid.console.output import print_agent_output from ra_aid.exceptions import AgentInterrupt, ToolExecutionError @@ -836,16 +836,24 @@ def _handle_api_error(e, attempt, max_retries, base_delay): time.sleep(0.1) +def get_agent_type(agent: RAgents) -> Literal["CiaynAgent", "React"]: + """ + Determines the type of the agent. + Returns "CiaynAgent" if agent is an instance of CiaynAgent, otherwise "React". + """ + + if isinstance(agent, CiaynAgent): + return "CiaynAgent" + else: + return "React" + + def _run_agent_stream(agent: RAgents, msg_list: list[BaseMessage], config: dict): for chunk in agent.stream({"messages": msg_list}, config): logger.debug("Agent output: %s", chunk) check_interrupt() - print_agent_output(chunk) - if _global_memory["plan_completed"] or _global_memory["task_completed"]: - reset_agent_completion_flags() - break - check_interrupt() - print_agent_output(chunk) + agent_type = get_agent_type(agent) + print_agent_output(chunk, agent_type) if _global_memory["plan_completed"] or _global_memory["task_completed"]: reset_agent_completion_flags() break @@ -889,10 +897,13 @@ def run_agent_with_retry( logger.debug("Agent run completed successfully") return "Agent run completed successfully" except ToolExecutionError as e: - fallback_response = fallback_handler.handle_failure(e, agent) - if fallback_response: - msg_list.extend(fallback_response) - continue + print("except ToolExecutionError in AGENT UTILS") + if not isinstance(agent, CiaynAgent): + logger.debug("AGENT UTILS ToolExecutionError called!") + fallback_response = fallback_handler.handle_failure(e, agent) + if fallback_response: + msg_list.extend(fallback_response) + continue except (KeyboardInterrupt, AgentInterrupt): raise except ( diff --git a/ra_aid/agents/ciayn_agent.py b/ra_aid/agents/ciayn_agent.py index 4060684..59e4e2d 100644 --- a/ra_aid/agents/ciayn_agent.py +++ b/ra_aid/agents/ciayn_agent.py @@ -2,9 +2,13 @@ import re from dataclasses import dataclass from typing import Any, Dict, Generator, List, Optional, Union +from langchain_core.language_models import BaseChatModel from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage +from langchain_core.tools import BaseTool +from ra_aid.console.output import cpm from ra_aid.exceptions import ToolExecutionError +from ra_aid.fallback_handler import FallbackHandler from ra_aid.logging_config import get_logger from ra_aid.models_params import DEFAULT_TOKEN_LIMIT from ra_aid.tools.reflection import get_function_info @@ -70,8 +74,8 @@ class CiaynAgent: def __init__( self, - model, - tools: list, + model: BaseChatModel, + tools: list[BaseTool], max_history_messages: int = 50, max_tokens: Optional[int] = DEFAULT_TOKEN_LIMIT, config: Optional[dict] = None, @@ -97,8 +101,10 @@ class CiaynAgent: self.available_functions = [] for t in tools: self.available_functions.append(get_function_info(t.func)) + self.tool_failure_current_provider = None self.tool_failure_current_model = None + self.fallback_handler = FallbackHandler(config, tools) def _build_prompt(self, last_result: Optional[str] = None) -> str: """Build the prompt for the agent including available tools and context.""" @@ -229,8 +235,11 @@ Output **ONLY THE CODE** and **NO MARKDOWN BACKTICKS**""" return base_prompt - def _execute_tool(self, code: str) -> str: + def _execute_tool(self, msg: BaseMessage) -> str: """Execute a tool call and return its result.""" + + cpm(f"execute_tool msg: { msg }") + code = msg.content globals_dict = {tool.func.__name__: tool.func for tool in self.tools} try: @@ -240,9 +249,7 @@ Output **ONLY THE CODE** and **NO MARKDOWN BACKTICKS**""" # if the eval fails, try to extract it via a model call if validate_function_call_pattern(code): functions_list = "\n\n".join(self.available_functions) - logger.debug(f"_execute_tool: code before extraction: {code}") code = _extract_tool_call(code, functions_list) - logger.debug(f"_execute_tool: code after extraction: {code}") logger.debug( f"_execute_tool: evaluating code: {code} with globals: {list(globals_dict.keys())}" @@ -251,8 +258,15 @@ Output **ONLY THE CODE** and **NO MARKDOWN BACKTICKS**""" logger.debug(f"_execute_tool: result: {result}") return result except Exception as e: - error_msg = f"Error executing code: {str(e)}" - raise ToolExecutionError(error_msg) + error_msg = f"Error: {str(e)} \n Could not excute code: {code}" + tool_name = self.extract_tool_name(code) + raise ToolExecutionError(error_msg, base_message=msg, tool_name=tool_name) + + def extract_tool_name(self, code: str) -> str: + match = re.match(r"\s*([\w_\-]+)\s*\(", code) + if match: + return match.group(1) + return "" def _create_agent_chunk(self, content: str) -> Dict[str, Any]: """Create an agent chunk in the format expected by print_agent_output.""" @@ -354,18 +368,31 @@ Output **ONLY THE CODE** and **NO MARKDOWN BACKTICKS**""" try: logger.debug(f"Code generated by agent: {response.content}") - last_result = self._execute_tool(response.content) + last_result = self._execute_tool(response) chat_history.append(response) first_iteration = False yield {} except ToolExecutionError as e: - chat_history.append( - HumanMessage( - content=f"Your tool call caused an error: {e}\n\nPlease correct your tool call and try again." + fallback_response = self.fallback_handler.handle_failure(e, self) + print(f"fallback_response={fallback_response}") + if fallback_response: + hm = HumanMessage( + content="The fallback handler has fixed your tool call results are in the last System message." ) - ) - yield self._create_error_chunk(str(e)) + chat_history.extend(fallback_response) + chat_history.append(hm) + logger.debug("Appended fallback response to chat history.") + yield {} + else: + yield self._create_error_chunk(str(e)) + # yield {"messages": [fallback_response[-1]]} + + # chat_history.append( + # HumanMessage( + # content=f"Your tool call caused an error: {e}\n\nPlease correct your tool call and try again." + # ) + # ) def _extract_tool_call(code: str, functions_list: str) -> str: diff --git a/ra_aid/agents_alias.py b/ra_aid/agents_alias.py new file mode 100644 index 0000000..d3e74c0 --- /dev/null +++ b/ra_aid/agents_alias.py @@ -0,0 +1,10 @@ +from langgraph.graph.graph import CompiledGraph +from typing import TYPE_CHECKING + +# Unfortunately need this to avoid Circular Imports +if TYPE_CHECKING: + from ra_aid.agents.ciayn_agent import CiaynAgent + + RAgents = CompiledGraph | CiaynAgent +else: + RAgents = CompiledGraph diff --git a/ra_aid/config.py b/ra_aid/config.py index 8bfaf5c..54d7995 100644 --- a/ra_aid/config.py +++ b/ra_aid/config.py @@ -15,8 +15,3 @@ VALID_PROVIDERS = [ "deepseek", "gemini", ] - -from ra_aid.agents.ciayn_agent import CiaynAgent -from langgraph.graph.graph import CompiledGraph - -RAgents = CompiledGraph | CiaynAgent diff --git a/ra_aid/console/output.py b/ra_aid/console/output.py index ee6fe08..9b06508 100644 --- a/ra_aid/console/output.py +++ b/ra_aid/console/output.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Optional +from typing import Any, Dict, Literal, Optional from langchain_core.messages import AIMessage from rich.markdown import Markdown @@ -10,7 +10,9 @@ from ra_aid.exceptions import ToolExecutionError from .formatting import console -def print_agent_output(chunk: Dict[str, Any]) -> None: +def print_agent_output( + chunk: Dict[str, Any], agent_type: Literal["CiaynAgent", "React"] +) -> None: """Print only the agent's message content, not tool calls. Args: @@ -44,7 +46,10 @@ def print_agent_output(chunk: Dict[str, Any]) -> None: ) ) tool_name = getattr(msg, "name", None) - raise ToolExecutionError(err_msg, tool_name=tool_name, base_message=msg) + if agent_type == "React": + raise ToolExecutionError( + err_msg, tool_name=tool_name, base_message=msg + ) def cpm(message: str, title: Optional[str] = None, border_style: str = "blue") -> None: diff --git a/ra_aid/exceptions.py b/ra_aid/exceptions.py index 34710d9..b7c714b 100644 --- a/ra_aid/exceptions.py +++ b/ra_aid/exceptions.py @@ -31,3 +31,21 @@ class ToolExecutionError(Exception): super().__init__(message) self.base_message = base_message self.tool_name = tool_name + + +class CiaynToolExecutionError(Exception): + """Exception raised when a tool execution fails. + + This exception is used to distinguish tool execution failures + from other types of errors in the agent system. + """ + + def __init__( + self, + message: str, + base_message: Optional[BaseMessage] = None, + tool_name: Optional[str] = None, + ): + super().__init__(message) + self.base_message = base_message + self.tool_name = tool_name diff --git a/ra_aid/fallback_handler.py b/ra_aid/fallback_handler.py index c1c5ea2..e9bb9f3 100644 --- a/ra_aid/fallback_handler.py +++ b/ra_aid/fallback_handler.py @@ -6,11 +6,11 @@ from langchain_core.messages import HumanMessage, SystemMessage from langchain_core.tools import BaseTool from langgraph.graph.message import BaseMessage +from ra_aid.agents_alias import RAgents from ra_aid.config import ( DEFAULT_MAX_TOOL_FAILURES, FALLBACK_TOOL_MODEL_LIMIT, RETRY_FALLBACK_COUNT, - RAgents, ) from ra_aid.console.output import cpm from ra_aid.exceptions import ToolExecutionError @@ -51,7 +51,8 @@ class FallbackHandler: self.current_tool_to_bind: None | BaseTool = None cpm( - "Fallback models selected: " + ", ".join([self._format_model(m) for m in self.fallback_tool_models]), + "Fallback models selected: " + + ", ".join([self._format_model(m) for m in self.fallback_tool_models]), title="Fallback Models", ) @@ -263,14 +264,18 @@ class FallbackHandler: tool_call_result = self.invoke_prompt_tool_call(tool_call) cpm(str(tool_call_result), title="Fallback Tool Call Result") - logger.debug(f"Fallback call successful with model: {self._format_model(fallback_model)}") + logger.debug( + f"Fallback call successful with model: {self._format_model(fallback_model)}" + ) self.reset_fallback_handler() return [response, tool_call_result] except Exception as e: if isinstance(e, KeyboardInterrupt): raise - logger.error(f"Fallback with model {self._format_model(fallback_model)} failed: {e}") + logger.error( + f"Fallback with model {self._format_model(fallback_model)} failed: {e}" + ) return None def construct_prompt_msg_list(self):