Merge pull request #90 from ariel-frischer/fallback-tools

Experimental Tool Fallback Handler
This commit is contained in:
Ariel Frischer 2025-02-17 15:39:49 -08:00 committed by GitHub
commit 9b0027a922
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
20 changed files with 2140 additions and 324 deletions

View File

@ -175,6 +175,8 @@ More information is available in our [Usage Examples](https://docs.ra-aid.ai/cat
- `--hil, -H`: Enable human-in-the-loop mode for interactive assistance during task execution
- `--chat`: Enable chat mode with direct human interaction (implies --hil)
- `--verbose`: Enable verbose logging output
- `--experimental-fallback-handler`: Enable experimental fallback handler to attempt to fix too calls when they fail 3 times consecutively.
- `--pretty-logger`: Enables panel markdown formatted logger messages for debugging purposes.
- `--temperature`: LLM temperature (0.0-2.0) to control randomness in responses
- `--disable-limit-tokens`: Disable token limiting for Anthropic Claude react agents
- `--recursion-limit`: Maximum recursion depth for agent operations (default: 100)

View File

@ -12,7 +12,6 @@ from rich.text import Text
from ra_aid import print_error, print_stage_header
from ra_aid.__version__ import __version__
from ra_aid.agent_utils import (
AgentInterrupt,
create_agent,
run_agent_with_retry,
run_planning_agent,
@ -22,11 +21,15 @@ from ra_aid.config import (
DEFAULT_MAX_TEST_CMD_RETRIES,
DEFAULT_RECURSION_LIMIT,
DEFAULT_TEST_CMD_TIMEOUT,
VALID_PROVIDERS,
)
from ra_aid.console.output import cpm
from ra_aid.dependencies import check_dependencies
from ra_aid.env import validate_environment
from ra_aid.exceptions import AgentInterrupt
from ra_aid.llm import initialize_llm
from ra_aid.logging_config import get_logger, setup_logging
from ra_aid.models_params import DEFAULT_TEMPERATURE, models_params
from ra_aid.project_info import format_project_info, get_project_info
from ra_aid.prompts import CHAT_PROMPT, WEB_RESEARCH_PROMPT_SECTION_CHAT
from ra_aid.tool_configs import get_chat_tools
@ -45,14 +48,6 @@ def launch_webui(host: str, port: int):
def parse_arguments(args=None):
VALID_PROVIDERS = [
"anthropic",
"openai",
"openrouter",
"openai-compatible",
"deepseek",
"gemini",
]
ANTHROPIC_DEFAULT_MODEL = "claude-3-5-sonnet-20241022"
OPENAI_DEFAULT_MODEL = "gpt-4o"
@ -145,6 +140,9 @@ Examples:
parser.add_argument(
"--verbose", action="store_true", help="Enable verbose logging output"
)
parser.add_argument(
"--pretty-logger", action="store_true", help="Enable pretty logging output"
)
parser.add_argument(
"--temperature",
type=float,
@ -156,6 +154,11 @@ Examples:
action="store_false",
help="Whether to disable token limiting for Anthropic Claude react agents. Token limiter removes older messages to prevent maximum token limit API errors.",
)
parser.add_argument(
"--experimental-fallback-handler",
action="store_true",
help="Enable experimental fallback handler.",
)
parser.add_argument(
"--recursion-limit",
type=int,
@ -286,7 +289,7 @@ def is_stage_requested(stage: str) -> bool:
def main():
"""Main entry point for the ra-aid command line tool."""
args = parse_arguments()
setup_logging(args.verbose)
setup_logging(args.verbose, args.pretty_logger)
logger.debug("Starting RA.Aid with arguments: %s", args)
# Launch web interface if requested
@ -304,7 +307,6 @@ def main():
logger.debug("Environment validation successful")
# Validate model configuration early
from ra_aid.models_params import models_params
model_config = models_params.get(args.provider, {}).get(args.model or "", {})
supports_temperature = model_config.get(
@ -316,10 +318,10 @@ def main():
if supports_temperature and args.temperature is None:
args.temperature = model_config.get("default_temperature")
if args.temperature is None:
print_error(
f"Temperature must be provided for model {args.model} which supports temperature"
cpm(
f"This model supports temperature argument but none was given. Setting default temperature to {DEFAULT_TEMPERATURE}."
)
sys.exit(1)
args.temperature = DEFAULT_TEMPERATURE
logger.debug(
f"Using default temperature {args.temperature} for model {args.model}"
)
@ -445,6 +447,7 @@ def main():
"auto_test": args.auto_test,
"test_cmd": args.test_cmd,
"max_test_cmd_retries": args.max_test_cmd_retries,
"experimental_fallback_handler": args.experimental_fallback_handler,
"test_cmd_timeout": args.test_cmd_timeout,
}

View File

@ -12,7 +12,12 @@ from typing import Any, Dict, List, Literal, Optional, Sequence
import litellm
from anthropic import APIError, APITimeoutError, InternalServerError, RateLimitError
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import BaseMessage, HumanMessage, trim_messages
from langchain_core.messages import (
BaseMessage,
HumanMessage,
SystemMessage,
trim_messages,
)
from langchain_core.tools import tool
from langgraph.checkpoint.memory import MemorySaver
from langgraph.prebuilt import create_react_agent
@ -23,10 +28,16 @@ from rich.markdown import Markdown
from rich.panel import Panel
from ra_aid.agents.ciayn_agent import CiaynAgent
from ra_aid.agents_alias import RAgents
from ra_aid.config import DEFAULT_MAX_TEST_CMD_RETRIES, DEFAULT_RECURSION_LIMIT
from ra_aid.console.formatting import print_error, print_stage_header
from ra_aid.console.output import print_agent_output
from ra_aid.exceptions import AgentInterrupt
from ra_aid.exceptions import (
AgentInterrupt,
FallbackToolExecutionError,
ToolExecutionError,
)
from ra_aid.fallback_handler import FallbackHandler
from ra_aid.logging_config import get_logger
from ra_aid.models_params import DEFAULT_TOKEN_LIMIT, models_params
from ra_aid.project_info import (
@ -238,7 +249,7 @@ def create_agent(
*,
checkpointer: Any = None,
agent_type: str = "default",
) -> Any:
):
"""Create a react agent with the given configuration.
Args:
@ -270,7 +281,7 @@ def create_agent(
return create_react_agent(model, tools, **agent_kwargs)
else:
logger.debug("Using CiaynAgent agent instance")
return CiaynAgent(model, tools, max_tokens=max_input_tokens)
return CiaynAgent(model, tools, max_tokens=max_input_tokens, config=config)
except Exception as e:
# Default to REACT agent if provider/model detection fails
@ -332,9 +343,6 @@ def run_research_agent(
if memory is None:
memory = MemorySaver()
if thread_id is None:
thread_id = str(uuid.uuid4())
tools = get_research_tools(
research_only=research_only,
expert_enabled=expert_enabled,
@ -405,8 +413,11 @@ def run_research_agent(
display_project_status(project_info)
if agent is not None:
logger.debug("Research agent completed successfully")
_result = run_agent_with_retry(agent, prompt, run_config)
logger.debug("Research agent created successfully")
none_or_fallback_handler = init_fallback_handler(agent, config, tools)
_result = run_agent_with_retry(
agent, prompt, run_config, none_or_fallback_handler
)
if _result:
# Log research completion
log_work_event(f"Completed research phase for: {base_task_or_query}")
@ -522,7 +533,10 @@ def run_web_research_agent(
console.print(Panel(Markdown(console_message), title="🔬 Researching..."))
logger.debug("Web research agent completed successfully")
_result = run_agent_with_retry(agent, prompt, run_config)
none_or_fallback_handler = init_fallback_handler(agent, config, tools)
_result = run_agent_with_retry(
agent, prompt, run_config, none_or_fallback_handler
)
if _result:
# Log web research completion
log_work_event(f"Completed web research phase for: {query}")
@ -627,7 +641,10 @@ def run_planning_agent(
try:
print_stage_header("Planning Stage")
logger.debug("Planning agent completed successfully")
_result = run_agent_with_retry(agent, planning_prompt, run_config)
none_or_fallback_handler = init_fallback_handler(agent, config, tools)
_result = run_agent_with_retry(
agent, planning_prompt, run_config, none_or_fallback_handler
)
if _result:
# Log planning completion
log_work_event(f"Completed planning phase for: {base_task}")
@ -732,7 +749,10 @@ def run_task_implementation_agent(
try:
logger.debug("Implementation agent completed successfully")
_result = run_agent_with_retry(agent, prompt, run_config)
none_or_fallback_handler = init_fallback_handler(agent, config, tools)
_result = run_agent_with_retry(
agent, prompt, run_config, none_or_fallback_handler
)
if _result:
# Log task implementation completion
log_work_event(f"Completed implementation of task: {task}")
@ -775,49 +795,141 @@ def check_interrupt():
raise AgentInterrupt("Interrupt requested")
def run_agent_with_retry(agent, prompt: str, config: dict) -> Optional[str]:
"""Run an agent with retry logic for API errors."""
logger.debug("Running agent with prompt length: %d", len(prompt))
original_handler = None
# New helper functions for run_agent_with_retry refactoring
def _setup_interrupt_handling():
if threading.current_thread() is threading.main_thread():
original_handler = signal.getsignal(signal.SIGINT)
signal.signal(signal.SIGINT, _request_interrupt)
return original_handler
return None
def _restore_interrupt_handling(original_handler):
if original_handler and threading.current_thread() is threading.main_thread():
signal.signal(signal.SIGINT, original_handler)
def _increment_agent_depth():
current_depth = _global_memory.get("agent_depth", 0)
_global_memory["agent_depth"] = current_depth + 1
def _decrement_agent_depth():
_global_memory["agent_depth"] = _global_memory.get("agent_depth", 1) - 1
def reset_agent_completion_flags():
_global_memory["plan_completed"] = False
_global_memory["task_completed"] = False
_global_memory["completion_message"] = ""
def _execute_test_command_wrapper(original_prompt, config, test_attempts, auto_test):
return execute_test_command(config, original_prompt, test_attempts, auto_test)
def _handle_api_error(e, attempt, max_retries, base_delay):
if isinstance(e, ValueError):
error_str = str(e).lower()
if "code" not in error_str or "429" not in error_str:
raise e
if attempt == max_retries - 1:
logger.error("Max retries reached, failing: %s", str(e))
raise RuntimeError(f"Max retries ({max_retries}) exceeded. Last error: {e}")
logger.warning("API error (attempt %d/%d): %s", attempt + 1, max_retries, str(e))
delay = base_delay * (2**attempt)
print_error(
f"Encountered {e.__class__.__name__}: {e}. Retrying in {delay}s... (Attempt {attempt+1}/{max_retries})"
)
start = time.monotonic()
while time.monotonic() - start < delay:
check_interrupt()
time.sleep(0.1)
def get_agent_type(agent: RAgents) -> Literal["CiaynAgent", "React"]:
"""
Determines the type of the agent.
Returns "CiaynAgent" if agent is an instance of CiaynAgent, otherwise "React".
"""
if isinstance(agent, CiaynAgent):
return "CiaynAgent"
else:
return "React"
def init_fallback_handler(agent: RAgents, config: Dict[str, Any], tools: List[Any]):
"""
Initialize fallback handler if agent is of type "React" and experimental_fallback_handler is enabled; otherwise return None.
"""
if not config.get("experimental_fallback_handler", False):
return None
agent_type = get_agent_type(agent)
if agent_type == "React":
return FallbackHandler(config, tools)
return None
def _handle_fallback_response(
error: ToolExecutionError,
fallback_handler: Optional[FallbackHandler],
agent: RAgents,
msg_list: list,
) -> None:
"""
Handle fallback response by invoking fallback_handler and updating msg_list.
"""
if not fallback_handler:
return
fallback_response = fallback_handler.handle_failure(error, agent, msg_list)
agent_type = get_agent_type(agent)
if fallback_response and agent_type == "React":
msg_list_response = [HumanMessage(str(msg)) for msg in fallback_response]
msg_list.extend(msg_list_response)
def _run_agent_stream(agent: RAgents, msg_list: list[BaseMessage], config: dict):
for chunk in agent.stream({"messages": msg_list}, config):
logger.debug("Agent output: %s", chunk)
check_interrupt()
agent_type = get_agent_type(agent)
print_agent_output(chunk, agent_type)
if _global_memory["plan_completed"] or _global_memory["task_completed"]:
reset_agent_completion_flags()
break
def run_agent_with_retry(
agent: RAgents,
prompt: str,
config: dict,
fallback_handler: Optional[FallbackHandler] = None,
) -> Optional[str]:
"""Run an agent with retry logic for API errors."""
logger.debug("Running agent with prompt length: %d", len(prompt))
original_handler = _setup_interrupt_handling()
max_retries = 20
base_delay = 1
test_attempts = 0
_max_test_retries = config.get("max_test_cmd_retries", DEFAULT_MAX_TEST_CMD_RETRIES)
auto_test = config.get("auto_test", False)
original_prompt = prompt
msg_list = [HumanMessage(content=prompt)]
with InterruptibleSection():
try:
# Track agent execution depth
current_depth = _global_memory.get("agent_depth", 0)
_global_memory["agent_depth"] = current_depth + 1
_increment_agent_depth()
for attempt in range(max_retries):
logger.debug("Attempt %d/%d", attempt + 1, max_retries)
check_interrupt()
try:
for chunk in agent.stream(
{"messages": [HumanMessage(content=prompt)]}, config
):
logger.debug("Agent output: %s", chunk)
check_interrupt()
print_agent_output(chunk)
if _global_memory["plan_completed"]:
_global_memory["plan_completed"] = False
_global_memory["task_completed"] = False
break
if _global_memory["task_completed"]:
_global_memory["task_completed"] = False
break
# Execute test command if configured
_run_agent_stream(agent, msg_list, config)
if fallback_handler:
fallback_handler.reset_fallback_handler()
should_break, prompt, auto_test, test_attempts = (
execute_test_command(
config, original_prompt, test_attempts, auto_test
_execute_test_command_wrapper(
original_prompt, config, test_attempts, auto_test
)
)
if should_break:
@ -827,6 +939,13 @@ def run_agent_with_retry(agent, prompt: str, config: dict) -> Optional[str]:
logger.debug("Agent run completed successfully")
return "Agent run completed successfully"
except ToolExecutionError as e:
_handle_fallback_response(e, fallback_handler, agent, msg_list)
continue
except FallbackToolExecutionError as e:
msg_list.append(
SystemMessage(f"FallbackToolExecutionError:{str(e)}")
)
except (KeyboardInterrupt, AgentInterrupt):
raise
except (
@ -836,35 +955,7 @@ def run_agent_with_retry(agent, prompt: str, config: dict) -> Optional[str]:
APIError,
ValueError,
) as e:
if isinstance(e, ValueError):
error_str = str(e).lower()
if "code" not in error_str or "429" not in error_str:
raise # Re-raise ValueError if it's not a Lambda 429
if attempt == max_retries - 1:
logger.error("Max retries reached, failing: %s", str(e))
raise RuntimeError(
f"Max retries ({max_retries}) exceeded. Last error: {e}"
)
logger.warning(
"API error (attempt %d/%d): %s",
attempt + 1,
max_retries,
str(e),
)
delay = base_delay * (2**attempt)
print_error(
f"Encountered {e.__class__.__name__}: {e}. Retrying in {delay}s... (Attempt {attempt+1}/{max_retries})"
)
start = time.monotonic()
while time.monotonic() - start < delay:
check_interrupt()
time.sleep(0.1)
_handle_api_error(e, attempt, max_retries, base_delay)
finally:
# Reset depth tracking
_global_memory["agent_depth"] = _global_memory.get("agent_depth", 1) - 1
if (
original_handler
and threading.current_thread() is threading.main_thread()
):
signal.signal(signal.SIGINT, original_handler)
_decrement_agent_depth()
_restore_interrupt_handling(original_handler)

View File

@ -2,11 +2,17 @@ import re
from dataclasses import dataclass
from typing import Any, Dict, Generator, List, Optional, Union
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
from langchain_core.tools import BaseTool
from ra_aid.config import DEFAULT_MAX_TOOL_FAILURES
from ra_aid.exceptions import ToolExecutionError
from ra_aid.fallback_handler import FallbackHandler
from ra_aid.logging_config import get_logger
from ra_aid.models_params import DEFAULT_TOKEN_LIMIT
from ra_aid.prompts import CIAYN_AGENT_BASE_PROMPT, EXTRACT_TOOL_CALL_PROMPT
from ra_aid.tools.expert import get_model
from ra_aid.tools.reflection import get_function_info
logger = get_logger(__name__)
@ -70,10 +76,11 @@ class CiaynAgent:
def __init__(
self,
model,
tools: list,
model: BaseChatModel,
tools: list[BaseTool],
max_history_messages: int = 50,
max_tokens: Optional[int] = DEFAULT_TOKEN_LIMIT,
config: Optional[dict] = None,
):
"""Initialize the agent with a model and list of tools.
@ -82,18 +89,35 @@ class CiaynAgent:
tools: List of tools available to the agent
max_history_messages: Maximum number of messages to keep in chat history
max_tokens: Maximum number of tokens allowed in message history (None for no limit)
config: Optional configuration dictionary
"""
if config is None:
config = {}
self.config = config
self.provider = config.get("provider", "openai")
self.model = model
self.tools = tools
self.max_history_messages = max_history_messages
self.max_tokens = max_tokens
self.chat_history = []
self.available_functions = []
for t in tools:
self.available_functions.append(get_function_info(t.func))
self.fallback_handler = FallbackHandler(config, tools)
self.sys_message = SystemMessage(
"Execute efficiently yet completely as a fully autonomous agent."
)
self.error_message_template = "Your tool call caused an error: {e}\n\nPlease correct your tool call and try again."
self.fallback_fixed_msg = HumanMessage(
"Fallback tool handler has fixed the tool call see: <fallback tool call result> for the output."
)
def _build_prompt(self, last_result: Optional[str] = None) -> str:
"""Build the prompt for the agent including available tools and context."""
base_prompt = ""
if last_result is not None:
base_prompt += f"\n<last result>{last_result}</last result>"
@ -101,143 +125,62 @@ class CiaynAgent:
functions_list = "\n\n".join(self.available_functions)
# Build the complete prompt without f-strings for the static parts
base_prompt += (
"""
base_prompt += CIAYN_AGENT_BASE_PROMPT.format(functions_list=functions_list)
<agent instructions>
You are a ReAct agent. You run in a loop and use ONE of the available functions per iteration, but you will be called in a loop, so you will be able to accomplish the task over many iterations.
The result of that function call will be given to you in the next message.
Call one function at a time. Function arguments can be complex objects, long strings, etc. if needed.
The user cannot see the results of function calls, so you have to explicitly use a tool like ask_human if you want them to see something.
You must always respond with a single line of python that calls one of the available tools.
Use as many steps as you need to in order to fully complete the task.
Start by asking the user what they want.
You must carefully review the conversation history, which functions were called so far, returned results, etc., and make sure the very next function call you make makes sense in order to achieve the original goal.
You are expected to use as many steps as necessary to completely achieve the user's request, making many tool calls along the way.
Think hard about what the best *next* tool call is, knowing that you can make as many calls as you need to after that.
You typically don't want to keep calling the same function over and over with the same parameters.
</agent instructions>
You must ONLY use ONE of the following functions (these are the ONLY functions that exist):
<available functions>"""
+ functions_list
+ """
</available functions>
You may use any of the above functions to complete your job. Use the best one for the current step you are on. Be efficient, avoid getting stuck in repetitive loops, and do not hesitate to call functions which delegate your work to make your life easier.
But you MUST NOT assume tools exist that are not in the above list, e.g. put_complete_file_contents.
Consider your task done only once you have taken *ALL* the steps required to complete it.
--- EXAMPLE BAD OUTPUTS ---
This tool is not in available functions, so this is a bad tool call:
<example bad output>
put_complete_file_contents(...)
</example bad output>
This tool call has a syntax error (unclosed parenthesis, quotes), so it is bad:
<example bad output>
put_complete_file_contents("asdf
</example bad output>
This tool call is bad because it includes a message as well as backticks:
<example bad output>
Sure, I'll make the following tool call to accomplish what you asked me:
```
list_directory_tree('.')
```
</example bad output>
This tool call is bad because the output code is surrounded with backticks:
<example bad output>
```
list_directory_tree('.')
```
</example bad output>
The following is bad becasue it makes the same tool call multiple times in a row with the exact same parameters, for no reason, getting stuck in a loop:
<example bad output>
<response 1>
list_directory_tree('.')
</response 1>
<response 2>
list_directory_tree('.')
</response 2>
</example bad output>
The following is bad because it makes more than one tool call in one response:
<example bad output>
list_directory_tree('.')
read_file_tool('README.md') # Now we've made
</example bad output.
This is a good output because it calls the tool appropriately and with correct syntax:
--- EXAMPLE GOOD OUTPUTS ---
<example good output>
request_research_and_implementation(\"\"\"
Example query.
\"\"\")
</example good output>
This is good output because it uses a multiple line string when needed and properly calls the tool, does not output backticks or extra information:
<example good output>
run_programming_task(\"\"\"
# Example Programming Task
Implement a widget factory satisfying the following requirements:
- Requirement A
- Requirement B
...
\"\"\")
</example good output>
As an agent, you will carefully plan ahead, carefully analyze tool call responses, and adapt to circumstances in order to accomplish your goal.
You will make as many tool calls as you feel necessary in order to fully complete the task.
We're entrusting you with a lot of autonomy and power, so be efficient and don't mess up.
You have often been criticized for:
- Making the same function calls over and over, getting stuck in a loop.
DO NOT CLAIM YOU ARE FINISHED UNTIL YOU ACTUALLY ARE!
Output **ONLY THE CODE** and **NO MARKDOWN BACKTICKS**"""
)
# base_prompt += "\n\nYou must reply with ONLY ONE of the functions given in available functions."
return base_prompt
def _execute_tool(self, code: str) -> str:
def _execute_tool(self, msg: BaseMessage) -> str:
"""Execute a tool call and return its result."""
code = msg.content
globals_dict = {tool.func.__name__: tool.func for tool in self.tools}
try:
code = code.strip()
# code = code.replace("\n", " ")
if code.startswith("```"):
code = code[3:].strip()
if code.endswith("```"):
code = code[:-3].strip()
# if the eval fails, try to extract it via a model call
if validate_function_call_pattern(code):
functions_list = "\n\n".join(self.available_functions)
code = _extract_tool_call(code, functions_list)
code = self._extract_tool_call(code, functions_list)
pass
result = eval(code.strip(), globals_dict)
return result
except Exception as e:
error_msg = f"Error executing code: {str(e)}"
raise ToolExecutionError(error_msg)
error_msg = f"Error: {str(e)} \n Could not execute code: {code}"
tool_name = self.extract_tool_name(code)
raise ToolExecutionError(
error_msg, base_message=msg, tool_name=tool_name
) from e
def extract_tool_name(self, code: str) -> str:
match = re.match(r"\s*([\w_\-]+)\s*\(", code)
if match:
return match.group(1)
return ""
def handle_fallback_response(
self, fallback_response: list[Any], e: ToolExecutionError
) -> str:
err_msg = HumanMessage(content=self.error_message_template.format(e=e))
if not fallback_response:
self.chat_history.append(err_msg)
return ""
self.chat_history.append(self.fallback_fixed_msg)
msg = f"Fallback tool handler has triggered after consecutive failed tool calls reached {DEFAULT_MAX_TOOL_FAILURES} failures.\n"
# Passing the fallback raw invocation may confuse our llm, as invocation methods may differ.
# msg += f"<fallback llm raw invocation>{fallback_response[0]}</fallback llm raw invocation>\n"
msg += f"<fallback tool name>{e.tool_name}</fallback tool name>\n"
msg += f"<fallback tool call result>\n{fallback_response[1]}\n</fallback tool call result>\n"
return msg
def _create_agent_chunk(self, content: str) -> Dict[str, Any]:
"""Create an agent chunk in the format expected by print_agent_output."""
@ -251,7 +194,6 @@ Output **ONLY THE CODE** and **NO MARKDOWN BACKTICKS**"""
@staticmethod
def _estimate_tokens(content: Optional[Union[str, BaseMessage]]) -> int:
"""Estimate number of tokens in content using simple byte length heuristic.
Estimates 1 token per 2.0 bytes of content. For messages, uses the content field.
Args:
@ -277,6 +219,22 @@ Output **ONLY THE CODE** and **NO MARKDOWN BACKTICKS**"""
return len(text.encode("utf-8")) // 2.0
def _extract_tool_call(self, code: str, functions_list: str) -> str:
model = get_model()
prompt = EXTRACT_TOOL_CALL_PROMPT.format(
functions_list=functions_list, code=code
)
response = model.invoke(prompt)
response = response.content
pattern = r"([\w_\-]+)\((.*?)\)"
matches = re.findall(pattern, response, re.DOTALL)
if len(matches) == 0:
raise ToolExecutionError("Failed to extract tool call")
ma = matches[0][0].strip()
mb = matches[0][1].strip().replace("\n", " ")
return f"{ma}({mb})"
def _trim_chat_history(
self, initial_messages: List[Any], chat_history: List[Any]
) -> List[Any]:
@ -315,72 +273,28 @@ Output **ONLY THE CODE** and **NO MARKDOWN BACKTICKS**"""
return initial_messages + chat_history
def stream(
self, messages_dict: Dict[str, List[Any]], config: Dict[str, Any] = None
self, messages_dict: Dict[str, List[Any]], _config: Dict[str, Any] = None
) -> Generator[Dict[str, Any], None, None]:
"""Stream agent responses in a format compatible with print_agent_output."""
initial_messages = messages_dict.get("messages", [])
chat_history = []
self.chat_history = []
last_result = None
first_iteration = True
while True:
base_prompt = self._build_prompt(None if first_iteration else last_result)
chat_history.append(HumanMessage(content=base_prompt))
full_history = self._trim_chat_history(initial_messages, chat_history)
response = self.model.invoke(
[
SystemMessage(
"Execute efficiently yet completely as a fully autonomous agent."
)
]
+ full_history
)
base_prompt = self._build_prompt(last_result)
self.chat_history.append(HumanMessage(content=base_prompt))
full_history = self._trim_chat_history(initial_messages, self.chat_history)
response = self.model.invoke([self.sys_message] + full_history)
try:
logger.debug(f"Code generated by agent: {response.content}")
last_result = self._execute_tool(response.content)
chat_history.append(response)
first_iteration = False
last_result = self._execute_tool(response)
self.chat_history.append(response)
self.fallback_handler.reset_fallback_handler()
yield {}
except ToolExecutionError as e:
chat_history.append(
HumanMessage(
content=f"Your tool call caused an error: {e}\n\nPlease correct your tool call and try again."
)
fallback_response = self.fallback_handler.handle_failure(
e, self, self.chat_history
)
yield self._create_error_chunk(str(e))
def _extract_tool_call(code: str, functions_list: str) -> str:
from ra_aid.tools.expert import get_model
model = get_model()
prompt = f"""
I'm conversing with a AI model and requiring responses in a particular format: A function call with any parameters escaped. Here is an example:
```
run_programming_task("blah \" blah\" blah")
```
The following tasks are allowed:
{functions_list}
I got this invalid response from the model, can you format it so it becomes a correct function call?
```
{code}
```
"""
response = model.invoke(prompt)
response = response.content
pattern = r"([\w_\-]+)\((.*?)\)"
matches = re.findall(pattern, response, re.DOTALL)
if len(matches) == 0:
raise ToolExecutionError("Failed to extract tool call")
ma = matches[0][0].strip()
mb = matches[0][1].strip().replace("\n", " ")
return f"{ma}({mb})"
last_result = self.handle_fallback_response(fallback_response, e)
yield {}

11
ra_aid/agents_alias.py Normal file
View File

@ -0,0 +1,11 @@
from typing import TYPE_CHECKING
from langgraph.graph.graph import CompiledGraph
# Unfortunately need this to avoid Circular Imports
if TYPE_CHECKING:
from ra_aid.agents.ciayn_agent import CiaynAgent
RAgents = CompiledGraph | CiaynAgent
else:
RAgents = CompiledGraph

View File

@ -2,4 +2,17 @@
DEFAULT_RECURSION_LIMIT = 100
DEFAULT_MAX_TEST_CMD_RETRIES = 3
DEFAULT_MAX_TOOL_FAILURES = 3
FALLBACK_TOOL_MODEL_LIMIT = 5
RETRY_FALLBACK_COUNT = 3
DEFAULT_TEST_CMD_TIMEOUT = 60 * 5 # 5 minutes in seconds
VALID_PROVIDERS = [
"anthropic",
"openai",
"openrouter",
"openai-compatible",
"deepseek",
"gemini",
]

View File

@ -1,18 +1,23 @@
from typing import Any, Dict
from typing import Any, Dict, Literal, Optional
from langchain_core.messages import AIMessage
from rich.markdown import Markdown
from rich.panel import Panel
from ra_aid.exceptions import ToolExecutionError
# Import shared console instance
from .formatting import console
def print_agent_output(chunk: Dict[str, Any]) -> None:
def print_agent_output(
chunk: Dict[str, Any], agent_type: Literal["CiaynAgent", "React"]
) -> None:
"""Print only the agent's message content, not tool calls.
Args:
chunk: A dictionary containing agent or tool messages
chunk: A dictionary containing agent or tool messages.
agent_type: Specifies the type of agent. 'CiaynAgent' handles tool errors internally, while 'React' raises a ToolExecutionError.
"""
if "agent" in chunk and "messages" in chunk["agent"]:
messages = chunk["agent"]["messages"]
@ -33,10 +38,31 @@ def print_agent_output(chunk: Dict[str, Any]) -> None:
elif "tools" in chunk and "messages" in chunk["tools"]:
for msg in chunk["tools"]["messages"]:
if msg.status == "error" and msg.content:
err_msg = msg.content.strip()
console.print(
Panel(
Markdown(msg.content.strip()),
Markdown(err_msg),
title="❌ Tool Error",
border_style="red bold",
)
)
tool_name = getattr(msg, "name", None)
# CiaynAgent handles this internally
if agent_type == "React":
raise ToolExecutionError(
err_msg, tool_name=tool_name, base_message=msg
)
def cpm(message: str, title: Optional[str] = None, border_style: str = "blue") -> None:
"""
Print a message using a Panel with Markdown formatting.
Args:
message (str): The message content to display.
title (Optional[str]): An optional title for the panel.
border_style (str): Border style for the panel.
"""
console.print(Panel(Markdown(message), title=title, border_style=border_style))

View File

@ -1,5 +1,9 @@
"""Custom exceptions for RA.Aid."""
from typing import Optional
from langchain_core.messages import BaseMessage
class AgentInterrupt(Exception):
"""Exception raised when an agent's execution is interrupted.
@ -18,4 +22,18 @@ class ToolExecutionError(Exception):
from other types of errors in the agent system.
"""
def __init__(
self,
message: str,
base_message: Optional[BaseMessage] = None,
tool_name: Optional[str] = None,
):
super().__init__(message)
self.base_message = base_message
self.tool_name = tool_name
class FallbackToolExecutionError(Exception):
"""Exception raised when a fallback tool execution fails."""
pass

437
ra_aid/fallback_handler.py Normal file
View File

@ -0,0 +1,437 @@
import json
import re
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.tools import BaseTool
from langgraph.graph.message import BaseMessage
from ra_aid.agents_alias import RAgents
from ra_aid.config import (
DEFAULT_MAX_TOOL_FAILURES,
FALLBACK_TOOL_MODEL_LIMIT,
RETRY_FALLBACK_COUNT,
)
from ra_aid.console.output import cpm
from ra_aid.exceptions import FallbackToolExecutionError, ToolExecutionError
from ra_aid.llm import initialize_llm, validate_provider_env
from ra_aid.logging_config import get_logger
from ra_aid.tool_configs import get_all_tools
from ra_aid.tool_leaderboard import supported_top_tool_models
logger = get_logger(__name__)
class FallbackHandler:
"""
FallbackHandler manages fallback logic when tool execution fails.
It loads fallback models from configuration and validated provider settings,
maintains failure counts, and triggers appropriate fallback methods for both
prompt-based and function-calling tool invocations. It also resets internal
counters when a tool call succeeds.
"""
def __init__(self, config, tools):
"""
Initialize the FallbackHandler with the given configuration and tools.
Args:
config (dict): Configuration dictionary that may include fallback settings.
tools (list): List of available tools.
"""
self.config = config
self.tools: list[BaseTool] = tools
self.fallback_enabled = config.get("experimental_fallback_handler", False)
self.fallback_tool_models = self._load_fallback_tool_models(config)
self.max_failures = config.get("max_tool_failures", DEFAULT_MAX_TOOL_FAILURES)
self.tool_failure_consecutive_failures = 0
self.failed_messages: list[BaseMessage] = []
self.current_failing_tool_name = ""
self.current_tool_to_bind: None | BaseTool = None
self.msg_list: list[BaseMessage] = []
cpm(
"Fallback models selected: "
+ ", ".join([self._format_model(m) for m in self.fallback_tool_models]),
title="Fallback Models",
)
def _format_model(self, m: dict) -> str:
return f"{m.get('model', '')} ({m.get('type', 'prompt')})"
def _load_fallback_tool_models(self, _config):
"""
Load and return fallback tool models based on the provided configuration.
If the config specifies 'fallback_tool_models', those are used (assuming comma-separated names).
Otherwise, this method filters the supported_top_tool_models based on provider environment validation,
selecting up to FALLBACK_TOOL_MODEL_LIMIT models.
Args:
config (dict): Configuration dictionary.
Returns:
list of dict: Each dictionary contains keys 'model' and 'type' representing a fallback model.
"""
supported = []
skipped = []
for item in supported_top_tool_models:
provider = item.get("provider")
model_name = item.get("model")
if validate_provider_env(provider):
supported.append(item)
if len(supported) == FALLBACK_TOOL_MODEL_LIMIT:
break
else:
skipped.append(model_name)
final_models = []
for item in supported:
if "type" not in item:
item["type"] = "prompt"
item["model"] = item["model"].lower()
final_models.append(item)
message = "Fallback models selected: " + ", ".join(
[m["model"] for m in final_models]
)
if skipped:
message += (
"\nSkipped top tool calling models due to missing provider ENV API keys: "
+ ", ".join(skipped)
)
return final_models
def handle_failure(
self, error: ToolExecutionError, agent: RAgents, msg_list: list[BaseMessage]
):
"""
Handle a tool failure by incrementing the failure counter and triggering fallback if thresholds are exceeded.
Args:
error (Exception): The exception raised during execution. If the exception has a 'base_message' attribute, that message is recorded.
agent: The agent instance on which fallback may be executed.
"""
if not self.fallback_enabled:
return None
if self.tool_failure_consecutive_failures == 0:
self.init_msg_list(msg_list)
failed_tool_call_name = self.extract_failed_tool_name(error)
self._reset_on_new_failure(failed_tool_call_name)
logger.debug(
f"_handle_tool_failure: tool failure encountered for code '{failed_tool_call_name}' with error: {error}"
)
self.current_failing_tool_name = failed_tool_call_name
self.current_tool_to_bind = self._find_tool_to_bind(
agent, failed_tool_call_name
)
if hasattr(error, "base_message") and error.base_message:
self.failed_messages.append(error.base_message)
self.tool_failure_consecutive_failures += 1
logger.debug(
f"_handle_tool_failure: failure count {self.tool_failure_consecutive_failures}, max_failures {self.max_failures}"
)
if (
self.tool_failure_consecutive_failures >= self.max_failures
and self.fallback_tool_models
):
logger.debug(
"_handle_tool_failure: threshold reached, invoking fallback mechanism."
)
return self.attempt_fallback()
def attempt_fallback(self):
"""
Initiate the fallback process by iterating over all fallback models to attempt to fix the failing tool call.
Returns:
List of [raw_llm_response (SystemMessage), tool_call_result (SystemMessage)] or None.
"""
logger.debug(
f"Tool call failed {self.tool_failure_consecutive_failures} times. Attempting fallback for tool: {self.current_failing_tool_name}"
)
cpm(
f"**Tool fallback activated**: Attempting fallback for tool {self.current_failing_tool_name}.",
title="Fallback Notification",
)
for fallback_model in self.fallback_tool_models:
result_list = self.invoke_fallback(fallback_model)
if result_list:
return result_list
cpm("All fallback models have failed.", title="Fallback Failed")
current_failing_tool_name = self.current_failing_tool_name
self.reset_fallback_handler()
raise FallbackToolExecutionError(
f"All fallback models have failed for tool: {current_failing_tool_name}"
)
def reset_fallback_handler(self):
"""
Reset the fallback handler's internal failure counters and clear the record of used fallback models.
"""
self.tool_failure_consecutive_failures = 0
self.failed_messages.clear()
self.fallback_tool_models = self._load_fallback_tool_models(self.config)
self.current_failing_tool_name = ""
self.current_tool_to_bind = None
self.msg_list = []
def _reset_on_new_failure(self, failed_tool_call_name):
if (
self.current_failing_tool_name
and failed_tool_call_name != self.current_failing_tool_name
):
logger.debug(
"New failing tool name identified. Resetting consecutive tool failures."
)
self.reset_fallback_handler()
def extract_failed_tool_name(self, error: ToolExecutionError):
if error.tool_name:
failed_tool_call_name = error.tool_name
else:
msg = str(error)
logger.debug("Error message: %s", msg)
match = re.search(r"name=['\"](\w+)['\"]", msg)
if match:
failed_tool_call_name = str(match.group(1))
logger.debug(
"Extracted failed_tool_call_name using regex: %s",
failed_tool_call_name,
)
else:
failed_tool_call_name = "Tool execution error"
raise FallbackToolExecutionError(
"Fallback failed: Could not extract failed tool name."
)
return failed_tool_call_name
def _find_tool_to_bind(self, agent, failed_tool_call_name):
logger.debug(f"failed_tool_call_name={failed_tool_call_name}")
tool_to_bind = None
if hasattr(agent, "tools"):
tool_to_bind = next(
(t for t in agent.tools if t.func.__name__ == failed_tool_call_name),
None,
)
if tool_to_bind is None:
all_tools = get_all_tools()
tool_to_bind = next(
(t for t in all_tools if t.func.__name__ == failed_tool_call_name),
None,
)
if tool_to_bind is None:
# TODO: Would be nice to try fuzzy match or levenstein str match to find closest correspond tool name
raise FallbackToolExecutionError(
f"Fallback failed failed_tool_call_name: '{failed_tool_call_name}' not found in any available tools."
)
return tool_to_bind
def _bind_tool_model(self, simple_model: BaseChatModel, fallback_model):
if fallback_model.get("type", "prompt").lower() == "fc":
# Force tool calling with tool_choice param.
bound_model = simple_model.bind_tools(
[self.current_tool_to_bind],
tool_choice=self.current_failing_tool_name,
)
else:
# Do not force tool calling (Prompt method)
bound_model = simple_model.bind_tools([self.current_tool_to_bind])
return bound_model
def invoke_fallback(self, fallback_model):
"""
Attempt a Prompt or function-calling fallback by invoking the current failing tool with the given fallback model.
Args:
fallback_model (dict): The fallback model to use.
Returns:
The response from the fallback model invocation, or None if failed.
"""
try:
logger.debug(f"Trying fallback model: {self._format_model(fallback_model)}")
simple_model = initialize_llm(
fallback_model["provider"], fallback_model["model"]
)
bound_model = self._bind_tool_model(simple_model, fallback_model)
retry_model = bound_model.with_retry(
stop_after_attempt=RETRY_FALLBACK_COUNT
)
msg_list = self.construct_prompt_msg_list()
response = retry_model.invoke(msg_list)
tool_call = self.base_message_to_tool_call_dict(response)
tool_call_result = self.invoke_prompt_tool_call(tool_call)
# cpm(str(tool_call_result), title="Fallback Tool Call Result")
logger.debug(
f"Fallback call successful with model: {self._format_model(fallback_model)}"
)
self.reset_fallback_handler()
return [response, tool_call_result]
except Exception as e:
if isinstance(e, KeyboardInterrupt):
raise
logger.error(
f"Fallback with model {self._format_model(fallback_model)} failed: {e}"
)
return None
def construct_prompt_msg_list(self):
"""
Construct a list of chat messages for the fallback prompt.
The initial message instructs the assistant that it is a fallback tool caller.
Then includes the failed tool call messages from self.failed_messages.
Finally, it appends a human message asking it to retry calling the tool with correct valid arguments.
Returns:
list: A list of chat messages.
"""
prompt_msg_list: list[BaseMessage] = []
prompt_msg_list.append(
SystemMessage(
content="You are a fallback tool caller. Your only responsibility is to figure out what the previous failed tool call was trying to do and to call that tool with the correct format and arguments, using the provided failure messages."
)
)
# TODO: Have some way to use the correct message type in the future, dont just convert everything to system message.
# This may be difficult as each model type may require different chat structures and throw API errors.
prompt_msg_list.extend(SystemMessage(str(msg)) for msg in self.msg_list)
if self.failed_messages:
# Convert to system messages to avoid API errors asking for correct msg structure
prompt_msg_list.extend(
[SystemMessage(str(msg)) for msg in self.failed_messages]
)
prompt_msg_list.append(
HumanMessage(
content=f"Retry using the tool: '{self.current_failing_tool_name}' with correct arguments and formatting."
)
)
return prompt_msg_list
def invoke_prompt_tool_call(self, tool_call_request: dict):
"""
Invoke a tool call from a prompt-based fallback response.
Args:
tool_call_request (dict): The tool call request containing keys 'type', 'name', and 'arguments'.
Returns:
The result of invoking the tool.
"""
tool_name_to_tool = {
getattr(tool.func, "__name__", None): tool for tool in self.tools
}
name = tool_call_request["name"]
arguments = tool_call_request["arguments"]
if name in tool_name_to_tool:
return tool_name_to_tool[name].invoke(arguments)
elif (
self.current_tool_to_bind is not None
and getattr(self.current_tool_to_bind.func, "__name__", None) == name
):
return self.current_tool_to_bind.invoke(arguments)
else:
raise FallbackToolExecutionError(
f"Tool '{name}' not found in available tools."
)
def base_message_to_tool_call_dict(self, response: BaseMessage):
"""
Extracts a tool call dictionary from a BaseMessage.
Args:
response: The response object containing tool call data.
Returns:
A tool call dictionary with keys 'id', 'type', 'name', and 'arguments' if a tool call is found,
otherwise None.
"""
tool_calls = self.get_tool_calls(response)
if not tool_calls:
raise FallbackToolExecutionError(
f"Could not extract tool_call_dict from response: {response}"
)
if len(tool_calls) > 1:
logger.warning("Multiple tool calls detected, using the first one")
tool_call = tool_calls[0]
return {
"id": tool_call["id"],
"type": tool_call["type"],
"name": tool_call["function"]["name"],
"arguments": self._parse_tool_arguments(tool_call["function"]["arguments"]),
}
def _parse_tool_arguments(self, tool_arguments):
"""
Helper method to parse tool call arguments.
If tool_arguments is a string, it returns the JSON-parsed dictionary.
Otherwise, returns tool_arguments as is.
"""
if isinstance(tool_arguments, str):
return json.loads(tool_arguments)
return tool_arguments
def get_tool_calls(self, response: BaseMessage):
"""
Extracts tool calls list from a fallback response.
Args:
response: The response object containing tool call data.
Returns:
The tool calls list if present, otherwise None.
"""
tool_calls = None
if hasattr(response, "additional_kwargs") and response.additional_kwargs.get(
"tool_calls"
):
tool_calls = response.additional_kwargs.get("tool_calls")
elif hasattr(response, "tool_calls"):
tool_calls = response.tool_calls
elif isinstance(response, dict) and response.get("additional_kwargs", {}).get(
"tool_calls"
):
tool_calls = response.get("additional_kwargs").get("tool_calls")
return tool_calls
def handle_failure_response(
self, error: ToolExecutionError, agent, agent_type: str
):
"""
Handle a tool failure by calling handle_failure and, if a fallback response is returned and the agent type is "React",
return a list of SystemMessage objects wrapping each message from the fallback response.
"""
fallback_response = self.handle_failure(error, agent)
if fallback_response and agent_type == "React":
return [SystemMessage(str(msg)) for msg in fallback_response]
return None
def init_msg_list(self, full_msg_list: list[BaseMessage]) -> None:
first_two = full_msg_list[:2]
last_two = full_msg_list[-2:]
merged = first_two.copy()
for msg in last_two:
if msg not in merged:
merged.append(msg)
self.msg_list = merged

View File

@ -8,6 +8,7 @@ from langchain_openai import ChatOpenAI
from openai import OpenAI
from ra_aid.chat_models.deepseek_chat import ChatDeepseekReasoner
from ra_aid.console.output import cpm
from ra_aid.logging_config import get_logger
from .models_params import models_params
@ -96,9 +97,9 @@ def create_deepseek_client(
return ChatDeepseekReasoner(
api_key=api_key,
base_url=base_url,
temperature=0
if is_expert
else (temperature if temperature is not None else 1),
temperature=(
0 if is_expert else (temperature if temperature is not None else 1)
),
model=model_name,
timeout=LLM_REQUEST_TIMEOUT,
max_retries=LLM_MAX_RETRIES,
@ -125,9 +126,9 @@ def create_openrouter_client(
return ChatDeepseekReasoner(
api_key=api_key,
base_url="https://openrouter.ai/api/v1",
temperature=0
if is_expert
else (temperature if temperature is not None else 1),
temperature=(
0 if is_expert else (temperature if temperature is not None else 1)
),
model=model_name,
timeout=LLM_REQUEST_TIMEOUT,
max_retries=LLM_MAX_RETRIES,
@ -171,7 +172,12 @@ def get_provider_config(provider: str, is_expert: bool = False) -> Dict[str, Any
"base_url": "https://api.deepseek.com",
},
}
return configs.get(provider, {})
config = configs.get(provider, {})
if not config or not config.get("api_key"):
raise ValueError(
f"Missing required environment variable for provider: {provider}"
)
return config
def create_llm_client(
@ -222,8 +228,9 @@ def create_llm_client(
temp_kwargs = {"temperature": 0} if supports_temperature else {}
elif supports_temperature:
if temperature is None:
raise ValueError(
f"Temperature must be provided for model {model_name} which supports temperature"
temperature = 0.7
cpm(
"This model supports temperature argument but none was given. Setting default temperature to 0.7."
)
temp_kwargs = {"temperature": temperature}
else:
@ -298,3 +305,19 @@ def initialize_llm(
def initialize_expert_llm(provider: str, model_name: str) -> BaseChatModel:
"""Initialize an expert language model client based on the specified provider and model."""
return create_llm_client(provider, model_name, temperature=None, is_expert=True)
def validate_provider_env(provider: str) -> bool:
"""Check if the required environment variables for a provider are set."""
required_vars = {
"openai": "OPENAI_API_KEY",
"anthropic": "ANTHROPIC_API_KEY",
"openrouter": "OPENROUTER_API_KEY",
"openai-compatible": "OPENAI_API_KEY",
"gemini": "GEMINI_API_KEY",
"deepseek": "DEEPSEEK_API_KEY",
}
key = required_vars.get(provider.lower())
if key:
return bool(os.getenv(key))
return False

View File

@ -2,17 +2,54 @@ import logging
import sys
from typing import Optional
from rich.console import Console
from rich.markdown import Markdown
from rich.panel import Panel
def setup_logging(verbose: bool = False) -> None:
class PrettyHandler(logging.Handler):
def __init__(self, level=logging.NOTSET):
super().__init__(level)
self.console = Console()
def emit(self, record):
try:
msg = self.format(record)
# Determine title and style based on log level
if record.levelno >= logging.CRITICAL:
title = "🔥 CRITICAL"
style = "bold red"
elif record.levelno >= logging.ERROR:
title = "❌ ERROR"
style = "red"
elif record.levelno >= logging.WARNING:
title = "⚠️ WARNING"
style = "yellow"
elif record.levelno >= logging.INFO:
title = " INFO"
style = "green"
else:
title = "🐞 DEBUG"
style = "blue"
self.console.print(Panel(Markdown(msg.strip()), title=title, style=style))
except Exception:
self.handleError(record)
def setup_logging(verbose: bool = False, pretty: bool = False) -> None:
logger = logging.getLogger("ra_aid")
logger.setLevel(logging.DEBUG if verbose else logging.INFO)
if not logger.handlers:
handler = logging.StreamHandler(sys.stdout)
formatter = logging.Formatter(
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
handler.setFormatter(formatter)
if pretty:
handler = PrettyHandler()
else:
print("USING STREAM HANDLER LOGGER")
handler = logging.StreamHandler(sys.stdout)
formatter = logging.Formatter(
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
handler.setFormatter(formatter)
logger.addHandler(handler)

View File

@ -984,3 +984,130 @@ Remember, if you do not make any tool call (e.g. ask_human to tell them a messag
NEVER ANNOUNCE WHAT YOU ARE DOING, JUST DO IT!
"""
EXTRACT_TOOL_CALL_PROMPT = """I'm conversing with a AI model and requiring responses in a particular format: A function call with any parameters escaped. Here is an example:
```
run_programming_task("blah \" blah\" blah")
```
The following tasks are allowed:
{functions_list}
I got this invalid response from the model, can you format it so it becomes a correct function call?
```
{code}
```"""
CIAYN_AGENT_BASE_PROMPT = """<agent instructions>
You are a ReAct agent. You run in a loop and use ONE of the available functions per iteration, but you will be called in a loop, so you will be able to accomplish the task over many iterations.
The result of that function call will be given to you in the next message.
Call one function at a time. Function arguments can be complex objects, long strings, etc. if needed.
The user cannot see the results of function calls, so you have to explicitly use a tool like ask_human if you want them to see something.
You must always respond with a single line of python that calls one of the available tools.
Use as many steps as you need to in order to fully complete the task.
Start by asking the user what they want.
You must carefully review the conversation history, which functions were called so far, returned results, etc., and make sure the very next function call you make makes sense in order to achieve the original goal.
You are expected to use as many steps as necessary to completely achieve the user's request, making many tool calls along the way.
Think hard about what the best *next* tool call is, knowing that you can make as many calls as you need to after that.
You typically don't want to keep calling the same function over and over with the same parameters.
</agent instructions>
You must ONLY use ONE of the following functions (these are the ONLY functions that exist):
<available functions>{functions_list}
</available functions>
You may use any of the above functions to complete your job. Use the best one for the current step you are on. Be efficient, avoid getting stuck in repetitive loops, and do not hesitate to call functions which delegate your work to make your life easier.
But you MUST NOT assume tools exist that are not in the above list, e.g. write_file_tool.
Consider your task done only once you have taken *ALL* the steps required to complete it.
--- EXAMPLE BAD OUTPUTS ---
This tool is not in available functions, so this is a bad tool call:
<example bad output>
write_file_tool(...)
</example bad output>
This tool call has a syntax error (unclosed parenthesis, quotes), so it is bad:
<example bad output>
write_file_tool("asdf
</example bad output>
This tool call is bad because it includes a message as well as backticks:
<example bad output>
Sure, I'll make the following tool call to accomplish what you asked me:
```
list_directory_tree('.')
```
</example bad output>
This tool call is bad because the output code is surrounded with backticks:
<example bad output>
```
list_directory_tree('.')
```
</example bad output>
The following is bad becasue it makes the same tool call multiple times in a row with the exact same parameters, for no reason, getting stuck in a loop:
<example bad output>
<response 1>
list_directory_tree('.')
</response 1>
<response 2>
list_directory_tree('.')
</response 2>
</example bad output>
The following is bad because it makes more than one tool call in one response:
<example bad output>
list_directory_tree('.')
read_file_tool('README.md') # Now we've made
</example bad output.
This is a good output because it calls the tool appropriately and with correct syntax:
--- EXAMPLE GOOD OUTPUTS ---
<example good output>
request_research_and_implementation(\"\"\"
Example query.
\"\"\")
</example good output>
This is good output because it uses a multiple line string when needed and properly calls the tool, does not output backticks or extra information:
<example good output>
run_programming_task(\"\"\"
# Example Programming Task
Implement a widget factory satisfying the following requirements:
- Requirement A
- Requirement B
...
\"\"\")
</example good output>
As an agent, you will carefully plan ahead, carefully analyze tool call responses, and adapt to circumstances in order to accomplish your goal.
You will make as many tool calls as you feel necessary in order to fully complete the task.
We're entrusting you with a lot of autonomy and power, so be efficient and don't mess up.
You have often been criticized for:
- Making the same function calls over and over, getting stuck in a loop.
DO NOT CLAIM YOU ARE FINISHED UNTIL YOU ACTUALLY ARE!
Output **ONLY THE CODE** and **NO MARKDOWN BACKTICKS**
"""

View File

@ -1,3 +1,5 @@
from langchain_core.tools import BaseTool
from ra_aid.tools import (
ask_expert,
ask_human,
@ -8,7 +10,6 @@ from ra_aid.tools import (
emit_research_notes,
fuzzy_find_project_files,
list_directory_tree,
plan_implementation_completed,
read_file_tool,
ripgrep_search,
run_programming_task,
@ -23,13 +24,13 @@ from ra_aid.tools.agent import (
request_task_implementation,
request_web_research,
)
from ra_aid.tools.memory import one_shot_completed
from ra_aid.tools.memory import one_shot_completed, plan_implementation_completed
# Read-only tools that don't modify system state
def get_read_only_tools(
human_interaction: bool = False, web_research_enabled: bool = False
) -> list:
):
"""Get the list of read-only tools, optionally including human interaction tools.
Args:
@ -63,6 +64,18 @@ def get_read_only_tools(
return tools
def get_all_tools() -> list[BaseTool]:
"""Return a list containing all available tools from different groups."""
all_tools = []
all_tools.extend(get_read_only_tools())
all_tools.extend(MODIFICATION_TOOLS)
all_tools.extend(EXPERT_TOOLS)
all_tools.extend(RESEARCH_TOOLS)
all_tools.extend(get_web_research_tools())
all_tools.extend(get_chat_tools())
return all_tools
# Define constant tool groups
READ_ONLY_TOOLS = get_read_only_tools()
# MODIFICATION_TOOLS = [run_programming_task, put_complete_file_contents]
@ -85,7 +98,7 @@ def get_research_tools(
expert_enabled: bool = True,
human_interaction: bool = False,
web_research_enabled: bool = False,
) -> list:
):
"""Get the list of research tools based on mode and whether expert is enabled.
Args:
@ -165,7 +178,7 @@ def get_implementation_tools(
return tools
def get_web_research_tools(expert_enabled: bool = True) -> list:
def get_web_research_tools(expert_enabled: bool = True):
"""Get the list of tools available for web research.
Args:
@ -185,9 +198,7 @@ def get_web_research_tools(expert_enabled: bool = True) -> list:
return tools
def get_chat_tools(
expert_enabled: bool = True, web_research_enabled: bool = False
) -> list:
def get_chat_tools(expert_enabled: bool = True, web_research_enabled: bool = False):
"""Get the list of tools available in chat mode.
Chat mode includes research and implementation capabilities but excludes

529
ra_aid/tool_leaderboard.py Normal file
View File

@ -0,0 +1,529 @@
from ra_aid.config import VALID_PROVIDERS
# Data extracted at 2/10/2025:
# https://gorilla.cs.berkeley.edu/leaderboard.html
# In order of overall_acc
leaderboard_data = [
{
"overall_acc": 74.31,
"model": "watt-tool-70B",
"type": "FC",
"link": "https://huggingface.co/watt-ai/watt-tool-70B/",
"cost": "N/A",
"latency": 3.4,
"ast_summary": 84.06,
"exec_summary": 89.39,
"live_ast_acc": 77.74,
"multi_turn_acc": 58.75,
"relevance": 94.44,
"irrelevance": 76.32,
"organization": "Watt AI Lab",
"license": "Apache-2.0",
"provider": "unknown",
},
{
"overall_acc": 72.08,
"model": "gpt-4o-2024-11-20",
"type": "Prompt",
"link": "https://openai.com/index/hello-gpt-4o/",
"cost": 13.54,
"latency": 0.78,
"ast_summary": 88.1,
"exec_summary": 89.38,
"live_ast_acc": 79.83,
"multi_turn_acc": 47.62,
"relevance": 83.33,
"irrelevance": 83.76,
"organization": "OpenAI",
"license": "Proprietary",
"provider": "openai",
},
{
"overall_acc": 69.58,
"model": "gpt-4o-2024-11-20",
"type": "FC",
"link": "https://openai.com/index/hello-gpt-4o/",
"cost": 8.23,
"latency": 1.11,
"ast_summary": 87.42,
"exec_summary": 89.2,
"live_ast_acc": 79.65,
"multi_turn_acc": 41,
"relevance": 83.33,
"irrelevance": 83.15,
"organization": "OpenAI",
"license": "Proprietary",
"provider": "openai",
},
{
"overall_acc": 67.98,
"model": "watt-tool-8B",
"type": "FC",
"link": "https://huggingface.co/watt-ai/watt-tool-8B/",
"cost": "N/A",
"latency": 1.31,
"ast_summary": 86.56,
"exec_summary": 89.34,
"live_ast_acc": 76.5,
"multi_turn_acc": 39.12,
"relevance": 83.33,
"irrelevance": 83.15,
"organization": "Watt AI Lab",
"license": "Apache-2.0",
"provider": "unknown",
},
{
"overall_acc": 67.88,
"model": "GPT-4-turbo-2024-04-09",
"type": "FC",
"link": "https://platform.openai.com/docs/models/gpt-4-and-gpt-4-turbo",
"cost": 33.22,
"latency": 2.47,
"ast_summary": 84.73,
"exec_summary": 85.21,
"live_ast_acc": 80.5,
"multi_turn_acc": 38.12,
"relevance": 72.22,
"irrelevance": 83.81,
"organization": "OpenAI",
"license": "Proprietary",
"provider": "openai",
},
{
"overall_acc": 66.73,
"model": "o1-2024-12-17",
"type": "Prompt",
"link": "https://openai.com/o1/",
"cost": 102.47,
"latency": 5.3,
"ast_summary": 85.67,
"exec_summary": 79.77,
"live_ast_acc": 80.63,
"multi_turn_acc": 36,
"relevance": 72.22,
"irrelevance": 87.78,
"organization": "OpenAI",
"license": "Proprietary",
"provider": "openai",
},
{
"overall_acc": 64.1,
"model": "GPT-4o-mini-2024-07-18",
"type": "FC",
"link": "https://openai.com/index/gpt-4o-mini-advancing-cost-efficient-intelligence/",
"cost": 0.51,
"latency": 1.49,
"ast_summary": 85.21,
"exec_summary": 83.57,
"live_ast_acc": 74.41,
"multi_turn_acc": 34.12,
"relevance": 83.33,
"irrelevance": 74.75,
"organization": "OpenAI",
"license": "Proprietary",
"provider": "openai",
},
{
"overall_acc": 62.79,
"model": "o1-mini-2024-09-12",
"type": "Prompt",
"link": "https://openai.com/index/openai-o1-mini-advancing-cost-efficient-reasoning/",
"cost": 29.76,
"latency": 8.44,
"ast_summary": 78.92,
"exec_summary": 82.7,
"live_ast_acc": 78.14,
"multi_turn_acc": 28.25,
"relevance": 61.11,
"irrelevance": 89.62,
"organization": "OpenAI",
"license": "Proprietary",
"provider": "openai",
},
{
"overall_acc": 62.73,
"model": "Functionary-Medium-v3.1",
"type": "FC",
"link": "https://huggingface.co/meetkai/functionary-medium-v3.1",
"cost": "N/A",
"latency": 14.06,
"ast_summary": 89.88,
"exec_summary": 91.32,
"live_ast_acc": 76.63,
"multi_turn_acc": 21.62,
"relevance": 72.22,
"irrelevance": 76.08,
"organization": "MeetKai",
"license": "MIT",
"provider": "unknown",
},
{
"overall_acc": 62.19,
"model": "Gemini-1.5-Pro-002",
"type": "Prompt",
"link": "https://deepmind.google/technologies/gemini/pro/",
"cost": 7.05,
"latency": 5.94,
"ast_summary": 88.58,
"exec_summary": 91.27,
"live_ast_acc": 76.72,
"multi_turn_acc": 20.75,
"relevance": 72.22,
"irrelevance": 78.15,
"organization": "Google",
"license": "Proprietary",
"provider": "google",
},
{
"overall_acc": 61.83,
"model": "Hammer2.1-7b",
"type": "FC",
"link": "https://huggingface.co/MadeAgents/Hammer2.1-7b",
"cost": "N/A",
"latency": 2.08,
"ast_summary": 88.65,
"exec_summary": 85.48,
"live_ast_acc": 75.11,
"multi_turn_acc": 23.5,
"relevance": 82.35,
"irrelevance": 78.59,
"organization": "MadeAgents",
"license": "cc-by-nc-4.0",
"provider": "unknown",
},
{
"overall_acc": 61.74,
"model": "Gemini-2.0-Flash-Exp",
"type": "Prompt",
"link": "https://deepmind.google/technologies/gemini/flash/",
"cost": 0.0,
"latency": 1.18,
"ast_summary": 89.96,
"exec_summary": 79.89,
"live_ast_acc": 82.01,
"multi_turn_acc": 17.88,
"relevance": 77.78,
"irrelevance": 86.44,
"organization": "Google",
"license": "Proprietary",
"provider": "google",
},
{
"overall_acc": 61.38,
"model": "Amazon-Nova-Pro-v1:0",
"type": "FC",
"link": "https://aws.amazon.com/cn/ai/generative-ai/nova/",
"cost": 5.26,
"latency": 2.67,
"ast_summary": 84.46,
"exec_summary": 85.64,
"live_ast_acc": 74.32,
"multi_turn_acc": 26.12,
"relevance": 77.78,
"irrelevance": 70.98,
"organization": "Amazon",
"license": "Proprietary",
"provider": "unknown",
},
{
"overall_acc": 61.31,
"model": "Qwen2.5-72B-Instruct",
"type": "Prompt",
"link": "https://huggingface.co/Qwen/Qwen2.5-72B-Instruct",
"cost": "N/A",
"latency": 3.72,
"ast_summary": 90.81,
"exec_summary": 92.7,
"live_ast_acc": 75.3,
"multi_turn_acc": 18,
"relevance": 100,
"irrelevance": 72.81,
"organization": "Qwen",
"license": "qwen",
"provider": "unknown",
},
{
"overall_acc": 60.97,
"model": "Gemini-1.5-Pro-002",
"type": "FC",
"link": "https://deepmind.google/technologies/gemini/pro/",
"cost": 5.39,
"latency": 2.07,
"ast_summary": 87.29,
"exec_summary": 84.61,
"live_ast_acc": 76.28,
"multi_turn_acc": 21.62,
"relevance": 72.22,
"irrelevance": 76.9,
"organization": "Google",
"license": "Proprietary",
"provider": "google",
},
{
"overall_acc": 60.89,
"model": "GPT-4o-mini-2024-07-18",
"type": "Prompt",
"link": "https://openai.com/index/gpt-4o-mini-advancing-cost-efficient-intelligence/",
"cost": 0.84,
"latency": 1.31,
"ast_summary": 86.77,
"exec_summary": 80.84,
"live_ast_acc": 76.5,
"multi_turn_acc": 22,
"relevance": 83.33,
"irrelevance": 80.67,
"organization": "OpenAI",
"license": "Proprietary",
"provider": "openai",
},
{
"overall_acc": 60.59,
"model": "Gemini-2.0-Flash-Exp",
"type": "FC",
"link": "https://deepmind.google/technologies/gemini/flash/",
"cost": 0.0,
"latency": 0.85,
"ast_summary": 85.1,
"exec_summary": 77.46,
"live_ast_acc": 79.03,
"multi_turn_acc": 20.25,
"relevance": 55.56,
"irrelevance": 91.51,
"organization": "Google",
"license": "Proprietary",
"provider": "google",
},
{
"overall_acc": 60.46,
"model": "Gemini-1.5-Pro-001",
"type": "Prompt",
"link": "https://deepmind.google/technologies/gemini/pro/",
"cost": 7.0,
"latency": 1.54,
"ast_summary": 85.56,
"exec_summary": 85.77,
"live_ast_acc": 76.68,
"multi_turn_acc": 18.88,
"relevance": 55.56,
"irrelevance": 84.81,
"organization": "Google",
"license": "Proprietary",
"provider": "google",
},
{
"overall_acc": 60.38,
"model": "Gemini-Exp-1206",
"type": "FC",
"link": "https://blog.google/feed/gemini-exp-1206/",
"cost": 0.0,
"latency": 3.42,
"ast_summary": 85.17,
"exec_summary": 80.86,
"live_ast_acc": 78.54,
"multi_turn_acc": 20.25,
"relevance": 77.78,
"irrelevance": 79.64,
"organization": "Google",
"license": "Proprietary",
"provider": "google",
},
{
"overall_acc": 59.67,
"model": "Qwen2.5-32B-Instruct",
"type": "Prompt",
"link": "https://huggingface.co/Qwen/Qwen2.5-32B-Instruct",
"cost": "N/A",
"latency": 2.26,
"ast_summary": 85.81,
"exec_summary": 89.79,
"live_ast_acc": 74.23,
"multi_turn_acc": 17.75,
"relevance": 100,
"irrelevance": 73.75,
"organization": "Qwen",
"license": "apache-2.0",
"provider": "unknown",
},
{
"overall_acc": 59.57,
"model": "GPT-4-turbo-2024-04-09",
"type": "Prompt",
"link": "https://platform.openai.com/docs/models/gpt-4-and-gpt-4-turbo",
"cost": 58.87,
"latency": 1.24,
"ast_summary": 90.88,
"exec_summary": 89.45,
"live_ast_acc": 63.84,
"multi_turn_acc": 30.25,
"relevance": 100,
"irrelevance": 35.57,
"organization": "OpenAI",
"license": "Proprietary",
"provider": "openai",
},
{
"overall_acc": 59.42,
"model": "Gemini-1.5-Pro-001",
"type": "FC",
"link": "https://deepmind.google/technologies/gemini/pro/",
"cost": 5.1,
"latency": 1.43,
"ast_summary": 84.33,
"exec_summary": 87.95,
"live_ast_acc": 76.23,
"multi_turn_acc": 16,
"relevance": 50,
"irrelevance": 84.39,
"organization": "Google",
"license": "Proprietary",
"provider": "google",
},
{
"overall_acc": 59.07,
"model": "Hammer2.1-3b",
"type": "FC",
"link": "https://huggingface.co/MadeAgents/Hammer2.1-3b",
"cost": "N/A",
"latency": 1.95,
"ast_summary": 86.85,
"exec_summary": 84.09,
"live_ast_acc": 74.04,
"multi_turn_acc": 17.38,
"relevance": 82.35,
"irrelevance": 81.87,
"organization": "MadeAgents",
"license": "qwen-research",
"provider": "unknown",
},
{
"overall_acc": 58.45,
"model": "mistral-large-2407",
"type": "FC",
"link": "https://mistral.ai/news/mistral-large-2407/",
"cost": 12.68,
"latency": 3.12,
"ast_summary": 86.81,
"exec_summary": 84.38,
"live_ast_acc": 69.88,
"multi_turn_acc": 23.75,
"relevance": 72.22,
"irrelevance": 52.85,
"organization": "Mistral AI",
"license": "Proprietary",
"provider": "mistral",
},
{
"overall_acc": 58.42,
"model": "ToolACE-8B",
"type": "FC",
"link": "https://huggingface.co/Team-ACE/ToolACE-8B",
"cost": "N/A",
"latency": 5.24,
"ast_summary": 87.54,
"exec_summary": 89.21,
"live_ast_acc": 78.59,
"multi_turn_acc": 7.75,
"relevance": 83.33,
"irrelevance": 87.88,
"organization": "Huawei Noah & USTC",
"license": "Apache-2.0",
"provider": "unknown",
},
{
"overall_acc": 57.78,
"model": "xLAM-8x22b-r",
"type": "FC",
"link": "https://huggingface.co/Salesforce/xLAM-8x22b-r",
"cost": "N/A",
"latency": 9.26,
"ast_summary": 83.69,
"exec_summary": 87.88,
"live_ast_acc": 72.59,
"multi_turn_acc": 16.25,
"relevance": 88.89,
"irrelevance": 67.81,
"organization": "Salesforce",
"license": "cc-by-nc-4.0",
"provider": "unknown",
},
{
"overall_acc": 57.68,
"model": "Qwen2.5-14B-Instruct",
"type": "Prompt",
"link": "https://huggingface.co/Qwen/Qwen2.5-14B-Instruct",
"cost": "N/A",
"latency": 2.02,
"ast_summary": 85.69,
"exec_summary": 88.84,
"live_ast_acc": 74.14,
"multi_turn_acc": 12.25,
"relevance": 77.78,
"irrelevance": 77.06,
"organization": "Qwen",
"license": "apache-2.0",
"provider": "unknown",
},
{
"overall_acc": 57.23,
"model": "DeepSeek-V3",
"type": "FC",
"link": "https://api-docs.deepseek.com/news/news1226",
"cost": "N/A",
"latency": 2.58,
"ast_summary": 89.17,
"exec_summary": 83.39,
"live_ast_acc": 68.41,
"multi_turn_acc": 18.62,
"relevance": 88.89,
"irrelevance": 59.36,
"organization": "DeepSeek",
"license": "DeepSeek License",
"provider": "unknown",
},
{
"overall_acc": 57.09,
"model": "Gemini-1.5-Flash-001",
"type": "Prompt",
"link": "https://deepmind.google/technologies/gemini/flash/",
"cost": 0.48,
"latency": 0.71,
"ast_summary": 85.69,
"exec_summary": 83.59,
"live_ast_acc": 68.9,
"multi_turn_acc": 19.5,
"relevance": 83.33,
"irrelevance": 62.78,
"organization": "Google",
"license": "Proprietary",
"provider": "google",
},
{
"overall_acc": 56.79,
"model": "Gemini-1.5-Flash-002",
"type": "Prompt",
"link": "https://deepmind.google/technologies/gemini/flash/",
"cost": 0.46,
"latency": 0.81,
"ast_summary": 81.65,
"exec_summary": 80.64,
"live_ast_acc": 76.72,
"multi_turn_acc": 12.5,
"relevance": 83.33,
"irrelevance": 78.49,
"organization": "Google",
"license": "Proprietary",
"provider": "google",
},
]
supported_top_tool_models = [
{
"cost": item["cost"],
"model": item["model"],
"type": item["type"],
"provider": item["provider"],
}
for item in leaderboard_data
if item["provider"] in VALID_PROVIDERS
]

View File

@ -10,7 +10,7 @@ from rich.panel import Panel
from rich.text import Text
from ra_aid.logging_config import get_logger
from ra_aid.models_params import models_params, DEFAULT_BASE_LATENCY
from ra_aid.models_params import DEFAULT_BASE_LATENCY, models_params
from ra_aid.proc.interactive import run_interactive_command
from ra_aid.text.processing import truncate_output
from ra_aid.tools.memory import _global_memory, log_work_event
@ -142,7 +142,11 @@ def run_programming_task(
# Get provider/model specific latency coefficient
provider = _global_memory.get("config", {}).get("provider", "")
model = _global_memory.get("config", {}).get("model", "")
latency = models_params.get(provider, {}).get(model, {}).get("latency_coefficient", DEFAULT_BASE_LATENCY)
latency = (
models_params.get(provider, {})
.get(model, {})
.get("latency_coefficient", DEFAULT_BASE_LATENCY)
)
result = run_interactive_command(command, expected_runtime_seconds=latency)
print()

View File

@ -82,7 +82,8 @@ def run_shell_command(
try:
print()
output, return_code = run_interactive_command(
["/bin/bash", "-c", command], expected_runtime_seconds=timeout
["/bin/bash", "-c", command],
expected_runtime_seconds=timeout,
)
print()
result = {

View File

@ -1,5 +1,6 @@
"""Unit tests for agent_utils.py."""
from typing import Any, Dict, Literal
from unittest.mock import Mock, patch
import litellm
@ -127,7 +128,10 @@ def test_create_agent_openai(mock_model, mock_memory):
assert agent == "ciayn_agent"
mock_ciayn.assert_called_once_with(
mock_model, [], max_tokens=models_params["openai"]["gpt-4"]["token_limit"]
mock_model,
[],
max_tokens=models_params["openai"]["gpt-4"]["token_limit"],
config={"provider": "openai", "model": "gpt-4"},
)
@ -141,7 +145,10 @@ def test_create_agent_no_token_limit(mock_model, mock_memory):
assert agent == "ciayn_agent"
mock_ciayn.assert_called_once_with(
mock_model, [], max_tokens=DEFAULT_TOKEN_LIMIT
mock_model,
[],
max_tokens=DEFAULT_TOKEN_LIMIT,
config={"provider": "unknown", "model": "unknown-model"},
)
@ -158,6 +165,7 @@ def test_create_agent_missing_config(mock_model, mock_memory):
mock_model,
[],
max_tokens=DEFAULT_TOKEN_LIMIT,
config={"provider": "openai"},
)
@ -201,7 +209,10 @@ def test_create_agent_with_checkpointer(mock_model, mock_memory):
assert agent == "ciayn_agent"
mock_ciayn.assert_called_once_with(
mock_model, [], max_tokens=models_params["openai"]["gpt-4"]["token_limit"]
mock_model,
[],
max_tokens=models_params["openai"]["gpt-4"]["token_limit"],
config={"provider": "openai", "model": "gpt-4"},
)
@ -275,3 +286,118 @@ def test_get_model_token_limit_planner(mock_memory):
mock_get_info.return_value = {"max_input_tokens": 120000}
token_limit = get_model_token_limit(config, "planner")
assert token_limit == 120000
# New tests for private helper methods in agent_utils.py
def test_setup_and_restore_interrupt_handling():
import signal
from ra_aid.agent_utils import (
_request_interrupt,
_restore_interrupt_handling,
_setup_interrupt_handling,
)
original_handler = signal.getsignal(signal.SIGINT)
handler = _setup_interrupt_handling()
# Verify the SIGINT handler is set to _request_interrupt
assert signal.getsignal(signal.SIGINT) == _request_interrupt
_restore_interrupt_handling(handler)
# Verify the SIGINT handler is restored to the original
assert signal.getsignal(signal.SIGINT) == original_handler
def test_increment_and_decrement_agent_depth():
from ra_aid.agent_utils import (
_decrement_agent_depth,
_global_memory,
_increment_agent_depth,
)
_global_memory["agent_depth"] = 10
_increment_agent_depth()
assert _global_memory["agent_depth"] == 11
_decrement_agent_depth()
assert _global_memory["agent_depth"] == 10
def test_run_agent_stream(monkeypatch):
from ra_aid.agent_utils import _global_memory, _run_agent_stream
# Create a dummy agent that yields one chunk
class DummyAgent:
def stream(self, input_data, cfg: dict):
yield {"content": "chunk1"}
dummy_agent = DummyAgent()
# Set flags so that _run_agent_stream will reset them
_global_memory["plan_completed"] = True
_global_memory["task_completed"] = True
_global_memory["completion_message"] = "existing"
call_flag = {"called": False}
def fake_print_agent_output(
chunk: Dict[str, Any], agent_type: Literal["CiaynAgent", "React"]
):
call_flag["called"] = True
monkeypatch.setattr(
"ra_aid.agent_utils.print_agent_output", fake_print_agent_output
)
_run_agent_stream(dummy_agent, [HumanMessage("dummy prompt")], {})
assert call_flag["called"]
assert _global_memory["plan_completed"] is False
assert _global_memory["task_completed"] is False
assert _global_memory["completion_message"] == ""
def test_execute_test_command_wrapper(monkeypatch):
from ra_aid.agent_utils import _execute_test_command_wrapper
# Patch execute_test_command to return a testable tuple
def fake_execute(config, orig, tests, auto):
return (True, "new prompt", auto, tests + 1)
monkeypatch.setattr("ra_aid.agent_utils.execute_test_command", fake_execute)
result = _execute_test_command_wrapper("orig", {}, 0, False)
assert result == (True, "new prompt", False, 1)
def test_handle_api_error_valueerror():
import pytest
from ra_aid.agent_utils import _handle_api_error
# ValueError not containing "code" or "429" should be re-raised
with pytest.raises(ValueError):
_handle_api_error(ValueError("some error"), 0, 5, 1)
def test_handle_api_error_max_retries():
import pytest
from ra_aid.agent_utils import _handle_api_error
# When attempt reaches max retries, a RuntimeError should be raised
with pytest.raises(RuntimeError):
_handle_api_error(Exception("error code 429"), 4, 5, 1)
def test_handle_api_error_retry(monkeypatch):
import time
from ra_aid.agent_utils import _handle_api_error
# Patch time.monotonic and time.sleep to simulate immediate delay expiration
fake_time = [0]
def fake_monotonic():
fake_time[0] += 0.5
return fake_time[0]
monkeypatch.setattr(time, "monotonic", fake_monotonic)
monkeypatch.setattr(time, "sleep", lambda s: None)
# Should not raise error when attempt is lower than max retries
_handle_api_error(Exception("error code 429"), 0, 5, 1)

View File

@ -1,11 +1,38 @@
import unittest
from unittest.mock import Mock
import pytest
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
from ra_aid.agents.ciayn_agent import CiaynAgent, validate_function_call_pattern
from ra_aid.exceptions import ToolExecutionError
# Dummy tool function for testing retry and fallback behavior
def dummy_tool():
dummy_tool.attempt += 1
if dummy_tool.attempt < 3:
raise Exception("Simulated failure")
return "dummy success"
dummy_tool.attempt = 0
class DummyTool:
def __init__(self, func):
self.func = func
class DummyModel:
def invoke(self, _messages: list[BaseMessage]):
return AIMessage("dummy_tool()")
def bind_tools(self, tools, tool_choice):
pass
# Fixtures from the source file
@pytest.fixture
def mock_model():
"""Create a mock language model."""
@ -21,6 +48,7 @@ def agent(mock_model):
return CiaynAgent(mock_model, tools, max_history_messages=3)
# Trimming test functions
def test_trim_chat_history_preserves_initial_messages(agent):
"""Test that initial messages are preserved during trimming."""
initial_messages = [
@ -33,9 +61,7 @@ def test_trim_chat_history_preserves_initial_messages(agent):
HumanMessage(content="Chat 3"),
AIMessage(content="Chat 4"),
]
result = agent._trim_chat_history(initial_messages, chat_history)
# Verify initial messages are preserved
assert result[:2] == initial_messages
# Verify only last 3 chat messages are kept (due to max_history_messages=3)
@ -47,9 +73,7 @@ def test_trim_chat_history_under_limit(agent):
"""Test trimming when chat history is under the maximum limit."""
initial_messages = [HumanMessage(content="Initial")]
chat_history = [HumanMessage(content="Chat 1"), AIMessage(content="Chat 2")]
result = agent._trim_chat_history(initial_messages, chat_history)
# Verify no trimming occurred
assert len(result) == 3
assert result == initial_messages + chat_history
@ -65,9 +89,7 @@ def test_trim_chat_history_over_limit(agent):
AIMessage(content="Chat 4"),
HumanMessage(content="Chat 5"),
]
result = agent._trim_chat_history(initial_messages, chat_history)
# Verify correct trimming
assert len(result) == 4 # initial + max_history_messages
assert result[0] == initial_messages[0] # Initial message preserved
@ -83,9 +105,7 @@ def test_trim_chat_history_empty_initial(agent):
HumanMessage(content="Chat 3"),
AIMessage(content="Chat 4"),
]
result = agent._trim_chat_history(initial_messages, chat_history)
# Verify only last 3 messages are kept
assert len(result) == 3
assert result == chat_history[-3:]
@ -98,9 +118,7 @@ def test_trim_chat_history_empty_chat(agent):
AIMessage(content="Initial 2"),
]
chat_history = []
result = agent._trim_chat_history(initial_messages, chat_history)
# Verify initial messages are preserved and no trimming occurred
assert result == initial_messages
assert len(result) == 2
@ -109,16 +127,13 @@ def test_trim_chat_history_empty_chat(agent):
def test_trim_chat_history_token_limit():
"""Test trimming based on token limit."""
agent = CiaynAgent(Mock(), [], max_history_messages=10, max_tokens=25)
initial_messages = [HumanMessage(content="Initial")] # ~2 tokens
chat_history = [
HumanMessage(content="A" * 40), # ~10 tokens
AIMessage(content="B" * 40), # ~10 tokens
HumanMessage(content="C" * 40), # ~10 tokens
]
result = agent._trim_chat_history(initial_messages, chat_history)
# Should keep initial message (~2 tokens) and last message (~10 tokens)
assert len(result) == 2
assert result[0] == initial_messages[0]
@ -128,16 +143,13 @@ def test_trim_chat_history_token_limit():
def test_trim_chat_history_no_token_limit():
"""Test trimming with no token limit set."""
agent = CiaynAgent(Mock(), [], max_history_messages=2, max_tokens=None)
initial_messages = [HumanMessage(content="Initial")]
chat_history = [
HumanMessage(content="A" * 1000),
AIMessage(content="B" * 1000),
HumanMessage(content="C" * 1000),
]
result = agent._trim_chat_history(initial_messages, chat_history)
# Should keep initial message and last 2 messages (max_history_messages=2)
assert len(result) == 3
assert result[0] == initial_messages[0]
@ -147,7 +159,6 @@ def test_trim_chat_history_no_token_limit():
def test_trim_chat_history_both_limits():
"""Test trimming with both message count and token limits."""
agent = CiaynAgent(Mock(), [], max_history_messages=3, max_tokens=35)
initial_messages = [HumanMessage(content="Init")] # ~1 token
chat_history = [
HumanMessage(content="A" * 40), # ~10 tokens
@ -155,9 +166,7 @@ def test_trim_chat_history_both_limits():
HumanMessage(content="C" * 40), # ~10 tokens
AIMessage(content="D" * 40), # ~10 tokens
]
result = agent._trim_chat_history(initial_messages, chat_history)
# Should first apply message limit (keeping last 3)
# Then token limit should further reduce to fit under 15 tokens
assert len(result) == 2 # Initial message + 1 message under token limit
@ -165,6 +174,38 @@ def test_trim_chat_history_both_limits():
assert result[1] == chat_history[-1]
# Fallback tests
class TestCiaynAgentFallback(unittest.TestCase):
def setUp(self):
# Reset dummy_tool attempt counter before each test
dummy_tool.attempt = 0
self.dummy_tool = DummyTool(dummy_tool)
self.model = DummyModel()
# Create a CiaynAgent with the dummy tool
self.agent = CiaynAgent(self.model, [self.dummy_tool])
# def test_retry_logic_with_failure_recovery(self):
# # Test that run_agent_with_retry retries until success
# from ra_aid.agent_utils import run_agent_with_retry
#
# config = {"max_test_cmd_retries": 0, "auto_test": True}
# result = run_agent_with_retry(self.agent, "dummy_tool()", config)
# self.assertEqual(result, "Agent run completed successfully")
def test_switch_models_on_fallback(self):
# Test fallback behavior by making dummy_tool always fail
from langchain_core.messages import HumanMessage
def always_fail():
raise Exception("Persistent failure")
always_fail_tool = DummyTool(always_fail)
agent = CiaynAgent(self.model, [always_fail_tool])
with self.assertRaises(ToolExecutionError):
agent._execute_tool(HumanMessage("always_fail()"))
# Function call validation tests
class TestFunctionCallValidation:
@pytest.mark.parametrize(
"test_input",
@ -221,3 +262,11 @@ class TestFunctionCallValidation:
def test_multiline_responses(self, test_input):
"""Test function calls spanning multiple lines."""
assert not validate_function_call_pattern(test_input)
class TestCiaynAgentNewMethods(unittest.TestCase):
pass
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,372 @@
import unittest
from ra_aid.exceptions import FallbackToolExecutionError
from ra_aid.fallback_handler import FallbackHandler
class DummyLogger:
def debug(self, msg):
pass
def error(self, msg):
pass
class DummyAgent:
provider = "openai"
tools = []
model = None
class TestFallbackHandler(unittest.TestCase):
def setUp(self):
self.config = {
"max_tool_failures": 2,
"fallback_tool_models": "dummy-fallback-model",
"experimental_fallback_handler": True,
}
self.fallback_handler = FallbackHandler(self.config, [])
self.logger = DummyLogger()
self.agent = DummyAgent()
def dummy_tool():
pass
class DummyToolWrapper:
def __init__(self, func):
self.func = func
self.agent.tools = [DummyToolWrapper(dummy_tool)]
def test_handle_failure_increments_counter(self):
from ra_aid.exceptions import ToolExecutionError
initial_failures = self.fallback_handler.tool_failure_consecutive_failures
error_obj = ToolExecutionError(
"Test error", base_message="dummy_call()", tool_name="dummy_tool"
)
self.fallback_handler.handle_failure(error_obj, self.agent, [])
self.assertEqual(
self.fallback_handler.tool_failure_consecutive_failures,
initial_failures + 1,
)
def test_attempt_fallback_resets_counter(self):
# Monkey-patch dummy functions for fallback components
def dummy_initialize_llm(provider, model_name, temperature=None):
class DummyModel:
def bind_tools(self, tools, tool_choice):
pass
return DummyModel()
def dummy_validate_provider_env(provider):
return True
import ra_aid.llm as llm
original_initialize = llm.initialize_llm
original_validate = llm.validate_provider_env
llm.initialize_llm = dummy_initialize_llm
llm.validate_provider_env = dummy_validate_provider_env
self.fallback_handler.tool_failure_consecutive_failures = 2
with self.assertRaises(FallbackToolExecutionError):
self.fallback_handler.attempt_fallback()
self.assertEqual(self.fallback_handler.tool_failure_consecutive_failures, 0)
llm.initialize_llm = original_initialize
llm.validate_provider_env = original_validate
def test_load_fallback_tool_models(self):
import ra_aid.fallback_handler as fh
original_supported = fh.supported_top_tool_models
fh.supported_top_tool_models = [
{"provider": "dummy", "model": "dummy_model", "type": "prompt"}
]
models = self.fallback_handler._load_fallback_tool_models(self.config)
self.assertIsInstance(models, list)
fh.supported_top_tool_models = original_supported
def test_extract_failed_tool_name(self):
from ra_aid.exceptions import FallbackToolExecutionError, ToolExecutionError
# Case when tool_name is provided
error1 = ToolExecutionError(
"Error", base_message="dummy", tool_name="dummy_tool"
)
name1 = self.fallback_handler.extract_failed_tool_name(error1)
self.assertEqual(name1, "dummy_tool")
# Case when tool_name is not provided but regex works
error2 = ToolExecutionError('error with name="test_tool"')
name2 = self.fallback_handler.extract_failed_tool_name(error2)
self.assertEqual(name2, "test_tool")
# Case when regex fails and exception is raised
error3 = ToolExecutionError("no tool name here")
with self.assertRaises(FallbackToolExecutionError):
self.fallback_handler.extract_failed_tool_name(error3)
def test_find_tool_to_bind(self):
class DummyWrapper:
def __init__(self, func):
self.func = func
def dummy_func(_args):
return "result"
dummy_wrapper = DummyWrapper(dummy_func)
self.agent.tools.append(dummy_wrapper)
tool = self.fallback_handler._find_tool_to_bind(self.agent, dummy_func.__name__)
self.assertIsNotNone(tool)
self.assertEqual(tool.func.__name__, dummy_func.__name__)
def test_bind_tool_model(self):
# Setup a dummy simple_model with bind_tools method
class DummyModel:
def bind_tools(self, tools, tool_choice=None):
self.bound = True
self.tools = tools
self.tool_choice = tool_choice
return self
def with_retry(self, stop_after_attempt):
return self
def invoke(self, msg_list):
return "dummy_response"
dummy_model = DummyModel()
# Set current tool for binding
class DummyTool:
def invoke(self, args):
return "result"
self.fallback_handler.current_tool_to_bind = DummyTool()
self.fallback_handler.current_failing_tool_name = "test_tool"
# Test with force calling ("fc") type
fallback_model_fc = {"type": "fc"}
bound_model_fc = self.fallback_handler._bind_tool_model(
dummy_model, fallback_model_fc
)
self.assertTrue(hasattr(bound_model_fc, "tool_choice"))
self.assertEqual(bound_model_fc.tool_choice, "test_tool")
# Test with prompt type
fallback_model_prompt = {"type": "prompt"}
bound_model_prompt = self.fallback_handler._bind_tool_model(
dummy_model, fallback_model_prompt
)
self.assertTrue(bound_model_prompt.tool_choice is None)
def test_invoke_fallback(self):
import os
from unittest.mock import patch
# Successful fallback scenario with proper API key set
with (
patch.dict(os.environ, {"DUMMY_API_KEY": "dummy_value"}),
patch(
"ra_aid.fallback_handler.supported_top_tool_models",
new=[{"provider": "dummy", "model": "dummy_model", "type": "prompt"}],
),
patch("ra_aid.fallback_handler.validate_provider_env", return_value=True),
patch("ra_aid.fallback_handler.initialize_llm") as mock_init_llm,
):
class DummyModel:
def bind_tools(self, tools, tool_choice=None):
return self
def with_retry(self, stop_after_attempt):
return self
def invoke(self, msg_list):
return DummyResponse()
class DummyResponse:
additional_kwargs = {
"tool_calls": [
{
"id": "1",
"type": "test",
"function": {"name": "dummy_tool", "arguments": '{"a":1}'},
}
]
}
def dummy_initialize_llm(provider, model_name):
return DummyModel()
mock_init_llm.side_effect = dummy_initialize_llm
# Set current tool for fallback
class DummyTool:
def invoke(self, args):
return "tool_result"
self.fallback_handler.current_tool_to_bind = DummyTool()
self.fallback_handler.current_failing_tool_name = "dummy_tool"
# Add dummy tool for lookup in invoke_prompt_tool_call
self.fallback_handler.tools.append(
type(
"DummyToolWrapper",
(),
{
"func": type("DummyToolFunc", (), {"__name__": "dummy_tool"})(),
"invoke": lambda self, args=None: "tool_result",
},
)
)
result = self.fallback_handler.invoke_fallback(
{"provider": "dummy", "model": "dummy_model", "type": "prompt"}
)
self.assertIsInstance(result, list)
self.assertEqual(result[1], "tool_result")
# Failed fallback scenario due to missing API key (simulate by empty environment)
with (
patch.dict(os.environ, {}, clear=True),
patch(
"ra_aid.fallback_handler.supported_top_tool_models",
new=[{"provider": "dummy", "model": "dummy_model", "type": "prompt"}],
),
patch("ra_aid.fallback_handler.validate_provider_env", return_value=False),
patch("ra_aid.fallback_handler.initialize_llm") as mock_init_llm,
):
class FailingDummyModel:
def bind_tools(self, tools, tool_choice=None):
return self
def with_retry(self, stop_after_attempt):
return self
def invoke(self, msg_list):
raise Exception("API key missing")
def failing_initialize_llm(provider, model_name):
return FailingDummyModel()
mock_init_llm.side_effect = failing_initialize_llm
fallback_result = self.fallback_handler.invoke_fallback(
{"provider": "dummy", "model": "dummy_model", "type": "prompt"}
)
self.assertIsNone(fallback_result)
# Test that the overall fallback mechanism raises FallbackToolExecutionError when all models fail
# Set failure count to trigger the fallback attempt in attempt_fallback
from ra_aid.exceptions import FallbackToolExecutionError
self.fallback_handler.tool_failure_consecutive_failures = (
self.fallback_handler.max_failures
)
with self.assertRaises(FallbackToolExecutionError) as cm:
self.fallback_handler.attempt_fallback()
self.assertIn("All fallback models have failed", str(cm.exception))
def test_construct_prompt_msg_list(self):
msgs = self.fallback_handler.construct_prompt_msg_list()
from ra_aid.fallback_handler import HumanMessage, SystemMessage
self.assertTrue(any(isinstance(m, SystemMessage) for m in msgs))
self.assertTrue(any(isinstance(m, HumanMessage) for m in msgs))
# Test with failed_messages added
self.fallback_handler.failed_messages.append("failed_msg")
msgs_with_fail = self.fallback_handler.construct_prompt_msg_list()
self.assertTrue(any("failed_msg" in str(m) for m in msgs_with_fail))
def test_invoke_prompt_tool_call(self):
# Create dummy tool function
def dummy_tool_func(args):
return "invoked_result"
dummy_tool_func.__name__ = "dummy_tool"
# Create wrapper class
class DummyToolWrapper:
def __init__(self, func):
self.func = func
def invoke(self, args):
return self.func(args)
dummy_wrapper = DummyToolWrapper(dummy_tool_func)
self.fallback_handler.tools = [dummy_wrapper]
tool_call_req = {"name": "dummy_tool", "arguments": {"x": 42}}
result = self.fallback_handler.invoke_prompt_tool_call(tool_call_req)
self.assertEqual(result, "invoked_result")
def test_base_message_to_tool_call_dict(self):
dummy_tool_call = {
"id": "123",
"type": "test",
"function": {"name": "dummy_tool", "arguments": '{"x":42}'},
}
DummyResponse = type(
"DummyResponse",
(),
{"additional_kwargs": {"tool_calls": [dummy_tool_call]}},
)
result = self.fallback_handler.base_message_to_tool_call_dict(DummyResponse)
self.assertEqual(result["id"], "123")
self.assertEqual(result["name"], "dummy_tool")
self.assertEqual(result["arguments"], {"x": 42})
def test_parse_tool_arguments(self):
args_str = '{"a": 1}'
parsed = self.fallback_handler._parse_tool_arguments(args_str)
self.assertEqual(parsed, {"a": 1})
args_dict = {"b": 2}
parsed_dict = self.fallback_handler._parse_tool_arguments(args_dict)
self.assertEqual(parsed_dict, {"b": 2})
def test_get_tool_calls(self):
DummyResponse = type("DummyResponse", (), {})()
DummyResponse.additional_kwargs = {"tool_calls": [{"id": "1"}]}
calls = self.fallback_handler.get_tool_calls(DummyResponse)
self.assertEqual(calls, [{"id": "1"}])
DummyResponse2 = type("DummyResponse2", (), {"tool_calls": [{"id": "2"}]})()
calls2 = self.fallback_handler.get_tool_calls(DummyResponse2)
self.assertEqual(calls2, [{"id": "2"}])
dummy_dict = {"additional_kwargs": {"tool_calls": [{"id": "3"}]}}
calls3 = self.fallback_handler.get_tool_calls(dummy_dict)
self.assertEqual(calls3, [{"id": "3"}])
def test_handle_failure_response(self):
from ra_aid.exceptions import ToolExecutionError
def dummy_handle_failure(error, agent):
return ["fallback_response"]
self.fallback_handler.handle_failure = dummy_handle_failure
response = self.fallback_handler.handle_failure_response(
ToolExecutionError("test", tool_name="dummy_tool"), self.agent, "React"
)
from ra_aid.fallback_handler import SystemMessage
self.assertTrue(all(isinstance(m, SystemMessage) for m in response))
response_non = self.fallback_handler.handle_failure_response(
ToolExecutionError("test", tool_name="dummy_tool"), self.agent, "Other"
)
self.assertIsNone(response_non)
def test_init_msg_list_non_overlapping(self):
# Test when the first two and last two messages do not overlap.
full_list = ["msg1", "msg2", "msg3", "msg4", "msg5"]
self.fallback_handler.init_msg_list(full_list)
# Expected merged list: first two ("msg1", "msg2") plus last two ("msg4", "msg5")
self.assertEqual(
self.fallback_handler.msg_list, ["msg1", "msg2", "msg4", "msg5"]
)
def test_init_msg_list_with_overlap(self):
# Test when the last two messages overlap with the first two.
full_list = ["msg1", "msg2", "msg1", "msg3"]
self.fallback_handler.init_msg_list(full_list)
# Expected merged list: first two ("msg1", "msg2") plus "msg3" from the last two, since "msg1" was already present.
self.assertEqual(self.fallback_handler.msg_list, ["msg1", "msg2", "msg3"])
if __name__ == "__main__":
unittest.main()

View File

@ -142,7 +142,9 @@ def test_initialize_expert_openai_compatible(clean_env, mock_openai, monkeypatch
def test_initialize_expert_unsupported_provider(clean_env):
"""Test error handling for unsupported provider in expert mode."""
with pytest.raises(ValueError, match=r"Unsupported provider: unknown"):
with pytest.raises(
ValueError, match=r"Missing required environment variable for provider: unknown"
):
initialize_expert_llm("unknown", "model")
@ -235,7 +237,9 @@ def test_initialize_openai_compatible(clean_env, mock_openai):
def test_initialize_unsupported_provider(clean_env):
"""Test initialization with unsupported provider raises ValueError"""
with pytest.raises(ValueError, match=r"Unsupported provider: unknown"):
with pytest.raises(
ValueError, match=r"Missing required environment variable for provider: unknown"
):
initialize_llm("unknown", "model")
@ -257,15 +261,33 @@ def test_temperature_defaults(clean_env, mock_openai, mock_anthropic, mock_gemin
max_retries=5,
)
# Test error when no temperature provided for models that support it
with pytest.raises(ValueError, match="Temperature must be provided for model"):
initialize_llm("openai", "test-model")
# Test default temperature when none is provided for models that support it
initialize_llm("openai", "test-model")
mock_openai.assert_called_with(
api_key="test-key",
model="test-model",
temperature=0.7,
timeout=180,
max_retries=5,
)
with pytest.raises(ValueError, match="Temperature must be provided for model"):
initialize_llm("anthropic", "test-model")
initialize_llm("anthropic", "test-model")
mock_anthropic.assert_called_with(
api_key="test-key",
model_name="test-model",
temperature=0.7,
timeout=180,
max_retries=5,
)
with pytest.raises(ValueError, match="Temperature must be provided for model"):
initialize_llm("gemini", "test-model")
initialize_llm("gemini", "test-model")
mock_gemini.assert_called_with(
api_key="test-key",
model="test-model",
temperature=0.7,
timeout=180,
max_retries=5,
)
# Test expert models don't require temperature
initialize_expert_llm("openai", "o1")