feat(agent_utils.py): introduce get_agent_type function to determine agent type and improve code clarity
refactor(agent_utils.py): update _run_agent_stream to utilize agent type for output printing fix(ciayn_agent.py): modify _execute_tool to handle BaseMessage and improve error reporting feat(ciayn_agent.py): add extract_tool_name method to identify tool names from code chore(agents_alias.py): create agents_alias module to avoid circular imports and define RAgents type refactor(config.py): remove direct import of CiaynAgent and update RAgents definition fix(output.py): update print_agent_output to accept agent type for better error handling fix(exceptions.py): add CiaynToolExecutionError for distinguishing tool execution failures refactor(fallback_handler.py): improve logging and error handling in fallback mechanism
This commit is contained in:
parent
96b41458a1
commit
e508e4d1f2
|
|
@ -19,7 +19,6 @@ from langchain_core.messages import (
|
||||||
)
|
)
|
||||||
from langchain_core.tools import tool
|
from langchain_core.tools import tool
|
||||||
from langgraph.checkpoint.memory import MemorySaver
|
from langgraph.checkpoint.memory import MemorySaver
|
||||||
from langgraph.graph.graph import CompiledGraph
|
|
||||||
from langgraph.prebuilt import create_react_agent
|
from langgraph.prebuilt import create_react_agent
|
||||||
from langgraph.prebuilt.chat_agent_executor import AgentState
|
from langgraph.prebuilt.chat_agent_executor import AgentState
|
||||||
from litellm import get_model_info
|
from litellm import get_model_info
|
||||||
|
|
@ -28,7 +27,8 @@ from rich.markdown import Markdown
|
||||||
from rich.panel import Panel
|
from rich.panel import Panel
|
||||||
|
|
||||||
from ra_aid.agents.ciayn_agent import CiaynAgent
|
from ra_aid.agents.ciayn_agent import CiaynAgent
|
||||||
from ra_aid.config import DEFAULT_MAX_TEST_CMD_RETRIES, DEFAULT_RECURSION_LIMIT, RAgents
|
from ra_aid.agents_alias import RAgents
|
||||||
|
from ra_aid.config import DEFAULT_MAX_TEST_CMD_RETRIES, DEFAULT_RECURSION_LIMIT
|
||||||
from ra_aid.console.formatting import print_error, print_stage_header
|
from ra_aid.console.formatting import print_error, print_stage_header
|
||||||
from ra_aid.console.output import print_agent_output
|
from ra_aid.console.output import print_agent_output
|
||||||
from ra_aid.exceptions import AgentInterrupt, ToolExecutionError
|
from ra_aid.exceptions import AgentInterrupt, ToolExecutionError
|
||||||
|
|
@ -836,16 +836,24 @@ def _handle_api_error(e, attempt, max_retries, base_delay):
|
||||||
time.sleep(0.1)
|
time.sleep(0.1)
|
||||||
|
|
||||||
|
|
||||||
|
def get_agent_type(agent: RAgents) -> Literal["CiaynAgent", "React"]:
|
||||||
|
"""
|
||||||
|
Determines the type of the agent.
|
||||||
|
Returns "CiaynAgent" if agent is an instance of CiaynAgent, otherwise "React".
|
||||||
|
"""
|
||||||
|
|
||||||
|
if isinstance(agent, CiaynAgent):
|
||||||
|
return "CiaynAgent"
|
||||||
|
else:
|
||||||
|
return "React"
|
||||||
|
|
||||||
|
|
||||||
def _run_agent_stream(agent: RAgents, msg_list: list[BaseMessage], config: dict):
|
def _run_agent_stream(agent: RAgents, msg_list: list[BaseMessage], config: dict):
|
||||||
for chunk in agent.stream({"messages": msg_list}, config):
|
for chunk in agent.stream({"messages": msg_list}, config):
|
||||||
logger.debug("Agent output: %s", chunk)
|
logger.debug("Agent output: %s", chunk)
|
||||||
check_interrupt()
|
check_interrupt()
|
||||||
print_agent_output(chunk)
|
agent_type = get_agent_type(agent)
|
||||||
if _global_memory["plan_completed"] or _global_memory["task_completed"]:
|
print_agent_output(chunk, agent_type)
|
||||||
reset_agent_completion_flags()
|
|
||||||
break
|
|
||||||
check_interrupt()
|
|
||||||
print_agent_output(chunk)
|
|
||||||
if _global_memory["plan_completed"] or _global_memory["task_completed"]:
|
if _global_memory["plan_completed"] or _global_memory["task_completed"]:
|
||||||
reset_agent_completion_flags()
|
reset_agent_completion_flags()
|
||||||
break
|
break
|
||||||
|
|
@ -889,6 +897,9 @@ def run_agent_with_retry(
|
||||||
logger.debug("Agent run completed successfully")
|
logger.debug("Agent run completed successfully")
|
||||||
return "Agent run completed successfully"
|
return "Agent run completed successfully"
|
||||||
except ToolExecutionError as e:
|
except ToolExecutionError as e:
|
||||||
|
print("except ToolExecutionError in AGENT UTILS")
|
||||||
|
if not isinstance(agent, CiaynAgent):
|
||||||
|
logger.debug("AGENT UTILS ToolExecutionError called!")
|
||||||
fallback_response = fallback_handler.handle_failure(e, agent)
|
fallback_response = fallback_handler.handle_failure(e, agent)
|
||||||
if fallback_response:
|
if fallback_response:
|
||||||
msg_list.extend(fallback_response)
|
msg_list.extend(fallback_response)
|
||||||
|
|
|
||||||
|
|
@ -2,9 +2,13 @@ import re
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Dict, Generator, List, Optional, Union
|
from typing import Any, Dict, Generator, List, Optional, Union
|
||||||
|
|
||||||
|
from langchain_core.language_models import BaseChatModel
|
||||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
|
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
|
||||||
|
from langchain_core.tools import BaseTool
|
||||||
|
|
||||||
|
from ra_aid.console.output import cpm
|
||||||
from ra_aid.exceptions import ToolExecutionError
|
from ra_aid.exceptions import ToolExecutionError
|
||||||
|
from ra_aid.fallback_handler import FallbackHandler
|
||||||
from ra_aid.logging_config import get_logger
|
from ra_aid.logging_config import get_logger
|
||||||
from ra_aid.models_params import DEFAULT_TOKEN_LIMIT
|
from ra_aid.models_params import DEFAULT_TOKEN_LIMIT
|
||||||
from ra_aid.tools.reflection import get_function_info
|
from ra_aid.tools.reflection import get_function_info
|
||||||
|
|
@ -70,8 +74,8 @@ class CiaynAgent:
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model,
|
model: BaseChatModel,
|
||||||
tools: list,
|
tools: list[BaseTool],
|
||||||
max_history_messages: int = 50,
|
max_history_messages: int = 50,
|
||||||
max_tokens: Optional[int] = DEFAULT_TOKEN_LIMIT,
|
max_tokens: Optional[int] = DEFAULT_TOKEN_LIMIT,
|
||||||
config: Optional[dict] = None,
|
config: Optional[dict] = None,
|
||||||
|
|
@ -97,8 +101,10 @@ class CiaynAgent:
|
||||||
self.available_functions = []
|
self.available_functions = []
|
||||||
for t in tools:
|
for t in tools:
|
||||||
self.available_functions.append(get_function_info(t.func))
|
self.available_functions.append(get_function_info(t.func))
|
||||||
|
|
||||||
self.tool_failure_current_provider = None
|
self.tool_failure_current_provider = None
|
||||||
self.tool_failure_current_model = None
|
self.tool_failure_current_model = None
|
||||||
|
self.fallback_handler = FallbackHandler(config, tools)
|
||||||
|
|
||||||
def _build_prompt(self, last_result: Optional[str] = None) -> str:
|
def _build_prompt(self, last_result: Optional[str] = None) -> str:
|
||||||
"""Build the prompt for the agent including available tools and context."""
|
"""Build the prompt for the agent including available tools and context."""
|
||||||
|
|
@ -229,8 +235,11 @@ Output **ONLY THE CODE** and **NO MARKDOWN BACKTICKS**"""
|
||||||
|
|
||||||
return base_prompt
|
return base_prompt
|
||||||
|
|
||||||
def _execute_tool(self, code: str) -> str:
|
def _execute_tool(self, msg: BaseMessage) -> str:
|
||||||
"""Execute a tool call and return its result."""
|
"""Execute a tool call and return its result."""
|
||||||
|
|
||||||
|
cpm(f"execute_tool msg: { msg }")
|
||||||
|
code = msg.content
|
||||||
globals_dict = {tool.func.__name__: tool.func for tool in self.tools}
|
globals_dict = {tool.func.__name__: tool.func for tool in self.tools}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
@ -240,9 +249,7 @@ Output **ONLY THE CODE** and **NO MARKDOWN BACKTICKS**"""
|
||||||
# if the eval fails, try to extract it via a model call
|
# if the eval fails, try to extract it via a model call
|
||||||
if validate_function_call_pattern(code):
|
if validate_function_call_pattern(code):
|
||||||
functions_list = "\n\n".join(self.available_functions)
|
functions_list = "\n\n".join(self.available_functions)
|
||||||
logger.debug(f"_execute_tool: code before extraction: {code}")
|
|
||||||
code = _extract_tool_call(code, functions_list)
|
code = _extract_tool_call(code, functions_list)
|
||||||
logger.debug(f"_execute_tool: code after extraction: {code}")
|
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"_execute_tool: evaluating code: {code} with globals: {list(globals_dict.keys())}"
|
f"_execute_tool: evaluating code: {code} with globals: {list(globals_dict.keys())}"
|
||||||
|
|
@ -251,8 +258,15 @@ Output **ONLY THE CODE** and **NO MARKDOWN BACKTICKS**"""
|
||||||
logger.debug(f"_execute_tool: result: {result}")
|
logger.debug(f"_execute_tool: result: {result}")
|
||||||
return result
|
return result
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_msg = f"Error executing code: {str(e)}"
|
error_msg = f"Error: {str(e)} \n Could not excute code: {code}"
|
||||||
raise ToolExecutionError(error_msg)
|
tool_name = self.extract_tool_name(code)
|
||||||
|
raise ToolExecutionError(error_msg, base_message=msg, tool_name=tool_name)
|
||||||
|
|
||||||
|
def extract_tool_name(self, code: str) -> str:
|
||||||
|
match = re.match(r"\s*([\w_\-]+)\s*\(", code)
|
||||||
|
if match:
|
||||||
|
return match.group(1)
|
||||||
|
return ""
|
||||||
|
|
||||||
def _create_agent_chunk(self, content: str) -> Dict[str, Any]:
|
def _create_agent_chunk(self, content: str) -> Dict[str, Any]:
|
||||||
"""Create an agent chunk in the format expected by print_agent_output."""
|
"""Create an agent chunk in the format expected by print_agent_output."""
|
||||||
|
|
@ -354,18 +368,31 @@ Output **ONLY THE CODE** and **NO MARKDOWN BACKTICKS**"""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
logger.debug(f"Code generated by agent: {response.content}")
|
logger.debug(f"Code generated by agent: {response.content}")
|
||||||
last_result = self._execute_tool(response.content)
|
last_result = self._execute_tool(response)
|
||||||
chat_history.append(response)
|
chat_history.append(response)
|
||||||
first_iteration = False
|
first_iteration = False
|
||||||
yield {}
|
yield {}
|
||||||
|
|
||||||
except ToolExecutionError as e:
|
except ToolExecutionError as e:
|
||||||
chat_history.append(
|
fallback_response = self.fallback_handler.handle_failure(e, self)
|
||||||
HumanMessage(
|
print(f"fallback_response={fallback_response}")
|
||||||
content=f"Your tool call caused an error: {e}\n\nPlease correct your tool call and try again."
|
if fallback_response:
|
||||||
)
|
hm = HumanMessage(
|
||||||
|
content="The fallback handler has fixed your tool call results are in the last System message."
|
||||||
)
|
)
|
||||||
|
chat_history.extend(fallback_response)
|
||||||
|
chat_history.append(hm)
|
||||||
|
logger.debug("Appended fallback response to chat history.")
|
||||||
|
yield {}
|
||||||
|
else:
|
||||||
yield self._create_error_chunk(str(e))
|
yield self._create_error_chunk(str(e))
|
||||||
|
# yield {"messages": [fallback_response[-1]]}
|
||||||
|
|
||||||
|
# chat_history.append(
|
||||||
|
# HumanMessage(
|
||||||
|
# content=f"Your tool call caused an error: {e}\n\nPlease correct your tool call and try again."
|
||||||
|
# )
|
||||||
|
# )
|
||||||
|
|
||||||
|
|
||||||
def _extract_tool_call(code: str, functions_list: str) -> str:
|
def _extract_tool_call(code: str, functions_list: str) -> str:
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,10 @@
|
||||||
|
from langgraph.graph.graph import CompiledGraph
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
# Unfortunately need this to avoid Circular Imports
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ra_aid.agents.ciayn_agent import CiaynAgent
|
||||||
|
|
||||||
|
RAgents = CompiledGraph | CiaynAgent
|
||||||
|
else:
|
||||||
|
RAgents = CompiledGraph
|
||||||
|
|
@ -15,8 +15,3 @@ VALID_PROVIDERS = [
|
||||||
"deepseek",
|
"deepseek",
|
||||||
"gemini",
|
"gemini",
|
||||||
]
|
]
|
||||||
|
|
||||||
from ra_aid.agents.ciayn_agent import CiaynAgent
|
|
||||||
from langgraph.graph.graph import CompiledGraph
|
|
||||||
|
|
||||||
RAgents = CompiledGraph | CiaynAgent
|
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Literal, Optional
|
||||||
|
|
||||||
from langchain_core.messages import AIMessage
|
from langchain_core.messages import AIMessage
|
||||||
from rich.markdown import Markdown
|
from rich.markdown import Markdown
|
||||||
|
|
@ -10,7 +10,9 @@ from ra_aid.exceptions import ToolExecutionError
|
||||||
from .formatting import console
|
from .formatting import console
|
||||||
|
|
||||||
|
|
||||||
def print_agent_output(chunk: Dict[str, Any]) -> None:
|
def print_agent_output(
|
||||||
|
chunk: Dict[str, Any], agent_type: Literal["CiaynAgent", "React"]
|
||||||
|
) -> None:
|
||||||
"""Print only the agent's message content, not tool calls.
|
"""Print only the agent's message content, not tool calls.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
@ -44,7 +46,10 @@ def print_agent_output(chunk: Dict[str, Any]) -> None:
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
tool_name = getattr(msg, "name", None)
|
tool_name = getattr(msg, "name", None)
|
||||||
raise ToolExecutionError(err_msg, tool_name=tool_name, base_message=msg)
|
if agent_type == "React":
|
||||||
|
raise ToolExecutionError(
|
||||||
|
err_msg, tool_name=tool_name, base_message=msg
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def cpm(message: str, title: Optional[str] = None, border_style: str = "blue") -> None:
|
def cpm(message: str, title: Optional[str] = None, border_style: str = "blue") -> None:
|
||||||
|
|
|
||||||
|
|
@ -31,3 +31,21 @@ class ToolExecutionError(Exception):
|
||||||
super().__init__(message)
|
super().__init__(message)
|
||||||
self.base_message = base_message
|
self.base_message = base_message
|
||||||
self.tool_name = tool_name
|
self.tool_name = tool_name
|
||||||
|
|
||||||
|
|
||||||
|
class CiaynToolExecutionError(Exception):
|
||||||
|
"""Exception raised when a tool execution fails.
|
||||||
|
|
||||||
|
This exception is used to distinguish tool execution failures
|
||||||
|
from other types of errors in the agent system.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
message: str,
|
||||||
|
base_message: Optional[BaseMessage] = None,
|
||||||
|
tool_name: Optional[str] = None,
|
||||||
|
):
|
||||||
|
super().__init__(message)
|
||||||
|
self.base_message = base_message
|
||||||
|
self.tool_name = tool_name
|
||||||
|
|
|
||||||
|
|
@ -6,11 +6,11 @@ from langchain_core.messages import HumanMessage, SystemMessage
|
||||||
from langchain_core.tools import BaseTool
|
from langchain_core.tools import BaseTool
|
||||||
from langgraph.graph.message import BaseMessage
|
from langgraph.graph.message import BaseMessage
|
||||||
|
|
||||||
|
from ra_aid.agents_alias import RAgents
|
||||||
from ra_aid.config import (
|
from ra_aid.config import (
|
||||||
DEFAULT_MAX_TOOL_FAILURES,
|
DEFAULT_MAX_TOOL_FAILURES,
|
||||||
FALLBACK_TOOL_MODEL_LIMIT,
|
FALLBACK_TOOL_MODEL_LIMIT,
|
||||||
RETRY_FALLBACK_COUNT,
|
RETRY_FALLBACK_COUNT,
|
||||||
RAgents,
|
|
||||||
)
|
)
|
||||||
from ra_aid.console.output import cpm
|
from ra_aid.console.output import cpm
|
||||||
from ra_aid.exceptions import ToolExecutionError
|
from ra_aid.exceptions import ToolExecutionError
|
||||||
|
|
@ -51,7 +51,8 @@ class FallbackHandler:
|
||||||
self.current_tool_to_bind: None | BaseTool = None
|
self.current_tool_to_bind: None | BaseTool = None
|
||||||
|
|
||||||
cpm(
|
cpm(
|
||||||
"Fallback models selected: " + ", ".join([self._format_model(m) for m in self.fallback_tool_models]),
|
"Fallback models selected: "
|
||||||
|
+ ", ".join([self._format_model(m) for m in self.fallback_tool_models]),
|
||||||
title="Fallback Models",
|
title="Fallback Models",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -263,14 +264,18 @@ class FallbackHandler:
|
||||||
|
|
||||||
tool_call_result = self.invoke_prompt_tool_call(tool_call)
|
tool_call_result = self.invoke_prompt_tool_call(tool_call)
|
||||||
cpm(str(tool_call_result), title="Fallback Tool Call Result")
|
cpm(str(tool_call_result), title="Fallback Tool Call Result")
|
||||||
logger.debug(f"Fallback call successful with model: {self._format_model(fallback_model)}")
|
logger.debug(
|
||||||
|
f"Fallback call successful with model: {self._format_model(fallback_model)}"
|
||||||
|
)
|
||||||
|
|
||||||
self.reset_fallback_handler()
|
self.reset_fallback_handler()
|
||||||
return [response, tool_call_result]
|
return [response, tool_call_result]
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if isinstance(e, KeyboardInterrupt):
|
if isinstance(e, KeyboardInterrupt):
|
||||||
raise
|
raise
|
||||||
logger.error(f"Fallback with model {self._format_model(fallback_model)} failed: {e}")
|
logger.error(
|
||||||
|
f"Fallback with model {self._format_model(fallback_model)} failed: {e}"
|
||||||
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def construct_prompt_msg_list(self):
|
def construct_prompt_msg_list(self):
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue