feat(agent_utils.py): introduce get_agent_type function to determine agent type and improve code clarity
refactor(agent_utils.py): update _run_agent_stream to utilize agent type for output printing fix(ciayn_agent.py): modify _execute_tool to handle BaseMessage and improve error reporting feat(ciayn_agent.py): add extract_tool_name method to identify tool names from code chore(agents_alias.py): create agents_alias module to avoid circular imports and define RAgents type refactor(config.py): remove direct import of CiaynAgent and update RAgents definition fix(output.py): update print_agent_output to accept agent type for better error handling fix(exceptions.py): add CiaynToolExecutionError for distinguishing tool execution failures refactor(fallback_handler.py): improve logging and error handling in fallback mechanism
This commit is contained in:
parent
96b41458a1
commit
e508e4d1f2
|
|
@ -19,7 +19,6 @@ from langchain_core.messages import (
|
|||
)
|
||||
from langchain_core.tools import tool
|
||||
from langgraph.checkpoint.memory import MemorySaver
|
||||
from langgraph.graph.graph import CompiledGraph
|
||||
from langgraph.prebuilt import create_react_agent
|
||||
from langgraph.prebuilt.chat_agent_executor import AgentState
|
||||
from litellm import get_model_info
|
||||
|
|
@ -28,7 +27,8 @@ from rich.markdown import Markdown
|
|||
from rich.panel import Panel
|
||||
|
||||
from ra_aid.agents.ciayn_agent import CiaynAgent
|
||||
from ra_aid.config import DEFAULT_MAX_TEST_CMD_RETRIES, DEFAULT_RECURSION_LIMIT, RAgents
|
||||
from ra_aid.agents_alias import RAgents
|
||||
from ra_aid.config import DEFAULT_MAX_TEST_CMD_RETRIES, DEFAULT_RECURSION_LIMIT
|
||||
from ra_aid.console.formatting import print_error, print_stage_header
|
||||
from ra_aid.console.output import print_agent_output
|
||||
from ra_aid.exceptions import AgentInterrupt, ToolExecutionError
|
||||
|
|
@ -836,16 +836,24 @@ def _handle_api_error(e, attempt, max_retries, base_delay):
|
|||
time.sleep(0.1)
|
||||
|
||||
|
||||
def get_agent_type(agent: RAgents) -> Literal["CiaynAgent", "React"]:
|
||||
"""
|
||||
Determines the type of the agent.
|
||||
Returns "CiaynAgent" if agent is an instance of CiaynAgent, otherwise "React".
|
||||
"""
|
||||
|
||||
if isinstance(agent, CiaynAgent):
|
||||
return "CiaynAgent"
|
||||
else:
|
||||
return "React"
|
||||
|
||||
|
||||
def _run_agent_stream(agent: RAgents, msg_list: list[BaseMessage], config: dict):
|
||||
for chunk in agent.stream({"messages": msg_list}, config):
|
||||
logger.debug("Agent output: %s", chunk)
|
||||
check_interrupt()
|
||||
print_agent_output(chunk)
|
||||
if _global_memory["plan_completed"] or _global_memory["task_completed"]:
|
||||
reset_agent_completion_flags()
|
||||
break
|
||||
check_interrupt()
|
||||
print_agent_output(chunk)
|
||||
agent_type = get_agent_type(agent)
|
||||
print_agent_output(chunk, agent_type)
|
||||
if _global_memory["plan_completed"] or _global_memory["task_completed"]:
|
||||
reset_agent_completion_flags()
|
||||
break
|
||||
|
|
@ -889,10 +897,13 @@ def run_agent_with_retry(
|
|||
logger.debug("Agent run completed successfully")
|
||||
return "Agent run completed successfully"
|
||||
except ToolExecutionError as e:
|
||||
fallback_response = fallback_handler.handle_failure(e, agent)
|
||||
if fallback_response:
|
||||
msg_list.extend(fallback_response)
|
||||
continue
|
||||
print("except ToolExecutionError in AGENT UTILS")
|
||||
if not isinstance(agent, CiaynAgent):
|
||||
logger.debug("AGENT UTILS ToolExecutionError called!")
|
||||
fallback_response = fallback_handler.handle_failure(e, agent)
|
||||
if fallback_response:
|
||||
msg_list.extend(fallback_response)
|
||||
continue
|
||||
except (KeyboardInterrupt, AgentInterrupt):
|
||||
raise
|
||||
except (
|
||||
|
|
|
|||
|
|
@ -2,9 +2,13 @@ import re
|
|||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Generator, List, Optional, Union
|
||||
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
from ra_aid.console.output import cpm
|
||||
from ra_aid.exceptions import ToolExecutionError
|
||||
from ra_aid.fallback_handler import FallbackHandler
|
||||
from ra_aid.logging_config import get_logger
|
||||
from ra_aid.models_params import DEFAULT_TOKEN_LIMIT
|
||||
from ra_aid.tools.reflection import get_function_info
|
||||
|
|
@ -70,8 +74,8 @@ class CiaynAgent:
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
tools: list,
|
||||
model: BaseChatModel,
|
||||
tools: list[BaseTool],
|
||||
max_history_messages: int = 50,
|
||||
max_tokens: Optional[int] = DEFAULT_TOKEN_LIMIT,
|
||||
config: Optional[dict] = None,
|
||||
|
|
@ -97,8 +101,10 @@ class CiaynAgent:
|
|||
self.available_functions = []
|
||||
for t in tools:
|
||||
self.available_functions.append(get_function_info(t.func))
|
||||
|
||||
self.tool_failure_current_provider = None
|
||||
self.tool_failure_current_model = None
|
||||
self.fallback_handler = FallbackHandler(config, tools)
|
||||
|
||||
def _build_prompt(self, last_result: Optional[str] = None) -> str:
|
||||
"""Build the prompt for the agent including available tools and context."""
|
||||
|
|
@ -229,8 +235,11 @@ Output **ONLY THE CODE** and **NO MARKDOWN BACKTICKS**"""
|
|||
|
||||
return base_prompt
|
||||
|
||||
def _execute_tool(self, code: str) -> str:
|
||||
def _execute_tool(self, msg: BaseMessage) -> str:
|
||||
"""Execute a tool call and return its result."""
|
||||
|
||||
cpm(f"execute_tool msg: { msg }")
|
||||
code = msg.content
|
||||
globals_dict = {tool.func.__name__: tool.func for tool in self.tools}
|
||||
|
||||
try:
|
||||
|
|
@ -240,9 +249,7 @@ Output **ONLY THE CODE** and **NO MARKDOWN BACKTICKS**"""
|
|||
# if the eval fails, try to extract it via a model call
|
||||
if validate_function_call_pattern(code):
|
||||
functions_list = "\n\n".join(self.available_functions)
|
||||
logger.debug(f"_execute_tool: code before extraction: {code}")
|
||||
code = _extract_tool_call(code, functions_list)
|
||||
logger.debug(f"_execute_tool: code after extraction: {code}")
|
||||
|
||||
logger.debug(
|
||||
f"_execute_tool: evaluating code: {code} with globals: {list(globals_dict.keys())}"
|
||||
|
|
@ -251,8 +258,15 @@ Output **ONLY THE CODE** and **NO MARKDOWN BACKTICKS**"""
|
|||
logger.debug(f"_execute_tool: result: {result}")
|
||||
return result
|
||||
except Exception as e:
|
||||
error_msg = f"Error executing code: {str(e)}"
|
||||
raise ToolExecutionError(error_msg)
|
||||
error_msg = f"Error: {str(e)} \n Could not excute code: {code}"
|
||||
tool_name = self.extract_tool_name(code)
|
||||
raise ToolExecutionError(error_msg, base_message=msg, tool_name=tool_name)
|
||||
|
||||
def extract_tool_name(self, code: str) -> str:
|
||||
match = re.match(r"\s*([\w_\-]+)\s*\(", code)
|
||||
if match:
|
||||
return match.group(1)
|
||||
return ""
|
||||
|
||||
def _create_agent_chunk(self, content: str) -> Dict[str, Any]:
|
||||
"""Create an agent chunk in the format expected by print_agent_output."""
|
||||
|
|
@ -354,18 +368,31 @@ Output **ONLY THE CODE** and **NO MARKDOWN BACKTICKS**"""
|
|||
|
||||
try:
|
||||
logger.debug(f"Code generated by agent: {response.content}")
|
||||
last_result = self._execute_tool(response.content)
|
||||
last_result = self._execute_tool(response)
|
||||
chat_history.append(response)
|
||||
first_iteration = False
|
||||
yield {}
|
||||
|
||||
except ToolExecutionError as e:
|
||||
chat_history.append(
|
||||
HumanMessage(
|
||||
content=f"Your tool call caused an error: {e}\n\nPlease correct your tool call and try again."
|
||||
fallback_response = self.fallback_handler.handle_failure(e, self)
|
||||
print(f"fallback_response={fallback_response}")
|
||||
if fallback_response:
|
||||
hm = HumanMessage(
|
||||
content="The fallback handler has fixed your tool call results are in the last System message."
|
||||
)
|
||||
)
|
||||
yield self._create_error_chunk(str(e))
|
||||
chat_history.extend(fallback_response)
|
||||
chat_history.append(hm)
|
||||
logger.debug("Appended fallback response to chat history.")
|
||||
yield {}
|
||||
else:
|
||||
yield self._create_error_chunk(str(e))
|
||||
# yield {"messages": [fallback_response[-1]]}
|
||||
|
||||
# chat_history.append(
|
||||
# HumanMessage(
|
||||
# content=f"Your tool call caused an error: {e}\n\nPlease correct your tool call and try again."
|
||||
# )
|
||||
# )
|
||||
|
||||
|
||||
def _extract_tool_call(code: str, functions_list: str) -> str:
|
||||
|
|
|
|||
|
|
@ -0,0 +1,10 @@
|
|||
from langgraph.graph.graph import CompiledGraph
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
# Unfortunately need this to avoid Circular Imports
|
||||
if TYPE_CHECKING:
|
||||
from ra_aid.agents.ciayn_agent import CiaynAgent
|
||||
|
||||
RAgents = CompiledGraph | CiaynAgent
|
||||
else:
|
||||
RAgents = CompiledGraph
|
||||
|
|
@ -15,8 +15,3 @@ VALID_PROVIDERS = [
|
|||
"deepseek",
|
||||
"gemini",
|
||||
]
|
||||
|
||||
from ra_aid.agents.ciayn_agent import CiaynAgent
|
||||
from langgraph.graph.graph import CompiledGraph
|
||||
|
||||
RAgents = CompiledGraph | CiaynAgent
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Dict, Literal, Optional
|
||||
|
||||
from langchain_core.messages import AIMessage
|
||||
from rich.markdown import Markdown
|
||||
|
|
@ -10,7 +10,9 @@ from ra_aid.exceptions import ToolExecutionError
|
|||
from .formatting import console
|
||||
|
||||
|
||||
def print_agent_output(chunk: Dict[str, Any]) -> None:
|
||||
def print_agent_output(
|
||||
chunk: Dict[str, Any], agent_type: Literal["CiaynAgent", "React"]
|
||||
) -> None:
|
||||
"""Print only the agent's message content, not tool calls.
|
||||
|
||||
Args:
|
||||
|
|
@ -44,7 +46,10 @@ def print_agent_output(chunk: Dict[str, Any]) -> None:
|
|||
)
|
||||
)
|
||||
tool_name = getattr(msg, "name", None)
|
||||
raise ToolExecutionError(err_msg, tool_name=tool_name, base_message=msg)
|
||||
if agent_type == "React":
|
||||
raise ToolExecutionError(
|
||||
err_msg, tool_name=tool_name, base_message=msg
|
||||
)
|
||||
|
||||
|
||||
def cpm(message: str, title: Optional[str] = None, border_style: str = "blue") -> None:
|
||||
|
|
|
|||
|
|
@ -31,3 +31,21 @@ class ToolExecutionError(Exception):
|
|||
super().__init__(message)
|
||||
self.base_message = base_message
|
||||
self.tool_name = tool_name
|
||||
|
||||
|
||||
class CiaynToolExecutionError(Exception):
|
||||
"""Exception raised when a tool execution fails.
|
||||
|
||||
This exception is used to distinguish tool execution failures
|
||||
from other types of errors in the agent system.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
base_message: Optional[BaseMessage] = None,
|
||||
tool_name: Optional[str] = None,
|
||||
):
|
||||
super().__init__(message)
|
||||
self.base_message = base_message
|
||||
self.tool_name = tool_name
|
||||
|
|
|
|||
|
|
@ -6,11 +6,11 @@ from langchain_core.messages import HumanMessage, SystemMessage
|
|||
from langchain_core.tools import BaseTool
|
||||
from langgraph.graph.message import BaseMessage
|
||||
|
||||
from ra_aid.agents_alias import RAgents
|
||||
from ra_aid.config import (
|
||||
DEFAULT_MAX_TOOL_FAILURES,
|
||||
FALLBACK_TOOL_MODEL_LIMIT,
|
||||
RETRY_FALLBACK_COUNT,
|
||||
RAgents,
|
||||
)
|
||||
from ra_aid.console.output import cpm
|
||||
from ra_aid.exceptions import ToolExecutionError
|
||||
|
|
@ -51,7 +51,8 @@ class FallbackHandler:
|
|||
self.current_tool_to_bind: None | BaseTool = None
|
||||
|
||||
cpm(
|
||||
"Fallback models selected: " + ", ".join([self._format_model(m) for m in self.fallback_tool_models]),
|
||||
"Fallback models selected: "
|
||||
+ ", ".join([self._format_model(m) for m in self.fallback_tool_models]),
|
||||
title="Fallback Models",
|
||||
)
|
||||
|
||||
|
|
@ -263,14 +264,18 @@ class FallbackHandler:
|
|||
|
||||
tool_call_result = self.invoke_prompt_tool_call(tool_call)
|
||||
cpm(str(tool_call_result), title="Fallback Tool Call Result")
|
||||
logger.debug(f"Fallback call successful with model: {self._format_model(fallback_model)}")
|
||||
logger.debug(
|
||||
f"Fallback call successful with model: {self._format_model(fallback_model)}"
|
||||
)
|
||||
|
||||
self.reset_fallback_handler()
|
||||
return [response, tool_call_result]
|
||||
except Exception as e:
|
||||
if isinstance(e, KeyboardInterrupt):
|
||||
raise
|
||||
logger.error(f"Fallback with model {self._format_model(fallback_model)} failed: {e}")
|
||||
logger.error(
|
||||
f"Fallback with model {self._format_model(fallback_model)} failed: {e}"
|
||||
)
|
||||
return None
|
||||
|
||||
def construct_prompt_msg_list(self):
|
||||
|
|
|
|||
Loading…
Reference in New Issue